Skip to content

Commit 11ba410

Browse files
authored
feat: implement a proper BFS when doing IterSubjects/IterResources for recursive nodes (#2838)
1 parent dbd2687 commit 11ba410

File tree

3 files changed

+690
-8
lines changed

3 files changed

+690
-8
lines changed

pkg/query/recursive.go

Lines changed: 194 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"fmt"
55

66
"github.com/google/uuid"
7+
8+
"github.com/authzed/spicedb/pkg/tuple"
79
)
810

911
const defaultMaxRecursionDepth = 50
@@ -37,18 +39,14 @@ func (r *RecursiveIterator) CheckImpl(ctx *Context, resources []Object, subject
3739
})
3840
}
3941

40-
// IterSubjectsImpl implements iterative deepening for IterSubjects operations
42+
// IterSubjectsImpl implements BFS traversal for IterSubjects operations
4143
func (r *RecursiveIterator) IterSubjectsImpl(ctx *Context, resource Object) (PathSeq, error) {
42-
return r.iterativeDeepening(ctx, func(ctx *Context, tree Iterator) (PathSeq, error) {
43-
return ctx.IterSubjects(tree, resource)
44-
})
44+
return r.breadthFirstIterSubjects(ctx, resource)
4545
}
4646

47-
// IterResourcesImpl implements iterative deepening for IterResources operations
47+
// IterResourcesImpl implements BFS traversal for IterResources operations
4848
func (r *RecursiveIterator) IterResourcesImpl(ctx *Context, subject ObjectAndRelation) (PathSeq, error) {
49-
return r.iterativeDeepening(ctx, func(ctx *Context, tree Iterator) (PathSeq, error) {
50-
return ctx.IterResources(tree, subject)
51-
})
49+
return r.breadthFirstIterResources(ctx, subject)
5250
}
5351

