Skip to content

Commit cd8f1ce

Browse files
authored
feat: add recursive direction strategies, and fix IS BFS (#2891)
1 parent 034d32d commit cd8f1ce

File tree

9 files changed

+1253
-262
lines changed

9 files changed

+1253
-262
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package benchmarks
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"testing"
7+
8+
"github.com/stretchr/testify/require"
9+
10+
"github.com/authzed/spicedb/internal/datastore/common"
11+
"github.com/authzed/spicedb/internal/datastore/memdb"
12+
"github.com/authzed/spicedb/pkg/datastore"
13+
"github.com/authzed/spicedb/pkg/query"
14+
"github.com/authzed/spicedb/pkg/schema/v2"
15+
"github.com/authzed/spicedb/pkg/schemadsl/compiler"
16+
"github.com/authzed/spicedb/pkg/schemadsl/input"
17+
"github.com/authzed/spicedb/pkg/tuple"
18+
)
19+
20+
// BenchmarkCheckDeepArrow benchmarks permission checking through a deep recursive chain.
21+
// This recreates the testharness scenario with:
22+
// - A 30+ level deep parent chain: document:target -> document:1 -> ... -> document:29
23+
// - document:29#view@user:slow
24+
// - Checking if user:slow has viewer permission on document:target
25+
//
26+
// The permission viewer = view + parent->viewer creates a recursive traversal through
27+
// all 30+ levels to find the view relationship at the end of the chain.
28+
func BenchmarkCheckDeepArrow(b *testing.B) {
29+
// Create an in-memory datastore
30+
rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC)
31+
require.NoError(b, err)
32+
33+
ctx := context.Background()
34+
35+
schemaText := `
36+
definition user {}
37+
38+
definition document {
39+
relation parent: document
40+
relation view: user
41+
permission viewer = view + parent->viewer
42+
}
43+
`
44+
45+
// Compile the schema
46+
compiled, err := compiler.Compile(compiler.InputSchema{
47+
Source: input.Source("benchmark"),
48+
SchemaString: schemaText,
49+
}, compiler.AllowUnprefixedObjectType())
50+
require.NoError(b, err)
51+
52+
// Write the schema
53+
_, err = rawDS.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
54+
return rwt.LegacyWriteNamespaces(ctx, compiled.ObjectDefinitions...)
55+
})
56+
require.NoError(b, err)
57+
58+
// Build relationships for the deep arrow scenario
59+
// Create a chain: document:target -> document:1 -> document:2 -> ... -> document:30 -> document:31
60+
// Plus: document:29#view@user:slow
61+
relationships := make([]tuple.Relationship, 0, 33)
62+
63+
// document:target#parent@document:1
64+
relationships = append(relationships, tuple.MustParse("document:target#parent@document:1"))
65+
66+
// Chain: document:1 through document:30
67+
for i := 1; i <= 30; i++ {
68+
rel := fmt.Sprintf("document:%d#parent@document:%d", i, i+1)
69+
relationships = append(relationships, tuple.MustParse(rel))
70+
}
71+
72+
// The view relationship at the end of the chain
73+
relationships = append(relationships, tuple.MustParse("document:29#view@user:slow"))
74+
75+
// Write all relationships to the datastore
76+
revision, err := common.WriteRelationships(ctx, rawDS, tuple.UpdateOperationCreate, relationships...)
77+
require.NoError(b, err)
78+
79+
// Build schema for querying
80+
dsSchema, err := schema.BuildSchemaFromDefinitions(compiled.ObjectDefinitions, nil)
81+
require.NoError(b, err)
82+
83+
// Create the iterator tree for the viewer permission using BuildIteratorFromSchema
84+
viewerIterator, err := query.BuildIteratorFromSchema(dsSchema, "document", "viewer")
85+
require.NoError(b, err)
86+
87+
// Create query context
88+
queryCtx := query.NewLocalContext(ctx,
89+
query.WithReader(rawDS.SnapshotReader(revision)),
90+
query.WithMaxRecursionDepth(50),
91+
)
92+
93+
// The resource we're checking: document:target
94+
resources := query.NewObjects("document", "target")
95+
96+
// The subject we're checking: user:slow
97+
subject := query.NewObject("user", "slow").WithEllipses()
98+
99+
// Reset the timer - everything before this is setup
100+
b.ResetTimer()
101+
102+
// Run the benchmark
103+
for b.Loop() {
104+
// Check if user:slow can view document:target
105+
// This will traverse the entire 30+ level chain
106+
seq, err := queryCtx.Check(viewerIterator, resources, subject)
107+
require.NoError(b, err)
108+
109+
// Collect all results (should find user:slow at the end of the chain)
110+
paths, err := query.CollectAll(seq)
111+
require.NoError(b, err)
112+
113+
// Verify we found the expected result
114+
require.Len(b, paths, 1)
115+
require.Equal(b, "slow", paths[0].Subject.ObjectID)
116+
}
117+
}

