|
4 | 4 | "fmt" |
5 | 5 |
|
6 | 6 | "github.com/google/uuid" |
| 7 | + |
| 8 | + "github.com/authzed/spicedb/pkg/tuple" |
7 | 9 | ) |
8 | 10 |
|
9 | 11 | const defaultMaxRecursionDepth = 50 |
@@ -37,18 +39,14 @@ func (r *RecursiveIterator) CheckImpl(ctx *Context, resources []Object, subject |
37 | 39 | }) |
38 | 40 | } |
39 | 41 |
|
40 | | -// IterSubjectsImpl implements iterative deepening for IterSubjects operations |
| 42 | +// IterSubjectsImpl implements BFS traversal for IterSubjects operations |
41 | 43 | func (r *RecursiveIterator) IterSubjectsImpl(ctx *Context, resource Object) (PathSeq, error) { |
42 | | - return r.iterativeDeepening(ctx, func(ctx *Context, tree Iterator) (PathSeq, error) { |
43 | | - return ctx.IterSubjects(tree, resource) |
44 | | - }) |
| 44 | + return r.breadthFirstIterSubjects(ctx, resource) |
45 | 45 | } |
46 | 46 |
|
47 | | -// IterResourcesImpl implements iterative deepening for IterResources operations |
| 47 | +// IterResourcesImpl implements BFS traversal for IterResources operations |
48 | 48 | func (r *RecursiveIterator) IterResourcesImpl(ctx *Context, subject ObjectAndRelation) (PathSeq, error) { |
49 | | - return r.iterativeDeepening(ctx, func(ctx *Context, tree Iterator) (PathSeq, error) { |
50 | | - return ctx.IterResources(tree, subject) |
51 | | - }) |
| 49 | + return r.breadthFirstIterResources(ctx, subject) |
52 | 50 | } |
53 | 51 |
|
54 | 52 | // iterativeDeepening executes the core iterative deepening algorithm |
@@ -215,3 +213,191 @@ func (r *RecursiveIterator) ReplaceSubiterators(newSubs []Iterator) (Iterator, e |
215 | 213 | func (r *RecursiveIterator) ID() string { |
216 | 214 | return r.id |
217 | 215 | } |
| 216 | + |
| 217 | +// breadthFirstIterSubjects implements BFS traversal for IterSubjects operations. |
| 218 | +func (r *RecursiveIterator) breadthFirstIterSubjects(ctx *Context, resource Object) (PathSeq, error) { |
| 219 | + ctx.TraceStep(r, "BFS IterSubjects starting with resource %s:%s", resource.ObjectType, resource.ObjectID) |
| 220 | + |
| 221 | + return breadthFirstIter( |
| 222 | + ctx, |
| 223 | + r, |
| 224 | + resource, |
| 225 | + // Key function: get unique key for a node |
| 226 | + func(node Object) string { |
| 227 | + return node.Key() |
| 228 | + }, |
| 229 | + // Execute: iterate subjects for a frontier object |
| 230 | + func(depth1Tree Iterator, frontierNode Object) (PathSeq, error) { |
| 231 | + return ctx.IterSubjects(depth1Tree, frontierNode) |
| 232 | + }, |
| 233 | + // Extract recursive node from path |
| 234 | + func(path Path) (Object, bool) { |
| 235 | + if r.isRecursiveSubject(path.Subject) { |
| 236 | + return GetObject(path.Subject), true |
| 237 | + } |
| 238 | + return Object{}, false |
| 239 | + }, |
| 240 | + ) |
| 241 | +} |
| 242 | + |
| 243 | +// breadthFirstIterResources implements BFS traversal for IterResources operations. |
| 244 | +func (r *RecursiveIterator) breadthFirstIterResources(ctx *Context, subject ObjectAndRelation) (PathSeq, error) { |
| 245 | + ctx.TraceStep(r, "BFS IterResources starting with subject %s:%s#%s", |
| 246 | + subject.ObjectType, subject.ObjectID, subject.Relation) |
| 247 | + |
| 248 | + return breadthFirstIter( |
| 249 | + ctx, |
| 250 | + r, |
| 251 | + subject, |
| 252 | + ObjectAndRelationKey, // No need for a closure, just call directly! |
| 253 | + // Execute: iterate resources for a frontier subject |
| 254 | + func(depth1Tree Iterator, frontierNode ObjectAndRelation) (PathSeq, error) { |
| 255 | + return ctx.IterResources(depth1Tree, frontierNode) |
| 256 | + }, |
| 257 | + // Extract recursive node from path |
| 258 | + func(path Path) (ObjectAndRelation, bool) { |
| 259 | + if r.isRecursiveResource(path.Resource) { |
| 260 | + return path.Resource.WithEllipses(), true |
| 261 | + } |
| 262 | + return ObjectAndRelation{}, false |
| 263 | + }, |
| 264 | + ) |
| 265 | +} |
| 266 | + |
| 267 | +// breadthFirstIter implements the core BFS algorithm for recursive iteration. |
| 268 | +// It is a generic function that works with both Object and ObjectAndRelation types. |
| 269 | +func breadthFirstIter[T any]( |
| 270 | + ctx *Context, |
| 271 | + r *RecursiveIterator, |
| 272 | + startNode T, |
| 273 | + keyFn func(node T) string, |
| 274 | + executeFn func(depth1Tree Iterator, frontierNode T) (PathSeq, error), |
| 275 | + extractNodeFn func(Path) (node T, isRecursive bool), |
| 276 | +) (PathSeq, error) { |
| 277 | + maxDepth := ctx.MaxRecursionDepth |
| 278 | + if maxDepth == 0 { |
| 279 | + maxDepth = defaultMaxRecursionDepth |
| 280 | + } |
| 281 | + |
| 282 | + // Build depth-1 tree once (one level of recursive expansion) |
| 283 | + depth1Tree, err := r.buildTreeAtDepth(1) |
| 284 | + if err != nil { |
| 285 | + return nil, err |
| 286 | + } |
| 287 | + |
| 288 | + return func(yield func(Path, error) bool) { |
| 289 | + // Track seen paths globally by endpoints (for cross-ply deduplication) |
| 290 | + pathsByEndpoint := make(map[string]Path) |
| 291 | + |
| 292 | + // Track seen recursive nodes to prevent cycles |
| 293 | + seenRecursiveNodes := make(map[string]bool) |
| 294 | + seenRecursiveNodes[keyFn(startNode)] = true |
| 295 | + |
| 296 | + // Initialize frontier with starting node |
| 297 | + currentFrontier := []T{startNode} |
| 298 | + |
| 299 | + for ply := 0; ply < maxDepth && len(currentFrontier) > 0; ply++ { |
| 300 | + ctx.TraceStep(r, "Ply %d: exploring %d frontier nodes", ply, len(currentFrontier)) |
| 301 | + |
| 302 | + // Collect paths from this ply by endpoint |
| 303 | + plyPaths := make(map[string]Path) |
| 304 | + var nextFrontier []T |
| 305 | + |
| 306 | + for _, frontierNode := range currentFrontier { |
| 307 | + // Execute depth-1 tree on this node |
| 308 | + pathSeq, err := executeFn(depth1Tree, frontierNode) |
| 309 | + if err != nil { |
| 310 | + yield(Path{}, fmt.Errorf("execution failed at ply %d: %w", ply, err)) |
| 311 | + return |
| 312 | + } |
| 313 | + |
| 314 | + for path, err := range pathSeq { |
| 315 | + if err != nil { |
| 316 | + yield(Path{}, err) |
| 317 | + return |
| 318 | + } |
| 319 | + |
| 320 | + // Merge paths by endpoint with OR semantics |
| 321 | + endpointKey := path.EndpointsKey() |
| 322 | + if existing, found := plyPaths[endpointKey]; found { |
| 323 | + merged, err := existing.MergeOr(path) |
| 324 | + if err != nil { |
| 325 | + yield(Path{}, fmt.Errorf("failed to merge paths: %w", err)) |
| 326 | + return |
| 327 | + } |
| 328 | + plyPaths[endpointKey] = merged |
| 329 | + } else { |
| 330 | + plyPaths[endpointKey] = path |
| 331 | + } |
| 332 | + |
| 333 | + // Extract recursive nodes for next ply |
| 334 | + if node, isRecursive := extractNodeFn(path); isRecursive { |
| 335 | + nodeKey := keyFn(node) |
| 336 | + if !seenRecursiveNodes[nodeKey] { |
| 337 | + seenRecursiveNodes[nodeKey] = true |
| 338 | + nextFrontier = append(nextFrontier, node) |
| 339 | + ctx.TraceStep(r, "Found recursive node: %s", nodeKey) |
| 340 | + } |
| 341 | + } |
| 342 | + } |
| 343 | + } |
| 344 | + |
| 345 | + // Yield new paths and update global map |
| 346 | + newPathCount := 0 |
| 347 | + for endpointKey, path := range plyPaths { |
| 348 | + if existing, found := pathsByEndpoint[endpointKey]; found { |
| 349 | + // Endpoint already seen in previous ply - merge but don't re-yield |
| 350 | + merged, err := existing.MergeOr(path) |
| 351 | + if err != nil { |
| 352 | + yield(Path{}, fmt.Errorf("failed to merge paths globally: %w", err)) |
| 353 | + return |
| 354 | + } |
| 355 | + pathsByEndpoint[endpointKey] = merged |
| 356 | + } else { |
| 357 | + // New endpoint - add to global map and yield |
| 358 | + pathsByEndpoint[endpointKey] = path |
| 359 | + newPathCount++ |
| 360 | + if !yield(path, nil) { |
| 361 | + return |
| 362 | + } |
| 363 | + } |
| 364 | + } |
| 365 | + |
| 366 | + ctx.TraceStep(r, "Ply %d: found %d unique paths (%d new), %d nodes for next ply", |
| 367 | + ply, len(plyPaths), newPathCount, len(nextFrontier)) |
| 368 | + |
| 369 | + currentFrontier = nextFrontier |
| 370 | + } |
| 371 | + |
| 372 | + if len(currentFrontier) == 0 { |
| 373 | + ctx.TraceStep(r, "BFS completed (no more recursive nodes)") |
| 374 | + } else { |
| 375 | + ctx.TraceStep(r, "BFS terminated at max depth %d", maxDepth) |
| 376 | + } |
| 377 | + }, nil |
| 378 | +} |
| 379 | + |
| 380 | +// isRecursiveSubject checks if a subject represents a recursive node that should be explored further. |
| 381 | +func (r *RecursiveIterator) isRecursiveSubject(subject ObjectAndRelation) bool { |
| 382 | + // Must match the definition type |
| 383 | + if subject.ObjectType != r.definitionName { |
| 384 | + return false |
| 385 | + } |
| 386 | + |
| 387 | + // Must match the relation or be ellipsis/empty |
| 388 | + // Empty relation means the subject reference doesn't specify a relation |
| 389 | + // Ellipsis means "any relation on this object" |
| 390 | + if subject.Relation != r.relationName && |
| 391 | + subject.Relation != "" && |
| 392 | + subject.Relation != tuple.Ellipsis { |
| 393 | + return false |
| 394 | + } |
| 395 | + |
| 396 | + return true |
| 397 | +} |
| 398 | + |
| 399 | +// isRecursiveResource checks if a resource represents a recursive node that should be explored further. |
| 400 | +func (r *RecursiveIterator) isRecursiveResource(resource Object) bool { |
| 401 | + // Resources don't have relations, just check type |
| 402 | + return resource.ObjectType == r.definitionName |
| 403 | +} |
0 commit comments