diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ada4af36..a4d515678 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Changed - Updated CI so that Postgres tests run against v18 which is GA and not against v13 which is EOL (https://github.com/authzed/spicedb/pull/2926) - Added tracing to request validation (https://github.com/authzed/spicedb/pull/2950) +- Query Planner optimization: in Check requests, prune branches that cannot lead to the subject type specified. ### Fixed - Regression introduced in 1.49.2: missing spans in ReadSchema calls (https://github.com/authzed/spicedb/pull/2947) diff --git a/internal/services/v1/permissions_queryplan.go b/internal/services/v1/permissions_queryplan.go index e5a4104f9..a68ea7b05 100644 --- a/internal/services/v1/permissions_queryplan.go +++ b/internal/services/v1/permissions_queryplan.go @@ -63,6 +63,13 @@ func (ps *permissionServer) checkPermissionWithQueryPlan(ctx context.Context, re return nil, ps.rewriteError(ctx, err) } + // Prune branches that can never reach the requested subject type. + it, _, err = query.ApplyReachabilityPruning(it, req.Subject.Object.ObjectType) + if err != nil { + return nil, ps.rewriteError(ctx, err) + } + // TODO apply to LR and LS too? + // Parse caveat context if provided caveatContext, err := GetCaveatContext(ctx, req.Context, ps.config.MaxCaveatContextSize) if err != nil { diff --git a/pkg/query/optimize_reachability.go b/pkg/query/optimize_reachability.go new file mode 100644 index 000000000..99ccd652e --- /dev/null +++ b/pkg/query/optimize_reachability.go @@ -0,0 +1,140 @@ +package query + +import "slices" + +// ApplyReachabilityPruning applies subject-type reachability pruning to an +// iterator tree, replacing subtrees with empty FixedIterators when they can +// never produce the target subject type. +// +// For arrows (e.g. editor->view), the pruning decision is based on whether the +// right side (computed userset) can reach the target subject type. If not, the +// entire arrow is elided. The left side's subject types are intermediate hops +// and are not considered for pruning. +func ApplyReachabilityPruning(it Iterator, targetSubjectType string) (Iterator, bool, error) { + if it == nil { + return nil, false, nil + } + + // For arrows, only recurse into the right child. + if arrow, ok := it.(*ArrowIterator); ok { + return applyReachabilityToArrow(arrow, targetSubjectType) + } + if arrow, ok := it.(*IntersectionArrowIterator); ok { + return applyReachabilityToIntersectionArrow(arrow, targetSubjectType) + } + + // For all other iterators, recurse into all children first (bottom-up). + origSubs := it.Subiterators() + changed := false + if len(origSubs) > 0 { + subs := make([]Iterator, len(origSubs)) + copy(subs, origSubs) + + subChanged := false + for i, sub := range subs { + newSub, ok, err := ApplyReachabilityPruning(sub, targetSubjectType) + if err != nil { + return nil, false, err + } + if ok { + subs[i] = newSub + subChanged = true + } + } + if subChanged { + changed = true + var err error + it, err = it.ReplaceSubiterators(subs) + if err != nil { + return nil, false, err + } + } + } + + // Now check this node's subject types. + subjectTypes, err := it.SubjectTypes() + if err != nil || len(subjectTypes) == 0 { + return it, changed, nil + } + + if hasMatchingSubjectType(subjectTypes, targetSubjectType) { + return it, changed, nil + } + + return newEmptyWithKey(it.CanonicalKey()), true, nil +} + +func applyReachabilityToArrow(arrow *ArrowIterator, targetSubjectType string) (Iterator, bool, error) { + // Only recurse into the right child - left side types are intermediates. + newRight, rightChanged, err := ApplyReachabilityPruning(arrow.right, targetSubjectType) + if err != nil { + return nil, false, err + } + + if rightChanged { + newArrow, err := arrow.ReplaceSubiterators([]Iterator{arrow.left, newRight}) + if err != nil { + return nil, false, err + } + arrow = newArrow.(*ArrowIterator) + } + + // Check if the right side can produce the target subject type. + if shouldPruneArrowRight(arrow.right, targetSubjectType, rightChanged) { + return newEmptyWithKey(arrow.CanonicalKey()), true, nil + } + return arrow, rightChanged, nil +} + +func applyReachabilityToIntersectionArrow(arrow *IntersectionArrowIterator, targetSubjectType string) (Iterator, bool, error) { + // Only recurse into the right child - left side types are intermediates. + newRight, rightChanged, err := ApplyReachabilityPruning(arrow.right, targetSubjectType) + if err != nil { + return nil, false, err + } + + if rightChanged { + newArrow, err := arrow.ReplaceSubiterators([]Iterator{arrow.left, newRight}) + if err != nil { + return nil, false, err + } + arrow = newArrow.(*IntersectionArrowIterator) + } + + // Check if the right side can produce the target subject type. + if shouldPruneArrowRight(arrow.right, targetSubjectType, rightChanged) { + return newEmptyWithKey(arrow.CanonicalKey()), true, nil + } + return arrow, rightChanged, nil +} + +// shouldPruneArrowRight checks whether the right side of an arrow can produce +// the target subject type. Since applyReachabilityToArrow always recurses into +// the right child first, by the time this runs any non-matching leaves have +// already been pruned. So this only needs to check two cases: +// 1. Right side has no subject types after pruning → prune the arrow. +// 2. Right side still has matching subject types → keep the arrow. +func shouldPruneArrowRight(right Iterator, targetSubjectType string, rightAlreadyPruned bool) bool { + rightTypes, err := right.SubjectTypes() + if err != nil { + return false + } + + if len(rightTypes) == 0 { + return rightAlreadyPruned + } + + return !hasMatchingSubjectType(rightTypes, targetSubjectType) +} + +func hasMatchingSubjectType(subjectTypes []ObjectType, targetType string) bool { + return slices.ContainsFunc(subjectTypes, func(s ObjectType) bool { + return s.Type == targetType + }) +} + +func newEmptyWithKey(key CanonicalKey) *FixedIterator { + empty := NewFixedIterator() + empty.canonicalKey = key + return empty +} diff --git a/pkg/query/optimize_reachability_test.go b/pkg/query/optimize_reachability_test.go new file mode 100644 index 000000000..8dfdbbce7 --- /dev/null +++ b/pkg/query/optimize_reachability_test.go @@ -0,0 +1,304 @@ +package query + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/authzed/spicedb/pkg/schema/v2" +) + +func dsIter(defName, relName, subjectType, subrelation string) *DatastoreIterator { + return NewDatastoreIterator(schema.NewTestBaseRelation(defName, relName, subjectType, subrelation)) +} + +func TestPruneUnreachableSubjectTypes(t *testing.T) { + t.Parallel() + + t.Run("does not prune empty FixedIterator", func(t *testing.T) { + t.Parallel() + + it := NewFixedIterator() + + result, changed, err := ApplyReachabilityPruning(it, "user") + require.NoError(t, err) + require.False(t, changed) + require.IsType(t, &FixedIterator{}, result) + }) + + t.Run("preserves canonical key on pruned iterator", func(t *testing.T) { + t.Parallel() + + it := dsIter("document", "viewer", "group", "...") + it.canonicalKey = CanonicalKey("test-key") + + result, changed, err := ApplyReachabilityPruning(it, "user") + require.NoError(t, err) + require.True(t, changed) + require.Equal(t, CanonicalKey("test-key"), result.CanonicalKey()) + }) + + t.Run("prunes leaf with subject type that does not match", func(t *testing.T) { + t.Parallel() + + it := dsIter("document", "viewer", "group", "...") + + result, changed, err := ApplyReachabilityPruning(it, "user") + require.NoError(t, err) + require.True(t, changed) + require.IsType(t, &FixedIterator{}, result) + require.Empty(t, result.(*FixedIterator).paths) + }) + + t.Run("keeps leaf with matching subject type", func(t *testing.T) { + t.Parallel() + + it := dsIter("document", "viewer", "user", "...") + + result, changed, err := ApplyReachabilityPruning(it, "user") + require.NoError(t, err) + require.False(t, changed) + require.IsType(t, &DatastoreIterator{}, result) + }) + + t.Run("union", func(t *testing.T) { + t.Run("prunes one branch of union", func(t *testing.T) { + t.Parallel() + + userIt := dsIter("document", "viewer", "user", "...") + groupIt := dsIter("document", "editor", "group", "...") + union := NewUnionIterator(userIt, groupIt) + + result, changed, err := ApplyReachabilityPruning(union, "user") + require.NoError(t, err) + require.True(t, changed) + + require.IsType(t, &UnionIterator{}, result) + subs := result.Subiterators() + require.Len(t, subs, 2) + require.IsType(t, &DatastoreIterator{}, subs[0], "user branch should remain") + require.IsType(t, &FixedIterator{}, subs[1], "group branch should be pruned") + }) + + t.Run("does not prune union when both branches match the subject type", func(t *testing.T) { + t.Parallel() + + userIt1 := dsIter("document", "viewer", "user", "...") + userIt2 := dsIter("document", "editor", "user", "...") + union := NewUnionIterator(userIt1, userIt2) + + result, changed, err := ApplyReachabilityPruning(union, "user") + require.NoError(t, err) + require.False(t, changed) + require.IsType(t, &UnionIterator{}, result) + }) + }) + + t.Run("arrows", func(t *testing.T) { + t.Run("prunes entire arrow when right side subject type doesn't match", func(t *testing.T) { + t.Parallel() + + // Arrow[document#parent->folder, folder#viewer->group] + // Right side produces group subjects, not user, so the whole arrow should be replaced + left := dsIter("document", "parent", "folder", "...") + right := dsIter("folder", "viewer", "group", "...") + arrow := NewArrowIterator(left, right) + + result, changed, err := ApplyReachabilityPruning(arrow, "user") + require.NoError(t, err) + require.True(t, changed) + require.IsType(t, &FixedIterator{}, result, "entire arrow should be pruned") + }) + + t.Run("keeps arrow when right side subject type matches", func(t *testing.T) { + t.Parallel() + + // Arrow[document#parent->folder, folder#viewer->user] + // Right side produces user subjects - matches target. + left := dsIter("document", "parent", "folder", "...") + right := dsIter("folder", "viewer", "user", "...") + arrow := NewArrowIterator(left, right) + + result, changed, err := ApplyReachabilityPruning(arrow, "user") + require.NoError(t, err) + require.False(t, changed) + require.IsType(t, &ArrowIterator{}, result, "arrow should remain") + }) + + t.Run("keeps arrow when right side has multiple subject types and one matches", func(t *testing.T) { + t.Parallel() + + // Arrow[document#parent->folder, Union[folder#viewer->user, folder#viewer->group, folder#viewer->team]] + // Right side produces user, group, and team subjects. User matches target. + // The non-matching branches (group, team) inside the union get pruned, + // so changed=true, but the arrow itself is kept. + left := dsIter("document", "parent", "folder", "...") + rightUser := dsIter("folder", "viewer", "user", "...") + rightGroup := dsIter("folder", "viewer", "group", "...") + rightTeam := dsIter("folder", "viewer", "team", "...") + right := NewUnionIterator(rightUser, rightGroup, rightTeam) + arrow := NewArrowIterator(left, right) + + result, changed, err := ApplyReachabilityPruning(arrow, "user") + require.NoError(t, err) + require.True(t, changed, "non-matching branches inside the right-side union are pruned") + require.IsType(t, &ArrowIterator{}, result, "arrow should remain because user is reachable") + + // The right side union should have the group and team branches pruned + rightResult := result.Subiterators()[1] + require.IsType(t, &UnionIterator{}, rightResult) + rightSubs := rightResult.Subiterators() + require.Len(t, rightSubs, 3) + require.IsType(t, &DatastoreIterator{}, rightSubs[0], "user branch should remain") + require.IsType(t, &FixedIterator{}, rightSubs[1], "group branch should be pruned") + require.IsType(t, &FixedIterator{}, rightSubs[2], "team branch should be pruned") + }) + + t.Run("prunes arrow when right side has multiple subject types and none match", func(t *testing.T) { + t.Parallel() + + // Arrow[document#parent->folder, Union[folder#viewer->group, folder#viewer->team, folder#viewer->org]] + // Right side produces group, team, and org subjects. None match user. + left := dsIter("document", "parent", "folder", "...") + rightGroup := dsIter("folder", "viewer", "group", "...") + rightTeam := dsIter("folder", "viewer", "team", "...") + rightOrg := dsIter("folder", "viewer", "org", "...") + right := NewUnionIterator(rightGroup, rightTeam, rightOrg) + arrow := NewArrowIterator(left, right) + + result, changed, err := ApplyReachabilityPruning(arrow, "user") + require.NoError(t, err) + require.True(t, changed) + require.IsType(t, &FixedIterator{}, result, "entire arrow should be pruned") + }) + + t.Run("prunes arrow inside union when unreachable", func(t *testing.T) { + t.Parallel() + + // Union[ + // document#viewer->user, + // Arrow[document#parent->folder, folder#viewer->group] + // ] + // The arrow's right side produces group, not user, so arrow should get pruned. + directUser := dsIter("document", "viewer", "user", "...") + left := dsIter("document", "parent", "folder", "...") + right := dsIter("folder", "viewer", "group", "...") + arrow := NewArrowIterator(left, right) + union := NewUnionIterator(directUser, arrow) + + result, changed, err := ApplyReachabilityPruning(union, "user") + require.NoError(t, err) + require.True(t, changed) + + require.IsType(t, &UnionIterator{}, result) + subs := result.Subiterators() + require.Len(t, subs, 2) + require.IsType(t, &DatastoreIterator{}, subs[0], "direct user branch should remain") + require.IsType(t, &FixedIterator{}, subs[1], "arrow branch should be pruned") + }) + }) + + t.Run("intersection arrows", func(t *testing.T) { + t.Run("prunes entire intersection arrow when right side subject type doesn't match", func(t *testing.T) { + t.Parallel() + + // IntersectionArrow[document#parent->folder, folder#viewer->group] + // Right side produces group subjects, not user, so the whole IntersectionArrow should be pruned + left := dsIter("document", "parent", "folder", "...") + right := dsIter("folder", "viewer", "group", "...") + arrow := NewIntersectionArrowIterator(left, right) + + result, changed, err := ApplyReachabilityPruning(arrow, "user") + require.NoError(t, err) + require.True(t, changed) + require.IsType(t, &FixedIterator{}, result, "entire intersection arrow should be pruned") + }) + + t.Run("keeps intersection arrow when right side subject type matches", func(t *testing.T) { + t.Parallel() + + // IntersectionArrow[document#parent->folder, folder#viewer->user] + // Right side produces user subjects - matches target. + left := dsIter("document", "parent", "folder", "...") + right := dsIter("folder", "viewer", "user", "...") + arrow := NewIntersectionArrowIterator(left, right) + + result, changed, err := ApplyReachabilityPruning(arrow, "user") + require.NoError(t, err) + require.False(t, changed) + require.IsType(t, &IntersectionArrowIterator{}, result, "intersection arrow should remain") + }) + + t.Run("keeps intersection arrow when right side has multiple subject types and one matches", func(t *testing.T) { + t.Parallel() + + // IntersectionArrow[document#parent->folder, Union[folder#viewer->user, folder#viewer->group, folder#viewer->team]] + // Right side produces user, group, and team subjects. User matches target. + // The non-matching branches (group, team) inside the union get pruned, + // so changed=true, but the intersection arrow itself is kept. + left := dsIter("document", "parent", "folder", "...") + rightUser := dsIter("folder", "viewer", "user", "...") + rightGroup := dsIter("folder", "viewer", "group", "...") + rightTeam := dsIter("folder", "viewer", "team", "...") + right := NewUnionIterator(rightUser, rightGroup, rightTeam) + arrow := NewIntersectionArrowIterator(left, right) + + result, changed, err := ApplyReachabilityPruning(arrow, "user") + require.NoError(t, err) + require.True(t, changed, "non-matching branches inside the right-side union are pruned") + require.IsType(t, &IntersectionArrowIterator{}, result, "intersection arrow should remain because user is reachable") + + // The right side union should have the group and team branches pruned + rightResult := result.Subiterators()[1] + require.IsType(t, &UnionIterator{}, rightResult) + rightSubs := rightResult.Subiterators() + require.Len(t, rightSubs, 3) + require.IsType(t, &DatastoreIterator{}, rightSubs[0], "user branch should remain") + require.IsType(t, &FixedIterator{}, rightSubs[1], "group branch should be pruned") + require.IsType(t, &FixedIterator{}, rightSubs[2], "team branch should be pruned") + }) + + t.Run("prunes intersection arrow when right side has multiple subject types and none match", func(t *testing.T) { + t.Parallel() + + // IntersectionArrow[document#parent->folder, Union[folder#viewer->group, folder#viewer->team, folder#viewer->org]] + // Right side produces group, team, and org subjects. None match user. + left := dsIter("document", "parent", "folder", "...") + rightGroup := dsIter("folder", "viewer", "group", "...") + rightTeam := dsIter("folder", "viewer", "team", "...") + rightOrg := dsIter("folder", "viewer", "org", "...") + right := NewUnionIterator(rightGroup, rightTeam, rightOrg) + arrow := NewIntersectionArrowIterator(left, right) + + result, changed, err := ApplyReachabilityPruning(arrow, "user") + require.NoError(t, err) + require.True(t, changed) + require.IsType(t, &FixedIterator{}, result, "entire intersection arrow should be pruned") + }) + + t.Run("prunes intersection arrow inside union when unreachable", func(t *testing.T) { + t.Parallel() + + // Union[ + // document#viewer->user, + // IntersectionArrow[document#parent->folder, folder#viewer->group] + // ] + // The intersection arrow's right side produces group, not user, so it should get pruned. + directUser := dsIter("document", "viewer", "user", "...") + left := dsIter("document", "parent", "folder", "...") + right := dsIter("folder", "viewer", "group", "...") + arrow := NewIntersectionArrowIterator(left, right) + union := NewUnionIterator(directUser, arrow) + + result, changed, err := ApplyReachabilityPruning(union, "user") + require.NoError(t, err) + require.True(t, changed) + + require.IsType(t, &UnionIterator{}, result) + subs := result.Subiterators() + require.Len(t, subs, 2) + require.IsType(t, &DatastoreIterator{}, subs[0], "direct user branch should remain") + require.IsType(t, &FixedIterator{}, subs[1], "intersection arrow branch should be pruned") + }) + }) +}