5452
// iterativeDeepening executes the core iterative deepening algorithm
@@ -215,3 +213,191 @@ func (r *RecursiveIterator) ReplaceSubiterators(newSubs []Iterator) (Iterator, e
215213
func (r *RecursiveIterator) ID() string {
216214
return r.id
217215
}
216+
217+
// breadthFirstIterSubjects implements BFS traversal for IterSubjects operations.
218+
func (r *RecursiveIterator) breadthFirstIterSubjects(ctx *Context, resource Object) (PathSeq, error) {
219+
ctx.TraceStep(r, "BFS IterSubjects starting with resource %s:%s", resource.ObjectType, resource.ObjectID)
220+
221+
return breadthFirstIter(
222+
ctx,
223+
r,
224+
resource,
225+
// Key function: get unique key for a node
226+
func(node Object) string {
227+
return node.Key()
228+
},
229+
// Execute: iterate subjects for a frontier object
230+
func(depth1Tree Iterator, frontierNode Object) (PathSeq, error) {
231+
return ctx.IterSubjects(depth1Tree, frontierNode)
232+
},
233+
// Extract recursive node from path
234+
func(path Path) (Object, bool) {
235+
if r.isRecursiveSubject(path.Subject) {
236+
return GetObject(path.Subject), true
237+
}
238+
return Object{}, false
239+
},
240+
)
241+
}
242+
243+
// breadthFirstIterResources implements BFS traversal for IterResources operations.
244+
func (r *RecursiveIterator) breadthFirstIterResources(ctx *Context, subject ObjectAndRelation) (PathSeq, error) {
245+
ctx.TraceStep(r, "BFS IterResources starting with subject %s:%s#%s",
246+
subject.ObjectType, subject.ObjectID, subject.Relation)
247+
248+
return breadthFirstIter(
249+
ctx,
250+
r,
251+
subject,
252+
ObjectAndRelationKey, // No need for a closure, just call directly!
253+
// Execute: iterate resources for a frontier subject
254+
func(depth1Tree Iterator, frontierNode ObjectAndRelation) (PathSeq, error) {
255+
return ctx.IterResources(depth1Tree, frontierNode)
256+
},
257+
// Extract recursive node from path
258+
func(path Path) (ObjectAndRelation, bool) {
259+
if r.isRecursiveResource(path.Resource) {
260+
return path.Resource.WithEllipses(), true
261+
}
262+
return ObjectAndRelation{}, false
263+
},
264+
)
265+
}
266+
267+
// breadthFirstIter implements the core BFS algorithm for recursive iteration.
268+
// It is a generic function that works with both Object and ObjectAndRelation types.
269+
func breadthFirstIter[T any](
270+
ctx *Context,
271+
r *RecursiveIterator,
272+
startNode T,
273+
keyFn func(node T) string,
274+
executeFn func(depth1Tree Iterator, frontierNode T) (PathSeq, error),
275+
extractNodeFn func(Path) (node T, isRecursive bool),
276+
) (PathSeq, error) {
277+
maxDepth := ctx.MaxRecursionDepth
278+
if maxDepth == 0 {
279+
maxDepth = defaultMaxRecursionDepth
280+
}
281+
282+
// Build depth-1 tree once (one level of recursive expansion)
283+
depth1Tree, err := r.buildTreeAtDepth(1)
284+
if err != nil {
285+
return nil, err
286+
}
287+
288+
return func(yield func(Path, error) bool) {
289+
// Track seen paths globally by endpoints (for cross-ply deduplication)
290+
pathsByEndpoint := make(map[string]Path)
291+
292+
// Track seen recursive nodes to prevent cycles
293+
seenRecursiveNodes := make(map[string]bool)
294+
seenRecursiveNodes[keyFn(startNode)] = true
295+
296+
// Initialize frontier with starting node
297+
currentFrontier := []T{startNode}
298+
299+
for ply := 0; ply < maxDepth && len(currentFrontier) > 0; ply++ {
300+
ctx.TraceStep(r, "Ply %d: exploring %d frontier nodes", ply, len(currentFrontier))
301+
302+
// Collect paths from this ply by endpoint
303+
plyPaths := make(map[string]Path)
304+
var nextFrontier []T
305+
306+
for _, frontierNode := range currentFrontier {
307+
// Execute depth-1 tree on this node
308+
pathSeq, err := executeFn(depth1Tree, frontierNode)
309+
if err != nil {
310+
yield(Path{}, fmt.Errorf("execution failed at ply %d: %w", ply, err))
311+
return
312+
}
313+
314+
for path, err := range pathSeq {
315+
if err != nil {
316+
yield(Path{}, err)
317+
return
318+
}
319+
320+
// Merge paths by endpoint with OR semantics
321+
endpointKey := path.EndpointsKey()
322+
if existing, found := plyPaths[endpointKey]; found {
323+
merged, err := existing.MergeOr(path)
324+
if err != nil {
325+
yield(Path{}, fmt.Errorf("failed to merge paths: %w", err))
326+
return
327+
}
328+
plyPaths[endpointKey] = merged
329+
} else {
330+
plyPaths[endpointKey] = path
331+
}
332+
333+
// Extract recursive nodes for next ply
334+
if node, isRecursive := extractNodeFn(path); isRecursive {
335+
nodeKey := keyFn(node)
336+
if !seenRecursiveNodes[nodeKey] {
337+
seenRecursiveNodes[nodeKey] = true
338+
nextFrontier = append(nextFrontier, node)
339+
ctx.TraceStep(r, "Found recursive node: %s", nodeKey)
340+
}
341+
}
342+
}
343+
}
344+
345+
// Yield new paths and update global map
346+
newPathCount := 0
347+
for endpointKey, path := range plyPaths {
348+
if existing, found := pathsByEndpoint[endpointKey]; found {
349+
// Endpoint already seen in previous ply - merge but don't re-yield
350+
merged, err := existing.MergeOr(path)
351+
if err != nil {
352+
yield(Path{}, fmt.Errorf("failed to merge paths globally: %w", err))
353+
return
354+
}
355+
pathsByEndpoint[endpointKey] = merged
356+
} else {
357+
// New endpoint - add to global map and yield
358+
pathsByEndpoint[endpointKey] = path
359+
newPathCount++
360+
if !yield(path, nil) {
361+
return
362+
}
363+
}
364+
}
365+
366+
ctx.TraceStep(r, "Ply %d: found %d unique paths (%d new), %d nodes for next ply",
367+
ply, len(plyPaths), newPathCount, len(nextFrontier))
368+
369+
currentFrontier = nextFrontier
370+
}
371+
372+
if len(currentFrontier) == 0 {
373+
ctx.TraceStep(r, "BFS completed (no more recursive nodes)")
374+
} else {
375+
ctx.TraceStep(r, "BFS terminated at max depth %d", maxDepth)
376+
}
377+
}, nil
378+
}
379+
380+
// isRecursiveSubject checks if a subject represents a recursive node that should be explored further.
381+
func (r *RecursiveIterator) isRecursiveSubject(subject ObjectAndRelation) bool {
382+
// Must match the definition type
383+
if subject.ObjectType != r.definitionName {
384+
return false
385+
}
386+
387+
// Must match the relation or be ellipsis/empty
388+
// Empty relation means the subject reference doesn't specify a relation
389+
// Ellipsis means "any relation on this object"
390+
if subject.Relation != r.relationName &&
391+
subject.Relation != "" &&
392+
subject.Relation != tuple.Ellipsis {
393+
return false
394+
}
395+
396+
return true
397+
}
398+
399+
// isRecursiveResource checks if a resource represents a recursive node that should be explored further.
400+
func (r *RecursiveIterator) isRecursiveResource(resource Object) bool {
401+
// Resources don't have relations, just check type
402+
return resource.ObjectType == r.definitionName
403+
}

0 commit comments

Comments
 (0)