pkg/query/context.go

Lines changed: 109 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,25 @@ package query
33
import (
44
"context"
55
"fmt"
6+
"io"
67
"maps"
78
"strings"
89
"sync"
910
"time"
1011

1112
"github.com/authzed/spicedb/internal/caveats"
1213
"github.com/authzed/spicedb/pkg/datastore"
14+
"github.com/authzed/spicedb/pkg/datastore/options"
1315
"github.com/authzed/spicedb/pkg/spiceerrors"
16+
"github.com/authzed/spicedb/pkg/tuple"
1417
)
1518

1619
// TraceLogger is used for debugging iterator execution
1720
type TraceLogger struct {
1821
traces []string
1922
depth int
2023
stack []Iterator // Stack of iterator pointers for proper indentation context
24+
writer io.Writer // Optional writer to output traces in real-time
2125
}
2226

2327
// NewTraceLogger creates a new trace logger
@@ -26,6 +30,28 @@ func NewTraceLogger() *TraceLogger {
2630
traces: make([]string, 0),
2731
depth: 0,
2832
stack: make([]Iterator, 0),
33+
writer: nil,
34+
}
35+
}
36+
37+
// NewTraceLoggerWithWriter creates a new trace logger with an optional writer
38+
// for real-time trace output
39+
func NewTraceLoggerWithWriter(w io.Writer) *TraceLogger {
40+
return &TraceLogger{
41+
traces: make([]string, 0),
42+
depth: 0,
43+
stack: make([]Iterator, 0),
44+
writer: w,
45+
}
46+
}
47+
48+
// appendTrace appends a trace line to the traces slice and optionally writes it
49+
// to the writer if one is configured
50+
func (t *TraceLogger) appendTrace(line string) {
51+
t.traces = append(t.traces, line)
52+
if t.writer != nil {
53+
// Write the line with a newline
54+
fmt.Fprintln(t.writer, line)
2955
}
3056
}
3157

