From 4380ec2ee177ee0548c08daf38d3bf1cb2ae3e43 Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Wed, 11 Mar 2026 12:50:13 -0700 Subject: [PATCH 1/3] feat: add queryopt package --- pkg/query/queryopt/caveat_pushdown.go | 18 ++++++++++++++ pkg/query/queryopt/registry.go | 35 +++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 pkg/query/queryopt/caveat_pushdown.go create mode 100644 pkg/query/queryopt/registry.go diff --git a/pkg/query/queryopt/caveat_pushdown.go b/pkg/query/queryopt/caveat_pushdown.go new file mode 100644 index 000000000..425d6c029 --- /dev/null +++ b/pkg/query/queryopt/caveat_pushdown.go @@ -0,0 +1,18 @@ +package queryopt + +import "github.com/authzed/spicedb/pkg/query" + +func init() { + MustRegisterOptimization(Optimizer{ + Name: "simple-caveat-pushdown", + Description: ` + Pushes caveat evalution to the lowest point in the tree. + Cannot push through intersection arrows + `, + Mutation: caveatPushdown, + }) +} + +func caveatPushdown(outline query.Outline) query.Outline { + // CLAUDE +} diff --git a/pkg/query/queryopt/registry.go b/pkg/query/queryopt/registry.go new file mode 100644 index 000000000..405419ad2 --- /dev/null +++ b/pkg/query/queryopt/registry.go @@ -0,0 +1,35 @@ +package queryopt + +import ( + "fmt" + + "github.com/authzed/spicedb/pkg/query" +) + +var optimizationRegistry = make(map[string]Optimizer) + +func MustRegisterOptimization(opt Optimizer) { + if _, ok := optimizationRegistry[opt.Name]; ok { + panic("queryopt: `" + opt.Name + "` registered twice at initialization") + } + optimizationRegistry[opt.Name] = opt +} + +func GetOptimization(name string) (Optimizer, error) { + v, ok := optimizationRegistry[name] + if !ok { + return Optimizer{}, fmt.Errorf("queryopt: no optimizer named `%s`", name) + } + return v, nil +} + +var StandardOptimzations = []string{ + "simple-caveat-pushdown", +} + +type Optimizer struct { + Name string + Description string + Mutation query.OutlineMutation + Priority int +} From 6d88bf26397666d1e07dededc1ef60246417f36b Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Wed, 11 Mar 2026 13:03:06 -0700 Subject: [PATCH 2/3] feat: port caveat pushdown to queryopt --- pkg/query/advisor.go | 9 +- pkg/query/canonicalize.go | 26 ++ pkg/query/queryopt/caveat_pushdown.go | 105 +++++++- pkg/query/queryopt/caveat_pushdown_test.go | 285 +++++++++++++++++++++ 4 files changed, 423 insertions(+), 2 deletions(-) create mode 100644 pkg/query/queryopt/caveat_pushdown_test.go diff --git a/pkg/query/advisor.go b/pkg/query/advisor.go index 673a6bbfd..5489d5170 100644 --- a/pkg/query/advisor.go +++ b/pkg/query/advisor.go @@ -62,8 +62,11 @@ func ApplyAdvisor(co CanonicalOutline, advisor PlanAdvisor) (CanonicalOutline, e // For each node it calls GetMutations on the advisor and applies the returned // mutations in sequence. After each mutation it verifies that the resulting // node's ID matches the original node's ID; a mismatch is a programmer bug. +// After all mutations are applied, any newly synthesized nodes (ID==0) receive +// fresh IDs via FillMissingNodeIDs, and their CanonicalKeys are recorded in +// the provided keys map. func applyAdvisorMutations(outline Outline, co CanonicalOutline, advisor PlanAdvisor) (Outline, error) { - return WalkOutlineBottomUp(outline, func(node Outline) (Outline, error) { + mutated, err := WalkOutlineBottomUp(outline, func(node Outline) (Outline, error) { mutations, err := advisor.GetMutations(node, co) if err != nil { return Outline{}, err @@ -81,6 +84,10 @@ func applyAdvisorMutations(outline Outline, co CanonicalOutline, advisor PlanAdv } return result, nil }) + if err != nil { + return Outline{}, err + } + return FillMissingNodeIDs(mutated, co.CanonicalKeys), nil } // collectAdvisorHints walks the outline tree pre-order via WalkOutlinePreOrder, diff --git a/pkg/query/canonicalize.go b/pkg/query/canonicalize.go index 67ac88458..1b2e63c90 100644 --- a/pkg/query/canonicalize.go +++ b/pkg/query/canonicalize.go @@ -273,3 +273,29 @@ func assignNodeIDs(outline Outline, keys map[OutlineNodeID]CanonicalKey) Outline keys[id] = outline.Serialize() return outline } + +// FillMissingNodeIDs walks the outline tree bottom-up and assigns fresh +// OutlineNodeIDs to any node where ID == 0, recording their CanonicalKeys +// in the provided map. Nodes that already have an ID are left unchanged, +// and their existing map entries are preserved. +// +// This is used after mutations that may introduce new structural nodes +// (e.g. caveat wrappers, rotated arrows) into an already-canonicalized tree. +func FillMissingNodeIDs(outline Outline, keys map[OutlineNodeID]CanonicalKey) Outline { + // Recurse on children first (bottom-up) + if len(outline.SubOutlines) > 0 { + newSubs := make([]Outline, len(outline.SubOutlines)) + for i, sub := range outline.SubOutlines { + newSubs[i] = FillMissingNodeIDs(sub, keys) + } + outline.SubOutlines = newSubs + } + + if outline.ID == 0 { + id := OutlineNodeID(nodeIDCounter.Add(1)) + outline.ID = id + keys[id] = outline.Serialize() + } + + return outline +} diff --git a/pkg/query/queryopt/caveat_pushdown.go b/pkg/query/queryopt/caveat_pushdown.go index 425d6c029..fd158d3ae 100644 --- a/pkg/query/queryopt/caveat_pushdown.go +++ b/pkg/query/queryopt/caveat_pushdown.go @@ -13,6 +13,109 @@ func init() { }) } +// caveatPushdown is an OutlineMutation that implements caveat pushdown on Outline trees. +// It is called bottom-up by MutateOutline, so by the time a CaveatIteratorType node is +// visited, its children have already been processed. +// +// For a node Caveat(child) it attempts to push the caveat one level deeper: +// +// Caveat(Union[A, B]) → Union[Caveat(A), B] (only A contains the caveat) +// Caveat(Union[A, B]) → Union[Caveat(A), Caveat(B)] (both contain the caveat) +// +// Pushdown is blocked (outline returned unchanged) when: +// - the node is not a CaveatIteratorType +// - the child is an IntersectionArrowIteratorType (special all() semantics) +// - the child is another CaveatIteratorType (would cause infinite recursion) +// - the child has no SubOutlines (leaf node) +// - none of the child's SubOutlines contain the caveat +// +// The original caveat node's ID is threaded through to all newly created caveat +// wrappers. Because CaveatIterator nodes serialize solely by caveat name (independent +// of position), all relocated instances share the same CanonicalKey and therefore +// legitimately carry the same ID. func caveatPushdown(outline query.Outline) query.Outline { - // CLAUDE + if outline.Type != query.CaveatIteratorType { + return outline + } + return caveatPushdownInner(outline, outline.ID) +} + +// caveatPushdownInner is the recursive implementation of caveatPushdown. +// originalID is the ID of the topmost caveat node being pushed; it is threaded +// through to every caveat wrapper created during the descent so that all +// relocated instances of the same caveat carry the original ID. +func caveatPushdownInner(outline query.Outline, originalID query.OutlineNodeID) query.Outline { + // A CaveatIterator must have exactly one child. + if len(outline.SubOutlines) != 1 { + return outline + } + child := outline.SubOutlines[0] + + // Do not push through IntersectionArrow (all() semantics require post-intersection caveat eval). + if child.Type == query.IntersectionArrowIteratorType { + return outline + } + + // Do not push through another CaveatIterator (prevents infinite recursion). + if child.Type == query.CaveatIteratorType { + return outline + } + + // Nothing to push into if the child has no children of its own. + if len(child.SubOutlines) == 0 { + return outline + } + + caveat := outline.Args.Caveat + + // For each grandchild, wrap it in the caveat if it contains the caveat relation, + // then recursively push the new caveat wrapper as deep as it will go. + // The originalID is threaded through so every new caveat wrapper carries the + // same ID as the caveat node being relocated. + newSubs := make([]query.Outline, len(child.SubOutlines)) + changed := false + for i, sub := range child.SubOutlines { + if outlineContainsCaveat(sub, caveat.CaveatName) { + wrapped := query.Outline{ + Type: query.CaveatIteratorType, + Args: outline.Args, + SubOutlines: []query.Outline{sub}, + ID: originalID, + } + newSubs[i] = caveatPushdownInner(wrapped, originalID) + changed = true + } else { + newSubs[i] = sub + } + } + + if !changed { + return outline + } + + // Return the child with updated grandchildren, dropping the outer caveat wrapper. + return query.Outline{ + Type: child.Type, + Args: child.Args, + SubOutlines: newSubs, + ID: child.ID, + } +} + +// outlineContainsCaveat reports whether the outline subtree contains any +// DatastoreIteratorType node whose Relation carries the given caveat name. +func outlineContainsCaveat(outline query.Outline, caveatName string) bool { + if outline.Type == query.DatastoreIteratorType { + if outline.Args != nil && outline.Args.Relation != nil { + if outline.Args.Relation.Caveat() == caveatName { + return true + } + } + } + for _, sub := range outline.SubOutlines { + if outlineContainsCaveat(sub, caveatName) { + return true + } + } + return false } diff --git a/pkg/query/queryopt/caveat_pushdown_test.go b/pkg/query/queryopt/caveat_pushdown_test.go new file mode 100644 index 000000000..5bc267690 --- /dev/null +++ b/pkg/query/queryopt/caveat_pushdown_test.go @@ -0,0 +1,285 @@ +package queryopt + +import ( + "testing" + + "github.com/stretchr/testify/require" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/query" + "github.com/authzed/spicedb/pkg/schema/v2" +) + +// dsOutline returns a DatastoreIteratorType outline whose BaseRelation carries +// the given caveat name (empty string means no caveat). +func dsOutline(caveatName string) query.Outline { + rel := schema.NewTestBaseRelationWithFeatures("document", "viewer", "user", "", caveatName, false) + return query.Outline{ + Type: query.DatastoreIteratorType, + Args: &query.IteratorArgs{Relation: rel}, + } +} + +// caveatArgs returns an IteratorArgs carrying the named caveat. +func caveatArgs(name string) *query.IteratorArgs { + return &query.IteratorArgs{ + Caveat: &core.ContextualizedCaveat{CaveatName: name}, + } +} + +// caveatOutline wraps child in a CaveatIteratorType outline for the named caveat. +func caveatOutline(name string, child query.Outline) query.Outline { + return query.Outline{ + Type: query.CaveatIteratorType, + Args: caveatArgs(name), + SubOutlines: []query.Outline{child}, + } +} + +// unionOutline returns a UnionIteratorType outline with the given children. +func unionOutline(children ...query.Outline) query.Outline { + return query.Outline{ + Type: query.UnionIteratorType, + SubOutlines: children, + } +} + +// intersectionOutline returns an IntersectionIteratorType outline with the given children. +func intersectionOutline(children ...query.Outline) query.Outline { + return query.Outline{ + Type: query.IntersectionIteratorType, + SubOutlines: children, + } +} + +// intersectionArrowOutline returns an IntersectionArrowIteratorType outline with left/right children. +func intersectionArrowOutline(left, right query.Outline) query.Outline { + return query.Outline{ + Type: query.IntersectionArrowIteratorType, + SubOutlines: []query.Outline{left, right}, + } +} + +// applyPushdown runs caveatPushdown bottom-up over outline via MutateOutline. +func applyPushdown(outline query.Outline) query.Outline { + return query.MutateOutline(outline, []query.OutlineMutation{caveatPushdown}) +} + +func TestCaveatPushdown(t *testing.T) { + t.Parallel() + + t.Run("pushes caveat through union when both sides have caveat", func(t *testing.T) { + t.Parallel() + + // Caveat(Union[DS(cav), DS(cav)]) + input := caveatOutline("test_caveat", unionOutline( + dsOutline("test_caveat"), + dsOutline("test_caveat"), + )) + + result := applyPushdown(input) + + // → Union[Caveat(DS(cav)), Caveat(DS(cav))] + require.Equal(t, query.UnionIteratorType, result.Type) + require.Len(t, result.SubOutlines, 2) + require.Equal(t, query.CaveatIteratorType, result.SubOutlines[0].Type) + require.Equal(t, query.CaveatIteratorType, result.SubOutlines[1].Type) + }) + + t.Run("pushes caveat through union only on side with caveat", func(t *testing.T) { + t.Parallel() + + // Caveat(Union[DS(cav), DS(no cav)]) + input := caveatOutline("test_caveat", unionOutline( + dsOutline("test_caveat"), + dsOutline(""), + )) + + result := applyPushdown(input) + + // → Union[Caveat(DS(cav)), DS(no cav)] + require.Equal(t, query.UnionIteratorType, result.Type) + require.Len(t, result.SubOutlines, 2) + require.Equal(t, query.CaveatIteratorType, result.SubOutlines[0].Type) + require.Equal(t, "test_caveat", result.SubOutlines[0].Args.Caveat.CaveatName) + require.Equal(t, query.DatastoreIteratorType, result.SubOutlines[1].Type) + }) + + t.Run("does not push caveat through intersection arrow", func(t *testing.T) { + t.Parallel() + + // Caveat(IntersectionArrow[DS(cav), DS(no cav)]) + input := caveatOutline("test_caveat", intersectionArrowOutline( + dsOutline("test_caveat"), + dsOutline(""), + )) + + result := applyPushdown(input) + + // → unchanged: Caveat(IntersectionArrow[...]) + require.Equal(t, query.CaveatIteratorType, result.Type) + require.Equal(t, query.IntersectionArrowIteratorType, result.SubOutlines[0].Type) + }) + + t.Run("does not push when no children contain the caveat", func(t *testing.T) { + t.Parallel() + + // Caveat(Union[DS(no cav), DS(no cav)]) + input := caveatOutline("test_caveat", unionOutline( + dsOutline(""), + dsOutline(""), + )) + + result := applyPushdown(input) + + // → unchanged: Caveat(Union[...]) + require.Equal(t, query.CaveatIteratorType, result.Type) + require.Equal(t, query.UnionIteratorType, result.SubOutlines[0].Type) + }) + + t.Run("does not push through leaf datastore node", func(t *testing.T) { + t.Parallel() + + // Caveat(DS(cav)) — DS has no SubOutlines, nothing to push into + input := caveatOutline("test_caveat", dsOutline("test_caveat")) + + result := applyPushdown(input) + + // → unchanged: Caveat(DS(cav)) + require.Equal(t, query.CaveatIteratorType, result.Type) + require.Equal(t, query.DatastoreIteratorType, result.SubOutlines[0].Type) + }) + + t.Run("pushes all the way through nested union recursively", func(t *testing.T) { + t.Parallel() + + // Three levels deep: Caveat(Union[Union[Union[DS(cav), DS(no cav)], DS(no cav)], DS(cav)]) + // + // MutateOutline walks bottom-up, so by the time the outer Caveat node is + // visited, all three inner Union nodes have already been processed without + // a caveat. The recursive call inside caveatPushdown is what drives the + // caveat all the way down through the two intermediate Union levels to reach + // the leaf DS(cav). Without it, pushdown would stop one level short. + input := caveatOutline("test_caveat", unionOutline( + unionOutline( + unionOutline( + dsOutline("test_caveat"), + dsOutline(""), + ), + dsOutline(""), + ), + dsOutline("test_caveat"), + )) + + result := applyPushdown(input) + + // → Union[Union[Union[Caveat(DS(cav)), DS(no cav)], DS(no cav)], Caveat(DS(cav))] + require.Equal(t, query.UnionIteratorType, result.Type) + require.Len(t, result.SubOutlines, 2) + + // left branch: the caveat must be pushed all three levels deep + mid := result.SubOutlines[0] + require.Equal(t, query.UnionIteratorType, mid.Type) + require.Len(t, mid.SubOutlines, 2) + require.Equal(t, query.DatastoreIteratorType, mid.SubOutlines[1].Type) // DS(no cav) unchanged + + inner := mid.SubOutlines[0] + require.Equal(t, query.UnionIteratorType, inner.Type) + require.Len(t, inner.SubOutlines, 2) + require.Equal(t, query.CaveatIteratorType, inner.SubOutlines[0].Type) // Caveat(DS(cav)) at the leaf + require.Equal(t, query.DatastoreIteratorType, inner.SubOutlines[1].Type) // DS(no cav) unchanged + + // right branch: leaf DS wrapped directly + require.Equal(t, query.CaveatIteratorType, result.SubOutlines[1].Type) + require.Equal(t, query.DatastoreIteratorType, result.SubOutlines[1].SubOutlines[0].Type) + }) + + t.Run("works with intersection of relations", func(t *testing.T) { + t.Parallel() + + // Caveat(Intersection[DS(cav), DS(no cav)]) + input := caveatOutline("test_caveat", intersectionOutline( + dsOutline("test_caveat"), + dsOutline(""), + )) + + result := applyPushdown(input) + + // → Intersection[Caveat(DS(cav)), DS(no cav)] + require.Equal(t, query.IntersectionIteratorType, result.Type) + require.Len(t, result.SubOutlines, 2) + require.Equal(t, query.CaveatIteratorType, result.SubOutlines[0].Type) + require.Equal(t, query.DatastoreIteratorType, result.SubOutlines[1].Type) + }) + + t.Run("does not push through nested caveat node", func(t *testing.T) { + t.Parallel() + + // Caveat(Caveat(DS(cav))) + // Bottom-up: inner Caveat(DS) is visited first — leaf below, no push. + // Then outer Caveat sees a CaveatIteratorType child — blocked. + input := caveatOutline("test_caveat", caveatOutline("test_caveat", dsOutline("test_caveat"))) + + result := applyPushdown(input) + + // → unchanged: Caveat(Caveat(DS)) + require.Equal(t, query.CaveatIteratorType, result.Type) + require.Equal(t, query.CaveatIteratorType, result.SubOutlines[0].Type) + require.Equal(t, query.DatastoreIteratorType, result.SubOutlines[0].SubOutlines[0].Type) + }) +} + +func TestOutlineContainsCaveat(t *testing.T) { + t.Parallel() + + t.Run("detects caveat directly in datastore node", func(t *testing.T) { + t.Parallel() + + input := caveatOutline("test_caveat", dsOutline("test_caveat")) + // Pushdown is blocked at leaf, so result stays Caveat(DS) — proving the + // caveat was detected (otherwise nothing would have changed to check). + // Test containment indirectly: wrap in union with a no-caveat peer. + input = caveatOutline("test_caveat", unionOutline(dsOutline("test_caveat"), dsOutline(""))) + result := applyPushdown(input) + // The caveat-bearing branch got wrapped, proving detection worked. + require.Equal(t, query.UnionIteratorType, result.Type) + require.Equal(t, query.CaveatIteratorType, result.SubOutlines[0].Type) + require.Equal(t, query.DatastoreIteratorType, result.SubOutlines[1].Type) + }) + + t.Run("does not detect when caveat name differs", func(t *testing.T) { + t.Parallel() + + // Caveat("test_caveat") over a DS with "other_caveat" — no match + input := caveatOutline("test_caveat", unionOutline( + dsOutline("other_caveat"), + dsOutline(""), + )) + result := applyPushdown(input) + + // Neither child contains "test_caveat", so nothing is pushed + require.Equal(t, query.CaveatIteratorType, result.Type) + }) + + t.Run("detects caveat in nested datastore", func(t *testing.T) { + t.Parallel() + + // Caveat(Union[DS(no cav), Union[DS(no cav), DS(cav)]]) + input := caveatOutline("test_caveat", unionOutline( + dsOutline(""), + unionOutline(dsOutline(""), dsOutline("test_caveat")), + )) + result := applyPushdown(input) + + // First child has no caveat anywhere → left unwrapped. + // Second child contains the caveat deep → wrapped and recursively pushed. + require.Equal(t, query.UnionIteratorType, result.Type) + require.Equal(t, query.DatastoreIteratorType, result.SubOutlines[0].Type) + + // The second branch should be a union with the caveat pushed all the way down + innerResult := result.SubOutlines[1] + require.Equal(t, query.UnionIteratorType, innerResult.Type) + require.Equal(t, query.DatastoreIteratorType, innerResult.SubOutlines[0].Type) + require.Equal(t, query.CaveatIteratorType, innerResult.SubOutlines[1].Type) + }) +} From 90a93cb228a604699ceb6a3a60b0e558874bafa8 Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Wed, 11 Mar 2026 13:54:04 -0700 Subject: [PATCH 3/3] refactor: move optimize into the new package, and change callsites --- .../query_plan_consistency_test.go | 16 +- internal/services/v1/permissions_queryplan.go | 15 +- pkg/query/build_tree.go | 10 - pkg/query/build_tree_test.go | 65 ++-- pkg/query/optimize.go | 93 ----- pkg/query/optimize_caveat.go | 93 ----- pkg/query/optimize_caveat_test.go | 340 ------------------ pkg/query/optimize_test.go | 67 ---- pkg/query/queryopt/registry.go | 54 ++- 9 files changed, 114 insertions(+), 639 deletions(-) delete mode 100644 pkg/query/optimize.go delete mode 100644 pkg/query/optimize_caveat.go delete mode 100644 pkg/query/optimize_caveat_test.go delete mode 100644 pkg/query/optimize_test.go diff --git a/internal/services/integrationtesting/query_plan_consistency_test.go b/internal/services/integrationtesting/query_plan_consistency_test.go index e0f04a8a6..d5ca99880 100644 --- a/internal/services/integrationtesting/query_plan_consistency_test.go +++ b/internal/services/integrationtesting/query_plan_consistency_test.go @@ -22,6 +22,7 @@ import ( "github.com/authzed/spicedb/pkg/datalayer" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/query" + "github.com/authzed/spicedb/pkg/query/queryopt" "github.com/authzed/spicedb/pkg/schema/v2" "github.com/authzed/spicedb/pkg/tuple" "github.com/authzed/spicedb/pkg/validationfile" @@ -129,15 +130,18 @@ func runQueryPlanAssertions(t *testing.T, handle *queryPlanConsistencyHandle) { require := require.New(t) rel := assertion.Relationship - it, err := query.BuildIteratorFromSchema(handle.schema, rel.Resource.ObjectType, rel.Resource.Relation) + co, err := query.BuildOutlineFromSchema(handle.schema, rel.Resource.ObjectType, rel.Resource.Relation) require.NoError(err) // Apply static optimizations if requested if optimizationMode.optimize { - it, _, err = query.ApplyOptimizations(it, query.StaticOptimizations) + co, err = queryopt.ApplyOptimizations(co, queryopt.StandardOptimzations) require.NoError(err) } + it, err := co.Compile() + require.NoError(err) + qctx := handle.buildContext(t) // Add caveat context from assertion if available @@ -200,7 +204,9 @@ func runQueryPlanLookupResources(t *testing.T, handle *queryPlanConsistencyHandl t.Run(tuple.StringONR(subject), func(t *testing.T) { accessibleResources := accessibilitySet.LookupAccessibleResources(resourceRelation, subject) queryCtx := handle.buildContext(t) - it, err := query.BuildIteratorFromSchema(handle.schema, resourceRelation.ObjectType, resourceRelation.Relation) + co, err := query.BuildOutlineFromSchema(handle.schema, resourceRelation.ObjectType, resourceRelation.Relation) + require.NoError(t, err) + it, err := co.Compile() require.NoError(t, err) // Perform a lookup call and ensure it returns the at least the same set of object IDs. @@ -254,7 +260,9 @@ func runQueryPlanLookupSubjects(t *testing.T, handle *queryPlanConsistencyHandle t.Run(tuple.StringONR(resource), func(t *testing.T) { accessibleSubjects := accessibilitySet.LookupAccessibleSubjects(resource) queryCtx := handle.buildContext(t) - it, err := query.BuildIteratorFromSchema(handle.schema, resourceRelation.ObjectType, resourceRelation.Relation) + co, err := query.BuildOutlineFromSchema(handle.schema, resourceRelation.ObjectType, resourceRelation.Relation) + require.NoError(t, err) + it, err := co.Compile() require.NoError(t, err) // Perform a lookup call and ensure it returns the at least the same set of subject IDs. diff --git a/internal/services/v1/permissions_queryplan.go b/internal/services/v1/permissions_queryplan.go index 077dd67fe..79f8179da 100644 --- a/internal/services/v1/permissions_queryplan.go +++ b/internal/services/v1/permissions_queryplan.go @@ -10,6 +10,7 @@ import ( "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/middleware/consistency" "github.com/authzed/spicedb/pkg/query" + "github.com/authzed/spicedb/pkg/query/queryopt" "github.com/authzed/spicedb/pkg/schema/v2" ) @@ -50,15 +51,19 @@ func (ps *permissionServer) checkPermissionWithQueryPlan(ctx context.Context, re return nil, ps.rewriteError(ctx, err) } - // Build iterator tree from schema - // TODO: Better iterator caching - it, err := query.BuildIteratorFromSchema(fullSchema, req.Resource.ObjectType, req.Permission) + // Build and optimize the outline, then compile to an iterator tree. + // TODO: Better outline caching + co, err := query.BuildOutlineFromSchema(fullSchema, req.Resource.ObjectType, req.Permission) if err != nil { return nil, ps.rewriteError(ctx, err) } - // Apply basic optimizations to the iterator tree - it, _, err = query.ApplyOptimizations(it, query.StaticOptimizations) + optimized, err := queryopt.ApplyOptimizations(co, queryopt.StandardOptimzations) + if err != nil { + return nil, ps.rewriteError(ctx, err) + } + + it, err := optimized.Compile() if err != nil { return nil, ps.rewriteError(ctx, err) } diff --git a/pkg/query/build_tree.go b/pkg/query/build_tree.go index 247931266..9f084b5c8 100644 --- a/pkg/query/build_tree.go +++ b/pkg/query/build_tree.go @@ -22,16 +22,6 @@ type outlineBuilder struct { recursiveSentinels []*recursiveSentinelInfo // Track recursion points for wrapping in RecursiveIterator } -// BuildIteratorFromSchema takes a schema and walks the schema tree for a given definition namespace and a relationship or -// permission therein. From this, it generates an iterator tree, rooted on that relationship. -func BuildIteratorFromSchema(fullSchema *schema.Schema, definitionName string, relationName string) (Iterator, error) { - canonical, err := BuildOutlineFromSchema(fullSchema, definitionName, relationName) - if err != nil { - return nil, err - } - return canonical.Compile() -} - // BuildOutlineFromSchema builds a canonical Outline tree from the schema. func BuildOutlineFromSchema(fullSchema *schema.Schema, definitionName string, relationName string) (CanonicalOutline, error) { builder := &outlineBuilder{ diff --git a/pkg/query/build_tree_test.go b/pkg/query/build_tree_test.go index e7b616258..2944833d9 100644 --- a/pkg/query/build_tree_test.go +++ b/pkg/query/build_tree_test.go @@ -14,6 +14,19 @@ import ( "github.com/authzed/spicedb/pkg/schema/v2" ) +// buildIterator is a test helper that builds a CanonicalOutline from the schema +// and compiles it into an Iterator. It mirrors the old BuildIteratorFromSchema +// convenience function which has been removed in favour of the explicit +// BuildOutlineFromSchema → Compile() pipeline. +func buildIterator(t *testing.T, fullSchema *schema.Schema, defName, relName string) (Iterator, error) { + t.Helper() + co, err := BuildOutlineFromSchema(fullSchema, defName, relName) + if err != nil { + return nil, err + } + return co.Compile() +} + func TestBuildTree(t *testing.T) { t.Parallel() @@ -28,7 +41,7 @@ func TestBuildTree(t *testing.T) { dsSchema, err := schema.BuildSchemaFromDefinitions(objectDefs, nil) require.NoError(err) - it, err := BuildIteratorFromSchema(dsSchema, "document", "edit") + it, err := buildIterator(t, dsSchema, "document", "edit") require.NoError(err) ctx := NewLocalContext(t.Context(), @@ -55,7 +68,7 @@ func TestBuildTreeMultipleRelations(t *testing.T) { require.NoError(err) // Test building iterator for edit permission which creates a union - it, err := BuildIteratorFromSchema(dsSchema, "document", "edit") + it, err := buildIterator(t, dsSchema, "document", "edit") require.NoError(err) explain := it.Explain() @@ -81,12 +94,12 @@ func TestBuildTreeInvalidDefinition(t *testing.T) { require.NoError(err) // Test with invalid definition name - _, err = BuildIteratorFromSchema(dsSchema, "nonexistent", "edit") + _, err = buildIterator(t, dsSchema, "nonexistent", "edit") require.Error(err) require.Contains(err.Error(), "couldn't find a schema definition named `nonexistent`") // Test with invalid relation/permission name - _, err = BuildIteratorFromSchema(dsSchema, "document", "nonexistent") + _, err = buildIterator(t, dsSchema, "document", "nonexistent") require.Error(err) require.Contains(err.Error(), "couldn't find a relation or permission named `nonexistent`") } @@ -105,7 +118,7 @@ func TestBuildTreeSubRelations(t *testing.T) { require.NoError(err) // Test building iterator for a relation with subrelations - it, err := BuildIteratorFromSchema(dsSchema, "document", "parent") + it, err := buildIterator(t, dsSchema, "document", "parent") require.NoError(err) // Should have created a relation iterator @@ -152,7 +165,7 @@ func TestBuildTreeRecursion(t *testing.T) { // This should detect recursion and create a RecursiveIterator // The arrow operation parent->member creates recursion: group->parent->member->parent->member... - it, err := BuildIteratorFromSchema(dsSchema, "group", "member") + it, err := buildIterator(t, dsSchema, "group", "member") require.NoError(err) require.NotNil(it) @@ -203,7 +216,7 @@ func TestBuildTreeIntersectionOperation(t *testing.T) { require.NoError(err) // Test building iterator for view_and_edit permission which uses intersection operations - it, err := BuildIteratorFromSchema(dsSchema, "document", "view_and_edit") + it, err := buildIterator(t, dsSchema, "document", "view_and_edit") require.NoError(err) // Should create an intersection iterator @@ -245,7 +258,7 @@ func TestBuildTreeExclusionOperation(t *testing.T) { require.NoError(err) // Test building iterator for exclusion permission - should succeed - it, err := BuildIteratorFromSchema(dsSchema, "document", "excluded_perm") + it, err := buildIterator(t, dsSchema, "document", "excluded_perm") require.NoError(err) require.NotNil(it) // Should be wrapped in an Alias @@ -297,7 +310,7 @@ func TestBuildTreeExclusionEdgeCases(t *testing.T) { dsSchema, err := schema.BuildSchemaFromDefinitions(objectDefs, nil) require.NoError(err) - it, err := BuildIteratorFromSchema(dsSchema, "document", "can_view") + it, err := buildIterator(t, dsSchema, "document", "can_view") require.NoError(err) require.NotNil(it) // Should be wrapped in an Alias @@ -336,7 +349,7 @@ func TestBuildTreeExclusionEdgeCases(t *testing.T) { dsSchema, err := schema.BuildSchemaFromDefinitions(objectDefs, nil) require.NoError(err) - it, err := BuildIteratorFromSchema(dsSchema, "document", "restricted_viewers") + it, err := buildIterator(t, dsSchema, "document", "restricted_viewers") require.NoError(err) require.NotNil(it) // Should be wrapped in an Alias @@ -377,7 +390,7 @@ func TestBuildTreeExclusionEdgeCases(t *testing.T) { dsSchema, err := schema.BuildSchemaFromDefinitions(objectDefs, nil) require.NoError(err) - it, err := BuildIteratorFromSchema(dsSchema, "document", "restricted_view") + it, err := buildIterator(t, dsSchema, "document", "restricted_view") require.NoError(err) require.NotNil(it) // Should be wrapped in an Alias @@ -419,7 +432,7 @@ func TestBuildTreeExclusionEdgeCases(t *testing.T) { dsSchema, err := schema.BuildSchemaFromDefinitions(objectDefs, nil) require.NoError(err) - it, err := BuildIteratorFromSchema(dsSchema, "document", "allowed_users") + it, err := buildIterator(t, dsSchema, "document", "allowed_users") require.NoError(err) require.NotNil(it) // Should be wrapped in an Alias @@ -457,7 +470,7 @@ func TestBuildTreeExclusionEdgeCases(t *testing.T) { require.NoError(err) // Building iterator should fail due to missing relation - _, err = BuildIteratorFromSchema(dsSchema, "document", "bad_exclusion") + _, err = buildIterator(t, dsSchema, "document", "bad_exclusion") require.Error(err) require.Contains(err.Error(), "couldn't find a relation or permission named `nonexistent_relation`") }) @@ -480,7 +493,7 @@ func TestBuildTreeExclusionEdgeCases(t *testing.T) { require.NoError(err) // Building iterator should fail due to missing relation - _, err = BuildIteratorFromSchema(dsSchema, "document", "bad_exclusion") + _, err = buildIterator(t, dsSchema, "document", "bad_exclusion") require.Error(err) require.Contains(err.Error(), "couldn't find a relation or permission named `nonexistent_relation`") }) @@ -507,7 +520,7 @@ func TestBuildTreeArrowMissingLeftRelation(t *testing.T) { require.NoError(err) // Test building iterator for arrow with missing left relation - _, err = BuildIteratorFromSchema(dsSchema, "document", "bad_arrow") + _, err = buildIterator(t, dsSchema, "document", "bad_arrow") require.Error(err) require.Contains(err.Error(), "couldn't find left-hand relation for arrow") } @@ -526,7 +539,7 @@ func TestBuildTreeSingleRelationOptimization(t *testing.T) { require.NoError(err) // Test building iterator for a simple relation - should not create unnecessary unions - it, err := BuildIteratorFromSchema(dsSchema, "document", "owner") + it, err := buildIterator(t, dsSchema, "document", "owner") require.NoError(err) // Should create a simple relation iterator without extra union wrappers @@ -580,7 +593,7 @@ func TestBuildTreeSubrelationHandling(t *testing.T) { require.NoError(err) // Should create an alias wrapping union with arrow - it, err := BuildIteratorFromSchema(dsSchema, "document", "viewer") + it, err := buildIterator(t, dsSchema, "document", "viewer") require.NoError(err) require.NotNil(it) @@ -617,7 +630,7 @@ func TestBuildTreeSubrelationHandling(t *testing.T) { dsSchema, err := schema.BuildSchemaFromDefinitions(objectDefs, nil) require.NoError(err) - it, err := BuildIteratorFromSchema(dsSchema, "document", "viewer") + it, err := buildIterator(t, dsSchema, "document", "viewer") require.NoError(err) require.NotNil(it) @@ -649,7 +662,7 @@ func TestBuildTreeSubrelationHandling(t *testing.T) { require.NoError(err) // Should create RecursiveIterator for arrow recursion - it, err := BuildIteratorFromSchema(dsSchema, "document", "viewer") + it, err := buildIterator(t, dsSchema, "document", "viewer") require.NoError(err) require.NotNil(it) @@ -680,7 +693,7 @@ func TestBuildTreeSubrelationHandling(t *testing.T) { require.NoError(err) // Should fail when trying to build iterator due to missing subrelation - _, err = BuildIteratorFromSchema(dsSchema, "document", "viewer") + _, err = buildIterator(t, dsSchema, "document", "viewer") require.Error(err) require.Contains(err.Error(), "couldn't find a relation or permission named `nonexistent`") }) @@ -708,7 +721,7 @@ func TestBuildTreeSubrelationHandling(t *testing.T) { dsSchema, err := schema.BuildSchemaFromDefinitions(objectDefs, nil) require.NoError(err) - it, err := BuildIteratorFromSchema(dsSchema, "document", "viewer") + it, err := buildIterator(t, dsSchema, "document", "viewer") require.NoError(err) require.NotNil(it) @@ -770,7 +783,7 @@ func TestBuildTreeWildcardIterator(t *testing.T) { t.Run("Schema with wildcard creates WildcardIterator", func(t *testing.T) { t.Parallel() - it, err := BuildIteratorFromSchema(dsSchema, "document", "viewer") + it, err := buildIterator(t, dsSchema, "document", "viewer") require.NoError(err) require.NotNil(it) @@ -801,7 +814,7 @@ func TestBuildTreeWildcardIterator(t *testing.T) { mixedSchema, err := schema.BuildSchemaFromDefinitions(mixedObjectDefs, nil) require.NoError(err) - it, err := BuildIteratorFromSchema(mixedSchema, "document", "viewer") + it, err := buildIterator(t, mixedSchema, "document", "viewer") require.NoError(err) require.NotNil(it) @@ -859,7 +872,7 @@ func TestBuildTreeMutualRecursionSentinelFiltering(t *testing.T) { t.Run("document viewer builds successfully with mutual recursion", func(t *testing.T) { t.Parallel() // Build iterator for document#viewer - should detect recursion and wrap properly - it, err := BuildIteratorFromSchema(dsSchema, "document", "viewer") + it, err := buildIterator(t, dsSchema, "document", "viewer") require.NoError(err) require.NotNil(it) @@ -872,7 +885,7 @@ func TestBuildTreeMutualRecursionSentinelFiltering(t *testing.T) { t.Run("otherdocument viewer builds successfully with mutual recursion", func(t *testing.T) { t.Parallel() // Build iterator for otherdocument#viewer - should also handle mutual recursion - it, err := BuildIteratorFromSchema(dsSchema, "otherdocument", "viewer") + it, err := buildIterator(t, dsSchema, "otherdocument", "viewer") require.NoError(err) require.NotNil(it) @@ -890,7 +903,7 @@ func TestBuildTreeMutualRecursionSentinelFiltering(t *testing.T) { // its own sentinels. // Build the tree - it, err := BuildIteratorFromSchema(dsSchema, "document", "viewer") + it, err := buildIterator(t, dsSchema, "document", "viewer") require.NoError(err) require.NotNil(it) diff --git a/pkg/query/optimize.go b/pkg/query/optimize.go deleted file mode 100644 index 49c94eb21..000000000 --- a/pkg/query/optimize.go +++ /dev/null @@ -1,93 +0,0 @@ -package query - -// TypedOptimizerFunc is a function that transforms an iterator of a specific type T -// into a potentially optimized iterator. It returns the optimized iterator, a boolean -// indicating whether any optimization was performed, and an error if the optimization failed. -// -// The type parameter T constrains the function to operate only on specific iterator types, -// providing compile-time type safety when creating typed optimizers. -type TypedOptimizerFunc[T Iterator] func(it T) (Iterator, bool, error) - -// OptimizerFunc is a type-erased wrapper around TypedOptimizerFunc[T] that can be -// stored in a homogeneous list while maintaining type safety at runtime. -type OptimizerFunc func(it Iterator) (Iterator, bool, error) - -// WrapOptimizer wraps a typed TypedOptimizerFunc[T] into a type-erased OptimizerFunc. -// This allows optimizer functions for different concrete iterator types to be stored -// together in a heterogeneous list. -func WrapOptimizer[T Iterator](fn TypedOptimizerFunc[T]) OptimizerFunc { - return func(it Iterator) (Iterator, bool, error) { - if v, ok := it.(T); ok { - return fn(v) - } - return it, false, nil - } -} - -var StaticOptimizations = []OptimizerFunc{ - WrapOptimizer(PushdownCaveatEvaluation), -} - -// ApplyOptimizations recursively applies a list of optimizer functions to an iterator -// tree, transforming it into an optimized form. -// -// The function operates bottom-up, optimizing leafs and subiterators first, and replacing the -// subtrees up to the top, which it then returns. -// -// Parameters: -// - it: The iterator tree to optimize -// - fns: A list of optimizer functions to apply -// -// Returns: -// - The optimized iterator (which may be the same as the input if no optimizations applied) -// - A boolean indicating whether any changes were made -// - An error if any optimization failed -func ApplyOptimizations(it Iterator, fns []OptimizerFunc) (Iterator, bool, error) { - var err error - origSubs := it.Subiterators() - changed := false - if len(origSubs) != 0 { - // Make a copy of the subiterators slice to avoid mutating the original iterator - subs := make([]Iterator, len(origSubs)) - copy(subs, origSubs) - - subChanged := false - for i, subit := range subs { - newit, ok, err := ApplyOptimizations(subit, fns) - if err != nil { - return nil, false, err - } - if ok { - subs[i] = newit - subChanged = true - } - } - if subChanged { - changed = true - it, err = it.ReplaceSubiterators(subs) - if err != nil { - return nil, false, err - } - } - } - - // Apply each optimizer to the current iterator - // If any optimizer transforms the iterator, recursively optimize the new tree - for _, fn := range fns { - newit, fnChanged, err := fn(it) - if err != nil { - return nil, false, err - } - if fnChanged { - // The iterator was transformed - recursively optimize the new tree - // to ensure all optimizations are fully applied - optimizedIt, _, err := ApplyOptimizations(newit, fns) - if err != nil { - return nil, false, err - } - // Return true for changed since we did transform the iterator - return optimizedIt, true, nil - } - } - return it, changed, nil -} diff --git a/pkg/query/optimize_caveat.go b/pkg/query/optimize_caveat.go deleted file mode 100644 index 4c6ec896e..000000000 --- a/pkg/query/optimize_caveat.go +++ /dev/null @@ -1,93 +0,0 @@ -package query - -import ( - core "github.com/authzed/spicedb/pkg/proto/core/v1" - "github.com/authzed/spicedb/pkg/spiceerrors" -) - -// PushdownCaveatEvaluation pushes caveat evaluation down through certain composite iterators -// to allow earlier filtering and better performance. -// -// This optimization transforms: -// -// Caveat(Union[A, B]) -> Union[Caveat(A), B] (if only A contains the caveat) -// Caveat(Union[A, B]) -> Union[Caveat(A), Caveat(B)] (if both contain the caveat) -// -// The pushdown does NOT occur through IntersectionArrow iterators, as they have special -// semantics that require caveat evaluation to happen after the intersection. -func PushdownCaveatEvaluation(c *CaveatIterator) (Iterator, bool, error) { - // Don't push through IntersectionArrow - if _, ok := c.subiterator.(*IntersectionArrowIterator); ok { - return c, false, nil - } - - // Don't push down if the subiterator is already a CaveatIterator - // This prevents infinite recursion - if _, ok := c.subiterator.(*CaveatIterator); ok { - return c, false, nil - } - - // Get the subiterators of the child - subs := c.subiterator.Subiterators() - if len(subs) == 0 { - // No subiterators to push down into (e.g., leaf iterator) - return c, false, nil - } - - // Find which subiterators contain relations with this caveat - newSubs := make([]Iterator, len(subs)) - changed := false - for i, sub := range subs { - if containsCaveat(sub, c.caveat) { - // Wrap this subiterator with the caveat - newSubs[i] = NewCaveatIterator(sub, c.caveat) - changed = true - } else { - // Leave unchanged - newSubs[i] = sub - } - } - - if !changed { - return c, false, nil - } - - // Replace the subiterators in the child iterator - newChild, err := c.subiterator.ReplaceSubiterators(newSubs) - if err != nil { - return nil, false, err - } - - // Return the child without the caveat wrapper - return newChild, true, nil -} - -// containsCaveat checks if an iterator tree contains a DatastoreIterator -// that references the given caveat. -func containsCaveat(it Iterator, caveat *core.ContextualizedCaveat) bool { - found := false - _, err := Walk(it, func(node Iterator) (Iterator, error) { - if rel, ok := node.(*DatastoreIterator); ok { - if relationContainsCaveat(rel, caveat) { - found = true - } - } - return node, nil - }) - if err != nil { - spiceerrors.MustPanicf("should never error -- callback contains no errors, but linters must always check") - } - - return found -} - -// relationContainsCaveat checks if a DatastoreIterator's base relation -// has a caveat that matches the given caveat name. -func relationContainsCaveat(rel *DatastoreIterator, caveat *core.ContextualizedCaveat) bool { - if rel.base == nil || caveat == nil { - return false - } - - // Check if the relation has this caveat - return rel.base.Caveat() == caveat.CaveatName -} diff --git a/pkg/query/optimize_caveat_test.go b/pkg/query/optimize_caveat_test.go deleted file mode 100644 index f9c374a33..000000000 --- a/pkg/query/optimize_caveat_test.go +++ /dev/null @@ -1,340 +0,0 @@ -package query - -import ( - "testing" - - "github.com/stretchr/testify/require" - - core "github.com/authzed/spicedb/pkg/proto/core/v1" - "github.com/authzed/spicedb/pkg/schema/v2" -) - -// createTestCaveatForPushdown creates a test ContextualizedCaveat -func createTestCaveatForPushdown(name string) *core.ContextualizedCaveat { - return &core.ContextualizedCaveat{ - CaveatName: name, - Context: nil, - } -} - -// createTestDatastoreIterator creates a DatastoreIterator with a caveat -func createTestDatastoreIterator(caveatName string) *DatastoreIterator { - // Create a BaseRelation with the caveat - baseRelation := schema.NewTestBaseRelationWithFeatures("document", "viewer", "user", "", caveatName, false) - return NewDatastoreIterator(baseRelation) -} - -// createTestDatastoreIteratorNoCaveat creates a DatastoreIterator without a caveat -func createTestDatastoreIteratorNoCaveat() *DatastoreIterator { - baseRelation := schema.NewTestBaseRelationWithFeatures("document", "viewer", "user", "", "", false) - return NewDatastoreIterator(baseRelation) -} - -func TestPushdownCaveatEvaluation(t *testing.T) { - t.Parallel() - - t.Run("pushes caveat through union when both sides have caveat", func(t *testing.T) { - t.Parallel() - - caveat := createTestCaveatForPushdown("test_caveat") - - // Create Union[Relation(with caveat), Relation(with caveat)] - rel1 := createTestDatastoreIterator("test_caveat") - rel2 := createTestDatastoreIterator("test_caveat") - union := NewUnionIterator(rel1, rel2) - - // Wrap in caveat: Caveat(Union[Rel1, Rel2]) - caveatIterator := NewCaveatIterator(union, caveat) - - // Apply optimization - result, changed, err := ApplyOptimizations(caveatIterator, []OptimizerFunc{ - WrapOptimizer[*CaveatIterator](PushdownCaveatEvaluation), - }) - require.NoError(t, err) - require.True(t, changed) - - // Should become Union[Caveat(Rel1), Caveat(Rel2)] - require.IsType(t, &UnionIterator{}, result, "Expected result to be a Union") - resultUnion := result.(*UnionIterator) - require.Len(t, resultUnion.subIts, 2) - - // Both should be wrapped in caveats - require.IsType(t, &CaveatIterator{}, resultUnion.subIts[0], "First subiterator should be a CaveatIterator") - require.IsType(t, &CaveatIterator{}, resultUnion.subIts[1], "Second subiterator should be a CaveatIterator") - }) - - t.Run("pushes caveat through union only on side with caveat", func(t *testing.T) { - t.Parallel() - - caveat := createTestCaveatForPushdown("test_caveat") - - // Create Union[Relation(with caveat), Relation(no caveat)] - rel1 := createTestDatastoreIterator("test_caveat") - rel2 := createTestDatastoreIteratorNoCaveat() - union := NewUnionIterator(rel1, rel2) - - // Wrap in caveat: Caveat(Union[Rel1, Rel2]) - caveatIterator := NewCaveatIterator(union, caveat) - - // Apply optimization - result, changed, err := ApplyOptimizations(caveatIterator, []OptimizerFunc{ - WrapOptimizer[*CaveatIterator](PushdownCaveatEvaluation), - }) - require.NoError(t, err) - require.True(t, changed) - - // Should become Union[Caveat(Rel1), Rel2] - require.IsType(t, &UnionIterator{}, result, "Expected result to be a Union") - resultUnion := result.(*UnionIterator) - require.Len(t, resultUnion.subIts, 2) - - // First should be wrapped, second should not - require.IsType(t, &CaveatIterator{}, resultUnion.subIts[0], "First subiterator should be a CaveatIterator") - require.IsType(t, &DatastoreIterator{}, resultUnion.subIts[1], "Second subiterator should be a DatastoreIterator (not wrapped)") - - // Verify the caveat wraps the correct relation - caveat1 := resultUnion.subIts[0].(*CaveatIterator) - rel2Result := resultUnion.subIts[1].(*DatastoreIterator) - require.IsType(t, &DatastoreIterator{}, caveat1.subiterator) - caveat1Sub := caveat1.subiterator.(*DatastoreIterator) - require.Equal(t, rel1, caveat1Sub) - require.Equal(t, rel2, rel2Result) - }) - - t.Run("does not push caveat through intersection arrow", func(t *testing.T) { - t.Parallel() - - caveat := createTestCaveatForPushdown("test_caveat") - - // Create an IntersectionArrow with a relation that has the caveat - rel := createTestDatastoreIterator("test_caveat") - relNoCaveat := createTestDatastoreIteratorNoCaveat() - intersectionArrow := NewIntersectionArrowIterator(rel, relNoCaveat) - - // Wrap in caveat - caveatIterator := NewCaveatIterator(intersectionArrow, caveat) - - // Apply optimization - result, changed, err := ApplyOptimizations(caveatIterator, []OptimizerFunc{ - WrapOptimizer[*CaveatIterator](PushdownCaveatEvaluation), - }) - require.NoError(t, err) - require.False(t, changed, "Should not optimize through IntersectionArrow") - - // Should remain as Caveat(IntersectionArrow) - require.IsType(t, &CaveatIterator{}, result, "Expected result to still be a CaveatIterator") - resultCaveat := result.(*CaveatIterator) - require.IsType(t, &IntersectionArrowIterator{}, resultCaveat.subiterator, "Subiterator should still be IntersectionArrow") - }) - - t.Run("does not push when no subiterators have caveat", func(t *testing.T) { - t.Parallel() - - caveat := createTestCaveatForPushdown("test_caveat") - - // Create Union[Relation(no caveat), Relation(no caveat)] - rel1 := createTestDatastoreIteratorNoCaveat() - rel2 := createTestDatastoreIteratorNoCaveat() - union := NewUnionIterator(rel1, rel2) - - // Wrap in caveat: Caveat(Union[Rel1, Rel2]) - caveatIterator := NewCaveatIterator(union, caveat) - - // Apply optimization - result, changed, err := ApplyOptimizations(caveatIterator, []OptimizerFunc{ - WrapOptimizer[*CaveatIterator](PushdownCaveatEvaluation), - }) - require.NoError(t, err) - require.False(t, changed) - - // Should remain unchanged - require.IsType(t, &CaveatIterator{}, result) - resultCaveat := result.(*CaveatIterator) - require.Equal(t, caveatIterator, resultCaveat) - }) - - t.Run("does not push through leaf iterator", func(t *testing.T) { - t.Parallel() - - caveat := createTestCaveatForPushdown("test_caveat") - - // Create Caveat(Relation) - leaf has no subiterators - rel := createTestDatastoreIterator("test_caveat") - caveatIterator := NewCaveatIterator(rel, caveat) - - // Apply optimization - result, changed, err := ApplyOptimizations(caveatIterator, []OptimizerFunc{ - WrapOptimizer[*CaveatIterator](PushdownCaveatEvaluation), - }) - require.NoError(t, err) - require.False(t, changed) - - // Should remain unchanged - require.IsType(t, &CaveatIterator{}, result) - resultCaveat := result.(*CaveatIterator) - require.Equal(t, caveatIterator, resultCaveat) - }) - - t.Run("pushes through nested union", func(t *testing.T) { - t.Parallel() - - caveat := createTestCaveatForPushdown("test_caveat") - - // Create Caveat(Union[Union[Rel1, Rel2], Rel3]) - rel1 := createTestDatastoreIterator("test_caveat") - rel2 := createTestDatastoreIteratorNoCaveat() - innerUnion := NewUnionIterator(rel1, rel2) - - rel3 := createTestDatastoreIterator("test_caveat") - outerUnion := NewUnionIterator(innerUnion, rel3) - - caveatIterator := NewCaveatIterator(outerUnion, caveat) - - // Apply optimization - result, changed, err := ApplyOptimizations(caveatIterator, []OptimizerFunc{ - WrapOptimizer[*CaveatIterator](PushdownCaveatEvaluation), - }) - require.NoError(t, err) - require.True(t, changed) - - // Due to recursive optimization, this will become: - // Union[Union[Caveat(Rel1), Rel2], Caveat(Rel3)] - // The outer caveat pushes down to wrap innerUnion and rel3 - // Then the caveat on innerUnion recursively pushes down to only wrap rel1 - require.IsType(t, &UnionIterator{}, result) - resultUnion := result.(*UnionIterator) - require.Len(t, resultUnion.subIts, 2) - - // First should be Union[Caveat(Rel1), Rel2] (caveat pushed down further) - require.IsType(t, &UnionIterator{}, resultUnion.subIts[0], "First subiterator should be a Union (caveat pushed down)") - innerResultUnion := resultUnion.subIts[0].(*UnionIterator) - require.Len(t, innerResultUnion.subIts, 2) - require.IsType(t, &CaveatIterator{}, innerResultUnion.subIts[0], "First element of inner union should be Caveat(Rel1)") - require.IsType(t, &DatastoreIterator{}, innerResultUnion.subIts[1], "Second element of inner union should be Rel2 (no caveat)") - - // Second should be Caveat(Rel3) - require.IsType(t, &CaveatIterator{}, resultUnion.subIts[1]) - caveat2 := resultUnion.subIts[1].(*CaveatIterator) - require.IsType(t, &DatastoreIterator{}, caveat2.subiterator, "Second subiterator should be Caveat(Relation)") - }) - - t.Run("works with intersection of relations", func(t *testing.T) { - t.Parallel() - - caveat := createTestCaveatForPushdown("test_caveat") - - // Create Caveat(Intersection[Rel1(with caveat), Rel2(no caveat)]) - rel1 := createTestDatastoreIterator("test_caveat") - rel2 := createTestDatastoreIteratorNoCaveat() - intersection := NewIntersectionIterator(rel1, rel2) - - caveatIterator := NewCaveatIterator(intersection, caveat) - - // Apply optimization - result, changed, err := ApplyOptimizations(caveatIterator, []OptimizerFunc{ - WrapOptimizer[*CaveatIterator](PushdownCaveatEvaluation), - }) - require.NoError(t, err) - require.True(t, changed) - - // Should become Intersection[Caveat(Rel1), Rel2] - require.IsType(t, &IntersectionIterator{}, result) - resultIntersection := result.(*IntersectionIterator) - require.Len(t, resultIntersection.subIts, 2) - - // First should be wrapped, second should not - require.IsType(t, &CaveatIterator{}, resultIntersection.subIts[0], "First subiterator should be a CaveatIterator") - require.IsType(t, &DatastoreIterator{}, resultIntersection.subIts[1], "Second subiterator should be a DatastoreIterator") - }) -} - -func TestContainsCaveat(t *testing.T) { - t.Parallel() - - caveat := createTestCaveatForPushdown("test_caveat") - - t.Run("detects caveat in relation iterator", func(t *testing.T) { - t.Parallel() - - rel := createTestDatastoreIterator("test_caveat") - require.True(t, containsCaveat(rel, caveat)) - }) - - t.Run("does not detect when caveat name differs", func(t *testing.T) { - t.Parallel() - - rel := createTestDatastoreIterator("other_caveat") - require.False(t, containsCaveat(rel, caveat)) - }) - - t.Run("does not detect when no caveat", func(t *testing.T) { - t.Parallel() - - rel := createTestDatastoreIteratorNoCaveat() - require.False(t, containsCaveat(rel, caveat)) - }) - - t.Run("detects caveat in nested structure", func(t *testing.T) { - t.Parallel() - - rel1 := createTestDatastoreIteratorNoCaveat() - rel2 := createTestDatastoreIterator("test_caveat") - union := NewUnionIterator(rel1, rel2) - - require.True(t, containsCaveat(union, caveat)) - }) - - t.Run("does not detect caveat in structure without it", func(t *testing.T) { - t.Parallel() - - rel1 := createTestDatastoreIteratorNoCaveat() - rel2 := createTestDatastoreIteratorNoCaveat() - union := NewUnionIterator(rel1, rel2) - - require.False(t, containsCaveat(union, caveat)) - }) - - t.Run("handles nil caveat in relationContainsCaveat", func(t *testing.T) { - t.Parallel() - - rel := createTestDatastoreIterator("test_caveat") - require.False(t, relationContainsCaveat(rel, nil)) - }) - - t.Run("handles relation with nil base in relationContainsCaveat", func(t *testing.T) { - t.Parallel() - - caveat := createTestCaveatForPushdown("test_caveat") - // Create a DatastoreIterator with nil base - rel := &DatastoreIterator{base: nil} - require.False(t, relationContainsCaveat(rel, caveat)) - }) -} - -func TestPushdownCaveatEvaluationEdgeCases(t *testing.T) { - t.Parallel() - - t.Run("does not push through nested CaveatIterator", func(t *testing.T) { - t.Parallel() - - caveat := createTestCaveatForPushdown("test_caveat") - - // Create Caveat(Caveat(Relation)) - rel := createTestDatastoreIterator("test_caveat") - innerCaveat := NewCaveatIterator(rel, caveat) - outerCaveat := NewCaveatIterator(innerCaveat, caveat) - - // Apply optimization - result, changed, err := ApplyOptimizations(outerCaveat, []OptimizerFunc{ - WrapOptimizer[*CaveatIterator](PushdownCaveatEvaluation), - }) - require.NoError(t, err) - require.False(t, changed, "Should not push through nested CaveatIterator to prevent infinite recursion") - - // Should remain unchanged - resultCaveat, ok := result.(*CaveatIterator) - require.True(t, ok) - _, ok = resultCaveat.subiterator.(*CaveatIterator) - require.True(t, ok, "Subiterator should still be a CaveatIterator") - }) -} diff --git a/pkg/query/optimize_test.go b/pkg/query/optimize_test.go deleted file mode 100644 index 23d941138..000000000 --- a/pkg/query/optimize_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package query - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -// newNonEmptyFixedIterator creates a FixedIterator with at least one path for testing -func newNonEmptyFixedIterator() *FixedIterator { - return NewFixedIterator(Path{ - Resource: Object{ObjectType: "doc", ObjectID: "test"}, - Relation: "viewer", - Subject: ObjectAndRelation{ - ObjectType: "user", - ObjectID: "alice", - Relation: "...", - }, - }) -} - -func TestWrapOptimizer(t *testing.T) { - t.Parallel() - - t.Run("matches correct type", func(t *testing.T) { - t.Parallel() - - // Create a typed optimizer that only works on Union - typedOptimizer := func(u *UnionIterator) (Iterator, bool, error) { - if len(u.subIts) == 1 { - return u.subIts[0], true, nil - } - return u, false, nil - } - - // Wrap it and use in ApplyOptimizations - wrapped := WrapOptimizer[*UnionIterator](typedOptimizer) - - // Test with a Union - should match and optimize - fixed := newNonEmptyFixedIterator() - union := NewUnionIterator(fixed) - - result, changed, err := ApplyOptimizations(union, []OptimizerFunc{wrapped}) - require.NoError(t, err) - require.True(t, changed) - require.Equal(t, fixed, result) - }) - - t.Run("does not match wrong type", func(t *testing.T) { - t.Parallel() - - // Create a typed optimizer that only works on Union - typedOptimizer := func(u *UnionIterator) (Iterator, bool, error) { - return u, true, nil // Would return true if called - } - - // Wrap it and use in ApplyOptimizations - wrapped := WrapOptimizer[*UnionIterator](typedOptimizer) - - // Test with an Intersection - should not match - intersection := NewIntersectionIterator() - result, changed, err := ApplyOptimizations(intersection, []OptimizerFunc{wrapped}) - require.NoError(t, err) - require.False(t, changed) - require.Equal(t, intersection, result) - }) -} diff --git a/pkg/query/queryopt/registry.go b/pkg/query/queryopt/registry.go index 405419ad2..2262f7a35 100644 --- a/pkg/query/queryopt/registry.go +++ b/pkg/query/queryopt/registry.go @@ -1,7 +1,9 @@ package queryopt import ( + "cmp" "fmt" + "slices" "github.com/authzed/spicedb/pkg/query" ) @@ -23,13 +25,63 @@ func GetOptimization(name string) (Optimizer, error) { return v, nil } +// StandardOptimzations is the default set of optimization names applied when +// no custom selection is required. var StandardOptimzations = []string{ "simple-caveat-pushdown", } +// Optimizer describes a single named outline optimization. type Optimizer struct { Name string Description string Mutation query.OutlineMutation - Priority int + // Priority controls the order in which optimizations are applied. + // Higher values run first. + Priority int +} + +// ApplyOptimizations looks up each optimizer by name, sorts them by descending +// Priority (higher priority runs first), and applies their mutations to the +// outline via query.MutateOutline. Nodes that survive mutation unchanged keep +// their existing IDs. Newly synthesized nodes (ID==0) receive fresh IDs via +// query.FillMissingNodeIDs. The returned CanonicalOutline is ready to compile. +func ApplyOptimizations(co query.CanonicalOutline, names []string) (query.CanonicalOutline, error) { + // Look up each named optimizer. + opts := make([]Optimizer, 0, len(names)) + for _, name := range names { + opt, err := GetOptimization(name) + if err != nil { + return query.CanonicalOutline{}, err + } + opts = append(opts, opt) + } + + // Sort by descending priority so higher-priority mutations run first. + slices.SortFunc(opts, func(a, b Optimizer) int { + return cmp.Compare(b.Priority, a.Priority) + }) + + // Collect mutations in priority order. + mutations := make([]query.OutlineMutation, len(opts)) + for i, opt := range opts { + mutations[i] = opt.Mutation + } + + // Apply all mutations bottom-up. + mutated := query.MutateOutline(co.Root, mutations) + + // Extend the CanonicalKeys map with entries for any newly created nodes. + // We copy first so the original CanonicalOutline's map is not mutated. + extendedKeys := make(map[query.OutlineNodeID]query.CanonicalKey, len(co.CanonicalKeys)) + for id, key := range co.CanonicalKeys { + extendedKeys[id] = key + } + filled := query.FillMissingNodeIDs(mutated, extendedKeys) + + return query.CanonicalOutline{ + Root: filled, + CanonicalKeys: extendedKeys, + Hints: co.Hints, + }, nil }