@@ -53,14 +79,11 @@ func (t *TraceLogger) EnterIterator(it Iterator, traceString string) {
5379
indent := strings.Repeat(" ", t.depth)
5480
idPrefix := iteratorIDPrefix(it)
5581

56-
t.traces = append(
57-
t.traces,
58-
fmt.Sprintf("%s-> %s: %s",
59-
indent,
60-
idPrefix,
61-
traceString,
62-
),
63-
)
82+
t.appendTrace(fmt.Sprintf("%s-> %s: %s",
83+
indent,
84+
idPrefix,
85+
traceString,
86+
))
6487
t.depth++
6588
t.stack = append(t.stack, it) // Push iterator pointer onto stack
6689
}
@@ -129,7 +152,7 @@ func (t *TraceLogger) ExitIterator(it Iterator, paths []Path) {
129152
p.Resource.ObjectType, p.Resource.ObjectID, p.Relation,
130153
p.Subject.ObjectType, p.Subject.ObjectID, caveatInfo)
131154
}
132-
t.traces = append(t.traces, fmt.Sprintf("%s<- %s: returned %d paths: [%s]",
155+
t.appendTrace(fmt.Sprintf("%s<- %s: returned %d paths: [%s]",
133156
indent, idPrefix, len(paths), strings.Join(pathStrs, ", ")))
134157
}
135158

@@ -155,7 +178,7 @@ func (t *TraceLogger) LogStep(it Iterator, step string, data ...any) {
155178
indent := strings.Repeat(" ", indentLevel)
156179
idPrefix := iteratorIDPrefix(it)
157180
message := fmt.Sprintf(step, data...)
158-
t.traces = append(t.traces, fmt.Sprintf("%s %s: %s", indent, idPrefix, message))
181+
t.appendTrace(fmt.Sprintf("%s %s: %s", indent, idPrefix, message))
159182
}
160183

161184
// DumpTrace returns all traces as a string
@@ -258,6 +281,17 @@ type Context struct {
258281
TraceLogger *TraceLogger // For debugging iterator execution
259282
Analyze *AnalyzeCollector // Thread-safe collector for query analysis stats
260283
MaxRecursionDepth int // Maximum depth for recursive iterators (0 = use default of 10)
284+
285+
// Pagination options for IterSubjects and IterResources
286+
PaginationCursors map[string]*tuple.Relationship // Cursors for pagination, keyed by iterator ID
287+
PaginationLimit *uint64 // Limit for pagination (max number of results to return)
288+
PaginationSort options.SortOrder // Sort order for pagination
289+
290+
// recursiveFrontierCollectors holds frontier collections for BFS IterSubjects.
291+
// Key: RecursiveIterator.ID()
292+
// Value: collected Objects for the next frontier
293+
// A non-nil entry for an ID enables collection mode for that RecursiveIterator.
294+
recursiveFrontierCollectors map[string][]Object
261295
}
262296

263297
// NewLocalContext creates a new query execution context with a LocalExecutor.
@@ -306,6 +340,32 @@ func WithMaxRecursionDepth(depth int) ContextOption {
306340
return func(ctx *Context) { ctx.MaxRecursionDepth = depth }
307341
}
308342

343+
// WithPaginationLimit sets the pagination limit for the context.
344+
func WithPaginationLimit(limit uint64) ContextOption {
345+
return func(ctx *Context) { ctx.PaginationLimit = &limit }
346+
}
347+
348+
// WithPaginationSort sets the pagination sort order for the context.
349+
func WithPaginationSort(sort options.SortOrder) ContextOption {
350+
return func(ctx *Context) { ctx.PaginationSort = sort }
351+
}
352+
353+
// GetPaginationCursor retrieves the cursor for a specific iterator ID.
354+
func (ctx *Context) GetPaginationCursor(iteratorID string) *tuple.Relationship {
355+
if ctx.PaginationCursors == nil {
356+
return nil
357+
}
358+
return ctx.PaginationCursors[iteratorID]
359+
}
360+
361+
// SetPaginationCursor sets the cursor for a specific iterator ID.
362+
func (ctx *Context) SetPaginationCursor(iteratorID string, cursor *tuple.Relationship) {
363+
if ctx.PaginationCursors == nil {
364+
ctx.PaginationCursors = make(map[string]*tuple.Relationship)
365+
}
366+
ctx.PaginationCursors[iteratorID] = cursor
367+
}
368+
309369
func (ctx *Context) TraceStep(it Iterator, step string, data ...any) {
310370
if ctx.TraceLogger != nil {
311371
ctx.TraceLogger.LogStep(it, step, data...)
@@ -483,3 +543,42 @@ type Executor interface {
483543
// specified ObjectType. If filterResourceType.Type is empty, no filtering is applied.
484544
IterResources(ctx *Context, it Iterator, subject ObjectAndRelation, filterResourceType ObjectType) (PathSeq, error)
485545
}
546+
547+
// EnableFrontierCollection enables frontier collection for a RecursiveIterator.
548+
// Creates a non-nil entry in the map, which signals collection mode.
549+
func (ctx *Context) EnableFrontierCollection(iteratorID string) {
550+
if ctx.recursiveFrontierCollectors == nil {
551+
ctx.recursiveFrontierCollectors = make(map[string][]Object)
552+
}
553+
ctx.recursiveFrontierCollectors[iteratorID] = []Object{}
554+
}
555+
556+
// CollectFrontierObject appends an object to the frontier collection.
557+
// Only appends if collection mode is enabled (non-nil entry exists).
558+
func (ctx *Context) CollectFrontierObject(iteratorID string, obj Object) {
559+
if ctx.recursiveFrontierCollectors == nil {
560+
return
561+
}
562+
if collection, exists := ctx.recursiveFrontierCollectors[iteratorID]; exists {
563+
ctx.recursiveFrontierCollectors[iteratorID] = append(collection, obj)
564+
}
565+
}
566+
567+
// ExtractFrontierCollection retrieves and removes the collected frontier.
568+
func (ctx *Context) ExtractFrontierCollection(iteratorID string) []Object {
569+
if ctx.recursiveFrontierCollectors == nil {
570+
return nil
571+
}
572+
collection := ctx.recursiveFrontierCollectors[iteratorID]
573+
delete(ctx.recursiveFrontierCollectors, iteratorID)
574+
return collection
575+
}
576+
577+
// IsCollectingFrontier checks if collection mode is enabled (non-nil entry exists).
578+
func (ctx *Context) IsCollectingFrontier(iteratorID string) bool {
579+
if ctx.recursiveFrontierCollectors == nil {
580+
return false
581+
}
582+
_, exists := ctx.recursiveFrontierCollectors[iteratorID]
583+
return exists
584+
}

0 commit comments

Comments
 (0)