diff --git a/.github/workflows/benchmark.yaml b/.github/workflows/benchmark.yaml index f2bfe0c47..dfe2e88cd 100644 --- a/.github/workflows/benchmark.yaml +++ b/.github/workflows/benchmark.yaml @@ -4,7 +4,7 @@ on: # yamllint disable-line rule:truthy push: branches: - "main" - # NOTE! Do NOT add any other "on", because this workflow has permission to write to the repo! +# NOTE! Do NOT add any other "on", because this workflow has permission to write to the repo! permissions: # permission to update benchmark contents in gh-pages branch diff --git a/docs/spicedb.md b/docs/spicedb.md index bdb1cfc17..8225aeaf3 100644 --- a/docs/spicedb.md +++ b/docs/spicedb.md @@ -91,6 +91,7 @@ spicedb datastore gc [flags] --datastore-disable-watch-support disable watch support (only enable if you absolutely do not need watch) --datastore-engine string type of datastore to initialize ("cockroachdb", "mysql", "postgres", "spanner") (default "memory") --datastore-experimental-column-optimization enable experimental column optimization (default true) + --datastore-experimental-schema-mode string experimental schema mode ("read-legacy-write-legacy", "read-legacy-write-both", "read-new-write-both", "read-new-write-new") (default "read-legacy-write-legacy") --datastore-follower-read-delay-duration duration amount of time to subtract from non-sync revision timestamps to ensure they are sufficiently in the past to enable follower reads (CockroachDB and Spanner drivers only) or read replicas (Postgres and MySQL drivers only) (default 4.8s) --datastore-gc-interval duration amount of time between passes of garbage collection (Postgres driver only) (default 3m0s) --datastore-gc-max-operation-time duration maximum amount of time a garbage collection pass can operate before timing out (Postgres driver only) (default 1m0s) @@ -257,6 +258,7 @@ spicedb datastore repair [flags] --datastore-disable-watch-support disable watch support (only enable if you absolutely do not need watch) --datastore-engine string type of datastore to initialize ("cockroachdb", "mysql", "postgres", "spanner") (default "memory") --datastore-experimental-column-optimization enable experimental column optimization (default true) + --datastore-experimental-schema-mode string experimental schema mode ("read-legacy-write-legacy", "read-legacy-write-both", "read-new-write-both", "read-new-write-new") (default "read-legacy-write-legacy") --datastore-follower-read-delay-duration duration amount of time to subtract from non-sync revision timestamps to ensure they are sufficiently in the past to enable follower reads (CockroachDB and Spanner drivers only) or read replicas (Postgres and MySQL drivers only) (default 4.8s) --datastore-gc-interval duration amount of time between passes of garbage collection (Postgres driver only) (default 3m0s) --datastore-gc-max-operation-time duration maximum amount of time a garbage collection pass can operate before timing out (Postgres driver only) (default 1m0s) @@ -444,6 +446,7 @@ spicedb serve [flags] --datastore-disable-watch-support disable watch support (only enable if you absolutely do not need watch) --datastore-engine string type of datastore to initialize ("cockroachdb", "mysql", "postgres", "spanner") (default "memory") --datastore-experimental-column-optimization enable experimental column optimization (default true) + --datastore-experimental-schema-mode string experimental schema mode ("read-legacy-write-legacy", "read-legacy-write-both", "read-new-write-both", "read-new-write-new") (default "read-legacy-write-legacy") --datastore-follower-read-delay-duration duration amount of time to subtract from non-sync revision timestamps to ensure they are sufficiently in the past to enable follower reads (CockroachDB and Spanner drivers only) or read replicas (Postgres and MySQL drivers only) (default 4.8s) --datastore-gc-interval duration amount of time between passes of garbage collection (Postgres driver only) (default 3m0s) --datastore-gc-max-operation-time duration maximum amount of time a garbage collection pass can operate before timing out (Postgres driver only) (default 1m0s) diff --git a/internal/caveats/run_test.go b/internal/caveats/run_test.go index 30f35b42c..d04975f6b 100644 --- a/internal/caveats/run_test.go +++ b/internal/caveats/run_test.go @@ -467,10 +467,10 @@ func TestRunCaveatExpressions(t *testing.T) { third } `, nil, req) - headRevision, err := ds.HeadRevision(t.Context()) + headRevision, _, err := ds.HeadRevision(t.Context()) req.NoError(err) - reader := ds.SnapshotReader(headRevision) + reader := ds.SnapshotReader(headRevision, datastore.NoSchemaHashForTesting) for _, debugOption := range []RunCaveatExpressionDebugOption{ RunCaveatExpressionNoDebugging, @@ -519,10 +519,10 @@ func TestRunCaveatWithMissingMap(t *testing.T) { } `, nil, req) - headRevision, err := ds.HeadRevision(t.Context()) + headRevision, _, err := ds.HeadRevision(t.Context()) req.NoError(err) - reader := ds.SnapshotReader(headRevision) + reader := ds.SnapshotReader(headRevision, datastore.NoSchemaHashForTesting) result, err := RunSingleCaveatExpression( t.Context(), @@ -549,10 +549,10 @@ func TestRunCaveatWithEmptyMap(t *testing.T) { } `, nil, req) - headRevision, err := ds.HeadRevision(t.Context()) + headRevision, _, err := ds.HeadRevision(t.Context()) req.NoError(err) - reader := ds.SnapshotReader(headRevision) + reader := ds.SnapshotReader(headRevision, datastore.NoSchemaHashForTesting) _, err = RunSingleCaveatExpression( t.Context(), @@ -585,10 +585,10 @@ func TestRunCaveatMultipleTimes(t *testing.T) { } `, nil, req) - headRevision, err := ds.HeadRevision(t.Context()) + headRevision, _, err := ds.HeadRevision(t.Context()) req.NoError(err) - reader := ds.SnapshotReader(headRevision) + reader := ds.SnapshotReader(headRevision, datastore.NoSchemaHashForTesting) runner := NewCaveatRunner(types.Default.TypeSet) // Run the first caveat. @@ -646,10 +646,10 @@ func TestRunCaveatWithMissingDefinition(t *testing.T) { } `, nil, req) - headRevision, err := ds.HeadRevision(t.Context()) + headRevision, _, err := ds.HeadRevision(t.Context()) req.NoError(err) - reader := ds.SnapshotReader(headRevision) + reader := ds.SnapshotReader(headRevision, datastore.NoSchemaHashForTesting) // Try to run a caveat that doesn't exist _, err = RunSingleCaveatExpression( @@ -679,10 +679,10 @@ func TestCaveatRunnerPopulateCaveatDefinitionsForExpr(t *testing.T) { } `, nil, req) - headRevision, err := ds.HeadRevision(t.Context()) + headRevision, _, err := ds.HeadRevision(t.Context()) req.NoError(err) - reader := ds.SnapshotReader(headRevision) + reader := ds.SnapshotReader(headRevision, datastore.NoSchemaHashForTesting) runner := NewCaveatRunner(types.Default.TypeSet) // Test populating definitions for complex expression @@ -721,10 +721,10 @@ func TestCaveatRunnerEmptyExpression(t *testing.T) { } `, nil, req) - headRevision, err := ds.HeadRevision(t.Context()) + headRevision, _, err := ds.HeadRevision(t.Context()) req.NoError(err) - reader := ds.SnapshotReader(headRevision) + reader := ds.SnapshotReader(headRevision, datastore.NoSchemaHashForTesting) runner := NewCaveatRunner(types.Default.TypeSet) // Test with an expression that has no caveats (empty operation) @@ -799,10 +799,10 @@ func TestUnknownCaveatOperation(t *testing.T) { } `, nil, req) - headRevision, err := ds.HeadRevision(t.Context()) + headRevision, _, err := ds.HeadRevision(t.Context()) req.NoError(err) - reader := ds.SnapshotReader(headRevision) + reader := ds.SnapshotReader(headRevision, datastore.NoSchemaHashForTesting) runner := NewCaveatRunner(types.Default.TypeSet) // Create an expression with an unknown operation diff --git a/internal/datastore/benchmark/driver_bench_test.go b/internal/datastore/benchmark/driver_bench_test.go index c12fcfaef..e5fca5747 100644 --- a/internal/datastore/benchmark/driver_bench_test.go +++ b/internal/datastore/benchmark/driver_bench_test.go @@ -95,7 +95,7 @@ func BenchmarkDatastoreDriver(b *testing.B) { // Sleep to give the datastore time to stabilize after all the writes time.Sleep(1 * time.Second) - headRev, err := ds.HeadRevision(ctx) + headRev, _, err := ds.HeadRevision(ctx) require.NoError(b, err) b.StartTimer() @@ -104,7 +104,7 @@ func BenchmarkDatastoreDriver(b *testing.B) { b.Run("SnapshotRead", func(b *testing.B) { for n := 0; n < b.N; n++ { randDocNum := rand.Intn(numDocuments) //nolint:gosec - iter, err := ds.SnapshotReader(headRev).QueryRelationships(ctx, datastore.RelationshipsFilter{ + iter, err := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting).QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: testfixtures.DocumentNS.Name, OptionalResourceIds: []string{strconv.Itoa(randDocNum)}, OptionalResourceRelation: "viewer", @@ -120,7 +120,7 @@ func BenchmarkDatastoreDriver(b *testing.B) { }) b.Run("SnapshotReadOnlyNamespace", func(b *testing.B) { for n := 0; n < b.N; n++ { - iter, err := ds.SnapshotReader(headRev).QueryRelationships(ctx, datastore.RelationshipsFilter{ + iter, err := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting).QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: testfixtures.DocumentNS.Name, }, options.WithQueryShape(queryshape.FindResourceOfType)) require.NoError(b, err) @@ -136,7 +136,7 @@ func BenchmarkDatastoreDriver(b *testing.B) { order := order b.Run(orderName, func(b *testing.B) { for n := 0; n < b.N; n++ { - iter, err := ds.SnapshotReader(headRev).QueryRelationships(ctx, datastore.RelationshipsFilter{ + iter, err := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting).QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: testfixtures.DocumentNS.Name, }, options.WithSort(order), options.WithQueryShape(queryshape.FindResourceOfType)) require.NoError(b, err) @@ -154,7 +154,7 @@ func BenchmarkDatastoreDriver(b *testing.B) { order := order b.Run(orderName, func(b *testing.B) { for n := 0; n < b.N; n++ { - iter, err := ds.SnapshotReader(headRev).QueryRelationships(ctx, datastore.RelationshipsFilter{ + iter, err := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting).QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: testfixtures.DocumentNS.Name, OptionalResourceRelation: "viewer", }, options.WithSort(order), options.WithQueryShape(queryshape.Varying)) @@ -174,7 +174,7 @@ func BenchmarkDatastoreDriver(b *testing.B) { b.Run(orderName, func(b *testing.B) { for n := 0; n < b.N; n++ { randDocNum := rand.Intn(numDocuments) //nolint:gosec - iter, err := ds.SnapshotReader(headRev).QueryRelationships(ctx, datastore.RelationshipsFilter{ + iter, err := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting).QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: testfixtures.DocumentNS.Name, OptionalResourceIds: []string{strconv.Itoa(randDocNum)}, OptionalResourceRelation: "viewer", @@ -191,7 +191,7 @@ func BenchmarkDatastoreDriver(b *testing.B) { }) b.Run("SnapshotReverseRead", func(b *testing.B) { for n := 0; n < b.N; n++ { - iter, err := ds.SnapshotReader(headRev).ReverseQueryRelationships(ctx, datastore.SubjectsFilter{ + iter, err := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting).ReverseQueryRelationships(ctx, datastore.SubjectsFilter{ SubjectType: testfixtures.UserNS.Name, }, options.WithSortForReverse(options.ByResource), options.WithQueryShapeForReverse(queryshape.Varying)) require.NoError(b, err) diff --git a/internal/datastore/common/chunkbytes.go b/internal/datastore/common/chunkbytes.go index af74dd05e..b9aac99ab 100644 --- a/internal/datastore/common/chunkbytes.go +++ b/internal/datastore/common/chunkbytes.go @@ -81,48 +81,70 @@ type SQLByteChunkerConfig[T any] struct { AliveValue T } +// WithExecutor returns a copy of the config with the specified executor. +func (c SQLByteChunkerConfig[T]) WithExecutor(executor ChunkedBytesExecutor) SQLByteChunkerConfig[T] { + c.Executor = executor + return c +} + +// WithTableName returns a copy of the config with the specified table name. +func (c SQLByteChunkerConfig[T]) WithTableName(tableName string) SQLByteChunkerConfig[T] { + c.TableName = tableName + return c +} + // SQLByteChunker provides methods for reading and writing byte data // that is chunked across multiple rows in a SQL table. type SQLByteChunker[T any] struct { config SQLByteChunkerConfig[T] } -// MustNewSQLByteChunker creates a new SQLByteChunker with the specified configuration. -// Panics if the configuration is invalid. -func MustNewSQLByteChunker[T any](config SQLByteChunkerConfig[T]) *SQLByteChunker[T] { +// NewSQLByteChunker creates a new SQLByteChunker with the specified configuration. +// Returns an error if the configuration is invalid. +func NewSQLByteChunker[T any](config SQLByteChunkerConfig[T]) (*SQLByteChunker[T], error) { if config.MaxChunkSize <= 0 { - panic("maxChunkSize must be greater than 0") + return nil, errors.New("maxChunkSize must be greater than 0") } if config.TableName == "" { - panic("tableName cannot be empty") + return nil, errors.New("tableName cannot be empty") } if config.NameColumn == "" { - panic("nameColumn cannot be empty") + return nil, errors.New("nameColumn cannot be empty") } if config.ChunkIndexColumn == "" { - panic("chunkIndexColumn cannot be empty") + return nil, errors.New("chunkIndexColumn cannot be empty") } if config.ChunkDataColumn == "" { - panic("chunkDataColumn cannot be empty") + return nil, errors.New("chunkDataColumn cannot be empty") } if config.PlaceholderFormat == nil { - panic("placeholderFormat cannot be nil") + return nil, errors.New("placeholderFormat cannot be nil") } if config.Executor == nil { - panic("executor cannot be nil") + return nil, errors.New("executor cannot be nil") } if config.WriteMode == WriteModeInsertWithTombstones { if config.CreatedAtColumn == "" { - panic("createdAtColumn is required when using WriteModeInsertWithTombstones") + return nil, errors.New("createdAtColumn is required when using WriteModeInsertWithTombstones") } if config.DeletedAtColumn == "" { - panic("deletedAtColumn is required when using WriteModeInsertWithTombstones") + return nil, errors.New("deletedAtColumn is required when using WriteModeInsertWithTombstones") } } return &SQLByteChunker[T]{ config: config, + }, nil +} + +// MustNewSQLByteChunker creates a new SQLByteChunker with the specified configuration. +// Panics if the configuration is invalid. +func MustNewSQLByteChunker[T any](config SQLByteChunkerConfig[T]) *SQLByteChunker[T] { + chunker, err := NewSQLByteChunker(config) + if err != nil { + panic(err) } + return chunker } // WriteChunkedBytes writes chunked byte data to the database within a transaction. diff --git a/internal/datastore/common/chunkbytes_test.go b/internal/datastore/common/chunkbytes_test.go index 341b5d3a3..14e455061 100644 --- a/internal/datastore/common/chunkbytes_test.go +++ b/internal/datastore/common/chunkbytes_test.go @@ -57,6 +57,7 @@ type fakeExecutor struct { readResult map[int][]byte readErr error transaction *fakeTransaction + onRead func() // Optional callback invoked on each read } func (m *fakeExecutor) BeginTransaction(ctx context.Context) (ChunkedBytesTransaction, error) { @@ -64,6 +65,9 @@ func (m *fakeExecutor) BeginTransaction(ctx context.Context) (ChunkedBytesTransa } func (m *fakeExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + if m.onRead != nil { + m.onRead() + } if m.readErr != nil { return nil, m.readErr } diff --git a/internal/datastore/common/hashcache.go b/internal/datastore/common/hashcache.go new file mode 100644 index 000000000..2cc33118e --- /dev/null +++ b/internal/datastore/common/hashcache.go @@ -0,0 +1,228 @@ +package common + +import ( + "context" + "fmt" + "sync/atomic" + + lru "github.com/hashicorp/golang-lru/v2" + "golang.org/x/sync/singleflight" + + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +const defaultMaxCacheEntries = 10 + +// bypassSentinels contains all schema hash values that intentionally bypass the cache. +// These are special sentinel values used in specific contexts where caching is not appropriate. +// Using a map for O(1) lookup instead of slice iteration. +var bypassSentinels = map[datastore.SchemaHash]bool{ + datastore.NoSchemaHashInTransaction: true, + datastore.NoSchemaHashForTesting: true, + datastore.NoSchemaHashForWatch: true, + datastore.NoSchemaHashForLegacyCursor: true, +} + +// isBypassSentinel returns true if the given schema hash is a bypass sentinel value. +func isBypassSentinel(schemaHash datastore.SchemaHash) bool { + return bypassSentinels[schemaHash] +} + +// latestSchemaEntry holds the most recent schema entry for fast-path lookups. +type latestSchemaEntry struct { + revision datastore.Revision + hash datastore.SchemaHash + schema *core.StoredSchema +} + +// SchemaHashCache is a thread-safe LRU cache for schemas indexed by hash. +// It maintains an atomic pointer to the latest schema for fast-path reads when +// multiple requests access the same (most recent) schema. +// +// The LRU cache itself is thread-safe, so no locks are needed for cache operations. +// The atomic latest entry provides lock-free fast-path for the common case. +type SchemaHashCache struct { + cache *lru.Cache[string, *core.StoredSchema] // Thread-safe LRU + latest atomic.Pointer[latestSchemaEntry] // Fast path for latest schema + singleflight singleflight.Group +} + +// NewSchemaHashCache creates a new hash-based schema cache. +func NewSchemaHashCache(opts options.SchemaCacheOptions) (*SchemaHashCache, error) { + maxEntries := int(opts.MaximumCacheEntries) + if maxEntries == 0 { + maxEntries = defaultMaxCacheEntries + } + + cache, err := lru.New[string, *core.StoredSchema](maxEntries) + if err != nil { + return nil, fmt.Errorf("failed to create LRU cache: %w", err) + } + + return &SchemaHashCache{ + cache: cache, + }, nil +} + +// Get retrieves a schema from the cache by revision and hash. +// Fast path: If the hash matches the atomic latest entry, return immediately. +// Slow path: Check the LRU cache (which is thread-safe). +// Returns (nil, nil) if the hash is a bypass sentinel (NoSchemaHashInTransaction or NoSchemaHashForTesting). +// Returns error if hash is empty string (indicates a bug where schema hash wasn't properly provided). +func (c *SchemaHashCache) Get(rev datastore.Revision, schemaHash datastore.SchemaHash) (*core.StoredSchema, error) { + if c == nil { + return nil, nil + } + + // Check for bypass sentinels - these intentionally skip the cache + if isBypassSentinel(schemaHash) { + return nil, nil + } + + // Empty hash indicates a bug - schema hash should always be provided or use a sentinel + if schemaHash == "" { + return nil, spiceerrors.MustBugf("empty schema hash passed to cache.Get() - use NoSchemaHashInTransaction or NoSchemaHashForTesting, or provide a real hash") + } + + // Fast path: Check atomic latest entry + if latest := c.latest.Load(); latest != nil && latest.hash == schemaHash { + schemaCacheHits.Inc() + return latest.schema, nil + } + + // Slow path: Check LRU cache (thread-safe, no lock needed) + schema, ok := c.cache.Get(string(schemaHash)) + if !ok { + schemaCacheMisses.Inc() + return nil, nil + } + + schemaCacheHits.Inc() + return schema, nil +} + +// Set stores a schema in the cache by revision and hash. +// Adds to the LRU cache (thread-safe) and updates the atomic latest entry +// if the revision is newer or if revision is NoRevision (from transactions). +// No-ops if hash is a bypass sentinel (NoSchemaHashInTransaction or NoSchemaHashForTesting). +// Returns error if hash is empty string or if revision is nil (indicates a bug where they weren't properly provided). +func (c *SchemaHashCache) Set(rev datastore.Revision, schemaHash datastore.SchemaHash, schema *core.StoredSchema) error { + if c == nil { + return nil + } + + // Check for bypass sentinels - these intentionally skip the cache + if isBypassSentinel(schemaHash) { + return nil + } + + // Empty hash indicates a bug - schema hash should always be provided or use a sentinel + if schemaHash == "" { + return spiceerrors.MustBugf("empty schema hash passed to cache.Set() - use NoSchemaHashInTransaction, NoSchemaHashForTesting, NoSchemaHashForWatch, or provide a real hash") + } + + // Nil revision indicates a bug - should use NoRevision for transaction cases + if rev == nil { + return spiceerrors.MustBugf("nil revision passed to cache.Set() - use datastore.NoRevision for transaction cases") + } + + // Add to LRU cache (thread-safe, no lock needed) + c.cache.Add(string(schemaHash), schema) + + // Update atomic latest if this is newer or if no revision check (txn case) + shouldUpdateLatest := rev == datastore.NoRevision + + if !shouldUpdateLatest { + // Check if this revision is newer than the current latest + if latest := c.latest.Load(); latest != nil { + // If we have a latest entry with a revision, compare + if latest.revision != nil && latest.revision != datastore.NoRevision { + shouldUpdateLatest = rev.GreaterThan(latest.revision) + } else { + // Current latest has no revision, so update with this one + shouldUpdateLatest = true + } + } else { + // No latest entry yet, so set it + shouldUpdateLatest = true + } + } + + if shouldUpdateLatest { + c.latest.Store(&latestSchemaEntry{ + revision: rev, + hash: schemaHash, + schema: schema, + }) + } + + return nil +} + +// GetOrLoad retrieves a schema from the cache, or loads it using the provided loader. +// Uses singleflight to deduplicate concurrent loads of the same hash. +// Bypasses cache for sentinel values (NoSchemaHashInTransaction or NoSchemaHashForTesting). +// Returns error if hash is empty string (indicates a bug where schema hash wasn't properly provided). +func (c *SchemaHashCache) GetOrLoad( + ctx context.Context, + rev datastore.Revision, + schemaHash datastore.SchemaHash, + loader func(ctx context.Context) (*core.StoredSchema, error), +) (*core.StoredSchema, error) { + // Check for bypass sentinels - load directly without caching + if c == nil || isBypassSentinel(schemaHash) { + schema, err := loader(ctx) + if err != nil { + return nil, err + } + return schema, nil + } + + // Empty hash indicates a bug - schema hash should always be provided or use a sentinel + if schemaHash == "" { + return nil, spiceerrors.MustBugf("empty schema hash passed to cache.GetOrLoad() - use NoSchemaHashInTransaction, NoSchemaHashForTesting, NoSchemaHashForWatch, or provide a real hash") + } + + // Try cache first + schema, err := c.Get(rev, schemaHash) + if err != nil { + return nil, err + } + if schema != nil { + return schema, nil + } + + // Load with singleflight to prevent duplicate loads + result, err, _ := c.singleflight.Do(string(schemaHash), func() (any, error) { + // Check cache again in case another goroutine loaded it + schema, err := c.Get(rev, schemaHash) + if err != nil { + return nil, err + } + if schema != nil { + return schema, nil + } + + // Load from datastore + schema, err = loader(ctx) + if err != nil { + return nil, err + } + + // Cache the result + if err := c.Set(rev, schemaHash, schema); err != nil { + return nil, err + } + + return schema, nil + }) + + if err != nil { + return nil, err + } + + return result.(*core.StoredSchema), nil +} diff --git a/internal/datastore/common/hashcache_bench_test.go b/internal/datastore/common/hashcache_bench_test.go new file mode 100644 index 000000000..68a5c09a2 --- /dev/null +++ b/internal/datastore/common/hashcache_bench_test.go @@ -0,0 +1,383 @@ +package common + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + + lru "github.com/hashicorp/golang-lru/v2" + "golang.org/x/sync/singleflight" + + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +// AtomicSchemaHashCache is an atomic pointer-based implementation for comparison +type AtomicSchemaHashCache struct { + cache atomic.Pointer[lru.Cache[string, *core.StoredSchema]] // GUARDED_BY(mu) + singleflight singleflight.Group + mu sync.Mutex // Only for writes +} + +func NewAtomicSchemaHashCache(opts options.SchemaCacheOptions) (*AtomicSchemaHashCache, error) { + maxEntries := int(opts.MaximumCacheEntries) + if maxEntries == 0 { + maxEntries = defaultMaxCacheEntries + } + + cache, err := lru.New[string, *core.StoredSchema](maxEntries) + if err != nil { + return nil, fmt.Errorf("failed to create LRU cache: %w", err) + } + + c := &AtomicSchemaHashCache{} + c.cache.Store(cache) + return c, nil +} + +func (c *AtomicSchemaHashCache) Get(schemaHash string) *core.StoredSchema { + if c == nil || schemaHash == "" { + return nil + } + + cache := c.cache.Load() + if cache == nil { + return nil + } + + schema, ok := cache.Get(schemaHash) + if !ok { + schemaCacheMisses.Inc() + return nil + } + + schemaCacheHits.Inc() + return schema +} + +func (c *AtomicSchemaHashCache) Set(schemaHash string, schema *core.StoredSchema) { + if c == nil || schemaHash == "" { + return + } + + c.mu.Lock() + defer c.mu.Unlock() + + cache := c.cache.Load() + if cache != nil { + cache.Add(schemaHash, schema) + } +} + +func (c *AtomicSchemaHashCache) GetOrLoad( + ctx context.Context, + schemaHash string, + loader func(ctx context.Context) (*core.StoredSchema, error), +) (*core.StoredSchema, error) { + if c == nil || schemaHash == "" { + schema, err := loader(ctx) + if err != nil { + return nil, err + } + return schema, nil + } + + if schema := c.Get(schemaHash); schema != nil { + return schema, nil + } + + result, err, _ := c.singleflight.Do(schemaHash, func() (any, error) { + if schema := c.Get(schemaHash); schema != nil { + return schema, nil + } + + schema, err := loader(ctx) + if err != nil { + return nil, err + } + + c.Set(schemaHash, schema) + + return schema, nil + }) + + if err != nil { + return nil, err + } + + return result.(*core.StoredSchema), nil +} + +// Benchmarks + +func createTestSchema(id string) *core.StoredSchema { + return &core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: "definition test_" + id + " {}", + }, + }, + } +} + +// BenchmarkHashCache_Get_Mutex tests read performance with mutex-based cache +func BenchmarkHashCache_Get_Mutex(b *testing.B) { + cache, _ := NewSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 100, + }) + + // Pre-populate cache + for i := 0; i < 10; i++ { + hash := fmt.Sprintf("hash%d", i) + _ = cache.Set(datastore.NoRevision, datastore.SchemaHash(hash), createTestSchema(hash)) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + hash := fmt.Sprintf("hash%d", i%10) + _, _ = cache.Get(datastore.NoRevision, datastore.SchemaHash(hash)) + i++ + } + }) +} + +// BenchmarkHashCache_Get_Atomic tests read performance with atomic-based cache +func BenchmarkHashCache_Get_Atomic(b *testing.B) { + cache, _ := NewAtomicSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 100, + }) + + // Pre-populate cache + for i := 0; i < 10; i++ { + hash := fmt.Sprintf("hash%d", i) + cache.Set(hash, createTestSchema(hash)) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + hash := fmt.Sprintf("hash%d", i%10) + _ = cache.Get(hash) + i++ + } + }) +} + +// BenchmarkHashCache_Set_Mutex tests write performance with mutex-based cache +func BenchmarkHashCache_Set_Mutex(b *testing.B) { + cache, _ := NewSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 100, + }) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + hash := fmt.Sprintf("hash%d", i%100) + _ = cache.Set(datastore.NoRevision, datastore.SchemaHash(hash), createTestSchema(hash)) + } +} + +// BenchmarkHashCache_Set_Atomic tests write performance with atomic-based cache +func BenchmarkHashCache_Set_Atomic(b *testing.B) { + cache, _ := NewAtomicSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 100, + }) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + hash := fmt.Sprintf("hash%d", i%100) + cache.Set(hash, createTestSchema(hash)) + } +} + +// BenchmarkHashCache_Mixed_Mutex tests mixed read/write with mutex (90% reads, 10% writes) +func BenchmarkHashCache_Mixed_Mutex(b *testing.B) { + cache, _ := NewSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 100, + }) + + // Pre-populate cache + for i := 0; i < 50; i++ { + hash := fmt.Sprintf("hash%d", i) + _ = cache.Set(datastore.NoRevision, datastore.SchemaHash(hash), createTestSchema(hash)) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + hash := fmt.Sprintf("hash%d", i%100) + if i%10 == 0 { + // 10% writes + _ = cache.Set(datastore.NoRevision, datastore.SchemaHash(hash), createTestSchema(hash)) + } else { + // 90% reads + _, _ = cache.Get(datastore.NoRevision, datastore.SchemaHash(hash)) + } + i++ + } + }) +} + +// BenchmarkHashCache_Mixed_Atomic tests mixed read/write with atomic (90% reads, 10% writes) +func BenchmarkHashCache_Mixed_Atomic(b *testing.B) { + cache, _ := NewAtomicSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 100, + }) + + // Pre-populate cache + for i := 0; i < 50; i++ { + hash := fmt.Sprintf("hash%d", i) + cache.Set(hash, createTestSchema(hash)) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + hash := fmt.Sprintf("hash%d", i%100) + if i%10 == 0 { + // 10% writes + cache.Set(hash, createTestSchema(hash)) + } else { + // 90% reads + _ = cache.Get(hash) + } + i++ + } + }) +} + +// BenchmarkHashCache_GetOrLoad_Mutex tests GetOrLoad with mutex +func BenchmarkHashCache_GetOrLoad_Mutex(b *testing.B) { + cache, _ := NewSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 100, + }) + + ctx := context.Background() + loader := func(ctx context.Context) (*core.StoredSchema, error) { + return createTestSchema("loaded"), nil + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + hash := fmt.Sprintf("hash%d", i%10) + _, _ = cache.GetOrLoad(ctx, datastore.NoRevision, datastore.SchemaHash(hash), loader) + i++ + } + }) +} + +// BenchmarkHashCache_GetOrLoad_Atomic tests GetOrLoad with atomic +func BenchmarkHashCache_GetOrLoad_Atomic(b *testing.B) { + cache, _ := NewAtomicSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 100, + }) + + ctx := context.Background() + loader := func(ctx context.Context) (*core.StoredSchema, error) { + return createTestSchema("loaded"), nil + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + hash := fmt.Sprintf("hash%d", i%10) + _, _ = cache.GetOrLoad(ctx, hash, loader) + i++ + } + }) +} + +// BenchmarkHashCache_HighContention_Mutex tests high contention scenario with mutex +func BenchmarkHashCache_HighContention_Mutex(b *testing.B) { + cache, _ := NewSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 100, + }) + + // Pre-populate with single entry + _ = cache.Set(datastore.NoRevision, datastore.SchemaHash("shared"), createTestSchema("shared")) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + // Everyone hits the same cache entry + _, _ = cache.Get(datastore.NoRevision, datastore.SchemaHash("shared")) + } + }) +} + +// BenchmarkHashCache_HighContention_Atomic tests high contention scenario with atomic +func BenchmarkHashCache_HighContention_Atomic(b *testing.B) { + cache, _ := NewAtomicSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 100, + }) + + // Pre-populate with single entry + cache.Set("shared", createTestSchema("shared")) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + // Everyone hits the same cache entry + _ = cache.Get("shared") + } + }) +} + +// BenchmarkHashCache_LowContention_Mutex tests low contention with mutex +func BenchmarkHashCache_LowContention_Mutex(b *testing.B) { + cache, _ := NewSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 100, + }) + + // Pre-populate with many entries + for i := 0; i < 100; i++ { + hash := fmt.Sprintf("hash%d", i) + _ = cache.Set(datastore.NoRevision, datastore.SchemaHash(hash), createTestSchema(hash)) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + // Each goroutine tends to hit different entries + hash := fmt.Sprintf("hash%d", i%100) + _, _ = cache.Get(datastore.NoRevision, datastore.SchemaHash(hash)) + i++ + } + }) +} + +// BenchmarkHashCache_LowContention_Atomic tests low contention with atomic +func BenchmarkHashCache_LowContention_Atomic(b *testing.B) { + cache, _ := NewAtomicSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 100, + }) + + // Pre-populate with many entries + for i := 0; i < 100; i++ { + hash := fmt.Sprintf("hash%d", i) + cache.Set(hash, createTestSchema(hash)) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + // Each goroutine tends to hit different entries + hash := fmt.Sprintf("hash%d", i%100) + _ = cache.Get(hash) + i++ + } + }) +} diff --git a/internal/datastore/common/hashcache_metrics.go b/internal/datastore/common/hashcache_metrics.go new file mode 100644 index 000000000..9f8ee9da5 --- /dev/null +++ b/internal/datastore/common/hashcache_metrics.go @@ -0,0 +1,22 @@ +package common + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + // schemaCacheLookups tracks the total number of schema cache lookups + schemaCacheLookups = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "spicedb", + Subsystem: "datastore", + Name: "schema_cache_lookups_total", + Help: "total number of schema cache Get() calls", + }, []string{"result"}) // result: "hit" or "miss" + + // schemaCacheHits tracks cache hits + schemaCacheHits = schemaCacheLookups.WithLabelValues("hit") + + // schemaCacheMisses tracks cache misses + schemaCacheMisses = schemaCacheLookups.WithLabelValues("miss") +) diff --git a/internal/datastore/common/hashcache_test.go b/internal/datastore/common/hashcache_test.go new file mode 100644 index 000000000..984e50328 --- /dev/null +++ b/internal/datastore/common/hashcache_test.go @@ -0,0 +1,382 @@ +package common + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +func makeTestSchema(text string) *core.StoredSchema { + return &core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: text, + }, + }, + } +} + +func TestSchemaHashCache_NilCache(t *testing.T) { + // Cache should use default settings when MaximumCacheEntries is 0 + cache, err := NewSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 0, + }) + require.NoError(t, err) + require.NotNil(t, cache) + + // Operations on nil cache should be safe no-ops + var nilCache *SchemaHashCache + schema, err := nilCache.Get(datastore.NoRevision, datastore.SchemaHash("hash1")) + require.NoError(t, err) + require.Nil(t, schema) + + err = nilCache.Set(datastore.NoRevision, datastore.SchemaHash("hash1"), makeTestSchema("definition user {}")) + require.NoError(t, err) + schema, err = nilCache.Get(datastore.NoRevision, datastore.SchemaHash("hash1")) + require.NoError(t, err) + require.Nil(t, schema) + + schema, err = nilCache.GetOrLoad(context.Background(), datastore.NoRevision, datastore.SchemaHash("hash1"), func(ctx context.Context) (*core.StoredSchema, error) { + return makeTestSchema("definition user {}"), nil + }) + require.NoError(t, err) + require.NotNil(t, schema) +} + +func TestSchemaHashCache_BasicGetSet(t *testing.T) { + cache, err := NewSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 10, + }) + require.NoError(t, err) + require.NotNil(t, cache) + + // Cache miss + retrieved, err := cache.Get(datastore.NoRevision, datastore.SchemaHash("hash1")) + require.NoError(t, err) + require.Nil(t, retrieved) + + // Set and get + schema := makeTestSchema("definition user {}") + err = cache.Set(datastore.NoRevision, datastore.SchemaHash("hash1"), schema) + require.NoError(t, err) + + retrieved, err = cache.Get(datastore.NoRevision, datastore.SchemaHash("hash1")) + require.NoError(t, err) + require.NotNil(t, retrieved) + require.Equal(t, schema.GetV1().SchemaText, retrieved.GetV1().SchemaText) +} + +func TestSchemaHashCache_EmptyHash(t *testing.T) { + cache, err := NewSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 10, + }) + require.NoError(t, err) + require.NotNil(t, cache) + + // Empty hash should panic (via MustBugf) in tests + require.Panics(t, func() { + _ = cache.Set(datastore.NoRevision, datastore.SchemaHash(""), makeTestSchema("definition user {}")) + }, "empty hash should panic") + + require.Panics(t, func() { + _, _ = cache.Get(datastore.NoRevision, datastore.SchemaHash("")) + }, "empty hash should panic") +} + +func TestSchemaHashCache_NilRevision(t *testing.T) { + cache, err := NewSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 10, + }) + require.NoError(t, err) + require.NotNil(t, cache) + + // Nil revision should panic (via MustBugf) in tests + require.Panics(t, func() { + _ = cache.Set(nil, datastore.SchemaHash("hash1"), makeTestSchema("definition user {}")) + }, "nil revision should panic") +} + +func TestSchemaHashCache_LRUEviction(t *testing.T) { + cache, err := NewSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 3, // Small cache for testing eviction + }) + require.NoError(t, err) + require.NotNil(t, cache) + + // Fill cache with 3 entries + require.NoError(t, cache.Set(datastore.NoRevision, datastore.SchemaHash("hash0"), makeTestSchema("definition 0"))) + require.NoError(t, cache.Set(datastore.NoRevision, datastore.SchemaHash("hash1"), makeTestSchema("definition 1"))) + require.NoError(t, cache.Set(datastore.NoRevision, datastore.SchemaHash("hash2"), makeTestSchema("definition 2"))) + + // All should be present initially + schema, err := cache.Get(datastore.NoRevision, datastore.SchemaHash("hash0")) + require.NoError(t, err) + require.NotNil(t, schema) + schema, err = cache.Get(datastore.NoRevision, datastore.SchemaHash("hash1")) + require.NoError(t, err) + require.NotNil(t, schema) + schema, err = cache.Get(datastore.NoRevision, datastore.SchemaHash("hash2")) + require.NoError(t, err) + require.NotNil(t, schema) + + // Add one more - this will evict the least recently used + // Since we just accessed hash2 (via the "latest" fast path), and hash0 and hash1 + // via LRU.Get(), the LRU order is: hash1, hash0, hash2 + // Adding hash3 will evict hash2 from the LRU + require.NoError(t, cache.Set(datastore.NoRevision, datastore.SchemaHash("hash3"), makeTestSchema("definition 3"))) + + // hash3 is now the newest + schema, err = cache.Get(datastore.NoRevision, datastore.SchemaHash("hash3")) + require.NoError(t, err) + require.NotNil(t, schema) + + // hash0 and hash1 should still be in the LRU + schema, err = cache.Get(datastore.NoRevision, datastore.SchemaHash("hash0")) + require.NoError(t, err) + require.NotNil(t, schema) + schema, err = cache.Get(datastore.NoRevision, datastore.SchemaHash("hash1")) + require.NoError(t, err) + require.NotNil(t, schema) + + // hash2 was evicted from LRU and is no longer in "latest" (hash3 is), so it should return nil + schema, err = cache.Get(datastore.NoRevision, datastore.SchemaHash("hash2")) + require.NoError(t, err) + require.Nil(t, schema) +} + +func TestSchemaHashCache_GetOrLoad(t *testing.T) { + cache, err := NewSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 10, + }) + require.NoError(t, err) + require.NotNil(t, cache) + + loadCalls := 0 + loader := func(ctx context.Context) (*core.StoredSchema, error) { + loadCalls++ + return makeTestSchema("loaded definition"), nil + } + + // First call should load + schema, err := cache.GetOrLoad(context.Background(), datastore.NoRevision, datastore.SchemaHash("hash1"), loader) + require.NoError(t, err) + require.NotNil(t, schema) + require.Equal(t, "loaded definition", schema.GetV1().SchemaText) + require.Equal(t, 1, loadCalls) + + // Second call should hit cache + schema, err = cache.GetOrLoad(context.Background(), datastore.NoRevision, datastore.SchemaHash("hash1"), loader) + require.NoError(t, err) + require.NotNil(t, schema) + require.Equal(t, "loaded definition", schema.GetV1().SchemaText) + require.Equal(t, 1, loadCalls) // Should not call loader again +} + +func TestSchemaHashCache_GetOrLoadEmptyHash(t *testing.T) { + cache, err := NewSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 10, + }) + require.NoError(t, err) + require.NotNil(t, cache) + + loader := func(ctx context.Context) (*core.StoredSchema, error) { + return makeTestSchema("loaded definition"), nil + } + + // Empty hash should panic (via MustBugf) in tests + require.Panics(t, func() { + _, _ = cache.GetOrLoad(context.Background(), datastore.NoRevision, datastore.SchemaHash(""), loader) + }, "empty hash should panic") +} + +func TestSchemaHashCache_Singleflight(t *testing.T) { + cache, err := NewSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 10, + }) + require.NoError(t, err) + require.NotNil(t, cache) + + loadCalls := 0 + loadStarted := make(chan struct{}) + loadContinue := make(chan struct{}) + + loader := func(ctx context.Context) (*core.StoredSchema, error) { + loadCalls++ + close(loadStarted) + <-loadContinue + return makeTestSchema("loaded definition"), nil + } + + // Start multiple concurrent loads + const numGoroutines = 10 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + results := make(chan error, numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + schema, err := cache.GetOrLoad(context.Background(), datastore.NoRevision, datastore.SchemaHash("hash1"), loader) + if err != nil { + results <- err + return + } + if schema == nil { + results <- fmt.Errorf("schema is nil") + return + } + if schema.GetV1().SchemaText != "loaded definition" { + results <- fmt.Errorf("unexpected schema text: %s", schema.GetV1().SchemaText) + return + } + results <- nil + }() + } + + // Wait for first load to start + <-loadStarted + + // Let the load complete + close(loadContinue) + + // Wait for all goroutines to finish + wg.Wait() + close(results) + + // Check all results + for err := range results { + require.NoError(t, err) + } + + // Should only have called loader once due to singleflight + require.Equal(t, 1, loadCalls) +} + +func TestSchemaHashCache_LoadError(t *testing.T) { + cache, err := NewSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 10, + }) + require.NoError(t, err) + require.NotNil(t, cache) + + expectedErr := fmt.Errorf("load failed") + loader := func(ctx context.Context) (*core.StoredSchema, error) { + return nil, expectedErr + } + + // Error should be propagated + schema, err := cache.GetOrLoad(context.Background(), datastore.NoRevision, datastore.SchemaHash("hash1"), loader) + require.Error(t, err) + require.Equal(t, expectedErr, err) + require.Nil(t, schema) + + // Failed load should not be cached + cached, err := cache.Get(datastore.NoRevision, datastore.SchemaHash("hash1")) + require.NoError(t, err) + require.Nil(t, cached) +} + +func TestSchemaHashCache_DefaultMaxEntries(t *testing.T) { + cache, err := NewSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 0, // Should default to 100 + }) + require.NoError(t, err) + require.NotNil(t, cache) + + // Add more than 100 entries to test default + for i := 0; i < 101; i++ { + require.NoError(t, cache.Set(datastore.NoRevision, datastore.SchemaHash(fmt.Sprintf("hash%d", i)), makeTestSchema(fmt.Sprintf("definition %d", i)))) + } + + // First entry should be evicted + schema, err := cache.Get(datastore.NoRevision, datastore.SchemaHash("hash0")) + require.NoError(t, err) + require.Nil(t, schema) + + // Last entry should be present + schema, err = cache.Get(datastore.NoRevision, datastore.SchemaHash("hash100")) + require.NoError(t, err) + require.NotNil(t, schema) +} + +func TestBypassSentinels_AllIncluded(t *testing.T) { + // This test ensures that all sentinel values defined in the datastore package + // are included in the bypassSentinels slice. If a new sentinel is added to the + // datastore package, it must also be added to the bypassSentinels slice in hashcache.go. + + allSentinels := []datastore.SchemaHash{ + datastore.NoSchemaHashInTransaction, + datastore.NoSchemaHashForTesting, + datastore.NoSchemaHashForWatch, + datastore.NoSchemaHashForLegacyCursor, + } + + // Verify each sentinel is in the bypassSentinels slice + for _, sentinel := range allSentinels { + require.True(t, isBypassSentinel(sentinel), + "sentinel %q is not in bypassSentinels slice", sentinel) + } + + // Verify bypassSentinels doesn't contain any extra values + require.Len(t, bypassSentinels, len(allSentinels), + "bypassSentinels should contain exactly %d entries", len(allSentinels)) +} + +func TestBypassSentinels_CacheBehavior(t *testing.T) { + cache, err := NewSchemaHashCache(options.SchemaCacheOptions{ + MaximumCacheEntries: 10, + }) + require.NoError(t, err) + + schema := makeTestSchema("definition user {}") + + // Test that each sentinel bypasses the cache in Get() + for sentinel := range bypassSentinels { + result, err := cache.Get(datastore.NoRevision, sentinel) + require.NoError(t, err) + require.Nil(t, result, "Get with sentinel %q should return nil", sentinel) + } + + // Test that each sentinel bypasses the cache in Set() + for sentinel := range bypassSentinels { + err := cache.Set(datastore.NoRevision, sentinel, schema) + require.NoError(t, err, "Set with sentinel %q should not error", sentinel) + + // Verify it wasn't actually cached + result, err := cache.Get(datastore.NoRevision, sentinel) + require.NoError(t, err) + require.Nil(t, result, "sentinel %q should not be cached", sentinel) + } + + // Test that each sentinel bypasses the cache in GetOrLoad() + for sentinel := range bypassSentinels { + loadCalled := false + result, err := cache.GetOrLoad(context.Background(), datastore.NoRevision, sentinel, + func(ctx context.Context) (*core.StoredSchema, error) { + loadCalled = true + return schema, nil + }) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, loadCalled, "loader should be called for sentinel %q", sentinel) + + // Verify it wasn't cached - loader should be called again + loadCalled = false + result, err = cache.GetOrLoad(context.Background(), datastore.NoRevision, sentinel, + func(ctx context.Context) (*core.StoredSchema, error) { + loadCalled = true + return schema, nil + }) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, loadCalled, "loader should be called again for sentinel %q (not cached)", sentinel) + } +} diff --git a/internal/datastore/common/sqlschema.go b/internal/datastore/common/sqlschema.go new file mode 100644 index 000000000..ab0e9bf1f --- /dev/null +++ b/internal/datastore/common/sqlschema.go @@ -0,0 +1,268 @@ +package common + +import ( + "context" + "errors" + "fmt" + "math" + "strings" + + "github.com/authzed/spicedb/internal/datastore/schema" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +const ( + // UnifiedSchemaName is the name used to store the unified schema in the schema table. + UnifiedSchemaName = "unified_schema" + + // LiveDeletedTxnID is the transaction ID value used to indicate a row has not been deleted yet. + // This is the maximum value for uint64, used by databases with tombstone-based deletion. + LiveDeletedTxnID = uint64(math.MaxInt64) +) + +// SQLSingleStoreSchemaReaderWriter implements both SingleStoreSchemaReader and SingleStoreSchemaWriter +// using the SQL byte chunking system. This provides a common implementation that SQL-based datastores +// can use. +// +// The type parameter T represents the transaction ID type used by the datastore: +// - uint64 for datastores with numeric transaction IDs (MySQL, Postgres) +// - any for datastores without transaction IDs (CRDB, Spanner in delete-and-insert mode) +type SQLSingleStoreSchemaReaderWriter[T any] struct { + chunker *SQLByteChunker[T] + transactionIDProvider func(ctx context.Context) T +} + +// NewSQLSingleStoreSchemaReaderWriter creates a new SQL-based single-store schema reader/writer. +// The transactionIDProvider function should return the appropriate transaction ID for the current context. +// For datastores that don't use transaction IDs, this should return the zero value. +func NewSQLSingleStoreSchemaReaderWriter[T any]( + chunker *SQLByteChunker[T], + transactionIDProvider func(ctx context.Context) T, +) *SQLSingleStoreSchemaReaderWriter[T] { + return &SQLSingleStoreSchemaReaderWriter[T]{ + chunker: chunker, + transactionIDProvider: transactionIDProvider, + } +} + +// ReadStoredSchema reads the stored schema from the unified schema table. +func (s *SQLSingleStoreSchemaReaderWriter[T]) ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) { + data, err := s.chunker.ReadChunkedBytes(ctx, UnifiedSchemaName) + if err != nil { + if isNoChunksFoundError(err) { + return nil, datastore.ErrSchemaNotFound + } + return nil, fmt.Errorf("failed to read schema: %w", err) + } + + storedSchema, err := schema.UnmarshalStoredSchema(data) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal schema: %w", err) + } + + return storedSchema, nil +} + +// WriteStoredSchema writes the stored schema to the unified schema table. +func (s *SQLSingleStoreSchemaReaderWriter[T]) WriteStoredSchema(ctx context.Context, storedSchema *core.StoredSchema) error { + data, err := schema.MarshalStoredSchema(storedSchema) + if err != nil { + return fmt.Errorf("failed to marshal schema: %w", err) + } + + transactionID := s.transactionIDProvider(ctx) + if err := s.chunker.WriteChunkedBytes(ctx, UnifiedSchemaName, data, transactionID); err != nil { + return fmt.Errorf("failed to write schema: %w", err) + } + + return nil +} + +// isNoChunksFoundError checks if the error indicates no chunks were found. +func isNoChunksFoundError(err error) bool { + if err == nil { + return false + } + errMsg := err.Error() + // Check both the error message and wrapped errors + return strings.Contains(errMsg, "no chunks found") || strings.Contains(errMsg, "failed to reassemble chunks") +} + +var ( + _ datastore.SingleStoreSchemaReader = (*SQLSingleStoreSchemaReaderWriter[uint64])(nil) + _ datastore.SingleStoreSchemaWriter = (*SQLSingleStoreSchemaReaderWriter[uint64])(nil) + _ datastore.SingleStoreSchemaReader = (*SQLSingleStoreSchemaReaderWriter[any])(nil) + _ datastore.SingleStoreSchemaWriter = (*SQLSingleStoreSchemaReaderWriter[any])(nil) +) + +// NoTransactionID is a helper function that returns a zero-value transaction ID. +// This can be used for datastores that don't track transaction IDs. +func NoTransactionID[T any](_ context.Context) T { + var zero T + return zero +} + +// StaticTransactionID returns a function that always returns the given transaction ID. +// This is useful for testing or for datastores that use a constant value. +func StaticTransactionID[T any](id T) func(context.Context) T { + return func(_ context.Context) T { + return id + } +} + +// NewSQLSingleStoreSchemaReaderWriterForTransactionIDs is a convenience function for creating a schema reader/writer +// that uses uint64 transaction IDs for datastores with explicit transaction ID tracking (e.g., MySQL, Postgres). +func NewSQLSingleStoreSchemaReaderWriterForTransactionIDs( + chunker *SQLByteChunker[uint64], + transactionIDProvider func(ctx context.Context) uint64, +) *SQLSingleStoreSchemaReaderWriter[uint64] { + return NewSQLSingleStoreSchemaReaderWriter(chunker, transactionIDProvider) +} + +// NewSQLSingleStoreSchemaReaderWriterWithBuiltInMVCC is a convenience function for creating a schema reader/writer +// for datastores with built-in MVCC that don't require explicit transaction ID tracking (e.g., CRDB, Spanner). +func NewSQLSingleStoreSchemaReaderWriterWithBuiltInMVCC( + chunker *SQLByteChunker[any], +) *SQLSingleStoreSchemaReaderWriter[any] { + return NewSQLSingleStoreSchemaReaderWriter(chunker, NoTransactionID[any]) +} + +// ValidateStoredSchema validates that a stored schema is well-formed. +func ValidateStoredSchema(storedSchema *core.StoredSchema) error { + if storedSchema == nil { + return errors.New("stored schema is nil") + } + + if storedSchema.Version == 0 { + return errors.New("stored schema version is 0") + } + + v1 := storedSchema.GetV1() + if v1 == nil { + return fmt.Errorf("unsupported schema version: %d", storedSchema.Version) + } + + if v1.SchemaText == "" { + return errors.New("schema text is empty") + } + + if v1.SchemaHash == "" { + return errors.New("schema hash is empty") + } + + return nil +} + +// SQLSchemaReaderWriter provides a cached implementation for reading and writing schemas +// across SQL-based datastores. It caches the chunker configuration and creates temporary +// chunkers with the appropriate executors for each operation. +// +// Type parameters: +// - T: the transaction ID type (uint64 for Postgres/MySQL, any for CRDB/Spanner) +// - R: the revision type (must implement datastore.Revision) +type SQLSchemaReaderWriter[T any, R datastore.Revision] struct { + chunkerConfig SQLByteChunkerConfig[T] + cacheOptions options.SchemaCacheOptions + cache *SchemaHashCache +} + +// NewSQLSchemaReaderWriter creates a new SQLSchemaReaderWriter with the given chunker configuration and cache options. +// The configuration is cached and reused for all read/write operations. +func NewSQLSchemaReaderWriter[T any, R datastore.Revision]( + chunkerConfig SQLByteChunkerConfig[T], + cacheOptions options.SchemaCacheOptions, +) (*SQLSchemaReaderWriter[T, R], error) { + cache, err := NewSchemaHashCache(cacheOptions) + if err != nil { + return nil, fmt.Errorf("failed to create schema cache: %w", err) + } + + return &SQLSchemaReaderWriter[T, R]{ + chunkerConfig: chunkerConfig, + cacheOptions: cacheOptions, + cache: cache, + }, nil +} + +// Close cleans up resources used by the schema reader/writer. +// No-op for hash-based cache. +func (s *SQLSchemaReaderWriter[T, R]) Close() { + // No resources to clean up for hash-based cache +} + +// ReadSchema reads the stored schema using the provided executor. +// The executor determines how the read operation is performed (e.g., with revision awareness). +// The revision and schemaHash parameters are used for cache lookup. If hash is a bypass sentinel +// (NoSchemaHashInTransaction, NoSchemaHashForTesting, or NoSchemaHashForWatch), the cache is bypassed +// for reads (but the result is still loaded). +func (s *SQLSchemaReaderWriter[T, R]) ReadSchema(ctx context.Context, executor ChunkedBytesExecutor, rev datastore.Revision, schemaHash datastore.SchemaHash) (*core.StoredSchema, error) { + // Use GetOrLoad pattern - it handles both cache lookup and loading + loader := func(ctx context.Context) (*core.StoredSchema, error) { + return s.readSchemaFromDatastore(ctx, executor) + } + + return s.cache.GetOrLoad(ctx, rev, schemaHash, loader) +} + +// readSchemaFromDatastore reads the schema directly from the datastore without cache. +func (s *SQLSchemaReaderWriter[T, R]) readSchemaFromDatastore(ctx context.Context, executor ChunkedBytesExecutor) (*core.StoredSchema, error) { + // Create a temporary chunker with the provided executor + chunker, err := NewSQLByteChunker(s.chunkerConfig.WithExecutor(executor)) + if err != nil { + return nil, fmt.Errorf("failed to create SQL byte chunker: %w", err) + } + + // Read and reassemble the schema chunks + data, err := chunker.ReadChunkedBytes(ctx, UnifiedSchemaName) + if err != nil { + if isNoChunksFoundError(err) { + return nil, datastore.ErrSchemaNotFound + } + return nil, fmt.Errorf("failed to read schema: %w", err) + } + + // Unmarshal the stored schema + storedSchema := &core.StoredSchema{} + if err := storedSchema.UnmarshalVT(data); err != nil { + return nil, fmt.Errorf("failed to unmarshal schema: %w", err) + } + + return storedSchema, nil +} + +// WriteSchema writes the stored schema using the provided executor and transaction ID provider. +// The executor determines how the write operation is performed (e.g., within a transaction). +// The transactionIDProvider returns the transaction ID to use for tombstone-based datastores. +func (s *SQLSchemaReaderWriter[T, R]) WriteSchema(ctx context.Context, schema *core.StoredSchema, executor ChunkedBytesExecutor, transactionIDProvider func(ctx context.Context) T) error { + if schema == nil { + return errors.New("stored schema cannot be nil") + } + + if schema.Version == 0 { + return errors.New("stored schema version cannot be 0") + } + + // Create a temporary chunker with the provided executor + chunker, err := NewSQLByteChunker(s.chunkerConfig.WithExecutor(executor)) + if err != nil { + return fmt.Errorf("failed to create SQL byte chunker: %w", err) + } + + // Marshal the schema + data, err := schema.MarshalVT() + if err != nil { + return fmt.Errorf("failed to marshal schema: %w", err) + } + + // Get the transaction ID (if applicable) + transactionID := transactionIDProvider(ctx) + + // Write the schema chunks + if err := chunker.WriteChunkedBytes(ctx, UnifiedSchemaName, data, transactionID); err != nil { + return fmt.Errorf("failed to write schema: %w", err) + } + + return nil +} diff --git a/internal/datastore/common/sqlschema_test.go b/internal/datastore/common/sqlschema_test.go new file mode 100644 index 000000000..136c42374 --- /dev/null +++ b/internal/datastore/common/sqlschema_test.go @@ -0,0 +1,531 @@ +package common + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + + sq "github.com/Masterminds/squirrel" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + "github.com/authzed/spicedb/internal/datastore/schema" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +func TestSQLSingleStoreSchemaReaderWriter_WriteAndRead(t *testing.T) { + tests := []struct { + name string + schemaText string + namespaces map[string]*core.NamespaceDefinition + caveats map[string]*core.CaveatDefinition + useTransaction bool + }{ + { + name: "simple schema", + schemaText: "definition user {}", + namespaces: map[string]*core.NamespaceDefinition{ + "user": {Name: "user"}, + }, + caveats: map[string]*core.CaveatDefinition{}, + useTransaction: true, + }, + { + name: "schema with caveat", + schemaText: "caveat is_allowed(allowed bool) { allowed }\ndefinition user {}", + namespaces: map[string]*core.NamespaceDefinition{ + "user": {Name: "user"}, + }, + caveats: map[string]*core.CaveatDefinition{ + "is_allowed": {Name: "is_allowed"}, + }, + useTransaction: true, + }, + { + name: "empty schema", + schemaText: "", + namespaces: map[string]*core.NamespaceDefinition{}, + caveats: map[string]*core.CaveatDefinition{}, + useTransaction: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Build stored schema + storedSchema := &core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: tt.schemaText, + SchemaHash: "test-hash", + NamespaceDefinitions: tt.namespaces, + CaveatDefinitions: tt.caveats, + }, + }, + } + + // Marshal the schema to simulate what would be written + expectedData, err := proto.Marshal(storedSchema) + require.NoError(t, err) + + // Create chunker with fake executor + txn := &fakeTransaction{} + executor := &fakeExecutor{ + transaction: txn, + readResult: map[int][]byte{0: expectedData}, // Return the expected data + } + + chunker := MustNewSQLByteChunker(SQLByteChunkerConfig[uint64]{ + TableName: "schema", + NameColumn: "name", + ChunkIndexColumn: "chunk_index", + ChunkDataColumn: "chunk_data", + MaxChunkSize: 1024 * 64, + PlaceholderFormat: sq.Question, + Executor: executor, + WriteMode: WriteModeInsertWithTombstones, + CreatedAtColumn: "created_at", + DeletedAtColumn: "deleted_at", + AliveValue: 999, + }) + + transactionIDProvider := func(ctx context.Context) uint64 { + if tt.useTransaction { + return 100 + } + return 0 + } + + readerWriter := NewSQLSingleStoreSchemaReaderWriter(chunker, transactionIDProvider) + + // Write schema + ctx := context.Background() + err = readerWriter.WriteStoredSchema(ctx, storedSchema) + require.NoError(t, err) + + // Verify write was called + require.NotEmpty(t, txn.capturedSQL) + + // Read schema back + readSchema, err := readerWriter.ReadStoredSchema(ctx) + require.NoError(t, err) + require.NotNil(t, readSchema) + + // Verify + require.Equal(t, storedSchema.Version, readSchema.Version) + require.NotNil(t, readSchema.GetV1()) + require.Equal(t, tt.schemaText, readSchema.GetV1().SchemaText) + require.Equal(t, "test-hash", readSchema.GetV1().SchemaHash) + require.Len(t, readSchema.GetV1().NamespaceDefinitions, len(tt.namespaces)) + require.Len(t, readSchema.GetV1().CaveatDefinitions, len(tt.caveats)) + }) + } +} + +func TestSQLSingleStoreSchemaReaderWriter_ReadNotFound(t *testing.T) { + // Create chunker with fake executor that returns empty result + executor := &fakeExecutor{ + readResult: map[int][]byte{}, // No chunks found + } + chunker := MustNewSQLByteChunker(SQLByteChunkerConfig[uint64]{ + TableName: "schema", + NameColumn: "name", + ChunkIndexColumn: "chunk_index", + ChunkDataColumn: "chunk_data", + MaxChunkSize: 1024 * 64, + PlaceholderFormat: sq.Question, + Executor: executor, + WriteMode: WriteModeDeleteAndInsert, + }) + + readerWriter := NewSQLSingleStoreSchemaReaderWriter(chunker, StaticTransactionID[uint64](0)) + + // Try to read non-existent schema + ctx := context.Background() + _, err := readerWriter.ReadStoredSchema(ctx) + require.Error(t, err) + require.ErrorIs(t, err, datastore.ErrSchemaNotFound) +} + +func TestSQLSingleStoreSchemaReaderWriter_WithBuiltInMVCC(t *testing.T) { + // Test with "any" type parameter for datastores with built-in MVCC (like CRDB/Spanner) + storedSchema := &core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: "definition user {}", + SchemaHash: "test-hash", + NamespaceDefinitions: map[string]*core.NamespaceDefinition{"user": {Name: "user"}}, + CaveatDefinitions: map[string]*core.CaveatDefinition{}, + }, + }, + } + + expectedData, err := storedSchema.MarshalVT() + require.NoError(t, err) + + txn := &fakeTransaction{} + executor := &fakeExecutor{ + transaction: txn, + readResult: map[int][]byte{0: expectedData}, + } + + chunker := MustNewSQLByteChunker(SQLByteChunkerConfig[any]{ + TableName: "schema", + NameColumn: "name", + ChunkIndexColumn: "chunk_index", + ChunkDataColumn: "chunk_data", + MaxChunkSize: 1024 * 64, + PlaceholderFormat: sq.Question, + Executor: executor, + WriteMode: WriteModeDeleteAndInsert, + }) + + readerWriter := NewSQLSingleStoreSchemaReaderWriterWithBuiltInMVCC(chunker) + + // Write schema + ctx := context.Background() + err = readerWriter.WriteStoredSchema(ctx, storedSchema) + require.NoError(t, err) + + // Read schema back + readSchema, err := readerWriter.ReadStoredSchema(ctx) + require.NoError(t, err) + require.NotNil(t, readSchema) + require.Equal(t, "definition user {}", readSchema.GetV1().SchemaText) +} + +func TestValidateStoredSchema(t *testing.T) { + tests := []struct { + name string + schema *core.StoredSchema + expectError bool + errorMsg string + }{ + { + name: "nil schema", + schema: nil, + expectError: true, + errorMsg: "stored schema is nil", + }, + { + name: "zero version", + schema: &core.StoredSchema{ + Version: 0, + }, + expectError: true, + errorMsg: "stored schema version is 0", + }, + { + name: "missing v1", + schema: &core.StoredSchema{ + Version: 1, + }, + expectError: true, + errorMsg: "unsupported schema version", + }, + { + name: "empty schema text", + schema: &core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: "", + SchemaHash: "hash", + }, + }, + }, + expectError: true, + errorMsg: "schema text is empty", + }, + { + name: "empty schema hash", + schema: &core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: "definition user {}", + SchemaHash: "", + }, + }, + }, + expectError: true, + errorMsg: "schema hash is empty", + }, + { + name: "valid schema", + schema: &core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: "definition user {}", + SchemaHash: "hash", + NamespaceDefinitions: map[string]*core.NamespaceDefinition{}, + CaveatDefinitions: map[string]*core.CaveatDefinition{}, + }, + }, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateStoredSchema(tt.schema) + if tt.expectError { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestBuildAndMarshal(t *testing.T) { + definitions := []datastore.SchemaDefinition{ + &core.NamespaceDefinition{Name: "user"}, + &core.CaveatDefinition{Name: "is_allowed"}, + } + schemaText := "definition user {}\ncaveat is_allowed(allowed bool) { allowed }" + + // Build stored schema + storedSchema, err := schema.BuildStoredSchemaFromDefinitions(definitions, schemaText) + require.NoError(t, err) + require.NotNil(t, storedSchema) + require.Equal(t, uint32(1), storedSchema.Version) + + v1 := storedSchema.GetV1() + require.NotNil(t, v1) + require.Equal(t, schemaText, v1.SchemaText) + require.Len(t, v1.NamespaceDefinitions, 1) + require.Len(t, v1.CaveatDefinitions, 1) + + // Marshal + data, err := schema.MarshalStoredSchema(storedSchema) + require.NoError(t, err) + require.NotEmpty(t, data) + + // Unmarshal + unmarshaled, err := schema.UnmarshalStoredSchema(data) + require.NoError(t, err) + require.NotNil(t, unmarshaled) + require.True(t, storedSchema.EqualVT(unmarshaled)) +} + +func TestHelperFunctions(t *testing.T) { + // Test NoTransactionID + ctx := context.Background() + result := NoTransactionID[uint64](ctx) + require.Equal(t, uint64(0), result) + + // Test StaticTransactionID + staticFunc := StaticTransactionID[uint64](12345) + require.Equal(t, uint64(12345), staticFunc(ctx)) + + // Test with any type + anyResult := NoTransactionID[any](ctx) + require.Nil(t, anyResult) +} + +func TestSQLSchemaReaderWriter_Singleflight(t *testing.T) { + t.Run("uses singleflight for revisioned reads", func(t *testing.T) { + // Create a schema to return + storedSchema := &core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: "definition user {}", + SchemaHash: "hash1", + }, + }, + } + data, err := storedSchema.MarshalVT() + require.NoError(t, err) + + // Create executor that counts reads (use atomic for thread safety) + var readCount atomic.Int32 + executor := &fakeExecutor{ + readResult: map[int][]byte{0: data}, + onRead: func() { + readCount.Add(1) + }, + } + + // Create config and schema reader/writer (cache disabled to test singleflight only) + config := SQLByteChunkerConfig[uint64]{ + TableName: "schema", + NameColumn: "name", + ChunkIndexColumn: "chunk_index", + ChunkDataColumn: "chunk_data", + MaxChunkSize: 1024 * 64, + PlaceholderFormat: sq.Question, + Executor: executor, + WriteMode: WriteModeDeleteAndInsert, + } + + schemaRW, err := NewSQLSchemaReaderWriter[uint64, testRevision]( + config, + options.SchemaCacheOptions{ + MaximumCacheEntries: 10, + }, + ) + require.NoError(t, err) + + // Make 10 concurrent reads for the same hash + const numReads = 10 + + type result struct { + schema *core.StoredSchema + err error + } + results := make(chan result, numReads) + + // Use WaitGroup to ensure all goroutines start at the same time + var ready sync.WaitGroup + ready.Add(numReads) + var start sync.WaitGroup + start.Add(1) + + for i := 0; i < numReads; i++ { + go func() { + ready.Done() + start.Wait() // Wait for all goroutines to be ready + schema, err := schemaRW.ReadSchema(context.Background(), executor, testRevision{id: 1}, datastore.SchemaHash("test-hash")) + results <- result{schema: schema, err: err} + }() + } + + // Wait for all goroutines to be ready, then release them all at once + ready.Wait() + start.Done() + + // Collect results + for i := 0; i < numReads; i++ { + res := <-results + require.NoError(t, res.err) + require.NotNil(t, res.schema) + require.Equal(t, "hash1", res.schema.GetV1().SchemaHash) + } + + // CRITICAL: Should only have 1 actual datastore read due to singleflight + // Allow up to 2 reads due to timing variations + require.LessOrEqual(t, readCount.Load(), int32(2), "singleflight should deduplicate most concurrent revisioned reads") + }) + + t.Run("does not use singleflight for transactional reads", func(t *testing.T) { + // Create a schema to return + storedSchema := &core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: "definition user {}", + SchemaHash: "hash1", + }, + }, + } + data, err := storedSchema.MarshalVT() + require.NoError(t, err) + + // Create executor that counts reads (use atomic for thread safety) + var readCount atomic.Int32 + executor := &fakeExecutor{ + readResult: map[int][]byte{0: data}, + onRead: func() { + readCount.Add(1) + }, + } + + // Create config and schema reader/writer + config := SQLByteChunkerConfig[uint64]{ + TableName: "schema", + NameColumn: "name", + ChunkIndexColumn: "chunk_index", + ChunkDataColumn: "chunk_data", + MaxChunkSize: 1024 * 64, + PlaceholderFormat: sq.Question, + Executor: executor, + WriteMode: WriteModeDeleteAndInsert, + } + + schemaRW, err := NewSQLSchemaReaderWriter[uint64, testRevision]( + config, + options.SchemaCacheOptions{ + MaximumCacheEntries: 0, + }, + ) + require.NoError(t, err) + + // Make 5 concurrent reads with nil revision (transactional reads) + const numReads = 5 + + type result struct { + schema *core.StoredSchema + err error + } + results := make(chan result, numReads) + + for i := 0; i < numReads; i++ { + go func() { + schema, err := schemaRW.ReadSchema(context.Background(), executor, testRevision{id: 0}, datastore.NoSchemaHashInTransaction) + results <- result{schema: schema, err: err} + }() + } + + // Collect results + for i := 0; i < numReads; i++ { + res := <-results + require.NoError(t, res.err) + require.NotNil(t, res.schema) + } + + // CRITICAL: Should have 5 actual datastore reads (no singleflight for nil revision) + require.Equal(t, int32(numReads), readCount.Load(), "transactional reads should NOT use singleflight") + }) +} + +// testRevision is a simple test revision type +type testRevision struct { + id int64 +} + +func (r testRevision) Equal(other datastore.Revision) bool { + otherTest, ok := other.(testRevision) + if !ok { + return false + } + return r.id == otherTest.id +} + +func (r testRevision) GreaterThan(other datastore.Revision) bool { + otherTest, ok := other.(testRevision) + if !ok { + return false + } + return r.id > otherTest.id +} + +func (r testRevision) LessThan(other datastore.Revision) bool { + otherTest, ok := other.(testRevision) + if !ok { + return false + } + return r.id < otherTest.id +} + +func (r testRevision) String() string { + return fmt.Sprintf("test-%d", r.id) +} + +func (r testRevision) Key() string { + return r.String() +} + +func (r testRevision) ByteSortable() bool { + return false +} diff --git a/internal/datastore/context.go b/internal/datastore/context.go index 110870097..2a9b91728 100644 --- a/internal/datastore/context.go +++ b/internal/datastore/context.go @@ -7,6 +7,7 @@ import ( "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/datastore/test" core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" ) // NewSeparatingContextDatastoreProxy severs any timeouts in the context being @@ -50,7 +51,7 @@ func (p *ctxProxy) IsStrictReadModeEnabled() bool { return false } -func (p *ctxProxy) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { +func (p *ctxProxy) OptimizedRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { return p.delegate.OptimizedRevision(context.WithoutCancel(ctx)) } @@ -58,7 +59,7 @@ func (p *ctxProxy) CheckRevision(ctx context.Context, revision datastore.Revisio return p.delegate.CheckRevision(context.WithoutCancel(ctx), revision) } -func (p *ctxProxy) HeadRevision(ctx context.Context) (datastore.Revision, error) { +func (p *ctxProxy) HeadRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { return p.delegate.HeadRevision(context.WithoutCancel(ctx)) } @@ -88,8 +89,8 @@ func (p *ctxProxy) ReadyState(ctx context.Context) (datastore.ReadyState, error) func (p *ctxProxy) Close() error { return p.delegate.Close() } -func (p *ctxProxy) SnapshotReader(rev datastore.Revision) datastore.Reader { - delegateReader := p.delegate.SnapshotReader(rev) +func (p *ctxProxy) SnapshotReader(rev datastore.Revision, schemaHash datastore.SchemaHash) datastore.Reader { + delegateReader := p.delegate.SnapshotReader(rev, schemaHash) return &ctxReader{delegateReader} } @@ -148,7 +149,18 @@ func (r *ctxReader) SchemaReader() (datastore.SchemaReader, error) { return r.delegate.SchemaReader() } +func (r *ctxReader) ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) { + singleStoreReader, ok := r.delegate.(datastore.SingleStoreSchemaReader) + if !ok { + return nil, spiceerrors.MustBugf("delegate reader does not implement SingleStoreSchemaReader") + } + return singleStoreReader.ReadStoredSchema(context.WithoutCancel(ctx)) +} + var ( - _ datastore.Datastore = (*ctxProxy)(nil) - _ datastore.Reader = (*ctxReader)(nil) + _ datastore.Datastore = (*ctxProxy)(nil) + _ datastore.Reader = (*ctxReader)(nil) + _ datastore.LegacySchemaReader = (*ctxReader)(nil) + _ datastore.SingleStoreSchemaReader = (*ctxReader)(nil) + _ datastore.DualSchemaReader = (*ctxReader)(nil) ) diff --git a/internal/datastore/crdb/caveat.go b/internal/datastore/crdb/caveat.go index 5857a5702..0fc0fb914 100644 --- a/internal/datastore/crdb/caveat.go +++ b/internal/datastore/crdb/caveat.go @@ -155,6 +155,7 @@ func (rwt *crdbReadWriteTXN) LegacyWriteCaveats(ctx context.Context, caveats []* if _, err := rwt.tx.Exec(ctx, sql, args...); err != nil { return fmt.Errorf(errWriteCaveat, err) } + return nil } @@ -170,5 +171,6 @@ func (rwt *crdbReadWriteTXN) LegacyDeleteCaveats(ctx context.Context, names []st if _, err := rwt.tx.Exec(ctx, sql, args...); err != nil { return fmt.Errorf(errDeleteCaveats, err) } + return nil } diff --git a/internal/datastore/crdb/crdb.go b/internal/datastore/crdb/crdb.go index c47bd3e55..dd0533c80 100644 --- a/internal/datastore/crdb/crdb.go +++ b/internal/datastore/crdb/crdb.go @@ -61,6 +61,16 @@ const ( queryTransactionNow = "SHOW COMMIT TIMESTAMP" queryShowZoneConfig = "SHOW ZONE CONFIGURATION FOR RANGE default;" + // Query to get the current HLC timestamp along with the latest schema_hash + querySelectNowWithHash = ` + WITH current_ts AS ( + SELECT cluster_logical_timestamp() as ts + ) + SELECT + current_ts.ts, + COALESCE((SELECT hash FROM schema_revision WHERE name = 'current' ORDER BY timestamp DESC LIMIT 1), ''::bytea) + FROM current_ts;` + spicedbTransactionKey = "$spicedb_transaction_key" ) @@ -197,9 +207,14 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas supportsIntegrity: config.withIntegrity, gcWindow: config.gcWindow, watchEnabled: !config.watchDisabled, + schemaMode: config.schemaMode, schema: *schema.Schema(config.columnOptimizationOption, config.withIntegrity, false), } - ds.SetNowFunc(ds.headRevisionInternal) + // Wrap headRevisionInternal to match RemoteNowFunction signature + ds.SetNowFunc(func(ctx context.Context) (datastore.Revision, error) { + rev, _, err := ds.headRevisionInternal(ctx) + return rev, err + }) // this ctx and cancel is tied to the lifetime of the datastore ds.ctx, ds.cancel = context.WithCancel(context.Background()) @@ -214,6 +229,13 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, url) } + // Initialize schema reader/writer + ds.schemaReaderWriter, err = common.NewSQLSchemaReaderWriter[any, revisions.HLCRevision](BaseSchemaChunkerConfig, config.schemaCacheOptions) + if err != nil { + ds.cancel() + return nil, err + } + err = ds.registerPrometheusCollectors(config.enablePrometheusStats) if err != nil { ds.cancel() @@ -243,6 +265,11 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas }) } + // Warm the schema cache on startup + if err := warmSchemaCache(initCtx, ds); err != nil { + log.Warn().Err(err).Msg("failed to warm schema cache on startup") + } + return ds, nil } @@ -290,10 +317,14 @@ type crdbDatastore struct { supportsIntegrity bool watchEnabled bool + // SQLSchemaReaderWriter for schema operations + schemaReaderWriter *common.SQLSchemaReaderWriter[any, revisions.HLCRevision] + schemaMode options.SchemaMode + uniqueID atomic.Pointer[string] } -func (cds *crdbDatastore) SnapshotReader(rev datastore.Revision) datastore.Reader { +func (cds *crdbDatastore) SnapshotReader(rev datastore.Revision, hash datastore.SchemaHash) datastore.Reader { executor := common.QueryRelationshipsExecutor{ Executor: pgxcommon.NewPGXQueryRelationshipsExecutor(cds.readPool, cds), } @@ -306,6 +337,10 @@ func (cds *crdbDatastore) SnapshotReader(rev datastore.Revision) datastore.Reade filterMaximumIDCount: cds.filterMaximumIDCount, withIntegrity: cds.supportsIntegrity, atSpecificRevision: rev.String(), + schemaMode: cds.schemaMode, + snapshotRevision: rev, + schemaHash: string(hash), + schemaReaderWriter: cds.schemaReaderWriter, } } @@ -340,13 +375,17 @@ func (cds *crdbDatastore) ReadWriteTx( filterMaximumIDCount: cds.filterMaximumIDCount, withIntegrity: cds.supportsIntegrity, atSpecificRevision: "", // No AS OF SYSTEM TIME for writes + schemaMode: cds.schemaMode, + snapshotRevision: datastore.NoRevision, // Revision not known until commit + schemaHash: string(datastore.NoSchemaHashInTransaction), // Bypass cache for transaction reads + schemaReaderWriter: cds.schemaReaderWriter, } rwt := &crdbReadWriteTXN{ - reader, - tx, - 0, - false, + crdbReader: reader, + tx: tx, + relCountChange: 0, + hasNonExpiredDeletionChange: false, } if err := f(ctx, rwt); err != nil { @@ -491,20 +530,49 @@ func (cds *crdbDatastore) Close() error { return errors.Join(errs...) } -func (cds *crdbDatastore) HeadRevision(ctx context.Context) (datastore.Revision, error) { - return cds.headRevisionInternal(ctx) +// SchemaHashReaderForTesting returns a test-only interface for reading the schema hash directly from schema_revision table. +func (cds *crdbDatastore) SchemaHashReaderForTesting() interface { + ReadSchemaHash(ctx context.Context) (string, error) +} { + return &crdbSchemaHashReaderForTesting{query: cds.readPool} } -func (cds *crdbDatastore) headRevisionInternal(ctx context.Context) (datastore.Revision, error) { - var hlcNow datastore.Revision +// SchemaModeForTesting returns the current schema mode for testing purposes. +func (cds *crdbDatastore) SchemaModeForTesting() (options.SchemaMode, error) { + return cds.schemaMode, nil +} - var fnErr error - hlcNow, fnErr = readCRDBNow(ctx, cds.readPool) +type crdbSchemaHashReaderForTesting struct { + query *pool.RetryPool +} + +func (r *crdbSchemaHashReaderForTesting) ReadSchemaHash(ctx context.Context) (string, error) { + var hashBytes []byte + + err := r.query.QueryRowFunc(ctx, func(ctx context.Context, row pgx.Row) error { + return row.Scan(&hashBytes) + }, "SELECT hash FROM schema_revision WHERE name = 'current'") + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return "", datastore.ErrSchemaNotFound + } + return "", fmt.Errorf("failed to query schema hash: %w", err) + } + + return string(hashBytes), nil +} + +func (cds *crdbDatastore) HeadRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { + return cds.headRevisionInternal(ctx) +} + +func (cds *crdbDatastore) headRevisionInternal(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { + hlcNow, schemaHash, fnErr := readCRDBNow(ctx, cds.readPool) if fnErr != nil { - return datastore.NoRevision, fmt.Errorf(errRevision, fnErr) + return datastore.NoRevision, "", fmt.Errorf(errRevision, fnErr) } - return hlcNow, fnErr + return hlcNow, schemaHash, fnErr } func (cds *crdbDatastore) OfflineFeatures() (*datastore.Features, error) { @@ -571,7 +639,7 @@ func (cds *crdbDatastore) features(ctx context.Context) (*datastore.Features, er features.IntegrityData.Status = datastore.FeatureSupported } - head, err := cds.HeadRevision(ctx) + head, _, err := cds.HeadRevision(ctx) if err != nil { return nil, err } @@ -618,18 +686,24 @@ func (cds *crdbDatastore) readTransactionCommitRev(ctx context.Context, reader p return revisions.NewForHLC(hlcNow) } -func readCRDBNow(ctx context.Context, reader pgxcommon.DBFuncQuerier) (datastore.Revision, error) { +func readCRDBNow(ctx context.Context, reader pgxcommon.DBFuncQuerier) (datastore.Revision, datastore.SchemaHash, error) { ctx, span := tracer.Start(ctx, "readCRDBNow") defer span.End() var hlcNow decimal.Decimal + var schemaHash []byte if err := reader.QueryRowFunc(ctx, func(ctx context.Context, row pgx.Row) error { - return row.Scan(&hlcNow) - }, querySelectNow); err != nil { - return datastore.NoRevision, fmt.Errorf("unable to read timestamp: %w", err) + return row.Scan(&hlcNow, &schemaHash) + }, querySelectNowWithHash); err != nil { + return datastore.NoRevision, "", fmt.Errorf("unable to read timestamp and schema hash: %w", err) } - return revisions.NewForHLC(hlcNow) + rev, err := revisions.NewForHLC(hlcNow) + if err != nil { + return datastore.NoRevision, "", err + } + + return rev, datastore.SchemaHash(schemaHash), nil } func readClusterTTLNanos(ctx context.Context, conn pgxcommon.DBFuncQuerier) (int64, error) { @@ -654,6 +728,39 @@ func readClusterTTLNanos(ctx context.Context, conn pgxcommon.DBFuncQuerier) (int return gcSeconds * 1_000_000_000, nil } +// warmSchemaCache attempts to warm the schema cache by loading the current schema. +// This is called during datastore initialization to avoid cold-start latency on first requests. +func warmSchemaCache(ctx context.Context, ds *crdbDatastore) error { + // Get the current revision and schema hash + rev, schemaHash, err := ds.HeadRevision(ctx) + if err != nil { + return fmt.Errorf("failed to get head revision: %w", err) + } + + // If there's no schema hash, there's no schema to warm + if schemaHash == "" { + log.Ctx(ctx).Debug().Msg("no schema hash found, skipping cache warming") + return nil + } + + // Create a simple executor for schema reading (no transaction, no revision filtering needed for warmup) + executor := newCRDBChunkedBytesExecutor(ds.readPool) + + // Load the schema to populate the cache + _, err = ds.schemaReaderWriter.ReadSchema(ctx, executor, rev, schemaHash) + if err != nil { + if errors.Is(err, datastore.ErrSchemaNotFound) { + // Schema not found is not an error during warming - just means no schema yet + log.Ctx(ctx).Debug().Msg("no schema found, skipping cache warming") + return nil + } + return fmt.Errorf("failed to read schema: %w", err) + } + + log.Ctx(ctx).Info().Str("schema_hash", string(schemaHash)).Msg("schema cache warmed successfully") + return nil +} + func (cds *crdbDatastore) registerPrometheusCollectors(enablePrometheusStats bool) error { if !enablePrometheusStats { return nil @@ -665,9 +772,14 @@ func (cds *crdbDatastore) registerPrometheusCollectors(enablePrometheusStats boo }) if err := prometheus.Register(readCollector); err != nil { - return fmt.Errorf("failed to register prometheus read collector: %w", err) + // Ignore AlreadyRegisteredError which can happen in tests + var alreadyRegistered prometheus.AlreadyRegisteredError + if !errors.As(err, &alreadyRegistered) { + return fmt.Errorf("failed to register prometheus read collector: %w", err) + } + } else { + cds.collectors = append(cds.collectors, readCollector) } - cds.collectors = append(cds.collectors, readCollector) writeCollector := pgxpoolprometheus.NewCollector(cds.writePool, map[string]string{ "db_name": "spicedb", @@ -675,9 +787,14 @@ func (cds *crdbDatastore) registerPrometheusCollectors(enablePrometheusStats boo }) if err := prometheus.Register(writeCollector); err != nil { - return fmt.Errorf("failed to register prometheus write collector: %w", err) + // Ignore AlreadyRegisteredError which can happen in tests + var alreadyRegistered prometheus.AlreadyRegisteredError + if !errors.As(err, &alreadyRegistered) { + return fmt.Errorf("failed to register prometheus write collector: %w", err) + } + } else { + cds.collectors = append(cds.collectors, writeCollector) } - cds.collectors = append(cds.collectors, writeCollector) return nil } diff --git a/internal/datastore/crdb/crdb_test.go b/internal/datastore/crdb/crdb_test.go index e2b43de72..41e8e9d7a 100644 --- a/internal/datastore/crdb/crdb_test.go +++ b/internal/datastore/crdb/crdb_test.go @@ -173,10 +173,10 @@ func TestCRDBDatastoreWithFollowerReads(t *testing.T) { // Revisions should be at least the follower read delay amount in the past for start := time.Now(); time.Since(start) < 50*time.Millisecond; { - testRevision, err := ds.OptimizedRevision(ctx) + testRevision, _, err := ds.OptimizedRevision(ctx) require.NoError(t, err) - nowRevision, err := ds.HeadRevision(ctx) + nowRevision, _, err := ds.HeadRevision(ctx) require.NoError(t, err) diff := nowRevision.(revisions.HLCRevision).TimestampNanoSec() - testRevision.(revisions.HLCRevision).TimestampNanoSec() @@ -314,6 +314,10 @@ func TestWatchFeatureDetection(t *testing.T) { require.NoError(t, err) require.NoError(t, crdbmigrations.CRDBMigrations.Run(ctx, migrationDriver, migrate.Head, migrate.LiveRun)) + // Grant SELECT on schema_revision to unprivileged user (needed for HeadRevision) + _, err = adminConn.Exec(ctx, `GRANT SELECT ON TABLE testspicedb.schema_revision TO unprivileged;`) + require.NoError(t, err) + tt.postInit(ctx, adminConn) ds, err := NewCRDBDatastore(ctx, connStrings[unprivileged], WithAcquireTimeout(5*time.Second)) @@ -328,7 +332,7 @@ func TestWatchFeatureDetection(t *testing.T) { require.Contains(t, features.Watch.Reason, tt.expectMessage) if features.Watch.Status != datastore.FeatureSupported { - headRevision, err := ds.HeadRevision(ctx) + headRevision, _, err := ds.HeadRevision(ctx) require.NoError(t, err) _, errChan := ds.Watch(ctx, headRevision, datastore.WatchJustRelationships()) @@ -528,10 +532,10 @@ func RelationshipIntegrityInfoTest(t *testing.T, tester test.DatastoreTester) { require.NoError(err) // Read the relationship back and ensure the integrity information is present. - headRev, err := ds.HeadRevision(ctx) + headRev, _, err := ds.HeadRevision(ctx) require.NoError(err) - reader := ds.SnapshotReader(headRev) + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) iter, err := reader.QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: "document", OptionalResourceIds: []string{"foo"}, @@ -591,10 +595,10 @@ func BulkRelationshipIntegrityInfoTest(t *testing.T, tester test.DatastoreTester require.NoError(err) // Read the relationship back and ensure the integrity information is present. - headRev, err := ds.HeadRevision(ctx) + headRev, _, err := ds.HeadRevision(ctx) require.NoError(err) - reader := ds.SnapshotReader(headRev) + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) iter, err := reader.QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: "document", OptionalResourceIds: []string{"foo"}, @@ -960,3 +964,34 @@ func TestRegisterPrometheusCollectors(t *testing.T) { require.NotNil(t, poolReadMetric) require.Equal(t, float64(readMaxConns), poolReadMetric.GetGauge().GetValue()) //nolint:testifylint // we expect exact values } + +func TestCRDBDatastoreUnifiedSchemaAllModes(t *testing.T) { + t.Parallel() + b := testdatastore.RunCRDBForTesting(t, "", crdbTestVersion()) + + test.UnifiedSchemaAllModesTest(t, func(schemaMode options.SchemaMode) test.DatastoreTester { + return test.DatastoreTesterFunc(func(revisionQuantization, gcInterval, gcWindow time.Duration, watchBufferLength uint16) (datastore.Datastore, error) { + ctx := context.Background() + ds := b.NewDatastore(t, func(engine, uri string) datastore.Datastore { + ds, err := NewCRDBDatastore( + ctx, + uri, + GCWindow(gcWindow), + RevisionQuantization(revisionQuantization), + WatchBufferLength(watchBufferLength), + OverlapStrategy(overlapStrategyPrefix), + DebugAnalyzeBeforeStatistics(), + WithAcquireTimeout(5*time.Second), + WithSchemaMode(schemaMode), + ) + require.NoError(t, err) + t.Cleanup(func() { + _ = ds.Close() + }) + return indexcheck.WrapWithIndexCheckingDatastoreProxyIfApplicable(ds) + }) + + return ds, nil + }) + }) +} diff --git a/internal/datastore/crdb/migrations/zz_migration.0010_add_schema_tables.go b/internal/datastore/crdb/migrations/zz_migration.0010_add_schema_tables.go new file mode 100644 index 000000000..e5dc680da --- /dev/null +++ b/internal/datastore/crdb/migrations/zz_migration.0010_add_schema_tables.go @@ -0,0 +1,41 @@ +package migrations + +import ( + "context" + + "github.com/jackc/pgx/v5" +) + +const ( + createSchemaTable = `CREATE TABLE schema ( + name VARCHAR NOT NULL, + chunk_index INT NOT NULL, + chunk_data BYTEA NOT NULL, + timestamp TIMESTAMP WITHOUT TIME ZONE DEFAULT now() NOT NULL, + CONSTRAINT pk_schema PRIMARY KEY (name, chunk_index) + );` + + createSchemaRevisionTable = `CREATE TABLE schema_revision ( + name VARCHAR NOT NULL DEFAULT 'current', + hash BYTEA NOT NULL, + timestamp TIMESTAMP WITHOUT TIME ZONE DEFAULT now() NOT NULL, + CONSTRAINT pk_schema_revision PRIMARY KEY (name) + );` +) + +func init() { + err := CRDBMigrations.Register("add-schema-tables", "add-expiration-support", addSchemaTablesFunc, noAtomicMigration) + if err != nil { + panic("failed to register migration: " + err.Error()) + } +} + +func addSchemaTablesFunc(ctx context.Context, conn *pgx.Conn) error { + if _, err := conn.Exec(ctx, createSchemaTable); err != nil { + return err + } + if _, err := conn.Exec(ctx, createSchemaRevisionTable); err != nil { + return err + } + return nil +} diff --git a/internal/datastore/crdb/migrations/zz_migration.0011_populate_schema_tables.go b/internal/datastore/crdb/migrations/zz_migration.0011_populate_schema_tables.go new file mode 100644 index 000000000..7cb50142c --- /dev/null +++ b/internal/datastore/crdb/migrations/zz_migration.0011_populate_schema_tables.go @@ -0,0 +1,167 @@ +package migrations + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "sort" + + "github.com/jackc/pgx/v5" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" +) + +const ( + schemaChunkSize = 1024 * 1024 // 1MB chunks + currentSchemaVersion = 1 + unifiedSchemaName = "unified_schema" + schemaRevisionName = "current" +) + +func init() { + if err := CRDBMigrations.Register("populate-schema-tables", "add-schema-tables", noNonAtomicMigration, func(ctx context.Context, tx pgx.Tx) error { + // Read all existing namespaces + rows, err := tx.Query(ctx, ` + SELECT namespace, serialized_config + FROM namespace_config + `) + if err != nil { + return fmt.Errorf("failed to query namespaces: %w", err) + } + defer rows.Close() + + namespaces := make(map[string]*core.NamespaceDefinition) + for rows.Next() { + var name string + var config []byte + if err := rows.Scan(&name, &config); err != nil { + return fmt.Errorf("failed to scan namespace: %w", err) + } + + var ns core.NamespaceDefinition + if err := ns.UnmarshalVT(config); err != nil { + return fmt.Errorf("failed to unmarshal namespace %s: %w", name, err) + } + namespaces[name] = &ns + } + if err := rows.Err(); err != nil { + return fmt.Errorf("error iterating namespaces: %w", err) + } + + // Read all existing caveats + rows, err = tx.Query(ctx, ` + SELECT name, definition + FROM caveat + `) + if err != nil { + return fmt.Errorf("failed to query caveats: %w", err) + } + defer rows.Close() + + caveats := make(map[string]*core.CaveatDefinition) + for rows.Next() { + var name string + var definition []byte + if err := rows.Scan(&name, &definition); err != nil { + return fmt.Errorf("failed to scan caveat: %w", err) + } + + var caveat core.CaveatDefinition + if err := caveat.UnmarshalVT(definition); err != nil { + return fmt.Errorf("failed to unmarshal caveat %s: %w", name, err) + } + caveats[name] = &caveat + } + if err := rows.Err(); err != nil { + return fmt.Errorf("error iterating caveats: %w", err) + } + + // If there are no namespaces or caveats, skip migration + if len(namespaces) == 0 && len(caveats) == 0 { + return nil + } + + // Generate canonical schema for hash computation + allDefs := make([]compiler.SchemaDefinition, 0, len(namespaces)+len(caveats)) + for _, ns := range namespaces { + allDefs = append(allDefs, ns) + } + for _, caveat := range caveats { + allDefs = append(allDefs, caveat) + } + + // Sort alphabetically for canonical ordering + sort.Slice(allDefs, func(i, j int) bool { + return allDefs[i].GetName() < allDefs[j].GetName() + }) + + // Generate canonical schema text + canonicalSchemaText, _, err := generator.GenerateSchema(allDefs) + if err != nil { + return fmt.Errorf("failed to generate canonical schema: %w", err) + } + + // Compute SHA256 hash + hashBytes := sha256.Sum256([]byte(canonicalSchemaText)) + schemaHash := hashBytes[:] + + // Generate user-facing schema text + schemaText, _, err := generator.GenerateSchema(allDefs) + if err != nil { + return fmt.Errorf("failed to generate schema: %w", err) + } + + // Create stored schema proto + storedSchema := &core.StoredSchema{ + Version: currentSchemaVersion, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: schemaText, + SchemaHash: hex.EncodeToString(schemaHash), + NamespaceDefinitions: namespaces, + CaveatDefinitions: caveats, + }, + }, + } + + // Marshal schema + schemaData, err := storedSchema.MarshalVT() + if err != nil { + return fmt.Errorf("failed to marshal schema: %w", err) + } + + // Insert schema chunks + for chunkIndex := 0; chunkIndex*schemaChunkSize < len(schemaData); chunkIndex++ { + start := chunkIndex * schemaChunkSize + end := start + schemaChunkSize + if end > len(schemaData) { + end = len(schemaData) + } + chunk := schemaData[start:end] + + _, err = tx.Exec(ctx, ` + INSERT INTO schema (name, chunk_index, chunk_data) + VALUES ($1, $2, $3) + `, unifiedSchemaName, chunkIndex, chunk) + if err != nil { + return fmt.Errorf("failed to insert schema chunk %d: %w", chunkIndex, err) + } + } + + // Insert schema hash + _, err = tx.Exec(ctx, ` + INSERT INTO schema_revision (name, hash) + VALUES ($1, $2) + `, schemaRevisionName, schemaHash) + if err != nil { + return fmt.Errorf("failed to insert schema hash: %w", err) + } + + return nil + }); err != nil { + panic("failed to register migration: " + err.Error()) + } +} diff --git a/internal/datastore/crdb/options.go b/internal/datastore/crdb/options.go index 9e745fc21..28c3f138b 100644 --- a/internal/datastore/crdb/options.go +++ b/internal/datastore/crdb/options.go @@ -9,6 +9,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/common" pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore/options" ) type crdbOptions struct { @@ -36,6 +37,8 @@ type crdbOptions struct { includeQueryParametersInTraces bool watchDisabled bool acquireTimeout time.Duration + schemaMode options.SchemaMode + schemaCacheOptions options.SchemaCacheOptions } const ( @@ -404,3 +407,13 @@ func WithWatchDisabled(isDisabled bool) Option { func WithAcquireTimeout(timeout time.Duration) Option { return func(po *crdbOptions) { po.acquireTimeout = timeout } } + +// WithSchemaMode sets the experimental schema mode for the datastore. +func WithSchemaMode(mode options.SchemaMode) Option { + return func(po *crdbOptions) { po.schemaMode = mode } +} + +// WithSchemaCacheOptions sets the schema cache options for the datastore. +func WithSchemaCacheOptions(cacheOptions options.SchemaCacheOptions) Option { + return func(po *crdbOptions) { po.schemaCacheOptions = cacheOptions } +} diff --git a/internal/datastore/crdb/reader.go b/internal/datastore/crdb/reader.go index 5bf2c9b1b..b46a9b972 100644 --- a/internal/datastore/crdb/reader.go +++ b/internal/datastore/crdb/reader.go @@ -15,7 +15,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/crdb/schema" pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" "github.com/authzed/spicedb/internal/datastore/revisions" - schemautil "github.com/authzed/spicedb/internal/datastore/schema" + schemaadapter "github.com/authzed/spicedb/internal/datastore/schema" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/datastore/options" core "github.com/authzed/spicedb/pkg/proto/core/v1" @@ -50,6 +50,10 @@ type crdbReader struct { filterMaximumIDCount uint16 withIntegrity bool atSpecificRevision string + schemaMode options.SchemaMode + snapshotRevision datastore.Revision + schemaHash string + schemaReaderWriter *common.SQLSchemaReaderWriter[any, revisions.HLCRevision] } const ( @@ -431,7 +435,62 @@ func (cr *crdbReader) addOverlapKey(namespace string) { // SchemaReader returns a SchemaReader for reading schema information. func (cr *crdbReader) SchemaReader() (datastore.SchemaReader, error) { - return schemautil.NewLegacySchemaReaderAdapter(cr), nil + // Wrap the reader with an unexported schema reader + reader := &crdbSchemaReader{r: cr} + return schemaadapter.NewSchemaReader(reader, cr.schemaMode, cr.snapshotRevision), nil } -var _ datastore.Reader = &crdbReader{} +// crdbSchemaReader wraps a crdbReader and implements DualSchemaReader. +// This prevents direct access to schema read methods from the reader. +type crdbSchemaReader struct { + r *crdbReader +} + +// ReadStoredSchema implements datastore.SingleStoreSchemaReader with revision-aware reading +func (sr *crdbSchemaReader) ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) { + // Create a revision-aware executor that applies AS OF SYSTEM TIME + executor := &revisionAwareExecutor{ + query: sr.r.query, + addFromToQuery: sr.r.addFromToQuery, + assertAsOfSysTime: sr.r.assertHasExpectedAsOfSystemTime, + } + + // Use the shared schema reader/writer to read the schema with the hash + return sr.r.schemaReaderWriter.ReadSchema(ctx, executor, sr.r.snapshotRevision, datastore.SchemaHash(sr.r.schemaHash)) +} + +// LegacyLookupNamespacesWithNames delegates to the underlying reader +func (sr *crdbSchemaReader) LegacyLookupNamespacesWithNames(ctx context.Context, nsNames []string) ([]datastore.RevisionedDefinition[*core.NamespaceDefinition], error) { + return sr.r.LegacyLookupNamespacesWithNames(ctx, nsNames) +} + +// LegacyReadCaveatByName delegates to the underlying reader +func (sr *crdbSchemaReader) LegacyReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { + return sr.r.LegacyReadCaveatByName(ctx, name) +} + +// LegacyListAllCaveats delegates to the underlying reader +func (sr *crdbSchemaReader) LegacyListAllCaveats(ctx context.Context) ([]datastore.RevisionedCaveat, error) { + return sr.r.LegacyListAllCaveats(ctx) +} + +// LegacyLookupCaveatsWithNames delegates to the underlying reader +func (sr *crdbSchemaReader) LegacyLookupCaveatsWithNames(ctx context.Context, names []string) ([]datastore.RevisionedCaveat, error) { + return sr.r.LegacyLookupCaveatsWithNames(ctx, names) +} + +// LegacyReadNamespaceByName delegates to the underlying reader +func (sr *crdbSchemaReader) LegacyReadNamespaceByName(ctx context.Context, nsName string) (*core.NamespaceDefinition, datastore.Revision, error) { + return sr.r.LegacyReadNamespaceByName(ctx, nsName) +} + +// LegacyListAllNamespaces delegates to the underlying reader +func (sr *crdbSchemaReader) LegacyListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { + return sr.r.LegacyListAllNamespaces(ctx) +} + +var ( + _ datastore.Reader = &crdbReader{} + _ datastore.LegacySchemaReader = &crdbReader{} + _ datastore.DualSchemaReader = &crdbSchemaReader{} +) diff --git a/internal/datastore/crdb/readwrite.go b/internal/datastore/crdb/readwrite.go index 5255bab04..8d1844c0c 100644 --- a/internal/datastore/crdb/readwrite.go +++ b/internal/datastore/crdb/readwrite.go @@ -3,8 +3,11 @@ package crdb import ( "cmp" "context" + "crypto/sha256" + "encoding/hex" "errors" "fmt" + "sort" sq "github.com/Masterminds/squirrel" "github.com/ccoveille/go-safecast/v2" @@ -16,11 +19,13 @@ import ( "github.com/authzed/spicedb/internal/datastore/crdb/schema" pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" "github.com/authzed/spicedb/internal/datastore/revisions" - schemautil "github.com/authzed/spicedb/internal/datastore/schema" + schemaadapter "github.com/authzed/spicedb/internal/datastore/schema" log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/datastore/options" core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" "github.com/authzed/spicedb/pkg/spiceerrors" "github.com/authzed/spicedb/pkg/tuple" ) @@ -565,7 +570,171 @@ func (rwt *crdbReadWriteTXN) LegacyDeleteNamespaces(ctx context.Context, nsNames } func (rwt *crdbReadWriteTXN) SchemaWriter() (datastore.SchemaWriter, error) { - return schemautil.NewLegacySchemaWriterAdapter(rwt, rwt), nil + // Wrap the transaction with an unexported schema writer + writer := &crdbSchemaWriter{rwt: rwt} + return schemaadapter.NewSchemaWriter(writer, writer, rwt.schemaMode), nil +} + +// crdbSchemaWriter wraps a crdbReadWriteTXN and implements DualSchemaWriter. +// This prevents direct access to schema write methods from the transaction. +type crdbSchemaWriter struct { + rwt *crdbReadWriteTXN +} + +// WriteStoredSchema implements datastore.SingleStoreSchemaWriter by writing within the current transaction +func (w *crdbSchemaWriter) WriteStoredSchema(ctx context.Context, schema *core.StoredSchema) error { + // Create a transaction-aware executor that uses the current transaction + executor := newTransactionAwareExecutor(w.rwt.tx) + + // Use the shared schema reader/writer to write the schema + // CRDB uses delete-and-insert mode so no transaction ID provider is needed + if err := w.rwt.schemaReaderWriter.WriteSchema(ctx, schema, executor, common.NoTransactionID[any]); err != nil { + return err + } + + // Write the schema hash to the schema_revision table for fast lookups + if err := w.writeSchemaHash(ctx, schema); err != nil { + return fmt.Errorf("failed to write schema hash: %w", err) + } + + return nil +} + +// writeSchemaHash writes the schema hash to the schema_revision table +func (w *crdbSchemaWriter) writeSchemaHash(ctx context.Context, schema *core.StoredSchema) error { + v1 := schema.GetV1() + if v1 == nil { + return fmt.Errorf("unsupported schema version: %d", schema.Version) + } + + // CRDB uses UPSERT (INSERT ON CONFLICT DO UPDATE) for schema_revision + sql, args, err := psql.Insert("schema_revision"). + Columns("name", "hash", "timestamp"). + Values("current", []byte(v1.SchemaHash), sq.Expr("now()")). + Suffix("ON CONFLICT (name) DO UPDATE SET hash = EXCLUDED.hash, timestamp = EXCLUDED.timestamp"). + ToSql() + if err != nil { + return fmt.Errorf("failed to build upsert query: %w", err) + } + + if _, err := w.rwt.tx.Exec(ctx, sql, args...); err != nil { + return fmt.Errorf("failed to upsert hash: %w", err) + } + + return nil +} + +// ReadStoredSchema implements datastore.SingleStoreSchemaReader to satisfy DualSchemaReader interface requirements +func (w *crdbSchemaWriter) ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) { + // Create a revision-aware executor that applies AS OF SYSTEM TIME + // Within a transaction, we don't use AS OF SYSTEM TIME, so pass empty atSpecificRevision + executor := &revisionAwareExecutor{ + query: w.rwt.query, + addFromToQuery: w.rwt.addFromToQuery, + assertAsOfSysTime: w.rwt.assertHasExpectedAsOfSystemTime, + } + + // Use the shared schema reader/writer to read the schema + // Pass empty string to bypass cache (transaction read) + return w.rwt.schemaReaderWriter.ReadSchema(ctx, executor, nil, datastore.NoSchemaHashInTransaction) +} + +// LegacyWriteNamespaces delegates to the underlying transaction +func (w *crdbSchemaWriter) LegacyWriteNamespaces(ctx context.Context, newConfigs ...*core.NamespaceDefinition) error { + return w.rwt.LegacyWriteNamespaces(ctx, newConfigs...) +} + +// LegacyDeleteNamespaces delegates to the underlying transaction +func (w *crdbSchemaWriter) LegacyDeleteNamespaces(ctx context.Context, nsNames []string, delOption datastore.DeleteNamespacesRelationshipsOption) error { + return w.rwt.LegacyDeleteNamespaces(ctx, nsNames, delOption) +} + +// LegacyLookupNamespacesWithNames delegates to the underlying transaction +func (w *crdbSchemaWriter) LegacyLookupNamespacesWithNames(ctx context.Context, nsNames []string) ([]datastore.RevisionedDefinition[*core.NamespaceDefinition], error) { + return w.rwt.LegacyLookupNamespacesWithNames(ctx, nsNames) +} + +// LegacyWriteCaveats delegates to the underlying transaction +func (w *crdbSchemaWriter) LegacyWriteCaveats(ctx context.Context, caveats []*core.CaveatDefinition) error { + return w.rwt.LegacyWriteCaveats(ctx, caveats) +} + +// LegacyDeleteCaveats delegates to the underlying transaction +func (w *crdbSchemaWriter) LegacyDeleteCaveats(ctx context.Context, names []string) error { + return w.rwt.LegacyDeleteCaveats(ctx, names) +} + +// WriteLegacySchemaHashFromDefinitions implements datastore.LegacySchemaHashWriter +func (w *crdbSchemaWriter) WriteLegacySchemaHashFromDefinitions(ctx context.Context, namespaces []datastore.RevisionedNamespace, caveats []datastore.RevisionedCaveat) error { + return w.rwt.writeLegacySchemaHashFromDefinitions(ctx, namespaces, caveats) +} + +// writeLegacySchemaHashFromDefinitions writes the schema hash computed from the given definitions +func (rwt *crdbReadWriteTXN) writeLegacySchemaHashFromDefinitions(ctx context.Context, namespaces []datastore.RevisionedNamespace, caveats []datastore.RevisionedCaveat) error { + // Build schema definitions list + definitions := make([]compiler.SchemaDefinition, 0, len(namespaces)+len(caveats)) + for _, ns := range namespaces { + definitions = append(definitions, ns.Definition) + } + for _, caveat := range caveats { + definitions = append(definitions, caveat.Definition) + } + + // Sort definitions by name for consistent ordering + sort.Slice(definitions, func(i, j int) bool { + return definitions[i].GetName() < definitions[j].GetName() + }) + + // Generate schema text from definitions + schemaText, _, err := generator.GenerateSchema(definitions) + if err != nil { + return fmt.Errorf("failed to generate schema: %w", err) + } + + // Compute schema hash (SHA256) + hash := sha256.Sum256([]byte(schemaText)) + schemaHash := hex.EncodeToString(hash[:]) + + // CRDB uses UPSERT (INSERT ON CONFLICT DO UPDATE) for schema_revision + sql, args, err := psql.Insert("schema_revision"). + Columns("name", "hash", "timestamp"). + Values("current", []byte(schemaHash), sq.Expr("now()")). + Suffix("ON CONFLICT (name) DO UPDATE SET hash = EXCLUDED.hash, timestamp = EXCLUDED.timestamp"). + ToSql() + if err != nil { + return fmt.Errorf("failed to build upsert query: %w", err) + } + + if _, err := rwt.tx.Exec(ctx, sql, args...); err != nil { + return fmt.Errorf("failed to upsert hash: %w", err) + } + + return nil +} + +// LegacyReadCaveatByName delegates to the underlying transaction +func (w *crdbSchemaWriter) LegacyReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { + return w.rwt.LegacyReadCaveatByName(ctx, name) +} + +// LegacyListAllCaveats delegates to the underlying transaction +func (w *crdbSchemaWriter) LegacyListAllCaveats(ctx context.Context) ([]datastore.RevisionedCaveat, error) { + return w.rwt.LegacyListAllCaveats(ctx) +} + +// LegacyLookupCaveatsWithNames delegates to the underlying transaction +func (w *crdbSchemaWriter) LegacyLookupCaveatsWithNames(ctx context.Context, names []string) ([]datastore.RevisionedCaveat, error) { + return w.rwt.LegacyLookupCaveatsWithNames(ctx, names) +} + +// LegacyReadNamespaceByName delegates to the underlying transaction +func (w *crdbSchemaWriter) LegacyReadNamespaceByName(ctx context.Context, nsName string) (*core.NamespaceDefinition, datastore.Revision, error) { + return w.rwt.LegacyReadNamespaceByName(ctx, nsName) +} + +// LegacyListAllNamespaces delegates to the underlying transaction +func (w *crdbSchemaWriter) LegacyListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { + return w.rwt.LegacyListAllNamespaces(ctx) } var copyCols = []string{ @@ -605,4 +774,9 @@ func (rwt *crdbReadWriteTXN) BulkLoad(ctx context.Context, iter datastore.BulkWr return pgxcommon.BulkLoad(ctx, rwt.tx, rwt.schema.RelationshipTableName, copyCols, iter) } -var _ datastore.ReadWriteTransaction = &crdbReadWriteTXN{} +var ( + _ datastore.ReadWriteTransaction = &crdbReadWriteTXN{} + _ datastore.LegacySchemaWriter = &crdbReadWriteTXN{} + _ datastore.DualSchemaWriter = &crdbSchemaWriter{} + _ datastore.DualSchemaReader = &crdbSchemaWriter{} +) diff --git a/internal/datastore/crdb/schema_chunker.go b/internal/datastore/crdb/schema_chunker.go new file mode 100644 index 000000000..f9a98c0c8 --- /dev/null +++ b/internal/datastore/crdb/schema_chunker.go @@ -0,0 +1,211 @@ +package crdb + +import ( + "context" + "errors" + + sq "github.com/Masterminds/squirrel" + "github.com/jackc/pgx/v5" + + "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/internal/datastore/crdb/pool" + pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" +) + +const ( + // CockroachDB has no practical limit on BYTEA column size (similar to Postgres), + // but we use 1MB chunks for reasonable memory usage and query performance. + crdbMaxChunkSize = 1024 * 1024 // 1MB +) + +// BaseSchemaChunkerConfig provides the base configuration for CRDB schema chunking. +// CRDB uses delete-and-insert write mode since it handles MVCC automatically. +var BaseSchemaChunkerConfig = common.SQLByteChunkerConfig[any]{ + TableName: "schema", + NameColumn: "name", + ChunkIndexColumn: "chunk_index", + ChunkDataColumn: "chunk_data", + MaxChunkSize: crdbMaxChunkSize, + PlaceholderFormat: sq.Dollar, + WriteMode: common.WriteModeDeleteAndInsert, +} + +// crdbChunkedBytesExecutor implements common.ChunkedBytesExecutor for CockroachDB. +type crdbChunkedBytesExecutor struct { + pool *pool.RetryPool +} + +func newCRDBChunkedBytesExecutor(pool *pool.RetryPool) *crdbChunkedBytesExecutor { + return &crdbChunkedBytesExecutor{pool: pool} +} + +func (e *crdbChunkedBytesExecutor) BeginTransaction(ctx context.Context) (common.ChunkedBytesTransaction, error) { + // For CRDB, we'll use BeginFunc which provides automatic retry logic + return &crdbChunkedBytesTransaction{pool: e.pool, ctx: ctx}, nil +} + +func (e *crdbChunkedBytesExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + sql, args, err := builder.ToSql() + if err != nil { + return nil, err + } + + result := make(map[int][]byte) + err = e.pool.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error { + for rows.Next() { + var chunkIndex int + var chunkData []byte + if err := rows.Scan(&chunkIndex, &chunkData); err != nil { + return err + } + result[chunkIndex] = chunkData + } + return rows.Err() + }, sql, args...) + if err != nil { + return nil, err + } + + return result, nil +} + +// crdbChunkedBytesTransaction implements common.ChunkedBytesTransaction for CockroachDB. +type crdbChunkedBytesTransaction struct { + pool *pool.RetryPool + ctx context.Context +} + +func (t *crdbChunkedBytesTransaction) ExecuteWrite(ctx context.Context, builder sq.InsertBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + return t.pool.BeginFunc(ctx, func(tx pgx.Tx) error { + _, err := tx.Exec(ctx, sql, args...) + return err + }) +} + +func (t *crdbChunkedBytesTransaction) ExecuteDelete(ctx context.Context, builder sq.DeleteBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + return t.pool.BeginFunc(ctx, func(tx pgx.Tx) error { + _, err := tx.Exec(ctx, sql, args...) + return err + }) +} + +func (t *crdbChunkedBytesTransaction) ExecuteUpdate(ctx context.Context, builder sq.UpdateBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + return t.pool.BeginFunc(ctx, func(tx pgx.Tx) error { + _, err := tx.Exec(ctx, sql, args...) + return err + }) +} + +// GetSchemaChunker returns a SQLByteChunker for the schema table. +// This is exported for testing purposes. +func (cds *crdbDatastore) GetSchemaChunker() *common.SQLByteChunker[any] { + executor := newCRDBChunkedBytesExecutor(cds.readPool) + return common.MustNewSQLByteChunker(BaseSchemaChunkerConfig.WithExecutor(executor)) +} + +// revisionAwareExecutor wraps the reader's query infrastructure to provide revision-aware chunk reading +type revisionAwareExecutor struct { + query pgxcommon.DBFuncQuerier + addFromToQuery func(sq.SelectBuilder, string, string) sq.SelectBuilder + assertAsOfSysTime func(string) +} + +func (e *revisionAwareExecutor) BeginTransaction(ctx context.Context) (common.ChunkedBytesTransaction, error) { + // We don't support transactions for reading + return nil, errors.New("transactions not supported for revision-aware reads") +} + +func (e *revisionAwareExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + // Modify the builder to add AS OF SYSTEM TIME + builder = e.addFromToQuery(builder, "schema", "") + + sql, args, err := builder.ToSql() + if err != nil { + return nil, err + } + e.assertAsOfSysTime(sql) + + // Execute using the reader's query function + result := make(map[int][]byte) + err = e.query.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error { + for rows.Next() { + var chunkIndex int + var chunkData []byte + if err := rows.Scan(&chunkIndex, &chunkData); err != nil { + return err + } + result[chunkIndex] = chunkData + } + return rows.Err() + }, sql, args...) + + return result, err +} + +// transactionAwareExecutor wraps an existing pgx.Tx to provide transaction-aware chunk writing +type transactionAwareExecutor struct { + tx pgx.Tx +} + +func newTransactionAwareExecutor(tx pgx.Tx) *transactionAwareExecutor { + return &transactionAwareExecutor{tx: tx} +} + +func (e *transactionAwareExecutor) BeginTransaction(ctx context.Context) (common.ChunkedBytesTransaction, error) { + // Return a transaction wrapper that uses the existing transaction + return &transactionAwareTransaction{tx: e.tx}, nil +} + +func (e *transactionAwareExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + return nil, errors.New("read operations not supported on transaction-aware executor") +} + +// transactionAwareTransaction implements common.ChunkedBytesTransaction using an existing pgx.Tx +type transactionAwareTransaction struct { + tx pgx.Tx +} + +func (t *transactionAwareTransaction) ExecuteWrite(ctx context.Context, builder sq.InsertBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.Exec(ctx, sql, args...) + return err +} + +func (t *transactionAwareTransaction) ExecuteDelete(ctx context.Context, builder sq.DeleteBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.Exec(ctx, sql, args...) + return err +} + +func (t *transactionAwareTransaction) ExecuteUpdate(ctx context.Context, builder sq.UpdateBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.Exec(ctx, sql, args...) + return err +} diff --git a/internal/datastore/dsfortesting/dsfortesting.go b/internal/datastore/dsfortesting/dsfortesting.go index f9ea13e51..a705ef2ad 100644 --- a/internal/datastore/dsfortesting/dsfortesting.go +++ b/internal/datastore/dsfortesting/dsfortesting.go @@ -13,6 +13,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/datastore/options" + core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/tuple" ) @@ -43,14 +44,50 @@ type validatingDatastore struct { datastore.Datastore } -func (vds validatingDatastore) SnapshotReader(rev datastore.Revision) datastore.Reader { - return validatingReader{vds.Datastore.SnapshotReader(rev)} +func (vds validatingDatastore) SnapshotReader(rev datastore.Revision, hash datastore.SchemaHash) datastore.Reader { + return validatingReader{vds.Datastore.SnapshotReader(rev, hash)} +} + +// SchemaHashReaderForTesting delegates to the underlying datastore if it implements the test interface +func (vds validatingDatastore) SchemaHashReaderForTesting() interface { + ReadSchemaHash(ctx context.Context) (string, error) +} { + type schemaHashReaderProvider interface { + SchemaHashReaderForTesting() interface { + ReadSchemaHash(ctx context.Context) (string, error) + } + } + + if hashReader, ok := vds.Datastore.(schemaHashReaderProvider); ok { + return hashReader.SchemaHashReaderForTesting() + } + return nil +} + +// SchemaModeForTesting delegates to the underlying datastore if it implements the test interface +func (vds validatingDatastore) SchemaModeForTesting() (options.SchemaMode, error) { + type schemaModeProvider interface { + SchemaModeForTesting() (options.SchemaMode, error) + } + + if provider, ok := vds.Datastore.(schemaModeProvider); ok { + return provider.SchemaModeForTesting() + } + return options.SchemaModeReadLegacyWriteLegacy, errors.New("delegate datastore does not implement SchemaModeForTesting()") } type validatingReader struct { datastore.Reader } +func (vr validatingReader) ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) { + singleStoreReader, ok := vr.Reader.(datastore.SingleStoreSchemaReader) + if !ok { + return nil, errors.New("validating reader delegate does not implement SingleStoreSchemaReader") + } + return singleStoreReader.ReadStoredSchema(ctx) +} + func (vr validatingReader) QueryRelationships( ctx context.Context, filter datastore.RelationshipsFilter, @@ -161,3 +198,9 @@ func (vr validatingReader) QueryRelationships( } }, nil } + +var ( + _ datastore.Datastore = validatingDatastore{} + _ datastore.Reader = validatingReader{} + _ datastore.SingleStoreSchemaReader = validatingReader{} +) diff --git a/internal/datastore/memdb/caveat.go b/internal/datastore/memdb/caveat.go index f56062611..e0e0ee17e 100644 --- a/internal/datastore/memdb/caveat.go +++ b/internal/datastore/memdb/caveat.go @@ -108,14 +108,18 @@ func (r *memdbReader) LegacyLookupCaveatsWithNames(ctx context.Context, caveatNa return toReturn, nil } -func (rwt *memdbReadWriteTx) LegacyWriteCaveats(_ context.Context, caveats []*core.CaveatDefinition) error { +func (rwt *memdbReadWriteTx) LegacyWriteCaveats(ctx context.Context, caveats []*core.CaveatDefinition) error { rwt.mustLock() defer rwt.Unlock() tx, err := rwt.txSource() if err != nil { return err } - return rwt.writeCaveat(tx, caveats) + if err := rwt.writeCaveat(tx, caveats); err != nil { + return err + } + + return nil } func (rwt *memdbReadWriteTx) writeCaveat(tx *memdb.Txn, caveats []*core.CaveatDefinition) error { @@ -140,7 +144,7 @@ func (rwt *memdbReadWriteTx) writeCaveat(tx *memdb.Txn, caveats []*core.CaveatDe return nil } -func (rwt *memdbReadWriteTx) LegacyDeleteCaveats(_ context.Context, names []string) error { +func (rwt *memdbReadWriteTx) LegacyDeleteCaveats(ctx context.Context, names []string) error { rwt.mustLock() defer rwt.Unlock() tx, err := rwt.txSource() @@ -152,5 +156,6 @@ func (rwt *memdbReadWriteTx) LegacyDeleteCaveats(_ context.Context, names []stri return err } } + return nil } diff --git a/internal/datastore/memdb/memdb.go b/internal/datastore/memdb/memdb.go index 610c14201..fe7c27cbd 100644 --- a/internal/datastore/memdb/memdb.go +++ b/internal/datastore/memdb/memdb.go @@ -2,6 +2,8 @@ package memdb import ( "context" + "crypto/sha256" + "encoding/hex" "errors" "fmt" "math" @@ -17,6 +19,8 @@ import ( "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/datastore/options" corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" "github.com/authzed/spicedb/pkg/spiceerrors" "github.com/authzed/spicedb/pkg/tuple" ) @@ -25,6 +29,11 @@ const ( Engine = "memory" defaultWatchBufferLength = 128 numAttempts = 10 + + // noHashSupported is a sentinel value indicating that schema hashing is not supported + // by the memdb datastore. memdb uses in-memory schema storage and doesn't use SQL-based + // schema hashing like other datastores. + noHashSupported datastore.SchemaHash = "__memdb_no_hash_support__" ) var ( @@ -69,8 +78,9 @@ func NewMemdbDatastore( db: db, revisions: []snapshot{ { - revision: nowRevision(), - db: db, + revision: nowRevision(), + schemaHash: noHashSupported, + db: db, }, }, @@ -96,11 +106,15 @@ type memdbDatastore struct { watchBufferLength uint16 watchBufferWriteTimeout time.Duration uniqueID string + + // Unified schema storage + storedSchema *corev1.StoredSchema // GUARDED_BY(RWMutex) } type snapshot struct { - revision revisions.TimestampRevision - db *memdb.MemDB + revision revisions.TimestampRevision + schemaHash datastore.SchemaHash + db *memdb.MemDB } func (mdb *memdbDatastore) MetricsID() (string, error) { @@ -111,20 +125,34 @@ func (mdb *memdbDatastore) UniqueID(_ context.Context) (string, error) { return mdb.uniqueID, nil } -func (mdb *memdbDatastore) SnapshotReader(dr datastore.Revision) datastore.Reader { +func (mdb *memdbDatastore) getCurrentSchemaHashNoLock() datastore.SchemaHash { + // Read the current schema hash from the schemarevision table + txn := mdb.db.Txn(false) + defer txn.Abort() + + raw, err := txn.First(tableSchemaRevision, indexID, "current") + if err != nil || raw == nil { + return noHashSupported + } + + schemaRev := raw.(*schemaRevisionData) + return datastore.SchemaHash(schemaRev.hash) +} + +func (mdb *memdbDatastore) SnapshotReader(dr datastore.Revision, hash datastore.SchemaHash) datastore.Reader { mdb.RLock() defer mdb.RUnlock() if err := mdb.checkNotClosed(); err != nil { - return &memdbReader{nil, nil, err, time.Now()} + return &memdbReader{nil, nil, err, time.Now(), "", mdb} } if len(mdb.revisions) == 0 { - return &memdbReader{nil, nil, errors.New("memdb datastore is not ready"), time.Now()} + return &memdbReader{nil, nil, errors.New("memdb datastore is not ready"), time.Now(), "", mdb} } if err := mdb.checkRevisionLocalCallerMustLock(dr); err != nil { - return &memdbReader{nil, nil, err, time.Now()} + return &memdbReader{nil, nil, err, time.Now(), "", mdb} } revIndex := sort.Search(len(mdb.revisions), func(i int) bool { @@ -138,7 +166,7 @@ func (mdb *memdbDatastore) SnapshotReader(dr datastore.Revision) datastore.Reade rev := mdb.revisions[revIndex] if rev.db == nil { - return &memdbReader{nil, nil, errors.New("memdb datastore is already closed"), time.Now()} + return &memdbReader{nil, nil, errors.New("memdb datastore is already closed"), time.Now(), "", mdb} } roTxn := rev.db.Txn(false) @@ -146,7 +174,7 @@ func (mdb *memdbDatastore) SnapshotReader(dr datastore.Revision) datastore.Reade return roTxn, nil } - return &memdbReader{noopTryLocker{}, txSrc, nil, time.Now()} + return &memdbReader{noopTryLocker{}, txSrc, nil, time.Now(), string(hash), mdb} } func (mdb *memdbDatastore) SupportsIntegrity() bool { @@ -191,7 +219,7 @@ func (mdb *memdbDatastore) ReadWriteTx( } newRevision := mdb.newRevisionID() - rwt := &memdbReadWriteTx{memdbReader{&sync.Mutex{}, txSrc, nil, time.Now()}, newRevision} + rwt := &memdbReadWriteTx{memdbReader{&sync.Mutex{}, txSrc, nil, time.Now(), string(datastore.NoSchemaHashInTransaction), mdb}, newRevision} if err := f(ctx, rwt); err != nil { mdb.Lock() if tx != nil { @@ -323,7 +351,11 @@ func (mdb *memdbDatastore) ReadWriteTx( // Create a snapshot and add it to the revisions slice snap := mdb.db.Snapshot() - mdb.revisions = append(mdb.revisions, snapshot{newRevision, snap}) + + // Get the current schema hash + schemaHash := mdb.getCurrentSchemaHashNoLock() + + mdb.revisions = append(mdb.revisions, snapshot{newRevision, schemaHash, snap}) return newRevision, nil } @@ -368,8 +400,9 @@ func (mdb *memdbDatastore) Close() error { if db := mdb.db; db != nil { mdb.revisions = []snapshot{ { - revision: nowRevision(), - db: db, + revision: nowRevision(), + schemaHash: noHashSupported, + db: db, }, } } else { @@ -381,6 +414,47 @@ func (mdb *memdbDatastore) Close() error { return nil } +// SchemaHashReaderForTesting returns a test-only interface for reading the schema hash directly from schema_revision table. +func (mdb *memdbDatastore) SchemaHashReaderForTesting() interface { + ReadSchemaHash(ctx context.Context) (string, error) +} { + return &memdbSchemaHashReaderForTesting{db: mdb} +} + +// SchemaModeForTesting returns the current schema mode for testing purposes. +// MemDB always operates in a mode equivalent to ReadNewWriteNew. +func (mdb *memdbDatastore) SchemaModeForTesting() (options.SchemaMode, error) { + return options.SchemaModeReadNewWriteNew, nil +} + +type memdbSchemaHashReaderForTesting struct { + db *memdbDatastore +} + +func (r *memdbSchemaHashReaderForTesting) ReadSchemaHash(ctx context.Context) (string, error) { + r.db.RLock() + defer r.db.RUnlock() + + tx := r.db.db.Txn(false) + defer tx.Abort() + + raw, err := tx.First(tableSchemaRevision, indexID, "current") + if err != nil { + return "", fmt.Errorf("failed to query schema hash: %w", err) + } + + if raw == nil { + return "", datastore.ErrSchemaNotFound + } + + revisionData, ok := raw.(*schemaRevisionData) + if !ok { + return "", errors.New("invalid schema revision data type") + } + + return string(revisionData.hash), nil +} + // This code assumes that the RWMutex has been acquired. func (mdb *memdbDatastore) checkNotClosed() error { if mdb.db == nil { @@ -389,4 +463,113 @@ func (mdb *memdbDatastore) checkNotClosed() error { return nil } +// readStoredSchemaInternal is an internal method for readers/transactions to access the stored schema. +// This should NOT be called directly - use readers/transactions instead. +func (mdb *memdbDatastore) readStoredSchemaInternal() (*corev1.StoredSchema, error) { + mdb.RLock() + defer mdb.RUnlock() + + if err := mdb.checkNotClosed(); err != nil { + return nil, err + } + + if mdb.storedSchema == nil { + return nil, datastore.ErrSchemaNotFound + } + + // Return a copy to prevent external mutations + return mdb.storedSchema.CloneVT(), nil +} + +// writeStoredSchemaNoLock writes the stored schema using the provided transaction. +// This is called from within an existing transaction, so it doesn't acquire locks or commit. +func (mdb *memdbDatastore) writeStoredSchemaNoLock(tx *memdb.Txn, schema *corev1.StoredSchema) error { + // Store a copy to prevent external mutations + mdb.storedSchema = schema.CloneVT() + + // Write the schema hash to the schema revision table for fast lookups + if err := mdb.writeSchemaHashNoLock(tx, schema); err != nil { + return fmt.Errorf("failed to write schema hash: %w", err) + } + + return nil +} + +// writeSchemaHashNoLock writes the schema hash to the in-memory schema revision table using the provided transaction. +func (mdb *memdbDatastore) writeSchemaHashNoLock(tx *memdb.Txn, schema *corev1.StoredSchema) error { + v1 := schema.GetV1() + if v1 == nil { + return fmt.Errorf("unsupported schema version: %d", schema.Version) + } + + // Delete existing hash (if any) + if existing, err := tx.First(tableSchemaRevision, indexID, "current"); err == nil && existing != nil { + if err := tx.Delete(tableSchemaRevision, existing); err != nil { + return fmt.Errorf("failed to delete old hash: %w", err) + } + } + + // Insert new hash + revisionData := &schemaRevisionData{ + name: "current", + hash: []byte(v1.SchemaHash), + } + + if err := tx.Insert(tableSchemaRevision, revisionData); err != nil { + return fmt.Errorf("failed to insert hash: %w", err) + } + + // Note: Don't commit here - the caller will commit the transaction + return nil +} + +// writeLegacySchemaHashFromDefinitionsInternal writes the schema hash computed from the given definitions +// writeLegacySchemaHashFromDefinitionsNoLock writes the schema hash using the provided transaction. +// This is called from within an existing transaction, so it doesn't acquire locks. +func (mdb *memdbDatastore) writeLegacySchemaHashFromDefinitionsNoLock(ctx context.Context, tx *memdb.Txn, namespaces []datastore.RevisionedNamespace, caveats []datastore.RevisionedCaveat) error { + // Build schema definitions list + definitions := make([]compiler.SchemaDefinition, 0, len(namespaces)+len(caveats)) + for _, ns := range namespaces { + definitions = append(definitions, ns.Definition) + } + for _, caveat := range caveats { + definitions = append(definitions, caveat.Definition) + } + + // Sort definitions by name for consistent ordering + sort.Slice(definitions, func(i, j int) bool { + return definitions[i].GetName() < definitions[j].GetName() + }) + + // Generate schema text from definitions + schemaText, _, err := generator.GenerateSchema(definitions) + if err != nil { + return fmt.Errorf("failed to generate schema: %w", err) + } + + // Compute schema hash (SHA256) + hash := sha256.Sum256([]byte(schemaText)) + schemaHash := hex.EncodeToString(hash[:]) + + // Delete existing hash (if any) + if existing, err := tx.First(tableSchemaRevision, indexID, "current"); err == nil && existing != nil { + if err := tx.Delete(tableSchemaRevision, existing); err != nil { + return fmt.Errorf("failed to delete old hash: %w", err) + } + } + + // Insert new hash + revisionData := &schemaRevisionData{ + name: "current", + hash: []byte(schemaHash), + } + + if err := tx.Insert(tableSchemaRevision, revisionData); err != nil { + return fmt.Errorf("failed to insert hash: %w", err) + } + + // Note: Don't commit here - the caller will commit the transaction + return nil +} + var _ datastore.Datastore = &memdbDatastore{} diff --git a/internal/datastore/memdb/memdb_test.go b/internal/datastore/memdb/memdb_test.go index f21c440c5..67fbf3e7b 100644 --- a/internal/datastore/memdb/memdb_test.go +++ b/internal/datastore/memdb/memdb_test.go @@ -121,7 +121,7 @@ func TestAnythingAfterCloseDoesNotPanic(t *testing.T) { ds, err := NewMemdbDatastore(0, 1*time.Hour, 1*time.Hour) require.NoError(err) - lowestRevision, err := ds.HeadRevision(t.Context()) + lowestRevision, _, err := ds.HeadRevision(t.Context()) require.NoError(err) err = ds.Close() @@ -142,10 +142,10 @@ func TestAnythingAfterCloseDoesNotPanic(t *testing.T) { err = ds.CheckRevision(t.Context(), lowestRevision) require.ErrorIs(err, ErrMemDBIsClosed) - _, err = ds.OptimizedRevision(t.Context()) + _, _, err = ds.OptimizedRevision(t.Context()) require.ErrorIs(err, ErrMemDBIsClosed) - reader := ds.SnapshotReader(datastore.NoRevision) + reader := ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting) _, err = reader.CountRelationships(t.Context(), "blah") require.ErrorIs(err, ErrMemDBIsClosed) } @@ -168,7 +168,7 @@ func BenchmarkQueryRelationships(b *testing.B) { }) require.NoError(err) - reader := ds.SnapshotReader(rev) + reader := ds.SnapshotReader(rev, datastore.NoSchemaHashForTesting) b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/internal/datastore/memdb/readonly.go b/internal/datastore/memdb/readonly.go index c6a3a951a..e994f3b81 100644 --- a/internal/datastore/memdb/readonly.go +++ b/internal/datastore/memdb/readonly.go @@ -23,9 +23,11 @@ type txFactory func() (*memdb.Txn, error) type memdbReader struct { TryLocker - txSource txFactory - initErr error - now time.Time + txSource txFactory + initErr error + now time.Time + schemaHash string + datastore *memdbDatastore } func (r *memdbReader) CountRelationships(ctx context.Context, name string) (int, error) { @@ -593,7 +595,17 @@ func (r *memdbReader) SchemaReader() (datastore.SchemaReader, error) { return schemautil.NewLegacySchemaReaderAdapter(r), nil } -var _ datastore.Reader = &memdbReader{} +// ReadStoredSchema implements datastore.SingleStoreSchemaReader +func (r *memdbReader) ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) { + return r.datastore.readStoredSchemaInternal() +} + +var ( + _ datastore.Reader = &memdbReader{} + _ datastore.LegacySchemaReader = &memdbReader{} + _ datastore.SingleStoreSchemaReader = &memdbReader{} + _ datastore.DualSchemaReader = &memdbReader{} +) type TryLocker interface { TryLock() bool diff --git a/internal/datastore/memdb/readwrite.go b/internal/datastore/memdb/readwrite.go index 5f35674f4..0c991b840 100644 --- a/internal/datastore/memdb/readwrite.go +++ b/internal/datastore/memdb/readwrite.go @@ -281,7 +281,7 @@ func (rwt *memdbReadWriteTx) StoreCounterValue(ctx context.Context, name string, return tx.Insert(tableCounters, counter) } -func (rwt *memdbReadWriteTx) LegacyWriteNamespaces(_ context.Context, newConfigs ...*core.NamespaceDefinition) error { +func (rwt *memdbReadWriteTx) LegacyWriteNamespaces(ctx context.Context, newConfigs ...*core.NamespaceDefinition) error { rwt.mustLock() defer rwt.Unlock() @@ -307,7 +307,7 @@ func (rwt *memdbReadWriteTx) LegacyWriteNamespaces(_ context.Context, newConfigs return nil } -func (rwt *memdbReadWriteTx) LegacyDeleteNamespaces(_ context.Context, nsNames []string, delOption datastore.DeleteNamespacesRelationshipsOption) error { +func (rwt *memdbReadWriteTx) LegacyDeleteNamespaces(ctx context.Context, nsNames []string, delOption datastore.DeleteNamespacesRelationshipsOption) error { if len(nsNames) == 0 { return nil } @@ -348,7 +348,30 @@ func (rwt *memdbReadWriteTx) LegacyDeleteNamespaces(_ context.Context, nsNames [ } func (rwt *memdbReadWriteTx) SchemaWriter() (datastore.SchemaWriter, error) { - return schemautil.NewLegacySchemaWriterAdapter(rwt, rwt), nil + // MemDB supports both legacy and unified schema storage. + // Use write-to-both mode to ensure both are updated. + return schemautil.NewSchemaWriter(rwt, rwt, options.SchemaModeReadNewWriteBoth), nil +} + +// WriteStoredSchema implements datastore.SingleStoreSchemaWriter +func (rwt *memdbReadWriteTx) WriteStoredSchema(ctx context.Context, schema *core.StoredSchema) error { + // Called from within a transaction - use the transaction's db handle directly + tx, err := rwt.txSource() + if err != nil { + return err + } + return rwt.datastore.writeStoredSchemaNoLock(tx, schema) +} + +// WriteLegacySchemaHashFromDefinitions implements datastore.LegacySchemaHashWriter +func (rwt *memdbReadWriteTx) WriteLegacySchemaHashFromDefinitions(ctx context.Context, namespaces []datastore.RevisionedNamespace, caveats []datastore.RevisionedCaveat) error { + // Called from within a transaction - use the transaction's db handle directly + // without trying to acquire additional locks + tx, err := rwt.txSource() + if err != nil { + return err + } + return rwt.datastore.writeLegacySchemaHashFromDefinitionsNoLock(ctx, tx, namespaces, caveats) } func (rwt *memdbReadWriteTx) BulkLoad(ctx context.Context, iter datastore.BulkWriteRelationshipSource) (uint64, error) { @@ -404,4 +427,9 @@ func relationshipFilterFilterFunc(filter *v1.RelationshipFilter) func(any) bool } } -var _ datastore.ReadWriteTransaction = &memdbReadWriteTx{} +var ( + _ datastore.ReadWriteTransaction = &memdbReadWriteTx{} + _ datastore.LegacySchemaWriter = &memdbReadWriteTx{} + _ datastore.SingleStoreSchemaWriter = &memdbReadWriteTx{} + _ datastore.DualSchemaWriter = &memdbReadWriteTx{} +) diff --git a/internal/datastore/memdb/revisions.go b/internal/datastore/memdb/revisions.go index be797714a..4adc717cb 100644 --- a/internal/datastore/memdb/revisions.go +++ b/internal/datastore/memdb/revisions.go @@ -38,21 +38,23 @@ func (mdb *memdbDatastore) newRevisionID() revisions.TimestampRevision { return created } -func (mdb *memdbDatastore) HeadRevision(_ context.Context) (datastore.Revision, error) { +func (mdb *memdbDatastore) HeadRevision(_ context.Context) (datastore.Revision, datastore.SchemaHash, error) { mdb.RLock() defer mdb.RUnlock() if err := mdb.checkNotClosed(); err != nil { - return nil, err + return nil, noHashSupported, err } - return mdb.headRevisionNoLock(), nil + rev, hash := mdb.headRevisionWithHashNoLock() + return rev, hash, nil } func (mdb *memdbDatastore) SquashRevisionsForTesting() { mdb.revisions = []snapshot{ { - revision: nowRevision(), - db: mdb.db, + revision: nowRevision(), + schemaHash: noHashSupported, + db: mdb.db, }, } } @@ -61,15 +63,21 @@ func (mdb *memdbDatastore) headRevisionNoLock() revisions.TimestampRevision { return mdb.revisions[len(mdb.revisions)-1].revision } -func (mdb *memdbDatastore) OptimizedRevision(_ context.Context) (datastore.Revision, error) { +func (mdb *memdbDatastore) headRevisionWithHashNoLock() (revisions.TimestampRevision, datastore.SchemaHash) { + snap := mdb.revisions[len(mdb.revisions)-1] + return snap.revision, snap.schemaHash +} + +func (mdb *memdbDatastore) OptimizedRevision(_ context.Context) (datastore.Revision, datastore.SchemaHash, error) { mdb.RLock() defer mdb.RUnlock() if err := mdb.checkNotClosed(); err != nil { - return nil, err + return nil, noHashSupported, err } now := nowRevision() - return revisions.NewForTimestamp(now.TimestampNanoSec() - now.TimestampNanoSec()%mdb.quantizationPeriod), nil + // Note: Memory datastore doesn't cache schemas, so we don't return a hash for OptimizedRevision + return revisions.NewForTimestamp(now.TimestampNanoSec() - now.TimestampNanoSec()%mdb.quantizationPeriod), noHashSupported, nil } func (mdb *memdbDatastore) CheckRevision(_ context.Context, dr datastore.Revision) error { diff --git a/internal/datastore/memdb/revisions_test.go b/internal/datastore/memdb/revisions_test.go index cd12239aa..a6cb9a5ab 100644 --- a/internal/datastore/memdb/revisions_test.go +++ b/internal/datastore/memdb/revisions_test.go @@ -11,7 +11,7 @@ func TestHeadRevision(t *testing.T) { ds, err := NewMemdbDatastore(0, 0, 500*time.Millisecond) require.NoError(t, err) - older, err := ds.HeadRevision(t.Context()) + older, _, err := ds.HeadRevision(t.Context()) require.NoError(t, err) err = ds.CheckRevision(t.Context(), older) require.NoError(t, err) @@ -19,7 +19,7 @@ func TestHeadRevision(t *testing.T) { time.Sleep(550 * time.Millisecond) // GC window elapsed, last revision is returned even if outside GC window - newer, err := ds.HeadRevision(t.Context()) + newer, _, err := ds.HeadRevision(t.Context()) require.NoError(t, err) err = ds.CheckRevision(t.Context(), newer) require.NoError(t, err) diff --git a/internal/datastore/memdb/schema.go b/internal/datastore/memdb/schema.go index 7905d4854..32e9ec668 100644 --- a/internal/datastore/memdb/schema.go +++ b/internal/datastore/memdb/schema.go @@ -26,6 +26,9 @@ const ( tableChangelog = "changelog" indexRevision = "id" + + tableSchema = "schema" + tableSchemaRevision = "schemarevision" ) type namespace struct { @@ -45,6 +48,16 @@ type counter struct { updated datastore.Revision } +type schemaData struct { + name string + data []byte +} + +type schemaRevisionData struct { + name string + hash []byte +} + type relationship struct { namespace string resourceID string @@ -228,5 +241,25 @@ var schema = &memdb.DBSchema{ }, }, }, + tableSchema: { + Name: tableSchema, + Indexes: map[string]*memdb.IndexSchema{ + indexID: { + Name: indexID, + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "name"}, + }, + }, + }, + tableSchemaRevision: { + Name: tableSchemaRevision, + Indexes: map[string]*memdb.IndexSchema{ + indexID: { + Name: indexID, + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "name"}, + }, + }, + }, }, } diff --git a/internal/datastore/memdb/schema_test.go b/internal/datastore/memdb/schema_test.go new file mode 100644 index 000000000..d7423ac7a --- /dev/null +++ b/internal/datastore/memdb/schema_test.go @@ -0,0 +1,213 @@ +package memdb + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSchemaTableReadWrite(t *testing.T) { + ds, err := NewMemdbDatastore(0, 0, 0) + require.NoError(t, err) + defer ds.Close() + + memdbDS := ds.(*memdbDatastore) + + // Lock for direct access + memdbDS.Lock() + defer memdbDS.Unlock() + + // Get a write transaction + tx := memdbDS.db.Txn(true) + defer tx.Abort() + + // Test writing schema data + testSchema := &schemaData{ + name: "test-schema", + data: []byte("test schema definition data"), + } + + err = tx.Insert(tableSchema, testSchema) + require.NoError(t, err, "failed to insert schema") + + // Test reading schema data + found, err := tx.First(tableSchema, indexID, "test-schema") + require.NoError(t, err) + require.NotNil(t, found, "schema not found") + + readSchema := found.(*schemaData) + require.Equal(t, testSchema.name, readSchema.name) + require.Equal(t, testSchema.data, readSchema.data) + + // Commit the transaction + tx.Commit() +} + +func TestSchemaRevisionTableReadWrite(t *testing.T) { + ds, err := NewMemdbDatastore(0, 0, 0) + require.NoError(t, err) + defer ds.Close() + + memdbDS := ds.(*memdbDatastore) + + // Lock for direct access + memdbDS.Lock() + defer memdbDS.Unlock() + + // Get a write transaction + tx := memdbDS.db.Txn(true) + defer tx.Abort() + + // Test writing schema revision data + revisionData := &schemaRevisionData{ + name: "current", + hash: []byte("schema-hash-12345"), + } + + err = tx.Insert(tableSchemaRevision, revisionData) + require.NoError(t, err, "failed to insert schema revision") + + // Test reading schema revision data + found, err := tx.First(tableSchemaRevision, indexID, "current") + require.NoError(t, err) + require.NotNil(t, found, "schema revision not found") + + readRevision := found.(*schemaRevisionData) + require.Equal(t, revisionData.name, readRevision.name) + require.Equal(t, revisionData.hash, readRevision.hash) + + // Commit the transaction + tx.Commit() +} + +func TestSchemaTableUpdate(t *testing.T) { + ds, err := NewMemdbDatastore(0, 0, 0) + require.NoError(t, err) + defer ds.Close() + + memdbDS := ds.(*memdbDatastore) + + // Lock for direct access + memdbDS.Lock() + defer memdbDS.Unlock() + + // Get a write transaction + tx := memdbDS.db.Txn(true) + defer tx.Abort() + + // Insert initial schema + initialSchema := &schemaData{ + name: "test-schema-update", + data: []byte("initial data"), + } + + err = tx.Insert(tableSchema, initialSchema) + require.NoError(t, err, "failed to insert initial schema") + + // Update the schema by deleting and inserting + err = tx.Delete(tableSchema, initialSchema) + require.NoError(t, err, "failed to delete old schema") + + updatedSchema := &schemaData{ + name: "test-schema-update", + data: []byte("updated data that is much longer"), + } + + err = tx.Insert(tableSchema, updatedSchema) + require.NoError(t, err, "failed to insert updated schema") + + // Read back and verify + found, err := tx.First(tableSchema, indexID, "test-schema-update") + require.NoError(t, err) + require.NotNil(t, found, "updated schema not found") + + readSchema := found.(*schemaData) + require.Equal(t, updatedSchema.data, readSchema.data) + + // Commit the transaction + tx.Commit() +} + +func TestSchemaTableDelete(t *testing.T) { + ds, err := NewMemdbDatastore(0, 0, 0) + require.NoError(t, err) + defer ds.Close() + + memdbDS := ds.(*memdbDatastore) + + // Lock for direct access + memdbDS.Lock() + defer memdbDS.Unlock() + + // Get a write transaction + tx := memdbDS.db.Txn(true) + defer tx.Abort() + + // Insert schema + testSchema := &schemaData{ + name: "test-schema-delete", + data: []byte("data to be deleted"), + } + + err = tx.Insert(tableSchema, testSchema) + require.NoError(t, err, "failed to insert schema") + + // Verify it exists + found, err := tx.First(tableSchema, indexID, "test-schema-delete") + require.NoError(t, err) + require.NotNil(t, found, "schema not found before delete") + + // Delete the schema + err = tx.Delete(tableSchema, testSchema) + require.NoError(t, err, "failed to delete schema") + + // Verify it's gone + found, err = tx.First(tableSchema, indexID, "test-schema-delete") + require.NoError(t, err) + require.Nil(t, found, "schema should be deleted") + + // Commit the transaction + tx.Commit() +} + +func TestSchemaTableMultipleSchemas(t *testing.T) { + ds, err := NewMemdbDatastore(0, 0, 0) + require.NoError(t, err) + defer ds.Close() + + memdbDS := ds.(*memdbDatastore) + + // Lock for direct access + memdbDS.Lock() + defer memdbDS.Unlock() + + // Get a write transaction + tx := memdbDS.db.Txn(true) + defer tx.Abort() + + // Insert multiple schemas + schemas := []*schemaData{ + {name: "schema1", data: []byte("data1")}, + {name: "schema2", data: []byte("data2")}, + {name: "schema3", data: []byte("data3")}, + } + + for _, s := range schemas { + err = tx.Insert(tableSchema, s) + require.NoError(t, err, "failed to insert schema %s", s.name) + } + + // Read them all back + for _, s := range schemas { + found, err := tx.First(tableSchema, indexID, s.name) + require.NoError(t, err) + require.NotNil(t, found, "schema %s not found", s.name) + + readSchema := found.(*schemaData) + require.Equal(t, s.name, readSchema.name) + require.Equal(t, s.data, readSchema.data) + } + + // Commit the transaction + tx.Commit() +} diff --git a/internal/datastore/memdb/stats.go b/internal/datastore/memdb/stats.go index 3a7263877..bd55757ff 100644 --- a/internal/datastore/memdb/stats.go +++ b/internal/datastore/memdb/stats.go @@ -8,7 +8,7 @@ import ( ) func (mdb *memdbDatastore) Statistics(ctx context.Context) (datastore.Stats, error) { - head, err := mdb.HeadRevision(ctx) + head, _, err := mdb.HeadRevision(ctx) if err != nil { return datastore.Stats{}, fmt.Errorf("unable to compute head revision: %w", err) } @@ -18,7 +18,7 @@ func (mdb *memdbDatastore) Statistics(ctx context.Context) (datastore.Stats, err return datastore.Stats{}, fmt.Errorf("unable to count relationships: %w", err) } - objTypes, err := mdb.SnapshotReader(head).LegacyListAllNamespaces(ctx) + objTypes, err := mdb.SnapshotReader(head, noHashSupported).LegacyListAllNamespaces(ctx) if err != nil { return datastore.Stats{}, fmt.Errorf("unable to list object types: %w", err) } diff --git a/internal/datastore/mysql/caveat.go b/internal/datastore/mysql/caveat.go index bb48ad509..f44c9ab01 100644 --- a/internal/datastore/mysql/caveat.go +++ b/internal/datastore/mysql/caveat.go @@ -148,7 +148,11 @@ func (rwt *mysqlReadWriteTXN) LegacyWriteCaveats(ctx context.Context, caveats [] } func (rwt *mysqlReadWriteTXN) LegacyDeleteCaveats(ctx context.Context, names []string) error { - return rwt.deleteCaveatsFromNames(ctx, names) + if err := rwt.deleteCaveatsFromNames(ctx, names); err != nil { + return err + } + + return nil } func (rwt *mysqlReadWriteTXN) deleteCaveatsFromNames(ctx context.Context, names []string) error { diff --git a/internal/datastore/mysql/datastore.go b/internal/datastore/mysql/datastore.go index 977985a05..24afbc21a 100644 --- a/internal/datastore/mysql/datastore.go +++ b/internal/datastore/mysql/datastore.go @@ -27,7 +27,7 @@ import ( log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/internal/sharederrors" "github.com/authzed/spicedb/pkg/datastore" - "github.com/authzed/spicedb/pkg/datastore/options" + dsoptions "github.com/authzed/spicedb/pkg/datastore/options" ) const ( @@ -74,6 +74,12 @@ const ( errMysqlDuplicateEntry = 1062 ) +type contextKey string + +const ( + ctxKeyTransactionID contextKey = "mysql_transaction_id" +) + var ( tracer = otel.Tracer("spicedb/internal/datastore/mysql") @@ -226,6 +232,16 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option followerReadDelayNanos, ) + revisionQueryWithHash := fmt.Sprintf( + querySelectRevisionWithHash, + colID, + driver.RelationTupleTransaction(), + colTimestamp, + quantizationPeriodNanos, + followerReadDelayNanos, + driver.SchemaRevision(), + ) + validTransactionQuery := fmt.Sprintf( queryValidTransaction, colID, @@ -254,39 +270,48 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option ) store := &mysqlDatastore{ - MigrationValidator: common.NewMigrationValidator(headMigration, config.allowedMigrations), - db: db, - driver: driver, - collectors: collectors, - url: uri, - revisionQuantization: config.revisionQuantization, - gcWindow: config.gcWindow, - gcInterval: config.gcInterval, - gcTimeout: config.gcMaxOperationTime, - gcCtx: gcCtx, - cancelGc: cancelGc, - watchEnabled: !config.watchDisabled, - watchBufferLength: config.watchBufferLength, - watchChangeBufferMaximumSize: config.watchChangeBufferMaximumSize, - watchBufferWriteTimeout: config.watchBufferWriteTimeout, - optimizedRevisionQuery: revisionQuery, - validTransactionQuery: validTransactionQuery, - createTxn: createTxn, - createBaseTxn: createBaseTxn, - QueryBuilder: queryBuilder, - readTxOptions: &sql.TxOptions{Isolation: sql.LevelSerializable, ReadOnly: true}, - maxRetries: config.maxRetries, - analyzeBeforeStats: config.analyzeBeforeStats, - schema: *schema, + MigrationValidator: common.NewMigrationValidator(headMigration, config.allowedMigrations), + db: db, + driver: driver, + collectors: collectors, + url: uri, + revisionQuantization: config.revisionQuantization, + gcWindow: config.gcWindow, + gcInterval: config.gcInterval, + gcTimeout: config.gcMaxOperationTime, + gcCtx: gcCtx, + cancelGc: cancelGc, + watchEnabled: !config.watchDisabled, + watchBufferLength: config.watchBufferLength, + watchChangeBufferMaximumSize: config.watchChangeBufferMaximumSize, + watchBufferWriteTimeout: config.watchBufferWriteTimeout, + optimizedRevisionQuery: revisionQuery, + optimizedRevisionQueryWithHash: revisionQueryWithHash, + validTransactionQuery: validTransactionQuery, + createTxn: createTxn, + createBaseTxn: createBaseTxn, + QueryBuilder: queryBuilder, + readTxOptions: &sql.TxOptions{Isolation: sql.LevelSerializable, ReadOnly: true}, + maxRetries: config.maxRetries, + analyzeBeforeStats: config.analyzeBeforeStats, + schema: *schema, + schemaMode: config.schemaMode, CachedOptimizedRevisions: revisions.NewCachedOptimizedRevisions( maxRevisionStaleness, ), CommonDecoder: revisions.CommonDecoder{ Kind: revisions.TransactionID, }, + schemaTableName: driver.Schema(), filterMaximumIDCount: config.filterMaximumIDCount, } + // Initialize schema reader/writer + store.schemaReaderWriter, err = common.NewSQLSchemaReaderWriter[uint64, revisions.TransactionIDRevision](BaseSchemaChunkerConfig.WithTableName(driver.Schema()), config.schemaCacheOptions) + if err != nil { + return nil, err + } + store.SetOptimizedRevisionFunc(store.optimizedRevisionFunc) ctx, cancel := context.WithTimeout(context.Background(), seedingTimeout) @@ -296,6 +321,8 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option return nil, err } + // Hash-based cache doesn't need watchers or warming + // Start a goroutine for garbage collection. if isPrimary { if store.gcInterval > 0*time.Minute && config.gcEnabled { @@ -314,6 +341,13 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option } } + // Warm the schema cache on startup + warmCtx, cancelWarm := context.WithTimeout(context.Background(), seedingTimeout) + defer cancelWarm() + if err := warmSchemaCache(warmCtx, store); err != nil { + log.Warn().Err(err).Msg("failed to warm schema cache on startup") + } + return store, nil } @@ -321,7 +355,7 @@ func (mds *mysqlDatastore) MetricsID() (string, error) { return common.MetricsIDFromURL(mds.url) } -func (mds *mysqlDatastore) SnapshotReader(rev datastore.Revision) datastore.Reader { +func (mds *mysqlDatastore) SnapshotReader(rev datastore.Revision, hash datastore.SchemaHash) datastore.Reader { createTxFunc := func(ctx context.Context) (*sql.Tx, txCleanupFunc, error) { tx, err := mds.db.BeginTx(ctx, mds.readTxOptions) if err != nil { @@ -336,12 +370,17 @@ func (mds *mysqlDatastore) SnapshotReader(rev datastore.Revision) datastore.Read } return &mysqlReader{ - mds.QueryBuilder, - createTxFunc, - executor, - buildLivingObjectFilterForRevision(rev), - mds.filterMaximumIDCount, - mds.schema, + QueryBuilder: mds.QueryBuilder, + txSource: createTxFunc, + executor: executor, + aliveFilter: buildLivingObjectFilterForRevision(rev), + filterMaximumIDCount: mds.filterMaximumIDCount, + schema: mds.schema, + schemaMode: mds.schemaMode, + snapshotRevision: rev, + schemaHash: string(hash), + schemaTableName: mds.driver.Schema(), + schemaReaderWriter: mds.schemaReaderWriter, } } @@ -352,9 +391,9 @@ func noCleanup() error { return nil } func (mds *mysqlDatastore) ReadWriteTx( ctx context.Context, fn datastore.TxUserFunc, - opts ...options.RWTOptionsOption, + opts ...dsoptions.RWTOptionsOption, ) (datastore.Revision, error) { - config := options.NewRWTOptionsWithOptions(opts...) + config := dsoptions.NewRWTOptionsWithOptions(opts...) var err error for i := uint8(0); i <= mds.maxRetries; i++ { @@ -379,20 +418,29 @@ func (mds *mysqlDatastore) ReadWriteTx( } rwt := &mysqlReadWriteTXN{ - &mysqlReader{ - mds.QueryBuilder, - longLivedTx, - executor, - currentlyLivingObjects, - mds.filterMaximumIDCount, - mds.schema, + mysqlReader: &mysqlReader{ + QueryBuilder: mds.QueryBuilder, + txSource: longLivedTx, + executor: executor, + aliveFilter: currentlyLivingObjects, + filterMaximumIDCount: mds.filterMaximumIDCount, + schema: mds.schema, + schemaMode: mds.schemaMode, + snapshotRevision: datastore.NoRevision, // snapshotRevision (not yet known in RWT) + schemaHash: string(datastore.NoSchemaHashInTransaction), // Bypass cache for transaction reads + schemaTableName: mds.driver.Schema(), + schemaReaderWriter: mds.schemaReaderWriter, }, - mds.driver.RelationTuple(), - tx, - newTxnID, + tupleTableName: mds.driver.RelationTuple(), + schemaTableName: mds.driver.Schema(), + schemaRevisionTableName: mds.driver.SchemaRevision(), + tx: tx, + newTxnID: newTxnID, } - return fn(ctx, rwt) + // Add transaction ID to context for schema operations + ctxWithTxn := context.WithValue(ctx, ctxKeyTransactionID, newTxnID) + return fn(ctxWithTxn, rwt) }); err != nil { if !config.DisableRetries && isErrorRetryable(err) { continue @@ -495,9 +543,11 @@ type mysqlDatastore struct { maxRetries uint8 filterMaximumIDCount uint16 schema common.SchemaInformation + schemaMode dsoptions.SchemaMode - optimizedRevisionQuery string - validTransactionQuery string + optimizedRevisionQuery string + optimizedRevisionQueryWithHash string + validTransactionQuery string gcGroup *errgroup.Group gcCtx context.Context @@ -507,6 +557,10 @@ type mysqlDatastore struct { createTxn sq.InsertBuilder createBaseTxn string + // SQLSchemaReaderWriter for schema operations + schemaReaderWriter *common.SQLSchemaReaderWriter[uint64, revisions.TransactionIDRevision] + schemaTableName string + uniqueID atomic.Pointer[string] *QueryBuilder @@ -528,6 +582,51 @@ func (mds *mysqlDatastore) Close() error { return mds.db.Close() } +// SchemaHashReaderForTesting returns a test-only interface for reading the schema hash directly from schema_revision table. +func (mds *mysqlDatastore) SchemaHashReaderForTesting() interface { + ReadSchemaHash(ctx context.Context) (string, error) +} { + return &mysqlSchemaHashReaderForTesting{ + db: mds.db, + schemaRevisionTableName: mds.driver.SchemaRevision(), + } +} + +// SchemaModeForTesting returns the current schema mode for testing purposes. +func (mds *mysqlDatastore) SchemaModeForTesting() (dsoptions.SchemaMode, error) { + return mds.schemaMode, nil +} + +type mysqlSchemaHashReaderForTesting struct { + db *sql.DB + schemaRevisionTableName string +} + +func (r *mysqlSchemaHashReaderForTesting) ReadSchemaHash(ctx context.Context) (string, error) { + query, args, err := sb.Select("hash"). + From(r.schemaRevisionTableName). + Where(sq.Eq{ + "name": "current", + "deleted_transaction": liveDeletedTxnID, + }). + ToSql() + if err != nil { + return "", fmt.Errorf("failed to build query: %w", err) + } + + var hashBytes []byte + + err = r.db.QueryRowContext(ctx, query, args...).Scan(&hashBytes) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", datastore.ErrSchemaNotFound + } + return "", fmt.Errorf("failed to query schema hash: %w", err) + } + + return string(hashBytes), nil +} + // ReadyState returns whether the datastore is ready to accept data. Datastores that require // database schema creation will return false until the migrations have been run to create // the necessary tables. @@ -592,7 +691,7 @@ func (mds *mysqlDatastore) OfflineFeatures() (*datastore.Features, error) { // isSeeded determines if the backing database has been seeded func (mds *mysqlDatastore) isSeeded(ctx context.Context) (bool, error) { - headRevision, err := mds.HeadRevision(ctx) + headRevision, _, err := mds.HeadRevision(ctx) if err != nil { return false, err } @@ -690,6 +789,39 @@ func (debugLogger) Print(v ...any) { log.Logger.Debug().CallerSkipFrame(1).Str("datastore", "mysql").Msg(fmt.Sprint(v...)) } +// warmSchemaCache attempts to warm the schema cache by loading the current schema. +// This is called during datastore initialization to avoid cold-start latency on first requests. +func warmSchemaCache(ctx context.Context, ds *mysqlDatastore) error { + // Get the current revision and schema hash + rev, schemaHash, err := ds.HeadRevision(ctx) + if err != nil { + return fmt.Errorf("failed to get head revision: %w", err) + } + + // If there's no schema hash, there's no schema to warm + if schemaHash == "" { + log.Ctx(ctx).Debug().Msg("no schema hash found, skipping cache warming") + return nil + } + + // Create a simple executor for schema reading (no transaction, no revision filtering needed for warmup) + executor := newMySQLChunkedBytesExecutor(ds.db) + + // Load the schema to populate the cache + _, err = ds.schemaReaderWriter.ReadSchema(ctx, executor, rev, schemaHash) + if err != nil { + if errors.Is(err, datastore.ErrSchemaNotFound) { + // Schema not found is not an error during warming - just means no schema yet + log.Ctx(ctx).Debug().Msg("no schema found, skipping cache warming") + return nil + } + return fmt.Errorf("failed to read schema: %w", err) + } + + log.Ctx(ctx).Info().Str("schema_hash", string(schemaHash)).Msg("schema cache warmed successfully") + return nil +} + func registerAndReturnPrometheusCollectors(replicaIndex int, isPrimary bool, connector driver.Connector, enablePrometheusStats bool) (*sql.DB, []prometheus.Collector, error) { if !enablePrometheusStats { return sql.OpenDB(connector), nil, nil diff --git a/internal/datastore/mysql/datastore_test.go b/internal/datastore/mysql/datastore_test.go index 16f8e6574..5b8375aba 100644 --- a/internal/datastore/mysql/datastore_test.go +++ b/internal/datastore/mysql/datastore_test.go @@ -22,6 +22,7 @@ import ( "github.com/authzed/spicedb/internal/testfixtures" testdatastore "github.com/authzed/spicedb/internal/testserver/datastore" "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/datastore/test" "github.com/authzed/spicedb/pkg/migrate" "github.com/authzed/spicedb/pkg/namespace" @@ -145,6 +146,8 @@ func additionalMySQLTests(t *testing.T, b testdatastore.RunningEngineForTest) { t.Run("ChunkedGarbageCollection", createDatastoreTest(b, ChunkedGarbageCollectionTest, defaultOptions...)) t.Run("EmptyGarbageCollection", createDatastoreTest(b, EmptyGarbageCollectionTest, defaultOptions...)) t.Run("NoRelationshipsGarbageCollection", createDatastoreTest(b, NoRelationshipsGarbageCollectionTest, defaultOptions...)) + schemaGCOptions := append([]Option{WithSchemaMode(options.SchemaModeReadNewWriteNew)}, defaultOptions...) + t.Run("SchemaGarbageCollection", createDatastoreTest(b, SchemaGarbageCollectionTest, schemaGCOptions...)) t.Run("QuantizedRevisions", func(t *testing.T) { QuantizedRevisionTest(t, b) }) @@ -370,6 +373,129 @@ func GarbageCollectionTest(t *testing.T, ds datastore.Datastore) { tRequire.RelationshipExists(ctx, crel3, relLastWriteAt) } +func SchemaGarbageCollectionTest(t *testing.T, ds datastore.Datastore) { + req := require.New(t) + + ctx := context.Background() + r, err := ds.ReadyState(ctx) + req.NoError(err) + req.True(r.IsReady) + + mds := ds.(*mysqlDatastore) + + mgg, err := mds.BuildGarbageCollector(ctx) + req.NoError(err) + defer mgg.Close() + + // Helper to count rows in schema tables + countSchemaRows := func() (schemaRows, schemaRevisionRows int64) { + //nolint:gosec // Table name comes from driver method, not user input + sql := fmt.Sprintf("SELECT COUNT(*) FROM %s", mds.driver.Schema()) + err := mds.db.QueryRowContext(ctx, sql).Scan(&schemaRows) + req.NoError(err) + + sql = fmt.Sprintf("SELECT COUNT(*) FROM %s", mds.driver.SchemaRevision()) + err = mds.db.QueryRowContext(ctx, sql).Scan(&schemaRevisionRows) + req.NoError(err) + + return schemaRows, schemaRevisionRows + } + + // Write schema version 1 + schemaText1 := `definition resource { + relation reader: user + } + definition user {}` + rev1, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + schemaWriter, err := rwt.SchemaWriter() + if err != nil { + return err + } + return schemaWriter.WriteSchema(ctx, []datastore.SchemaDefinition{}, schemaText1, nil) + }) + req.NoError(err) + + schemaRows, schemaRevisionRows := countSchemaRows() + req.Positive(schemaRows, "Should have schema rows after first write") + req.Positive(schemaRevisionRows, "Should have schema_revision rows after first write") + + // Write schema version 2 + schemaText2 := `definition resource { + relation reader: user + relation writer: user + } + definition user {}` + rev2, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + schemaWriter, err := rwt.SchemaWriter() + if err != nil { + return err + } + return schemaWriter.WriteSchema(ctx, []datastore.SchemaDefinition{}, schemaText2, nil) + }) + req.NoError(err) + + schemaRows2, schemaRevisionRows2 := countSchemaRows() + req.Greater(schemaRows2, schemaRows, "Should have more schema rows after second write") + req.Greater(schemaRevisionRows2, schemaRevisionRows, "Should have more schema_revision rows after second write") + + // Write schema version 3 + schemaText3 := `definition resource { + relation reader: user + relation writer: user + relation admin: user + } + definition user {}` + rev3, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + schemaWriter, err := rwt.SchemaWriter() + if err != nil { + return err + } + return schemaWriter.WriteSchema(ctx, []datastore.SchemaDefinition{}, schemaText3, nil) + }) + req.NoError(err) + + schemaRows3, schemaRevisionRows3 := countSchemaRows() + req.Greater(schemaRows3, schemaRows2, "Should have more schema rows after third write") + req.Greater(schemaRevisionRows3, schemaRevisionRows2, "Should have more schema_revision rows after third write") + + // Run GC at rev1 - should not remove any schema rows since they're all still in use + removed, err := mgg.DeleteBeforeTx(ctx, rev1) + req.NoError(err) + req.Zero(removed.Relationships) + + schemaRowsAfterGC1, schemaRevisionRowsAfterGC1 := countSchemaRows() + req.Equal(schemaRows3, schemaRowsAfterGC1, "No schema rows should be removed at rev1") + req.Equal(schemaRevisionRows3, schemaRevisionRowsAfterGC1, "No schema_revision rows should be removed at rev1") + + // Run GC at rev2 - should remove schema rows from rev1 + _, err = mgg.DeleteBeforeTx(ctx, rev2) + req.NoError(err) + + schemaRowsAfterGC2, schemaRevisionRowsAfterGC2 := countSchemaRows() + req.Less(schemaRowsAfterGC2, schemaRows3, "Schema rows from rev1 should be removed") + req.Less(schemaRevisionRowsAfterGC2, schemaRevisionRows3, "Schema_revision rows from rev1 should be removed") + + // Run GC at rev3 - should remove schema rows from rev2 + _, err = mgg.DeleteBeforeTx(ctx, rev3) + req.NoError(err) + + schemaRowsAfterGC3, schemaRevisionRowsAfterGC3 := countSchemaRows() + req.Less(schemaRowsAfterGC3, schemaRowsAfterGC2, "Schema rows from rev2 should be removed") + req.Less(schemaRevisionRowsAfterGC3, schemaRevisionRowsAfterGC2, "Schema_revision rows from rev2 should be removed") + + // Verify we can still read the latest schema + headRev, _, err := ds.HeadRevision(ctx) + req.NoError(err) + + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) + schemaReader, err := reader.SchemaReader() + req.NoError(err) + schemaText, err := schemaReader.SchemaText() + req.NoError(err) + req.NotEmpty(schemaText, "Schema text should not be empty") + req.Contains(schemaText, "relation admin", "Schema should contain the admin relation") +} + func GarbageCollectionByTimeTest(t *testing.T, ds datastore.Datastore) { req := require.New(t) @@ -766,7 +892,7 @@ func TransactionTimestampsTest(t *testing.T, ds datastore.Datastore) { // Let's make sure both Now() and transactionCreated() have timezones aligned req.Less(ts.Sub(startTimeUTC), 5*time.Minute) - revision, err := ds.OptimizedRevision(ctx) + revision, _, err := ds.OptimizedRevision(ctx) req.NoError(err) req.Equal(revisions.NewForTransactionID(txID), revision) } @@ -844,6 +970,34 @@ func TestMySQLWithAWSIAMCredentialsProvider(t *testing.T) { require.ErrorContains(t, err, ":1234: connect: connection refused") } +func TestMySQLDatastoreUnifiedSchemaAllModes(t *testing.T) { + t.Parallel() + b := testdatastore.RunMySQLForTesting(t, "") + + test.UnifiedSchemaAllModesTest(t, func(schemaMode options.SchemaMode) test.DatastoreTester { + return test.DatastoreTesterFunc(func(revisionQuantization, gcInterval, gcWindow time.Duration, watchBufferLength uint16) (datastore.Datastore, error) { + ctx := context.Background() + ds := b.NewDatastore(t, func(engine, uri string) datastore.Datastore { + ds, err := NewMySQLDatastore( + ctx, + uri, + GCWindow(gcWindow), + RevisionQuantization(revisionQuantization), + WatchBufferLength(watchBufferLength), + WithSchemaMode(schemaMode), + ) + require.NoError(t, err) + t.Cleanup(func() { + _ = ds.Close() + }) + return indexcheck.WrapWithIndexCheckingDatastoreProxyIfApplicable(ds) + }) + + return ds, nil + }) + }) +} + func datastoreDB(t *testing.T, migrate bool) *sql.DB { var databaseURI string testdatastore.RunMySQLForTestingWithOptions(t, testdatastore.MySQLTesterOptions{MigrateForNewDatastore: migrate}, "").NewDatastore(t, func(engine, uri string) datastore.Datastore { diff --git a/internal/datastore/mysql/gc.go b/internal/datastore/mysql/gc.go index 42c809597..7d95f8ca9 100644 --- a/internal/datastore/mysql/gc.go +++ b/internal/datastore/mysql/gc.go @@ -132,6 +132,18 @@ func (mcc *mysqlGarbageCollector) DeleteBeforeTx( // Delete any namespace rows with deleted_transaction <= the transaction ID. removed.Namespaces, err = mcc.batchDelete(ctx, mcc.mds.driver.Namespace(), sq.LtOrEq{colDeletedTxn: txID}) + if err != nil { + return removed, err + } + + // Delete any schema rows with deleted_transaction <= the transaction ID. + _, err = mcc.batchDelete(ctx, mcc.mds.driver.Schema(), sq.LtOrEq{colDeletedTxn: txID}) + if err != nil { + return removed, err + } + + // Delete any schema_revision rows with deleted_transaction <= the transaction ID. + _, err = mcc.batchDelete(ctx, mcc.mds.driver.SchemaRevision(), sq.LtOrEq{colDeletedTxn: txID}) return removed, err } diff --git a/internal/datastore/mysql/migrations/tables.go b/internal/datastore/mysql/migrations/tables.go index b9aaa528f..a7c95307d 100644 --- a/internal/datastore/mysql/migrations/tables.go +++ b/internal/datastore/mysql/migrations/tables.go @@ -1,13 +1,15 @@ package migrations const ( - tableNamespaceDefault = "namespace_config" - tableTransactionDefault = "relation_tuple_transaction" - tableTupleDefault = "relation_tuple" - tableMigrationVersion = "mysql_migration_version" - tableMetadataDefault = "mysql_metadata" - tableCaveatDefault = "caveat" - tableRelationshipCounters = "relationship_counters" + tableNamespaceDefault = "namespace_config" + tableTransactionDefault = "relation_tuple_transaction" + tableTupleDefault = "relation_tuple" + tableMigrationVersion = "mysql_migration_version" + tableMetadataDefault = "mysql_metadata" + tableCaveatDefault = "caveat" + tableRelationshipCounters = "relationship_counters" + tableSchemaDefault = "schema_data" + tableSchemaRevisionDefault = "schema_revision" ) type tables struct { @@ -18,6 +20,8 @@ type tables struct { tableMetadata string tableCaveat string tableRelationshipCounters string + tableSchema string + tableSchemaRevision string } func newTables(prefix string) *tables { @@ -29,6 +33,8 @@ func newTables(prefix string) *tables { tableMetadata: prefix + tableMetadataDefault, tableCaveat: prefix + tableCaveatDefault, tableRelationshipCounters: prefix + tableRelationshipCounters, + tableSchema: prefix + tableSchemaDefault, + tableSchemaRevision: prefix + tableSchemaRevisionDefault, } } @@ -62,3 +68,11 @@ func (tn *tables) Caveat() string { func (tn *tables) RelationshipCounters() string { return tn.tableRelationshipCounters } + +func (tn *tables) Schema() string { + return tn.tableSchema +} + +func (tn *tables) SchemaRevision() string { + return tn.tableSchemaRevision +} diff --git a/internal/datastore/mysql/migrations/zz_migration.0011_add_schema_tables.go b/internal/datastore/mysql/migrations/zz_migration.0011_add_schema_tables.go new file mode 100644 index 000000000..f5a281ccd --- /dev/null +++ b/internal/datastore/mysql/migrations/zz_migration.0011_add_schema_tables.go @@ -0,0 +1,35 @@ +package migrations + +import "fmt" + +func createSchemaTable(t *tables) string { + return fmt.Sprintf(`CREATE TABLE %s ( + name VARCHAR(700) NOT NULL, + chunk_index INT NOT NULL, + chunk_data LONGBLOB NOT NULL, + created_transaction BIGINT NOT NULL, + deleted_transaction BIGINT NOT NULL DEFAULT '9223372036854775807', + CONSTRAINT pk_schema PRIMARY KEY (name, chunk_index, created_transaction)) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`, + t.Schema(), + ) +} + +func createSchemaRevisionTable(t *tables) string { + return fmt.Sprintf(`CREATE TABLE %s ( + name VARCHAR(700) NOT NULL DEFAULT 'current', + hash BLOB NOT NULL, + created_transaction BIGINT NOT NULL, + deleted_transaction BIGINT NOT NULL DEFAULT '9223372036854775807', + CONSTRAINT pk_schema_revision PRIMARY KEY (name, created_transaction)) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`, + t.SchemaRevision(), + ) +} + +func init() { + mustRegisterMigration("add_schema_tables", "add_expiration_to_relation_tuple", noNonatomicMigration, + newStatementBatch( + createSchemaTable, + createSchemaRevisionTable, + ).execute, + ) +} diff --git a/internal/datastore/mysql/migrations/zz_migration.0012_populate_schema_tables.go b/internal/datastore/mysql/migrations/zz_migration.0012_populate_schema_tables.go new file mode 100644 index 000000000..9fa33139f --- /dev/null +++ b/internal/datastore/mysql/migrations/zz_migration.0012_populate_schema_tables.go @@ -0,0 +1,188 @@ +package migrations + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "sort" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" +) + +const ( + schemaChunkSize = 1024 * 1024 // 1MB chunks + currentSchemaVersion = 1 + unifiedSchemaName = "unified_schema" + schemaRevisionName = "current" +) + +func init() { + mustRegisterMigration("populate_schema_tables", "add_schema_tables", noNonatomicMigration, populateSchemaTablesFunc) +} + +func populateSchemaTablesFunc(ctx context.Context, wrapper TxWrapper) error { + tx := wrapper.tx + + // Read all existing namespaces (not deleted ones) + //nolint:gosec // Table name is from internal schema configuration, not user input + query := fmt.Sprintf(` + SELECT nc1.namespace, nc1.serialized_config + FROM %[1]s nc1 + INNER JOIN ( + SELECT namespace, MAX(created_transaction) as max_created + FROM %[1]s + WHERE deleted_transaction = 9223372036854775807 + GROUP BY namespace + ) nc2 ON nc1.namespace = nc2.namespace AND nc1.created_transaction = nc2.max_created + WHERE nc1.deleted_transaction = 9223372036854775807 + `, wrapper.tables.Namespace()) + + rows, err := tx.QueryContext(ctx, query) + if err != nil { + return fmt.Errorf("failed to query namespaces: %w", err) + } + defer rows.Close() + + namespaces := make(map[string]*core.NamespaceDefinition) + for rows.Next() { + var name string + var config []byte + if err := rows.Scan(&name, &config); err != nil { + return fmt.Errorf("failed to scan namespace: %w", err) + } + + var ns core.NamespaceDefinition + if err := ns.UnmarshalVT(config); err != nil { + return fmt.Errorf("failed to unmarshal namespace %s: %w", name, err) + } + namespaces[name] = &ns + } + if err := rows.Err(); err != nil { + return fmt.Errorf("error iterating namespaces: %w", err) + } + + // Read all existing caveats (not deleted ones) + query = fmt.Sprintf(` + SELECT c1.name, c1.definition + FROM %[1]s c1 + INNER JOIN ( + SELECT name, MAX(created_transaction) as max_created + FROM %[1]s + WHERE deleted_transaction = 9223372036854775807 + GROUP BY name + ) c2 ON c1.name = c2.name AND c1.created_transaction = c2.max_created + WHERE c1.deleted_transaction = 9223372036854775807 + `, wrapper.tables.Caveat()) + + rows, err = tx.QueryContext(ctx, query) + if err != nil { + return fmt.Errorf("failed to query caveats: %w", err) + } + defer rows.Close() + + caveats := make(map[string]*core.CaveatDefinition) + for rows.Next() { + var name string + var definition []byte + if err := rows.Scan(&name, &definition); err != nil { + return fmt.Errorf("failed to scan caveat: %w", err) + } + + var caveat core.CaveatDefinition + if err := caveat.UnmarshalVT(definition); err != nil { + return fmt.Errorf("failed to unmarshal caveat %s: %w", name, err) + } + caveats[name] = &caveat + } + if err := rows.Err(); err != nil { + return fmt.Errorf("error iterating caveats: %w", err) + } + + // If there are no namespaces or caveats, skip migration + if len(namespaces) == 0 && len(caveats) == 0 { + return nil + } + + // Generate canonical schema for hash computation + allDefs := make([]compiler.SchemaDefinition, 0, len(namespaces)+len(caveats)) + for _, ns := range namespaces { + allDefs = append(allDefs, ns) + } + for _, caveat := range caveats { + allDefs = append(allDefs, caveat) + } + + // Sort alphabetically for canonical ordering + sort.Slice(allDefs, func(i, j int) bool { + return allDefs[i].GetName() < allDefs[j].GetName() + }) + + // Generate canonical schema text + canonicalSchemaText, _, err := generator.GenerateSchema(allDefs) + if err != nil { + return fmt.Errorf("failed to generate canonical schema: %w", err) + } + + // Compute SHA256 hash + hashBytes := sha256.Sum256([]byte(canonicalSchemaText)) + schemaHash := hashBytes[:] + + // Generate user-facing schema text + schemaText, _, err := generator.GenerateSchema(allDefs) + if err != nil { + return fmt.Errorf("failed to generate schema: %w", err) + } + + // Create stored schema proto + storedSchema := &core.StoredSchema{ + Version: currentSchemaVersion, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: schemaText, + SchemaHash: hex.EncodeToString(schemaHash), + NamespaceDefinitions: namespaces, + CaveatDefinitions: caveats, + }, + }, + } + + // Marshal schema + schemaData, err := storedSchema.MarshalVT() + if err != nil { + return fmt.Errorf("failed to marshal schema: %w", err) + } + + // Insert schema chunks + for chunkIndex := 0; chunkIndex*schemaChunkSize < len(schemaData); chunkIndex++ { + start := chunkIndex * schemaChunkSize + end := start + schemaChunkSize + if end > len(schemaData) { + end = len(schemaData) + } + chunk := schemaData[start:end] + + query = fmt.Sprintf(` + INSERT INTO %s (name, chunk_index, chunk_data) + VALUES (?, ?, ?) + `, wrapper.tables.Schema()) + _, err = tx.ExecContext(ctx, query, unifiedSchemaName, chunkIndex, chunk) + if err != nil { + return fmt.Errorf("failed to insert schema chunk %d: %w", chunkIndex, err) + } + } + + // Insert schema hash + query = fmt.Sprintf(` + INSERT INTO %s (name, hash) + VALUES (?, ?) + `, wrapper.tables.SchemaRevision()) + _, err = tx.ExecContext(ctx, query, schemaRevisionName, schemaHash) + if err != nil { + return fmt.Errorf("failed to insert schema hash: %w", err) + } + + return nil +} diff --git a/internal/datastore/mysql/options.go b/internal/datastore/mysql/options.go index 63f6e195c..00fd34992 100644 --- a/internal/datastore/mysql/options.go +++ b/internal/datastore/mysql/options.go @@ -6,6 +6,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/common" log "github.com/authzed/spicedb/internal/logging" + dsoptions "github.com/authzed/spicedb/pkg/datastore/options" ) const ( @@ -56,6 +57,8 @@ type mysqlOptions struct { allowedMigrations []string columnOptimizationOption common.ColumnOptimizationOption watchDisabled bool + schemaMode dsoptions.SchemaMode + schemaCacheOptions dsoptions.SchemaCacheOptions } // Option provides the facility to configure how clients within the @@ -313,3 +316,17 @@ func WithWatchDisabled(isDisabled bool) Option { mo.watchDisabled = isDisabled } } + +// WithSchemaMode sets the experimental schema mode for the datastore. +func WithSchemaMode(mode dsoptions.SchemaMode) Option { + return func(mo *mysqlOptions) { + mo.schemaMode = mode + } +} + +// WithSchemaCacheOptions sets the schema cache options for the datastore. +func WithSchemaCacheOptions(cacheOptions dsoptions.SchemaCacheOptions) Option { + return func(mo *mysqlOptions) { + mo.schemaCacheOptions = cacheOptions + } +} diff --git a/internal/datastore/mysql/query_builder.go b/internal/datastore/mysql/query_builder.go index 4a84885f5..5fec307c3 100644 --- a/internal/datastore/mysql/query_builder.go +++ b/internal/datastore/mysql/query_builder.go @@ -9,8 +9,9 @@ import ( // QueryBuilder captures all parameterizable queries used // by the MySQL datastore implementation type QueryBuilder struct { - GetLastRevision sq.SelectBuilder - LoadRevisionRange sq.SelectBuilder + GetLastRevision sq.SelectBuilder + GetLastRevisionWithHash sq.SelectBuilder + LoadRevisionRange sq.SelectBuilder WriteNamespaceQuery sq.InsertBuilder ReadNamespaceQuery sq.SelectBuilder @@ -43,6 +44,7 @@ func NewQueryBuilder(driver *migrations.MySQLDriver) *QueryBuilder { // transaction builders builder.GetLastRevision = getLastRevision(driver.RelationTupleTransaction()) + builder.GetLastRevisionWithHash = getLastRevisionWithHash(driver.RelationTupleTransaction(), driver.SchemaRevision()) builder.LoadRevisionRange = loadRevisionRange(driver.RelationTupleTransaction()) // namespace builders @@ -99,6 +101,15 @@ func getLastRevision(tableTransaction string) sq.SelectBuilder { return sb.Select("MAX(id)").From(tableTransaction).Limit(1) } +func getLastRevisionWithHash(tableTransaction string, tableSchemaRevision string) sq.SelectBuilder { + // Get the latest transaction ID and schema hash as separate subqueries + // to avoid MySQL's only_full_group_by restriction + return sb.Select( + "(SELECT MAX(id) FROM "+tableTransaction+")", + "COALESCE((SELECT hash FROM "+tableSchemaRevision+" WHERE name = 'current' ORDER BY created_transaction DESC LIMIT 1), '')"). + Limit(1) +} + func loadRevisionRange(tableTransaction string) sq.SelectBuilder { return sb.Select(colID, colMetadata).From(tableTransaction) } diff --git a/internal/datastore/mysql/reader.go b/internal/datastore/mysql/reader.go index 33e056d9d..6dad752e5 100644 --- a/internal/datastore/mysql/reader.go +++ b/internal/datastore/mysql/reader.go @@ -10,9 +10,9 @@ import ( "github.com/authzed/spicedb/internal/datastore/common" "github.com/authzed/spicedb/internal/datastore/revisions" - schemautil "github.com/authzed/spicedb/internal/datastore/schema" + schemaadapter "github.com/authzed/spicedb/internal/datastore/schema" "github.com/authzed/spicedb/pkg/datastore" - "github.com/authzed/spicedb/pkg/datastore/options" + dsoptions "github.com/authzed/spicedb/pkg/datastore/options" core "github.com/authzed/spicedb/pkg/proto/core/v1" ) @@ -28,6 +28,11 @@ type mysqlReader struct { aliveFilter queryFilterer filterMaximumIDCount uint16 schema common.SchemaInformation + schemaMode dsoptions.SchemaMode + snapshotRevision datastore.Revision + schemaHash string + schemaTableName string + schemaReaderWriter *common.SQLSchemaReaderWriter[uint64, revisions.TransactionIDRevision] } type queryFilterer func(original sq.SelectBuilder) sq.SelectBuilder @@ -164,7 +169,7 @@ func (mr *mysqlReader) lookupCounters(ctx context.Context, optionalName string) func (mr *mysqlReader) QueryRelationships( ctx context.Context, filter datastore.RelationshipsFilter, - opts ...options.QueryOptionsOption, + opts ...dsoptions.QueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(mr.schema, mr.filterMaximumIDCount). WithAdditionalFilter(mr.aliveFilter). @@ -179,7 +184,7 @@ func (mr *mysqlReader) QueryRelationships( func (mr *mysqlReader) ReverseQueryRelationships( ctx context.Context, subjectsFilter datastore.SubjectsFilter, - opts ...options.ReverseQueryOptionsOption, + opts ...dsoptions.ReverseQueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(mr.schema, mr.filterMaximumIDCount). WithAdditionalFilter(mr.aliveFilter). @@ -188,7 +193,7 @@ func (mr *mysqlReader) ReverseQueryRelationships( return nil, err } - queryOpts := options.NewReverseQueryOptionsWithOptions(opts...) + queryOpts := dsoptions.NewReverseQueryOptionsWithOptions(opts...) if queryOpts.ResRelation != nil { qBuilder = qBuilder. @@ -199,13 +204,13 @@ func (mr *mysqlReader) ReverseQueryRelationships( return mr.executor.ExecuteQuery( ctx, qBuilder, - options.WithLimit(queryOpts.LimitForReverse), - options.WithAfter(queryOpts.AfterForReverse), - options.WithSort(queryOpts.SortForReverse), - options.WithSkipCaveats(queryOpts.SkipCaveatsForReverse), - options.WithSkipExpiration(queryOpts.SkipExpirationForReverse), - options.WithQueryShape(queryOpts.QueryShapeForReverse), - options.WithSQLExplainCallbackForTest(queryOpts.SQLExplainCallbackForTestForReverse), + dsoptions.WithLimit(queryOpts.LimitForReverse), + dsoptions.WithAfter(queryOpts.AfterForReverse), + dsoptions.WithSort(queryOpts.SortForReverse), + dsoptions.WithSkipCaveats(queryOpts.SkipCaveatsForReverse), + dsoptions.WithSkipExpiration(queryOpts.SkipExpirationForReverse), + dsoptions.WithQueryShape(queryOpts.QueryShapeForReverse), + dsoptions.WithSQLExplainCallbackForTest(queryOpts.SQLExplainCallbackForTestForReverse), ) } @@ -337,7 +342,61 @@ func loadAllNamespaces(ctx context.Context, tx *sql.Tx, queryBuilder sq.SelectBu // SchemaReader returns a SchemaReader for reading schema information. func (mr *mysqlReader) SchemaReader() (datastore.SchemaReader, error) { - return schemautil.NewLegacySchemaReaderAdapter(mr), nil + // Wrap the reader with an unexported schema reader + reader := &mysqlSchemaReader{r: mr} + return schemaadapter.NewSchemaReader(reader, mr.schemaMode, mr.snapshotRevision), nil } -var _ datastore.Reader = &mysqlReader{} +// mysqlSchemaReader wraps a mysqlReader and implements DualSchemaReader. +// This prevents direct access to schema read methods from the reader. +type mysqlSchemaReader struct { + r *mysqlReader +} + +// ReadStoredSchema implements datastore.SingleStoreSchemaReader with revision-aware reading +func (sr *mysqlSchemaReader) ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) { + // Create a revision-aware executor that applies alive filter + executor := &mysqlRevisionAwareExecutor{ + txSource: sr.r.txSource, + aliveFilter: sr.r.aliveFilter, + } + + // Use the shared schema reader/writer to read the schema with the hash + return sr.r.schemaReaderWriter.ReadSchema(ctx, executor, sr.r.snapshotRevision, datastore.SchemaHash(sr.r.schemaHash)) +} + +// LegacyReadCaveatByName delegates to the underlying reader +func (sr *mysqlSchemaReader) LegacyReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { + return sr.r.LegacyReadCaveatByName(ctx, name) +} + +// LegacyListAllCaveats delegates to the underlying reader +func (sr *mysqlSchemaReader) LegacyListAllCaveats(ctx context.Context) ([]datastore.RevisionedCaveat, error) { + return sr.r.LegacyListAllCaveats(ctx) +} + +// LegacyLookupCaveatsWithNames delegates to the underlying reader +func (sr *mysqlSchemaReader) LegacyLookupCaveatsWithNames(ctx context.Context, names []string) ([]datastore.RevisionedCaveat, error) { + return sr.r.LegacyLookupCaveatsWithNames(ctx, names) +} + +// LegacyReadNamespaceByName delegates to the underlying reader +func (sr *mysqlSchemaReader) LegacyReadNamespaceByName(ctx context.Context, nsName string) (*core.NamespaceDefinition, datastore.Revision, error) { + return sr.r.LegacyReadNamespaceByName(ctx, nsName) +} + +// LegacyListAllNamespaces delegates to the underlying reader +func (sr *mysqlSchemaReader) LegacyListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { + return sr.r.LegacyListAllNamespaces(ctx) +} + +// LegacyLookupNamespacesWithNames delegates to the underlying reader +func (sr *mysqlSchemaReader) LegacyLookupNamespacesWithNames(ctx context.Context, nsNames []string) ([]datastore.RevisionedDefinition[*core.NamespaceDefinition], error) { + return sr.r.LegacyLookupNamespacesWithNames(ctx, nsNames) +} + +var ( + _ datastore.Reader = &mysqlReader{} + _ datastore.LegacySchemaReader = &mysqlReader{} + _ datastore.DualSchemaReader = &mysqlSchemaReader{} +) diff --git a/internal/datastore/mysql/readwrite.go b/internal/datastore/mysql/readwrite.go index d5d042bd8..79f9de755 100644 --- a/internal/datastore/mysql/readwrite.go +++ b/internal/datastore/mysql/readwrite.go @@ -4,12 +4,15 @@ import ( "bytes" "cmp" "context" + "crypto/sha256" "database/sql" "database/sql/driver" + "encoding/hex" "encoding/json" "errors" "fmt" "regexp" + "sort" "strings" "time" @@ -21,11 +24,13 @@ import ( "github.com/authzed/spicedb/internal/datastore/common" "github.com/authzed/spicedb/internal/datastore/revisions" - schemautil "github.com/authzed/spicedb/internal/datastore/schema" + schemaadapter "github.com/authzed/spicedb/internal/datastore/schema" log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/datastore/options" core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" "github.com/authzed/spicedb/pkg/spiceerrors" "github.com/authzed/spicedb/pkg/tuple" ) @@ -48,9 +53,11 @@ var ( type mysqlReadWriteTXN struct { *mysqlReader - tupleTableName string - tx *sql.Tx - newTxnID uint64 + tupleTableName string + schemaTableName string + schemaRevisionTableName string + tx *sql.Tx + newTxnID uint64 } // structpbWrapper is used to marshall maps into MySQLs JSON data type @@ -515,7 +522,202 @@ func (rwt *mysqlReadWriteTXN) LegacyDeleteNamespaces(ctx context.Context, nsName } func (rwt *mysqlReadWriteTXN) SchemaWriter() (datastore.SchemaWriter, error) { - return schemautil.NewLegacySchemaWriterAdapter(rwt, rwt), nil + // Wrap the transaction with an unexported schema writer + writer := &mysqlSchemaWriter{rwt: rwt} + return schemaadapter.NewSchemaWriter(writer, writer, rwt.schemaMode), nil +} + +// mysqlSchemaWriter wraps a mysqlReadWriteTXN and implements DualSchemaWriter. +// This prevents direct access to schema write methods from the transaction. +type mysqlSchemaWriter struct { + rwt *mysqlReadWriteTXN +} + +// WriteStoredSchema implements datastore.SingleStoreSchemaWriter by writing within the current transaction +func (w *mysqlSchemaWriter) WriteStoredSchema(ctx context.Context, schema *core.StoredSchema) error { + // Create a transaction-aware executor that uses the current transaction + executor := newMySQLTransactionAwareExecutor(w.rwt.tx) + + // Use the shared schema reader/writer to write the schema with the newTxnID as transaction ID + if err := w.rwt.schemaReaderWriter.WriteSchema(ctx, schema, executor, func(ctx context.Context) uint64 { + return w.rwt.newTxnID + }); err != nil { + return err + } + + // Write the schema hash to the schema_revision table for fast lookups + if err := w.writeSchemaHash(ctx, schema); err != nil { + return fmt.Errorf("failed to write schema hash: %w", err) + } + + return nil +} + +// writeSchemaHash writes the schema hash to the schema_revision table +func (w *mysqlSchemaWriter) writeSchemaHash(ctx context.Context, schema *core.StoredSchema) error { + v1 := schema.GetV1() + if v1 == nil { + return fmt.Errorf("unsupported schema version: %d", schema.Version) + } + + // Mark existing hash rows as deleted + sql, args, err := sb.Update(w.rwt.schemaRevisionTableName). + Set("deleted_transaction", w.rwt.newTxnID). + Where(sq.Eq{ + "name": "current", + "deleted_transaction": liveDeletedTxnID, + }). + ToSql() + if err != nil { + return fmt.Errorf("failed to build delete query: %w", err) + } + + if _, err := w.rwt.tx.ExecContext(ctx, sql, args...); err != nil { + return fmt.Errorf("failed to delete old hash: %w", err) + } + + // Insert new hash row (INSERT IGNORE handles WriteBoth mode duplicates) + sql, args, err = sb.Insert(w.rwt.schemaRevisionTableName). + Options("IGNORE"). + Columns("name", "hash", "created_transaction", "deleted_transaction"). + Values("current", []byte(v1.SchemaHash), w.rwt.newTxnID, liveDeletedTxnID). + ToSql() + if err != nil { + return fmt.Errorf("failed to build insert query: %w", err) + } + + if _, err := w.rwt.tx.ExecContext(ctx, sql, args...); err != nil { + return fmt.Errorf("failed to insert hash: %w", err) + } + + return nil +} + +// ReadStoredSchema implements datastore.SingleStoreSchemaReader to satisfy DualSchemaReader interface requirements +func (w *mysqlSchemaWriter) ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) { + // Create a revision-aware executor that applies alive filter + executor := &mysqlRevisionAwareExecutor{ + txSource: w.rwt.txSource, + aliveFilter: w.rwt.aliveFilter, + } + + // Use the shared schema reader/writer to read the schema + // Pass empty string for transaction reads to bypass cache + return w.rwt.schemaReaderWriter.ReadSchema(ctx, executor, nil, datastore.NoSchemaHashInTransaction) +} + +// LegacyWriteNamespaces delegates to the underlying transaction +func (w *mysqlSchemaWriter) LegacyWriteNamespaces(ctx context.Context, newConfigs ...*core.NamespaceDefinition) error { + return w.rwt.LegacyWriteNamespaces(ctx, newConfigs...) +} + +// LegacyDeleteNamespaces delegates to the underlying transaction +func (w *mysqlSchemaWriter) LegacyDeleteNamespaces(ctx context.Context, nsNames []string, delOption datastore.DeleteNamespacesRelationshipsOption) error { + return w.rwt.LegacyDeleteNamespaces(ctx, nsNames, delOption) +} + +// LegacyWriteCaveats delegates to the underlying transaction +func (w *mysqlSchemaWriter) LegacyWriteCaveats(ctx context.Context, caveats []*core.CaveatDefinition) error { + return w.rwt.LegacyWriteCaveats(ctx, caveats) +} + +// LegacyDeleteCaveats delegates to the underlying transaction +func (w *mysqlSchemaWriter) LegacyDeleteCaveats(ctx context.Context, names []string) error { + return w.rwt.LegacyDeleteCaveats(ctx, names) +} + +// LegacyReadCaveatByName delegates to the underlying transaction +func (w *mysqlSchemaWriter) LegacyReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { + return w.rwt.LegacyReadCaveatByName(ctx, name) +} + +// LegacyListAllCaveats delegates to the underlying transaction +func (w *mysqlSchemaWriter) LegacyListAllCaveats(ctx context.Context) ([]datastore.RevisionedCaveat, error) { + return w.rwt.LegacyListAllCaveats(ctx) +} + +// LegacyLookupCaveatsWithNames delegates to the underlying transaction +func (w *mysqlSchemaWriter) LegacyLookupCaveatsWithNames(ctx context.Context, names []string) ([]datastore.RevisionedCaveat, error) { + return w.rwt.LegacyLookupCaveatsWithNames(ctx, names) +} + +// LegacyReadNamespaceByName delegates to the underlying transaction +func (w *mysqlSchemaWriter) LegacyReadNamespaceByName(ctx context.Context, nsName string) (*core.NamespaceDefinition, datastore.Revision, error) { + return w.rwt.LegacyReadNamespaceByName(ctx, nsName) +} + +// LegacyListAllNamespaces delegates to the underlying transaction +func (w *mysqlSchemaWriter) LegacyListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { + return w.rwt.LegacyListAllNamespaces(ctx) +} + +// LegacyLookupNamespacesWithNames delegates to the underlying transaction +func (w *mysqlSchemaWriter) LegacyLookupNamespacesWithNames(ctx context.Context, nsNames []string) ([]datastore.RevisionedDefinition[*core.NamespaceDefinition], error) { + return w.rwt.LegacyLookupNamespacesWithNames(ctx, nsNames) +} + +// WriteLegacySchemaHashFromDefinitions implements datastore.LegacySchemaHashWriter +func (w *mysqlSchemaWriter) WriteLegacySchemaHashFromDefinitions(ctx context.Context, namespaces []datastore.RevisionedNamespace, caveats []datastore.RevisionedCaveat) error { + return w.rwt.writeLegacySchemaHashFromDefinitions(ctx, namespaces, caveats) +} + +// writeLegacySchemaHashFromDefinitions writes the schema hash computed from the given definitions +func (rwt *mysqlReadWriteTXN) writeLegacySchemaHashFromDefinitions(ctx context.Context, namespaces []datastore.RevisionedNamespace, caveats []datastore.RevisionedCaveat) error { + // Build schema definitions list + definitions := make([]compiler.SchemaDefinition, 0, len(namespaces)+len(caveats)) + for _, ns := range namespaces { + definitions = append(definitions, ns.Definition) + } + for _, caveat := range caveats { + definitions = append(definitions, caveat.Definition) + } + + // Sort definitions by name for consistent ordering + sort.Slice(definitions, func(i, j int) bool { + return definitions[i].GetName() < definitions[j].GetName() + }) + + // Generate schema text from definitions + schemaText, _, err := generator.GenerateSchema(definitions) + if err != nil { + return fmt.Errorf("failed to generate schema: %w", err) + } + + // Compute schema hash (SHA256) + hash := sha256.Sum256([]byte(schemaText)) + schemaHash := hex.EncodeToString(hash[:]) + + // Mark existing hash rows as deleted + sql, args, err := sb.Update(rwt.schemaRevisionTableName). + Set("deleted_transaction", rwt.newTxnID). + Where(sq.Eq{ + "name": "current", + "deleted_transaction": liveDeletedTxnID, + }). + ToSql() + if err != nil { + return fmt.Errorf("failed to build delete query: %w", err) + } + + if _, err := rwt.tx.ExecContext(ctx, sql, args...); err != nil { + return fmt.Errorf("failed to delete old hash: %w", err) + } + + // Insert new hash row (INSERT IGNORE handles WriteBoth mode duplicates) + sql, args, err = sb.Insert(rwt.schemaRevisionTableName). + Options("IGNORE"). + Columns("name", "hash", "created_transaction", "deleted_transaction"). + Values("current", []byte(schemaHash), rwt.newTxnID, liveDeletedTxnID). + ToSql() + if err != nil { + return fmt.Errorf("failed to build insert query: %w", err) + } + + if _, err := rwt.tx.ExecContext(ctx, sql, args...); err != nil { + return fmt.Errorf("failed to insert hash: %w", err) + } + + return nil } func (rwt *mysqlReadWriteTXN) BulkLoad(ctx context.Context, iter datastore.BulkWriteRelationshipSource) (uint64, error) { @@ -644,4 +846,9 @@ func exactRelationshipClause(r tuple.Relationship) sq.Eq { } } -var _ datastore.ReadWriteTransaction = &mysqlReadWriteTXN{} +var ( + _ datastore.ReadWriteTransaction = &mysqlReadWriteTXN{} + _ datastore.LegacySchemaWriter = &mysqlReadWriteTXN{} + _ datastore.DualSchemaWriter = &mysqlSchemaWriter{} + _ datastore.DualSchemaReader = &mysqlSchemaWriter{} +) diff --git a/internal/datastore/mysql/revisions.go b/internal/datastore/mysql/revisions.go index 9bfdc4096..6f9e81ad7 100644 --- a/internal/datastore/mysql/revisions.go +++ b/internal/datastore/mysql/revisions.go @@ -44,6 +44,25 @@ const ( )) as revision, %[4]d - CAST(UNIX_TIMESTAMP(UTC_TIMESTAMP(6)) * 1000000000 AS UNSIGNED INTEGER) %% %[4]d as validForNanos;` + // querySelectRevisionWithHash is like querySelectRevision but also loads the schema hash + // + // %[1] Name of id column + // %[2] Relationship tuple transaction table + // %[3] Name of timestamp column + // %[4] Quantization period (in nanoseconds) + // %[5] Follower read delay (in nanoseconds) + // %[6] Schema revision table + querySelectRevisionWithHash = `SELECT COALESCE(( + SELECT MIN(%[1]s) + FROM %[2]s + WHERE %[3]s >= FROM_UNIXTIME(FLOOR((UNIX_TIMESTAMP(UTC_TIMESTAMP(6)) * 1000000000 - %[5]d) / %[4]d) * %[4]d / 1000000000) + ), ( + SELECT MAX(%[1]s) + FROM %[2]s + )) as revision, + %[4]d - CAST(UNIX_TIMESTAMP(UTC_TIMESTAMP(6)) * 1000000000 AS UNSIGNED INTEGER) %% %[4]d as validForNanos, + COALESCE((SELECT hash FROM %[6]s WHERE name = 'current' ORDER BY created_transaction DESC LIMIT 1), '') as schema_hash;` + // queryValidTransaction will return a single row with two values, one boolean // for whether the specified transaction ID is newer than the garbage collection // window, and one boolean for whether the transaction ID represents a transaction @@ -70,26 +89,27 @@ const ( ) as unknown;` ) -func (mds *mysqlDatastore) optimizedRevisionFunc(ctx context.Context) (datastore.Revision, time.Duration, error) { +func (mds *mysqlDatastore) optimizedRevisionFunc(ctx context.Context) (datastore.Revision, time.Duration, datastore.SchemaHash, error) { var rev uint64 var validForNanos time.Duration - if err := mds.db.QueryRowContext(ctx, mds.optimizedRevisionQuery). - Scan(&rev, &validForNanos); err != nil { - return datastore.NoRevision, 0, fmt.Errorf(errRevision, err) + var schemaHash []byte + if err := mds.db.QueryRowContext(ctx, mds.optimizedRevisionQueryWithHash). + Scan(&rev, &validForNanos, &schemaHash); err != nil { + return datastore.NoRevision, 0, "", fmt.Errorf(errRevision, err) } - return revisions.NewForTransactionID(rev), validForNanos, nil + return revisions.NewForTransactionID(rev), validForNanos, datastore.SchemaHash(schemaHash), nil } -func (mds *mysqlDatastore) HeadRevision(ctx context.Context) (datastore.Revision, error) { - revision, err := mds.loadRevision(ctx) +func (mds *mysqlDatastore) HeadRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { + revision, schemaHash, err := mds.loadRevision(ctx) if err != nil { - return datastore.NoRevision, err + return datastore.NoRevision, "", err } if revision == 0 { - return datastore.NoRevision, nil + return datastore.NoRevision, "", nil } - return revisions.NewForTransactionID(revision), nil + return revisions.NewForTransactionID(revision), schemaHash, nil } func (mds *mysqlDatastore) CheckRevision(ctx context.Context, revision datastore.Revision) error { @@ -118,30 +138,31 @@ func (mds *mysqlDatastore) CheckRevision(ctx context.Context, revision datastore return nil } -func (mds *mysqlDatastore) loadRevision(ctx context.Context) (uint64, error) { +func (mds *mysqlDatastore) loadRevision(ctx context.Context) (uint64, datastore.SchemaHash, error) { // slightly changed to support no revisions at all, needed for runtime seeding of first transaction ctx, span := tracer.Start(ctx, "loadRevision") defer span.End() - query, args, err := mds.GetLastRevision.ToSql() + query, args, err := mds.GetLastRevisionWithHash.ToSql() if err != nil { - return 0, fmt.Errorf(errRevision, err) + return 0, "", fmt.Errorf(errRevision, err) } var revision *uint64 - err = mds.db.QueryRowContext(ctx, query, args...).Scan(&revision) + var schemaHash []byte + err = mds.db.QueryRowContext(ctx, query, args...).Scan(&revision, &schemaHash) if err != nil { if errors.Is(err, sql.ErrNoRows) { - return 0, nil + return 0, "", nil } - return 0, fmt.Errorf(errRevision, err) + return 0, "", fmt.Errorf(errRevision, err) } if revision == nil { - return 0, nil + return 0, "", nil } - return *revision, nil + return *revision, datastore.SchemaHash(schemaHash), nil } func (mds *mysqlDatastore) checkValidTransaction(ctx context.Context, revisionTx uint64) (bool, bool, error) { diff --git a/internal/datastore/mysql/schema_chunker.go b/internal/datastore/mysql/schema_chunker.go new file mode 100644 index 000000000..6fe5f8a93 --- /dev/null +++ b/internal/datastore/mysql/schema_chunker.go @@ -0,0 +1,237 @@ +package mysql + +import ( + "context" + "database/sql" + "errors" + + sq "github.com/Masterminds/squirrel" + + "github.com/authzed/spicedb/internal/datastore/common" +) + +const ( + // MySQL LONGBLOB can store up to 4GB, but we use 64KB chunks for safety + // and to avoid issues with max_allowed_packet settings. + mysqlMaxChunkSize = 64 * 1024 // 64KB +) + +// BaseSchemaChunkerConfig provides the base configuration for MySQL schema chunking. +// MySQL uses smaller chunks (64KB), question mark placeholders, and tombstone-based write mode. +var BaseSchemaChunkerConfig = common.SQLByteChunkerConfig[uint64]{ + TableName: "schema", + NameColumn: "name", + ChunkIndexColumn: "chunk_index", + ChunkDataColumn: "chunk_data", + MaxChunkSize: mysqlMaxChunkSize, + PlaceholderFormat: sq.Question, + WriteMode: common.WriteModeInsertWithTombstones, + CreatedAtColumn: "created_transaction", + DeletedAtColumn: "deleted_transaction", + AliveValue: liveDeletedTxnID, +} + +// mysqlChunkedBytesExecutor implements common.ChunkedBytesExecutor for MySQL. +type mysqlChunkedBytesExecutor struct { + db *sql.DB +} + +func newMySQLChunkedBytesExecutor(db *sql.DB) *mysqlChunkedBytesExecutor { + return &mysqlChunkedBytesExecutor{db: db} +} + +func (e *mysqlChunkedBytesExecutor) BeginTransaction(ctx context.Context) (common.ChunkedBytesTransaction, error) { + tx, err := e.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + return &mysqlChunkedBytesTransaction{tx: tx}, nil +} + +func (e *mysqlChunkedBytesExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + sql, args, err := builder.ToSql() + if err != nil { + return nil, err + } + + rows, err := e.db.QueryContext(ctx, sql, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + result := make(map[int][]byte) + for rows.Next() { + var chunkIndex int + var chunkData []byte + if err := rows.Scan(&chunkIndex, &chunkData); err != nil { + return nil, err + } + result[chunkIndex] = chunkData + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return result, nil +} + +// mysqlChunkedBytesTransaction implements common.ChunkedBytesTransaction for MySQL. +type mysqlChunkedBytesTransaction struct { + tx *sql.Tx +} + +func (t *mysqlChunkedBytesTransaction) ExecuteWrite(ctx context.Context, builder sq.InsertBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.ExecContext(ctx, sql, args...) + if err != nil { + return err + } + + return t.tx.Commit() +} + +func (t *mysqlChunkedBytesTransaction) ExecuteDelete(ctx context.Context, builder sq.DeleteBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.ExecContext(ctx, sql, args...) + if err != nil { + return err + } + + return t.tx.Commit() +} + +func (t *mysqlChunkedBytesTransaction) ExecuteUpdate(ctx context.Context, builder sq.UpdateBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.ExecContext(ctx, sql, args...) + if err != nil { + return err + } + + return t.tx.Commit() +} + +// GetSchemaChunker returns a SQLByteChunker for the schema table. +// This is exported for testing purposes. +func (mds *mysqlDatastore) GetSchemaChunker() *common.SQLByteChunker[uint64] { + executor := newMySQLChunkedBytesExecutor(mds.db) + return common.MustNewSQLByteChunker( + BaseSchemaChunkerConfig. + WithTableName(mds.schemaTableName). + WithExecutor(executor), + ) +} + +// mysqlRevisionAwareExecutor wraps the reader's query infrastructure to provide revision-aware chunk reading +type mysqlRevisionAwareExecutor struct { + txSource txFactory + aliveFilter func(sq.SelectBuilder) sq.SelectBuilder +} + +func (e *mysqlRevisionAwareExecutor) BeginTransaction(ctx context.Context) (common.ChunkedBytesTransaction, error) { + // We don't support transactions for reading + return nil, errors.New("transactions not supported for revision-aware reads") +} + +func (e *mysqlRevisionAwareExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + // Apply the alive filter to get chunks that were alive at this revision + builder = e.aliveFilter(builder) + + sql, args, err := builder.ToSql() + if err != nil { + return nil, err + } + + // Get a transaction to execute the query + tx, txCleanup, err := e.txSource(ctx) + if err != nil { + return nil, err + } + defer common.LogOnError(ctx, txCleanup) + + // Execute the query + rows, err := tx.QueryContext(ctx, sql, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + result := make(map[int][]byte) + for rows.Next() { + var chunkIndex int + var chunkData []byte + if err := rows.Scan(&chunkIndex, &chunkData); err != nil { + return nil, err + } + result[chunkIndex] = chunkData + } + + return result, rows.Err() +} + +// mysqlTransactionAwareExecutor wraps an existing sql.Tx to provide transaction-aware chunk writing +type mysqlTransactionAwareExecutor struct { + tx *sql.Tx +} + +func newMySQLTransactionAwareExecutor(tx *sql.Tx) *mysqlTransactionAwareExecutor { + return &mysqlTransactionAwareExecutor{tx: tx} +} + +func (e *mysqlTransactionAwareExecutor) BeginTransaction(ctx context.Context) (common.ChunkedBytesTransaction, error) { + // Return a transaction wrapper that uses the existing transaction + return &mysqlTransactionAwareTransaction{tx: e.tx}, nil +} + +func (e *mysqlTransactionAwareExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + return nil, errors.New("read operations not supported on transaction-aware executor") +} + +// mysqlTransactionAwareTransaction implements common.ChunkedBytesTransaction using an existing sql.Tx +// without committing after each operation (unlike mysqlChunkedBytesTransaction) +type mysqlTransactionAwareTransaction struct { + tx *sql.Tx +} + +func (t *mysqlTransactionAwareTransaction) ExecuteWrite(ctx context.Context, builder sq.InsertBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.ExecContext(ctx, sql, args...) + return err +} + +func (t *mysqlTransactionAwareTransaction) ExecuteDelete(ctx context.Context, builder sq.DeleteBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.ExecContext(ctx, sql, args...) + return err +} + +func (t *mysqlTransactionAwareTransaction) ExecuteUpdate(ctx context.Context, builder sq.UpdateBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.ExecContext(ctx, sql, args...) + return err +} diff --git a/internal/datastore/mysql/watch.go b/internal/datastore/mysql/watch.go index 553e3ccd9..9d8e30b39 100644 --- a/internal/datastore/mysql/watch.go +++ b/internal/datastore/mysql/watch.go @@ -129,7 +129,7 @@ func (mds *mysqlDatastore) loadChanges( afterRevision uint64, options datastore.WatchOptions, ) (changes []datastore.RevisionChanges, newRevision uint64, err error) { - newRevision, err = mds.loadRevision(ctx) + newRevision, _, err = mds.loadRevision(ctx) if err != nil { return changes, newRevision, err } diff --git a/internal/datastore/postgres/caveat.go b/internal/datastore/postgres/caveat.go index 535632096..b23d2ca2b 100644 --- a/internal/datastore/postgres/caveat.go +++ b/internal/datastore/postgres/caveat.go @@ -144,12 +144,17 @@ func (rwt *pgReadWriteTXN) LegacyWriteCaveats(ctx context.Context, caveats []*co if _, err := rwt.tx.Exec(ctx, sql, args...); err != nil { return fmt.Errorf(errWriteCaveats, err) } + return nil } func (rwt *pgReadWriteTXN) LegacyDeleteCaveats(ctx context.Context, names []string) error { // mark current caveats as deleted - return rwt.deleteCaveatsFromNames(ctx, names) + if err := rwt.deleteCaveatsFromNames(ctx, names); err != nil { + return err + } + + return nil } func (rwt *pgReadWriteTXN) deleteCaveatsFromNames(ctx context.Context, names []string) error { diff --git a/internal/datastore/postgres/gc.go b/internal/datastore/postgres/gc.go index 8f0f50d7b..e55b942c2 100644 --- a/internal/datastore/postgres/gc.go +++ b/internal/datastore/postgres/gc.go @@ -196,6 +196,32 @@ func (pgg *pgGarbageCollector) deleteBeforeTx(ctx context.Context, conn exec, tx return removed, fmt.Errorf("failed to GC namespaces table: %w", err) } + // Delete any schema rows with deleted_xid < minTxAlive. + _, err = pgg.batchDelete( + ctx, + conn, + schema.TableSchema, + gcPKCols, + sq.Lt{schema.ColDeletedXid: minTxAlive}, + nil, + ) + if err != nil { + return removed, fmt.Errorf("failed to GC schema table: %w", err) + } + + // Delete any schema_revision rows with deleted_xid < minTxAlive. + _, err = pgg.batchDelete( + ctx, + conn, + schema.TableSchemaRevision, + gcPKCols, + sq.Lt{schema.ColDeletedXid: minTxAlive}, + nil, + ) + if err != nil { + return removed, fmt.Errorf("failed to GC schema_revision table: %w", err) + } + return removed, err } diff --git a/internal/datastore/postgres/migrations/zz_migration.0024_add_schema_tables.go b/internal/datastore/postgres/migrations/zz_migration.0024_add_schema_tables.go new file mode 100644 index 000000000..17ab14883 --- /dev/null +++ b/internal/datastore/postgres/migrations/zz_migration.0024_add_schema_tables.go @@ -0,0 +1,41 @@ +package migrations + +import ( + "context" + + "github.com/jackc/pgx/v5" +) + +var schemaTablesStatements = []string{ + `CREATE TABLE schema ( + name VARCHAR NOT NULL, + chunk_index INT NOT NULL, + chunk_data BYTEA NOT NULL, + created_xid xid8 NOT NULL DEFAULT (pg_current_xact_id()), + deleted_xid xid8 NOT NULL DEFAULT ('9223372036854775807'), + CONSTRAINT pk_schema PRIMARY KEY (name, chunk_index, created_xid));`, + `CREATE INDEX ix_schema_gc ON schema (deleted_xid DESC) WHERE deleted_xid < '9223372036854775807'::xid8;`, + `CREATE TABLE schema_revision ( + name VARCHAR NOT NULL DEFAULT 'current', + hash BYTEA NOT NULL, + created_xid xid8 NOT NULL DEFAULT (pg_current_xact_id()), + deleted_xid xid8 NOT NULL DEFAULT ('9223372036854775807'), + CONSTRAINT pk_schema_revision PRIMARY KEY (name, created_xid));`, + `CREATE INDEX ix_schema_revision_gc ON schema_revision (deleted_xid DESC) WHERE deleted_xid < '9223372036854775807'::xid8;`, +} + +func init() { + if err := DatabaseMigrations.Register("add-schema-tables", "add-index-for-transaction-gc", + noNonatomicMigration, + func(ctx context.Context, tx pgx.Tx) error { + for _, stmt := range schemaTablesStatements { + if _, err := tx.Exec(ctx, stmt); err != nil { + return err + } + } + + return nil + }); err != nil { + panic("failed to register migration: " + err.Error()) + } +} diff --git a/internal/datastore/postgres/migrations/zz_migration.0025_populate_schema_tables.go b/internal/datastore/postgres/migrations/zz_migration.0025_populate_schema_tables.go new file mode 100644 index 000000000..632e3a03a --- /dev/null +++ b/internal/datastore/postgres/migrations/zz_migration.0025_populate_schema_tables.go @@ -0,0 +1,175 @@ +package migrations + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "sort" + + "github.com/jackc/pgx/v5" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" +) + +const ( + schemaChunkSize = 1024 * 1024 // 1MB chunks + currentSchemaVersion = 1 + unifiedSchemaName = "unified_schema" + schemaRevisionName = "current" +) + +func init() { + if err := DatabaseMigrations.Register("populate-schema-tables", "add-schema-tables", + noNonatomicMigration, + func(ctx context.Context, tx pgx.Tx) error { + // Read all existing namespaces + rows, err := tx.Query(ctx, ` + SELECT DISTINCT ON (namespace) + namespace, serialized_config + FROM namespace_config + WHERE deleted_xid = '9223372036854775807'::xid8 + ORDER BY namespace, created_xid DESC + `) + if err != nil { + return fmt.Errorf("failed to query namespaces: %w", err) + } + defer rows.Close() + + namespaces := make(map[string]*core.NamespaceDefinition) + for rows.Next() { + var name string + var config []byte + if err := rows.Scan(&name, &config); err != nil { + return fmt.Errorf("failed to scan namespace: %w", err) + } + + var ns core.NamespaceDefinition + if err := ns.UnmarshalVT(config); err != nil { + return fmt.Errorf("failed to unmarshal namespace %s: %w", name, err) + } + namespaces[name] = &ns + } + if err := rows.Err(); err != nil { + return fmt.Errorf("error iterating namespaces: %w", err) + } + + // Read all existing caveats + rows, err = tx.Query(ctx, ` + SELECT DISTINCT ON (name) + name, definition + FROM caveat + WHERE deleted_xid = '9223372036854775807'::xid8 + ORDER BY name, created_xid DESC + `) + if err != nil { + return fmt.Errorf("failed to query caveats: %w", err) + } + defer rows.Close() + + caveats := make(map[string]*core.CaveatDefinition) + for rows.Next() { + var name string + var definition []byte + if err := rows.Scan(&name, &definition); err != nil { + return fmt.Errorf("failed to scan caveat: %w", err) + } + + var caveat core.CaveatDefinition + if err := caveat.UnmarshalVT(definition); err != nil { + return fmt.Errorf("failed to unmarshal caveat %s: %w", name, err) + } + caveats[name] = &caveat + } + if err := rows.Err(); err != nil { + return fmt.Errorf("error iterating caveats: %w", err) + } + + // If there are no namespaces or caveats, skip migration + if len(namespaces) == 0 && len(caveats) == 0 { + return nil + } + + // Generate canonical schema for hash computation + allDefs := make([]compiler.SchemaDefinition, 0, len(namespaces)+len(caveats)) + for _, ns := range namespaces { + allDefs = append(allDefs, ns) + } + for _, caveat := range caveats { + allDefs = append(allDefs, caveat) + } + + // Sort alphabetically for canonical ordering + sort.Slice(allDefs, func(i, j int) bool { + return allDefs[i].GetName() < allDefs[j].GetName() + }) + + // Generate canonical schema text + canonicalSchemaText, _, err := generator.GenerateSchema(allDefs) + if err != nil { + return fmt.Errorf("failed to generate canonical schema: %w", err) + } + + // Compute SHA256 hash + hashBytes := sha256.Sum256([]byte(canonicalSchemaText)) + schemaHash := hashBytes[:] + + // Generate user-facing schema text (with original ordering) + schemaText, _, err := generator.GenerateSchema(allDefs) + if err != nil { + return fmt.Errorf("failed to generate schema: %w", err) + } + + // Create stored schema proto + storedSchema := &core.StoredSchema{ + Version: currentSchemaVersion, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: schemaText, + SchemaHash: hex.EncodeToString(schemaHash), + NamespaceDefinitions: namespaces, + CaveatDefinitions: caveats, + }, + }, + } + + // Marshal schema + schemaData, err := storedSchema.MarshalVT() + if err != nil { + return fmt.Errorf("failed to marshal schema: %w", err) + } + + // Insert schema chunks + for chunkIndex := 0; chunkIndex*schemaChunkSize < len(schemaData); chunkIndex++ { + start := chunkIndex * schemaChunkSize + end := start + schemaChunkSize + if end > len(schemaData) { + end = len(schemaData) + } + chunk := schemaData[start:end] + + _, err = tx.Exec(ctx, ` + INSERT INTO schema (name, chunk_index, chunk_data) + VALUES ($1, $2, $3) + `, unifiedSchemaName, chunkIndex, chunk) + if err != nil { + return fmt.Errorf("failed to insert schema chunk %d: %w", chunkIndex, err) + } + } + + // Insert schema hash + _, err = tx.Exec(ctx, ` + INSERT INTO schema_revision (name, hash) + VALUES ($1, $2) + `, schemaRevisionName, schemaHash) + if err != nil { + return fmt.Errorf("failed to insert schema hash: %w", err) + } + + return nil + }); err != nil { + panic("failed to register migration: " + err.Error()) + } +} diff --git a/internal/datastore/postgres/options.go b/internal/datastore/postgres/options.go index 1e2e1a812..f1a956b98 100644 --- a/internal/datastore/postgres/options.go +++ b/internal/datastore/postgres/options.go @@ -7,6 +7,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/common" pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" log "github.com/authzed/spicedb/internal/logging" + dsoptions "github.com/authzed/spicedb/pkg/datastore/options" ) type postgresOptions struct { @@ -40,6 +41,9 @@ type postgresOptions struct { migrationPhase string allowedMigrations []string + schemaMode dsoptions.SchemaMode + schemaCacheOptions dsoptions.SchemaCacheOptions + logger *tracingLogger queryInterceptor pgxcommon.QueryInterceptor @@ -445,3 +449,13 @@ func WithWatchDisabled(isDisabled bool) Option { func WithRelaxedIsolationLevel(isEnabled bool) Option { return func(po *postgresOptions) { po.relaxedIsolationLevel = isEnabled } } + +// WithSchemaMode sets the experimental schema mode for the datastore. +func WithSchemaMode(mode dsoptions.SchemaMode) Option { + return func(po *postgresOptions) { po.schemaMode = mode } +} + +// WithSchemaCacheOptions sets the schema cache options for the datastore. +func WithSchemaCacheOptions(cacheOptions dsoptions.SchemaCacheOptions) Option { + return func(po *postgresOptions) { po.schemaCacheOptions = cacheOptions } +} diff --git a/internal/datastore/postgres/postgres.go b/internal/datastore/postgres/postgres.go index 982a45d45..a99bad6ca 100644 --- a/internal/datastore/postgres/postgres.go +++ b/internal/datastore/postgres/postgres.go @@ -34,7 +34,7 @@ import ( log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/internal/sharederrors" "github.com/authzed/spicedb/pkg/datastore" - "github.com/authzed/spicedb/pkg/datastore/options" + dsoptions "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/spiceerrors" ) @@ -72,6 +72,12 @@ const ( var livingTupleConstraints = []string{"uq_relation_tuple_living_xid", "pk_relation_tuple"} +type contextKey string + +const ( + ctxKeyTransactionID contextKey = "postgres_transaction_id" +) + func init() { dbsql.Register(tracingDriverName, sqlmw.Driver(stdlib.GetDefaultDriver(), new(traceInterceptor))) } @@ -306,34 +312,43 @@ func newPostgresDatastore( CachedOptimizedRevisions: revisions.NewCachedOptimizedRevisions( maxRevisionStaleness, ), - MigrationValidator: common.NewMigrationValidator(headMigration, config.allowedMigrations), - dburl: pgURL, - readPool: pgxcommon.MustNewInterceptorPooler(readPool, config.queryInterceptor), - writePool: nil, /* disabled by default */ - collectors: collectors, - watchBufferLength: config.watchBufferLength, - watchChangeBufferMaximumSize: config.watchChangeBufferMaximumSize, - watchBufferWriteTimeout: config.watchBufferWriteTimeout, - optimizedRevisionQuery: revisionQuery, - validTransactionQuery: validTransactionQuery, - revisionHeartbeatQuery: revisionHeartbeatQuery, - gcWindow: config.gcWindow, - gcInterval: config.gcInterval, - gcTimeout: config.gcMaxOperationTime, - analyzeBeforeStatistics: config.analyzeBeforeStatistics, - watchEnabled: watchEnabled, - workerCtx: gcCtx, - cancelGc: cancelGc, - readTxOptions: pgx.TxOptions{IsoLevel: pgx.RepeatableRead, AccessMode: pgx.ReadOnly}, - maxRetries: config.maxRetries, - credentialsProvider: credentialsProvider, - isPrimary: isPrimary, - inStrictReadMode: config.readStrictMode, - filterMaximumIDCount: config.filterMaximumIDCount, - schema: *schema.Schema(config.columnOptimizationOption, false), - quantizationPeriodNanos: quantizationPeriodNanos, - isolationLevel: isolationLevel, + MigrationValidator: common.NewMigrationValidator(headMigration, config.allowedMigrations), + dburl: pgURL, + readPool: pgxcommon.MustNewInterceptorPooler(readPool, config.queryInterceptor), + writePool: nil, /* disabled by default */ + collectors: collectors, + watchBufferLength: config.watchBufferLength, + watchChangeBufferMaximumSize: config.watchChangeBufferMaximumSize, + watchBufferWriteTimeout: config.watchBufferWriteTimeout, + optimizedRevisionQuery: revisionQuery, + validTransactionQuery: validTransactionQuery, + revisionHeartbeatQuery: revisionHeartbeatQuery, + gcWindow: config.gcWindow, + gcInterval: config.gcInterval, + gcTimeout: config.gcMaxOperationTime, + analyzeBeforeStatistics: config.analyzeBeforeStatistics, + watchEnabled: watchEnabled, + workerCtx: gcCtx, + cancelGc: cancelGc, + readTxOptions: pgx.TxOptions{IsoLevel: pgx.RepeatableRead, AccessMode: pgx.ReadOnly}, + maxRetries: config.maxRetries, + credentialsProvider: credentialsProvider, + isPrimary: isPrimary, + inStrictReadMode: config.readStrictMode, + filterMaximumIDCount: config.filterMaximumIDCount, + schema: *schema.Schema(config.columnOptimizationOption, false), + includeQueryParametersInTraces: config.includeQueryParametersInTraces, + schemaMode: config.schemaMode, + quantizationPeriodNanos: quantizationPeriodNanos, + isolationLevel: isolationLevel, + } + + // Create schema reader/writer + schemaReaderWriter, err := common.NewSQLSchemaReaderWriter[uint64, postgresRevision](BaseSchemaChunkerConfig, config.schemaCacheOptions) + if err != nil { + return nil, fmt.Errorf("failed to create schema reader/writer: %w", err) } + datastore.schemaReaderWriter = schemaReaderWriter if isPrimary && config.readStrictMode { return nil, spiceerrors.MustBugf("strict read mode is not supported on primary instances") @@ -369,6 +384,11 @@ func newPostgresDatastore( } } + // Warm the schema cache on startup + if err := warmSchemaCache(initializationContext, datastore); err != nil { + log.Warn().Err(err).Msg("failed to warm schema cache on startup") + } + return datastore, nil } @@ -396,13 +416,17 @@ type pgDatastore struct { inStrictReadMode bool schema common.SchemaInformation includeQueryParametersInTraces bool + schemaMode dsoptions.SchemaMode credentialsProvider datastore.CredentialsProvider uniqueID atomic.Pointer[string] - workerGroup *errgroup.Group - workerCtx context.Context - cancelGc context.CancelFunc + workerGroup *errgroup.Group + workerCtx context.Context + cancelGc context.CancelFunc + + // SQLSchemaReaderWriter for schema operations + schemaReaderWriter *common.SQLSchemaReaderWriter[uint64, postgresRevision] gcHasRun atomic.Bool filterMaximumIDCount uint16 quantizationPeriodNanos int64 @@ -417,7 +441,7 @@ func (pgd *pgDatastore) IsStrictReadModeEnabled() bool { return pgd.inStrictReadMode } -func (pgd *pgDatastore) SnapshotReader(revRaw datastore.Revision) datastore.Reader { +func (pgd *pgDatastore) SnapshotReader(revRaw datastore.Revision, schemaHash datastore.SchemaHash) datastore.Reader { rev := revRaw.(postgresRevision) queryFuncs := pgxcommon.QuerierFuncsFor(pgd.readPool) @@ -430,11 +454,15 @@ func (pgd *pgDatastore) SnapshotReader(revRaw datastore.Revision) datastore.Read } return &pgReader{ - queryFuncs, - executor, - buildLivingObjectFilterForRevision(rev), - pgd.filterMaximumIDCount, - pgd.schema, + query: queryFuncs, + executor: executor, + aliveFilter: buildLivingObjectFilterForRevision(rev), + filterMaximumIDCount: pgd.filterMaximumIDCount, + schema: pgd.schema, + schemaMode: pgd.schemaMode, + snapshotRevision: rev, + schemaReaderWriter: pgd.schemaReaderWriter, + schemaHash: string(schemaHash), } } @@ -443,13 +471,13 @@ func (pgd *pgDatastore) SnapshotReader(revRaw datastore.Revision) datastore.Read func (pgd *pgDatastore) ReadWriteTx( ctx context.Context, fn datastore.TxUserFunc, - opts ...options.RWTOptionsOption, + opts ...dsoptions.RWTOptionsOption, ) (datastore.Revision, error) { if !pgd.isPrimary { return datastore.NoRevision, spiceerrors.MustBugf("read-write transaction not supported on read-only datastore") } - config := options.NewRWTOptionsWithOptions(opts...) + config := dsoptions.NewRWTOptionsWithOptions(opts...) var err error for i := uint8(0); i <= pgd.maxRetries; i++ { @@ -474,18 +502,24 @@ func (pgd *pgDatastore) ReadWriteTx( } rwt := &pgReadWriteTXN{ - &pgReader{ - queryFuncs, - executor, - currentlyLivingObjects, - pgd.filterMaximumIDCount, - pgd.schema, + pgReader: &pgReader{ + query: queryFuncs, + executor: executor, + aliveFilter: currentlyLivingObjects, + filterMaximumIDCount: pgd.filterMaximumIDCount, + schema: pgd.schema, + schemaMode: pgd.schemaMode, + snapshotRevision: datastore.NoRevision, // snapshotRevision (not yet known in RWT) + schemaReaderWriter: pgd.schemaReaderWriter, + schemaHash: string(datastore.NoSchemaHashInTransaction), // Transaction reads bypass cache }, - tx, - newXID, + tx: tx, + newXID: newXID, } - return fn(ctx, rwt) + // Add transaction ID to context for schema operations + ctxWithTxn := context.WithValue(ctx, ctxKeyTransactionID, newXID.Uint64) + return fn(ctxWithTxn, rwt) })) if err != nil { if !config.DisableRetries && errorRetryable(err) { @@ -657,6 +691,50 @@ func (pgd *pgDatastore) Close() error { return nil } +// SchemaHashReaderForTesting returns a test-only interface for reading the schema hash directly from schema_revision table. +func (pgd *pgDatastore) SchemaHashReaderForTesting() interface { + ReadSchemaHash(ctx context.Context) (string, error) +} { + queryFuncs := pgxcommon.QuerierFuncsFor(pgd.readPool) + return &pgSchemaHashReaderForTesting{query: queryFuncs} +} + +// SchemaModeForTesting returns the current schema mode for testing purposes. +func (pgd *pgDatastore) SchemaModeForTesting() (dsoptions.SchemaMode, error) { + return pgd.schemaMode, nil +} + +type pgSchemaHashReaderForTesting struct { + query pgxcommon.DBFuncQuerier +} + +func (r *pgSchemaHashReaderForTesting) ReadSchemaHash(ctx context.Context) (string, error) { + sql, args, err := psql.Select("hash"). + From("schema_revision"). + Where(sq.Eq{ + "name": "current", + "deleted_xid": liveDeletedTxnID, + }). + ToSql() + if err != nil { + return "", fmt.Errorf("failed to build query: %w", err) + } + + var hashBytes []byte + + err = r.query.QueryRowFunc(ctx, func(ctx context.Context, row pgx.Row) error { + return row.Scan(&hashBytes) + }, sql, args...) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return "", datastore.ErrSchemaNotFound + } + return "", fmt.Errorf("failed to query schema hash: %w", err) + } + + return string(hashBytes), nil +} + func errorRetryable(err error) bool { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return false @@ -806,6 +884,39 @@ func currentlyLivingObjects(original sq.SelectBuilder) sq.SelectBuilder { var _ datastore.Datastore = &pgDatastore{} +// warmSchemaCache attempts to warm the schema cache by loading the current schema. +// This is called during datastore initialization to avoid cold-start latency on first requests. +func warmSchemaCache(ctx context.Context, ds *pgDatastore) error { + // Get the current revision and schema hash + rev, schemaHash, err := ds.HeadRevision(ctx) + if err != nil { + return fmt.Errorf("failed to get head revision: %w", err) + } + + // If there's no schema hash, there's no schema to warm + if schemaHash == "" { + log.Ctx(ctx).Debug().Msg("no schema hash found, skipping cache warming") + return nil + } + + // Create a simple executor for schema reading (no transaction, no revision filtering needed for warmup) + executor := newPostgresChunkedBytesExecutor(ds.readPool.(*pgxpool.Pool)) + + // Load the schema to populate the cache + _, err = ds.schemaReaderWriter.ReadSchema(ctx, executor, rev, schemaHash) + if err != nil { + if errors.Is(err, datastore.ErrSchemaNotFound) { + // Schema not found is not an error during warming - just means no schema yet + log.Ctx(ctx).Debug().Msg("no schema found, skipping cache warming") + return nil + } + return fmt.Errorf("failed to read schema: %w", err) + } + + log.Ctx(ctx).Info().Str("schema_hash", string(schemaHash)).Msg("schema cache warmed successfully") + return nil +} + func registerAndReturnPrometheusCollectors(replicaIndex int, isPrimary bool, readPool, writePool *pgxpool.Pool, enablePrometheusStats bool) ([]prometheus.Collector, error) { collectors := []prometheus.Collector{} if !enablePrometheusStats { @@ -823,9 +934,14 @@ func registerAndReturnPrometheusCollectors(replicaIndex int, isPrimary bool, rea "pool_usage": "read", }) if err := prometheus.Register(readCollector); err != nil { - return collectors, err + // Ignore AlreadyRegisteredError which can happen in tests + var alreadyRegistered prometheus.AlreadyRegisteredError + if !errors.As(err, &alreadyRegistered) { + return collectors, err + } + } else { + collectors = append(collectors, readCollector) } - collectors = append(collectors, readCollector) if isPrimary { writeCollector := pgxpoolprometheus.NewCollector(writePool, map[string]string{ @@ -834,9 +950,14 @@ func registerAndReturnPrometheusCollectors(replicaIndex int, isPrimary bool, rea }) if err := prometheus.Register(writeCollector); err != nil { - return collectors, nil + // Ignore AlreadyRegisteredError which can happen in tests + var alreadyRegistered prometheus.AlreadyRegisteredError + if !errors.As(err, &alreadyRegistered) { + return collectors, err + } + } else { + collectors = append(collectors, writeCollector) } - collectors = append(collectors, writeCollector) gcCollectors, err := common.RegisterGCMetrics() if err != nil { diff --git a/internal/datastore/postgres/postgres_shared_test.go b/internal/datastore/postgres/postgres_shared_test.go index f090f6fa5..091f56556 100644 --- a/internal/datastore/postgres/postgres_shared_test.go +++ b/internal/datastore/postgres/postgres_shared_test.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "math/rand" + "os" "strings" "sync" "testing" @@ -44,6 +45,15 @@ const ( veryLargeGCInterval = 90000 * time.Second ) +func postgresTestVersion() string { + ver := os.Getenv("POSTGRES_TEST_VERSION") + if ver != "" { + return ver + } + + return pgversion.LatestTestedPostgresVersion +} + // Implement the interface for testing datastores func (pgd *pgDatastore) ExampleRetryableError() error { return &pgconn.PgError{ @@ -622,6 +632,128 @@ func GarbageCollectionTest(t *testing.T, ds datastore.Datastore) { require.Zero(removed.Namespaces) } +func SchemaGarbageCollectionTest(t *testing.T, ds datastore.Datastore) { + require := require.New(t) + + ctx := context.Background() + r, err := ds.ReadyState(ctx) + require.NoError(err) + require.True(r.IsReady) + + pds := ds.(*pgDatastore) + + pgg, err := pds.BuildGarbageCollector(ctx) + require.NoError(err) + defer pgg.Close() + + // Helper to count rows in schema tables + countSchemaRows := func() (schemaRows, schemaRevisionRows int64) { + sql := "SELECT COUNT(*) FROM schema" + err := pds.readPool.QueryRow(ctx, sql).Scan(&schemaRows) + require.NoError(err) + + sql = "SELECT COUNT(*) FROM schema_revision" + err = pds.readPool.QueryRow(ctx, sql).Scan(&schemaRevisionRows) + require.NoError(err) + + return schemaRows, schemaRevisionRows + } + + // Write schema version 1 + schemaText1 := `definition resource { + relation reader: user + } + definition user {}` + rev1, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + schemaWriter, err := rwt.SchemaWriter() + if err != nil { + return err + } + return schemaWriter.WriteSchema(ctx, []datastore.SchemaDefinition{}, schemaText1, nil) + }) + require.NoError(err) + + schemaRows, schemaRevisionRows := countSchemaRows() + require.Positive(schemaRows, "Should have schema rows after first write") + require.Positive(schemaRevisionRows, "Should have schema_revision rows after first write") + + // Write schema version 2 + schemaText2 := `definition resource { + relation reader: user + relation writer: user + } + definition user {}` + rev2, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + schemaWriter, err := rwt.SchemaWriter() + if err != nil { + return err + } + return schemaWriter.WriteSchema(ctx, []datastore.SchemaDefinition{}, schemaText2, nil) + }) + require.NoError(err) + + schemaRows2, schemaRevisionRows2 := countSchemaRows() + require.Greater(schemaRows2, schemaRows, "Should have more schema rows after second write") + require.Greater(schemaRevisionRows2, schemaRevisionRows, "Should have more schema_revision rows after second write") + + // Write schema version 3 + schemaText3 := `definition resource { + relation reader: user + relation writer: user + relation admin: user + } + definition user {}` + rev3, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + schemaWriter, err := rwt.SchemaWriter() + if err != nil { + return err + } + return schemaWriter.WriteSchema(ctx, []datastore.SchemaDefinition{}, schemaText3, nil) + }) + require.NoError(err) + + schemaRows3, schemaRevisionRows3 := countSchemaRows() + require.Greater(schemaRows3, schemaRows2, "Should have more schema rows after third write") + require.Greater(schemaRevisionRows3, schemaRevisionRows2, "Should have more schema_revision rows after third write") + + // Run GC at rev1 - should not remove any schema rows since they're all still in use + removed, err := pgg.DeleteBeforeTx(ctx, rev1) + require.NoError(err) + require.Zero(removed.Relationships) + + schemaRowsAfterGC1, schemaRevisionRowsAfterGC1 := countSchemaRows() + require.Equal(schemaRows3, schemaRowsAfterGC1, "No schema rows should be removed at rev1") + require.Equal(schemaRevisionRows3, schemaRevisionRowsAfterGC1, "No schema_revision rows should be removed at rev1") + + // Run GC at rev2 - should remove schema rows from rev1 + _, err = pgg.DeleteBeforeTx(ctx, rev2) + require.NoError(err) + + schemaRowsAfterGC2, schemaRevisionRowsAfterGC2 := countSchemaRows() + require.Less(schemaRowsAfterGC2, schemaRows3, "Schema rows from rev1 should be removed") + require.Less(schemaRevisionRowsAfterGC2, schemaRevisionRows3, "Schema_revision rows from rev1 should be removed") + + // Run GC at rev3 - should remove schema rows from rev2 + _, err = pgg.DeleteBeforeTx(ctx, rev3) + require.NoError(err) + + schemaRowsAfterGC3, schemaRevisionRowsAfterGC3 := countSchemaRows() + require.Less(schemaRowsAfterGC3, schemaRowsAfterGC2, "Schema rows from rev2 should be removed") + require.Less(schemaRevisionRowsAfterGC3, schemaRevisionRowsAfterGC2, "Schema_revision rows from rev2 should be removed") + + // Verify we can still read the latest schema + headRev, _, err := ds.HeadRevision(ctx) + require.NoError(err) + + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) + schemaReader, err := reader.SchemaReader() + require.NoError(err) + schemaText, err := schemaReader.SchemaText() + require.NoError(err) + require.NotEmpty(schemaText, "Schema text should not be empty") + require.Contains(schemaText, "relation admin", "Schema should contain the admin relation") +} + func TransactionTimestampsTest(t *testing.T, ds datastore.Datastore) { require := require.New(t) @@ -1044,7 +1176,7 @@ func assertRevisionLowerAndHigher(ctx context.Context, t *testing.T, ds datastor var snapshot pgSnapshot pgDS, ok := ds.(*pgDatastore) require.True(t, ok) - rev, _, err := pgDS.optimizedRevisionFunc(ctx) + rev, _, _, err := pgDS.optimizedRevisionFunc(ctx) require.NoError(t, err) pgRev, ok := rev.(postgresRevision) @@ -1123,10 +1255,10 @@ func ConcurrentRevisionHeadTest(t *testing.T, ds datastore.Datastore) { require.False(commitFirstRev.Equal(commitLastRev)) // Ensure a call to HeadRevision now reflects both sets of data applied. - headRev, err := ds.HeadRevision(ctx) + headRev, _, err := ds.HeadRevision(ctx) require.NoError(err) - reader := ds.SnapshotReader(headRev) + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) it, err := reader.QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: "resource", }) @@ -1260,7 +1392,7 @@ func OverlappingRevisionWatchTest(t *testing.T, ds datastore.Datastore) { require.NoError(err) require.True(r.IsReady) - rev, err := ds.HeadRevision(ctx) + rev, _, err := ds.HeadRevision(ctx) require.NoError(err) pds := ds.(*pgDatastore) @@ -1495,7 +1627,7 @@ func BenchmarkPostgresQuery(b *testing.B) { require := require.New(b) for i := 0; i < b.N; i++ { - iter, err := ds.SnapshotReader(revision).QueryRelationships(context.Background(), datastore.RelationshipsFilter{ + iter, err := ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting).QueryRelationships(context.Background(), datastore.RelationshipsFilter{ OptionalResourceType: testfixtures.DocumentNS.Name, }, options.WithQueryShape(queryshape.FindResourceOfType)) require.NoError(err) @@ -1676,7 +1808,7 @@ func GCQueriesServedByExpectedIndexes(t *testing.T, _ testdatastore.RunningEngin // Get the head revision. ctx := context.Background() - revision, err := ds.HeadRevision(ctx) + revision, _, err := ds.HeadRevision(ctx) require.NoError(err) casted := datastore.UnwrapAs[*pgDatastore](ds) @@ -1700,6 +1832,9 @@ func GCQueriesServedByExpectedIndexes(t *testing.T, _ testdatastore.RunningEngin case strings.HasPrefix(explanation, "Delete on namespace_config"): fallthrough + case strings.HasPrefix(explanation, "Delete on schema"): + fallthrough + case strings.HasPrefix(explanation, "Delete on relation_tuple"): require.Contains(explanation, "Index Scan") @@ -1802,7 +1937,7 @@ func StrictReadModeFallbackTest(t *testing.T, primaryDS datastore.Datastore, unw require.NoError(err) // Get the HEAD revision. - lowestRevision, err := primaryDS.HeadRevision(ctx) + lowestRevision, _, err := primaryDS.HeadRevision(ctx) require.NoError(err) // Wrap the replica DS. @@ -1810,7 +1945,7 @@ func StrictReadModeFallbackTest(t *testing.T, primaryDS datastore.Datastore, unw require.NoError(err) // Perform a read at the head revision, which should succeed. - reader := replicaDS.SnapshotReader(lowestRevision) + reader := replicaDS.SnapshotReader(lowestRevision, datastore.NoSchemaHashForTesting) it, err := reader.QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: "resource", }) @@ -1832,7 +1967,7 @@ func StrictReadModeFallbackTest(t *testing.T, primaryDS datastore.Datastore, unw } limit := uint64(50) - it, err = replicaDS.SnapshotReader(badRev).QueryRelationships(ctx, datastore.RelationshipsFilter{ + it, err = replicaDS.SnapshotReader(badRev, datastore.NoSchemaHashForTesting).QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: "resource", }, options.WithLimit(&limit)) require.NoError(err) @@ -1856,11 +1991,11 @@ func StrictReadModeTest(t *testing.T, primaryDS datastore.Datastore, replicaDS d require.NoError(err) // Get the HEAD revision. - lowestRevision, err := primaryDS.HeadRevision(ctx) + lowestRevision, _, err := primaryDS.HeadRevision(ctx) require.NoError(err) // Perform a read at the head revision, which should succeed. - reader := replicaDS.SnapshotReader(lowestRevision) + reader := replicaDS.SnapshotReader(lowestRevision, datastore.NoSchemaHashForTesting) it, err := reader.QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: "resource", }) @@ -1881,7 +2016,7 @@ func StrictReadModeTest(t *testing.T, primaryDS datastore.Datastore, replicaDS d }, } - it, err = replicaDS.SnapshotReader(badRev).QueryRelationships(ctx, datastore.RelationshipsFilter{ + it, err = replicaDS.SnapshotReader(badRev, datastore.NoSchemaHashForTesting).QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: "resource", }) require.NoError(err) @@ -1899,7 +2034,7 @@ func NullCaveatWatchTest(t *testing.T, ds datastore.Datastore) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - lowestRevision, err := ds.HeadRevision(ctx) + lowestRevision, _, err := ds.HeadRevision(ctx) require.NoError(err) // Run the watch API. @@ -1970,7 +2105,7 @@ func RevisionTimestampAndTransactionIDTest(t *testing.T, ds datastore.Datastore) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - lowestRevision, err := ds.HeadRevision(ctx) + lowestRevision, _, err := ds.HeadRevision(ctx) require.NoError(err) // Run the watch API. @@ -2034,7 +2169,7 @@ func ContinuousCheckpointTest(t *testing.T, ds datastore.Datastore) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - lowestRevision, err := ds.HeadRevision(ctx) + lowestRevision, _, err := ds.HeadRevision(ctx) require.NoError(err) // Run the watch API. @@ -2092,9 +2227,9 @@ func ExceedInsertQuerySizeTest(t *testing.T, ds datastore.Datastore) { require.Error(err) require.ErrorContains(err, "exceeds the maximum size supported by this datastore") - headRev, err := ds.HeadRevision(ctx) + headRev, _, err := ds.HeadRevision(ctx) require.NoError(err) - iter, err := ds.SnapshotReader(headRev).QueryRelationships(context.Background(), datastore.RelationshipsFilter{ + iter, err := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting).QueryRelationships(context.Background(), datastore.RelationshipsFilter{ OptionalResourceType: "resource", }) require.NoError(err) @@ -2105,4 +2240,32 @@ func ExceedInsertQuerySizeTest(t *testing.T, ds datastore.Datastore) { require.Equal(0, count, "expected to have 0 relationships, but found %d", count) } +func TestPostgresDatastoreUnifiedSchemaAllModes(t *testing.T) { + t.Parallel() + b := testdatastore.RunPostgresForTesting(t, "", "head", postgresTestVersion(), false) + + test.UnifiedSchemaAllModesTest(t, func(schemaMode options.SchemaMode) test.DatastoreTester { + return test.DatastoreTesterFunc(func(revisionQuantization, gcInterval, gcWindow time.Duration, watchBufferLength uint16) (datastore.Datastore, error) { + ctx := context.Background() + ds := b.NewDatastore(t, func(engine, uri string) datastore.Datastore { + ds, err := NewPostgresDatastore( + ctx, + uri, + GCWindow(gcWindow), + RevisionQuantization(revisionQuantization), + WatchBufferLength(watchBufferLength), + WithSchemaMode(schemaMode), + ) + require.NoError(t, err) + t.Cleanup(func() { + _ = ds.Close() + }) + return ds + }) + + return ds, nil + }) + }) +} + const waitForChangesTimeout = 10 * time.Second diff --git a/internal/datastore/postgres/postgres_test.go b/internal/datastore/postgres/postgres_test.go index 78c23b161..51bfed609 100644 --- a/internal/datastore/postgres/postgres_test.go +++ b/internal/datastore/postgres/postgres_test.go @@ -4,7 +4,6 @@ package postgres import ( "fmt" - "os" "testing" "time" @@ -13,19 +12,10 @@ import ( "github.com/stretchr/testify/require" "github.com/authzed/spicedb/internal/datastore/postgres/common" - "github.com/authzed/spicedb/internal/datastore/postgres/version" testdatastore "github.com/authzed/spicedb/internal/testserver/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" ) -func postgresTestVersion() string { - ver := os.Getenv("POSTGRES_TEST_VERSION") - if ver != "" { - return ver - } - - return version.LatestTestedPostgresVersion -} - var postgresConfig = postgresTestConfig{"head", "", postgresTestVersion(), false} func TestPostgresDatastore(t *testing.T) { @@ -81,6 +71,17 @@ func TestPostgresDatastoreGC(t *testing.T) { MigrationPhase(config.migrationPhase), WithRevisionHeartbeat(false), )) + + t.Run("SchemaGarbageCollection", createDatastoreTest( + b, + SchemaGarbageCollectionTest, + RevisionQuantization(0), + GCWindow(1*time.Millisecond), + GCInterval(veryLargeGCInterval), + WatchBufferLength(1), + MigrationPhase(config.migrationPhase), + WithSchemaMode(options.SchemaModeReadNewWriteNew), + )) }) } diff --git a/internal/datastore/postgres/reader.go b/internal/datastore/postgres/reader.go index ce40b06c8..c25e455e8 100644 --- a/internal/datastore/postgres/reader.go +++ b/internal/datastore/postgres/reader.go @@ -11,9 +11,9 @@ import ( "github.com/authzed/spicedb/internal/datastore/common" pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" "github.com/authzed/spicedb/internal/datastore/postgres/schema" - schemautil "github.com/authzed/spicedb/internal/datastore/schema" + schemaadapter "github.com/authzed/spicedb/internal/datastore/schema" "github.com/authzed/spicedb/pkg/datastore" - "github.com/authzed/spicedb/pkg/datastore/options" + dsoptions "github.com/authzed/spicedb/pkg/datastore/options" core "github.com/authzed/spicedb/pkg/proto/core/v1" ) @@ -23,6 +23,10 @@ type pgReader struct { aliveFilter queryFilterer filterMaximumIDCount uint16 schema common.SchemaInformation + schemaMode dsoptions.SchemaMode + snapshotRevision datastore.Revision + schemaReaderWriter *common.SQLSchemaReaderWriter[uint64, postgresRevision] + schemaHash string } type queryFilterer func(original sq.SelectBuilder) sq.SelectBuilder @@ -149,7 +153,7 @@ func (r *pgReader) lookupCounters(ctx context.Context, optionalName string) ([]d func (r *pgReader) QueryRelationships( ctx context.Context, filter datastore.RelationshipsFilter, - opts ...options.QueryOptionsOption, + opts ...dsoptions.QueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(r.schema, r.filterMaximumIDCount). WithAdditionalFilter(r.aliveFilter). @@ -158,7 +162,7 @@ func (r *pgReader) QueryRelationships( return nil, err } - builtOpts := options.NewQueryOptionsWithOptions(opts...) + builtOpts := dsoptions.NewQueryOptionsWithOptions(opts...) indexingHint := schema.IndexingHintForQueryShape(r.schema, builtOpts.QueryShape) qBuilder = qBuilder.WithIndexingHint(indexingHint) @@ -168,7 +172,7 @@ func (r *pgReader) QueryRelationships( func (r *pgReader) ReverseQueryRelationships( ctx context.Context, subjectsFilter datastore.SubjectsFilter, - opts ...options.ReverseQueryOptionsOption, + opts ...dsoptions.ReverseQueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(r.schema, r.filterMaximumIDCount). WithAdditionalFilter(r.aliveFilter). @@ -177,7 +181,7 @@ func (r *pgReader) ReverseQueryRelationships( return nil, err } - queryOpts := options.NewReverseQueryOptionsWithOptions(opts...) + queryOpts := dsoptions.NewReverseQueryOptionsWithOptions(opts...) if queryOpts.ResRelation != nil { qBuilder = qBuilder. @@ -190,13 +194,13 @@ func (r *pgReader) ReverseQueryRelationships( return r.executor.ExecuteQuery(ctx, qBuilder, - options.WithLimit(queryOpts.LimitForReverse), - options.WithAfter(queryOpts.AfterForReverse), - options.WithSort(queryOpts.SortForReverse), - options.WithSkipCaveats(queryOpts.SkipCaveatsForReverse), - options.WithSkipExpiration(queryOpts.SkipExpirationForReverse), - options.WithQueryShape(queryOpts.QueryShapeForReverse), - options.WithSQLExplainCallbackForTest(queryOpts.SQLExplainCallbackForTestForReverse), + dsoptions.WithLimit(queryOpts.LimitForReverse), + dsoptions.WithAfter(queryOpts.AfterForReverse), + dsoptions.WithSort(queryOpts.SortForReverse), + dsoptions.WithSkipCaveats(queryOpts.SkipCaveatsForReverse), + dsoptions.WithSkipExpiration(queryOpts.SkipExpirationForReverse), + dsoptions.WithQueryShape(queryOpts.QueryShapeForReverse), + dsoptions.WithSQLExplainCallbackForTest(queryOpts.SQLExplainCallbackForTestForReverse), ) } @@ -307,7 +311,63 @@ func revisionForVersion(version xid8) postgresRevision { // SchemaReader returns a SchemaReader for reading schema information. func (r *pgReader) SchemaReader() (datastore.SchemaReader, error) { - return schemautil.NewLegacySchemaReaderAdapter(r), nil + // Wrap the reader with an unexported schema reader + reader := &pgSchemaReader{r: r} + return schemaadapter.NewSchemaReader(reader, r.schemaMode, r.snapshotRevision), nil } -var _ datastore.Reader = &pgReader{} +// pgSchemaReader wraps a pgReader and implements DualSchemaReader. +// This prevents direct access to schema read methods from the reader. +type pgSchemaReader struct { + r *pgReader +} + +// ReadStoredSchema implements datastore.SingleStoreSchemaReader with revision-aware reading +func (sr *pgSchemaReader) ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) { + // Create a revision-aware executor that applies alive filter for schema table + // The schema table uses XID8 columns (created_xid, deleted_xid) just like relationship tables, + // so we use the exact same pg_visible_in_snapshot() logic + executor := &pgRevisionAwareExecutor{ + query: sr.r.query, + aliveFilter: sr.r.aliveFilter, + } + + // Use the shared schema reader/writer to read the schema with hash-based caching + return sr.r.schemaReaderWriter.ReadSchema(ctx, executor, sr.r.snapshotRevision, datastore.SchemaHash(sr.r.schemaHash)) +} + +// LegacyLookupNamespacesWithNames delegates to the underlying reader +func (sr *pgSchemaReader) LegacyLookupNamespacesWithNames(ctx context.Context, nsNames []string) ([]datastore.RevisionedDefinition[*core.NamespaceDefinition], error) { + return sr.r.LegacyLookupNamespacesWithNames(ctx, nsNames) +} + +// LegacyReadCaveatByName delegates to the underlying reader +func (sr *pgSchemaReader) LegacyReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { + return sr.r.LegacyReadCaveatByName(ctx, name) +} + +// LegacyListAllCaveats delegates to the underlying reader +func (sr *pgSchemaReader) LegacyListAllCaveats(ctx context.Context) ([]datastore.RevisionedCaveat, error) { + return sr.r.LegacyListAllCaveats(ctx) +} + +// LegacyLookupCaveatsWithNames delegates to the underlying reader +func (sr *pgSchemaReader) LegacyLookupCaveatsWithNames(ctx context.Context, names []string) ([]datastore.RevisionedCaveat, error) { + return sr.r.LegacyLookupCaveatsWithNames(ctx, names) +} + +// LegacyReadNamespaceByName delegates to the underlying reader +func (sr *pgSchemaReader) LegacyReadNamespaceByName(ctx context.Context, nsName string) (*core.NamespaceDefinition, datastore.Revision, error) { + return sr.r.LegacyReadNamespaceByName(ctx, nsName) +} + +// LegacyListAllNamespaces delegates to the underlying reader +func (sr *pgSchemaReader) LegacyListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { + return sr.r.LegacyListAllNamespaces(ctx) +} + +var ( + _ datastore.Reader = &pgReader{} + _ datastore.LegacySchemaReader = &pgReader{} + _ datastore.DualSchemaReader = &pgSchemaReader{} +) diff --git a/internal/datastore/postgres/readwrite.go b/internal/datastore/postgres/readwrite.go index e327e2d3f..ea1f4b996 100644 --- a/internal/datastore/postgres/readwrite.go +++ b/internal/datastore/postgres/readwrite.go @@ -3,8 +3,11 @@ package postgres import ( "cmp" "context" + "crypto/sha256" + "encoding/hex" "errors" "fmt" + "sort" sq "github.com/Masterminds/squirrel" "github.com/ccoveille/go-safecast/v2" @@ -15,12 +18,14 @@ import ( "github.com/authzed/spicedb/internal/datastore/common" pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" "github.com/authzed/spicedb/internal/datastore/postgres/schema" - schemautil "github.com/authzed/spicedb/internal/datastore/schema" + schemaadapter "github.com/authzed/spicedb/internal/datastore/schema" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/genutil/mapz" core "github.com/authzed/spicedb/pkg/proto/core/v1" typedschema "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" "github.com/authzed/spicedb/pkg/spiceerrors" "github.com/authzed/spicedb/pkg/tuple" ) @@ -666,7 +671,202 @@ func (rwt *pgReadWriteTXN) LegacyDeleteNamespaces(ctx context.Context, nsNames [ } func (rwt *pgReadWriteTXN) SchemaWriter() (datastore.SchemaWriter, error) { - return schemautil.NewLegacySchemaWriterAdapter(rwt, rwt), nil + // Wrap the transaction with an unexported schema writer + writer := &pgSchemaWriter{rwt: rwt} + return schemaadapter.NewSchemaWriter(writer, writer, rwt.schemaMode), nil +} + +// pgSchemaWriter wraps a pgReadWriteTXN and implements DualSchemaWriter. +// This prevents direct access to schema write methods from the transaction. +type pgSchemaWriter struct { + rwt *pgReadWriteTXN +} + +// WriteStoredSchema implements datastore.SingleStoreSchemaWriter by writing within the current transaction +func (w *pgSchemaWriter) WriteStoredSchema(ctx context.Context, schema *core.StoredSchema) error { + // Create a transaction-aware executor that uses the current transaction + executor := newPGTransactionAwareExecutor(w.rwt.tx) + + // Use the shared schema reader/writer to write the schema with the newXID as transaction ID + if err := w.rwt.schemaReaderWriter.WriteSchema(ctx, schema, executor, func(ctx context.Context) uint64 { + return w.rwt.newXID.Uint64 + }); err != nil { + return err + } + + // Write the schema hash to the schema_revision table for fast lookups + if err := w.writeSchemaHash(ctx, schema); err != nil { + return fmt.Errorf("failed to write schema hash: %w", err) + } + + return nil +} + +// writeSchemaHash writes the schema hash to the schema_revision table +func (w *pgSchemaWriter) writeSchemaHash(ctx context.Context, schema *core.StoredSchema) error { + v1 := schema.GetV1() + if v1 == nil { + return fmt.Errorf("unsupported schema version: %d", schema.Version) + } + + // Mark existing hash rows as deleted + sql, args, err := psql.Update("schema_revision"). + Set("deleted_xid", w.rwt.newXID.Uint64). + Where(sq.Eq{ + "name": "current", + "deleted_xid": liveDeletedTxnID, + }). + ToSql() + if err != nil { + return fmt.Errorf("failed to build delete query: %w", err) + } + + if _, err := w.rwt.tx.Exec(ctx, sql, args...); err != nil { + return fmt.Errorf("failed to delete old hash: %w", err) + } + + // Insert new hash row (ON CONFLICT DO NOTHING handles WriteBoth mode) + sql, args, err = psql.Insert("schema_revision"). + Columns("name", "hash", "created_xid", "deleted_xid"). + Values("current", []byte(v1.SchemaHash), w.rwt.newXID.Uint64, liveDeletedTxnID). + Suffix("ON CONFLICT (name, created_xid) DO NOTHING"). + ToSql() + if err != nil { + return fmt.Errorf("failed to build insert query: %w", err) + } + + if _, err := w.rwt.tx.Exec(ctx, sql, args...); err != nil { + return fmt.Errorf("failed to insert hash: %w", err) + } + + return nil +} + +// ReadStoredSchema implements datastore.SingleStoreSchemaReader to satisfy DualSchemaReader interface requirements +func (w *pgSchemaWriter) ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) { + // Create a revision-aware executor that applies alive filter for schema table + executor := &pgRevisionAwareExecutor{ + query: w.rwt.query, + aliveFilter: w.rwt.aliveFilter, + } + + // Use the shared schema reader/writer to read the schema + // Pass empty string for transaction reads to bypass cache reads (but still load) + return w.rwt.schemaReaderWriter.ReadSchema(ctx, executor, nil, datastore.NoSchemaHashInTransaction) +} + +// LegacyWriteNamespaces delegates to the underlying transaction +func (w *pgSchemaWriter) LegacyWriteNamespaces(ctx context.Context, newConfigs ...*core.NamespaceDefinition) error { + return w.rwt.LegacyWriteNamespaces(ctx, newConfigs...) +} + +// LegacyDeleteNamespaces delegates to the underlying transaction +func (w *pgSchemaWriter) LegacyDeleteNamespaces(ctx context.Context, nsNames []string, delOption datastore.DeleteNamespacesRelationshipsOption) error { + return w.rwt.LegacyDeleteNamespaces(ctx, nsNames, delOption) +} + +// LegacyLookupNamespacesWithNames delegates to the underlying transaction +func (w *pgSchemaWriter) LegacyLookupNamespacesWithNames(ctx context.Context, nsNames []string) ([]datastore.RevisionedDefinition[*core.NamespaceDefinition], error) { + return w.rwt.LegacyLookupNamespacesWithNames(ctx, nsNames) +} + +// LegacyWriteCaveats delegates to the underlying transaction +func (w *pgSchemaWriter) LegacyWriteCaveats(ctx context.Context, caveats []*core.CaveatDefinition) error { + return w.rwt.LegacyWriteCaveats(ctx, caveats) +} + +// LegacyDeleteCaveats delegates to the underlying transaction +func (w *pgSchemaWriter) LegacyDeleteCaveats(ctx context.Context, names []string) error { + return w.rwt.LegacyDeleteCaveats(ctx, names) +} + +// LegacyReadCaveatByName delegates to the underlying transaction +func (w *pgSchemaWriter) LegacyReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { + return w.rwt.LegacyReadCaveatByName(ctx, name) +} + +// LegacyListAllCaveats delegates to the underlying transaction +func (w *pgSchemaWriter) LegacyListAllCaveats(ctx context.Context) ([]datastore.RevisionedCaveat, error) { + return w.rwt.LegacyListAllCaveats(ctx) +} + +// LegacyLookupCaveatsWithNames delegates to the underlying transaction +func (w *pgSchemaWriter) LegacyLookupCaveatsWithNames(ctx context.Context, names []string) ([]datastore.RevisionedCaveat, error) { + return w.rwt.LegacyLookupCaveatsWithNames(ctx, names) +} + +// LegacyReadNamespaceByName delegates to the underlying transaction +func (w *pgSchemaWriter) LegacyReadNamespaceByName(ctx context.Context, nsName string) (*core.NamespaceDefinition, datastore.Revision, error) { + return w.rwt.LegacyReadNamespaceByName(ctx, nsName) +} + +// LegacyListAllNamespaces delegates to the underlying transaction +func (w *pgSchemaWriter) LegacyListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { + return w.rwt.LegacyListAllNamespaces(ctx) +} + +// WriteLegacySchemaHashFromDefinitions implements datastore.LegacySchemaHashWriter +func (w *pgSchemaWriter) WriteLegacySchemaHashFromDefinitions(ctx context.Context, namespaces []datastore.RevisionedNamespace, caveats []datastore.RevisionedCaveat) error { + return w.rwt.writeLegacySchemaHashFromDefinitions(ctx, namespaces, caveats) +} + +// writeLegacySchemaHashFromDefinitions writes the schema hash computed from the given definitions +func (rwt *pgReadWriteTXN) writeLegacySchemaHashFromDefinitions(ctx context.Context, namespaces []datastore.RevisionedNamespace, caveats []datastore.RevisionedCaveat) error { + // Build schema definitions list + definitions := make([]compiler.SchemaDefinition, 0, len(namespaces)+len(caveats)) + for _, ns := range namespaces { + definitions = append(definitions, ns.Definition) + } + for _, caveat := range caveats { + definitions = append(definitions, caveat.Definition) + } + + // Sort definitions by name for consistent ordering + sort.Slice(definitions, func(i, j int) bool { + return definitions[i].GetName() < definitions[j].GetName() + }) + + // Generate schema text from definitions + schemaText, _, err := generator.GenerateSchema(definitions) + if err != nil { + return fmt.Errorf("failed to generate schema: %w", err) + } + + // Compute schema hash (SHA256) + hash := sha256.Sum256([]byte(schemaText)) + schemaHash := hex.EncodeToString(hash[:]) + + // Mark existing hash rows as deleted + sql, args, err := psql.Update("schema_revision"). + Set("deleted_xid", rwt.newXID.Uint64). + Where(sq.Eq{ + "name": "current", + "deleted_xid": liveDeletedTxnID, + }). + ToSql() + if err != nil { + return fmt.Errorf("failed to build delete query: %w", err) + } + + if _, err := rwt.tx.Exec(ctx, sql, args...); err != nil { + return fmt.Errorf("failed to delete old hash: %w", err) + } + + // Insert new hash row (ON CONFLICT DO NOTHING handles WriteBoth mode) + sql, args, err = psql.Insert("schema_revision"). + Columns("name", "hash", "created_xid", "deleted_xid"). + Values("current", []byte(schemaHash), rwt.newXID.Uint64, liveDeletedTxnID). + Suffix("ON CONFLICT (name, created_xid) DO NOTHING"). + ToSql() + if err != nil { + return fmt.Errorf("failed to build insert query: %w", err) + } + + if _, err := rwt.tx.Exec(ctx, sql, args...); err != nil { + return fmt.Errorf("failed to insert hash: %w", err) + } + + return nil } func (rwt *pgReadWriteTXN) RegisterCounter(ctx context.Context, name string, filter *core.RelationshipFilter) error { @@ -820,4 +1020,9 @@ func exactRelationshipDifferentCaveatAndExpirationClause(r tuple.Relationship) s } } -var _ datastore.ReadWriteTransaction = &pgReadWriteTXN{} +var ( + _ datastore.ReadWriteTransaction = &pgReadWriteTXN{} + _ datastore.LegacySchemaWriter = &pgReadWriteTXN{} + _ datastore.DualSchemaWriter = &pgSchemaWriter{} + _ datastore.DualSchemaReader = &pgSchemaWriter{} +) diff --git a/internal/datastore/postgres/revisions.go b/internal/datastore/postgres/revisions.go index 18a41f644..8ac7b4eb3 100644 --- a/internal/datastore/postgres/revisions.go +++ b/internal/datastore/postgres/revisions.go @@ -101,50 +101,84 @@ const ( LIMIT 1;` queryCurrentSnapshot = `SELECT pg_current_snapshot();` + // queryCurrentSnapshotWithHash gets current snapshot along with the latest schema_hash + queryCurrentSnapshotWithHash = ` + WITH current_xid AS ( + SELECT pg_current_xact_id() as xid, pg_current_snapshot() as snapshot + ) + SELECT + current_xid.snapshot, + COALESCE((SELECT hash FROM schema_revision WHERE created_xid <= current_xid.xid AND deleted_xid > current_xid.xid ORDER BY created_xid DESC LIMIT 1), ''::bytea) + FROM current_xid;` + queryCurrentTransactionID = `SELECT pg_current_xact_id()::text::integer;` queryLatestXID = `SELECT max(xid)::text::integer FROM relation_tuple_transaction;` ) -func (pgd *pgDatastore) optimizedRevisionFunc(ctx context.Context) (datastore.Revision, time.Duration, error) { +func (pgd *pgDatastore) optimizedRevisionFunc(ctx context.Context) (datastore.Revision, time.Duration, datastore.SchemaHash, error) { var revision xid8 var snapshot pgSnapshot var validForNanos time.Duration - if err := pgd.readPool.QueryRow(ctx, pgd.optimizedRevisionQuery). - Scan(&revision, &snapshot, &validForNanos); err != nil { - return datastore.NoRevision, 0, fmt.Errorf(errRevision, err) + var schemaHash []byte + + // Build query that also fetches schema hash + modifiedQuery := pgd.buildOptimizedRevisionQueryWithHash() + + if err := pgd.readPool.QueryRow(ctx, modifiedQuery). + Scan(&revision, &snapshot, &validForNanos, &schemaHash); err != nil { + return datastore.NoRevision, 0, "", fmt.Errorf(errRevision, err) } snapshot = snapshot.markComplete(revision.Uint64) - return postgresRevision{snapshot: snapshot, optionalTxID: revision}, validForNanos, nil + return postgresRevision{snapshot: snapshot, optionalTxID: revision}, validForNanos, datastore.SchemaHash(schemaHash), nil } -func (pgd *pgDatastore) HeadRevision(ctx context.Context) (datastore.Revision, error) { +func (pgd *pgDatastore) HeadRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { ctx, span := tracer.Start(ctx, "HeadRevision") defer span.End() - result, err := pgd.getHeadRevision(ctx, pgd.readPool) + result, hash, err := pgd.getHeadRevisionWithHash(ctx, pgd.readPool) if err != nil { - return nil, err + return nil, "", err } if result == nil { - return datastore.NoRevision, nil + return datastore.NoRevision, "", nil } - return *result, nil + return *result, hash, nil } -func (pgd *pgDatastore) getHeadRevision(ctx context.Context, querier common.Querier) (*postgresRevision, error) { +func (pgd *pgDatastore) getHeadRevisionWithHash(ctx context.Context, querier common.Querier) (*postgresRevision, datastore.SchemaHash, error) { var snapshot pgSnapshot - if err := querier.QueryRow(ctx, queryCurrentSnapshot).Scan(&snapshot); err != nil { + var schemaHash []byte + + if err := querier.QueryRow(ctx, queryCurrentSnapshotWithHash).Scan(&snapshot, &schemaHash); err != nil { if errors.Is(err, pgx.ErrNoRows) { - return nil, nil + return nil, "", nil } - return nil, fmt.Errorf(errRevision, err) + return nil, "", fmt.Errorf(errRevision, err) } - return &postgresRevision{snapshot: snapshot}, nil + return &postgresRevision{snapshot: snapshot}, datastore.SchemaHash(schemaHash), nil +} + +// buildOptimizedRevisionQueryWithHash creates a modified version of the optimized revision query +// that also fetches the schema_hash +func (pgd *pgDatastore) buildOptimizedRevisionQueryWithHash() string { + // The base query structure is: + // WITH selected AS (SELECT (...) as xid) + // SELECT selected.xid, COALESCE(...), validity_calc FROM selected; + // + // We need to add the schema_hash as a 4th column + baseQuery := pgd.optimizedRevisionQuery + + // Find the "FROM selected;" at the end and insert the schema_hash fetch before it + return baseQuery[:len(baseQuery)-len("FROM selected;")] + + `, + COALESCE((SELECT hash FROM schema_revision WHERE created_xid <= selected.xid AND deleted_xid > selected.xid ORDER BY created_xid DESC LIMIT 1), ''::bytea) + FROM selected;` } func (pgd *pgDatastore) CheckRevision(ctx context.Context, revisionRaw datastore.Revision) error { @@ -337,6 +371,10 @@ func (pr postgresRevision) Equal(rhsRaw datastore.Revision) bool { return ok && pr.snapshot.Equal(rhs.snapshot) } +func (pr postgresRevision) Key() string { + return pr.String() +} + func (pr postgresRevision) GreaterThan(rhsRaw datastore.Revision) bool { if rhsRaw == datastore.NoRevision { return true diff --git a/internal/datastore/postgres/schema/schema.go b/internal/datastore/postgres/schema/schema.go index 674ab37d0..94bdedce1 100644 --- a/internal/datastore/postgres/schema/schema.go +++ b/internal/datastore/postgres/schema/schema.go @@ -12,6 +12,8 @@ const ( TableTuple = "relation_tuple" TableCaveat = "caveat" TableRelationshipCounter = "relationship_counter" + TableSchema = "schema" + TableSchemaRevision = "schema_revision" ColXID = "xid" ColTimestamp = "timestamp" diff --git a/internal/datastore/postgres/schema_chunker.go b/internal/datastore/postgres/schema_chunker.go new file mode 100644 index 000000000..b73a8939e --- /dev/null +++ b/internal/datastore/postgres/schema_chunker.go @@ -0,0 +1,177 @@ +package postgres + +import ( + "context" + "errors" + + sq "github.com/Masterminds/squirrel" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/authzed/spicedb/internal/datastore/common" + pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" +) + +const ( + // PostgreSQL has no practical limit on BYTEA column size (up to 1GB per cell), + // but we use 1MB chunks for reasonable memory usage and query performance. + postgresMaxChunkSize = 1024 * 1024 // 1MB +) + +// BaseSchemaChunkerConfig provides the base configuration for Postgres schema chunking. +// Postgres uses tombstone-based write mode with XID8 columns (matching relationship tables). +var BaseSchemaChunkerConfig = common.SQLByteChunkerConfig[uint64]{ + TableName: "schema", + NameColumn: "name", + ChunkIndexColumn: "chunk_index", + ChunkDataColumn: "chunk_data", + MaxChunkSize: postgresMaxChunkSize, + PlaceholderFormat: sq.Dollar, + WriteMode: common.WriteModeInsertWithTombstones, + CreatedAtColumn: "created_xid", + DeletedAtColumn: "deleted_xid", + AliveValue: liveDeletedTxnID, +} + +// postgresChunkedBytesExecutor implements common.ChunkedBytesExecutor for PostgreSQL. +type postgresChunkedBytesExecutor struct { + pool *pgxpool.Pool +} + +func newPostgresChunkedBytesExecutor(pool *pgxpool.Pool) *postgresChunkedBytesExecutor { + return &postgresChunkedBytesExecutor{pool: pool} +} + +func (e *postgresChunkedBytesExecutor) BeginTransaction(ctx context.Context) (common.ChunkedBytesTransaction, error) { + tx, err := e.pool.Begin(ctx) + if err != nil { + return nil, err + } + return &postgresChunkedBytesTransaction{tx: tx}, nil +} + +func (e *postgresChunkedBytesExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + sql, args, err := builder.ToSql() + if err != nil { + return nil, err + } + + rows, err := e.pool.Query(ctx, sql, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + result := make(map[int][]byte) + for rows.Next() { + var chunkIndex int + var chunkData []byte + if err := rows.Scan(&chunkIndex, &chunkData); err != nil { + return nil, err + } + result[chunkIndex] = chunkData + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return result, nil +} + +// postgresChunkedBytesTransaction implements common.ChunkedBytesTransaction for PostgreSQL. +type postgresChunkedBytesTransaction struct { + tx pgx.Tx +} + +func (t *postgresChunkedBytesTransaction) ExecuteWrite(ctx context.Context, builder sq.InsertBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.Exec(ctx, sql, args...) + return err +} + +func (t *postgresChunkedBytesTransaction) ExecuteDelete(ctx context.Context, builder sq.DeleteBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.Exec(ctx, sql, args...) + return err +} + +func (t *postgresChunkedBytesTransaction) ExecuteUpdate(ctx context.Context, builder sq.UpdateBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.Exec(ctx, sql, args...) + return err +} + +// GetSchemaChunker returns a SQLByteChunker for the schema table. +// This is exported for testing purposes. +func (pgd *pgDatastore) GetSchemaChunker() *common.SQLByteChunker[uint64] { + executor := newPostgresChunkedBytesExecutor(pgd.readPool.(*pgxpool.Pool)) + return common.MustNewSQLByteChunker(BaseSchemaChunkerConfig.WithExecutor(executor)) +} + +// pgRevisionAwareExecutor wraps the reader's query infrastructure to provide revision-aware chunk reading +type pgRevisionAwareExecutor struct { + query pgxcommon.DBFuncQuerier + aliveFilter func(sq.SelectBuilder) sq.SelectBuilder +} + +func (e *pgRevisionAwareExecutor) BeginTransaction(ctx context.Context) (common.ChunkedBytesTransaction, error) { + // We don't support transactions for reading + return nil, errors.New("transactions not supported for revision-aware reads") +} + +func (e *pgRevisionAwareExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + // Apply the alive filter to get chunks that were alive at this revision + builder = e.aliveFilter(builder) + + sql, args, err := builder.ToSql() + if err != nil { + return nil, err + } + + // Execute using the reader's query function + result := make(map[int][]byte) + err = e.query.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error { + for rows.Next() { + var chunkIndex int + var chunkData []byte + if err := rows.Scan(&chunkIndex, &chunkData); err != nil { + return err + } + result[chunkIndex] = chunkData + } + return rows.Err() + }, sql, args...) + + return result, err +} + +// pgTransactionAwareExecutor wraps an existing pgx.Tx to provide transaction-aware chunk writing +type pgTransactionAwareExecutor struct { + tx pgx.Tx +} + +func newPGTransactionAwareExecutor(tx pgx.Tx) *pgTransactionAwareExecutor { + return &pgTransactionAwareExecutor{tx: tx} +} + +func (e *pgTransactionAwareExecutor) BeginTransaction(ctx context.Context) (common.ChunkedBytesTransaction, error) { + // Return a transaction wrapper that uses the existing transaction + return &postgresChunkedBytesTransaction{tx: e.tx}, nil +} + +func (e *pgTransactionAwareExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + return nil, errors.New("read operations not supported on transaction-aware executor") +} diff --git a/internal/datastore/proxy/checkingreplicated.go b/internal/datastore/proxy/checkingreplicated.go index bb086bd53..8b797a5bb 100644 --- a/internal/datastore/proxy/checkingreplicated.go +++ b/internal/datastore/proxy/checkingreplicated.go @@ -94,7 +94,7 @@ type checkingReplicatedDatastore struct { // SnapshotReader creates a read-only handle that reads the datastore at the specified revision. // Any errors establishing the reader will be returned by subsequent calls. -func (rd *checkingReplicatedDatastore) SnapshotReader(revision datastore.Revision) datastore.Reader { +func (rd *checkingReplicatedDatastore) SnapshotReader(revision datastore.Revision, schemaHash datastore.SchemaHash) datastore.Reader { replica := selectReplica(rd.replicas, &rd.lastReplica) replicaID, err := replica.MetricsID() if err != nil { @@ -103,9 +103,10 @@ func (rd *checkingReplicatedDatastore) SnapshotReader(revision datastore.Revisio } readReplicatedSelectedReplicaCount.WithLabelValues(replicaID).Inc() return &checkingStableReader{ - rev: revision, - replica: replica, - primary: rd.Datastore, + rev: revision, + schemaHash: schemaHash, + replica: replica, + primary: rd.Datastore, } } @@ -113,9 +114,10 @@ func (rd *checkingReplicatedDatastore) SnapshotReader(revision datastore.Revisio // reading from it. If the replica does not have the requested revision, the primary will be used // instead. Only supported for a stable replica within each pool. type checkingStableReader struct { - rev datastore.Revision - replica datastore.ReadOnlyDatastore - primary datastore.Datastore + rev datastore.Revision + schemaHash datastore.SchemaHash + replica datastore.ReadOnlyDatastore + primary datastore.Datastore // chosePrimaryForTest is used for testing to determine if the primary was used for the read. chosePrimaryForTest bool @@ -222,6 +224,18 @@ func (rr *checkingStableReader) SchemaReader() (datastore.SchemaReader, error) { return rr.chosenReader.SchemaReader() } +func (rr *checkingStableReader) ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) { + if err := rr.determineSource(ctx); err != nil { + return nil, err + } + + singleStoreReader, ok := rr.chosenReader.(datastore.SingleStoreSchemaReader) + if !ok { + return nil, errors.New("chosen reader does not implement SingleStoreSchemaReader") + } + return singleStoreReader.ReadStoredSchema(ctx) +} + // determineSource will choose the replica or primary to read from based on the revision, by checking // if the replica contains the revision. If the replica does not contain the revision, the primary // will be used instead. @@ -236,7 +250,7 @@ func (rr *checkingStableReader) determineSource(ctx context.Context) error { if errors.As(err, &irr) { if irr.Reason() == datastore.CouldNotDetermineRevision { log.Trace().Str("revision", rr.rev.String()).Err(err).Msg("replica does not contain the requested revision, using primary") - rr.chosenReader = rr.primary.SnapshotReader(rr.rev) + rr.chosenReader = rr.primary.SnapshotReader(rr.rev, rr.schemaHash) rr.chosePrimaryForTest = true return } @@ -253,9 +267,17 @@ func (rr *checkingStableReader) determineSource(ctx context.Context) error { } checkingReplicatedReplicaReaderCount.WithLabelValues(metricsID).Inc() - rr.chosenReader = rr.replica.SnapshotReader(rr.rev) + rr.chosenReader = rr.replica.SnapshotReader(rr.rev, rr.schemaHash) rr.chosePrimaryForTest = false }) return finalError } + +var ( + _ datastore.Datastore = (*checkingReplicatedDatastore)(nil) + _ datastore.Reader = (*checkingStableReader)(nil) + _ datastore.LegacySchemaReader = (*checkingStableReader)(nil) + _ datastore.SingleStoreSchemaReader = (*checkingStableReader)(nil) + _ datastore.DualSchemaReader = (*checkingStableReader)(nil) +) diff --git a/internal/datastore/proxy/checkingreplicated_test.go b/internal/datastore/proxy/checkingreplicated_test.go index b950ad085..2d56a8441 100644 --- a/internal/datastore/proxy/checkingreplicated_test.go +++ b/internal/datastore/proxy/checkingreplicated_test.go @@ -25,7 +25,7 @@ func TestCheckingReplicatedReaderFallsbackToPrimaryOnCheckRevisionFailure(t *tes require.NoError(t, err) // Try at revision 1, which should use the replica. - reader := replicated.SnapshotReader(revisionparsing.MustParseRevisionForTest("1")) + reader := replicated.SnapshotReader(revisionparsing.MustParseRevisionForTest("1"), datastore.NoSchemaHashForTesting) ns, err := reader.LegacyListAllNamespaces(t.Context()) require.NoError(t, err) require.Empty(t, ns) @@ -33,7 +33,7 @@ func TestCheckingReplicatedReaderFallsbackToPrimaryOnCheckRevisionFailure(t *tes require.False(t, reader.(*checkingStableReader).chosePrimaryForTest) // Try at revision 2, which should use the primary. - reader = replicated.SnapshotReader(revisionparsing.MustParseRevisionForTest("2")) + reader = replicated.SnapshotReader(revisionparsing.MustParseRevisionForTest("2"), datastore.NoSchemaHashForTesting) ns, err = reader.LegacyListAllNamespaces(t.Context()) require.NoError(t, err) require.Empty(t, ns) @@ -48,7 +48,7 @@ func TestCheckingReplicatedReaderFallsbackToPrimaryOnRevisionNotAvailableError(t replicated, err := NewCheckingReplicatedDatastore(primary, replica) require.NoError(t, err) - reader := replicated.SnapshotReader(revisionparsing.MustParseRevisionForTest("3")) + reader := replicated.SnapshotReader(revisionparsing.MustParseRevisionForTest("3"), datastore.NoSchemaHashForTesting) ns, err := reader.LegacyLookupNamespacesWithNames(t.Context(), []string{"ns1"}) require.NoError(t, err) require.Len(t, ns, 1) @@ -72,7 +72,7 @@ func TestReplicatedReaderReturnsExpectedError(t *testing.T) { } // Try at revision 1, which should use the replica. - reader := ds.SnapshotReader(revisionparsing.MustParseRevisionForTest("1")) + reader := ds.SnapshotReader(revisionparsing.MustParseRevisionForTest("1"), datastore.NoSchemaHashForTesting) _, _, err := reader.LegacyReadNamespaceByName(t.Context(), "expecterror") require.Error(t, err) require.ErrorContains(t, err, "raising an expected error") @@ -90,7 +90,7 @@ func (f fakeDatastore) MetricsID() (string, error) { return "fake", nil } -func (f fakeDatastore) SnapshotReader(revision datastore.Revision) datastore.Reader { +func (f fakeDatastore) SnapshotReader(revision datastore.Revision, _ datastore.SchemaHash) datastore.Reader { return fakeSnapshotReader{ revision: revision, state: f.state, @@ -102,12 +102,12 @@ func (f fakeDatastore) ReadWriteTx(_ context.Context, _ datastore.TxUserFunc, _ return nil, nil } -func (f fakeDatastore) OptimizedRevision(_ context.Context) (datastore.Revision, error) { - return nil, nil +func (f fakeDatastore) OptimizedRevision(_ context.Context) (datastore.Revision, datastore.SchemaHash, error) { + return nil, datastore.NoSchemaHashForTesting, nil } -func (f fakeDatastore) HeadRevision(_ context.Context) (datastore.Revision, error) { - return nil, nil +func (f fakeDatastore) HeadRevision(_ context.Context) (datastore.Revision, datastore.SchemaHash, error) { + return nil, datastore.NoSchemaHashForTesting, nil } func (f fakeDatastore) CheckRevision(_ context.Context, rev datastore.Revision) error { diff --git a/internal/datastore/proxy/counting.go b/internal/datastore/proxy/counting.go index 60be581d0..fda9839b3 100644 --- a/internal/datastore/proxy/counting.go +++ b/internal/datastore/proxy/counting.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "errors" "sync/atomic" "github.com/prometheus/client_golang/prometheus" @@ -157,19 +158,19 @@ func (p *countingProxy) UniqueID(ctx context.Context) (string, error) { return p.delegate.UniqueID(ctx) } -func (p *countingProxy) SnapshotReader(rev datastore.Revision) datastore.Reader { - delegateReader := p.delegate.SnapshotReader(rev) +func (p *countingProxy) SnapshotReader(rev datastore.Revision, schemaHash datastore.SchemaHash) datastore.Reader { + delegateReader := p.delegate.SnapshotReader(rev, schemaHash) return &countingReader{ delegate: delegateReader, counts: p.counts, } } -func (p *countingProxy) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { +func (p *countingProxy) OptimizedRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { return p.delegate.OptimizedRevision(ctx) } -func (p *countingProxy) HeadRevision(ctx context.Context) (datastore.Revision, error) { +func (p *countingProxy) HeadRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { return p.delegate.HeadRevision(ctx) } @@ -263,15 +264,38 @@ func (r *countingReader) LookupCounters(ctx context.Context) ([]datastore.Relati return r.delegate.LookupCounters(ctx) } -// SchemaReader returns a wrapped version of the countingReader that exercises -// the legacy methods when the new methods are invoked. +// SchemaReader returns a schema reader that respects the underlying schema mode. +// For new unified schema mode, it passes through directly. For legacy mode, +// it wraps the proxy to ensure counting is maintained. func (r *countingReader) SchemaReader() (datastore.SchemaReader, error) { + underlyingSchemaReader, err := r.delegate.SchemaReader() + if err != nil { + return nil, err + } + + // If using new unified schema mode, pass through directly + if _, isLegacy := underlyingSchemaReader.(*schemautil.LegacySchemaReaderAdapter); !isLegacy { + return underlyingSchemaReader, nil + } + + // For legacy mode, wrap to maintain counting return schemautil.NewLegacySchemaReaderAdapter(r), nil } +func (r *countingReader) ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) { + singleStoreReader, ok := r.delegate.(datastore.SingleStoreSchemaReader) + if !ok { + return nil, errors.New("delegate reader does not implement SingleStoreSchemaReader") + } + return singleStoreReader.ReadStoredSchema(ctx) +} + // Type assertions var ( - _ datastore.ReadOnlyDatastore = (*countingProxy)(nil) - _ datastore.Datastore = (*countingDatastoreProxy)(nil) - _ datastore.Reader = (*countingReader)(nil) + _ datastore.ReadOnlyDatastore = (*countingProxy)(nil) + _ datastore.Datastore = (*countingDatastoreProxy)(nil) + _ datastore.Reader = (*countingReader)(nil) + _ datastore.LegacySchemaReader = (*countingReader)(nil) + _ datastore.SingleStoreSchemaReader = (*countingReader)(nil) + _ datastore.DualSchemaReader = (*countingReader)(nil) ) diff --git a/internal/datastore/proxy/counting_test.go b/internal/datastore/proxy/counting_test.go index 50f659e41..b3854d203 100644 --- a/internal/datastore/proxy/counting_test.go +++ b/internal/datastore/proxy/counting_test.go @@ -37,7 +37,7 @@ func TestCountingProxyBasicCounting(t *testing.T) { require.Equal(uint64(0), counts.LegacyListAllNamespaces()) require.Equal(uint64(0), counts.LegacyLookupNamespacesWithNames()) - r := ds.SnapshotReader(datastore.NoRevision) + r := ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting) // Call each method once _, err := r.QueryRelationships(ctx, datastore.RelationshipsFilter{}) @@ -72,7 +72,7 @@ func TestCountingProxyMultipleCalls(t *testing.T) { ds, counts := NewCountingDatastoreProxy(delegate) ctx := context.Background() - r := ds.SnapshotReader(datastore.NoRevision) + r := ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting) require.Equal(uint64(0), counts.QueryRelationships()) @@ -103,7 +103,7 @@ func TestCountingProxyCaveatMethodsNotCounted(t *testing.T) { ds, counts := NewCountingDatastoreProxy(delegate) ctx := context.Background() - r := ds.SnapshotReader(datastore.NoRevision) + r := ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting) // Call caveat methods _, _, err := r.LegacyReadCaveatByName(ctx, "test") @@ -135,7 +135,7 @@ func TestCountingProxyCounterMethodsNotCounted(t *testing.T) { ds, counts := NewCountingDatastoreProxy(delegate) ctx := context.Background() - r := ds.SnapshotReader(datastore.NoRevision) + r := ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting) // Call counter methods _, err := r.CountRelationships(ctx, "counter1") @@ -173,7 +173,7 @@ func TestCountingProxyThreadSafety(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - r := ds.SnapshotReader(datastore.NoRevision) + r := ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting) for range callsPerGoroutine { _, err := r.QueryRelationships(ctx, datastore.RelationshipsFilter{}) assert.NoError(t, err) @@ -218,13 +218,13 @@ func TestCountingProxyPassthrough(t *testing.T) { require.Equal("mockds", uniqueID) // Test HeadRevision - delegate.On("HeadRevision", mock.Anything).Return(datastore.NoRevision, nil).Once() - _, err = ds.HeadRevision(ctx) + delegate.On("HeadRevision", mock.Anything).Return(datastore.NoRevision, datastore.NoSchemaHashForTesting, nil).Once() + _, _, err = ds.HeadRevision(ctx) require.NoError(err) // Test OptimizedRevision - delegate.On("OptimizedRevision", mock.Anything).Return(datastore.NoRevision, nil).Once() - _, err = ds.OptimizedRevision(ctx) + delegate.On("OptimizedRevision", mock.Anything).Return(datastore.NoRevision, datastore.NoSchemaHashForTesting, nil).Once() + _, _, err = ds.OptimizedRevision(ctx) require.NoError(err) // Test CheckRevision @@ -279,19 +279,19 @@ func TestCountingProxyMultipleReaders(t *testing.T) { reader2 := &proxy_test.MockReader{} // First snapshot reader - delegate.On("SnapshotReader", datastore.NoRevision).Return(reader1).Once() + delegate.On("SnapshotReader", datastore.NoRevision, datastore.NoSchemaHashForTesting).Return(reader1).Once() reader1.On("QueryRelationships", mock.Anything, mock.Anything).Return(nil, nil) // Second snapshot reader - delegate.On("SnapshotReader", mock.Anything).Return(reader2) + delegate.On("SnapshotReader", mock.Anything, mock.Anything).Return(reader2) reader2.On("QueryRelationships", mock.Anything, mock.Anything).Return(nil, nil) ds, counts := NewCountingDatastoreProxy(delegate) ctx := context.Background() // Create two readers and make calls on each - r1 := ds.SnapshotReader(datastore.NoRevision) - r2 := ds.SnapshotReader(datastore.NoRevision) + r1 := ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting) + r2 := ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting) _, err := r1.QueryRelationships(ctx, datastore.RelationshipsFilter{}) require.NoError(err) @@ -318,7 +318,7 @@ func TestWriteMethodCounts(t *testing.T) { ds, counts := NewCountingDatastoreProxy(delegate) ctx := context.Background() - r := ds.SnapshotReader(datastore.NoRevision) + r := ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting) // Make some calls _, err := r.QueryRelationships(ctx, datastore.RelationshipsFilter{}) diff --git a/internal/datastore/proxy/indexcheck/fakedatastore_test.go b/internal/datastore/proxy/indexcheck/fakedatastore_test.go index 6143c13a3..ce75ea379 100644 --- a/internal/datastore/proxy/indexcheck/fakedatastore_test.go +++ b/internal/datastore/proxy/indexcheck/fakedatastore_test.go @@ -26,7 +26,7 @@ func (f fakeDatastore) UniqueID(_ context.Context) (string, error) { return "fake", nil } -func (f fakeDatastore) SnapshotReader(revision datastore.Revision) datastore.Reader { +func (f fakeDatastore) SnapshotReader(revision datastore.Revision, _ datastore.SchemaHash) datastore.Reader { return fakeSnapshotReader{ revision: revision, indexesUsed: f.indexesUsed, @@ -39,12 +39,12 @@ func (f fakeDatastore) ReadWriteTx(ctx context.Context, fn datastore.TxUserFunc, }) } -func (f fakeDatastore) OptimizedRevision(_ context.Context) (datastore.Revision, error) { - return nil, nil +func (f fakeDatastore) OptimizedRevision(_ context.Context) (datastore.Revision, datastore.SchemaHash, error) { + return nil, datastore.NoSchemaHashForTesting, nil } -func (f fakeDatastore) HeadRevision(_ context.Context) (datastore.Revision, error) { - return nil, nil +func (f fakeDatastore) HeadRevision(_ context.Context) (datastore.Revision, datastore.SchemaHash, error) { + return nil, datastore.NoSchemaHashForTesting, nil } func (f fakeDatastore) CheckRevision(_ context.Context, rev datastore.Revision) error { diff --git a/internal/datastore/proxy/indexcheck/indexcheck.go b/internal/datastore/proxy/indexcheck/indexcheck.go index 4846ca8ba..4f4889a98 100644 --- a/internal/datastore/proxy/indexcheck/indexcheck.go +++ b/internal/datastore/proxy/indexcheck/indexcheck.go @@ -2,6 +2,7 @@ package indexcheck import ( "context" + "errors" "fmt" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" @@ -33,8 +34,8 @@ func WrapWithIndexCheckingDatastoreProxyIfApplicable(ds datastore.Datastore) dat type indexcheckingProxy struct{ delegate datastore.SQLDatastore } -func (p *indexcheckingProxy) SnapshotReader(rev datastore.Revision) datastore.Reader { - delegateReader := p.delegate.SnapshotReader(rev) +func (p *indexcheckingProxy) SnapshotReader(rev datastore.Revision, schemaHash datastore.SchemaHash) datastore.Reader { + delegateReader := p.delegate.SnapshotReader(rev, schemaHash) return &indexcheckingReader{p.delegate, delegateReader} } @@ -56,7 +57,7 @@ func (p *indexcheckingProxy) UniqueID(ctx context.Context) (string, error) { return p.delegate.UniqueID(ctx) } -func (p *indexcheckingProxy) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { +func (p *indexcheckingProxy) OptimizedRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { return p.delegate.OptimizedRevision(ctx) } @@ -64,7 +65,7 @@ func (p *indexcheckingProxy) CheckRevision(ctx context.Context, revision datasto return p.delegate.CheckRevision(ctx, revision) } -func (p *indexcheckingProxy) HeadRevision(ctx context.Context) (datastore.Revision, error) { +func (p *indexcheckingProxy) HeadRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { return p.delegate.HeadRevision(ctx) } @@ -98,6 +99,36 @@ func (p *indexcheckingProxy) ReadyState(ctx context.Context) (datastore.ReadySta func (p *indexcheckingProxy) Close() error { return p.delegate.Close() } +// SchemaHashReaderForTesting returns a test-only interface for reading the schema hash. +// This delegates to the underlying datastore if it supports the test interface. +func (p *indexcheckingProxy) SchemaHashReaderForTesting() interface { + ReadSchemaHash(ctx context.Context) (string, error) +} { + type schemaHashReaderProvider interface { + SchemaHashReaderForTesting() interface { + ReadSchemaHash(ctx context.Context) (string, error) + } + } + + if provider, ok := p.delegate.(schemaHashReaderProvider); ok { + return provider.SchemaHashReaderForTesting() + } + return nil +} + +// SchemaModeForTesting returns the current schema mode for testing purposes. +// This delegates to the underlying datastore if it supports the test interface. +func (p *indexcheckingProxy) SchemaModeForTesting() (options.SchemaMode, error) { + type schemaModeProvider interface { + SchemaModeForTesting() (options.SchemaMode, error) + } + + if provider, ok := p.delegate.(schemaModeProvider); ok { + return provider.SchemaModeForTesting() + } + return options.SchemaModeReadLegacyWriteLegacy, errors.New("delegate datastore does not implement SchemaModeForTesting()") +} + type indexcheckingReader struct { parent datastore.SQLDatastore delegate datastore.Reader @@ -187,6 +218,14 @@ func (r *indexcheckingReader) SchemaReader() (datastore.SchemaReader, error) { return r.delegate.SchemaReader() } +func (r *indexcheckingReader) ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) { + singleStoreReader, ok := r.delegate.(datastore.SingleStoreSchemaReader) + if !ok { + return nil, errors.New("delegate reader does not implement SingleStoreSchemaReader") + } + return singleStoreReader.ReadStoredSchema(ctx) +} + type indexcheckingRWT struct { *indexcheckingReader delegate datastore.ReadWriteTransaction @@ -228,6 +267,14 @@ func (rwt *indexcheckingRWT) SchemaWriter() (datastore.SchemaWriter, error) { return rwt.delegate.SchemaWriter() } +func (rwt *indexcheckingRWT) WriteStoredSchema(ctx context.Context, schema *core.StoredSchema) error { + singleStoreWriter, ok := rwt.delegate.(datastore.SingleStoreSchemaWriter) + if !ok { + return errors.New("delegate transaction does not implement SingleStoreSchemaWriter") + } + return singleStoreWriter.WriteStoredSchema(ctx, schema) +} + func (rwt *indexcheckingRWT) DeleteRelationships(ctx context.Context, filter *v1.RelationshipFilter, options ...options.DeleteOptionsOption) (uint64, bool, error) { return rwt.delegate.DeleteRelationships(ctx, filter, options...) } @@ -237,7 +284,13 @@ func (rwt *indexcheckingRWT) BulkLoad(ctx context.Context, iter datastore.BulkWr } var ( - _ datastore.Datastore = (*indexcheckingProxy)(nil) - _ datastore.Reader = (*indexcheckingReader)(nil) - _ datastore.ReadWriteTransaction = (*indexcheckingRWT)(nil) + _ datastore.Datastore = (*indexcheckingProxy)(nil) + _ datastore.Reader = (*indexcheckingReader)(nil) + _ datastore.LegacySchemaReader = (*indexcheckingReader)(nil) + _ datastore.SingleStoreSchemaReader = (*indexcheckingReader)(nil) + _ datastore.DualSchemaReader = (*indexcheckingReader)(nil) + _ datastore.ReadWriteTransaction = (*indexcheckingRWT)(nil) + _ datastore.LegacySchemaWriter = (*indexcheckingRWT)(nil) + _ datastore.SingleStoreSchemaWriter = (*indexcheckingRWT)(nil) + _ datastore.DualSchemaWriter = (*indexcheckingRWT)(nil) ) diff --git a/internal/datastore/proxy/indexcheck/indexcheck_test.go b/internal/datastore/proxy/indexcheck/indexcheck_test.go index 91d385658..3521805e8 100644 --- a/internal/datastore/proxy/indexcheck/indexcheck_test.go +++ b/internal/datastore/proxy/indexcheck/indexcheck_test.go @@ -21,10 +21,10 @@ func TestIndexCheckingMissingIndex(t *testing.T) { wrapped := WrapWithIndexCheckingDatastoreProxyIfApplicable(ds) require.NotNil(t, wrapped.(*indexcheckingProxy).delegate) - headRev, err := ds.HeadRevision(t.Context()) + headRev, _, err := ds.HeadRevision(t.Context()) require.NoError(t, err) - reader := wrapped.SnapshotReader(headRev) + reader := wrapped.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) it, err := reader.QueryRelationships(t.Context(), datastore.RelationshipsFilter{ OptionalResourceType: "document", OptionalResourceIds: []string{"somedoc"}, @@ -49,10 +49,10 @@ func TestIndexCheckingFoundIndex(t *testing.T) { wrapped := WrapWithIndexCheckingDatastoreProxyIfApplicable(ds) require.NotNil(t, wrapped.(*indexcheckingProxy).delegate) - headRev, err := ds.HeadRevision(t.Context()) + headRev, _, err := ds.HeadRevision(t.Context()) require.NoError(t, err) - reader := wrapped.SnapshotReader(headRev) + reader := wrapped.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) it, err := reader.QueryRelationships(t.Context(), datastore.RelationshipsFilter{ OptionalResourceType: "document", OptionalResourceIds: []string{"somedoc"}, @@ -97,7 +97,7 @@ func TestIndexCheckingProxyMethods(t *testing.T) { }) t.Run("OptimizedRevision", func(t *testing.T) { - rev, err := proxy.OptimizedRevision(t.Context()) + rev, _, err := proxy.OptimizedRevision(t.Context()) require.NoError(t, err) require.Nil(t, rev) }) @@ -108,7 +108,7 @@ func TestIndexCheckingProxyMethods(t *testing.T) { }) t.Run("HeadRevision", func(t *testing.T) { - rev, err := proxy.HeadRevision(t.Context()) + rev, _, err := proxy.HeadRevision(t.Context()) require.NoError(t, err) require.Nil(t, rev) }) @@ -163,7 +163,7 @@ func TestIndexCheckingProxyMethods(t *testing.T) { func TestIndexCheckingReaderMethods(t *testing.T) { ds := fakeDatastore{} proxy := newIndexCheckingDatastoreProxy(ds) - reader := proxy.SnapshotReader(nil) + reader := proxy.SnapshotReader(nil, datastore.NoSchemaHashForTesting) t.Run("CountRelationships", func(t *testing.T) { count, err := reader.CountRelationships(t.Context(), "test") diff --git a/internal/datastore/proxy/observable.go b/internal/datastore/proxy/observable.go index 9e3ef5934..5962ee9fc 100644 --- a/internal/datastore/proxy/observable.go +++ b/internal/datastore/proxy/observable.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "errors" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -79,8 +80,8 @@ func (p *observableProxy) UniqueID(ctx context.Context) (string, error) { return p.delegate.UniqueID(ctx) } -func (p *observableProxy) SnapshotReader(rev datastore.Revision) datastore.Reader { - delegateReader := p.delegate.SnapshotReader(rev) +func (p *observableProxy) SnapshotReader(rev datastore.Revision, schemaHash datastore.SchemaHash) datastore.Reader { + delegateReader := p.delegate.SnapshotReader(rev, schemaHash) return &observableReader{delegateReader} } @@ -94,7 +95,7 @@ func (p *observableProxy) ReadWriteTx( }, opts...) } -func (p *observableProxy) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { +func (p *observableProxy) OptimizedRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { ctx, closer := observe(ctx, "OptimizedRevision", "") defer closer() @@ -110,7 +111,7 @@ func (p *observableProxy) CheckRevision(ctx context.Context, revision datastore. return p.delegate.CheckRevision(ctx, revision) } -func (p *observableProxy) HeadRevision(ctx context.Context) (datastore.Revision, error) { +func (p *observableProxy) HeadRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { ctx, closer := observe(ctx, "HeadRevision", "") defer closer() @@ -156,6 +157,34 @@ func (p *observableProxy) ReadyState(ctx context.Context) (datastore.ReadyState, func (p *observableProxy) Close() error { return p.delegate.Close() } +// SchemaHashReaderForTesting delegates to the underlying datastore if it implements the test interface +func (p *observableProxy) SchemaHashReaderForTesting() interface { + ReadSchemaHash(ctx context.Context) (string, error) +} { + type schemaHashReaderProvider interface { + SchemaHashReaderForTesting() interface { + ReadSchemaHash(ctx context.Context) (string, error) + } + } + + if hashReader, ok := p.delegate.(schemaHashReaderProvider); ok { + return hashReader.SchemaHashReaderForTesting() + } + return nil +} + +// SchemaModeForTesting delegates to the underlying datastore if it implements the test interface +func (p *observableProxy) SchemaModeForTesting() (options.SchemaMode, error) { + type schemaModeProvider interface { + SchemaModeForTesting() (options.SchemaMode, error) + } + + if provider, ok := p.delegate.(schemaModeProvider); ok { + return provider.SchemaModeForTesting() + } + return options.SchemaModeReadLegacyWriteLegacy, errors.New("delegate datastore does not implement SchemaModeForTesting()") +} + type observableReader struct{ delegate datastore.Reader } func (r *observableReader) CountRelationships(ctx context.Context, name string) (int, error) { @@ -278,12 +307,32 @@ func (r *observableReader) ReverseQueryRelationships(ctx context.Context, subjec }, nil } -// SchemaReader returns a wrapped version of the proxy that exercises the -// legacy methods to implement the new methods. +// SchemaReader returns a schema reader that respects the underlying schema mode. +// For new unified schema mode, it passes through directly. For legacy mode, +// it wraps the proxy to ensure observability is maintained. func (r *observableReader) SchemaReader() (datastore.SchemaReader, error) { + underlyingSchemaReader, err := r.delegate.SchemaReader() + if err != nil { + return nil, err + } + + // If using new unified schema mode, pass through directly + if _, isLegacy := underlyingSchemaReader.(*schemautil.LegacySchemaReaderAdapter); !isLegacy { + return underlyingSchemaReader, nil + } + + // For legacy mode, wrap to maintain observability return schemautil.NewLegacySchemaReaderAdapter(r), nil } +func (r *observableReader) ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) { + singleStoreReader, ok := r.delegate.(datastore.SingleStoreSchemaReader) + if !ok { + return nil, errors.New("delegate reader does not implement SingleStoreSchemaReader") + } + return singleStoreReader.ReadStoredSchema(ctx) +} + type observableRWT struct { *observableReader delegate datastore.ReadWriteTransaction @@ -377,6 +426,14 @@ func (rwt *observableRWT) SchemaWriter() (datastore.SchemaWriter, error) { return rwt.delegate.SchemaWriter() } +func (rwt *observableRWT) WriteStoredSchema(ctx context.Context, schema *core.StoredSchema) error { + singleStoreWriter, ok := rwt.delegate.(datastore.SingleStoreSchemaWriter) + if !ok { + return errors.New("delegate transaction does not implement SingleStoreSchemaWriter") + } + return singleStoreWriter.WriteStoredSchema(ctx, schema) +} + func (rwt *observableRWT) DeleteRelationships(ctx context.Context, filter *v1.RelationshipFilter, options ...options.DeleteOptionsOption) (uint64, bool, error) { ctx, closer := observe(ctx, "DeleteRelationships", "", trace.WithAttributes( filterToAttributes(filter)..., @@ -415,7 +472,13 @@ func observe(ctx context.Context, name string, queryShape string, opts ...trace. } var ( - _ datastore.Datastore = (*observableProxy)(nil) - _ datastore.Reader = (*observableReader)(nil) - _ datastore.ReadWriteTransaction = (*observableRWT)(nil) + _ datastore.Datastore = (*observableProxy)(nil) + _ datastore.Reader = (*observableReader)(nil) + _ datastore.LegacySchemaReader = (*observableReader)(nil) + _ datastore.SingleStoreSchemaReader = (*observableReader)(nil) + _ datastore.DualSchemaReader = (*observableReader)(nil) + _ datastore.ReadWriteTransaction = (*observableRWT)(nil) + _ datastore.LegacySchemaWriter = (*observableRWT)(nil) + _ datastore.SingleStoreSchemaWriter = (*observableRWT)(nil) + _ datastore.DualSchemaWriter = (*observableRWT)(nil) ) diff --git a/internal/datastore/proxy/proxy_test/mock.go b/internal/datastore/proxy/proxy_test/mock.go index 268a93b70..8a648a6c4 100644 --- a/internal/datastore/proxy/proxy_test/mock.go +++ b/internal/datastore/proxy/proxy_test/mock.go @@ -28,7 +28,7 @@ func (dm *MockDatastore) UniqueID(_ context.Context) (string, error) { return dm.CurrentUniqueID, nil } -func (dm *MockDatastore) SnapshotReader(rev datastore.Revision) datastore.Reader { +func (dm *MockDatastore) SnapshotReader(rev datastore.Revision, schemaHash datastore.SchemaHash) datastore.Reader { args := dm.Called(rev) return args.Get(0).(datastore.Reader) } @@ -52,14 +52,14 @@ func (dm *MockDatastore) ReadWriteTx( return args.Get(1).(datastore.Revision), args.Error(2) } -func (dm *MockDatastore) OptimizedRevision(_ context.Context) (datastore.Revision, error) { +func (dm *MockDatastore) OptimizedRevision(_ context.Context) (datastore.Revision, datastore.SchemaHash, error) { args := dm.Called() - return args.Get(0).(datastore.Revision), args.Error(1) + return args.Get(0).(datastore.Revision), args.Get(1).(datastore.SchemaHash), args.Error(2) } -func (dm *MockDatastore) HeadRevision(_ context.Context) (datastore.Revision, error) { +func (dm *MockDatastore) HeadRevision(_ context.Context) (datastore.Revision, datastore.SchemaHash, error) { args := dm.Called() - return args.Get(0).(datastore.Revision), args.Error(1) + return args.Get(0).(datastore.Revision), args.Get(1).(datastore.SchemaHash), args.Error(2) } func (dm *MockDatastore) CheckRevision(_ context.Context, revision datastore.Revision) error { @@ -210,6 +210,15 @@ func (dm *MockReader) SchemaReader() (datastore.SchemaReader, error) { return sr, args.Error(1) } +func (dm *MockReader) ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) { + args := dm.Called() + var schema *core.StoredSchema + if args.Get(0) != nil { + schema = args.Get(0).(*core.StoredSchema) + } + return schema, args.Error(1) +} + type MockReadWriteTransaction struct { mock.Mock } @@ -388,7 +397,8 @@ func (dm *MockReadWriteTransaction) SchemaWriter() (datastore.SchemaWriter, erro } var ( - _ datastore.Datastore = &MockDatastore{} - _ datastore.Reader = &MockReader{} - _ datastore.ReadWriteTransaction = &MockReadWriteTransaction{} + _ datastore.Datastore = &MockDatastore{} + _ datastore.Reader = &MockReader{} + _ datastore.SingleStoreSchemaReader = &MockReader{} + _ datastore.ReadWriteTransaction = &MockReadWriteTransaction{} ) diff --git a/internal/datastore/proxy/readonly_test.go b/internal/datastore/proxy/readonly_test.go index 69ba1b8fb..2d12e36ea 100644 --- a/internal/datastore/proxy/readonly_test.go +++ b/internal/datastore/proxy/readonly_test.go @@ -20,7 +20,7 @@ func newReadOnlyMock() (*proxy_test.MockDatastore, *proxy_test.MockReader) { readerMock := &proxy_test.MockReader{} dsMock.On("ReadWriteTx").Panic("read-only proxy should never open a read-write transaction").Maybe() - dsMock.On("SnapshotReader", mock.Anything).Return(readerMock).Maybe() + dsMock.On("SnapshotReader", mock.Anything, mock.Anything).Return(readerMock).Maybe() return dsMock, readerMock } @@ -74,9 +74,9 @@ func TestOptimizedRevisionPassthrough(t *testing.T) { ds := NewReadonlyDatastore(delegate) ctx := t.Context() - delegate.On("OptimizedRevision").Return(expectedRevision, nil).Times(1) + delegate.On("OptimizedRevision").Return(expectedRevision, datastore.NoSchemaHashForTesting, nil).Times(1) - revision, err := ds.OptimizedRevision(ctx) + revision, _, err := ds.OptimizedRevision(ctx) require.NoError(err) require.Equal(expectedRevision, revision) delegate.AssertExpectations(t) @@ -89,9 +89,9 @@ func TestHeadRevisionPassthrough(t *testing.T) { ds := NewReadonlyDatastore(delegate) ctx := t.Context() - delegate.On("HeadRevision").Return(expectedRevision, nil).Times(1) + delegate.On("HeadRevision").Return(expectedRevision, datastore.NoSchemaHashForTesting, nil).Times(1) - revision, err := ds.HeadRevision(ctx) + revision, _, err := ds.HeadRevision(ctx) require.NoError(err) require.Equal(expectedRevision, revision) delegate.AssertExpectations(t) @@ -134,7 +134,7 @@ func TestSnapshotReaderPassthrough(t *testing.T) { reader.On("LegacyReadNamespaceByName", "fake").Return(nil, expectedRevision, nil).Times(1) - _, rev, err := ds.SnapshotReader(expectedRevision).LegacyReadNamespaceByName(ctx, "fake") + _, rev, err := ds.SnapshotReader(expectedRevision, datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(ctx, "fake") require.NoError(err) require.True(expectedRevision.Equal(rev)) delegate.AssertExpectations(t) diff --git a/internal/datastore/proxy/relationshipintegrity.go b/internal/datastore/proxy/relationshipintegrity.go index 7c76833b0..516f3d8d8 100644 --- a/internal/datastore/proxy/relationshipintegrity.go +++ b/internal/datastore/proxy/relationshipintegrity.go @@ -166,10 +166,10 @@ func (r *relationshipIntegrityProxy) UniqueID(ctx context.Context) (string, erro return r.ds.UniqueID(ctx) } -func (r *relationshipIntegrityProxy) SnapshotReader(rev datastore.Revision) datastore.Reader { +func (r *relationshipIntegrityProxy) SnapshotReader(rev datastore.Revision, schemaHash datastore.SchemaHash) datastore.Reader { return relationshipIntegrityReader{ parent: r, - wrapped: r.ds.SnapshotReader(rev), + wrapped: r.ds.SnapshotReader(rev, schemaHash), } } @@ -198,11 +198,11 @@ func (r *relationshipIntegrityProxy) OfflineFeatures() (*datastore.Features, err return r.ds.OfflineFeatures() } -func (r *relationshipIntegrityProxy) HeadRevision(ctx context.Context) (datastore.Revision, error) { +func (r *relationshipIntegrityProxy) HeadRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { return r.ds.HeadRevision(ctx) } -func (r *relationshipIntegrityProxy) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { +func (r *relationshipIntegrityProxy) OptimizedRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { return r.ds.OptimizedRevision(ctx) } @@ -294,6 +294,65 @@ func (r *relationshipIntegrityProxy) Watch(ctx context.Context, afterRevision da return checkedResultsChan, checkedErrChan } +// SchemaHashReaderForTesting returns a test-only interface for reading the schema hash. +// This delegates to the underlying datastore if it supports the test interface. +func (r *relationshipIntegrityProxy) SchemaHashReaderForTesting() interface { + ReadSchemaHash(ctx context.Context) (string, error) +} { + // Try direct method call using reflection/interface check + type provider interface { + SchemaHashReaderForTesting() interface { + ReadSchemaHash(ctx context.Context) (string, error) + } + } + + // Check delegate directly + if p, ok := r.ds.(provider); ok { + result := p.SchemaHashReaderForTesting() + if result != nil { + return result + } + } + + // Try unwrapping if delegate is itself a proxy + type unwrapper interface { + Unwrap() datastore.Datastore + } + if u, ok := r.ds.(unwrapper); ok { + unwrapped := u.Unwrap() + if p, ok := unwrapped.(provider); ok { + return p.SchemaHashReaderForTesting() + } + } + + return nil +} + +// SchemaModeForTesting returns the current schema mode for testing purposes. +// This delegates to the underlying datastore if it supports the test interface. +func (r *relationshipIntegrityProxy) SchemaModeForTesting() (options.SchemaMode, error) { + type provider interface { + SchemaModeForTesting() (options.SchemaMode, error) + } + + // Check delegate directly + if p, ok := r.ds.(provider); ok { + return p.SchemaModeForTesting() + } + + // Try unwrapping if delegate is itself a proxy + type unwrapper interface { + Unwrap() datastore.Datastore + } + if u, ok := r.ds.(unwrapper); ok { + if p, ok := u.Unwrap().(provider); ok { + return p.SchemaModeForTesting() + } + } + + return options.SchemaModeReadLegacyWriteLegacy, errors.New("delegate datastore does not implement SchemaModeForTesting()") +} + func (r *relationshipIntegrityProxy) Unwrap() datastore.Datastore { return r.ds } @@ -391,12 +450,28 @@ func (r relationshipIntegrityReader) SchemaReader() (datastore.SchemaReader, err return r.wrapped.SchemaReader() } +func (r relationshipIntegrityReader) ReadStoredSchema(ctx context.Context) (*corev1.StoredSchema, error) { + singleStoreReader, ok := r.wrapped.(datastore.SingleStoreSchemaReader) + if !ok { + return nil, errors.New("wrapped reader does not implement SingleStoreSchemaReader") + } + return singleStoreReader.ReadStoredSchema(ctx) +} + type relationshipIntegrityTx struct { datastore.ReadWriteTransaction parent *relationshipIntegrityProxy } +func (r *relationshipIntegrityTx) WriteStoredSchema(ctx context.Context, schema *corev1.StoredSchema) error { + singleStoreWriter, ok := r.ReadWriteTransaction.(datastore.SingleStoreSchemaWriter) + if !ok { + return errors.New("wrapped transaction does not implement SingleStoreSchemaWriter") + } + return singleStoreWriter.WriteStoredSchema(ctx, schema) +} + func (r *relationshipIntegrityTx) WriteRelationships( ctx context.Context, mutations []tuple.RelationshipUpdate, @@ -471,3 +546,15 @@ func (w integrityAddingBulkLoadInterator) Next(ctx context.Context) (*tuple.Rela return rel, nil } + +var ( + _ datastore.Datastore = (*relationshipIntegrityProxy)(nil) + _ datastore.Reader = (*relationshipIntegrityReader)(nil) + _ datastore.LegacySchemaReader = (*relationshipIntegrityReader)(nil) + _ datastore.SingleStoreSchemaReader = (*relationshipIntegrityReader)(nil) + _ datastore.DualSchemaReader = (*relationshipIntegrityReader)(nil) + _ datastore.ReadWriteTransaction = (*relationshipIntegrityTx)(nil) + _ datastore.LegacySchemaWriter = (*relationshipIntegrityTx)(nil) + _ datastore.SingleStoreSchemaWriter = (*relationshipIntegrityTx)(nil) + _ datastore.DualSchemaWriter = (*relationshipIntegrityTx)(nil) +) diff --git a/internal/datastore/proxy/relationshipintegrity_test.go b/internal/datastore/proxy/relationshipintegrity_test.go index 525924111..c484e3d5b 100644 --- a/internal/datastore/proxy/relationshipintegrity_test.go +++ b/internal/datastore/proxy/relationshipintegrity_test.go @@ -92,10 +92,10 @@ func TestReadWithMissingIntegrity(t *testing.T) { pds, err := NewRelationshipIntegrityProxy(ds, DefaultKeyForTesting, nil) require.NoError(t, err) - headRev, err := pds.HeadRevision(t.Context()) + headRev, _, err := pds.HeadRevision(t.Context()) require.NoError(t, err) - reader := pds.SnapshotReader(headRev) + reader := pds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) iter, err := reader.QueryRelationships( t.Context(), datastore.RelationshipsFilter{OptionalResourceType: "resource"}, @@ -141,10 +141,10 @@ func TestBasicIntegrityFailureDueToInvalidHashVersion(t *testing.T) { require.NoError(t, err) // Read them back and ensure the read fails. - headRev, err := pds.HeadRevision(t.Context()) + headRev, _, err := pds.HeadRevision(t.Context()) require.NoError(t, err) - reader := pds.SnapshotReader(headRev) + reader := pds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) iter, err := reader.QueryRelationships( t.Context(), datastore.RelationshipsFilter{OptionalResourceType: "resource"}, @@ -190,10 +190,10 @@ func TestBasicIntegrityFailureDueToInvalidHashSignature(t *testing.T) { require.NoError(t, err) // Read them back and ensure the read fails. - headRev, err := pds.HeadRevision(t.Context()) + headRev, _, err := pds.HeadRevision(t.Context()) require.NoError(t, err) - reader := pds.SnapshotReader(headRev) + reader := pds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) iter, err := reader.QueryRelationships( t.Context(), datastore.RelationshipsFilter{OptionalResourceType: "resource"}, @@ -229,10 +229,10 @@ func TestBasicIntegrityFailureDueToWriteWithExpiredKey(t *testing.T) { require.NoError(t, err) // Read them back and ensure the read fails. - headRev, err := pds.HeadRevision(t.Context()) + headRev, _, err := pds.HeadRevision(t.Context()) require.NoError(t, err) - reader := pds.SnapshotReader(headRev) + reader := pds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) iter, err := reader.QueryRelationships( t.Context(), datastore.RelationshipsFilter{OptionalResourceType: "resource"}, @@ -250,7 +250,7 @@ func TestWatchIntegrityFailureDueToInvalidHashSignature(t *testing.T) { ds, err := dsfortesting.NewMemDBDatastoreForTesting(t, 0, 5*time.Second, 1*time.Hour) require.NoError(t, err) - headRev, err := ds.HeadRevision(t.Context()) + headRev, _, err := ds.HeadRevision(t.Context()) require.NoError(t, err) pds, err := NewRelationshipIntegrityProxy(ds, DefaultKeyForTesting, nil) @@ -310,16 +310,16 @@ func BenchmarkQueryRelsWithIntegrity(b *testing.B) { }) require.NoError(b, err) - headRev, err := pds.HeadRevision(b.Context()) + headRev, _, err := pds.HeadRevision(b.Context()) require.NoError(b, err) b.ResetTimer() for i := 0; i < b.N; i++ { var reader datastore.Reader if withIntegrity { - reader = pds.SnapshotReader(headRev) + reader = pds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) } else { - reader = ds.SnapshotReader(headRev) + reader = ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) } iter, err := reader.QueryRelationships( b.Context(), diff --git a/internal/datastore/proxy/schemacaching/standardcache.go b/internal/datastore/proxy/schemacaching/standardcache.go index 04e72af87..5c6ee7696 100644 --- a/internal/datastore/proxy/schemacaching/standardcache.go +++ b/internal/datastore/proxy/schemacaching/standardcache.go @@ -29,8 +29,8 @@ func (p *definitionCachingProxy) Close() error { return p.Datastore.Close() } -func (p *definitionCachingProxy) SnapshotReader(rev datastore.Revision) datastore.Reader { - delegateReader := p.Datastore.SnapshotReader(rev) +func (p *definitionCachingProxy) SnapshotReader(rev datastore.Revision, hash datastore.SchemaHash) datastore.Reader { + delegateReader := p.Datastore.SnapshotReader(rev, hash) return &definitionCachingReader{delegateReader, rev, p} } @@ -100,13 +100,35 @@ func (r *definitionCachingReader) LegacyLookupCaveatsWithNames( estimatedCaveatDefinitionSize) } -// SchemaReader returns a reference to this reader, but wrapped in such a way -// that the legacy methods are used to drive the new methods, which ensures -// that caching logic stays in place. +// SchemaReader returns a schema reader that respects the underlying datastore's +// schema mode configuration. For new unified schema mode, it passes through directly +// to leverage the hash-based cache. For legacy mode, it wraps the proxy to use +// the per-definition caching methods. func (r *definitionCachingReader) SchemaReader() (datastore.SchemaReader, error) { + // Get the underlying reader's schema reader to determine its configured mode + underlyingSchemaReader, err := r.Reader.SchemaReader() + if err != nil { + return nil, err + } + + // If using new unified schema mode, pass through directly + // The hash-based schema cache handles caching efficiently for unified schemas + if _, isLegacy := underlyingSchemaReader.(*schemautil.LegacySchemaReaderAdapter); !isLegacy { + return underlyingSchemaReader, nil + } + + // For legacy mode, wrap the proxy to ensure per-definition caching is used return schemautil.NewLegacySchemaReaderAdapter(r), nil } +func (r *definitionCachingReader) ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) { + singleStoreReader, ok := r.Reader.(datastore.SingleStoreSchemaReader) + if !ok { + return nil, errors.New("delegate reader does not implement SingleStoreSchemaReader") + } + return singleStoreReader.ReadStoredSchema(ctx) +} + func listAndCache[T schemaDefinition]( ctx context.Context, r *definitionCachingReader, @@ -302,10 +324,20 @@ func (rwt *definitionCachingRWT) SchemaWriter() (datastore.SchemaWriter, error) return schemautil.NewLegacySchemaWriterAdapter(rwt, rwt.ReadWriteTransaction), nil } -// SchemaReader returns a wrapper around the definitionCachingRWT that ensures -// that the caching logic in this proxy is exercised when a handle on the -// SchemaReader is requested. +// SchemaReader returns a schema reader for the transaction. For new unified schema mode, +// it passes through directly. For legacy mode, it wraps the transaction to use caching. func (rwt *definitionCachingRWT) SchemaReader() (datastore.SchemaReader, error) { + underlyingSchemaReader, err := rwt.ReadWriteTransaction.SchemaReader() + if err != nil { + return nil, err + } + + // If using new unified schema mode, pass through directly + if _, isLegacy := underlyingSchemaReader.(*schemautil.LegacySchemaReaderAdapter); !isLegacy { + return underlyingSchemaReader, nil + } + + // For legacy mode, wrap to use transaction-local caching return schemautil.NewLegacySchemaReaderAdapter(rwt), nil } @@ -321,8 +353,12 @@ func (c *cacheEntry) Size() int64 { } var ( - _ datastore.Datastore = &definitionCachingProxy{} - _ datastore.Reader = &definitionCachingReader{} + _ datastore.Datastore = &definitionCachingProxy{} + _ datastore.Reader = &definitionCachingReader{} + _ datastore.LegacySchemaReader = &definitionCachingReader{} + _ datastore.SingleStoreSchemaReader = &definitionCachingReader{} + _ datastore.DualSchemaReader = &definitionCachingReader{} + _ datastore.ReadWriteTransaction = &definitionCachingRWT{} ) func estimatedNamespaceDefinitionSize(sizevt int) int64 { diff --git a/internal/datastore/proxy/schemacaching/standardcaching_test.go b/internal/datastore/proxy/schemacaching/standardcaching_test.go index 628cffd96..40e137f03 100644 --- a/internal/datastore/proxy/schemacaching/standardcaching_test.go +++ b/internal/datastore/proxy/schemacaching/standardcaching_test.go @@ -17,6 +17,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/datastore/proxy/proxy_test" "github.com/authzed/spicedb/internal/datastore/revisions" + schemaadapter "github.com/authzed/spicedb/internal/datastore/schema" "github.com/authzed/spicedb/pkg/caveats" caveattypes "github.com/authzed/spicedb/pkg/caveats/types" "github.com/authzed/spicedb/pkg/datastore" @@ -43,6 +44,11 @@ const ( caveatB = "caveat_b" ) +// setupSchemaReaderMock configures a MockReader to return a legacy schema reader adapter +func setupSchemaReaderMock(reader *proxy_test.MockReader) { + reader.On("SchemaReader").Return(schemaadapter.NewLegacySchemaReaderAdapter(reader), nil) +} + // TestNilUnmarshal asserts that if we get a nil NamespaceDefinition from a // datastore implementation, the process of inserting it into the cache and // back does not break anything. @@ -173,35 +179,35 @@ func TestOldSnapshotCaching(t *testing.T) { require := require.New(t) ds := NewCachingDatastoreProxy(dsMock, dptc, 1*time.Hour, JustInTimeCaching, 100*time.Millisecond) - _, updatedOneA, err := tester.readSingleFunc(t.Context(), ds.SnapshotReader(one), nsA) + _, updatedOneA, err := tester.readSingleFunc(t.Context(), ds.SnapshotReader(one, datastore.NoSchemaHashForTesting), nsA) require.NoError(err) require.True(old.Equal(updatedOneA)) - _, updatedOneAAgain, err := tester.readSingleFunc(t.Context(), ds.SnapshotReader(one), nsA) + _, updatedOneAAgain, err := tester.readSingleFunc(t.Context(), ds.SnapshotReader(one, datastore.NoSchemaHashForTesting), nsA) require.NoError(err) require.True(old.Equal(updatedOneAAgain)) - _, updatedOneB, err := tester.readSingleFunc(t.Context(), ds.SnapshotReader(one), nsB) + _, updatedOneB, err := tester.readSingleFunc(t.Context(), ds.SnapshotReader(one, datastore.NoSchemaHashForTesting), nsB) require.NoError(err) require.True(zero.Equal(updatedOneB)) - _, updatedOneBAgain, err := tester.readSingleFunc(t.Context(), ds.SnapshotReader(one), nsB) + _, updatedOneBAgain, err := tester.readSingleFunc(t.Context(), ds.SnapshotReader(one, datastore.NoSchemaHashForTesting), nsB) require.NoError(err) require.True(zero.Equal(updatedOneBAgain)) - _, updatedTwoA, err := tester.readSingleFunc(t.Context(), ds.SnapshotReader(two), nsA) + _, updatedTwoA, err := tester.readSingleFunc(t.Context(), ds.SnapshotReader(two, datastore.NoSchemaHashForTesting), nsA) require.NoError(err) require.True(zero.Equal(updatedTwoA)) - _, updatedTwoAAgain, err := tester.readSingleFunc(t.Context(), ds.SnapshotReader(two), nsA) + _, updatedTwoAAgain, err := tester.readSingleFunc(t.Context(), ds.SnapshotReader(two, datastore.NoSchemaHashForTesting), nsA) require.NoError(err) require.True(zero.Equal(updatedTwoAAgain)) - _, updatedTwoB, err := tester.readSingleFunc(t.Context(), ds.SnapshotReader(two), nsB) + _, updatedTwoB, err := tester.readSingleFunc(t.Context(), ds.SnapshotReader(two, datastore.NoSchemaHashForTesting), nsB) require.NoError(err) require.True(one.Equal(updatedTwoB)) - _, updatedTwoBAgain, err := tester.readSingleFunc(t.Context(), ds.SnapshotReader(two), nsB) + _, updatedTwoBAgain, err := tester.readSingleFunc(t.Context(), ds.SnapshotReader(two, datastore.NoSchemaHashForTesting), nsB) require.NoError(err) require.True(one.Equal(updatedTwoBAgain)) @@ -217,6 +223,7 @@ func TestSnapshotCaching(t *testing.T) { oneReader := &proxy_test.MockReader{} dsMock.On("SnapshotReader", one).Return(oneReader) + oneReader.On("SchemaReader").Return(schemaadapter.NewLegacySchemaReaderAdapter(oneReader), nil) oneReader.On("LegacyReadNamespaceByName", nsA).Return(nil, old, nil).Once() oneReader.On("LegacyReadNamespaceByName", nsB).Return(nil, zero, nil).Once() oneReader.On("LegacyReadCaveatByName", caveatA).Return(nil, old, nil).Once() @@ -224,6 +231,7 @@ func TestSnapshotCaching(t *testing.T) { twoReader := &proxy_test.MockReader{} dsMock.On("SnapshotReader", two).Return(twoReader) + twoReader.On("SchemaReader").Return(schemaadapter.NewLegacySchemaReaderAdapter(twoReader), nil) twoReader.On("LegacyReadNamespaceByName", nsA).Return(nil, zero, nil).Once() twoReader.On("LegacyReadNamespaceByName", nsB).Return(nil, one, nil).Once() twoReader.On("LegacyReadCaveatByName", caveatA).Return(nil, zero, nil).Once() @@ -238,7 +246,7 @@ func TestSnapshotCaching(t *testing.T) { ds := NewCachingDatastoreProxy(dsMock, dptc, 1*time.Hour, JustInTimeCaching, 100*time.Millisecond) // Get a handle on the reader for A - dsForA := ds.SnapshotReader(one) + dsForA := ds.SnapshotReader(one, datastore.NoSchemaHashForTesting) schemaReaderForA, err := dsForA.SchemaReader() require.NoError(err) @@ -269,7 +277,7 @@ func TestSnapshotCaching(t *testing.T) { require.True(old.Equal(revCaveatDef.LastWrittenRevision)) // Get a handle on the reader for B - dsForB := ds.SnapshotReader(one) + dsForB := ds.SnapshotReader(one, datastore.NoSchemaHashForTesting) schemaReaderForB, err := dsForB.SchemaReader() require.NoError(err) @@ -300,7 +308,7 @@ func TestSnapshotCaching(t *testing.T) { require.True(zero.Equal(revCaveatDef.LastWrittenRevision)) // Get a handle on the second reader for A - dsForA = ds.SnapshotReader(two) + dsForA = ds.SnapshotReader(two, datastore.NoSchemaHashForTesting) schemaReaderForA, err = dsForA.SchemaReader() require.NoError(err) @@ -331,7 +339,7 @@ func TestSnapshotCaching(t *testing.T) { require.True(zero.Equal(revCaveatDef.LastWrittenRevision)) // Get a handle on the second reader for B - dsForB = ds.SnapshotReader(two) + dsForB = ds.SnapshotReader(two, datastore.NoSchemaHashForTesting) schemaReaderForB, err = dsForB.SchemaReader() require.NoError(err) @@ -410,6 +418,7 @@ func TestRWTCaching(t *testing.T) { require := require.New(t) dsMock.On("ReadWriteTx", nilOpts).Return(rwtMock, one, nil).Once() + rwtMock.On("SchemaReader").Return(schemaadapter.NewLegacySchemaReaderAdapter(rwtMock), nil) rwtMock.On("LegacyReadNamespaceByName", nsA).Return(nil, zero, nil).Once() rwtMock.On("LegacyReadCaveatByName", caveatA).Return(nil, zero, nil).Once() @@ -524,7 +533,7 @@ func TestOldSingleFlight(t *testing.T) { ds := NewCachingDatastoreProxy(dsMock, nil, 1*time.Hour, JustInTimeCaching, 100*time.Millisecond) readNamespace := func() error { - _, updatedAt, err := tester.readSingleFunc(t.Context(), ds.SnapshotReader(one), nsA) + _, updatedAt, err := tester.readSingleFunc(t.Context(), ds.SnapshotReader(one, datastore.NoSchemaHashForTesting), nsA) require.NoError(err) require.True(old.Equal(updatedAt)) return err @@ -548,6 +557,7 @@ func TestSingleFlight(t *testing.T) { oneReader := &proxy_test.MockReader{} dsMock.On("SnapshotReader", one).Return(oneReader) + oneReader.On("SchemaReader").Return(schemaadapter.NewLegacySchemaReaderAdapter(oneReader), nil) oneReader. On("LegacyReadNamespaceByName", nsA). WaitUntil(time.After(50*time.Millisecond)). @@ -562,7 +572,7 @@ func TestSingleFlight(t *testing.T) { assert := assert.New(t) ds := NewCachingDatastoreProxy(dsMock, nil, 1*time.Hour, JustInTimeCaching, 100*time.Millisecond) - snapshotReader := ds.SnapshotReader(one) + snapshotReader := ds.SnapshotReader(one, datastore.NoSchemaHashForTesting) schemaReader, err := snapshotReader.SchemaReader() if !assert.NoError(err) { //nolint:testifylint // you can't use require within a goroutine; the linter is wrong. return @@ -662,10 +672,10 @@ func TestOldSnapshotCachingRealDatastore(t *testing.T) { require.NoError(t, err) } - headRev, err := ds.HeadRevision(ctx) + headRev, _, err := ds.HeadRevision(ctx) require.NoError(t, err) - reader := ds.SnapshotReader(headRev) + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) ns, _, _ := reader.LegacyReadNamespaceByName(ctx, tc.namespaceName) testutil.RequireProtoEqual(t, tc.nsDef, ns, "found different namespaces") @@ -739,10 +749,10 @@ func TestSnapshotCachingRealDatastore(t *testing.T) { require.NoError(t, err) } - headRev, err := ds.HeadRevision(ctx) + headRev, _, err := ds.HeadRevision(ctx) require.NoError(t, err) - reader := ds.SnapshotReader(headRev) + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) schemaReader, err := reader.SchemaReader() require.NoError(t, err) @@ -771,6 +781,10 @@ type singleflightReader struct { proxy_test.MockReader } +func (r *singleflightReader) SchemaReader() (datastore.SchemaReader, error) { + return schemaadapter.NewLegacySchemaReaderAdapter(r), nil +} + func (r *singleflightReader) LegacyReadNamespaceByName(ctx context.Context, namespace string) (ns *core.NamespaceDefinition, lastWritten datastore.Revision, err error) { // NOTE: the sleep is here to ensure that the context can be cancelled before this executes. time.Sleep(10 * time.Millisecond) @@ -829,12 +843,12 @@ func TestOldSingleFlightCancelled(t *testing.T) { var d2 datastore.SchemaDefinition g.Add(2) go func() { - _, _, _ = tester.readSingleFunc(ctx1, ds.SnapshotReader(one), nsA) + _, _, _ = tester.readSingleFunc(ctx1, ds.SnapshotReader(one, datastore.NoSchemaHashForTesting), nsA) g.Done() }() go func() { time.Sleep(5 * time.Millisecond) - d2, _, _ = tester.readSingleFunc(ctx2, ds.SnapshotReader(one), nsA) + d2, _, _ = tester.readSingleFunc(ctx2, ds.SnapshotReader(one, datastore.NoSchemaHashForTesting), nsA) g.Done() }() cancel1() @@ -859,7 +873,7 @@ func TestSingleFlightCancelled(t *testing.T) { dsMock.On("SnapshotReader", one).Return(&singleflightReader{MockReader: proxy_test.MockReader{}}) ds := NewCachingDatastoreProxy(dsMock, nil, 1*time.Hour, JustInTimeCaching, 100*time.Millisecond) - snapshotReader := ds.SnapshotReader(one) + snapshotReader := ds.SnapshotReader(one, datastore.NoSchemaHashForTesting) schemaReader, err := snapshotReader.SchemaReader() if !assert.NoError(t, err) { //nolint:testifylint // you can't use require within a goroutine; the linter is wrong. return @@ -918,7 +932,7 @@ func TestOldMixedCaching(t *testing.T) { require := require.New(t) ds := NewCachingDatastoreProxy(dsMock, dptc, 1*time.Hour, JustInTimeCaching, 100*time.Millisecond) - dsReader := ds.SnapshotReader(one) + dsReader := ds.SnapshotReader(one, datastore.NoSchemaHashForTesting) // Lookup name A _, _, err := tester.readSingleFunc(t.Context(), dsReader, nsA) @@ -975,6 +989,7 @@ func TestMixedCaching(t *testing.T) { reader := &proxy_test.MockReader{} reader.Test(t) + setupSchemaReaderMock(reader) reader.On("LegacyReadNamespaceByName", nsA).Return(nsDefA, old, nil).Once() reader.On("LegacyReadCaveatByName", caveatA).Return(caveatDefA, old, nil).Once() // NOTE: the mocks here only expect the Bs because the caching layer is going to @@ -1003,7 +1018,7 @@ func TestMixedCaching(t *testing.T) { require := require.New(t) ds := NewCachingDatastoreProxy(dsMock, dptc, 1*time.Hour, JustInTimeCaching, 100*time.Millisecond) - dsReader := ds.SnapshotReader(one) + dsReader := ds.SnapshotReader(one, datastore.NoSchemaHashForTesting) schemaReader, err := dsReader.SchemaReader() require.NoError(err) @@ -1059,9 +1074,9 @@ func TestOldInvalidNamespaceInCache(t *testing.T) { ds.Close() }) - headRevision, err := ds.HeadRevision(ctx) + headRevision, _, err := ds.HeadRevision(ctx) require.NoError(err) - dsReader := ds.SnapshotReader(headRevision) + dsReader := ds.SnapshotReader(headRevision, datastore.NoSchemaHashForTesting) namespace, _, err := dsReader.LegacyReadNamespaceByName(ctx, invalidNamespace) require.Nil(namespace) @@ -1094,9 +1109,9 @@ func TestInvalidNamespaceInCache(t *testing.T) { ds.Close() }) - headRevision, err := ds.HeadRevision(ctx) + headRevision, _, err := ds.HeadRevision(ctx) require.NoError(err) - dsReader := ds.SnapshotReader(headRevision) + dsReader := ds.SnapshotReader(headRevision, datastore.NoSchemaHashForTesting) schemaReader, err := dsReader.SchemaReader() require.NoError(err) @@ -1142,7 +1157,7 @@ func TestOldMixedInvalidNamespacesInCache(t *testing.T) { }) require.NoError(err) - dsReader := ds.SnapshotReader(revision) + dsReader := ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting) namespace, _, err := dsReader.LegacyReadNamespaceByName(ctx, invalidNamespace) require.Nil(namespace) @@ -1189,7 +1204,7 @@ func TestMixedInvalidNamespacesInCache(t *testing.T) { }) require.NoError(err) - dsReader := ds.SnapshotReader(revision) + dsReader := ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting) schemaReader, err := dsReader.SchemaReader() require.NoError(err) diff --git a/internal/datastore/proxy/schemacaching/watchingcache.go b/internal/datastore/proxy/schemacaching/watchingcache.go index 20987905f..08ddd6e33 100644 --- a/internal/datastore/proxy/schemacaching/watchingcache.go +++ b/internal/datastore/proxy/schemacaching/watchingcache.go @@ -96,10 +96,12 @@ func createWatchingCacheProxy(delegate datastore.Datastore, c cache.Cache[cache. "namespace", datastore.NewNamespaceNotFoundErr, func(ctx context.Context, name string, revision datastore.Revision) (*core.NamespaceDefinition, datastore.Revision, error) { - return fallbackCache.SnapshotReader(revision).LegacyReadNamespaceByName(ctx, name) + // Fallback operation - load schema on demand + return fallbackCache.SnapshotReader(revision, datastore.NoSchemaHashForWatch).LegacyReadNamespaceByName(ctx, name) }, func(ctx context.Context, names []string, revision datastore.Revision) ([]datastore.RevisionedDefinition[*core.NamespaceDefinition], error) { - return fallbackCache.SnapshotReader(revision).LegacyLookupNamespacesWithNames(ctx, names) + // Fallback operation - load schema on demand + return fallbackCache.SnapshotReader(revision, datastore.NoSchemaHashForWatch).LegacyLookupNamespacesWithNames(ctx, names) }, definitionsReadCachedCounter, definitionsReadTotalCounter, @@ -109,10 +111,12 @@ func createWatchingCacheProxy(delegate datastore.Datastore, c cache.Cache[cache. "caveat", datastore.NewCaveatNameNotFoundErr, func(ctx context.Context, name string, revision datastore.Revision) (*core.CaveatDefinition, datastore.Revision, error) { - return fallbackCache.SnapshotReader(revision).LegacyReadCaveatByName(ctx, name) + // Fallback operation - load schema on demand + return fallbackCache.SnapshotReader(revision, datastore.NoSchemaHashForWatch).LegacyReadCaveatByName(ctx, name) }, func(ctx context.Context, names []string, revision datastore.Revision) ([]datastore.RevisionedDefinition[*core.CaveatDefinition], error) { - return fallbackCache.SnapshotReader(revision).LegacyLookupCaveatsWithNames(ctx, names) + // Fallback operation - load schema on demand + return fallbackCache.SnapshotReader(revision, datastore.NoSchemaHashForWatch).LegacyLookupCaveatsWithNames(ctx, names) }, definitionsReadCachedCounter, definitionsReadTotalCounter, @@ -122,8 +126,8 @@ func createWatchingCacheProxy(delegate datastore.Datastore, c cache.Cache[cache. return proxy } -func (p *watchingCachingProxy) SnapshotReader(rev datastore.Revision) datastore.Reader { - delegateReader := p.Datastore.SnapshotReader(rev) +func (p *watchingCachingProxy) SnapshotReader(rev datastore.Revision, hash datastore.SchemaHash) datastore.Reader { + delegateReader := p.Datastore.SnapshotReader(rev, hash) return &watchingCachingReader{delegateReader, rev, p} } @@ -149,7 +153,7 @@ func (p *watchingCachingProxy) Start(ctx context.Context) error { func (p *watchingCachingProxy) startSync(ctx context.Context) error { log.Info().Msg("starting watching cache") - headRev, err := p.HeadRevision(context.Background()) + headRev, _, err := p.HeadRevision(context.Background()) if err != nil { p.namespaceCache.setFallbackMode() p.caveatCache.setFallbackMode() @@ -193,7 +197,8 @@ func (p *watchingCachingProxy) startSync(ctx context.Context) error { p.caveatCache.reset() log.Debug().Str("revision", headRev.String()).Msg("starting watching cache watch operation") - reader := p.Datastore.SnapshotReader(headRev) + // Watch cache rebuild - load schema on demand + reader := p.Datastore.SnapshotReader(headRev, datastore.NoSchemaHashForWatch) // Populate the cache with all definitions at the head revision. log.Info().Str("revision", headRev.String()).Msg("prepopulating namespace watching cache") @@ -633,5 +638,32 @@ func (w *watchingCachingReader) LegacyLookupCaveatsWithNames( } func (w *watchingCachingReader) SchemaReader() (datastore.SchemaReader, error) { + underlyingSchemaReader, err := w.Reader.SchemaReader() + if err != nil { + return nil, err + } + + // If using new unified schema mode, pass through directly + if _, isLegacy := underlyingSchemaReader.(*schemautil.LegacySchemaReaderAdapter); !isLegacy { + return underlyingSchemaReader, nil + } + + // For legacy mode, wrap to use the watching cache return schemautil.NewLegacySchemaReaderAdapter(w), nil } + +func (w *watchingCachingReader) ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) { + singleStoreReader, ok := w.Reader.(datastore.SingleStoreSchemaReader) + if !ok { + return nil, errors.New("delegate reader does not implement SingleStoreSchemaReader") + } + return singleStoreReader.ReadStoredSchema(ctx) +} + +var ( + _ datastore.Datastore = (*watchingCachingProxy)(nil) + _ datastore.Reader = (*watchingCachingReader)(nil) + _ datastore.LegacySchemaReader = (*watchingCachingReader)(nil) + _ datastore.SingleStoreSchemaReader = (*watchingCachingReader)(nil) + _ datastore.DualSchemaReader = (*watchingCachingReader)(nil) +) diff --git a/internal/datastore/proxy/schemacaching/watchingcache_test.go b/internal/datastore/proxy/schemacaching/watchingcache_test.go index fb8489560..9d4b6302e 100644 --- a/internal/datastore/proxy/schemacaching/watchingcache_test.go +++ b/internal/datastore/proxy/schemacaching/watchingcache_test.go @@ -40,12 +40,12 @@ func TestOldWatchingCacheBasicOperation(t *testing.T) { }) // Ensure no namespaces are found. - _, _, err := wcache.SnapshotReader(rev("1")).LegacyReadNamespaceByName(t.Context(), "somenamespace") + _, _, err := wcache.SnapshotReader(rev("1"), datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(t.Context(), "somenamespace") require.ErrorAs(t, err, &datastore.NamespaceNotFoundError{}) require.False(t, wcache.namespaceCache.inFallbackMode) // Ensure a re-read also returns not found, even before a checkpoint is received. - _, _, err = wcache.SnapshotReader(rev("1")).LegacyReadNamespaceByName(t.Context(), "somenamespace") + _, _, err = wcache.SnapshotReader(rev("1"), datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(t.Context(), "somenamespace") require.ErrorAs(t, err, &datastore.NamespaceNotFoundError{}) // Send a checkpoint for revision 1. @@ -55,7 +55,7 @@ func TestOldWatchingCacheBasicOperation(t *testing.T) { fakeDS.updateNamespace("somenamespace", &corev1.NamespaceDefinition{Name: "somenamespace"}, rev("2")) // Ensure that reading at rev 2 returns found. - nsDef, _, err := wcache.SnapshotReader(rev("2")).LegacyReadNamespaceByName(t.Context(), "somenamespace") + nsDef, _, err := wcache.SnapshotReader(rev("2"), datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(t.Context(), "somenamespace") require.NoError(t, err) require.Equal(t, "somenamespace", nsDef.Name) @@ -63,7 +63,7 @@ func TestOldWatchingCacheBasicOperation(t *testing.T) { fakeDS.disableReads() // Ensure that reading at rev 3 returns an error, as with reads disabled the cache should not be hit. - _, _, err = wcache.SnapshotReader(rev("3")).LegacyReadNamespaceByName(t.Context(), "somenamespace") + _, _, err = wcache.SnapshotReader(rev("3"), datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(t.Context(), "somenamespace") require.Error(t, err) require.ErrorContains(t, err, "reads are disabled") @@ -72,7 +72,7 @@ func TestOldWatchingCacheBasicOperation(t *testing.T) { // Ensure that reading at rev 3 returns found, even though the cache should not yet be there. This will // require a datastore fallback read because the cache is not yet checkedpointed to that revision. - nsDef, _, err = wcache.SnapshotReader(rev("3")).LegacyReadNamespaceByName(t.Context(), "somenamespace") + nsDef, _, err = wcache.SnapshotReader(rev("3"), datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(t.Context(), "somenamespace") require.NoError(t, err) require.Equal(t, "somenamespace", nsDef.Name) @@ -84,12 +84,12 @@ func TestOldWatchingCacheBasicOperation(t *testing.T) { fakeDS.disableReads() // Read again, which should now be via the cache. - nsDef, _, err = wcache.SnapshotReader(rev("3.0000000005")).LegacyReadNamespaceByName(t.Context(), "somenamespace") + nsDef, _, err = wcache.SnapshotReader(rev("3.0000000005"), datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(t.Context(), "somenamespace") require.NoError(t, err) require.Equal(t, "somenamespace", nsDef.Name) // Read via a lookup. - nsDefs, err := wcache.SnapshotReader(rev("3.0000000005")).LegacyLookupNamespacesWithNames(t.Context(), []string{"somenamespace"}) + nsDefs, err := wcache.SnapshotReader(rev("3.0000000005"), datastore.NoSchemaHashForTesting).LegacyLookupNamespacesWithNames(t.Context(), []string{"somenamespace"}) require.NoError(t, err) require.Equal(t, "somenamespace", nsDefs[0].Definition.Name) @@ -97,17 +97,17 @@ func TestOldWatchingCacheBasicOperation(t *testing.T) { fakeDS.updateNamespace("somenamespace", nil, rev("5")) // Re-read at an earlier revision. - nsDef, _, err = wcache.SnapshotReader(rev("3.0000000005")).LegacyReadNamespaceByName(t.Context(), "somenamespace") + nsDef, _, err = wcache.SnapshotReader(rev("3.0000000005"), datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(t.Context(), "somenamespace") require.NoError(t, err) require.Equal(t, "somenamespace", nsDef.Name) // Read at revision 5. - _, _, err = wcache.SnapshotReader(rev("5")).LegacyReadNamespaceByName(t.Context(), "somenamespace") + _, _, err = wcache.SnapshotReader(rev("5"), datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(t.Context(), "somenamespace") require.Error(t, err) require.ErrorAs(t, err, &datastore.NamespaceNotFoundError{}, "missing not found in: %v", err) // Lookup at revision 5. - nsDefs, err = wcache.SnapshotReader(rev("5")).LegacyLookupNamespacesWithNames(t.Context(), []string{"somenamespace"}) + nsDefs, err = wcache.SnapshotReader(rev("5"), datastore.NoSchemaHashForTesting).LegacyLookupNamespacesWithNames(t.Context(), []string{"somenamespace"}) require.NoError(t, err) require.Empty(t, nsDefs) @@ -115,12 +115,12 @@ func TestOldWatchingCacheBasicOperation(t *testing.T) { fakeDS.updateCaveat("somecaveat", &corev1.CaveatDefinition{Name: "somecaveat"}, rev("6")) // Read at revision 6. - caveatDef, _, err := wcache.SnapshotReader(rev("6")).LegacyReadCaveatByName(t.Context(), "somecaveat") + caveatDef, _, err := wcache.SnapshotReader(rev("6"), datastore.NoSchemaHashForTesting).LegacyReadCaveatByName(t.Context(), "somecaveat") require.NoError(t, err) require.Equal(t, "somecaveat", caveatDef.Name) // Attempt to read at revision 1, which should require a read. - _, _, err = wcache.SnapshotReader(rev("1")).LegacyReadCaveatByName(t.Context(), "somecaveat") + _, _, err = wcache.SnapshotReader(rev("1"), datastore.NoSchemaHashForTesting).LegacyReadCaveatByName(t.Context(), "somecaveat") require.ErrorContains(t, err, "reads are disabled") } @@ -143,7 +143,7 @@ func TestWatchingCacheBasicOperation(t *testing.T) { wcache.Close() }) - reader := wcache.SnapshotReader(rev("1")) + reader := wcache.SnapshotReader(rev("1"), datastore.NoSchemaHashForTesting) schemaReader, err := reader.SchemaReader() require.NoError(t, err) @@ -165,7 +165,7 @@ func TestWatchingCacheBasicOperation(t *testing.T) { fakeDS.updateNamespace("somenamespace", &corev1.NamespaceDefinition{Name: "somenamespace"}, rev("2")) // Get a handle on the schemaReader at that revision - reader = wcache.SnapshotReader(rev("2")) + reader = wcache.SnapshotReader(rev("2"), datastore.NoSchemaHashForTesting) schemaReader, err = reader.SchemaReader() require.NoError(t, err) @@ -179,7 +179,7 @@ func TestWatchingCacheBasicOperation(t *testing.T) { fakeDS.disableReads() // Get a handle on the schemaReader at that revision - reader = wcache.SnapshotReader(rev("3")) + reader = wcache.SnapshotReader(rev("3"), datastore.NoSchemaHashForTesting) schemaReader, err = reader.SchemaReader() require.NoError(t, err) @@ -205,7 +205,7 @@ func TestWatchingCacheBasicOperation(t *testing.T) { fakeDS.disableReads() // Get a handle on the schemaReader at that revision - reader = wcache.SnapshotReader(rev("3.0000000005")) + reader = wcache.SnapshotReader(rev("3.0000000005"), datastore.NoSchemaHashForTesting) schemaReader, err = reader.SchemaReader() require.NoError(t, err) @@ -229,7 +229,7 @@ func TestWatchingCacheBasicOperation(t *testing.T) { require.Equal(t, "somenamespace", nsRevDef.Definition.Name) // Get a handle on the schemaReader at rev 5 - reader = wcache.SnapshotReader(rev("5")) + reader = wcache.SnapshotReader(rev("5"), datastore.NoSchemaHashForTesting) schemaReader, err = reader.SchemaReader() require.NoError(t, err) @@ -247,7 +247,7 @@ func TestWatchingCacheBasicOperation(t *testing.T) { fakeDS.updateCaveat("somecaveat", &corev1.CaveatDefinition{Name: "somecaveat"}, rev("6")) // Get a handle on the schemaReader at rev 6 - reader = wcache.SnapshotReader(rev("6")) + reader = wcache.SnapshotReader(rev("6"), datastore.NoSchemaHashForTesting) schemaReader, err = reader.SchemaReader() require.NoError(t, err) @@ -263,7 +263,7 @@ func TestWatchingCacheBasicOperation(t *testing.T) { require.Equal(t, "somecaveat", caveatDefMap["somecaveat"].GetName()) // Get a handle on the schemaReader at rev 1 - reader = wcache.SnapshotReader(rev("1")) + reader = wcache.SnapshotReader(rev("1"), datastore.NoSchemaHashForTesting) schemaReader, err = reader.SchemaReader() require.NoError(t, err) @@ -306,7 +306,7 @@ func TestOldWatchingCacheParallelOperations(t *testing.T) { defer wg.Done() // Read somenamespace (which should not be found) - _, _, err := wcache.SnapshotReader(rev("1")).LegacyReadNamespaceByName(t.Context(), "somenamespace") + _, _, err := wcache.SnapshotReader(rev("1"), datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(t.Context(), "somenamespace") firstErrs <- err firstFallbackModes <- wcache.namespaceCache.inFallbackMode @@ -314,7 +314,7 @@ func TestOldWatchingCacheParallelOperations(t *testing.T) { fakeDS.updateNamespace("somenamespace", &corev1.NamespaceDefinition{Name: "somenamespace"}, rev("2")) // Read again (which should be found now) - nsDef, _, err := wcache.SnapshotReader(rev("2")).LegacyReadNamespaceByName(t.Context(), "somenamespace") + nsDef, _, err := wcache.SnapshotReader(rev("2"), datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(t.Context(), "somenamespace") firstErrs <- err firstNsDefNames <- nsDef.Name })() @@ -323,12 +323,12 @@ func TestOldWatchingCacheParallelOperations(t *testing.T) { defer wg.Done() // Read anothernamespace (which should not be found) - _, _, err := wcache.SnapshotReader(rev("1")).LegacyReadNamespaceByName(t.Context(), "anothernamespace") + _, _, err := wcache.SnapshotReader(rev("1"), datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(t.Context(), "anothernamespace") secondErrs <- err secondFallbackModes <- wcache.namespaceCache.inFallbackMode // Read again (which should still not be found) - _, _, err = wcache.SnapshotReader(rev("3")).LegacyReadNamespaceByName(t.Context(), "anothernamespace") + _, _, err = wcache.SnapshotReader(rev("3"), datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(t.Context(), "anothernamespace") secondErrs <- err secondFallbackModes <- wcache.namespaceCache.inFallbackMode })() @@ -392,7 +392,7 @@ func TestWatchingCacheParallelOperations(t *testing.T) { defer wg.Done() // Read somenamespace (which should not be found) - schemaReader, err := wcache.SnapshotReader(rev("1")).SchemaReader() + schemaReader, err := wcache.SnapshotReader(rev("1"), datastore.NoSchemaHashForTesting).SchemaReader() assert.NoError(t, err) _, found, err := schemaReader.LookupTypeDefByName(t.Context(), "somenamespace") assert.NoError(t, err) @@ -403,7 +403,7 @@ func TestWatchingCacheParallelOperations(t *testing.T) { fakeDS.updateNamespace("somenamespace", &corev1.NamespaceDefinition{Name: "somenamespace"}, rev("2")) // Read again (which should be found now) - schemaReader, err = wcache.SnapshotReader(rev("2")).SchemaReader() + schemaReader, err = wcache.SnapshotReader(rev("2"), datastore.NoSchemaHashForTesting).SchemaReader() assert.NoError(t, err) nsRevDef, _, err := schemaReader.LookupTypeDefByName(t.Context(), "somenamespace") assert.NoError(t, err, "expected namespace read from rev 2 to succeed") @@ -414,7 +414,7 @@ func TestWatchingCacheParallelOperations(t *testing.T) { defer wg.Done() // Read anothernamespace (which should not be found) - schemaReader, err := wcache.SnapshotReader(rev("1")).SchemaReader() + schemaReader, err := wcache.SnapshotReader(rev("1"), datastore.NoSchemaHashForTesting).SchemaReader() if !assert.NoError(t, err) { //nolint:testifylint // you can't use require within a goroutine; the linter is wrong. return } @@ -426,7 +426,7 @@ func TestWatchingCacheParallelOperations(t *testing.T) { assert.False(t, wcache.namespaceCache.inFallbackMode) // Read again (which should still not be found) - schemaReader, err = wcache.SnapshotReader(rev("3")).SchemaReader() + schemaReader, err = wcache.SnapshotReader(rev("3"), datastore.NoSchemaHashForTesting).SchemaReader() if !assert.NoError(t, err) { //nolint:testifylint // you can't use require within a goroutine; the linter is wrong. return } @@ -480,10 +480,10 @@ func TestWatchingCacheParallelReaderWriter(t *testing.T) { go func() { // Start a loop to read a namespace a bunch of times. for i := 0; i < 1000; i++ { - headRevision, err := fakeDS.HeadRevision(t.Context()) + headRevision, _, err := fakeDS.HeadRevision(t.Context()) assert.NoError(t, err) - schemaReader, err := wcache.SnapshotReader(headRevision).SchemaReader() + schemaReader, err := wcache.SnapshotReader(headRevision, datastore.NoSchemaHashForTesting).SchemaReader() assert.NoError(t, err) nsRevDef, _, err := schemaReader.LookupTypeDefByName(t.Context(), "somenamespace") assert.NoError(t, err) @@ -539,10 +539,10 @@ func TestOldWatchingCacheParallelReaderWriter(t *testing.T) { go (func() { // Start a loop to read a namespace a bunch of times. for i := 0; i < 1000; i++ { - headRevision, err := fakeDS.HeadRevision(t.Context()) + headRevision, _, err := fakeDS.HeadRevision(t.Context()) headRevisionErrors <- err - nsDef, _, err := wcache.SnapshotReader(headRevision).LegacyReadNamespaceByName(t.Context(), "somenamespace") + nsDef, _, err := wcache.SnapshotReader(headRevision, datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(t.Context(), "somenamespace") snapshotReaderErrors <- err namespaceNames <- nsDef.Name } @@ -593,7 +593,7 @@ func TestWatchingCacheFallbackToStandardCache(t *testing.T) { // Ensure the namespace is not found, but is cached in the fallback caching layer. r := rev("1") - schemaReader, err := wcache.SnapshotReader(r).SchemaReader() + schemaReader, err := wcache.SnapshotReader(r, datastore.NoSchemaHashForTesting).SchemaReader() require.NoError(t, err) _, found, err := schemaReader.LookupTypeDefByName(t.Context(), "somenamespace") require.NoError(t, err) @@ -642,7 +642,7 @@ func TestOldWatchingCacheFallbackToStandardCache(t *testing.T) { // Ensure the namespace is not found, but is cached in the fallback caching layer. r := rev("1") - _, _, err = wcache.SnapshotReader(r).LegacyReadNamespaceByName(t.Context(), "somenamespace") + _, _, err = wcache.SnapshotReader(r, datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(t.Context(), "somenamespace") require.ErrorAs(t, err, &datastore.NamespaceNotFoundError{}) require.False(t, wcache.namespaceCache.inFallbackMode) @@ -654,7 +654,7 @@ func TestOldWatchingCacheFallbackToStandardCache(t *testing.T) { // Disable reading and ensure it still works, via the fallback cache. fakeDS.readsDisabled = true - _, _, err = wcache.SnapshotReader(rev("1")).LegacyReadNamespaceByName(t.Context(), "somenamespace") + _, _, err = wcache.SnapshotReader(rev("1"), datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(t.Context(), "somenamespace") require.ErrorAs(t, err, &datastore.NamespaceNotFoundError{}) require.False(t, wcache.namespaceCache.inFallbackMode) } @@ -700,7 +700,7 @@ func TestOldWatchingCachePrepopulated(t *testing.T) { }) // Ensure the namespace is found. - def, _, err := wcache.SnapshotReader(rev("4")).LegacyReadNamespaceByName(t.Context(), "somenamespace") + def, _, err := wcache.SnapshotReader(rev("4"), datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(t.Context(), "somenamespace") require.NoError(t, err) require.Equal(t, "somenamespace", def.Name) } @@ -746,7 +746,7 @@ func TestWatchingCachePrepopulated(t *testing.T) { }) // Ensure the namespace is found. - schemaReader, err := wcache.SnapshotReader(rev("4")).SchemaReader() + schemaReader, err := wcache.SnapshotReader(rev("4"), datastore.NoSchemaHashForTesting).SchemaReader() require.NoError(t, err) revDef, _, err := schemaReader.LookupTypeDefByName(t.Context(), "somenamespace") require.NoError(t, err) @@ -921,15 +921,15 @@ func (fds *fakeDatastore) enableReads() { fds.readsDisabled = false } -func (fds *fakeDatastore) SnapshotReader(rev datastore.Revision) datastore.Reader { +func (fds *fakeDatastore) SnapshotReader(rev datastore.Revision, _ datastore.SchemaHash) datastore.Reader { return &fakeSnapshotReader{fds, rev} } -func (fds *fakeDatastore) HeadRevision(context.Context) (datastore.Revision, error) { +func (fds *fakeDatastore) HeadRevision(context.Context) (datastore.Revision, datastore.SchemaHash, error) { fds.lock.RLock() defer fds.lock.RUnlock() - return fds.headRevision, nil + return fds.headRevision, datastore.NoSchemaHashForTesting, nil } func (*fakeDatastore) ReadWriteTx(context.Context, datastore.TxUserFunc, ...options.RWTOptionsOption) (datastore.Revision, error) { @@ -952,8 +952,8 @@ func (*fakeDatastore) OfflineFeatures() (*datastore.Features, error) { return nil, fmt.Errorf("not implemented") } -func (*fakeDatastore) OptimizedRevision(context.Context) (datastore.Revision, error) { - return nil, fmt.Errorf("not implemented") +func (*fakeDatastore) OptimizedRevision(context.Context) (datastore.Revision, datastore.SchemaHash, error) { + return nil, "", fmt.Errorf("not implemented") } func (*fakeDatastore) ReadyState(context.Context) (datastore.ReadyState, error) { diff --git a/internal/datastore/proxy/singleflight.go b/internal/datastore/proxy/singleflight.go index 410ef0093..2f73535ee 100644 --- a/internal/datastore/proxy/singleflight.go +++ b/internal/datastore/proxy/singleflight.go @@ -16,7 +16,6 @@ func NewSingleflightDatastoreProxy(d datastore.Datastore) datastore.Datastore { } type singleflightProxy struct { - headRevGroup singleflight.Group[string, datastore.Revision] checkRevGroup singleflight.Group[string, string] statsGroup singleflight.Group[string, datastore.Stats] delegate datastore.Datastore @@ -32,15 +31,15 @@ func (p *singleflightProxy) UniqueID(ctx context.Context) (string, error) { return p.delegate.UniqueID(ctx) } -func (p *singleflightProxy) SnapshotReader(rev datastore.Revision) datastore.Reader { - return p.delegate.SnapshotReader(rev) +func (p *singleflightProxy) SnapshotReader(rev datastore.Revision, schemaHash datastore.SchemaHash) datastore.Reader { + return p.delegate.SnapshotReader(rev, schemaHash) } func (p *singleflightProxy) ReadWriteTx(ctx context.Context, f datastore.TxUserFunc, opts ...options.RWTOptionsOption) (datastore.Revision, error) { return p.delegate.ReadWriteTx(ctx, f, opts...) } -func (p *singleflightProxy) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { +func (p *singleflightProxy) OptimizedRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { // NOTE: Optimized revisions are singleflighted by the underlying datastore via the // CachedOptimizedRevisions struct. return p.delegate.OptimizedRevision(ctx) @@ -53,11 +52,9 @@ func (p *singleflightProxy) CheckRevision(ctx context.Context, revision datastor return err } -func (p *singleflightProxy) HeadRevision(ctx context.Context) (datastore.Revision, error) { - rev, _, err := p.headRevGroup.Do(ctx, "", func(ctx context.Context) (datastore.Revision, error) { - return p.delegate.HeadRevision(ctx) - }) - return rev, err +func (p *singleflightProxy) HeadRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { + // Singleflight doesn't support tuple returns, so just delegate directly + return p.delegate.HeadRevision(ctx) } func (p *singleflightProxy) RevisionFromString(serialized string) (datastore.Revision, error) { diff --git a/internal/datastore/proxy/strictreplicated.go b/internal/datastore/proxy/strictreplicated.go index f4c2d5fa8..3c46b0f49 100644 --- a/internal/datastore/proxy/strictreplicated.go +++ b/internal/datastore/proxy/strictreplicated.go @@ -74,7 +74,7 @@ type strictReplicatedDatastore struct { // SnapshotReader creates a read-only handle that reads the datastore at the specified revision. // Any errors establishing the reader will be returned by subsequent calls. -func (rd *strictReplicatedDatastore) SnapshotReader(revision datastore.Revision) datastore.Reader { +func (rd *strictReplicatedDatastore) SnapshotReader(revision datastore.Revision, schemaHash datastore.SchemaHash) datastore.Reader { replica := selectReplica(rd.replicas, &rd.lastReplica) replicaID, err := replica.MetricsID() if err != nil { @@ -83,10 +83,11 @@ func (rd *strictReplicatedDatastore) SnapshotReader(revision datastore.Revision) } return &strictReadReplicatedReader{ - rev: revision, - replica: replica, - replicaID: replicaID, - primary: rd.Datastore, + rev: revision, + schemaHash: schemaHash, + replica: replica, + replicaID: replicaID, + primary: rd.Datastore, } } @@ -98,38 +99,39 @@ func (rd *strictReplicatedDatastore) SnapshotReader(revision datastore.Revision) // read mode enabled, to ensure the query will fail with a RevisionUnavailableError if the revision is // not available. type strictReadReplicatedReader struct { - rev datastore.Revision - replica datastore.ReadOnlyDatastore - replicaID string - primary datastore.Datastore + rev datastore.Revision + schemaHash datastore.SchemaHash + replica datastore.ReadOnlyDatastore + replicaID string + primary datastore.Datastore } func (rr *strictReadReplicatedReader) LegacyReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { - sr := rr.replica.SnapshotReader(rr.rev) + sr := rr.replica.SnapshotReader(rr.rev, rr.schemaHash) caveat, lastWritten, err := sr.LegacyReadCaveatByName(ctx, name) if err != nil && errors.As(err, &common.RevisionUnavailableError{}) { log.Trace().Str("caveat", name).Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") - return rr.primary.SnapshotReader(rr.rev).LegacyReadCaveatByName(ctx, name) + return rr.primary.SnapshotReader(rr.rev, rr.schemaHash).LegacyReadCaveatByName(ctx, name) } return caveat, lastWritten, err } func (rr *strictReadReplicatedReader) LegacyListAllCaveats(ctx context.Context) ([]datastore.RevisionedCaveat, error) { - sr := rr.replica.SnapshotReader(rr.rev) + sr := rr.replica.SnapshotReader(rr.rev, rr.schemaHash) caveats, err := sr.LegacyListAllCaveats(ctx) if err != nil && errors.As(err, &common.RevisionUnavailableError{}) { log.Trace().Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") - return rr.primary.SnapshotReader(rr.rev).LegacyListAllCaveats(ctx) + return rr.primary.SnapshotReader(rr.rev, rr.schemaHash).LegacyListAllCaveats(ctx) } return caveats, err } func (rr *strictReadReplicatedReader) LegacyLookupCaveatsWithNames(ctx context.Context, names []string) ([]datastore.RevisionedCaveat, error) { - sr := rr.replica.SnapshotReader(rr.rev) + sr := rr.replica.SnapshotReader(rr.rev, rr.schemaHash) caveats, err := sr.LegacyLookupCaveatsWithNames(ctx, names) if err != nil && errors.As(err, &common.RevisionUnavailableError{}) { log.Trace().Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") - return rr.primary.SnapshotReader(rr.rev).LegacyLookupCaveatsWithNames(ctx, names) + return rr.primary.SnapshotReader(rr.rev, rr.schemaHash).LegacyLookupCaveatsWithNames(ctx, names) } return caveats, err } @@ -149,7 +151,7 @@ func queryRelationships[F any, O any]( ) (datastore.RelationshipIterator, error) { strictReadReplicatedTotalQueryCount.Inc() - sr := rr.replica.SnapshotReader(rr.rev) + sr := rr.replica.SnapshotReader(rr.rev, rr.schemaHash) it, err := handler(sr)(ctx, filter, options...) // Check for a RevisionUnavailableError, which indicates the replica does not contain the requested // revision. In this case, use the primary instead. This may not be returned on this call from @@ -158,7 +160,7 @@ func queryRelationships[F any, O any]( if errors.As(err, &common.RevisionUnavailableError{}) { log.Trace().Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") strictReadReplicatedFallbackQueryCount.WithLabelValues(rr.replicaID).Inc() - return handler(rr.primary.SnapshotReader(rr.rev))(ctx, filter, options...) + return handler(rr.primary.SnapshotReader(rr.rev, rr.schemaHash))(ctx, filter, options...) } return nil, err } @@ -175,7 +177,7 @@ func queryRelationships[F any, O any]( if errors.As(err, &common.RevisionUnavailableError{}) { log.Trace().Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") strictReadReplicatedFallbackQueryCount.WithLabelValues(rr.replicaID).Inc() - return handler(rr.primary.SnapshotReader(rr.rev))(ctx, filter, options...) + return handler(rr.primary.SnapshotReader(rr.rev, rr.schemaHash))(ctx, filter, options...) } return nil, err } @@ -212,51 +214,51 @@ func (rr *strictReadReplicatedReader) ReverseQueryRelationships( } func (rr *strictReadReplicatedReader) LegacyReadNamespaceByName(ctx context.Context, nsName string) (ns *core.NamespaceDefinition, lastWritten datastore.Revision, err error) { - sr := rr.replica.SnapshotReader(rr.rev) + sr := rr.replica.SnapshotReader(rr.rev, rr.schemaHash) namespace, lastWritten, err := sr.LegacyReadNamespaceByName(ctx, nsName) if err != nil && errors.As(err, &common.RevisionUnavailableError{}) { log.Trace().Str("namespace", nsName).Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") - return rr.primary.SnapshotReader(rr.rev).LegacyReadNamespaceByName(ctx, nsName) + return rr.primary.SnapshotReader(rr.rev, rr.schemaHash).LegacyReadNamespaceByName(ctx, nsName) } return namespace, lastWritten, err } func (rr *strictReadReplicatedReader) LegacyListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { - sr := rr.replica.SnapshotReader(rr.rev) + sr := rr.replica.SnapshotReader(rr.rev, rr.schemaHash) namespaces, err := sr.LegacyListAllNamespaces(ctx) if err != nil && errors.As(err, &common.RevisionUnavailableError{}) { log.Trace().Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") - return rr.primary.SnapshotReader(rr.rev).LegacyListAllNamespaces(ctx) + return rr.primary.SnapshotReader(rr.rev, rr.schemaHash).LegacyListAllNamespaces(ctx) } return namespaces, err } func (rr *strictReadReplicatedReader) LegacyLookupNamespacesWithNames(ctx context.Context, nsNames []string) ([]datastore.RevisionedNamespace, error) { - sr := rr.replica.SnapshotReader(rr.rev) + sr := rr.replica.SnapshotReader(rr.rev, rr.schemaHash) namespaces, err := sr.LegacyLookupNamespacesWithNames(ctx, nsNames) if err != nil && errors.As(err, &common.RevisionUnavailableError{}) { log.Trace().Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") - return rr.primary.SnapshotReader(rr.rev).LegacyLookupNamespacesWithNames(ctx, nsNames) + return rr.primary.SnapshotReader(rr.rev, rr.schemaHash).LegacyLookupNamespacesWithNames(ctx, nsNames) } return namespaces, err } func (rr *strictReadReplicatedReader) CountRelationships(ctx context.Context, filter string) (int, error) { - sr := rr.replica.SnapshotReader(rr.rev) + sr := rr.replica.SnapshotReader(rr.rev, rr.schemaHash) count, err := sr.CountRelationships(ctx, filter) if err != nil && errors.As(err, &common.RevisionUnavailableError{}) { log.Trace().Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") - return rr.primary.SnapshotReader(rr.rev).CountRelationships(ctx, filter) + return rr.primary.SnapshotReader(rr.rev, rr.schemaHash).CountRelationships(ctx, filter) } return count, err } func (rr *strictReadReplicatedReader) LookupCounters(ctx context.Context) ([]datastore.RelationshipCounter, error) { - sr := rr.replica.SnapshotReader(rr.rev) + sr := rr.replica.SnapshotReader(rr.rev, rr.schemaHash) counters, err := sr.LookupCounters(ctx) if err != nil && errors.As(err, &common.RevisionUnavailableError{}) { log.Trace().Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") - return rr.primary.SnapshotReader(rr.rev).LookupCounters(ctx) + return rr.primary.SnapshotReader(rr.rev, rr.schemaHash).LookupCounters(ctx) } return counters, err } @@ -264,5 +266,32 @@ func (rr *strictReadReplicatedReader) LookupCounters(ctx context.Context) ([]dat // SchemaReader returns the SchemaReader instance of the replica's SchemaReader at the // associated revision. func (rr *strictReadReplicatedReader) SchemaReader() (datastore.SchemaReader, error) { - return rr.replica.SnapshotReader(rr.rev).SchemaReader() + return rr.replica.SnapshotReader(rr.rev, rr.schemaHash).SchemaReader() } + +func (rr *strictReadReplicatedReader) ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) { + sr := rr.replica.SnapshotReader(rr.rev, rr.schemaHash) + singleStoreReader, ok := sr.(datastore.SingleStoreSchemaReader) + if !ok { + return nil, errors.New("replica reader does not implement SingleStoreSchemaReader") + } + + schema, err := singleStoreReader.ReadStoredSchema(ctx) + if err != nil && errors.As(err, &common.RevisionUnavailableError{}) { + log.Trace().Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") + pr := rr.primary.SnapshotReader(rr.rev, rr.schemaHash) + primarySingleStore, ok := pr.(datastore.SingleStoreSchemaReader) + if !ok { + return nil, errors.New("primary reader does not implement SingleStoreSchemaReader") + } + return primarySingleStore.ReadStoredSchema(ctx) + } + return schema, err +} + +var ( + _ datastore.Reader = (*strictReadReplicatedReader)(nil) + _ datastore.LegacySchemaReader = (*strictReadReplicatedReader)(nil) + _ datastore.SingleStoreSchemaReader = (*strictReadReplicatedReader)(nil) + _ datastore.DualSchemaReader = (*strictReadReplicatedReader)(nil) +) diff --git a/internal/datastore/proxy/strictreplicated_test.go b/internal/datastore/proxy/strictreplicated_test.go index 4ec25b60b..93816ce96 100644 --- a/internal/datastore/proxy/strictreplicated_test.go +++ b/internal/datastore/proxy/strictreplicated_test.go @@ -26,7 +26,7 @@ func TestStrictReplicatedQueryFallsbackToPrimaryOnRevisionNotAvailableError(t *t require.NoError(t, err) // Query the replicated, which should fallback to the primary. - reader := replicated.SnapshotReader(revisionparsing.MustParseRevisionForTest("3")) + reader := replicated.SnapshotReader(revisionparsing.MustParseRevisionForTest("3"), datastore.NoSchemaHashForTesting) iter, err := reader.QueryRelationships(t.Context(), datastore.RelationshipsFilter{ OptionalResourceType: "resource", }) @@ -46,7 +46,7 @@ func TestStrictReplicatedQueryFallsbackToPrimaryOnRevisionNotAvailableError(t *t require.Len(t, revfound, 2) // Query the replica directly, which should error. - reader = replica.SnapshotReader(revisionparsing.MustParseRevisionForTest("3")) + reader = replica.SnapshotReader(revisionparsing.MustParseRevisionForTest("3"), datastore.NoSchemaHashForTesting) iter, err = reader.QueryRelationships(t.Context(), datastore.RelationshipsFilter{ OptionalResourceType: "resource", }) @@ -66,7 +66,7 @@ func TestStrictReplicatedQueryFallsbackToPrimaryOnRevisionNotAvailableError(t *t require.ErrorContains(t, err, "revision not available") // Query the replica for a different revision, which should work. - reader = replica.SnapshotReader(revisionparsing.MustParseRevisionForTest("1")) + reader = replica.SnapshotReader(revisionparsing.MustParseRevisionForTest("1"), datastore.NoSchemaHashForTesting) iter, err = reader.QueryRelationships(t.Context(), datastore.RelationshipsFilter{ OptionalResourceType: "resource", }) @@ -95,7 +95,7 @@ func TestStrictReplicatedQueryNonFallbackError(t *testing.T) { require.NoError(t, err) // Query the replicated, which should return the error. - reader := replicated.SnapshotReader(revisionparsing.MustParseRevisionForTest("3")) + reader := replicated.SnapshotReader(revisionparsing.MustParseRevisionForTest("3"), datastore.NoSchemaHashForTesting) _, err = reader.QueryRelationships(t.Context(), datastore.RelationshipsFilter{ OptionalResourceType: "resource", }) diff --git a/internal/datastore/revisions/hlcrevision.go b/internal/datastore/revisions/hlcrevision.go index 8699450c0..1025ca459 100644 --- a/internal/datastore/revisions/hlcrevision.go +++ b/internal/datastore/revisions/hlcrevision.go @@ -134,6 +134,10 @@ func (hlc HLCRevision) String() string { return strconv.FormatInt(hlc.time, 10) + "." + strings.Repeat("0", logicalClockLength-len(logicalClockString)) + logicalClockString } +func (hlc HLCRevision) Key() string { + return hlc.String() +} + func (hlc HLCRevision) TimestampNanoSec() int64 { return hlc.time } diff --git a/internal/datastore/revisions/hlcrevision_test.go b/internal/datastore/revisions/hlcrevision_test.go index 90b009714..d894b85f9 100644 --- a/internal/datastore/revisions/hlcrevision_test.go +++ b/internal/datastore/revisions/hlcrevision_test.go @@ -295,3 +295,43 @@ func TestFailsIfLogicalClockExceedsMaxUin32(t *testing.T) { _, _ = HLCRevisionFromString("0.9999999999") }) } + +func TestHLCRevisionKey(t *testing.T) { + testCases := []struct { + name string + revision string + }{ + { + name: "simple timestamp", + revision: "1.0000000000", + }, + { + name: "with logical clock", + revision: "1703283409994227985.0000000004", + }, + { + name: "large logical clock", + revision: "1703283409994227985.0010000000", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rev, err := HLCRevisionFromString(tc.revision) + require.NoError(t, err) + + // Key should be deterministic + key1 := rev.Key() + key2 := rev.Key() + require.Equal(t, key1, key2, "Key() should be deterministic") + + // Key should equal String() for HLC revisions + require.Equal(t, rev.String(), key1, "Key() should match String() for HLC revisions") + + // Equal revisions should have equal keys + rev2, err := HLCRevisionFromString(tc.revision) + require.NoError(t, err) + require.Equal(t, rev.Key(), rev2.Key(), "Equal revisions should have equal keys") + }) + } +} diff --git a/internal/datastore/revisions/optimized.go b/internal/datastore/revisions/optimized.go index 3fee36938..de918d20a 100644 --- a/internal/datastore/revisions/optimized.go +++ b/internal/datastore/revisions/optimized.go @@ -21,8 +21,8 @@ var tracer = otel.Tracer("spicedb/internal/datastore/common/revisions") // OptimizedRevisionFunction instructs the datastore to compute its own current // optimized revision given the specific quantization, and return for how long -// it will remain valid. -type OptimizedRevisionFunction func(context.Context) (rev datastore.Revision, validFor time.Duration, err error) +// it will remain valid, along with the schema hash. +type OptimizedRevisionFunction func(context.Context) (rev datastore.Revision, validFor time.Duration, schemaHash datastore.SchemaHash, err error) // NewCachedOptimizedRevisions returns a CachedOptimizedRevisions for the given configuration func NewCachedOptimizedRevisions(maxRevisionStaleness time.Duration) *CachedOptimizedRevisions { @@ -38,7 +38,7 @@ func (cor *CachedOptimizedRevisions) SetOptimizedRevisionFunc(revisionFunc Optim cor.optimizedFunc = revisionFunc } -func (cor *CachedOptimizedRevisions) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { +func (cor *CachedOptimizedRevisions) OptimizedRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { span := trace.SpanFromContext(ctx) localNow := cor.clockFn.Now() @@ -58,15 +58,15 @@ func (cor *CachedOptimizedRevisions) OptimizedRevision(ctx context.Context) (dat cor.RUnlock() log.Ctx(ctx).Debug().Time("now", localNow).Time("valid", candidate.validThrough).Msg("returning cached revision") span.AddEvent(otelconv.EventDatastoreRevisionsCacheReturned) - return candidate.revision, nil + return candidate.revision, candidate.schemaHash, nil } } cor.RUnlock() - newQuantizedRevision, err, _ := cor.updateGroup.Do("", func() (any, error) { + result, err, _ := cor.updateGroup.Do("", func() (any, error) { log.Ctx(ctx).Debug().Time("now", localNow).Msg("computing new revision") - optimized, validFor, err := cor.optimizedFunc(ctx) + optimized, validFor, schemaHash, err := cor.optimizedFunc(ctx) if err != nil { return nil, fmt.Errorf("unable to compute optimized revision: %w", err) } @@ -85,17 +85,24 @@ func (cor *CachedOptimizedRevisions) OptimizedRevision(ctx context.Context) (dat } cor.candidates = cor.candidates[numToDrop:] - cor.candidates = append(cor.candidates, validRevision{optimized, rvt}) + cor.candidates = append(cor.candidates, validRevision{optimized, rvt, schemaHash}) cor.Unlock() span.AddEvent(otelconv.EventDatastoreRevisionsComputed) log.Ctx(ctx).Debug().Time("now", localNow).Time("valid", rvt).Stringer("validFor", validFor).Msg("setting valid through") - return optimized, nil + return struct { + rev datastore.Revision + hash datastore.SchemaHash + }{optimized, schemaHash}, nil }) if err != nil { - return datastore.NoRevision, err + return datastore.NoRevision, "", err } - return newQuantizedRevision.(datastore.Revision), err + r := result.(struct { + rev datastore.Revision + hash datastore.SchemaHash + }) + return r.rev, r.hash, nil } // CachedOptimizedRevisions does caching and deduplication for requests for optimized revisions. @@ -116,4 +123,5 @@ type CachedOptimizedRevisions struct { type validRevision struct { revision datastore.Revision validThrough time.Time + schemaHash datastore.SchemaHash } diff --git a/internal/datastore/revisions/optimized_test.go b/internal/datastore/revisions/optimized_test.go index 5adcd184d..22d733b78 100644 --- a/internal/datastore/revisions/optimized_test.go +++ b/internal/datastore/revisions/optimized_test.go @@ -20,9 +20,9 @@ type trackingRevisionFunction struct { mock.Mock } -func (m *trackingRevisionFunction) optimizedRevisionFunc(_ context.Context) (datastore.Revision, time.Duration, error) { +func (m *trackingRevisionFunction) optimizedRevisionFunc(_ context.Context) (datastore.Revision, time.Duration, datastore.SchemaHash, error) { args := m.Called() - return args.Get(0).(datastore.Revision), args.Get(1).(time.Duration), args.Error(2) + return args.Get(0).(datastore.Revision), args.Get(1).(time.Duration), datastore.NoSchemaHashForTesting, args.Error(2) } var ( @@ -128,7 +128,7 @@ func TestOptimizedRevisionCache(t *testing.T) { } require.Eventually(func() bool { - revision, err := or.OptimizedRevision(ctx) + revision, _, err := or.OptimizedRevision(ctx) require.NoError(err) printableRevSet := slicez.Map(expectedRevSet, func(val datastore.Revision) string { return val.String() @@ -166,7 +166,7 @@ func TestOptimizedRevisionCacheSingleFlight(t *testing.T) { g := errgroup.Group{} for i := 0; i < 10; i++ { g.Go(func() error { - revision, err := or.OptimizedRevision(ctx) + revision, _, err := or.OptimizedRevision(ctx) if err != nil { return err } @@ -188,20 +188,20 @@ func BenchmarkOptimizedRevisions(b *testing.B) { quantization := 1 * time.Millisecond or := NewCachedOptimizedRevisions(quantization) - or.SetOptimizedRevisionFunc(func(ctx context.Context) (datastore.Revision, time.Duration, error) { + or.SetOptimizedRevisionFunc(func(ctx context.Context) (datastore.Revision, time.Duration, datastore.SchemaHash, error) { nowNS := time.Now().UnixNano() validForNS := nowNS % quantization.Nanoseconds() roundedNS := nowNS - validForNS // This should be non-negative. uintRoundedNs := safecast.RequireConvert[uint64](b, roundedNS) rev := NewForTransactionID(uintRoundedNs) - return rev, time.Duration(validForNS) * time.Nanosecond, nil + return rev, time.Duration(validForNS) * time.Nanosecond, datastore.NoSchemaHashForTesting, nil }) ctx := b.Context() b.RunParallel(func(p *testing.PB) { for p.Next() { - if _, err := or.OptimizedRevision(ctx); err != nil { + if _, _, err := or.OptimizedRevision(ctx); err != nil { b.FailNow() } } @@ -223,7 +223,7 @@ func TestSingleFlightError(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second) defer cancel() - _, err := or.OptimizedRevision(ctx) + _, _, err := or.OptimizedRevision(ctx) req.Error(err) mock.AssertExpectations(t) } diff --git a/internal/datastore/revisions/remoteclock.go b/internal/datastore/revisions/remoteclock.go index ef793c86d..800578b9c 100644 --- a/internal/datastore/revisions/remoteclock.go +++ b/internal/datastore/revisions/remoteclock.go @@ -48,19 +48,19 @@ func NewRemoteClockRevisions(gcWindow, maxRevisionStaleness, followerReadDelay, return revisions } -func (rcr *RemoteClockRevisions) optimizedRevisionFunc(ctx context.Context) (datastore.Revision, time.Duration, error) { +func (rcr *RemoteClockRevisions) optimizedRevisionFunc(ctx context.Context) (datastore.Revision, time.Duration, datastore.SchemaHash, error) { nowRev, err := rcr.nowFunc(ctx) if err != nil { - return datastore.NoRevision, 0, err + return datastore.NoRevision, 0, "", err } if nowRev == datastore.NoRevision { - return datastore.NoRevision, 0, datastore.NewInvalidRevisionErr(nowRev, datastore.CouldNotDetermineRevision) + return datastore.NoRevision, 0, "", datastore.NewInvalidRevisionErr(nowRev, datastore.CouldNotDetermineRevision) } nowTS, ok := nowRev.(WithTimestampRevision) if !ok { - return datastore.NoRevision, 0, spiceerrors.MustBugf("expected with-timestamp revision, got %T", nowRev) + return datastore.NoRevision, 0, "", spiceerrors.MustBugf("expected with-timestamp revision, got %T", nowRev) } delayedNow := nowTS.TimestampNanoSec() - rcr.followerReadDelayNanos @@ -77,7 +77,9 @@ func (rcr *RemoteClockRevisions) optimizedRevisionFunc(ctx context.Context) (dat Int64("totalSkew", nowTS.TimestampNanoSec()-quantized). Msg("revision skews") - return nowTS.ConstructForTimestamp(quantized), time.Duration(validForNanos) * time.Nanosecond, nil + // Remote clock systems don't have schema hash from the revision function + // The schema hash will be loaded separately when needed + return nowTS.ConstructForTimestamp(quantized), time.Duration(validForNanos) * time.Nanosecond, "", nil } // SetNowFunc sets the function used to determine the head revision diff --git a/internal/datastore/revisions/remoteclock_test.go b/internal/datastore/revisions/remoteclock_test.go index e5088890b..1d88a0c93 100644 --- a/internal/datastore/revisions/remoteclock_test.go +++ b/internal/datastore/revisions/remoteclock_test.go @@ -88,7 +88,7 @@ func TestRemoteClockOptimizedRevisions(t *testing.T) { remoteClock.Set(time.Unix(timeAndExpected.unixTime, 0)) expected := NewForTimestamp(timeAndExpected.expected * 1_000_000_000) - optimized, err := rcr.OptimizedRevision(t.Context()) + optimized, _, err := rcr.OptimizedRevision(t.Context()) require.NoError(err) require.True( expected.Equal(optimized), @@ -165,7 +165,7 @@ func TestRemoteClockStalenessBeyondGC(t *testing.T) { remoteClock.Set(time.Unix(currentTime, 0)) // Call optimized revision. - optimized, err := rcr.OptimizedRevision(t.Context()) + optimized, _, err := rcr.OptimizedRevision(t.Context()) require.NoError(t, err) // Ensure the optimized revision is not past the GC window. @@ -175,7 +175,7 @@ func TestRemoteClockStalenessBeyondGC(t *testing.T) { // Set the current time to 100001 to ensure the optimized revision is past the GC window. remoteClock.Set(time.Unix(100001, 0)) - newOptimized, err := rcr.OptimizedRevision(t.Context()) + newOptimized, _, err := rcr.OptimizedRevision(t.Context()) require.NoError(t, err) // Ensure the new optimized revision is not past the GC window. diff --git a/internal/datastore/revisions/timestamprevision.go b/internal/datastore/revisions/timestamprevision.go index fc2a25013..e3afbde4b 100644 --- a/internal/datastore/revisions/timestamprevision.go +++ b/internal/datastore/revisions/timestamprevision.go @@ -69,6 +69,10 @@ func (ir TimestampRevision) String() string { return strconv.FormatInt(int64(ir), 10) } +func (ir TimestampRevision) Key() string { + return ir.String() +} + func (ir TimestampRevision) Time() time.Time { return time.Unix(0, int64(ir)) } diff --git a/internal/datastore/revisions/timestamprevision_test.go b/internal/datastore/revisions/timestamprevision_test.go index a0662c9b7..fbe49f6e8 100644 --- a/internal/datastore/revisions/timestamprevision_test.go +++ b/internal/datastore/revisions/timestamprevision_test.go @@ -15,3 +15,45 @@ func TestZeroTimestampRevision(t *testing.T) { require.False(t, TimestampRevision(1).Equal(zeroTimestampRevision)) require.True(t, TimestampRevision(1).GreaterThan(zeroTimestampRevision)) } + +func TestTimestampRevisionKey(t *testing.T) { + testCases := []struct { + name string + timestamp int64 + }{ + { + name: "zero", + timestamp: 0, + }, + { + name: "small timestamp", + timestamp: 1000000000, + }, + { + name: "large timestamp", + timestamp: 1703283409994227985, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rev := NewForTimestamp(tc.timestamp) + + // Key should be deterministic + key1 := rev.Key() + key2 := rev.Key() + require.Equal(t, key1, key2, "Key() should be deterministic") + + // Key should equal String() for timestamp revisions + require.Equal(t, rev.String(), key1, "Key() should match String() for timestamp revisions") + + // Equal revisions should have equal keys + rev2 := NewForTimestamp(tc.timestamp) + require.Equal(t, rev.Key(), rev2.Key(), "Equal revisions should have equal keys") + + // Different revisions should have different keys + rev3 := NewForTimestamp(tc.timestamp + 1) + require.NotEqual(t, rev.Key(), rev3.Key(), "Different revisions should have different keys") + }) + } +} diff --git a/internal/datastore/revisions/txidrevision.go b/internal/datastore/revisions/txidrevision.go index 31d837ff2..28414cc3a 100644 --- a/internal/datastore/revisions/txidrevision.go +++ b/internal/datastore/revisions/txidrevision.go @@ -63,6 +63,10 @@ func (ir TransactionIDRevision) String() string { return strconv.FormatUint(uint64(ir), 10) } +func (ir TransactionIDRevision) Key() string { + return ir.String() +} + func (ir TransactionIDRevision) WithInexactFloat64() float64 { return float64(ir) } diff --git a/internal/datastore/revisions/txidrevision_test.go b/internal/datastore/revisions/txidrevision_test.go index a8252635a..7953bc01f 100644 --- a/internal/datastore/revisions/txidrevision_test.go +++ b/internal/datastore/revisions/txidrevision_test.go @@ -15,3 +15,45 @@ func TestZeroTransactionIDRevision(t *testing.T) { require.False(t, TransactionIDRevision(1).Equal(zeroTransactionIDRevision)) require.True(t, TransactionIDRevision(1).GreaterThan(zeroTransactionIDRevision)) } + +func TestTransactionIDRevisionKey(t *testing.T) { + testCases := []struct { + name string + txID uint64 + }{ + { + name: "zero", + txID: 0, + }, + { + name: "small number", + txID: 42, + }, + { + name: "large number", + txID: 1234567890123456789, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rev := NewForTransactionID(tc.txID) + + // Key should be deterministic + key1 := rev.Key() + key2 := rev.Key() + require.Equal(t, key1, key2, "Key() should be deterministic") + + // Key should equal String() for transaction ID revisions + require.Equal(t, rev.String(), key1, "Key() should match String() for transaction ID revisions") + + // Equal revisions should have equal keys + rev2 := NewForTransactionID(tc.txID) + require.Equal(t, rev.Key(), rev2.Key(), "Equal revisions should have equal keys") + + // Different revisions should have different keys + rev3 := NewForTransactionID(tc.txID + 1) + require.NotEqual(t, rev.Key(), rev3.Key(), "Different revisions should have different keys") + }) + } +} diff --git a/internal/datastore/schema/schema.go b/internal/datastore/schema/schema.go index 47b7d39ed..3fe9f44ab 100644 --- a/internal/datastore/schema/schema.go +++ b/internal/datastore/schema/schema.go @@ -10,6 +10,7 @@ import ( "github.com/authzed/spicedb/pkg/caveats/types" "github.com/authzed/spicedb/pkg/datastore" + dsoptions "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/genutil/mapz" core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/schemadsl/compiler" @@ -373,3 +374,137 @@ func (l *LegacySchemaWriterAdapter) AddDefinitionsForTesting(ctx context.Context } var _ datastore.SchemaWriter = (*LegacySchemaWriterAdapter)(nil) + +// NewSchemaReader creates a new schema reader based on the experimental schema mode. +// The reader parameter must implement DualSchemaReader, which provides both legacy and single-store methods. +// The snapshotRevision is the revision at which the reader is reading. +// Based on the schema mode, this function returns the appropriate reader implementation. +func NewSchemaReader(reader datastore.DualSchemaReader, schemaMode dsoptions.SchemaMode, snapshotRevision datastore.Revision) datastore.SchemaReader { + switch schemaMode { + case dsoptions.SchemaModeReadNewWriteBoth, dsoptions.SchemaModeReadNewWriteNew: + // Use unified schema storage for reading + return newSingleStoreSchemaReader(reader, snapshotRevision) + + case dsoptions.SchemaModeReadLegacyWriteLegacy, dsoptions.SchemaModeReadLegacyWriteBoth: + // Use legacy schema storage for reading + return NewLegacySchemaReaderAdapter(reader) + + default: + panic(fmt.Sprintf("unsupported schema mode: %v", schemaMode)) + } +} + +// NewSchemaWriter creates a new schema writer based on the experimental schema mode. +// The writer parameter must implement DualSchemaWriter, which provides both legacy and single-store methods. +// Based on the schema mode, this function returns the appropriate writer implementation: +// - SchemaModeReadLegacyWriteLegacy: writes only to legacy storage +// - SchemaModeReadLegacyWriteBoth or SchemaModeReadNewWriteBoth: writes to both legacy and unified storage +// - SchemaModeReadNewWriteNew: writes only to unified storage +func NewSchemaWriter(writer datastore.DualSchemaWriter, reader datastore.DualSchemaReader, schemaMode dsoptions.SchemaMode) datastore.SchemaWriter { + switch schemaMode { + case dsoptions.SchemaModeReadNewWriteNew: + // Use unified schema storage only + return newSingleStoreSchemaWriter(writer, reader) + + case dsoptions.SchemaModeReadLegacyWriteBoth, dsoptions.SchemaModeReadNewWriteBoth: + // Write to both legacy and unified storage + return newDualSchemaWriter(writer, reader) + + case dsoptions.SchemaModeReadLegacyWriteLegacy: + // Use legacy schema storage only + return NewLegacySchemaWriterAdapter(writer, reader) + + default: + panic(fmt.Sprintf("unsupported schema mode: %v", schemaMode)) + } +} + +// dualSchemaWriter writes to both legacy and unified schema storage. +type dualSchemaWriter struct { + legacyWriter *LegacySchemaWriterAdapter + unifiedWriter *singleStoreSchemaWriter +} + +// newDualSchemaWriter creates a new writer that writes to both legacy and unified storage. +func newDualSchemaWriter(writer datastore.DualSchemaWriter, reader datastore.DualSchemaReader) *dualSchemaWriter { + return &dualSchemaWriter{ + legacyWriter: NewLegacySchemaWriterAdapter(writer, reader), + unifiedWriter: newSingleStoreSchemaWriter(writer, reader), + } +} + +// WriteSchema writes the schema to both legacy and unified storage. +func (d *dualSchemaWriter) WriteSchema(ctx context.Context, definitions []datastore.SchemaDefinition, schemaString string, caveatTypeSet *types.TypeSet) error { + // Write to legacy storage first + if err := d.legacyWriter.WriteSchema(ctx, definitions, schemaString, caveatTypeSet); err != nil { + return fmt.Errorf("failed to write to legacy storage: %w", err) + } + + // Write the legacy schema hash since we're in "write to both" mode. + // This is done after the legacy write completes, using the final set of definitions. + // This avoids the problem of reading back buffered writes that aren't yet visible. + if hashWriter, ok := d.legacyWriter.legacyWriter.(datastore.LegacySchemaHashWriter); ok { + // Separate namespaces and caveats from definitions + var namespaces []*core.NamespaceDefinition + var caveats []*core.CaveatDefinition + + for _, def := range definitions { + switch typedDef := def.(type) { + case *core.NamespaceDefinition: + namespaces = append(namespaces, typedDef) + case *core.CaveatDefinition: + caveats = append(caveats, typedDef) + } + } + + // Build the final list of namespaces + finalNamespaces := make([]datastore.RevisionedNamespace, 0, len(namespaces)) + for _, ns := range namespaces { + finalNamespaces = append(finalNamespaces, datastore.RevisionedNamespace{ + Definition: ns, + // Revision doesn't matter for hash computation + LastWrittenRevision: datastore.NoRevision, + }) + } + + // Build the final list of caveats + finalCaveats := make([]datastore.RevisionedCaveat, 0, len(caveats)) + for _, caveat := range caveats { + finalCaveats = append(finalCaveats, datastore.RevisionedCaveat{ + Definition: caveat, + // Revision doesn't matter for hash computation + LastWrittenRevision: datastore.NoRevision, + }) + } + + if err := hashWriter.WriteLegacySchemaHashFromDefinitions(ctx, finalNamespaces, finalCaveats); err != nil { + return fmt.Errorf("failed to write legacy schema hash: %w", err) + } + } + + // Write to unified storage + if err := d.unifiedWriter.WriteSchema(ctx, definitions, schemaString, caveatTypeSet); err != nil { + return fmt.Errorf("failed to write to unified storage: %w", err) + } + + return nil +} + +// AddDefinitionsForTesting adds or overwrites schema definitions in both storages. +func (d *dualSchemaWriter) AddDefinitionsForTesting(ctx context.Context, tb testing.TB, definitions ...datastore.SchemaDefinition) error { + tb.Helper() + + // Add to legacy storage first + if err := d.legacyWriter.AddDefinitionsForTesting(ctx, tb, definitions...); err != nil { + return fmt.Errorf("failed to add to legacy storage: %w", err) + } + + // Add to unified storage + if err := d.unifiedWriter.AddDefinitionsForTesting(ctx, tb, definitions...); err != nil { + return fmt.Errorf("failed to add to unified storage: %w", err) + } + + return nil +} + +var _ datastore.SchemaWriter = (*dualSchemaWriter)(nil) diff --git a/internal/datastore/schema/singlestore.go b/internal/datastore/schema/singlestore.go new file mode 100644 index 000000000..095ddca1d --- /dev/null +++ b/internal/datastore/schema/singlestore.go @@ -0,0 +1,492 @@ +package schema + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "sort" + "testing" + + "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +const ( + // currentSchemaVersion is the current version of the StoredSchema proto. + currentSchemaVersion = 1 +) + +// singleStoreSchemaReader implements SchemaReader using the unified schema storage. +type singleStoreSchemaReader struct { + reader datastore.SingleStoreSchemaReader + snapshotRevision datastore.Revision +} + +// newSingleStoreSchemaReader creates a new schema reader that uses unified schema storage. +func newSingleStoreSchemaReader(reader datastore.SingleStoreSchemaReader, snapshotRevision datastore.Revision) *singleStoreSchemaReader { + return &singleStoreSchemaReader{ + reader: reader, + snapshotRevision: snapshotRevision, + } +} + +// SchemaText returns the schema text from the unified schema storage. +func (s *singleStoreSchemaReader) SchemaText() (string, error) { + ctx := context.Background() + + // Read the stored schema + storedSchema, err := s.reader.ReadStoredSchema(ctx) + if err != nil { + if errors.Is(err, datastore.ErrSchemaNotFound) { + return "", datastore.NewSchemaNotDefinedErr() + } + return "", err + } + + // Extract schema text based on version + switch { + case storedSchema.GetV1() != nil: + return storedSchema.GetV1().SchemaText, nil + default: + return "", fmt.Errorf("unsupported schema version: %d", storedSchema.Version) + } +} + +// LookupTypeDefByName looks up a type definition by name from the unified schema. +func (s *singleStoreSchemaReader) LookupTypeDefByName(ctx context.Context, name string) (datastore.RevisionedTypeDefinition, bool, error) { + storedSchema, err := s.reader.ReadStoredSchema(ctx) + if err != nil { + if errors.Is(err, datastore.ErrSchemaNotFound) { + return datastore.RevisionedTypeDefinition{}, false, nil + } + return datastore.RevisionedTypeDefinition{}, false, err + } + + v1 := storedSchema.GetV1() + if v1 == nil { + return datastore.RevisionedTypeDefinition{}, false, fmt.Errorf("unsupported schema version: %d", storedSchema.Version) + } + + ns, found := v1.NamespaceDefinitions[name] + if !found { + return datastore.RevisionedTypeDefinition{}, false, nil + } + + return datastore.RevisionedTypeDefinition{ + Definition: ns, + // In unified schema mode, all definitions share the schema's revision + LastWrittenRevision: s.snapshotRevision, + }, true, nil +} + +// LookupCaveatDefByName looks up a caveat definition by name from the unified schema. +func (s *singleStoreSchemaReader) LookupCaveatDefByName(ctx context.Context, name string) (datastore.RevisionedCaveat, bool, error) { + storedSchema, err := s.reader.ReadStoredSchema(ctx) + if err != nil { + if errors.Is(err, datastore.ErrSchemaNotFound) { + return datastore.RevisionedCaveat{}, false, nil + } + return datastore.RevisionedCaveat{}, false, err + } + + v1 := storedSchema.GetV1() + if v1 == nil { + return datastore.RevisionedCaveat{}, false, fmt.Errorf("unsupported schema version: %d", storedSchema.Version) + } + + caveat, found := v1.CaveatDefinitions[name] + if !found { + return datastore.RevisionedCaveat{}, false, nil + } + + return datastore.RevisionedCaveat{ + Definition: caveat, + // In unified schema mode, all definitions share the schema's revision + LastWrittenRevision: s.snapshotRevision, + }, true, nil +} + +// ListAllTypeDefinitions lists all type definitions from the unified schema. +func (s *singleStoreSchemaReader) ListAllTypeDefinitions(ctx context.Context) ([]datastore.RevisionedTypeDefinition, error) { + storedSchema, err := s.reader.ReadStoredSchema(ctx) + if err != nil { + if errors.Is(err, datastore.ErrSchemaNotFound) { + return nil, nil + } + return nil, err + } + + v1 := storedSchema.GetV1() + if v1 == nil { + return nil, fmt.Errorf("unsupported schema version: %d", storedSchema.Version) + } + + results := make([]datastore.RevisionedTypeDefinition, 0, len(v1.NamespaceDefinitions)) + for _, ns := range v1.NamespaceDefinitions { + results = append(results, datastore.RevisionedTypeDefinition{ + Definition: ns, + LastWrittenRevision: s.snapshotRevision, + }) + } + + return results, nil +} + +// ListAllCaveatDefinitions lists all caveat definitions from the unified schema. +func (s *singleStoreSchemaReader) ListAllCaveatDefinitions(ctx context.Context) ([]datastore.RevisionedCaveat, error) { + storedSchema, err := s.reader.ReadStoredSchema(ctx) + if err != nil { + if errors.Is(err, datastore.ErrSchemaNotFound) { + return nil, nil + } + return nil, err + } + + v1 := storedSchema.GetV1() + if v1 == nil { + return nil, fmt.Errorf("unsupported schema version: %d", storedSchema.Version) + } + + results := make([]datastore.RevisionedCaveat, 0, len(v1.CaveatDefinitions)) + for _, caveat := range v1.CaveatDefinitions { + results = append(results, datastore.RevisionedCaveat{ + Definition: caveat, + LastWrittenRevision: s.snapshotRevision, + }) + } + + return results, nil +} + +// ListAllSchemaDefinitions lists all schema definitions from the unified schema. +func (s *singleStoreSchemaReader) ListAllSchemaDefinitions(ctx context.Context) (map[string]datastore.SchemaDefinition, error) { + storedSchema, err := s.reader.ReadStoredSchema(ctx) + if err != nil { + if errors.Is(err, datastore.ErrSchemaNotFound) { + return make(map[string]datastore.SchemaDefinition), nil + } + return nil, err + } + + v1 := storedSchema.GetV1() + if v1 == nil { + return nil, fmt.Errorf("unsupported schema version: %d", storedSchema.Version) + } + + result := make(map[string]datastore.SchemaDefinition, len(v1.NamespaceDefinitions)+len(v1.CaveatDefinitions)) + for _, ns := range v1.NamespaceDefinitions { + result[ns.Name] = ns + } + for _, caveat := range v1.CaveatDefinitions { + result[caveat.Name] = caveat + } + + return result, nil +} + +// LookupSchemaDefinitionsByNames looks up schema definitions by name from the unified schema. +func (s *singleStoreSchemaReader) LookupSchemaDefinitionsByNames(ctx context.Context, names []string) (map[string]datastore.SchemaDefinition, error) { + storedSchema, err := s.reader.ReadStoredSchema(ctx) + if err != nil { + if errors.Is(err, datastore.ErrSchemaNotFound) { + return make(map[string]datastore.SchemaDefinition), nil + } + return nil, err + } + + v1 := storedSchema.GetV1() + if v1 == nil { + return nil, fmt.Errorf("unsupported schema version: %d", storedSchema.Version) + } + + result := make(map[string]datastore.SchemaDefinition) + for _, name := range names { + if ns, found := v1.NamespaceDefinitions[name]; found { + result[name] = ns + } else if caveat, found := v1.CaveatDefinitions[name]; found { + result[name] = caveat + } + } + + return result, nil +} + +// LookupTypeDefinitionsByNames looks up type definitions by name from the unified schema. +func (s *singleStoreSchemaReader) LookupTypeDefinitionsByNames(ctx context.Context, names []string) (map[string]datastore.SchemaDefinition, error) { + storedSchema, err := s.reader.ReadStoredSchema(ctx) + if err != nil { + if errors.Is(err, datastore.ErrSchemaNotFound) { + return make(map[string]datastore.SchemaDefinition), nil + } + return nil, err + } + + v1 := storedSchema.GetV1() + if v1 == nil { + return nil, fmt.Errorf("unsupported schema version: %d", storedSchema.Version) + } + + result := make(map[string]datastore.SchemaDefinition) + for _, name := range names { + if ns, found := v1.NamespaceDefinitions[name]; found { + result[name] = ns + } + } + + return result, nil +} + +// LookupCaveatDefinitionsByNames looks up caveat definitions by name from the unified schema. +func (s *singleStoreSchemaReader) LookupCaveatDefinitionsByNames(ctx context.Context, names []string) (map[string]datastore.SchemaDefinition, error) { + storedSchema, err := s.reader.ReadStoredSchema(ctx) + if err != nil { + if errors.Is(err, datastore.ErrSchemaNotFound) { + return make(map[string]datastore.SchemaDefinition), nil + } + return nil, err + } + + v1 := storedSchema.GetV1() + if v1 == nil { + return nil, fmt.Errorf("unsupported schema version: %d", storedSchema.Version) + } + + result := make(map[string]datastore.SchemaDefinition) + for _, name := range names { + if caveat, found := v1.CaveatDefinitions[name]; found { + result[name] = caveat + } + } + + return result, nil +} + +var _ datastore.SchemaReader = (*singleStoreSchemaReader)(nil) + +// singleStoreSchemaWriter implements SchemaWriter using the unified schema storage. +type singleStoreSchemaWriter struct { + writer datastore.SingleStoreSchemaWriter + reader datastore.SingleStoreSchemaReader +} + +// newSingleStoreSchemaWriter creates a new schema writer that uses unified schema storage. +func newSingleStoreSchemaWriter(writer datastore.SingleStoreSchemaWriter, reader datastore.SingleStoreSchemaReader) *singleStoreSchemaWriter { + return &singleStoreSchemaWriter{ + writer: writer, + reader: reader, + } +} + +// WriteSchema writes the schema to unified storage. +func (s *singleStoreSchemaWriter) WriteSchema(ctx context.Context, definitions []datastore.SchemaDefinition, schemaString string, caveatTypeSet *types.TypeSet) error { + // Build namespace and caveat maps + namespaces := make(map[string]*core.NamespaceDefinition) + caveats := make(map[string]*core.CaveatDefinition) + + for _, def := range definitions { + switch typedDef := def.(type) { + case *core.NamespaceDefinition: + namespaces[typedDef.Name] = typedDef + case *core.CaveatDefinition: + caveats[typedDef.Name] = typedDef + default: + return spiceerrors.MustBugf("unknown definition type: %T", def) + } + } + + // Generate canonical schema hash by sorting definitions alphabetically + // This ensures the hash is deterministic regardless of definition order + // Create a sortable slice - each definition implements both interfaces + sortedDefs := make([]datastore.SchemaDefinition, len(definitions)) + copy(sortedDefs, definitions) + sort.Slice(sortedDefs, func(i, j int) bool { + return sortedDefs[i].GetName() < sortedDefs[j].GetName() + }) + + // Convert to compiler.SchemaDefinition for generator + canonicalDefs := make([]compiler.SchemaDefinition, len(sortedDefs)) + for i, def := range sortedDefs { + canonicalDefs[i] = def.(compiler.SchemaDefinition) + } + + canonicalSchemaText, _, err := generator.GenerateSchema(canonicalDefs) + if err != nil { + return fmt.Errorf("failed to generate canonical schema: %w", err) + } + + // Compute SHA256 hash of the canonical schema text + hashBytes := sha256.Sum256([]byte(canonicalSchemaText)) + schemaHash := hex.EncodeToString(hashBytes[:]) + + // Create the stored schema proto + storedSchema := &core.StoredSchema{ + Version: currentSchemaVersion, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: schemaString, + SchemaHash: schemaHash, + NamespaceDefinitions: namespaces, + CaveatDefinitions: caveats, + }, + }, + } + + // Write to storage + return s.writer.WriteStoredSchema(ctx, storedSchema) +} + +// AddDefinitionsForTesting adds or overwrites schema definitions for testing. +func (s *singleStoreSchemaWriter) AddDefinitionsForTesting(ctx context.Context, tb testing.TB, definitions ...datastore.SchemaDefinition) error { + tb.Helper() + + // Read existing schema + existingSchema, err := s.reader.ReadStoredSchema(ctx) + if err != nil && !errors.Is(err, datastore.ErrSchemaNotFound) { + return err + } + + // Start with empty maps if no existing schema + var namespaces map[string]*core.NamespaceDefinition + var caveats map[string]*core.CaveatDefinition + var schemaText string + + if existingSchema != nil && existingSchema.GetV1() != nil { + v1 := existingSchema.GetV1() + namespaces = make(map[string]*core.NamespaceDefinition, len(v1.NamespaceDefinitions)) + caveats = make(map[string]*core.CaveatDefinition, len(v1.CaveatDefinitions)) + for k, v := range v1.NamespaceDefinitions { + namespaces[k] = v + } + for k, v := range v1.CaveatDefinitions { + caveats[k] = v + } + schemaText = v1.SchemaText + } else { + namespaces = make(map[string]*core.NamespaceDefinition) + caveats = make(map[string]*core.CaveatDefinition) + } + + // Add or overwrite definitions + for _, def := range definitions { + switch typedDef := def.(type) { + case *core.NamespaceDefinition: + namespaces[typedDef.Name] = typedDef + case *core.CaveatDefinition: + caveats[typedDef.Name] = typedDef + default: + return spiceerrors.MustBugf("unknown definition type: %T", def) + } + } + + // Regenerate schema text if needed + if schemaText == "" { + allDefs := make([]compiler.SchemaDefinition, 0, len(namespaces)+len(caveats)) + for _, ns := range namespaces { + allDefs = append(allDefs, ns) + } + for _, caveat := range caveats { + allDefs = append(allDefs, caveat) + } + + caveatTypeSet := types.Default.TypeSet + newSchemaText, _, err := generator.GenerateSchemaWithCaveatTypeSet(allDefs, caveatTypeSet) + if err != nil { + return fmt.Errorf("failed to generate schema text: %w", err) + } + schemaText = newSchemaText + } + + // Generate canonical schema hash by sorting all definitions alphabetically + allDefs := make([]compiler.SchemaDefinition, 0, len(namespaces)+len(caveats)) + for _, ns := range namespaces { + allDefs = append(allDefs, ns) + } + for _, caveat := range caveats { + allDefs = append(allDefs, caveat) + } + + sort.Slice(allDefs, func(i, j int) bool { + return allDefs[i].GetName() < allDefs[j].GetName() + }) + + canonicalSchemaText, _, err := generator.GenerateSchema(allDefs) + if err != nil { + return fmt.Errorf("failed to generate canonical schema: %w", err) + } + + schemaHash := hex.EncodeToString([]byte(canonicalSchemaText)) + storedSchema := &core.StoredSchema{ + Version: currentSchemaVersion, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: schemaText, + SchemaHash: schemaHash, + NamespaceDefinitions: namespaces, + CaveatDefinitions: caveats, + }, + }, + } + + // Write to storage + return s.writer.WriteStoredSchema(ctx, storedSchema) +} + +var _ datastore.SchemaWriter = (*singleStoreSchemaWriter)(nil) + +// BuildStoredSchemaFromDefinitions is a helper function to build a StoredSchema proto from schema definitions. +func BuildStoredSchemaFromDefinitions(definitions []datastore.SchemaDefinition, schemaString string) (*core.StoredSchema, error) { + // Build namespace and caveat maps + namespaces := make(map[string]*core.NamespaceDefinition) + caveats := make(map[string]*core.CaveatDefinition) + + for _, def := range definitions { + switch typedDef := def.(type) { + case *core.NamespaceDefinition: + namespaces[typedDef.Name] = typedDef + case *core.CaveatDefinition: + caveats[typedDef.Name] = typedDef + default: + return nil, spiceerrors.MustBugf("unknown definition type: %T", def) + } + } + + // Generate schema hash from the schema string + schemaHash := hex.EncodeToString([]byte(schemaString)) + + // Create the stored schema proto + return &core.StoredSchema{ + Version: currentSchemaVersion, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: schemaString, + SchemaHash: schemaHash, + NamespaceDefinitions: namespaces, + CaveatDefinitions: caveats, + }, + }, + }, nil +} + +// UnmarshalStoredSchema unmarshals a StoredSchema from bytes. +func UnmarshalStoredSchema(data []byte) (*core.StoredSchema, error) { + var stored core.StoredSchema + if err := stored.UnmarshalVT(data); err != nil { + return nil, fmt.Errorf("failed to unmarshal schema: %w", err) + } + return &stored, nil +} + +// MarshalStoredSchema marshals a StoredSchema to bytes. +func MarshalStoredSchema(schema *core.StoredSchema) ([]byte, error) { + data, err := schema.MarshalVT() + if err != nil { + return nil, fmt.Errorf("failed to marshal schema: %w", err) + } + return data, nil +} diff --git a/internal/datastore/spanner/caveat.go b/internal/datastore/spanner/caveat.go index b810a88e2..060522f19 100644 --- a/internal/datastore/spanner/caveat.go +++ b/internal/datastore/spanner/caveat.go @@ -89,7 +89,7 @@ func (sr spannerReader) listCaveats(ctx context.Context, caveatNames []string) ( return caveats, nil } -func (rwt spannerReadWriteTXN) LegacyWriteCaveats(_ context.Context, caveats []*core.CaveatDefinition) error { +func (rwt spannerReadWriteTXN) LegacyWriteCaveats(ctx context.Context, caveats []*core.CaveatDefinition) error { names := map[string]struct{}{} mutations := make([]*spanner.Mutation, 0, len(caveats)) for _, caveat := range caveats { @@ -107,15 +107,28 @@ func (rwt spannerReadWriteTXN) LegacyWriteCaveats(_ context.Context, caveats []* []string{colName, colCaveatDefinition, colCaveatTS}, []any{caveat.Name, serialized, spanner.CommitTimestamp}, )) + + // Track the buffered caveat write so we can return it from List methods + // without attempting to read from Spanner (which doesn't see buffered writes) + rwt.bufferedCaveats[caveat.Name] = caveat + // Remove from deleted set in case it was previously deleted in this transaction + delete(rwt.deletedCaveats, caveat.Name) + } + + if err := rwt.spannerRWT.BufferWrite(mutations); err != nil { + return err } - return rwt.spannerRWT.BufferWrite(mutations) + return nil } -func (rwt spannerReadWriteTXN) LegacyDeleteCaveats(_ context.Context, names []string) error { +func (rwt spannerReadWriteTXN) LegacyDeleteCaveats(ctx context.Context, names []string) error { keys := make([]spanner.Key, 0, len(names)) for _, n := range names { keys = append(keys, spanner.Key{n}) + // Remove from buffered caveats and mark as deleted so List methods won't return it + delete(rwt.bufferedCaveats, n) + rwt.deletedCaveats[n] = struct{}{} } err := rwt.spannerRWT.BufferWrite([]*spanner.Mutation{ spanner.Delete(tableCaveat, spanner.KeySetFromKeys(keys...)), @@ -124,7 +137,7 @@ func (rwt spannerReadWriteTXN) LegacyDeleteCaveats(_ context.Context, names []st return fmt.Errorf(errUnableToDeleteCaveat, err) } - return err + return nil } func ContextualizedCaveatFrom(name spanner.NullString, context spanner.NullJSON) (*core.ContextualizedCaveat, error) { diff --git a/internal/datastore/spanner/migrations/driver.go b/internal/datastore/spanner/migrations/driver.go index f3df460ae..e3bdd96cc 100644 --- a/internal/datastore/spanner/migrations/driver.go +++ b/internal/datastore/spanner/migrations/driver.go @@ -101,6 +101,10 @@ func (smd *SpannerMigrationDriver) RunTx(ctx context.Context, f migrate.TxMigrat } func (smd *SpannerMigrationDriver) WriteVersion(_ context.Context, rwt *spanner.ReadWriteTransaction, version, replaced string) error { + // Use mutations (BufferWrite) for version tracking. Mutations are applied at commit time + // and don't participate in DML sequence number tracking, avoiding conflicts with upTx + // functions that may use DML. Spanner allows DML followed by mutations in the same + // transaction, so this is safe regardless of what upTx does. return rwt.BufferWrite([]*spanner.Mutation{ spanner.Delete(tableSchemaVersion, spanner.KeySetFromKeys(spanner.Key{replaced})), spanner.Insert(tableSchemaVersion, []string{colVersionNum}, []any{version}), diff --git a/internal/datastore/spanner/migrations/zz_migration.0012_add_schema_tables.go b/internal/datastore/spanner/migrations/zz_migration.0012_add_schema_tables.go new file mode 100644 index 000000000..2f44accf1 --- /dev/null +++ b/internal/datastore/spanner/migrations/zz_migration.0012_add_schema_tables.go @@ -0,0 +1,40 @@ +package migrations + +import ( + "context" + + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" +) + +const ( + createSchemaTable = `CREATE TABLE schema ( + name STRING(1024) NOT NULL, + chunk_index INT64 NOT NULL, + chunk_data BYTES(MAX) NOT NULL, + timestamp TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true) + ) PRIMARY KEY (name, chunk_index)` + + createSchemaRevisionTable = `CREATE TABLE schema_revision ( + name STRING(1024) NOT NULL, + schema_hash BYTES(MAX) NOT NULL, + timestamp TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true) + ) PRIMARY KEY (name)` +) + +func init() { + if err := SpannerMigrations.Register("add-schema-tables", "add-expiration-support", func(ctx context.Context, w Wrapper) error { + updateOp, err := w.adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{ + Database: w.client.DatabaseName(), + Statements: []string{ + createSchemaTable, + createSchemaRevisionTable, + }, + }) + if err != nil { + return err + } + return updateOp.Wait(ctx) + }, nil); err != nil { + panic("failed to register migration: " + err.Error()) + } +} diff --git a/internal/datastore/spanner/migrations/zz_migration.0013_populate_schema_tables.go b/internal/datastore/spanner/migrations/zz_migration.0013_populate_schema_tables.go new file mode 100644 index 000000000..c868e67f2 --- /dev/null +++ b/internal/datastore/spanner/migrations/zz_migration.0013_populate_schema_tables.go @@ -0,0 +1,169 @@ +package migrations + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "sort" + + "cloud.google.com/go/spanner" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" +) + +const ( + schemaChunkSize = 1024 * 1024 // 1MB chunks + currentSchemaVersion = 1 + unifiedSchemaName = "unified_schema" + schemaRevisionName = "current" +) + +func init() { + if err := SpannerMigrations.Register("populate-schema-tables", "add-schema-tables", + nil, // No DDL changes needed + func(ctx context.Context, rwt *spanner.ReadWriteTransaction) error { + // Read all existing namespaces + stmt := spanner.Statement{ + SQL: `SELECT namespace, serialized_config + FROM namespace_config + WHERE timestamp = (SELECT MAX(timestamp) FROM namespace_config nc WHERE nc.namespace = namespace_config.namespace)`, + } + + iter := rwt.Query(ctx, stmt) + defer iter.Stop() + + namespaces := make(map[string]*core.NamespaceDefinition) + err := iter.Do(func(row *spanner.Row) error { + var name string + var config []byte + if err := row.Columns(&name, &config); err != nil { + return fmt.Errorf("failed to scan namespace: %w", err) + } + + var ns core.NamespaceDefinition + if err := ns.UnmarshalVT(config); err != nil { + return fmt.Errorf("failed to unmarshal namespace %s: %w", name, err) + } + namespaces[name] = &ns + return nil + }) + if err != nil { + return fmt.Errorf("failed to query namespaces: %w", err) + } + + // Read all existing caveats + stmt = spanner.Statement{ + SQL: `SELECT name, definition + FROM caveat + WHERE timestamp = (SELECT MAX(timestamp) FROM caveat c WHERE c.name = caveat.name)`, + } + + iter = rwt.Query(ctx, stmt) + defer iter.Stop() + + caveats := make(map[string]*core.CaveatDefinition) + err = iter.Do(func(row *spanner.Row) error { + var name string + var definition []byte + if err := row.Columns(&name, &definition); err != nil { + return fmt.Errorf("failed to scan caveat: %w", err) + } + + var caveat core.CaveatDefinition + if err := caveat.UnmarshalVT(definition); err != nil { + return fmt.Errorf("failed to unmarshal caveat %s: %w", name, err) + } + caveats[name] = &caveat + return nil + }) + if err != nil { + return fmt.Errorf("failed to query caveats: %w", err) + } + + // If there are no namespaces or caveats, skip migration + if len(namespaces) == 0 && len(caveats) == 0 { + return nil + } + + // Generate canonical schema for hash computation + allDefs := make([]compiler.SchemaDefinition, 0, len(namespaces)+len(caveats)) + for _, ns := range namespaces { + allDefs = append(allDefs, ns) + } + for _, caveat := range caveats { + allDefs = append(allDefs, caveat) + } + + // Sort alphabetically for canonical ordering + sort.Slice(allDefs, func(i, j int) bool { + return allDefs[i].GetName() < allDefs[j].GetName() + }) + + // Generate canonical schema text + canonicalSchemaText, _, err := generator.GenerateSchema(allDefs) + if err != nil { + return fmt.Errorf("failed to generate canonical schema: %w", err) + } + + // Compute SHA256 hash + hashBytes := sha256.Sum256([]byte(canonicalSchemaText)) + schemaHash := hashBytes[:] + + // Generate user-facing schema text + schemaText, _, err := generator.GenerateSchema(allDefs) + if err != nil { + return fmt.Errorf("failed to generate schema: %w", err) + } + + // Create stored schema proto + storedSchema := &core.StoredSchema{ + Version: currentSchemaVersion, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: schemaText, + SchemaHash: hex.EncodeToString(schemaHash), + NamespaceDefinitions: namespaces, + CaveatDefinitions: caveats, + }, + }, + } + + // Marshal schema + schemaData, err := storedSchema.MarshalVT() + if err != nil { + return fmt.Errorf("failed to marshal schema: %w", err) + } + + // Insert schema chunks using mutations + var mutations []*spanner.Mutation + for chunkIndex := 0; chunkIndex*schemaChunkSize < len(schemaData); chunkIndex++ { + start := chunkIndex * schemaChunkSize + end := start + schemaChunkSize + if end > len(schemaData) { + end = len(schemaData) + } + chunk := schemaData[start:end] + + mutations = append(mutations, spanner.Insert( + "schema", + []string{"name", "chunk_index", "chunk_data", "timestamp"}, + []any{unifiedSchemaName, chunkIndex, chunk, spanner.CommitTimestamp}, + )) + } + + // Insert schema hash + mutations = append(mutations, spanner.Insert( + "schema_revision", + []string{"name", "schema_hash", "timestamp"}, + []any{schemaRevisionName, schemaHash, spanner.CommitTimestamp}, + )) + + // Apply mutations + return rwt.BufferWrite(mutations) + }); err != nil { + panic("failed to register migration: " + err.Error()) + } +} diff --git a/internal/datastore/spanner/options.go b/internal/datastore/spanner/options.go index d32a98525..a224d84fd 100644 --- a/internal/datastore/spanner/options.go +++ b/internal/datastore/spanner/options.go @@ -8,6 +8,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/common" log "github.com/authzed/spicedb/internal/logging" + dsoptions "github.com/authzed/spicedb/pkg/datastore/options" ) // DatastoreMetricsOption is an option for configuring the metrics that are emitted @@ -58,6 +59,8 @@ type spannerOptions struct { columnOptimizationOption common.ColumnOptimizationOption watchDisabled bool datastoreMetricsOption DatastoreMetricsOption + schemaMode dsoptions.SchemaMode + schemaCacheOptions dsoptions.SchemaCacheOptions } type migrationPhase uint8 @@ -292,3 +295,17 @@ func WithWatchDisabled(isDisabled bool) Option { po.watchDisabled = isDisabled } } + +// WithSchemaMode sets the experimental schema mode for the datastore. +func WithSchemaMode(mode dsoptions.SchemaMode) Option { + return func(po *spannerOptions) { + po.schemaMode = mode + } +} + +// WithSchemaCacheOptions sets the schema cache options for the datastore. +func WithSchemaCacheOptions(cacheOptions dsoptions.SchemaCacheOptions) Option { + return func(po *spannerOptions) { + po.schemaCacheOptions = cacheOptions + } +} diff --git a/internal/datastore/spanner/reader.go b/internal/datastore/spanner/reader.go index 958709f13..bafd1326a 100644 --- a/internal/datastore/spanner/reader.go +++ b/internal/datastore/spanner/reader.go @@ -13,10 +13,10 @@ import ( "github.com/authzed/spicedb/internal/datastore/common" "github.com/authzed/spicedb/internal/datastore/revisions" - schemautil "github.com/authzed/spicedb/internal/datastore/schema" + schemaadapter "github.com/authzed/spicedb/internal/datastore/schema" "github.com/authzed/spicedb/internal/telemetry/otelconv" "github.com/authzed/spicedb/pkg/datastore" - "github.com/authzed/spicedb/pkg/datastore/options" + dsoptions "github.com/authzed/spicedb/pkg/datastore/options" core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/tuple" ) @@ -37,9 +37,13 @@ type spannerReader struct { txSource txFactory filterMaximumIDCount uint16 schema common.SchemaInformation + schemaMode dsoptions.SchemaMode + snapshotRevision datastore.Revision + schemaHash string + schemaReaderWriter *common.SQLSchemaReaderWriter[any, revisions.TimestampRevision] } -func (sr spannerReader) CountRelationships(ctx context.Context, name string) (int, error) { +func (sr *spannerReader) CountRelationships(ctx context.Context, name string) (int, error) { // Ensure the counter exists. counters, err := sr.lookupCounters(ctx, name) if err != nil { @@ -79,11 +83,11 @@ func (sr spannerReader) CountRelationships(ctx context.Context, name string) (in const noFilterOnCounterName = "" -func (sr spannerReader) LookupCounters(ctx context.Context) ([]datastore.RelationshipCounter, error) { +func (sr *spannerReader) LookupCounters(ctx context.Context) ([]datastore.RelationshipCounter, error) { return sr.lookupCounters(ctx, noFilterOnCounterName) } -func (sr spannerReader) lookupCounters(ctx context.Context, optionalFilterName string) ([]datastore.RelationshipCounter, error) { +func (sr *spannerReader) lookupCounters(ctx context.Context, optionalFilterName string) ([]datastore.RelationshipCounter, error) { key := spanner.AllKeys() if optionalFilterName != noFilterOnCounterName { key = spanner.Key{optionalFilterName} @@ -132,27 +136,27 @@ func (sr spannerReader) lookupCounters(ctx context.Context, optionalFilterName s return counters, nil } -func (sr spannerReader) QueryRelationships( +func (sr *spannerReader) QueryRelationships( ctx context.Context, filter datastore.RelationshipsFilter, - opts ...options.QueryOptionsOption, + opts ...dsoptions.QueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(sr.schema, sr.filterMaximumIDCount).FilterWithRelationshipsFilter(filter) if err != nil { return nil, err } - builtOpts := options.NewQueryOptionsWithOptions(opts...) + builtOpts := dsoptions.NewQueryOptionsWithOptions(opts...) indexingHint := IndexingHintForQueryShape(sr.schema, builtOpts.QueryShape) qBuilder = qBuilder.WithIndexingHint(indexingHint) return sr.executor.ExecuteQuery(ctx, qBuilder, opts...) } -func (sr spannerReader) ReverseQueryRelationships( +func (sr *spannerReader) ReverseQueryRelationships( ctx context.Context, subjectsFilter datastore.SubjectsFilter, - opts ...options.ReverseQueryOptionsOption, + opts ...dsoptions.ReverseQueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(sr.schema, sr.filterMaximumIDCount). FilterWithSubjectsSelectors(subjectsFilter.AsSelector()) @@ -160,7 +164,7 @@ func (sr spannerReader) ReverseQueryRelationships( return nil, err } - queryOpts := options.NewReverseQueryOptionsWithOptions(opts...) + queryOpts := dsoptions.NewReverseQueryOptionsWithOptions(opts...) if queryOpts.ResRelation != nil { qBuilder = qBuilder. @@ -173,13 +177,13 @@ func (sr spannerReader) ReverseQueryRelationships( return sr.executor.ExecuteQuery(ctx, qBuilder, - options.WithLimit(queryOpts.LimitForReverse), - options.WithAfter(queryOpts.AfterForReverse), - options.WithSort(queryOpts.SortForReverse), - options.WithSkipCaveats(queryOpts.SkipCaveatsForReverse), - options.WithSkipExpiration(queryOpts.SkipExpirationForReverse), - options.WithQueryShape(queryOpts.QueryShapeForReverse), - options.WithSQLExplainCallbackForTest(queryOpts.SQLExplainCallbackForTestForReverse), + dsoptions.WithLimit(queryOpts.LimitForReverse), + dsoptions.WithAfter(queryOpts.AfterForReverse), + dsoptions.WithSort(queryOpts.SortForReverse), + dsoptions.WithSkipCaveats(queryOpts.SkipCaveatsForReverse), + dsoptions.WithSkipExpiration(queryOpts.SkipExpirationForReverse), + dsoptions.WithQueryShape(queryOpts.QueryShapeForReverse), + dsoptions.WithSQLExplainCallbackForTest(queryOpts.SQLExplainCallbackForTestForReverse), ) } @@ -289,7 +293,7 @@ func queryExecutor(txSource txFactory) common.ExecuteReadRelsQueryFunc { } } -func (sr spannerReader) LegacyReadNamespaceByName(ctx context.Context, nsName string) (*core.NamespaceDefinition, datastore.Revision, error) { +func (sr *spannerReader) LegacyReadNamespaceByName(ctx context.Context, nsName string) (*core.NamespaceDefinition, datastore.Revision, error) { nsKey := spanner.Key{nsName} row, err := sr.txSource().ReadRow( ctx, @@ -318,7 +322,7 @@ func (sr spannerReader) LegacyReadNamespaceByName(ctx context.Context, nsName st return ns, revisions.NewForTime(updated), nil } -func (sr spannerReader) LegacyListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { +func (sr *spannerReader) LegacyListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { iter := sr.txSource().Read( ctx, tableNamespace, @@ -335,7 +339,7 @@ func (sr spannerReader) LegacyListAllNamespaces(ctx context.Context) ([]datastor return allNamespaces, nil } -func (sr spannerReader) LegacyLookupNamespacesWithNames(ctx context.Context, nsNames []string) ([]datastore.RevisionedNamespace, error) { +func (sr *spannerReader) LegacyLookupNamespacesWithNames(ctx context.Context, nsNames []string) ([]datastore.RevisionedNamespace, error) { if len(nsNames) == 0 { return nil, nil } @@ -403,7 +407,58 @@ var queryTuplesForDelete = sql.Select( // SchemaReader returns a SchemaReader for reading schema information. func (sr *spannerReader) SchemaReader() (datastore.SchemaReader, error) { - return schemautil.NewLegacySchemaReaderAdapter(sr), nil + // Wrap the reader with an unexported schema reader + reader := &spannerSchemaReader{r: sr} + return schemaadapter.NewSchemaReader(reader, sr.schemaMode, sr.snapshotRevision), nil } -var _ datastore.Reader = (*spannerReader)(nil) +// spannerSchemaReader wraps a spannerReader and implements DualSchemaReader. +// This prevents direct access to schema read methods from the reader. +type spannerSchemaReader struct { + r *spannerReader +} + +// ReadStoredSchema implements datastore.SingleStoreSchemaReader +func (ssr *spannerSchemaReader) ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) { + // Create a read-only executor for reading schema chunks + executor := &spannerSchemaReadExecutor{txSource: ssr.r.txSource} + + // Use the shared schema reader/writer to read the schema with the hash + return ssr.r.schemaReaderWriter.ReadSchema(ctx, executor, ssr.r.snapshotRevision, datastore.SchemaHash(ssr.r.schemaHash)) +} + +// LegacyLookupNamespacesWithNames delegates to the underlying reader +func (ssr *spannerSchemaReader) LegacyLookupNamespacesWithNames(ctx context.Context, nsNames []string) ([]datastore.RevisionedDefinition[*core.NamespaceDefinition], error) { + return ssr.r.LegacyLookupNamespacesWithNames(ctx, nsNames) +} + +// LegacyReadCaveatByName delegates to the underlying reader +func (ssr *spannerSchemaReader) LegacyReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { + return ssr.r.LegacyReadCaveatByName(ctx, name) +} + +// LegacyListAllCaveats delegates to the underlying reader +func (ssr *spannerSchemaReader) LegacyListAllCaveats(ctx context.Context) ([]datastore.RevisionedCaveat, error) { + return ssr.r.LegacyListAllCaveats(ctx) +} + +// LegacyLookupCaveatsWithNames delegates to the underlying reader +func (ssr *spannerSchemaReader) LegacyLookupCaveatsWithNames(ctx context.Context, names []string) ([]datastore.RevisionedCaveat, error) { + return ssr.r.LegacyLookupCaveatsWithNames(ctx, names) +} + +// LegacyReadNamespaceByName delegates to the underlying reader +func (ssr *spannerSchemaReader) LegacyReadNamespaceByName(ctx context.Context, nsName string) (*core.NamespaceDefinition, datastore.Revision, error) { + return ssr.r.LegacyReadNamespaceByName(ctx, nsName) +} + +// LegacyListAllNamespaces delegates to the underlying reader +func (ssr *spannerSchemaReader) LegacyListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { + return ssr.r.LegacyListAllNamespaces(ctx) +} + +var ( + _ datastore.Reader = (*spannerReader)(nil) + _ datastore.LegacySchemaReader = (*spannerReader)(nil) + _ datastore.DualSchemaReader = (*spannerSchemaReader)(nil) +) diff --git a/internal/datastore/spanner/readwrite.go b/internal/datastore/spanner/readwrite.go index 64339f7d2..141312bf9 100644 --- a/internal/datastore/spanner/readwrite.go +++ b/internal/datastore/spanner/readwrite.go @@ -3,7 +3,10 @@ package spanner import ( "cmp" "context" + "crypto/sha256" + "encoding/hex" "fmt" + "sort" "cloud.google.com/go/spanner" sq "github.com/Masterminds/squirrel" @@ -13,12 +16,14 @@ import ( "github.com/authzed/spicedb/internal/datastore/common" "github.com/authzed/spicedb/internal/datastore/revisions" - schemautil "github.com/authzed/spicedb/internal/datastore/schema" + schemaadapter "github.com/authzed/spicedb/internal/datastore/schema" log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/pkg/datastore" - "github.com/authzed/spicedb/pkg/datastore/options" + dsoptions "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/genutil/mapz" core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" "github.com/authzed/spicedb/pkg/spiceerrors" "github.com/authzed/spicedb/pkg/tuple" ) @@ -26,11 +31,105 @@ import ( type spannerReadWriteTXN struct { spannerReader spannerRWT *spanner.ReadWriteTransaction + + // IMPORTANT: Spanner Read-Write Transaction Limitation + // ===================================================== + // In Cloud Spanner, reads within a read-write transaction do NOT see the effects of + // buffered writes (mutations) performed earlier in that same transaction. This is because + // writes are buffered locally at the client and are not sent to the server until commit. + // This is a fundamental Spanner design, not an emulator limitation. + // + // To work around this, we track all schema writes and deletes in memory maps below. + // When List methods are called, we merge buffered writes with committed data read from + // Spanner, ensuring the legacy schema writer can correctly compute diffs without attempting + // to read buffered writes from Spanner. + + // bufferedNamespaces tracks namespaces written in this transaction + bufferedNamespaces map[string]*core.NamespaceDefinition + + // deletedNamespaces tracks namespaces deleted in this transaction + deletedNamespaces map[string]struct{} + + // bufferedCaveats tracks caveats written in this transaction + bufferedCaveats map[string]*core.CaveatDefinition + + // deletedCaveats tracks caveats deleted in this transaction + deletedCaveats map[string]struct{} } const inLimit = 10_000 // https://cloud.google.com/spanner/quotas#query-limits -func (rwt spannerReadWriteTXN) RegisterCounter(ctx context.Context, name string, filter *core.RelationshipFilter) error { +// LegacyListAllNamespaces reads namespaces from Spanner and merges them with any buffered writes. +// This is necessary because in Spanner, buffered writes in a read-write transaction are not visible +// to reads in the same transaction. The buffered map contains namespaces written in this transaction, +// and the deleted map tracks namespaces deleted in this transaction. +func (rwt *spannerReadWriteTXN) LegacyListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { + // First, read from Spanner (this will get committed data, not buffered writes) + existing, err := rwt.spannerReader.LegacyListAllNamespaces(ctx) + if err != nil { + return nil, err + } + + // Build a map of existing namespaces by name, excluding deleted ones + merged := make(map[string]datastore.RevisionedNamespace) + for _, ns := range existing { + if _, deleted := rwt.deletedNamespaces[ns.Definition.Name]; !deleted { + merged[ns.Definition.Name] = ns + } + } + + // Overlay buffered writes (these override anything read from Spanner) + for name, def := range rwt.bufferedNamespaces { + merged[name] = datastore.RevisionedNamespace{ + Definition: def, + LastWrittenRevision: datastore.NoRevision, // Will be set on commit + } + } + + // Convert map back to slice + result := make([]datastore.RevisionedNamespace, 0, len(merged)) + for _, ns := range merged { + result = append(result, ns) + } + + return result, nil +} + +// LegacyListAllCaveats reads caveats from Spanner and merges them with any buffered writes. +// See LegacyListAllNamespaces for the rationale. +func (rwt *spannerReadWriteTXN) LegacyListAllCaveats(ctx context.Context) ([]datastore.RevisionedCaveat, error) { + // First, read from Spanner (this will get committed data, not buffered writes) + existing, err := rwt.spannerReader.LegacyListAllCaveats(ctx) + if err != nil { + return nil, err + } + + // Build a map of existing caveats by name, excluding deleted ones + merged := make(map[string]datastore.RevisionedCaveat) + for _, caveat := range existing { + if _, deleted := rwt.deletedCaveats[caveat.Definition.Name]; !deleted { + merged[caveat.Definition.Name] = caveat + } + } + + // Overlay buffered writes (these override anything read from Spanner) + for name, def := range rwt.bufferedCaveats { + merged[name] = datastore.RevisionedCaveat{ + Definition: def, + LastWrittenRevision: datastore.NoRevision, // Will be set on commit + } + } + + // Convert map back to slice + result := make([]datastore.RevisionedCaveat, 0, len(merged)) + for _, caveat := range merged { + result = append(result, caveat) + } + + return result, nil +} + +func (rwt *spannerReadWriteTXN) RegisterCounter(ctx context.Context, name string, filter *core.RelationshipFilter) error { // Ensure the counter doesn't already exist. counters, err := rwt.lookupCounters(ctx, name) if err != nil { @@ -60,7 +159,7 @@ func (rwt spannerReadWriteTXN) RegisterCounter(ctx context.Context, name string, return nil } -func (rwt spannerReadWriteTXN) UnregisterCounter(ctx context.Context, name string) error { +func (rwt *spannerReadWriteTXN) UnregisterCounter(ctx context.Context, name string) error { // Ensure the counter exists. counters, err := rwt.lookupCounters(ctx, name) if err != nil { @@ -82,7 +181,7 @@ func (rwt spannerReadWriteTXN) UnregisterCounter(ctx context.Context, name strin return nil } -func (rwt spannerReadWriteTXN) StoreCounterValue(ctx context.Context, name string, value int, computedAtRevision datastore.Revision) error { +func (rwt *spannerReadWriteTXN) StoreCounterValue(ctx context.Context, name string, value int, computedAtRevision datastore.Revision) error { // Ensure the counter exists. counters, err := rwt.lookupCounters(ctx, name) if err != nil { @@ -108,7 +207,7 @@ func (rwt spannerReadWriteTXN) StoreCounterValue(ctx context.Context, name strin return nil } -func (rwt spannerReadWriteTXN) WriteRelationships(ctx context.Context, mutations []tuple.RelationshipUpdate) error { +func (rwt *spannerReadWriteTXN) WriteRelationships(ctx context.Context, mutations []tuple.RelationshipUpdate) error { var rowCountChange int64 for _, mutation := range mutations { txnMut, countChange, err := spannerMutation(ctx, mutation.Operation, mutation.Relationship) @@ -149,7 +248,7 @@ func spannerMutation( return txnMut, countChange, err } -func (rwt spannerReadWriteTXN) DeleteRelationships(ctx context.Context, filter *v1.RelationshipFilter, opts ...options.DeleteOptionsOption) (uint64, bool, error) { +func (rwt *spannerReadWriteTXN) DeleteRelationships(ctx context.Context, filter *v1.RelationshipFilter, opts ...dsoptions.DeleteOptionsOption) (uint64, bool, error) { numDeleted, limitReached, err := deleteWithFilter(ctx, rwt.spannerRWT, filter, opts...) if err != nil { return 0, false, fmt.Errorf(errUnableToDeleteRelationships, err) @@ -158,8 +257,8 @@ func (rwt spannerReadWriteTXN) DeleteRelationships(ctx context.Context, filter * return numDeleted, limitReached, nil } -func deleteWithFilter(ctx context.Context, rwt *spanner.ReadWriteTransaction, filter *v1.RelationshipFilter, opts ...options.DeleteOptionsOption) (uint64, bool, error) { - delOpts := options.NewDeleteOptionsWithOptionsAndDefaults(opts...) +func deleteWithFilter(ctx context.Context, rwt *spanner.ReadWriteTransaction, filter *v1.RelationshipFilter, opts ...dsoptions.DeleteOptionsOption) (uint64, bool, error) { + delOpts := dsoptions.NewDeleteOptionsWithOptionsAndDefaults(opts...) var delLimit uint64 if delOpts.DeleteLimit != nil && *delOpts.DeleteLimit > 0 { delLimit = *delOpts.DeleteLimit @@ -356,7 +455,7 @@ func caveatVals(r tuple.Relationship) []any { return vals } -func (rwt spannerReadWriteTXN) LegacyWriteNamespaces(_ context.Context, newConfigs ...*core.NamespaceDefinition) error { +func (rwt *spannerReadWriteTXN) LegacyWriteNamespaces(ctx context.Context, newConfigs ...*core.NamespaceDefinition) error { mutations := make([]*spanner.Mutation, 0, len(newConfigs)) for _, newConfig := range newConfigs { serialized, err := newConfig.MarshalVT() @@ -369,12 +468,22 @@ func (rwt spannerReadWriteTXN) LegacyWriteNamespaces(_ context.Context, newConfi []string{colNamespaceName, colNamespaceConfig, colTimestamp}, []any{newConfig.Name, serialized, spanner.CommitTimestamp}, )) + + // Track the buffered namespace write so we can return it from List methods + // without attempting to read from Spanner (which doesn't see buffered writes) + rwt.bufferedNamespaces[newConfig.Name] = newConfig + // Remove from deleted set in case it was previously deleted in this transaction + delete(rwt.deletedNamespaces, newConfig.Name) + } + + if err := rwt.spannerRWT.BufferWrite(mutations); err != nil { + return err } - return rwt.spannerRWT.BufferWrite(mutations) + return nil } -func (rwt spannerReadWriteTXN) LegacyDeleteNamespaces(ctx context.Context, nsNames []string, delOption datastore.DeleteNamespacesRelationshipsOption) error { +func (rwt *spannerReadWriteTXN) LegacyDeleteNamespaces(ctx context.Context, nsNames []string, delOption datastore.DeleteNamespacesRelationshipsOption) error { if len(nsNames) == 0 { return nil } @@ -407,16 +516,175 @@ func (rwt spannerReadWriteTXN) LegacyDeleteNamespaces(ctx context.Context, nsNam if err != nil { return fmt.Errorf(errUnableToDeleteConfig, err) } + + // Remove from buffered namespaces and mark as deleted so List methods won't return it + delete(rwt.bufferedNamespaces, nsName) + rwt.deletedNamespaces[nsName] = struct{}{} } return nil } -func (rwt spannerReadWriteTXN) SchemaWriter() (datastore.SchemaWriter, error) { - return schemautil.NewLegacySchemaWriterAdapter(rwt, rwt), nil +func (rwt *spannerReadWriteTXN) SchemaWriter() (datastore.SchemaWriter, error) { + // Wrap the transaction with an unexported schema writer + writer := &spannerSchemaWriter{rwt: rwt} + return schemaadapter.NewSchemaWriter(writer, writer, rwt.schemaMode), nil +} + +// spannerSchemaWriter wraps a spannerReadWriteTXN and implements DualSchemaWriter. +// This prevents direct access to schema write methods from the transaction. +type spannerSchemaWriter struct { + rwt *spannerReadWriteTXN +} + +// WriteStoredSchema implements datastore.SingleStoreSchemaWriter +func (w *spannerSchemaWriter) WriteStoredSchema(ctx context.Context, schema *core.StoredSchema) error { + // Create an executor that uses the current transaction + executor := newSpannerChunkedBytesExecutor(w.rwt.spannerRWT) + + // Use the shared schema reader/writer to write the schema + // Spanner uses delete-and-insert mode so no transaction ID provider is needed + noTxID := func(ctx context.Context) any { return common.NoTransactionID[any](ctx) } + if err := w.rwt.schemaReaderWriter.WriteSchema(ctx, schema, executor, noTxID); err != nil { + return err + } + + // Write the schema hash to the schema_revision table for fast lookups + if err := w.writeSchemaHash(ctx, schema); err != nil { + return fmt.Errorf("failed to write schema hash: %w", err) + } + + return nil } -func (rwt spannerReadWriteTXN) BulkLoad(ctx context.Context, iter datastore.BulkWriteRelationshipSource) (uint64, error) { +// writeSchemaHash writes the schema hash to the schema_revision table +func (w *spannerSchemaWriter) writeSchemaHash(ctx context.Context, schema *core.StoredSchema) error { + v1 := schema.GetV1() + if v1 == nil { + return fmt.Errorf("unsupported schema version: %d", schema.Version) + } + + // Use InsertOrUpdate mutation to upsert the schema hash + mutation := spanner.InsertOrUpdate( + tableSchemaRevision, + []string{"name", "schema_hash", "timestamp"}, + []any{"current", []byte(v1.SchemaHash), spanner.CommitTimestamp}, + ) + + if err := w.rwt.spannerRWT.BufferWrite([]*spanner.Mutation{mutation}); err != nil { + return fmt.Errorf("failed to buffer schema hash write: %w", err) + } + + return nil +} + +// writeSchemaHashFromDefinitions writes the schema hash computed from the given definitions +func (rwt *spannerReadWriteTXN) writeSchemaHashFromDefinitions(ctx context.Context, namespaces []datastore.RevisionedNamespace, caveats []datastore.RevisionedCaveat) error { + // Build schema definitions list + definitions := make([]compiler.SchemaDefinition, 0, len(namespaces)+len(caveats)) + for _, ns := range namespaces { + definitions = append(definitions, ns.Definition) + } + for _, caveat := range caveats { + definitions = append(definitions, caveat.Definition) + } + + // Sort definitions by name for consistent ordering + sort.Slice(definitions, func(i, j int) bool { + return definitions[i].GetName() < definitions[j].GetName() + }) + + // Generate schema text from definitions + schemaText, _, err := generator.GenerateSchema(definitions) + if err != nil { + return fmt.Errorf("failed to generate schema: %w", err) + } + + // Compute schema hash (SHA256) + hash := sha256.Sum256([]byte(schemaText)) + schemaHash := hex.EncodeToString(hash[:]) + + // Use InsertOrUpdate mutation to upsert the schema hash + mutation := spanner.InsertOrUpdate( + tableSchemaRevision, + []string{"name", "schema_hash", "timestamp"}, + []any{"current", []byte(schemaHash), spanner.CommitTimestamp}, + ) + + if err := rwt.spannerRWT.BufferWrite([]*spanner.Mutation{mutation}); err != nil { + return fmt.Errorf("failed to buffer schema hash write: %w", err) + } + + return nil +} + +// ReadStoredSchema implements datastore.SingleStoreSchemaReader to satisfy DualSchemaReader interface requirements +func (w *spannerSchemaWriter) ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) { + // Create an executor that uses the current write transaction for reads + // This ensures we read within the same transaction, avoiding "transaction already committed" errors + executor := newSpannerChunkedBytesExecutor(w.rwt.spannerRWT) + + // Use the shared schema reader/writer to read the schema + // Pass empty string for transaction reads to bypass cache + return w.rwt.schemaReaderWriter.ReadSchema(ctx, executor, nil, datastore.NoSchemaHashInTransaction) +} + +// LegacyWriteNamespaces delegates to the underlying transaction +func (w *spannerSchemaWriter) LegacyWriteNamespaces(ctx context.Context, newConfigs ...*core.NamespaceDefinition) error { + return w.rwt.LegacyWriteNamespaces(ctx, newConfigs...) +} + +// LegacyDeleteNamespaces delegates to the underlying transaction +func (w *spannerSchemaWriter) LegacyDeleteNamespaces(ctx context.Context, nsNames []string, delOption datastore.DeleteNamespacesRelationshipsOption) error { + return w.rwt.LegacyDeleteNamespaces(ctx, nsNames, delOption) +} + +// LegacyLookupNamespacesWithNames delegates to the underlying transaction +func (w *spannerSchemaWriter) LegacyLookupNamespacesWithNames(ctx context.Context, nsNames []string) ([]datastore.RevisionedDefinition[*core.NamespaceDefinition], error) { + return w.rwt.LegacyLookupNamespacesWithNames(ctx, nsNames) +} + +// LegacyWriteCaveats delegates to the underlying transaction +func (w *spannerSchemaWriter) LegacyWriteCaveats(ctx context.Context, caveats []*core.CaveatDefinition) error { + return w.rwt.LegacyWriteCaveats(ctx, caveats) +} + +// LegacyDeleteCaveats delegates to the underlying transaction +func (w *spannerSchemaWriter) LegacyDeleteCaveats(ctx context.Context, names []string) error { + return w.rwt.LegacyDeleteCaveats(ctx, names) +} + +// LegacyReadCaveatByName delegates to the underlying transaction +func (w *spannerSchemaWriter) LegacyReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { + return w.rwt.LegacyReadCaveatByName(ctx, name) +} + +// LegacyListAllCaveats delegates to the underlying transaction +func (w *spannerSchemaWriter) LegacyListAllCaveats(ctx context.Context) ([]datastore.RevisionedCaveat, error) { + return w.rwt.LegacyListAllCaveats(ctx) +} + +// LegacyLookupCaveatsWithNames delegates to the underlying transaction +func (w *spannerSchemaWriter) LegacyLookupCaveatsWithNames(ctx context.Context, names []string) ([]datastore.RevisionedCaveat, error) { + return w.rwt.LegacyLookupCaveatsWithNames(ctx, names) +} + +// LegacyReadNamespaceByName delegates to the underlying transaction +func (w *spannerSchemaWriter) LegacyReadNamespaceByName(ctx context.Context, nsName string) (*core.NamespaceDefinition, datastore.Revision, error) { + return w.rwt.LegacyReadNamespaceByName(ctx, nsName) +} + +// LegacyListAllNamespaces delegates to the underlying transaction +func (w *spannerSchemaWriter) LegacyListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { + return w.rwt.LegacyListAllNamespaces(ctx) +} + +// WriteLegacySchemaHashFromDefinitions implements datastore.LegacySchemaHashWriter +func (w *spannerSchemaWriter) WriteLegacySchemaHashFromDefinitions(ctx context.Context, namespaces []datastore.RevisionedNamespace, caveats []datastore.RevisionedCaveat) error { + return w.rwt.writeSchemaHashFromDefinitions(ctx, namespaces, caveats) +} + +func (rwt *spannerReadWriteTXN) BulkLoad(ctx context.Context, iter datastore.BulkWriteRelationshipSource) (uint64, error) { var numLoaded uint64 var rel *tuple.Relationship var err error @@ -439,4 +707,9 @@ func (rwt spannerReadWriteTXN) BulkLoad(ctx context.Context, iter datastore.Bulk return numLoaded, nil } -var _ datastore.ReadWriteTransaction = (*spannerReadWriteTXN)(nil) +var ( + _ datastore.ReadWriteTransaction = (*spannerReadWriteTXN)(nil) + _ datastore.LegacySchemaWriter = (*spannerReadWriteTXN)(nil) + _ datastore.DualSchemaWriter = (*spannerSchemaWriter)(nil) + _ datastore.DualSchemaReader = (*spannerSchemaWriter)(nil) +) diff --git a/internal/datastore/spanner/revisions.go b/internal/datastore/spanner/revisions.go index b365a420c..2a438123b 100644 --- a/internal/datastore/spanner/revisions.go +++ b/internal/datastore/spanner/revisions.go @@ -14,29 +14,35 @@ import ( var ( ParseRevisionString = revisions.RevisionParser(revisions.Timestamp) nowStmt = spanner.NewStatement("SELECT CURRENT_TIMESTAMP()") + nowWithHashStmt = spanner.NewStatement(` + SELECT + CURRENT_TIMESTAMP(), + COALESCE((SELECT schema_hash FROM schema_revision WHERE name = 'current' ORDER BY timestamp DESC LIMIT 1), b'') + `) ) -func (sd *spannerDatastore) headRevisionInternal(ctx context.Context) (datastore.Revision, error) { - now, err := sd.now(ctx) +func (sd *spannerDatastore) headRevisionInternal(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { + now, schemaHash, err := sd.nowWithHash(ctx) if err != nil { - return datastore.NoRevision, fmt.Errorf(errRevision, err) + return datastore.NoRevision, "", fmt.Errorf(errRevision, err) } - return revisions.NewForTime(now), nil + return revisions.NewForTime(now), schemaHash, nil } -func (sd *spannerDatastore) HeadRevision(ctx context.Context) (datastore.Revision, error) { +func (sd *spannerDatastore) HeadRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { return sd.headRevisionInternal(ctx) } -func (sd *spannerDatastore) now(ctx context.Context) (time.Time, error) { +func (sd *spannerDatastore) nowWithHash(ctx context.Context) (time.Time, datastore.SchemaHash, error) { var timestamp time.Time - if err := sd.client.Single().Query(ctx, nowStmt).Do(func(r *spanner.Row) error { - return r.Columns(×tamp) + var schemaHash []byte + if err := sd.client.Single().Query(ctx, nowWithHashStmt).Do(func(r *spanner.Row) error { + return r.Columns(×tamp, &schemaHash) }); err != nil { - return timestamp, fmt.Errorf(errRevision, err) + return timestamp, "", fmt.Errorf(errRevision, err) } - return timestamp, nil + return timestamp, datastore.SchemaHash(schemaHash), nil } func (sd *spannerDatastore) staleHeadRevision(ctx context.Context) (datastore.Revision, error) { diff --git a/internal/datastore/spanner/schema.go b/internal/datastore/spanner/schema.go index 8b5dabc39..069696e69 100644 --- a/internal/datastore/spanner/schema.go +++ b/internal/datastore/spanner/schema.go @@ -35,6 +35,13 @@ const ( tableTransactionMetadata = "transaction_metadata" colTransactionTag = "transaction_tag" colMetadata = "metadata" + + tableSchema = "schema" + tableSchemaRevision = "schema_revision" + colSchemaName = "name" + colSchemaChunkIndex = "chunk_index" + colSchemaChunkData = "chunk_data" + colSchemaHash = "schema_hash" ) var allRelationshipCols = []string{ diff --git a/internal/datastore/spanner/schema_chunker.go b/internal/datastore/spanner/schema_chunker.go new file mode 100644 index 000000000..ed5693f1d --- /dev/null +++ b/internal/datastore/spanner/schema_chunker.go @@ -0,0 +1,167 @@ +package spanner + +import ( + "context" + "errors" + "fmt" + + "cloud.google.com/go/spanner" + sq "github.com/Masterminds/squirrel" + + "github.com/authzed/spicedb/internal/datastore/common" +) + +const ( + // Spanner has no practical limit on BYTES column size, + // but we use 1MB chunks for reasonable memory usage and query performance. + spannerMaxChunkSize = 1024 * 1024 // 1MB +) + +// BaseSchemaChunkerConfig provides the base configuration for Spanner schema chunking. +// Spanner uses @p-style placeholders and delete-and-insert write mode. +var BaseSchemaChunkerConfig = common.SQLByteChunkerConfig[any]{ + TableName: tableSchema, + NameColumn: colSchemaName, + ChunkIndexColumn: colSchemaChunkIndex, + ChunkDataColumn: colSchemaChunkData, + MaxChunkSize: spannerMaxChunkSize, + PlaceholderFormat: sq.AtP, + WriteMode: common.WriteModeDeleteAndInsert, +} + +// spannerChunkedBytesExecutor implements common.ChunkedBytesExecutor for Spanner's mutation-based API. +type spannerChunkedBytesExecutor struct { + rwt *spanner.ReadWriteTransaction +} + +// newSpannerChunkedBytesExecutor creates a new executor for Spanner chunk operations. +func newSpannerChunkedBytesExecutor(rwt *spanner.ReadWriteTransaction) *spannerChunkedBytesExecutor { + return &spannerChunkedBytesExecutor{rwt: rwt} +} + +// BeginTransaction returns a transaction wrapper for Spanner operations. +func (e *spannerChunkedBytesExecutor) BeginTransaction(ctx context.Context) (common.ChunkedBytesTransaction, error) { + return &spannerChunkedBytesTransaction{rwt: e.rwt}, nil +} + +// ExecuteRead executes a SELECT query and returns chunk data. +func (e *spannerChunkedBytesExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + sql, args, err := builder.ToSql() + if err != nil { + return nil, fmt.Errorf("failed to build query: %w", err) + } + + iter := e.rwt.Query(ctx, statementFromSQL(sql, args)) + defer iter.Stop() + + chunks := make(map[int][]byte) + if err := iter.Do(func(row *spanner.Row) error { + var chunkIndex int64 + var chunkData []byte + if err := row.Columns(&chunkIndex, &chunkData); err != nil { + return err + } + chunks[int(chunkIndex)] = chunkData + return nil + }); err != nil { + return nil, err + } + + return chunks, nil +} + +// spannerChunkedBytesTransaction implements common.ChunkedBytesTransaction for Spanner. +type spannerChunkedBytesTransaction struct { + rwt *spanner.ReadWriteTransaction +} + +// ExecuteWrite converts an INSERT builder to Spanner mutations. +func (t *spannerChunkedBytesTransaction) ExecuteWrite(ctx context.Context, builder sq.InsertBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return fmt.Errorf("failed to build insert: %w", err) + } + + // Convert the INSERT statement to a Spanner mutation. + // This assumes the specific format from the chunker (validated in tests). + mutation, err := t.convertInsertToMutation(sql, args) + if err != nil { + return err + } + + return t.rwt.BufferWrite([]*spanner.Mutation{mutation}) +} + +// ExecuteDelete converts a DELETE builder to Spanner mutations. +func (t *spannerChunkedBytesTransaction) ExecuteDelete(ctx context.Context, builder sq.DeleteBuilder) error { + // For schema table, we can just delete all keys + // The chunker only deletes from the schema table by name + mutation := spanner.Delete(tableSchema, spanner.AllKeys()) + return t.rwt.BufferWrite([]*spanner.Mutation{mutation}) +} + +// ExecuteUpdate converts an UPDATE builder to Spanner mutations. +func (t *spannerChunkedBytesTransaction) ExecuteUpdate(ctx context.Context, builder sq.UpdateBuilder) error { + return errors.New("ExecuteUpdate not implemented for Spanner chunked bytes") +} + +// convertInsertToMutation converts an INSERT SQL statement to a Spanner mutation. +// This assumes the chunker generates INSERT statements in the expected format: +// INSERT INTO schema (name, chunk_index, chunk_data) VALUES (@p1, @p2, @p3) +// The format is validated in schema_chunker_test.go. +func (t *spannerChunkedBytesTransaction) convertInsertToMutation(sql string, args []any) (*spanner.Mutation, error) { + // We assume the chunker provides exactly 3 args in the correct order: + // [name, chunk_index, chunk_data] + if len(args) != 3 { + return nil, fmt.Errorf("expected 3 args from chunker, got %d", len(args)) + } + + // Add timestamp column + cols := []string{colSchemaName, colSchemaChunkIndex, colSchemaChunkData, colTimestamp} + vals := []any{args[0], args[1], args[2], spanner.CommitTimestamp} + + return spanner.Insert(tableSchema, cols, vals), nil +} + +// spannerSchemaReadExecutor implements common.ChunkedBytesExecutor for read-only operations. +type spannerSchemaReadExecutor struct { + txSource txFactory +} + +// BeginTransaction returns nil since read operations don't need transactions. +func (e *spannerSchemaReadExecutor) BeginTransaction(ctx context.Context) (common.ChunkedBytesTransaction, error) { + return nil, errors.New("BeginTransaction not supported for read-only executor") +} + +// ExecuteRead executes a SELECT query and returns chunk data. +func (e *spannerSchemaReadExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + sql, args, err := builder.ToSql() + if err != nil { + return nil, fmt.Errorf("failed to build query: %w", err) + } + + tx := e.txSource() + iter := tx.Query(ctx, statementFromSQL(sql, args)) + defer iter.Stop() + + chunks := make(map[int][]byte) + if err := iter.Do(func(row *spanner.Row) error { + var chunkIndex int64 + var chunkData []byte + if err := row.Columns(&chunkIndex, &chunkData); err != nil { + return err + } + chunks[int(chunkIndex)] = chunkData + return nil + }); err != nil { + return nil, err + } + + return chunks, nil +} + +var ( + _ common.ChunkedBytesExecutor = (*spannerChunkedBytesExecutor)(nil) + _ common.ChunkedBytesTransaction = (*spannerChunkedBytesTransaction)(nil) + _ common.ChunkedBytesExecutor = (*spannerSchemaReadExecutor)(nil) +) diff --git a/internal/datastore/spanner/schema_chunker_test.go b/internal/datastore/spanner/schema_chunker_test.go new file mode 100644 index 000000000..1b8344aec --- /dev/null +++ b/internal/datastore/spanner/schema_chunker_test.go @@ -0,0 +1,104 @@ +package spanner + +import ( + "testing" + + sq "github.com/Masterminds/squirrel" + "github.com/stretchr/testify/require" +) + +// TestChunkerInsertSQLFormat validates that the SQL chunker generates INSERT statements +// in the expected format that convertInsertToMutation assumes. +func TestChunkerInsertSQLFormat(t *testing.T) { + // This test validates the format assumption in convertInsertToMutation. + // The chunker should generate: INSERT INTO schema (name, chunk_index, chunk_data) VALUES (@p1, @p2, @p3) + + builder := sq.Insert(tableSchema). + Columns(colSchemaName, colSchemaChunkIndex, colSchemaChunkData). + Values("test_name", 0, []byte("test_data")). + PlaceholderFormat(sq.AtP) + + sql, args, err := builder.ToSql() + require.NoError(t, err) + + // Validate the SQL format + expectedSQL := "INSERT INTO schema (name,chunk_index,chunk_data) VALUES (@p1,@p2,@p3)" + require.Equal(t, expectedSQL, sql, "SQL format has changed - convertInsertToMutation needs updating") + + // Validate args order and count + require.Len(t, args, 3, "Expected exactly 3 args") + require.Equal(t, "test_name", args[0], "First arg should be name") + require.Equal(t, 0, args[1], "Second arg should be chunk_index") + require.Equal(t, []byte("test_data"), args[2], "Third arg should be chunk_data") +} + +// TestChunkerDeleteSQLFormat validates that the SQL chunker generates DELETE statements +// in the expected format. +func TestChunkerDeleteSQLFormat(t *testing.T) { + // This test validates the format for DELETE operations. + // The chunker should generate: DELETE FROM schema WHERE name = @p1 + + builder := sq.Delete(tableSchema). + Where(sq.Eq{colSchemaName: "test_name"}). + PlaceholderFormat(sq.AtP) + + sql, args, err := builder.ToSql() + require.NoError(t, err) + + // Validate the SQL format + expectedSQL := "DELETE FROM schema WHERE name = @p1" + require.Equal(t, expectedSQL, sql, "DELETE SQL format has changed") + + // Validate args + require.Len(t, args, 1, "Expected exactly 1 arg") + require.Equal(t, "test_name", args[0], "First arg should be name") +} + +// TestConvertInsertToMutation validates the convertInsertToMutation function. +func TestConvertInsertToMutation(t *testing.T) { + txn := &spannerChunkedBytesTransaction{} + + tests := []struct { + name string + sql string + args []any + expectError bool + errorMsg string + }{ + { + name: "valid insert with 3 args", + sql: "INSERT INTO schema (name,chunk_index,chunk_data) VALUES (@p1,@p2,@p3)", + args: []any{"test_name", 0, []byte("data")}, + expectError: false, + }, + { + name: "invalid - too few args", + sql: "INSERT INTO schema (name,chunk_index) VALUES (@p1,@p2)", + args: []any{"test_name", 0}, + expectError: true, + errorMsg: "expected 3 args", + }, + { + name: "invalid - too many args", + sql: "INSERT INTO schema (name,chunk_index,chunk_data,extra) VALUES (@p1,@p2,@p3,@p4)", + args: []any{"test_name", 0, []byte("data"), "extra"}, + expectError: true, + errorMsg: "expected 3 args", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mutation, err := txn.convertInsertToMutation(tt.sql, tt.args) + + if tt.expectError { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errorMsg) + require.Nil(t, mutation) + } else { + require.NoError(t, err) + require.NotNil(t, mutation) + } + }) + } +} diff --git a/internal/datastore/spanner/spanner.go b/internal/datastore/spanner/spanner.go index c535bc7c2..d9e346ae9 100644 --- a/internal/datastore/spanner/spanner.go +++ b/internal/datastore/spanner/spanner.go @@ -2,6 +2,7 @@ package spanner import ( "context" + "errors" "fmt" "log/slog" "os" @@ -22,6 +23,7 @@ import ( "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" + "google.golang.org/api/iterator" "google.golang.org/api/option" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -32,7 +34,8 @@ import ( log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/internal/telemetry/otelconv" "github.com/authzed/spicedb/pkg/datastore" - "github.com/authzed/spicedb/pkg/datastore/options" + dsoptions "github.com/authzed/spicedb/pkg/datastore/options" + core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/tuple" ) @@ -90,11 +93,13 @@ type spannerDatastore struct { watchBufferWriteTimeout time.Duration watchEnabled bool - client *spanner.Client - config spannerOptions - database string - schema common.SchemaInformation + client *spanner.Client + config spannerOptions + database string + schema common.SchemaInformation + schemaMode dsoptions.SchemaMode + schemaReaderWriter *common.SQLSchemaReaderWriter[any, revisions.TimestampRevision] cachedEstimatedBytesPerRelationship atomic.Uint64 tableSizesStatsTable string @@ -245,13 +250,26 @@ func NewSpannerDatastore(ctx context.Context, database string, opts ...Option) ( tableSizesStatsTable: tableSizesStatsTable, filterMaximumIDCount: config.filterMaximumIDCount, schema: *schema, + schemaMode: config.schemaMode, } + + // Initialize schema reader/writer + ds.schemaReaderWriter, err = common.NewSQLSchemaReaderWriter[any, revisions.TimestampRevision](BaseSchemaChunkerConfig, config.schemaCacheOptions) + if err != nil { + return nil, err + } + // Optimized revision and revision checking use a stale read for the // current timestamp. // TODO: Still investigating whether a stale read can be used for // HeadRevision for FullConsistency queries. ds.SetNowFunc(ds.staleHeadRevision) + // Warm the schema cache on startup + if err := warmSchemaCache(ctx, ds); err != nil { + log.Warn().Err(err).Msg("failed to warm schema cache on startup") + } + return ds, nil } @@ -286,14 +304,50 @@ func (t *traceableRTX) Query(ctx context.Context, statement spanner.Statement) * return t.delegate.Query(ctx, statement) } -func (sd *spannerDatastore) SnapshotReader(revisionRaw datastore.Revision) datastore.Reader { +// warmSchemaCache attempts to warm the schema cache by loading the current schema. +// This is called during datastore initialization to avoid cold-start latency on first requests. +func warmSchemaCache(ctx context.Context, ds *spannerDatastore) error { + // Get the current revision and schema hash + rev, schemaHash, err := ds.HeadRevision(ctx) + if err != nil { + return fmt.Errorf("failed to get head revision: %w", err) + } + + // If there's no schema hash, there's no schema to warm + if schemaHash == "" { + log.Ctx(ctx).Debug().Msg("no schema hash found, skipping cache warming") + return nil + } + + // Create a simple executor for schema reading using a single-use read transaction + txSource := func() readTX { + return &traceableRTX{delegate: ds.client.Single()} + } + executor := &spannerSchemaReadExecutor{txSource: txSource} + + // Load the schema to populate the cache + _, err = ds.schemaReaderWriter.ReadSchema(ctx, executor, rev, schemaHash) + if err != nil { + if errors.Is(err, datastore.ErrSchemaNotFound) { + // Schema not found is not an error during warming - just means no schema yet + log.Ctx(ctx).Debug().Msg("no schema found, skipping cache warming") + return nil + } + return fmt.Errorf("failed to read schema: %w", err) + } + + log.Ctx(ctx).Info().Str("schema_hash", string(schemaHash)).Msg("schema cache warmed successfully") + return nil +} + +func (sd *spannerDatastore) SnapshotReader(revisionRaw datastore.Revision, hash datastore.SchemaHash) datastore.Reader { r := revisionRaw.(revisions.TimestampRevision) txSource := func() readTX { return &traceableRTX{delegate: sd.client.Single().WithTimestampBound(spanner.ReadTimestamp(r.Time()))} } executor := common.QueryRelationshipsExecutor{Executor: queryExecutor(txSource)} - return &spannerReader{executor, txSource, sd.filterMaximumIDCount, sd.schema} + return &spannerReader{executor, txSource, sd.filterMaximumIDCount, sd.schema, sd.schemaMode, revisionRaw, string(hash), sd.schemaReaderWriter} } func (sd *spannerDatastore) MetricsID() (string, error) { @@ -326,8 +380,8 @@ func (sd *spannerDatastore) readTransactionMetadata(ctx context.Context, transac return metadata, nil } -func (sd *spannerDatastore) ReadWriteTx(ctx context.Context, fn datastore.TxUserFunc, opts ...options.RWTOptionsOption) (datastore.Revision, error) { - config := options.NewRWTOptionsWithOptions(opts...) +func (sd *spannerDatastore) ReadWriteTx(ctx context.Context, fn datastore.TxUserFunc, opts ...dsoptions.RWTOptionsOption) (datastore.Revision, error) { + config := dsoptions.NewRWTOptionsWithOptions(opts...) ctx, span := tracer.Start(ctx, "ReadWriteTx") defer span.End() @@ -358,8 +412,12 @@ func (sd *spannerDatastore) ReadWriteTx(ctx context.Context, fn datastore.TxUser executor := common.QueryRelationshipsExecutor{Executor: queryExecutor(txSource)} rwt := &spannerReadWriteTXN{ - spannerReader{executor, txSource, sd.filterMaximumIDCount, sd.schema}, - spannerRWT, + spannerReader: spannerReader{executor, txSource, sd.filterMaximumIDCount, sd.schema, sd.schemaMode, datastore.NoRevision, string(datastore.NoSchemaHashInTransaction), sd.schemaReaderWriter}, + spannerRWT: spannerRWT, + bufferedNamespaces: make(map[string]*core.NamespaceDefinition), + deletedNamespaces: make(map[string]struct{}), + bufferedCaveats: make(map[string]*core.CaveatDefinition), + deletedCaveats: make(map[string]struct{}), } err := func() error { innerCtx, innerSpan := tracer.Start(ctx, "TxUserFunc") @@ -427,6 +485,50 @@ func (sd *spannerDatastore) Close() error { return nil } +// SchemaHashReaderForTesting returns a test-only interface for reading the schema hash directly from schema_revision table. +func (sd *spannerDatastore) SchemaHashReaderForTesting() interface { + ReadSchemaHash(ctx context.Context) (string, error) +} { + return &spannerSchemaHashReaderForTesting{client: sd.client} +} + +// SchemaModeForTesting returns the current schema mode for testing purposes. +func (sd *spannerDatastore) SchemaModeForTesting() (dsoptions.SchemaMode, error) { + return sd.schemaMode, nil +} + +type spannerSchemaHashReaderForTesting struct { + client *spanner.Client +} + +func (r *spannerSchemaHashReaderForTesting) ReadSchemaHash(ctx context.Context) (string, error) { + txn := r.client.Single() + defer txn.Close() + + iter := txn.Query(ctx, spanner.Statement{ + SQL: "SELECT schema_hash FROM schema_revision WHERE name = @name", + Params: map[string]any{ + "name": "current", + }, + }) + defer iter.Stop() + + row, err := iter.Next() + if err != nil { + if errors.Is(err, iterator.Done) { + return "", datastore.ErrSchemaNotFound + } + return "", fmt.Errorf("failed to query schema hash: %w", err) + } + + var hashBytes []byte + if err := row.Columns(&hashBytes); err != nil { + return "", fmt.Errorf("failed to scan schema hash: %w", err) + } + + return string(hashBytes), nil +} + func statementFromSQL(sql string, args []any) spanner.Statement { params := make(map[string]any, len(args)) for index, arg := range args { diff --git a/internal/datastore/spanner/spanner_test.go b/internal/datastore/spanner/spanner_test.go index 3a4872484..d5587d9b1 100644 --- a/internal/datastore/spanner/spanner_test.go +++ b/internal/datastore/spanner/spanner_test.go @@ -16,6 +16,7 @@ import ( testdatastore "github.com/authzed/spicedb/internal/testserver/datastore" "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/datastore/test" "github.com/authzed/spicedb/pkg/tuple" ) @@ -146,3 +147,29 @@ func FakeStatsTest(t *testing.T, ds datastore.Datastore) { require.NoError(t, err) require.Equal(t, uint64(3), stats.EstimatedRelationshipCount) } + +func TestSpannerDatastoreUnifiedSchemaAllModes(t *testing.T) { + ctx := context.Background() + b := testdatastore.RunSpannerForTesting(t, "", "head") + + test.UnifiedSchemaAllModesTest(t, func(schemaMode options.SchemaMode) test.DatastoreTester { + return test.DatastoreTesterFunc(func(revisionQuantization, gcInterval, gcWindow time.Duration, watchBufferLength uint16) (datastore.Datastore, error) { + ds := b.NewDatastore(t, func(engine, uri string) datastore.Datastore { + ds, err := NewSpannerDatastore( + ctx, + uri, + RevisionQuantization(revisionQuantization), + WatchBufferLength(watchBufferLength), + WithSchemaMode(schemaMode), + ) + require.NoError(t, err) + t.Cleanup(func() { + _ = ds.Close() + }) + return ds + }) + + return ds, nil + }) + }) +} diff --git a/internal/dispatch/graph/graph.go b/internal/dispatch/graph/graph.go index 98b93078b..c3de883a7 100644 --- a/internal/dispatch/graph/graph.go +++ b/internal/dispatch/graph/graph.go @@ -203,8 +203,8 @@ type localDispatcher struct { lookupResourcesHandler3 *graph.CursoredLookupResources3 } -func (ld *localDispatcher) loadNamespace(ctx context.Context, nsName string, revision datastore.Revision) (*core.NamespaceDefinition, error) { - ds := datastoremw.MustFromContext(ctx).SnapshotReader(revision) +func (ld *localDispatcher) loadNamespace(ctx context.Context, nsName string, revision datastore.Revision, schemaHash datastore.SchemaHash) (*core.NamespaceDefinition, error) { + ds := datastoremw.MustFromContext(ctx).SnapshotReader(revision, schemaHash) // Load namespace and relation from the datastore schemaReader, err := ds.SchemaReader() @@ -286,7 +286,7 @@ func (ld *localDispatcher) DispatchCheck(ctx context.Context, req *v1.DispatchCh return &v1.DispatchCheckResponse{Metadata: emptyMetadata}, rewriteError(ctx, err) } - ns, err := ld.loadNamespace(ctx, req.ResourceRelation.Namespace, revision) + ns, err := ld.loadNamespace(ctx, req.ResourceRelation.Namespace, revision, datastore.SchemaHash(req.Metadata.SchemaHash)) if err != nil { return &v1.DispatchCheckResponse{Metadata: emptyMetadata}, rewriteError(ctx, err) } @@ -352,7 +352,7 @@ func (ld *localDispatcher) DispatchExpand(ctx context.Context, req *v1.DispatchE return &v1.DispatchExpandResponse{Metadata: emptyMetadata}, err } - ns, err := ld.loadNamespace(ctx, req.ResourceAndRelation.Namespace, revision) + ns, err := ld.loadNamespace(ctx, req.ResourceAndRelation.Namespace, revision, datastore.SchemaHash(req.Metadata.SchemaHash)) if err != nil { return &v1.DispatchExpandResponse{Metadata: emptyMetadata}, err } diff --git a/internal/dispatch/graph/lookupresources2_test.go b/internal/dispatch/graph/lookupresources2_test.go index bad38fa05..916f9b9dd 100644 --- a/internal/dispatch/graph/lookupresources2_test.go +++ b/internal/dispatch/graph/lookupresources2_test.go @@ -1413,8 +1413,8 @@ type disallowedWrapper struct { disallowedQueries []tuple.RelationReference } -func (dw disallowedWrapper) SnapshotReader(rev datastore.Revision) datastore.Reader { - return disallowedReader{dw.Datastore.SnapshotReader(rev), dw.disallowedQueries} +func (dw disallowedWrapper) SnapshotReader(rev datastore.Revision, hash datastore.SchemaHash) datastore.Reader { + return disallowedReader{dw.Datastore.SnapshotReader(rev, hash), dw.disallowedQueries} } type disallowedReader struct { diff --git a/internal/dispatch/keys/keys.go b/internal/dispatch/keys/keys.go index 47c8d0449..b64810609 100644 --- a/internal/dispatch/keys/keys.go +++ b/internal/dispatch/keys/keys.go @@ -5,6 +5,7 @@ import ( datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" "github.com/authzed/spicedb/internal/namespace" + "github.com/authzed/spicedb/pkg/datastore" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" ) @@ -106,7 +107,7 @@ func (c *CanonicalKeyHandler) CheckCacheKey(ctx context.Context, req *v1.Dispatc if err != nil { return emptyDispatchCacheKey, err } - r := ds.SnapshotReader(revision) + r := ds.SnapshotReader(revision, datastore.SchemaHash(req.Metadata.SchemaHash)) _, relation, err := namespace.ReadNamespaceAndRelation( ctx, diff --git a/internal/graph/check.go b/internal/graph/check.go index 4adc2f29f..9c3e577ce 100644 --- a/internal/graph/check.go +++ b/internal/graph/check.go @@ -333,7 +333,7 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest } }() log.Ctx(ctx).Trace().Object("direct", crc.parentReq).Send() - ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision, datastore.SchemaHash(crc.parentReq.Metadata.SchemaHash)) directSubjectsAndWildcardsWithoutCaveats := 0 directSubjectsAndWildcardsWithoutExpiration := 0 @@ -673,7 +673,7 @@ func (cc *ConcurrentChecker) checkComputedUserset(ctx context.Context, crc curre // for TTU-based computed usersets, as directly computed ones reference relations within // the same namespace as the caller, and thus must be fully typed checked. if cu.Object == core.ComputedUserset_TUPLE_USERSET_OBJECT { - ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision, datastore.SchemaHash(crc.parentReq.Metadata.SchemaHash)) err := namespace.CheckNamespaceAndRelation(ctx, targetRR.Namespace, targetRR.Relation, true, ds) if err != nil { if errors.As(err, &namespace.RelationNotFoundError{}) { @@ -822,7 +822,7 @@ func checkIntersectionTupleToUserset( // Query for the subjects over which to walk the TTU. log.Ctx(ctx).Trace().Object("intersectionttu", crc.parentReq).Send() - ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision, datastore.SchemaHash(crc.parentReq.Metadata.SchemaHash)) queryOpts, err := queryOptionsForArrowRelation(ctx, ds, crc.parentReq.ResourceRelation.Namespace, ttu.GetTupleset().GetRelation()) if err != nil { return checkResultError(NewCheckFailureErr(err), emptyMetadata) @@ -988,7 +988,7 @@ func checkTupleToUserset[T relation]( defer span.End() log.Ctx(ctx).Trace().Object("ttu", crc.parentReq).Send() - ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision, datastore.SchemaHash(crc.parentReq.Metadata.SchemaHash)) queryOpts, err := queryOptionsForArrowRelation(ctx, ds, crc.parentReq.ResourceRelation.Namespace, ttu.GetTupleset().GetRelation()) if err != nil { diff --git a/internal/graph/check_isolated_test.go b/internal/graph/check_isolated_test.go index f33ce38f6..e962f1b27 100644 --- a/internal/graph/check_isolated_test.go +++ b/internal/graph/check_isolated_test.go @@ -9,6 +9,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/graph" "github.com/authzed/spicedb/internal/testfixtures" + "github.com/authzed/spicedb/pkg/datastore" ) func TestTraitsForArrowRelation(t *testing.T) { @@ -132,7 +133,7 @@ func TestTraitsForArrowRelation(t *testing.T) { require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, tc.schema, nil, require) - reader := ds.SnapshotReader(revision) + reader := ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting) traits, err := graph.TraitsForArrowRelation(t.Context(), reader, tc.namespaceName, tc.relationName) if tc.expectedError != "" { diff --git a/internal/graph/computed/computecheck.go b/internal/graph/computed/computecheck.go index be493dc74..aca5a1e3b 100644 --- a/internal/graph/computed/computecheck.go +++ b/internal/graph/computed/computecheck.go @@ -45,6 +45,7 @@ type CheckParameters struct { Subject tuple.ObjectAndRelation CaveatContext map[string]any AtRevision datastore.Revision + SchemaHash datastore.SchemaHash MaximumDepth uint32 DebugOption DebugOption CheckHints []*v1.CheckHint @@ -178,7 +179,7 @@ func computeCaveatedCheckResult(ctx context.Context, runner *cexpr.CaveatRunner, } ds := datastoremw.MustFromContext(ctx) - reader := ds.SnapshotReader(params.AtRevision) + reader := ds.SnapshotReader(params.AtRevision, params.SchemaHash) caveatResult, err := runner.RunCaveatExpression(ctx, result.Expression, params.CaveatContext, reader, cexpr.RunCaveatExpressionNoDebugging) if err != nil { diff --git a/internal/graph/expand.go b/internal/graph/expand.go index 11905e5f8..d0630dd7b 100644 --- a/internal/graph/expand.go +++ b/internal/graph/expand.go @@ -58,7 +58,7 @@ func (ce *ConcurrentExpander) expandDirect( ) ReduceableExpandFunc { log.Ctx(ctx).Trace().Object("direct", req).Send() return func(ctx context.Context, resultChan chan<- ExpandResult) { - ds := datastoremw.MustFromContext(ctx).SnapshotReader(req.Revision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(req.Revision, datastore.SchemaHash(req.Metadata.SchemaHash)) it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: req.ResourceAndRelation.Namespace, OptionalResourceIds: []string{req.ResourceAndRelation.ObjectId}, @@ -243,7 +243,7 @@ func (ce *ConcurrentExpander) expandComputedUserset(ctx context.Context, req Val } // Check if the target relation exists. If not, return nothing. - ds := datastoremw.MustFromContext(ctx).SnapshotReader(req.Revision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(req.Revision, datastore.SchemaHash(req.Metadata.SchemaHash)) err := namespace.CheckNamespaceAndRelation(ctx, start.ObjectType, cu.Relation, true, ds) if err != nil { if errors.As(err, &namespace.RelationNotFoundError{}) { @@ -277,7 +277,7 @@ func expandTupleToUserset[T relation]( expandFunc expandFunc, ) ReduceableExpandFunc { return func(ctx context.Context, resultChan chan<- ExpandResult) { - ds := datastoremw.MustFromContext(ctx).SnapshotReader(req.Revision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(req.Revision, datastore.SchemaHash(req.Metadata.SchemaHash)) it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: req.ResourceAndRelation.Namespace, OptionalResourceIds: []string{req.ResourceAndRelation.ObjectId}, diff --git a/internal/graph/lookupresources2.go b/internal/graph/lookupresources2.go index 648f6a95f..6b7c8133e 100644 --- a/internal/graph/lookupresources2.go +++ b/internal/graph/lookupresources2.go @@ -127,7 +127,7 @@ func (crr *CursoredLookupResources2) afterSameType( // Load the type system and reachability graph to find the entrypoints for the reachability. ds := datastoremw.MustFromContext(ctx) - reader := ds.SnapshotReader(req.Revision) + reader := ds.SnapshotReader(req.Revision, datastore.SchemaHash(req.Metadata.SchemaHash)) ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(reader)) vdef, err := ts.GetValidatedDefinition(ctx, req.ResourceRelation.Namespace) if err != nil { @@ -594,6 +594,7 @@ func (crr *CursoredLookupResources2) redispatchOrReport( Subject: tuple.FromCoreObjectAndRelation(parentRequest.TerminalSubject), CaveatContext: parentRequest.Context.AsMap(), AtRevision: parentRequest.Revision, + SchemaHash: datastore.SchemaHash(parentRequest.Metadata.SchemaHash), MaximumDepth: parentRequest.Metadata.DepthRemaining - 1, DebugOption: computed.NoDebugging, CheckHints: checkHints, diff --git a/internal/graph/lookupresources3.go b/internal/graph/lookupresources3.go index 066265877..da4f3314d 100644 --- a/internal/graph/lookupresources3.go +++ b/internal/graph/lookupresources3.go @@ -267,7 +267,7 @@ func (crr *CursoredLookupResources3) LookupResources3(req ValidatedLookupResourc // Build refs for the lookup resources operation. The lr3refs holds references to shared // interfaces used by various suboperations of the lookup resources operation. ds := datastoremw.MustFromContext(stream.Context()) - reader := ds.SnapshotReader(req.Revision) + reader := ds.SnapshotReader(req.Revision, datastore.SchemaHash(req.Metadata.SchemaHash)) ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(reader)) caveatRunner := caveats.NewCaveatRunner(crr.caveatTypeSet) @@ -1008,6 +1008,7 @@ func (crr *CursoredLookupResources3) filterSubjectsByCheck( Subject: tuple.FromCoreObjectAndRelation(refs.req.TerminalSubject), CaveatContext: refs.req.Context.AsMap(), AtRevision: refs.req.Revision, + SchemaHash: datastore.SchemaHash(refs.req.Metadata.SchemaHash), MaximumDepth: refs.req.Metadata.DepthRemaining - 1, DebugOption: computed.NoDebugging, CheckHints: checkHints, diff --git a/internal/graph/lookupsubjects.go b/internal/graph/lookupsubjects.go index 3fbc31540..2bb8d2e75 100644 --- a/internal/graph/lookupsubjects.go +++ b/internal/graph/lookupsubjects.go @@ -68,7 +68,7 @@ func (cl *ConcurrentLookupSubjects) LookupSubjects( } ds := datastoremw.MustFromContext(ctx) - reader := ds.SnapshotReader(req.Revision) + reader := ds.SnapshotReader(req.Revision, datastore.SchemaHash(req.Metadata.SchemaHash)) _, relation, err := namespace.ReadNamespaceAndRelation( ctx, req.ResourceRelation.Namespace, @@ -194,7 +194,7 @@ func (cl *ConcurrentLookupSubjects) lookupViaComputed( ts *schema.TypeSystem, cu *core.ComputedUserset, ) error { - ds := datastoremw.MustFromContext(ctx).SnapshotReader(parentRequest.Revision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(parentRequest.Revision, datastore.SchemaHash(parentRequest.Metadata.SchemaHash)) if err := namespace.CheckNamespaceAndRelation(ctx, parentRequest.ResourceRelation.Namespace, cu.Relation, true, ds); err != nil { if errors.As(err, &namespace.RelationNotFoundError{}) { return nil @@ -250,7 +250,7 @@ func lookupViaIntersectionTupleToUserset( ts *schema.TypeSystem, ttu *core.FunctionedTupleToUserset, ) error { - ds := datastoremw.MustFromContext(ctx).SnapshotReader(parentRequest.Revision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(parentRequest.Revision, datastore.SchemaHash(parentRequest.Metadata.SchemaHash)) opts, err := cl.queryOptionsForRelation(ctx, ts, parentRequest.ResourceRelation.Namespace, ttu.GetTupleset().GetRelation()) if err != nil { return err @@ -425,7 +425,7 @@ func lookupViaTupleToUserset[T relation]( toDispatchByTuplesetType := datasets.NewSubjectByTypeSet() relationshipsBySubjectONR := mapz.NewMultiMap[tuple.ObjectAndRelation, tuple.Relationship]() - ds := datastoremw.MustFromContext(ctx).SnapshotReader(parentRequest.Revision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(parentRequest.Revision, datastore.SchemaHash(parentRequest.Metadata.SchemaHash)) opts, err := cl.queryOptionsForRelation(ctx, ts, parentRequest.ResourceRelation.Namespace, ttu.GetTupleset().GetRelation()) if err != nil { return err diff --git a/internal/graph/lr2streams.go b/internal/graph/lr2streams.go index f27638156..e8ca0a3a4 100644 --- a/internal/graph/lr2streams.go +++ b/internal/graph/lr2streams.go @@ -14,6 +14,7 @@ import ( "github.com/authzed/spicedb/internal/taskrunner" "github.com/authzed/spicedb/internal/telemetry/otelconv" caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/genutil/mapz" core "github.com/authzed/spicedb/pkg/proto/core/v1" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" @@ -137,6 +138,7 @@ func (rdc *checkAndDispatchRunner) runChecker(ctx context.Context, startingIndex Subject: tuple.FromCoreObjectAndRelation(rdc.parentRequest.TerminalSubject), CaveatContext: rdc.parentRequest.Context.AsMap(), AtRevision: rdc.parentRequest.Revision, + SchemaHash: datastore.SchemaHash(rdc.parentRequest.Metadata.SchemaHash), MaximumDepth: rdc.parentRequest.Metadata.DepthRemaining - 1, DebugOption: computed.NoDebugging, CheckHints: checkHints, diff --git a/internal/middleware/datastore/counting_test.go b/internal/middleware/datastore/counting_test.go index c6a73ddbd..7f67c14b8 100644 --- a/internal/middleware/datastore/counting_test.go +++ b/internal/middleware/datastore/counting_test.go @@ -34,7 +34,7 @@ func TestUnaryCountingInterceptor(t *testing.T) { ds := MustFromContext(ctx) // Make some calls to trigger counting - reader := ds.SnapshotReader(datastore.NoRevision) + reader := ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting) _, _ = reader.QueryRelationships(ctx, datastore.RelationshipsFilter{}) _, _ = reader.QueryRelationships(ctx, datastore.RelationshipsFilter{}) @@ -82,7 +82,7 @@ func TestStreamCountingInterceptor(t *testing.T) { ds := MustFromContext(ss.Context()) // Make some calls to trigger counting - reader := ds.SnapshotReader(datastore.NoRevision) + reader := ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting) _, _ = reader.ReverseQueryRelationships(ss.Context(), datastore.SubjectsFilter{SubjectType: "user"}) return nil @@ -117,7 +117,7 @@ func TestUnaryCountingInterceptor_HandlerError(t *testing.T) { handler := func(ctx context.Context, req any) (any, error) { // Make a call before erroring ds := MustFromContext(ctx) - reader := ds.SnapshotReader(datastore.NoRevision) + reader := ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting) _, _, _ = reader.LegacyReadNamespaceByName(ctx, "test") return nil, &testError{} @@ -156,7 +156,7 @@ func TestStreamCountingInterceptor_HandlerError(t *testing.T) { handler := func(srv any, ss grpc.ServerStream) error { // Make a call before erroring ds := MustFromContext(ss.Context()) - reader := ds.SnapshotReader(datastore.NoRevision) + reader := ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting) _, _ = reader.LegacyListAllNamespaces(ss.Context()) return &testError{} diff --git a/internal/middleware/pertoken/pertoken_test.go b/internal/middleware/pertoken/pertoken_test.go index 4914aa8ae..177b70b12 100644 --- a/internal/middleware/pertoken/pertoken_test.go +++ b/internal/middleware/pertoken/pertoken_test.go @@ -54,12 +54,12 @@ func (t testServer) Ping(ctx context.Context, req *testpb.PingRequest) (*testpb. } } - headRev, err := ds.HeadRevision(ctx) + headRev, _, err := ds.HeadRevision(ctx) if err != nil { return nil, err } - reader := ds.SnapshotReader(headRev) + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) if reader == nil { return nil, errors.New("no snapshot reader available") } diff --git a/internal/namespace/aliasing_test.go b/internal/namespace/aliasing_test.go index 3eb7a3db7..3e64fb36b 100644 --- a/internal/namespace/aliasing_test.go +++ b/internal/namespace/aliasing_test.go @@ -7,6 +7,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/pkg/datastore" ns "github.com/authzed/spicedb/pkg/namespace" core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/schema" @@ -197,10 +198,10 @@ func TestAliasing(t *testing.T) { ds, err := dsfortesting.NewMemDBDatastoreForTesting(t, 0, 0, memdb.DisableGC) require.NoError(err) - lastRevision, err := ds.HeadRevision(t.Context()) + lastRevision, _, err := ds.HeadRevision(t.Context()) require.NoError(err) - ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds.SnapshotReader(lastRevision))) + ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds.SnapshotReader(lastRevision, datastore.NoSchemaHashForTesting))) def, err := schema.NewDefinition(ts, tc.toCheck) require.NoError(err) diff --git a/internal/namespace/annotate_test.go b/internal/namespace/annotate_test.go index 614ed56a2..2170774e9 100644 --- a/internal/namespace/annotate_test.go +++ b/internal/namespace/annotate_test.go @@ -7,6 +7,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/schema" "github.com/authzed/spicedb/pkg/schemadsl/compiler" "github.com/authzed/spicedb/pkg/schemadsl/input" @@ -32,10 +33,10 @@ func TestAnnotateNamespace(t *testing.T) { }, compiler.AllowUnprefixedObjectType()) require.NoError(err) - lastRevision, err := ds.HeadRevision(t.Context()) + lastRevision, _, err := ds.HeadRevision(t.Context()) require.NoError(err) - ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds.SnapshotReader(lastRevision))) + ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds.SnapshotReader(lastRevision, datastore.NoSchemaHashForTesting))) def, err := schema.NewDefinition(ts, compiled.ObjectDefinitions[0]) require.NoError(err) diff --git a/internal/namespace/canonicalization_test.go b/internal/namespace/canonicalization_test.go index 3496a5cf5..da70e315c 100644 --- a/internal/namespace/canonicalization_test.go +++ b/internal/namespace/canonicalization_test.go @@ -8,6 +8,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/pkg/datastore" ns "github.com/authzed/spicedb/pkg/namespace" core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/schema" @@ -527,10 +528,10 @@ func TestCanonicalization(t *testing.T) { ctx := t.Context() - lastRevision, err := ds.HeadRevision(t.Context()) + lastRevision, _, err := ds.HeadRevision(t.Context()) require.NoError(err) - ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds.SnapshotReader(lastRevision))) + ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds.SnapshotReader(lastRevision, datastore.NoSchemaHashForTesting))) def, err := schema.NewDefinition(ts, tc.toCheck) require.NoError(err) @@ -663,10 +664,10 @@ func TestCanonicalizationComparison(t *testing.T) { }, compiler.AllowUnprefixedObjectType()) require.NoError(err) - lastRevision, err := ds.HeadRevision(t.Context()) + lastRevision, _, err := ds.HeadRevision(t.Context()) require.NoError(err) - ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds.SnapshotReader(lastRevision))) + ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds.SnapshotReader(lastRevision, datastore.NoSchemaHashForTesting))) def, err := schema.NewDefinition(ts, compiled.ObjectDefinitions[0]) require.NoError(err) diff --git a/internal/namespace/util_test.go b/internal/namespace/util_test.go index f3a9df4bb..ded58883e 100644 --- a/internal/namespace/util_test.go +++ b/internal/namespace/util_test.go @@ -10,6 +10,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/namespace" "github.com/authzed/spicedb/internal/testfixtures" + "github.com/authzed/spicedb/pkg/datastore" ns "github.com/authzed/spicedb/pkg/namespace" core "github.com/authzed/spicedb/pkg/proto/core/v1" ) @@ -167,10 +168,10 @@ func TestCheckNamespaceAndRelations(t *testing.T) { ds, _ := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, tc.schema, nil, req) - rev, err := ds.HeadRevision(t.Context()) + rev, _, err := ds.HeadRevision(t.Context()) require.NoError(t, err) - reader := ds.SnapshotReader(rev) + reader := ds.SnapshotReader(rev, datastore.NoSchemaHashForTesting) err = namespace.CheckNamespaceAndRelations(t.Context(), tc.checks, reader) if tc.expectedError == "" { diff --git a/internal/relationships/validation_test.go b/internal/relationships/validation_test.go index 788099a32..95517bdff 100644 --- a/internal/relationships/validation_test.go +++ b/internal/relationships/validation_test.go @@ -9,6 +9,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/testfixtures" caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/tuple" ) @@ -335,7 +336,7 @@ func TestValidateRelationshipOperations(t *testing.T) { req.NoError(err) uds, rev := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, nil, req) - reader := uds.SnapshotReader(rev) + reader := uds.SnapshotReader(rev, datastore.NoSchemaHashForTesting) op := tuple.Create if tc.operation == core.RelationTupleUpdate_DELETE { diff --git a/internal/services/integrationtesting/consistency_datastore_test.go b/internal/services/integrationtesting/consistency_datastore_test.go index 2153b2e1b..bce951b4a 100644 --- a/internal/services/integrationtesting/consistency_datastore_test.go +++ b/internal/services/integrationtesting/consistency_datastore_test.go @@ -58,7 +58,7 @@ func TestConsistencyPerDatastore(t *testing.T) { t.Cleanup(func() { dispatcher.Close() }) accessibilitySet := consistencytestutil.BuildAccessibilitySet(t, cad.Ctx, cad.Populated, cad.DataStore) - headRevision, err := cad.DataStore.HeadRevision(cad.Ctx) + headRevision, _, err := cad.DataStore.HeadRevision(cad.Ctx) require.NoError(t, err) // Run the assertions within each file. diff --git a/internal/services/integrationtesting/consistency_test.go b/internal/services/integrationtesting/consistency_test.go index 283507181..daa5cb0f6 100644 --- a/internal/services/integrationtesting/consistency_test.go +++ b/internal/services/integrationtesting/consistency_test.go @@ -85,10 +85,10 @@ func TestConsistency(t *testing.T) { cad := consistencytestutil.LoadDataAndCreateClusterForTesting(t, "testconfigs/self.yaml.skip", testTimedelta, options...) // Validate the type system for each namespace. - headRevision, err := cad.DataStore.HeadRevision(cad.Ctx) + headRevision, _, err := cad.DataStore.HeadRevision(cad.Ctx) require.NoError(t, err) - ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(cad.DataStore.SnapshotReader(headRevision))) + ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(cad.DataStore.SnapshotReader(headRevision, datastore.NoSchemaHashForTesting))) for _, nsDef := range cad.Populated.NamespaceDefinitions { _, err := ts.GetValidatedDefinition(cad.Ctx, nsDef.Name) @@ -248,10 +248,10 @@ func runConsistencyTestSuiteForFile(t *testing.T, filePath string, useCachingDis cad := consistencytestutil.LoadDataAndCreateClusterForTesting(t, filePath, testTimedelta, options...) // Validate the type system for each namespace. - headRevision, err := cad.DataStore.HeadRevision(cad.Ctx) + headRevision, _, err := cad.DataStore.HeadRevision(cad.Ctx) require.NoError(t, err) - ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(cad.DataStore.SnapshotReader(headRevision))) + ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(cad.DataStore.SnapshotReader(headRevision, datastore.NoSchemaHashForTesting))) for _, nsDef := range cad.Populated.NamespaceDefinitions { _, err := ts.GetValidatedDefinition(cad.Ctx, nsDef.Name) @@ -1096,10 +1096,10 @@ func validateDevelopmentExpectedRels(t *testing.T, devContext *development.DevCo // validateReachableSubjectTypes validates that the reachable subject types are those expected. func validateReachableSubjectTypes(t *testing.T, vctx validationContext) { testForEachResource(t, vctx, "validate_reachable_subject_types", func(t *testing.T, resource tuple.ObjectAndRelation) { - headRev, err := vctx.clusterAndData.DataStore.HeadRevision(t.Context()) + headRev, _, err := vctx.clusterAndData.DataStore.HeadRevision(t.Context()) require.NoError(t, err) - reader := vctx.clusterAndData.DataStore.SnapshotReader(headRev) + reader := vctx.clusterAndData.DataStore.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(reader)) reachableSubjectTypes, err := ts.GetFullRecursiveSubjectTypesForRelation(t.Context(), resource.ObjectType, resource.Relation) diff --git a/internal/services/integrationtesting/consistencytestutil/accessibilityset.go b/internal/services/integrationtesting/consistencytestutil/accessibilityset.go index 1316a1f8b..87248265f 100644 --- a/internal/services/integrationtesting/consistencytestutil/accessibilityset.go +++ b/internal/services/integrationtesting/consistencytestutil/accessibilityset.go @@ -107,7 +107,7 @@ func BuildAccessibilitySet(t *testing.T, ctx context.Context, populated *validat // NOTE: We only conduct checks here for the *defined* subjects from the relationships, // rather than every possible subject, as the latter would make the consistency test suite // VERY slow, due to the combinatorial size of all possible subjects. - headRevision, err := ds.HeadRevision(ctx) + headRevision, _, err := ds.HeadRevision(ctx) require.NoError(t, err) params, err := graph.NewDefaultDispatcherParametersForTesting() diff --git a/internal/services/integrationtesting/consistencytestutil/clusteranddata.go b/internal/services/integrationtesting/consistencytestutil/clusteranddata.go index 762a49a64..5b5a4116e 100644 --- a/internal/services/integrationtesting/consistencytestutil/clusteranddata.go +++ b/internal/services/integrationtesting/consistencytestutil/clusteranddata.go @@ -56,7 +56,7 @@ func BuildDataAndCreateClusterForTesting(t *testing.T, consistencyTestFilePath s dsCtx := datastoremw.ContextWithHandle(t.Context()) require.NoError(datastoremw.SetInContext(dsCtx, ds)) - res := schema.ResolverForDatastoreReader(ds.SnapshotReader(revision)) + res := schema.ResolverForDatastoreReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting)) ts := schema.NewTypeSystem(res) // Validate the type system for each namespace. diff --git a/internal/services/integrationtesting/query_plan_consistency_test.go b/internal/services/integrationtesting/query_plan_consistency_test.go index b6ce03d57..d1809a7a1 100644 --- a/internal/services/integrationtesting/query_plan_consistency_test.go +++ b/internal/services/integrationtesting/query_plan_consistency_test.go @@ -48,7 +48,7 @@ type queryPlanConsistencyHandle struct { func (q *queryPlanConsistencyHandle) buildContext(t *testing.T) *query.Context { return query.NewLocalContext(t.Context(), - query.WithReader(q.ds.SnapshotReader(q.revision)), + query.WithReader(q.ds.SnapshotReader(q.revision, datastore.NoSchemaHashForTesting)), query.WithCaveatRunner(caveats.NewCaveatRunner(caveattypes.Default.TypeSet)), query.WithTraceLogger(query.NewTraceLogger())) // Enable tracing for debugging } @@ -61,7 +61,7 @@ func runQueryPlanConsistencyForFile(t *testing.T, filePath string) { populated, _, err := validationfile.PopulateFromFiles(t.Context(), ds, caveattypes.Default.TypeSet, []string{filePath}) require.NoError(err) - headRevision, err := ds.HeadRevision(t.Context()) + headRevision, _, err := ds.HeadRevision(t.Context()) require.NoError(err) schemaView, err := schema.BuildSchemaFromDefinitions(populated.NamespaceDefinitions, populated.CaveatDefinitions) diff --git a/internal/services/steelthreadtesting/steelthread_test.go b/internal/services/steelthreadtesting/steelthread_test.go index aad978d72..5d4541364 100644 --- a/internal/services/steelthreadtesting/steelthread_test.go +++ b/internal/services/steelthreadtesting/steelthread_test.go @@ -23,6 +23,7 @@ import ( caveattypes "github.com/authzed/spicedb/pkg/caveats/types" dsconfig "github.com/authzed/spicedb/pkg/cmd/datastore" "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/validationfile" ) @@ -42,22 +43,35 @@ func TestNonMemdbSteelThreads(t *testing.T) { t.Skip("Skipping non-memdb steelthread tests in regenerate mode") } + schemaModes := []struct { + name string + mode options.SchemaMode + }{ + {"LegacySchema", options.SchemaModeReadLegacyWriteLegacy}, + {"NewSchema", options.SchemaModeReadNewWriteNew}, + } + for _, engineID := range datastore.SortedEngineIDs() { t.Run(engineID, func(t *testing.T) { rde := testdatastore.RunDatastoreEngine(t, engineID) - for _, tc := range steelThreadTestCases { - t.Run(tc.name, func(t *testing.T) { - ds := rde.NewDatastore(t, config.DatastoreConfigInitFunc(t, - dsconfig.WithWatchBufferLength(0), - dsconfig.WithGCWindow(time.Duration(90_000_000_000_000)), - dsconfig.WithRevisionQuantization(10), - dsconfig.WithMaxRetries(50), - dsconfig.WithExperimentalColumnOptimization(true), - dsconfig.WithWriteAcquisitionTimeout(5*time.Second))) - - ds = indexcheck.WrapWithIndexCheckingDatastoreProxyIfApplicable(ds) - runSteelThreadTest(t, tc, ds) + for _, sm := range schemaModes { + t.Run(sm.name, func(t *testing.T) { + for _, tc := range steelThreadTestCases { + t.Run(tc.name, func(t *testing.T) { + ds := rde.NewDatastore(t, config.DatastoreConfigInitFunc(t, + dsconfig.WithWatchBufferLength(0), + dsconfig.WithGCWindow(time.Duration(90_000_000_000_000)), + dsconfig.WithRevisionQuantization(10), + dsconfig.WithMaxRetries(50), + dsconfig.WithExperimentalColumnOptimization(true), + dsconfig.WithWriteAcquisitionTimeout(5*time.Second), + dsconfig.WithExperimentalSchemaMode(sm.mode))) + + ds = indexcheck.WrapWithIndexCheckingDatastoreProxyIfApplicable(ds) + runSteelThreadTest(t, tc, ds) + }) + } }) } }) @@ -119,3 +133,74 @@ func runSteelThreadTest(t *testing.T, tc steelThreadTestCase, ds datastore.Datas }) } } + +// Benchmarks to compare legacy vs new schema mode performance +// These benchmarks use PostgreSQL as the target datastore + +func BenchmarkSteelThreadSchemaMode(b *testing.B) { + // Skip if not in steelthread mode + if os.Getenv("RUN_STEELTHREAD_BENCHMARKS") != "true" { + b.Skip("Set RUN_STEELTHREAD_BENCHMARKS=true to run schema mode benchmarks") + } + + schemaModes := []struct { + name string + mode options.SchemaMode + }{ + {"LegacySchema", options.SchemaModeReadLegacyWriteLegacy}, + {"NewSchema", options.SchemaModeReadNewWriteNew}, + } + + for _, sm := range schemaModes { + b.Run(sm.name, func(b *testing.B) { + // Use PostgreSQL for benchmarking + rde := testdatastore.RunDatastoreEngine(b, "postgres") + + for _, tc := range steelThreadTestCases { + b.Run(tc.name, func(b *testing.B) { + ds := rde.NewDatastore(b, config.DatastoreConfigInitFunc(b, + dsconfig.WithWatchBufferLength(0), + dsconfig.WithGCWindow(time.Duration(90_000_000_000_000)), + dsconfig.WithRevisionQuantization(10), + dsconfig.WithMaxRetries(50), + dsconfig.WithExperimentalColumnOptimization(true), + dsconfig.WithWriteAcquisitionTimeout(5*time.Second), + dsconfig.WithExperimentalSchemaMode(sm.mode))) + + ctx := context.Background() + clientConn, cleanup, _, _ := testserver.NewTestServerWithConfigAndDatastore(require.New(b), 0, 0, false, + testserver.DefaultTestServerConfig, + ds, + func(ds datastore.Datastore, require *require.Assertions) (datastore.Datastore, datastore.Revision) { + // Load in the data once + _, rev, err := validationfile.PopulateFromFiles(ctx, ds, caveattypes.Default.TypeSet, []string{"testdata/" + tc.datafile}) + require.NoError(err) + return ds, rev + }) + b.Cleanup(cleanup) + + clients := stClients{ + PermissionsClient: v1.NewPermissionsServiceClient(clientConn), + SchemaClient: v1.NewSchemaServiceClient(clientConn), + } + + // Benchmark each operation + for _, operationInfo := range tc.operations { + b.Run(operationInfo.name, func(b *testing.B) { + handler, ok := operations[operationInfo.operationName] + require.True(b, ok, "operation not found: %s", operationInfo.name) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := handler(operationInfo.arguments, clients) + if err != nil { + b.Fatal(err) + } + } + }) + } + }) + } + }) + } +} diff --git a/internal/services/v1/bulkcheck.go b/internal/services/v1/bulkcheck.go index 0d103d007..cdd199d33 100644 --- a/internal/services/v1/bulkcheck.go +++ b/internal/services/v1/bulkcheck.go @@ -48,7 +48,7 @@ const maxBulkCheckCount = 10000 func (bc *bulkChecker) checkBulkPermissions(ctx context.Context, req *v1.CheckBulkPermissionsRequest) (*v1.CheckBulkPermissionsResponse, error) { telemetry.LogicalChecks.Add(float64(len(req.Items))) - atRevision, checkedAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, checkedAt, err := consistency.RevisionAndSchemaHashFromContext(ctx) if err != nil { return nil, err } @@ -77,6 +77,7 @@ func (bc *bulkChecker) checkBulkPermissions(ctx context.Context, req *v1.CheckBu // the dispatching system already internally supports this kind of batching for performance. groupedItems, err := groupItems(ctx, groupingParameters{ atRevision: atRevision, + schemaHash: schemaHash, maxCaveatContextSize: bc.maxCaveatContextSize, maximumAPIDepth: bc.maxAPIDepth, withTracing: req.WithTracing, @@ -163,7 +164,7 @@ func (bc *bulkChecker) checkBulkPermissions(ctx context.Context, req *v1.CheckBu bulkResponseMutex.Lock() defer bulkResponseMutex.Unlock() - ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) schemaText := "" if len(debugInfos) > 0 { @@ -248,7 +249,7 @@ func (bc *bulkChecker) checkBulkPermissions(ctx context.Context, req *v1.CheckBu tr.Add(func(ctx context.Context) error { startTime := time.Now() - ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) // Ensure the check namespaces and relations are valid. err := namespace.CheckNamespaceAndRelations(ctx, diff --git a/internal/services/v1/experimental.go b/internal/services/v1/experimental.go index ef7bb84bd..2be19b568 100644 --- a/internal/services/v1/experimental.go +++ b/internal/services/v1/experimental.go @@ -327,34 +327,39 @@ func (es *experimentalServer) BulkExportRelationships( ctx := resp.Context() perfinsights.SetInContext(ctx, perfinsights.NoLabels) - atRevision, _, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, _, err := consistency.RevisionAndSchemaHashFromContext(ctx) if err != nil { return shared.RewriteErrorWithoutConfig(ctx, err) } - return BulkExport(ctx, datastoremw.MustFromContext(ctx), es.maxBatchSize, req, atRevision, resp.Send) + return BulkExport(ctx, datastoremw.MustFromContext(ctx), es.maxBatchSize, req, atRevision, schemaHash, resp.Send) } // BulkExport implements the BulkExportRelationships API functionality. Given a datastore.Datastore, it will // export stream via the sender all relationships matched by the incoming request. -// If no cursor is provided, it will fallback to the provided revision. -func BulkExport(ctx context.Context, ds datastore.ReadOnlyDatastore, batchSize uint64, req *v1.BulkExportRelationshipsRequest, fallbackRevision datastore.Revision, sender func(response *v1.BulkExportRelationshipsResponse) error) error { +// If no cursor is provided, it will fallback to the provided revision and schema hash. +func BulkExport(ctx context.Context, ds datastore.ReadOnlyDatastore, batchSize uint64, req *v1.BulkExportRelationshipsRequest, fallbackRevision datastore.Revision, fallbackSchemaHash datastore.SchemaHash, sender func(response *v1.BulkExportRelationshipsResponse) error) error { if req.OptionalLimit > 0 && uint64(req.OptionalLimit) > batchSize { return shared.RewriteErrorWithoutConfig(ctx, NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), batchSize)) } atRevision := fallbackRevision + schemaHash := fallbackSchemaHash var curNamespace string var cur dsoptions.Cursor if req.OptionalCursor != nil { var err error - atRevision, curNamespace, cur, err = decodeCursor(ds, req.OptionalCursor) + atRevision, schemaHash, curNamespace, cur, err = decodeCursor(ds, req.OptionalCursor) if err != nil { return shared.RewriteErrorWithoutConfig(ctx, err) } + // If cursor has empty schema hash (legacy cursor), skip cache + if schemaHash == "" { + schemaHash = datastore.NoSchemaHashForLegacyCursor + } } - reader := ds.SnapshotReader(atRevision) + reader := ds.SnapshotReader(atRevision, schemaHash) namespaces, err := reader.LegacyListAllNamespaces(ctx) if err != nil { @@ -569,7 +574,7 @@ func (es *experimentalServer) ExperimentalReflectSchema(ctx context.Context, req func (es *experimentalServer) ExperimentalDiffSchema(ctx context.Context, req *v1.ExperimentalDiffSchemaRequest) (*v1.ExperimentalDiffSchemaResponse, error) { perfinsights.SetInContext(ctx, perfinsights.NoLabels) - atRevision, _, err := consistency.RevisionFromContext(ctx) + atRevision, _, _, err := consistency.RevisionAndSchemaHashFromContext(ctx) if err != nil { return nil, err } @@ -596,12 +601,12 @@ func (es *experimentalServer) ExperimentalComputablePermissions(ctx context.Cont } }) - atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, revisionReadAt, err := consistency.RevisionAndSchemaHashFromContext(ctx) if err != nil { return nil, shared.RewriteErrorWithoutConfig(ctx, err) } - ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds)) vdef, err := ts.GetValidatedDefinition(ctx, req.DefinitionName) if err != nil { @@ -679,12 +684,12 @@ func (es *experimentalServer) ExperimentalDependentRelations(ctx context.Context } }) - atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, revisionReadAt, err := consistency.RevisionAndSchemaHashFromContext(ctx) if err != nil { return nil, shared.RewriteErrorWithoutConfig(ctx, err) } - ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds)) vdef, err := ts.GetValidatedDefinition(ctx, req.DefinitionName) if err != nil { @@ -804,12 +809,12 @@ func (es *experimentalServer) ExperimentalCountRelationships(ctx context.Context } ds := datastoremw.MustFromContext(ctx) - headRev, err := ds.HeadRevision(ctx) + headRev, schemaHash, err := ds.HeadRevision(ctx) if err != nil { return nil, shared.RewriteErrorWithoutConfig(ctx, err) } - snapshotReader := ds.SnapshotReader(headRev) + snapshotReader := ds.SnapshotReader(headRev, schemaHash) count, err := snapshotReader.CountRelationships(ctx, req.Name) if err != nil { return nil, shared.RewriteErrorWithoutConfig(ctx, err) @@ -859,30 +864,33 @@ func queryForEach( return cursor, nil } -func decodeCursor(ds datastore.ReadOnlyDatastore, encoded *v1.Cursor) (datastore.Revision, string, dsoptions.Cursor, error) { +func decodeCursor(ds datastore.ReadOnlyDatastore, encoded *v1.Cursor) (datastore.Revision, datastore.SchemaHash, string, dsoptions.Cursor, error) { decoded, err := cursor.Decode(encoded) if err != nil { - return datastore.NoRevision, "", nil, err + return datastore.NoRevision, "", "", nil, err } if decoded.GetV1() == nil { - return datastore.NoRevision, "", nil, errors.New("malformed cursor: no V1 in OneOf") + return datastore.NoRevision, "", "", nil, errors.New("malformed cursor: no V1 in OneOf") } if len(decoded.GetV1().Sections) != 2 { - return datastore.NoRevision, "", nil, errors.New("malformed cursor: wrong number of components") + return datastore.NoRevision, "", "", nil, errors.New("malformed cursor: wrong number of components") } atRevision, err := ds.RevisionFromString(decoded.GetV1().Revision) if err != nil { - return datastore.NoRevision, "", nil, err + return datastore.NoRevision, "", "", nil, err } cur, err := tuple.Parse(decoded.GetV1().GetSections()[1]) if err != nil { - return datastore.NoRevision, "", nil, fmt.Errorf("malformed cursor: invalid encoded relation tuple: %w", err) + return datastore.NoRevision, "", "", nil, fmt.Errorf("malformed cursor: invalid encoded relation tuple: %w", err) } - // Returns the current namespace and the cursor. - return atRevision, decoded.GetV1().GetSections()[0], dsoptions.ToCursor(cur), nil + // Extract schema hash from cursor (could be empty for legacy cursors) + schemaHash := datastore.SchemaHash(decoded.GetV1().SchemaHash) + + // Returns the current namespace, schema hash, and cursor. + return atRevision, schemaHash, decoded.GetV1().GetSections()[0], dsoptions.ToCursor(cur), nil } diff --git a/internal/services/v1/grouping.go b/internal/services/v1/grouping.go index 99b681d2e..0cdb92baa 100644 --- a/internal/services/v1/grouping.go +++ b/internal/services/v1/grouping.go @@ -17,6 +17,7 @@ type groupedCheckParameters struct { type groupingParameters struct { atRevision datastore.Revision + schemaHash datastore.SchemaHash maximumAPIDepth uint32 maxCaveatContextSize int withTracing bool @@ -65,6 +66,7 @@ func checkParametersFromCheckBulkPermissionsRequestItem( ResourceType: tuple.RR(bc.Resource.ObjectType, bc.Permission), Subject: tuple.ONR(bc.Subject.Object.ObjectType, bc.Subject.Object.ObjectId, normalizeSubjectRelation(bc.Subject)), CaveatContext: caveatContext, + SchemaHash: params.schemaHash, AtRevision: params.atRevision, MaximumDepth: params.maximumAPIDepth, DebugOption: debugOption, diff --git a/internal/services/v1/permissions.go b/internal/services/v1/permissions.go index 8c8f3e73b..fdefcf185 100644 --- a/internal/services/v1/permissions.go +++ b/internal/services/v1/permissions.go @@ -74,12 +74,12 @@ func (ps *permissionServer) CheckPermission(ctx context.Context, req *v1.CheckPe return ps.checkPermissionWithQueryPlan(ctx, req) } - atRevision, checkedAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, checkedAt, err := consistency.RevisionAndSchemaHashFromContext(ctx) if err != nil { return nil, ps.rewriteError(ctx, err) } - ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) caveatContext, err := GetCaveatContext(ctx, req.Context, ps.config.MaxCaveatContextSize) if err != nil { @@ -122,6 +122,7 @@ func (ps *permissionServer) CheckPermission(ctx context.Context, req *v1.CheckPe Subject: tuple.ONR(req.Subject.Object.ObjectType, req.Subject.Object.ObjectId, normalizeSubjectRelation(req.Subject)), CaveatContext: caveatContext, AtRevision: atRevision, + SchemaHash: schemaHash, MaximumDepth: ps.config.MaximumAPIDepth, DebugOption: debugOption, }, @@ -241,12 +242,12 @@ func (ps *permissionServer) ExpandPermissionTree(ctx context.Context, req *v1.Ex telemetry.LogicalChecks.Inc() - atRevision, expandedAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, expandedAt, err := consistency.RevisionAndSchemaHashFromContext(ctx) if err != nil { return nil, ps.rewriteError(ctx, err) } - ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) err = namespace.CheckNamespaceAndRelation(ctx, req.Resource.ObjectType, req.Permission, false, ds) if err != nil { @@ -479,12 +480,12 @@ func (ps *permissionServer) lookupResources3(req *v1.LookupResourcesRequest, res ctx := resp.Context() - atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, revisionReadAt, err := consistency.RevisionAndSchemaHashFromContext(ctx) if err != nil { return ps.rewriteError(ctx, err) } - ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) if err := namespace.CheckNamespaceAndRelations(ctx, []namespace.TypeAndRelationToCheck{ @@ -554,7 +555,7 @@ func (ps *permissionServer) lookupResources3(req *v1.LookupResourcesRequest, res if len(item.AfterResponseCursorSections) > 0 { currentCursor = item.AfterResponseCursorSections - ec, err := cursor.EncodeFromDispatchCursorSections(currentCursor, lrRequestHash, atRevision, map[string]string{ + ec, err := cursor.EncodeFromDispatchCursorSections(currentCursor, lrRequestHash, atRevision, schemaHash, map[string]string{ lrv3CursorFlag: "1", }) if err != nil { @@ -625,12 +626,12 @@ func (ps *permissionServer) lookupResources2(req *v1.LookupResourcesRequest, res ctx := resp.Context() - atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, revisionReadAt, err := consistency.RevisionAndSchemaHashFromContext(ctx) if err != nil { return ps.rewriteError(ctx, err) } - ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) if err := namespace.CheckNamespaceAndRelations(ctx, []namespace.TypeAndRelationToCheck{ @@ -700,7 +701,7 @@ func (ps *permissionServer) lookupResources2(req *v1.LookupResourcesRequest, res alreadyPublishedPermissionedResourceIds[found.ResourceId] = struct{}{} } - encodedCursor, err := cursor.EncodeFromDispatchCursor(result.AfterResponseCursor, lrRequestHash, atRevision, map[string]string{ + encodedCursor, err := cursor.EncodeFromDispatchCursor(result.AfterResponseCursor, lrRequestHash, atRevision, schemaHash, map[string]string{ lrv2CursorFlag: "1", }) if err != nil { @@ -776,12 +777,12 @@ func (ps *permissionServer) LookupSubjects(req *v1.LookupSubjectsRequest, resp v return ps.rewriteError(ctx, status.Errorf(codes.Unimplemented, "concrete limit is not yet supported")) } - atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, revisionReadAt, err := consistency.RevisionAndSchemaHashFromContext(ctx) if err != nil { return ps.rewriteError(ctx, err) } - ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) caveatContext, err := GetCaveatContext(ctx, req.Context, ps.config.MaxCaveatContextSize) if err != nil { @@ -1136,34 +1137,39 @@ func (ps *permissionServer) ExportBulkRelationships( return labelsForFilter(req.OptionalRelationshipFilter) }) - atRevision, _, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, _, err := consistency.RevisionAndSchemaHashFromContext(ctx) if err != nil { return shared.RewriteErrorWithoutConfig(ctx, err) } - return ExportBulk(ctx, datastoremw.MustFromContext(ctx), uint64(ps.config.MaxBulkExportRelationshipsLimit), req, atRevision, resp.Send) + return ExportBulk(ctx, datastoremw.MustFromContext(ctx), uint64(ps.config.MaxBulkExportRelationshipsLimit), req, atRevision, schemaHash, resp.Send) } // ExportBulk implements the ExportBulkRelationships API functionality. Given a datastore.Datastore, it will // export stream via the sender all relationships matched by the incoming request. -// If no cursor is provided, it will fallback to the provided revision. -func ExportBulk(ctx context.Context, ds datastore.Datastore, batchSize uint64, req *v1.ExportBulkRelationshipsRequest, fallbackRevision datastore.Revision, sender func(response *v1.ExportBulkRelationshipsResponse) error) error { +// If no cursor is provided, it will fallback to the provided revision and schema hash. +func ExportBulk(ctx context.Context, ds datastore.Datastore, batchSize uint64, req *v1.ExportBulkRelationshipsRequest, fallbackRevision datastore.Revision, fallbackSchemaHash datastore.SchemaHash, sender func(response *v1.ExportBulkRelationshipsResponse) error) error { if req.OptionalLimit > 0 && uint64(req.OptionalLimit) > batchSize { return shared.RewriteErrorWithoutConfig(ctx, NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), batchSize)) } atRevision := fallbackRevision + schemaHash := fallbackSchemaHash var curNamespace string var cur dsoptions.Cursor if req.OptionalCursor != nil { var err error - atRevision, curNamespace, cur, err = decodeCursor(ds, req.OptionalCursor) + atRevision, schemaHash, curNamespace, cur, err = decodeCursor(ds, req.OptionalCursor) if err != nil { return shared.RewriteErrorWithoutConfig(ctx, err) } + // If cursor has empty schema hash (legacy cursor), skip cache + if schemaHash == "" { + schemaHash = datastore.NoSchemaHashForLegacyCursor + } } - reader := ds.SnapshotReader(atRevision) + reader := ds.SnapshotReader(atRevision, schemaHash) namespaces, err := reader.LegacyListAllNamespaces(ctx) if err != nil { diff --git a/internal/services/v1/permissions_queryplan.go b/internal/services/v1/permissions_queryplan.go index e5c92fc90..9648da03a 100644 --- a/internal/services/v1/permissions_queryplan.go +++ b/internal/services/v1/permissions_queryplan.go @@ -16,13 +16,13 @@ import ( // checkPermissionWithQueryPlan executes a permission check using the query plan API. // This builds an iterator tree from the schema and executes it against the datastore. func (ps *permissionServer) checkPermissionWithQueryPlan(ctx context.Context, req *v1.CheckPermissionRequest) (*v1.CheckPermissionResponse, error) { - atRevision, checkedAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, checkedAt, err := consistency.RevisionAndSchemaHashFromContext(ctx) if err != nil { return nil, ps.rewriteError(ctx, err) } ds := datastoremw.MustFromContext(ctx) - reader := ds.SnapshotReader(atRevision) + reader := ds.SnapshotReader(atRevision, schemaHash) // Load all namespace and caveat definitions to build the schema // TODO: Better schema caching diff --git a/internal/services/v1/reflectionutil.go b/internal/services/v1/reflectionutil.go index 0d2b7cc69..b40f82ad8 100644 --- a/internal/services/v1/reflectionutil.go +++ b/internal/services/v1/reflectionutil.go @@ -16,12 +16,12 @@ import ( func loadCurrentSchema(ctx context.Context) (*diff.DiffableSchema, datastore.Revision, error) { ds := datastoremw.MustFromContext(ctx) - atRevision, _, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, _, err := consistency.RevisionAndSchemaHashFromContext(ctx) if err != nil { return nil, nil, err } - reader := ds.SnapshotReader(atRevision) + reader := ds.SnapshotReader(atRevision, schemaHash) namespacesAndRevs, err := reader.LegacyListAllNamespaces(ctx) if err != nil { diff --git a/internal/services/v1/relationships.go b/internal/services/v1/relationships.go index 030742cd9..540d73b28 100644 --- a/internal/services/v1/relationships.go +++ b/internal/services/v1/relationships.go @@ -195,12 +195,12 @@ func (ps *permissionServer) ReadRelationships(req *v1.ReadRelationshipsRequest, } ctx := resp.Context() - atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, revisionReadAt, err := consistency.RevisionAndSchemaHashFromContext(ctx) if err != nil { return ps.rewriteError(ctx, err) } - ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) if err := validateRelationshipsFilter(ctx, req.RelationshipFilter, ds); err != nil { return ps.rewriteError(ctx, err) } @@ -298,7 +298,7 @@ func (ps *permissionServer) ReadRelationships(req *v1.ReadRelationshipsRequest, } dispatchCursor.Sections[0] = tuple.StringWithoutCaveatOrExpiration(rel) - encodedCursor, err := cursor.EncodeFromDispatchCursor(dispatchCursor, rrRequestHash, atRevision, nil) + encodedCursor, err := cursor.EncodeFromDispatchCursor(dispatchCursor, rrRequestHash, atRevision, schemaHash, nil) if err != nil { return ps.rewriteError(ctx, err) } diff --git a/internal/services/v1/relationships_test.go b/internal/services/v1/relationships_test.go index 4398e064b..acd19cbe6 100644 --- a/internal/services/v1/relationships_test.go +++ b/internal/services/v1/relationships_test.go @@ -1365,7 +1365,7 @@ func TestDeleteRelationshipsBeyondLimitPartial(t *testing.T) { for i := 0; i < 10; i++ { iterations++ - headRev, err := ds.HeadRevision(t.Context()) + headRev, _, err := ds.HeadRevision(t.Context()) require.NoError(err) beforeDelete := readOfType(require, "document", client, zedtoken.MustNewFromRevisionForTesting(headRev)) @@ -1379,7 +1379,7 @@ func TestDeleteRelationshipsBeyondLimitPartial(t *testing.T) { }) require.NoError(err) - headRev, err = ds.HeadRevision(context.Background()) + headRev, _, err = ds.HeadRevision(context.Background()) require.NoError(err) afterDelete := readOfType(require, "document", client, zedtoken.MustNewFromRevisionForTesting(headRev)) diff --git a/internal/services/v1/schema.go b/internal/services/v1/schema.go index 75aa488ba..6b76888dc 100644 --- a/internal/services/v1/schema.go +++ b/internal/services/v1/schema.go @@ -82,12 +82,12 @@ func (ss *schemaServer) ReadSchema(ctx context.Context, _ *v1.ReadSchemaRequest) // Schema is always read from the head revision. ds := datastoremw.MustFromContext(ctx) - headRevision, err := ds.HeadRevision(ctx) + headRevision, schemaHash, err := ds.HeadRevision(ctx) if err != nil { return nil, ss.rewriteError(ctx, err) } - reader := ds.SnapshotReader(headRevision) + reader := ds.SnapshotReader(headRevision, schemaHash) schemaReader, err := reader.SchemaReader() if err != nil { @@ -233,7 +233,7 @@ func (ss *schemaServer) ReflectSchema(ctx context.Context, req *v1.ReflectSchema func (ss *schemaServer) DiffSchema(ctx context.Context, req *v1.DiffSchemaRequest) (*v1.DiffSchemaResponse, error) { perfinsights.SetInContext(ctx, perfinsights.NoLabels) - atRevision, _, err := consistency.RevisionFromContext(ctx) + atRevision, _, _, err := consistency.RevisionAndSchemaHashFromContext(ctx) if err != nil { return nil, err } @@ -260,12 +260,12 @@ func (ss *schemaServer) ComputablePermissions(ctx context.Context, req *v1.Compu } }) - atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, revisionReadAt, err := consistency.RevisionAndSchemaHashFromContext(ctx) if err != nil { return nil, shared.RewriteErrorWithoutConfig(ctx, err) } - ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds)) vdef, err := ts.GetValidatedDefinition(ctx, req.DefinitionName) if err != nil { @@ -347,12 +347,12 @@ func (ss *schemaServer) DependentRelations(ctx context.Context, req *v1.Dependen } }) - atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, revisionReadAt, err := consistency.RevisionAndSchemaHashFromContext(ctx) if err != nil { return nil, shared.RewriteErrorWithoutConfig(ctx, err) } - ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds)) vdef, err := ts.GetValidatedDefinition(ctx, req.DefinitionName) if err != nil { diff --git a/internal/services/v1/schema_test.go b/internal/services/v1/schema_test.go index ed6337e84..2b8e16e37 100644 --- a/internal/services/v1/schema_test.go +++ b/internal/services/v1/schema_test.go @@ -13,6 +13,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/memdb" tf "github.com/authzed/spicedb/internal/testfixtures" "github.com/authzed/spicedb/internal/testserver" + "github.com/authzed/spicedb/pkg/datastore" core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/spiceerrors" "github.com/authzed/spicedb/pkg/testutil" @@ -550,10 +551,10 @@ func TestSchemaUnchangedNamespaces(t *testing.T) { require.NoError(t, err) // Ensure the `user` definition was not modified. - rev, err := ds.HeadRevision(t.Context()) + rev, _, err := ds.HeadRevision(t.Context()) require.NoError(t, err) - reader := ds.SnapshotReader(rev) + reader := ds.SnapshotReader(rev, datastore.NoSchemaHashForTesting) _, userRevision, err := reader.LegacyReadNamespaceByName(t.Context(), "user") require.NoError(t, err) diff --git a/internal/services/v1/watch.go b/internal/services/v1/watch.go index 38148a864..18f9bb20c 100644 --- a/internal/services/v1/watch.go +++ b/internal/services/v1/watch.go @@ -55,6 +55,7 @@ func (ws *watchServer) Watch(req *v1.WatchRequest, stream v1.WatchService_WatchS ds := datastoremw.MustFromContext(ctx) var afterRevision datastore.Revision + var schemaHash datastore.SchemaHash if req.OptionalStartCursor != nil && req.OptionalStartCursor.Token != "" { decodedRevision, tokenStatus, err := zedtoken.DecodeRevision(req.OptionalStartCursor, ds) if err != nil { @@ -71,15 +72,17 @@ func (ws *watchServer) Watch(req *v1.WatchRequest, stream v1.WatchService_WatchS } afterRevision = decodedRevision + // Schema hash not available from decoded token - use sentinel to load on demand + schemaHash = datastore.NoSchemaHashForWatch } else { var err error - afterRevision, err = ds.OptimizedRevision(ctx) + afterRevision, schemaHash, err = ds.OptimizedRevision(ctx) if err != nil { return status.Errorf(codes.Unavailable, "failed to start watch: %s", err) } } - reader := ds.SnapshotReader(afterRevision) + reader := ds.SnapshotReader(afterRevision, schemaHash) filters, err := buildRelationshipFilters(req, stream, reader, ws, ctx) if err != nil { diff --git a/internal/testfixtures/datastore.go b/internal/testfixtures/datastore.go index fa3265907..e80d7b243 100644 --- a/internal/testfixtures/datastore.go +++ b/internal/testfixtures/datastore.go @@ -147,7 +147,7 @@ var StandardCaveatedRelationships = []string{ // EmptyDatastore returns an empty datastore for testing. func EmptyDatastore(ds datastore.Datastore, require *require.Assertions) (datastore.Datastore, datastore.Revision) { - rev, err := ds.HeadRevision(context.Background()) + rev, _, err := ds.HeadRevision(context.Background()) require.NoError(err) return ds, rev } @@ -299,7 +299,7 @@ func (tc RelationshipChecker) ExactRelationshipIterator(ctx context.Context, rel dsFilter, err := datastore.RelationshipsFilterFromPublicFilter(filter) tc.Require.NoError(err) - iter, err := tc.DS.SnapshotReader(rev).QueryRelationships(ctx, dsFilter, options.WithQueryShape(queryshape.Varying)) + iter, err := tc.DS.SnapshotReader(rev, datastore.NoSchemaHashForTesting).QueryRelationships(ctx, dsFilter, options.WithQueryShape(queryshape.Varying)) tc.Require.NoError(err) return iter } diff --git a/internal/testfixtures/validating.go b/internal/testfixtures/validating.go index 7b8317eb0..2c380d681 100644 --- a/internal/testfixtures/validating.go +++ b/internal/testfixtures/validating.go @@ -25,8 +25,8 @@ func NewValidatingDatastore(delegate datastore.Datastore) datastore.Datastore { return validatingDatastore{Datastore: delegate} } -func (vd validatingDatastore) SnapshotReader(revision datastore.Revision) datastore.Reader { - return validatingSnapshotReader{vd.Datastore.SnapshotReader(revision)} +func (vd validatingDatastore) SnapshotReader(revision datastore.Revision, schemaHash datastore.SchemaHash) datastore.Reader { + return validatingSnapshotReader{vd.Datastore.SnapshotReader(revision, schemaHash)} } func (vd validatingDatastore) ReadWriteTx( diff --git a/internal/testserver/datastore/postgres.go b/internal/testserver/datastore/postgres.go index 3f3fdfd49..7c4e5229c 100644 --- a/internal/testserver/datastore/postgres.go +++ b/internal/testserver/datastore/postgres.go @@ -3,6 +3,7 @@ package datastore import ( "context" "fmt" + "sync" "testing" "time" @@ -38,8 +39,9 @@ type postgresTester struct { creds string targetMigration string pgbouncerProxy *container - pool *dockertest.Pool + pool *dockertest.Pool // GUARDED_BY(poolMutex) useContainerHostname bool + poolMutex sync.Mutex // protects concurrent access to pool.Retry() } // RunPostgresForTesting returns a RunningEngineForTest for postgres @@ -220,6 +222,9 @@ func (b *postgresTester) runPgbouncerForTesting(t testing.TB, pool *dockertest.P func (b *postgresTester) initializeHostConnection(t testing.TB) (conn *pgx.Conn) { hostname, port := b.getHostHostnameAndPort() uri := fmt.Sprintf("postgresql://%s@%s:%s/?sslmode=disable", b.creds, hostname, port) + + // Lock to prevent concurrent access to pool.Retry() which has internal state + b.poolMutex.Lock() err := b.pool.Retry(func() error { var err error ctx, cancelConnect := context.WithTimeout(context.Background(), dockerBootTimeout) @@ -230,6 +235,8 @@ func (b *postgresTester) initializeHostConnection(t testing.TB) (conn *pgx.Conn) } return nil }) + b.poolMutex.Unlock() + require.NoError(t, err) return conn } diff --git a/magefiles/lint.go b/magefiles/lint.go index 0308c58ff..80bfb5ff9 100644 --- a/magefiles/lint.go +++ b/magefiles/lint.go @@ -107,6 +107,9 @@ func (Lint) Analyzers() error { // Skip our dispatch codec logic that explicitly calls MarshalVT with proto.Marshal as a fallback // Skip our internal telemetry reporter which uses a prometheus proto definition that we don't control "-protomarshalcheck.skip-pkg=github.com/authzed/spicedb/pkg/proto/dispatch/v1,github.com/authzed/spicedb/internal/telemetry", + "-testsentinelcheck", + // Skip test utility packages, development packages, and internal datastore common (which defines bypass sentinels) + "-testsentinelcheck.skip-pkg=github.com/authzed/spicedb/pkg/datastore/test,github.com/authzed/spicedb/internal/services/integrationtesting/consistencytestutil,github.com/authzed/spicedb/internal/testfixtures,github.com/authzed/spicedb/pkg/development,github.com/authzed/spicedb/internal/datastore/common", "github.com/authzed/spicedb/...", ) } diff --git a/pkg/cmd/datastore/datastore.go b/pkg/cmd/datastore/datastore.go index 51d09f3ac..9c4620845 100644 --- a/pkg/cmd/datastore/datastore.go +++ b/pkg/cmd/datastore/datastore.go @@ -23,6 +23,7 @@ import ( "github.com/authzed/spicedb/internal/sharederrors" caveattypes "github.com/authzed/spicedb/pkg/caveats/types" "github.com/authzed/spicedb/pkg/datastore" + dsoptions "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/validationfile" ) @@ -182,8 +183,13 @@ type Config struct { AllowedMigrations []string `debugmap:"visible"` // Experimental - ExperimentalColumnOptimization bool `debugmap:"visible"` - EnableRevisionHeartbeat bool `debugmap:"visible"` + ExperimentalColumnOptimization bool `debugmap:"visible"` + EnableRevisionHeartbeat bool `debugmap:"visible"` + ExperimentalSchemaMode dsoptions.SchemaMode `debugmap:"visible"` + + // Internal - used for flag parsing + experimentalSchemaModeString string `debugmap:"hidden"` + schemaCacheOptions dsoptions.SchemaCacheOptions `debugmap:"hidden"` } //go:generate go run github.com/ecordell/optgen -sensitive-field-name-matches uri,secure -output zz_generated.relintegritykey.options.go . RelIntegrityKey @@ -336,6 +342,7 @@ func RegisterDatastoreFlagsWithPrefix(flagSet *pflag.FlagSet, prefix string, opt } flagSet.BoolVar(&opts.ExperimentalColumnOptimization, flagName("datastore-experimental-column-optimization"), true, "enable experimental column optimization") + flagSet.StringVar(&opts.experimentalSchemaModeString, flagName("datastore-experimental-schema-mode"), "read-legacy-write-legacy", `experimental schema mode ("read-legacy-write-legacy", "read-legacy-write-both", "read-new-write-both", "read-new-write-new")`) return nil } @@ -387,6 +394,14 @@ func DefaultDatastoreConfig() *Config { IncludeQueryParametersInTraces: false, WriteAcquisitionTimeout: 30 * time.Millisecond, CaveatTypeSet: caveattypes.Default.TypeSet, + experimentalSchemaModeString: "read-legacy-write-legacy", + } +} + +// WithSchemaCacheOptions sets the schema cache options for the datastore. +func WithSchemaCacheOptions(cacheOptions dsoptions.SchemaCacheOptions) ConfigOption { + return func(c *Config) { + c.schemaCacheOptions = cacheOptions } } @@ -397,6 +412,15 @@ func NewDatastore(ctx context.Context, options ...ConfigOption) (datastore.Datas o(opts) } + // Parse the experimental schema mode string if provided + if opts.experimentalSchemaModeString != "" { + mode, err := dsoptions.ParseSchemaMode(opts.experimentalSchemaModeString) + if err != nil { + return nil, err + } + opts.ExperimentalSchemaMode = mode + } + if (opts.Engine == PostgresEngine || opts.Engine == MySQLEngine) && opts.FollowerReadDelay == DefaultFollowerReadDelay { // Set the default follower read delay for postgres and mysql to 0 - // this should only be set if read replicas are used. @@ -423,12 +447,12 @@ func NewDatastore(ctx context.Context, options ...ConfigOption) (datastore.Datas ctx, cancel := context.WithTimeout(ctx, opts.BootstrapTimeout) defer cancel() - revision, err := ds.HeadRevision(ctx) + revision, schemaHash, err := ds.HeadRevision(ctx) if err != nil { return nil, fmt.Errorf("unable to determine datastore state before applying bootstrap data: %w", err) } - nsDefs, err := ds.SnapshotReader(revision).LegacyListAllNamespaces(ctx) + nsDefs, err := ds.SnapshotReader(revision, schemaHash).LegacyListAllNamespaces(ctx) if err != nil { return nil, fmt.Errorf("unable to determine datastore state before applying bootstrap data: %w", err) } @@ -578,6 +602,8 @@ func newCRDBDatastore(ctx context.Context, opts Config) (datastore.Datastore, er crdb.WithColumnOptimization(opts.ExperimentalColumnOptimization), crdb.IncludeQueryParametersInTraces(opts.IncludeQueryParametersInTraces), crdb.WithWatchDisabled(opts.DisableWatchSupport), + crdb.WithSchemaMode(opts.ExperimentalSchemaMode), + crdb.WithSchemaCacheOptions(opts.schemaCacheOptions), ) } @@ -626,6 +652,8 @@ func commonPostgresDatastoreOptions(opts Config) ([]postgres.Option, error) { postgres.WithColumnOptimization(opts.ExperimentalColumnOptimization), postgres.WatchChangeBufferMaximumSize(watchChangeBufferMaximumSize), postgres.IncludeQueryParametersInTraces(opts.IncludeQueryParametersInTraces), + postgres.WithSchemaMode(opts.ExperimentalSchemaMode), + postgres.WithSchemaCacheOptions(opts.schemaCacheOptions), }, nil } @@ -726,6 +754,8 @@ func newSpannerDatastore(ctx context.Context, opts Config) (datastore.Datastore, spanner.FilterMaximumIDCount(opts.FilterMaximumIDCount), spanner.WithColumnOptimization(opts.ExperimentalColumnOptimization), spanner.WithWatchDisabled(opts.DisableWatchSupport), + spanner.WithSchemaMode(opts.ExperimentalSchemaMode), + spanner.WithSchemaCacheOptions(opts.schemaCacheOptions), ) } @@ -779,6 +809,8 @@ func commonMySQLDatastoreOptions(opts Config) ([]mysql.Option, error) { mysql.FilterMaximumIDCount(opts.FilterMaximumIDCount), mysql.AllowedMigrations(opts.AllowedMigrations), mysql.WithColumnOptimization(opts.ExperimentalColumnOptimization), + mysql.WithSchemaMode(opts.ExperimentalSchemaMode), + mysql.WithSchemaCacheOptions(opts.schemaCacheOptions), }, nil } diff --git a/pkg/cmd/datastore/datastore_test.go b/pkg/cmd/datastore/datastore_test.go index 00cf657b9..044940846 100644 --- a/pkg/cmd/datastore/datastore_test.go +++ b/pkg/cmd/datastore/datastore_test.go @@ -6,6 +6,8 @@ import ( "github.com/spf13/pflag" "github.com/stretchr/testify/require" + + pkgdatastore "github.com/authzed/spicedb/pkg/datastore" ) func TestDefaults(t *testing.T) { @@ -27,10 +29,10 @@ func TestLoadDatastoreFromFileContents(t *testing.T) { ds.Close() }) - revision, err := ds.HeadRevision(ctx) + revision, _, err := ds.HeadRevision(ctx) require.NoError(t, err) - namespaces, err := ds.SnapshotReader(revision).LegacyListAllNamespaces(ctx) + namespaces, err := ds.SnapshotReader(revision, pkgdatastore.NoSchemaHashForTesting).LegacyListAllNamespaces(ctx) require.NoError(t, err) require.Len(t, namespaces, 1) require.Equal(t, "user", namespaces[0].Definition.Name) @@ -51,10 +53,10 @@ func TestLoadDatastoreFromFile(t *testing.T) { ds.Close() }) - revision, err := ds.HeadRevision(ctx) + revision, _, err := ds.HeadRevision(ctx) require.NoError(t, err) - namespaces, err := ds.SnapshotReader(revision).LegacyListAllNamespaces(ctx) + namespaces, err := ds.SnapshotReader(revision, pkgdatastore.NoSchemaHashForTesting).LegacyListAllNamespaces(ctx) require.NoError(t, err) require.Len(t, namespaces, 1) require.Equal(t, "user", namespaces[0].Definition.Name) @@ -93,10 +95,10 @@ relationships: |- ds.Close() }) - revision, err := ds.HeadRevision(ctx) + revision, _, err := ds.HeadRevision(ctx) require.NoError(t, err) - namespaces, err := ds.SnapshotReader(revision).LegacyListAllNamespaces(ctx) + namespaces, err := ds.SnapshotReader(revision, pkgdatastore.NoSchemaHashForTesting).LegacyListAllNamespaces(ctx) require.NoError(t, err) require.Len(t, namespaces, 2) require.Equal(t, "organization", namespaces[0].Definition.Name) @@ -115,10 +117,10 @@ func TestLoadDatastoreFromFileAndContents(t *testing.T) { WithEngine(MemoryEngine)) require.NoError(t, err) - revision, err := ds.HeadRevision(ctx) + revision, _, err := ds.HeadRevision(ctx) require.NoError(t, err) - namespaces, err := ds.SnapshotReader(revision).LegacyListAllNamespaces(ctx) + namespaces, err := ds.SnapshotReader(revision, pkgdatastore.NoSchemaHashForTesting).LegacyListAllNamespaces(ctx) require.NoError(t, err) require.Len(t, namespaces, 2) namespaceNames := []string{namespaces[0].Definition.Name, namespaces[1].Definition.Name} diff --git a/pkg/cmd/datastore/zz_generated.options.go b/pkg/cmd/datastore/zz_generated.options.go index a4b2e7bbf..926a5aa4b 100644 --- a/pkg/cmd/datastore/zz_generated.options.go +++ b/pkg/cmd/datastore/zz_generated.options.go @@ -4,6 +4,7 @@ package datastore import ( "fmt" types "github.com/authzed/spicedb/pkg/caveats/types" + options "github.com/authzed/spicedb/pkg/datastore/options" defaults "github.com/creasty/defaults" "time" ) @@ -88,6 +89,7 @@ func (c *Config) ToOption() ConfigOption { to.AllowedMigrations = c.AllowedMigrations to.ExperimentalColumnOptimization = c.ExperimentalColumnOptimization to.EnableRevisionHeartbeat = c.EnableRevisionHeartbeat + to.ExperimentalSchemaMode = c.ExperimentalSchemaMode } } @@ -218,6 +220,7 @@ func (c *Config) DebugMap() map[string]any { } debugMap["ExperimentalColumnOptimization"] = c.ExperimentalColumnOptimization debugMap["EnableRevisionHeartbeat"] = c.EnableRevisionHeartbeat + debugMap["ExperimentalSchemaMode"] = c.ExperimentalSchemaMode return debugMap } @@ -691,3 +694,10 @@ func WithEnableRevisionHeartbeat(enableRevisionHeartbeat bool) ConfigOption { c.EnableRevisionHeartbeat = enableRevisionHeartbeat } } + +// WithExperimentalSchemaMode returns an option that can set ExperimentalSchemaMode on a Config +func WithExperimentalSchemaMode(experimentalSchemaMode options.SchemaMode) ConfigOption { + return func(c *Config) { + c.ExperimentalSchemaMode = experimentalSchemaMode + } +} diff --git a/pkg/cmd/server/server.go b/pkg/cmd/server/server.go index 1599035aa..f67e5f067 100644 --- a/pkg/cmd/server/server.go +++ b/pkg/cmd/server/server.go @@ -45,6 +45,7 @@ import ( datastorecfg "github.com/authzed/spicedb/pkg/cmd/datastore" "github.com/authzed/spicedb/pkg/cmd/util" "github.com/authzed/spicedb/pkg/datastore" + dsoptions "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/middleware/consistency" "github.com/authzed/spicedb/pkg/middleware/requestid" "github.com/authzed/spicedb/pkg/spiceerrors" @@ -114,6 +115,7 @@ type Config struct { DispatchCacheConfig CacheConfig `debugmap:"visible"` ClusterDispatchCacheConfig CacheConfig `debugmap:"visible"` LR3ResourceChunkCacheConfig CacheConfig `debugmap:"visible"` + SchemaCacheConfig CacheConfig `debugmap:"visible"` // API Behavior DisableV1SchemaAPI bool `debugmap:"visible"` @@ -195,12 +197,6 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) { log.Ctx(ctx).Trace().Msg("using preconfigured auth function") } - nscc, err := CompleteCache[cache.StringKey, schemacaching.CacheEntry](&c.NamespaceCacheConfig) - if err != nil { - return nil, fmt.Errorf("failed to create namespace cache: %w", err) - } - log.Ctx(ctx).Info().EmbedObject(nscc).Msg("configured namespace cache") - ds := c.Datastore if ds == nil { var err error @@ -210,6 +206,9 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) { // are at most the number of elements returned from a datastore query datastorecfg.WithFilterMaximumIDCount(c.DispatchChunkSize), datastorecfg.WithEnableRevisionHeartbeat(c.EnableRevisionHeartbeat), + datastorecfg.WithSchemaCacheOptions(dsoptions.SchemaCacheOptions{ + MaximumCacheEntries: 100, // Default cache size + }), ) if err != nil { return nil, spiceerrors.NewTerminationErrorBuilder(fmt.Errorf("failed to create datastore: %w", err)). @@ -219,6 +218,12 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) { } } + nscc, err := CompleteCache[cache.StringKey, schemacaching.CacheEntry](&c.NamespaceCacheConfig) + if err != nil { + return nil, fmt.Errorf("failed to create namespace cache: %w", err) + } + log.Ctx(ctx).Info().EmbedObject(nscc).Msg("configured namespace cache") + cachingMode := schemacaching.JustInTimeCaching if c.EnableExperimentalWatchableSchemaCache { cachingMode = schemacaching.WatchIfSupported diff --git a/pkg/cmd/server/zz_generated.options.go b/pkg/cmd/server/zz_generated.options.go index f9fc5a91c..250f4518a 100644 --- a/pkg/cmd/server/zz_generated.options.go +++ b/pkg/cmd/server/zz_generated.options.go @@ -79,6 +79,7 @@ func (c *Config) ToOption() ConfigOption { to.DispatchCacheConfig = c.DispatchCacheConfig to.ClusterDispatchCacheConfig = c.ClusterDispatchCacheConfig to.LR3ResourceChunkCacheConfig = c.LR3ResourceChunkCacheConfig + to.SchemaCacheConfig = c.SchemaCacheConfig to.DisableV1SchemaAPI = c.DisableV1SchemaAPI to.V1SchemaAdditiveOnly = c.V1SchemaAdditiveOnly to.MaximumUpdatesPerWrite = c.MaximumUpdatesPerWrite @@ -208,6 +209,7 @@ func (c *Config) DebugMap() map[string]any { debugMap["DispatchCacheConfig"] = c.DispatchCacheConfig debugMap["ClusterDispatchCacheConfig"] = c.ClusterDispatchCacheConfig debugMap["LR3ResourceChunkCacheConfig"] = c.LR3ResourceChunkCacheConfig + debugMap["SchemaCacheConfig"] = c.SchemaCacheConfig debugMap["DisableV1SchemaAPI"] = c.DisableV1SchemaAPI debugMap["V1SchemaAdditiveOnly"] = c.V1SchemaAdditiveOnly debugMap["MaximumUpdatesPerWrite"] = c.MaximumUpdatesPerWrite @@ -617,6 +619,13 @@ func WithLR3ResourceChunkCacheConfig(lR3ResourceChunkCacheConfig CacheConfig) Co } } +// WithSchemaCacheConfig returns an option that can set SchemaCacheConfig on a Config +func WithSchemaCacheConfig(schemaCacheConfig CacheConfig) ConfigOption { + return func(c *Config) { + c.SchemaCacheConfig = schemaCacheConfig + } +} + // WithDisableV1SchemaAPI returns an option that can set DisableV1SchemaAPI on a Config func WithDisableV1SchemaAPI(disableV1SchemaAPI bool) ConfigOption { return func(c *Config) { diff --git a/pkg/cursor/cursor.go b/pkg/cursor/cursor.go index 9ab4f7086..93d2cfe62 100644 --- a/pkg/cursor/cursor.go +++ b/pkg/cursor/cursor.go @@ -50,7 +50,7 @@ func Decode(encoded *v1.Cursor) (*impl.DecodedCursor, error) { // consumption, including the provided call context to ensure the API cursor reflects the calling // API method. The call hash should contain all the parameters of the calling API function, // as well as its revision and name. -func EncodeFromDispatchCursor(dispatchCursor *dispatch.Cursor, callAndParameterHash string, revision datastore.Revision, flags map[string]string) (*v1.Cursor, error) { +func EncodeFromDispatchCursor(dispatchCursor *dispatch.Cursor, callAndParameterHash string, revision datastore.Revision, schemaHash datastore.SchemaHash, flags map[string]string) (*v1.Cursor, error) { if dispatchCursor == nil { return nil, spiceerrors.MustBugf("got nil dispatch cursor") } @@ -62,13 +62,14 @@ func EncodeFromDispatchCursor(dispatchCursor *dispatch.Cursor, callAndParameterH DispatchVersion: dispatchCursor.DispatchVersion, Sections: dispatchCursor.Sections, CallAndParametersHash: callAndParameterHash, + SchemaHash: []byte(schemaHash), Flags: flags, }, }, }) } -func EncodeFromDispatchCursorSections(dispatchCursorSections []string, callAndParameterHash string, revision datastore.Revision, flags map[string]string) (*v1.Cursor, error) { +func EncodeFromDispatchCursorSections(dispatchCursorSections []string, callAndParameterHash string, revision datastore.Revision, schemaHash datastore.SchemaHash, flags map[string]string) (*v1.Cursor, error) { return Encode(&impl.DecodedCursor{ VersionOneof: &impl.DecodedCursor_V1{ V1: &impl.V1Cursor{ @@ -76,6 +77,7 @@ func EncodeFromDispatchCursorSections(dispatchCursorSections []string, callAndPa DispatchVersion: 1, Sections: dispatchCursorSections, CallAndParametersHash: callAndParameterHash, + SchemaHash: []byte(schemaHash), Flags: flags, }, }, diff --git a/pkg/cursor/cursor_test.go b/pkg/cursor/cursor_test.go index c68ed3dc7..a72a79d5e 100644 --- a/pkg/cursor/cursor_test.go +++ b/pkg/cursor/cursor_test.go @@ -50,7 +50,7 @@ func TestEncodeDecode(t *testing.T) { require := require.New(t) encoded, err := EncodeFromDispatchCursor(&dispatch.Cursor{ Sections: tc.sections, - }, tc.hash, tc.revision, map[string]string{"some": "flag"}) + }, tc.hash, tc.revision, datastore.NoSchemaHashForTesting, map[string]string{"some": "flag"}) require.NoError(err) require.NotNil(encoded) diff --git a/pkg/datastore/datastore.go b/pkg/datastore/datastore.go index c708a88ae..36e92e701 100644 --- a/pkg/datastore/datastore.go +++ b/pkg/datastore/datastore.go @@ -677,15 +677,15 @@ type ReadOnlyDatastore interface { // SnapshotReader creates a read-only handle that reads the datastore at the specified revision. // Any errors establishing the reader will be returned by subsequent calls. - SnapshotReader(Revision) Reader + SnapshotReader(Revision, SchemaHash) Reader // OptimizedRevision gets a revision that will likely already be replicated // and will likely be shared amongst many queries. - OptimizedRevision(ctx context.Context) (Revision, error) + OptimizedRevision(ctx context.Context) (Revision, SchemaHash, error) // HeadRevision gets a revision that is guaranteed to be at least as fresh as // right now. - HeadRevision(ctx context.Context) (Revision, error) + HeadRevision(ctx context.Context) (Revision, SchemaHash, error) // CheckRevision checks the specified revision to make sure it's valid and // hasn't been garbage collected. @@ -956,6 +956,10 @@ type Revision interface { // ByteSortable returns true if the string representation of the Revision is byte sortable, false otherwise. ByteSortable() bool + + // Key returns a unique string key for this revision suitable for use in maps and caches. + // The key should be deterministic and consistent for equal revisions. + Key() string } type nilRevision struct{} @@ -980,7 +984,33 @@ func (nilRevision) String() string { return "nil" } +func (nilRevision) Key() string { + return "nil" +} + // NoRevision is a zero type for the revision that will make changing the // revision type in the future a bit easier if necessary. Implementations // should use any time they want to signal an empty/error revision. var NoRevision Revision = nilRevision{} + +// SchemaHash represents a unique identifier for a schema version. +type SchemaHash string + +// NoSchemaHashInTransaction is a sentinel value indicating no schema hash should be used +// for cache lookups within a transaction. This prevents caching schemas at unstable revisions. +// This is a non-empty sentinel value to help catch bugs where empty strings are incorrectly used. +const NoSchemaHashInTransaction SchemaHash = "__transaction_bypass__" + +// NoSchemaHashForTesting is a sentinel value used in test code to bypass schema hash requirements. +// This is a non-empty sentinel value to help catch bugs where empty strings are incorrectly used. +const NoSchemaHashForTesting SchemaHash = "__testing_bypass__" + +// NoSchemaHashForWatch is a sentinel value used in watch operations where the schema hash is not +// immediately available (e.g., when resuming from a cursor). The schema will be loaded when needed. +// This is a non-empty sentinel value to help catch bugs where empty strings are incorrectly used. +const NoSchemaHashForWatch SchemaHash = "__watch_load_on_demand__" + +// NoSchemaHashForLegacyCursor is a sentinel value used when decoding legacy cursors that don't +// contain a schema hash field. The schema will be loaded on demand from the fallback value. +// This is a non-empty sentinel value to help catch bugs where empty strings are incorrectly used. +const NoSchemaHashForLegacyCursor SchemaHash = "__legacy_cursor_load_on_demand__" diff --git a/pkg/datastore/datastore_test.go b/pkg/datastore/datastore_test.go index aa3982b5b..b6f2ff8e2 100644 --- a/pkg/datastore/datastore_test.go +++ b/pkg/datastore/datastore_test.go @@ -617,7 +617,7 @@ func (f fakeDatastore) MetricsID() (string, error) { return "fake", nil } -func (f fakeDatastore) SnapshotReader(_ Revision) Reader { +func (f fakeDatastore) SnapshotReader(_ Revision, _ SchemaHash) Reader { return nil } @@ -625,12 +625,12 @@ func (f fakeDatastore) ReadWriteTx(_ context.Context, _ TxUserFunc, _ ...options return nil, nil } -func (f fakeDatastore) OptimizedRevision(_ context.Context) (Revision, error) { - return nil, nil +func (f fakeDatastore) OptimizedRevision(_ context.Context) (Revision, SchemaHash, error) { + return nil, NoSchemaHashForTesting, nil } -func (f fakeDatastore) HeadRevision(_ context.Context) (Revision, error) { - return nil, nil +func (f fakeDatastore) HeadRevision(_ context.Context) (Revision, SchemaHash, error) { + return nil, NoSchemaHashForTesting, nil } func (f fakeDatastore) CheckRevision(_ context.Context, _ Revision) error { diff --git a/pkg/datastore/errors.go b/pkg/datastore/errors.go index e19782dc8..299dd8252 100644 --- a/pkg/datastore/errors.go +++ b/pkg/datastore/errors.go @@ -306,4 +306,5 @@ var ( ErrClosedIterator = errors.New("unable to iterate: iterator closed") ErrCursorsWithoutSorting = errors.New("cursors are disabled on unsorted results") ErrCursorEmpty = errors.New("cursors are only available after the first result") + ErrSchemaNotFound = errors.New("schema not found") ) diff --git a/pkg/datastore/mocks/mock_datastore.go b/pkg/datastore/mocks/mock_datastore.go index 3f954dcd0..c730b6fde 100644 --- a/pkg/datastore/mocks/mock_datastore.go +++ b/pkg/datastore/mocks/mock_datastore.go @@ -750,12 +750,13 @@ func (mr *MockReadOnlyDatastoreMockRecorder) Features(ctx any) *gomock.Call { } // HeadRevision mocks base method. -func (m *MockReadOnlyDatastore) HeadRevision(ctx context.Context) (datastore.Revision, error) { +func (m *MockReadOnlyDatastore) HeadRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "HeadRevision", ctx) ret0, _ := ret[0].(datastore.Revision) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret1, _ := ret[1].(datastore.SchemaHash) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // HeadRevision indicates an expected call of HeadRevision. @@ -795,12 +796,13 @@ func (mr *MockReadOnlyDatastoreMockRecorder) OfflineFeatures() *gomock.Call { } // OptimizedRevision mocks base method. -func (m *MockReadOnlyDatastore) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { +func (m *MockReadOnlyDatastore) OptimizedRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OptimizedRevision", ctx) ret0, _ := ret[0].(datastore.Revision) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret1, _ := ret[1].(datastore.SchemaHash) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // OptimizedRevision indicates an expected call of OptimizedRevision. @@ -840,17 +842,17 @@ func (mr *MockReadOnlyDatastoreMockRecorder) RevisionFromString(serialized any) } // SnapshotReader mocks base method. -func (m *MockReadOnlyDatastore) SnapshotReader(arg0 datastore.Revision) datastore.Reader { +func (m *MockReadOnlyDatastore) SnapshotReader(arg0 datastore.Revision, arg1 datastore.SchemaHash) datastore.Reader { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SnapshotReader", arg0) + ret := m.ctrl.Call(m, "SnapshotReader", arg0, arg1) ret0, _ := ret[0].(datastore.Reader) return ret0 } // SnapshotReader indicates an expected call of SnapshotReader. -func (mr *MockReadOnlyDatastoreMockRecorder) SnapshotReader(arg0 any) *gomock.Call { +func (mr *MockReadOnlyDatastoreMockRecorder) SnapshotReader(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SnapshotReader", reflect.TypeOf((*MockReadOnlyDatastore)(nil).SnapshotReader), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SnapshotReader", reflect.TypeOf((*MockReadOnlyDatastore)(nil).SnapshotReader), arg0, arg1) } // Statistics mocks base method. @@ -966,12 +968,13 @@ func (mr *MockDatastoreMockRecorder) Features(ctx any) *gomock.Call { } // HeadRevision mocks base method. -func (m *MockDatastore) HeadRevision(ctx context.Context) (datastore.Revision, error) { +func (m *MockDatastore) HeadRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "HeadRevision", ctx) ret0, _ := ret[0].(datastore.Revision) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret1, _ := ret[1].(datastore.SchemaHash) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // HeadRevision indicates an expected call of HeadRevision. @@ -1011,12 +1014,13 @@ func (mr *MockDatastoreMockRecorder) OfflineFeatures() *gomock.Call { } // OptimizedRevision mocks base method. -func (m *MockDatastore) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { +func (m *MockDatastore) OptimizedRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OptimizedRevision", ctx) ret0, _ := ret[0].(datastore.Revision) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret1, _ := ret[1].(datastore.SchemaHash) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // OptimizedRevision indicates an expected call of OptimizedRevision. @@ -1076,17 +1080,17 @@ func (mr *MockDatastoreMockRecorder) RevisionFromString(serialized any) *gomock. } // SnapshotReader mocks base method. -func (m *MockDatastore) SnapshotReader(arg0 datastore.Revision) datastore.Reader { +func (m *MockDatastore) SnapshotReader(arg0 datastore.Revision, arg1 datastore.SchemaHash) datastore.Reader { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SnapshotReader", arg0) + ret := m.ctrl.Call(m, "SnapshotReader", arg0, arg1) ret0, _ := ret[0].(datastore.Reader) return ret0 } // SnapshotReader indicates an expected call of SnapshotReader. -func (mr *MockDatastoreMockRecorder) SnapshotReader(arg0 any) *gomock.Call { +func (mr *MockDatastoreMockRecorder) SnapshotReader(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SnapshotReader", reflect.TypeOf((*MockDatastore)(nil).SnapshotReader), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SnapshotReader", reflect.TypeOf((*MockDatastore)(nil).SnapshotReader), arg0, arg1) } // Statistics mocks base method. @@ -1287,12 +1291,13 @@ func (mr *MockSQLDatastoreMockRecorder) Features(ctx any) *gomock.Call { } // HeadRevision mocks base method. -func (m *MockSQLDatastore) HeadRevision(ctx context.Context) (datastore.Revision, error) { +func (m *MockSQLDatastore) HeadRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "HeadRevision", ctx) ret0, _ := ret[0].(datastore.Revision) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret1, _ := ret[1].(datastore.SchemaHash) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // HeadRevision indicates an expected call of HeadRevision. @@ -1332,12 +1337,13 @@ func (mr *MockSQLDatastoreMockRecorder) OfflineFeatures() *gomock.Call { } // OptimizedRevision mocks base method. -func (m *MockSQLDatastore) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { +func (m *MockSQLDatastore) OptimizedRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OptimizedRevision", ctx) ret0, _ := ret[0].(datastore.Revision) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret1, _ := ret[1].(datastore.SchemaHash) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // OptimizedRevision indicates an expected call of OptimizedRevision. @@ -1426,17 +1432,17 @@ func (mr *MockSQLDatastoreMockRecorder) RevisionFromString(serialized any) *gomo } // SnapshotReader mocks base method. -func (m *MockSQLDatastore) SnapshotReader(arg0 datastore.Revision) datastore.Reader { +func (m *MockSQLDatastore) SnapshotReader(arg0 datastore.Revision, arg1 datastore.SchemaHash) datastore.Reader { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SnapshotReader", arg0) + ret := m.ctrl.Call(m, "SnapshotReader", arg0, arg1) ret0, _ := ret[0].(datastore.Reader) return ret0 } // SnapshotReader indicates an expected call of SnapshotReader. -func (mr *MockSQLDatastoreMockRecorder) SnapshotReader(arg0 any) *gomock.Call { +func (mr *MockSQLDatastoreMockRecorder) SnapshotReader(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SnapshotReader", reflect.TypeOf((*MockSQLDatastore)(nil).SnapshotReader), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SnapshotReader", reflect.TypeOf((*MockSQLDatastore)(nil).SnapshotReader), arg0, arg1) } // Statistics mocks base method. @@ -1552,12 +1558,13 @@ func (mr *MockStrictReadDatastoreMockRecorder) Features(ctx any) *gomock.Call { } // HeadRevision mocks base method. -func (m *MockStrictReadDatastore) HeadRevision(ctx context.Context) (datastore.Revision, error) { +func (m *MockStrictReadDatastore) HeadRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "HeadRevision", ctx) ret0, _ := ret[0].(datastore.Revision) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret1, _ := ret[1].(datastore.SchemaHash) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // HeadRevision indicates an expected call of HeadRevision. @@ -1611,12 +1618,13 @@ func (mr *MockStrictReadDatastoreMockRecorder) OfflineFeatures() *gomock.Call { } // OptimizedRevision mocks base method. -func (m *MockStrictReadDatastore) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { +func (m *MockStrictReadDatastore) OptimizedRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OptimizedRevision", ctx) ret0, _ := ret[0].(datastore.Revision) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret1, _ := ret[1].(datastore.SchemaHash) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // OptimizedRevision indicates an expected call of OptimizedRevision. @@ -1676,17 +1684,17 @@ func (mr *MockStrictReadDatastoreMockRecorder) RevisionFromString(serialized any } // SnapshotReader mocks base method. -func (m *MockStrictReadDatastore) SnapshotReader(arg0 datastore.Revision) datastore.Reader { +func (m *MockStrictReadDatastore) SnapshotReader(arg0 datastore.Revision, arg1 datastore.SchemaHash) datastore.Reader { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SnapshotReader", arg0) + ret := m.ctrl.Call(m, "SnapshotReader", arg0, arg1) ret0, _ := ret[0].(datastore.Reader) return ret0 } // SnapshotReader indicates an expected call of SnapshotReader. -func (mr *MockStrictReadDatastoreMockRecorder) SnapshotReader(arg0 any) *gomock.Call { +func (mr *MockStrictReadDatastoreMockRecorder) SnapshotReader(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SnapshotReader", reflect.TypeOf((*MockStrictReadDatastore)(nil).SnapshotReader), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SnapshotReader", reflect.TypeOf((*MockStrictReadDatastore)(nil).SnapshotReader), arg0, arg1) } // Statistics mocks base method. @@ -1802,12 +1810,13 @@ func (mr *MockStartableDatastoreMockRecorder) Features(ctx any) *gomock.Call { } // HeadRevision mocks base method. -func (m *MockStartableDatastore) HeadRevision(ctx context.Context) (datastore.Revision, error) { +func (m *MockStartableDatastore) HeadRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "HeadRevision", ctx) ret0, _ := ret[0].(datastore.Revision) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret1, _ := ret[1].(datastore.SchemaHash) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // HeadRevision indicates an expected call of HeadRevision. @@ -1847,12 +1856,13 @@ func (mr *MockStartableDatastoreMockRecorder) OfflineFeatures() *gomock.Call { } // OptimizedRevision mocks base method. -func (m *MockStartableDatastore) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { +func (m *MockStartableDatastore) OptimizedRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OptimizedRevision", ctx) ret0, _ := ret[0].(datastore.Revision) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret1, _ := ret[1].(datastore.SchemaHash) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // OptimizedRevision indicates an expected call of OptimizedRevision. @@ -1912,17 +1922,17 @@ func (mr *MockStartableDatastoreMockRecorder) RevisionFromString(serialized any) } // SnapshotReader mocks base method. -func (m *MockStartableDatastore) SnapshotReader(arg0 datastore.Revision) datastore.Reader { +func (m *MockStartableDatastore) SnapshotReader(arg0 datastore.Revision, arg1 datastore.SchemaHash) datastore.Reader { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SnapshotReader", arg0) + ret := m.ctrl.Call(m, "SnapshotReader", arg0, arg1) ret0, _ := ret[0].(datastore.Reader) return ret0 } // SnapshotReader indicates an expected call of SnapshotReader. -func (mr *MockStartableDatastoreMockRecorder) SnapshotReader(arg0 any) *gomock.Call { +func (mr *MockStartableDatastoreMockRecorder) SnapshotReader(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SnapshotReader", reflect.TypeOf((*MockStartableDatastore)(nil).SnapshotReader), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SnapshotReader", reflect.TypeOf((*MockStartableDatastore)(nil).SnapshotReader), arg0, arg1) } // Start mocks base method. @@ -2052,12 +2062,13 @@ func (mr *MockRepairableDatastoreMockRecorder) Features(ctx any) *gomock.Call { } // HeadRevision mocks base method. -func (m *MockRepairableDatastore) HeadRevision(ctx context.Context) (datastore.Revision, error) { +func (m *MockRepairableDatastore) HeadRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "HeadRevision", ctx) ret0, _ := ret[0].(datastore.Revision) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret1, _ := ret[1].(datastore.SchemaHash) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // HeadRevision indicates an expected call of HeadRevision. @@ -2097,12 +2108,13 @@ func (mr *MockRepairableDatastoreMockRecorder) OfflineFeatures() *gomock.Call { } // OptimizedRevision mocks base method. -func (m *MockRepairableDatastore) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { +func (m *MockRepairableDatastore) OptimizedRevision(ctx context.Context) (datastore.Revision, datastore.SchemaHash, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OptimizedRevision", ctx) ret0, _ := ret[0].(datastore.Revision) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret1, _ := ret[1].(datastore.SchemaHash) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // OptimizedRevision indicates an expected call of OptimizedRevision. @@ -2190,17 +2202,17 @@ func (mr *MockRepairableDatastoreMockRecorder) RevisionFromString(serialized any } // SnapshotReader mocks base method. -func (m *MockRepairableDatastore) SnapshotReader(arg0 datastore.Revision) datastore.Reader { +func (m *MockRepairableDatastore) SnapshotReader(arg0 datastore.Revision, arg1 datastore.SchemaHash) datastore.Reader { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SnapshotReader", arg0) + ret := m.ctrl.Call(m, "SnapshotReader", arg0, arg1) ret0, _ := ret[0].(datastore.Reader) return ret0 } // SnapshotReader indicates an expected call of SnapshotReader. -func (mr *MockRepairableDatastoreMockRecorder) SnapshotReader(arg0 any) *gomock.Call { +func (mr *MockRepairableDatastoreMockRecorder) SnapshotReader(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SnapshotReader", reflect.TypeOf((*MockRepairableDatastore)(nil).SnapshotReader), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SnapshotReader", reflect.TypeOf((*MockRepairableDatastore)(nil).SnapshotReader), arg0, arg1) } // Statistics mocks base method. @@ -2352,6 +2364,20 @@ func (mr *MockRevisionMockRecorder) GreaterThan(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GreaterThan", reflect.TypeOf((*MockRevision)(nil).GreaterThan), arg0) } +// Key mocks base method. +func (m *MockRevision) Key() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Key") + ret0, _ := ret[0].(string) + return ret0 +} + +// Key indicates an expected call of Key. +func (mr *MockRevisionMockRecorder) Key() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Key", reflect.TypeOf((*MockRevision)(nil).Key)) +} + // LessThan mocks base method. func (m *MockRevision) LessThan(arg0 datastore.Revision) bool { m.ctrl.T.Helper() diff --git a/pkg/datastore/options/options.go b/pkg/datastore/options/options.go index 2a1b3d5a1..313fa8110 100644 --- a/pkg/datastore/options/options.go +++ b/pkg/datastore/options/options.go @@ -2,6 +2,7 @@ package options import ( "context" + "fmt" "google.golang.org/protobuf/types/known/structpb" @@ -10,6 +11,13 @@ import ( "github.com/authzed/spicedb/pkg/tuple" ) +// SchemaCacheOptions configures the schema cache behavior. +type SchemaCacheOptions struct { + // MaximumCacheEntries is the maximum number of schema entries to cache. + // If 0, defaults to 100. + MaximumCacheEntries uint32 +} + //go:generate go run github.com/ecordell/optgen -output zz_generated.query_options.go . QueryOptions ReverseQueryOptions RWTOptions //go:generate go run github.com/ecordell/optgen -output zz_generated.delete_options.go . DeleteOptions @@ -127,3 +135,52 @@ var ( // LimitOne is a constant *uint64 that can be used with WithLimit requests. LimitOne = &one ) + +// SchemaMode represents the experimental schema mode for datastore operations. +type SchemaMode uint8 + +const ( + // SchemaModeReadLegacyWriteLegacy uses legacy schema reader and writer + SchemaModeReadLegacyWriteLegacy SchemaMode = iota + + // SchemaModeReadLegacyWriteBoth uses legacy schema reader and writes to both legacy and unified schema + SchemaModeReadLegacyWriteBoth + + // SchemaModeReadNewWriteBoth uses unified schema reader and writes to both legacy and unified schema + SchemaModeReadNewWriteBoth + + // SchemaModeReadNewWriteNew uses unified schema reader and writer only + SchemaModeReadNewWriteNew +) + +var schemaModeNames = map[string]SchemaMode{ + "read-legacy-write-legacy": SchemaModeReadLegacyWriteLegacy, + "read-legacy-write-both": SchemaModeReadLegacyWriteBoth, + "read-new-write-both": SchemaModeReadNewWriteBoth, + "read-new-write-new": SchemaModeReadNewWriteNew, +} + +var schemaModeStrings = map[SchemaMode]string{ + SchemaModeReadLegacyWriteLegacy: "read-legacy-write-legacy", + SchemaModeReadLegacyWriteBoth: "read-legacy-write-both", + SchemaModeReadNewWriteBoth: "read-new-write-both", + SchemaModeReadNewWriteNew: "read-new-write-new", +} + +// ParseSchemaMode converts a string to a SchemaMode. Returns an error if the string is invalid. +func ParseSchemaMode(s string) (SchemaMode, error) { + mode, ok := schemaModeNames[s] + if !ok { + return SchemaModeReadLegacyWriteLegacy, fmt.Errorf("invalid schema mode %q, must be one of: read-legacy-write-legacy, read-legacy-write-both, read-new-write-both, read-new-write-new", s) + } + return mode, nil +} + +// String returns the string representation of the SchemaMode. +func (s SchemaMode) String() string { + str, ok := schemaModeStrings[s] + if !ok { + return "unknown" + } + return str +} diff --git a/pkg/datastore/singlestoreschema.go b/pkg/datastore/singlestoreschema.go new file mode 100644 index 000000000..9589a10da --- /dev/null +++ b/pkg/datastore/singlestoreschema.go @@ -0,0 +1,43 @@ +package datastore + +import ( + "context" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +// SingleStoreSchemaReader defines methods for reading schema from the unified single-store schema table. +type SingleStoreSchemaReader interface { + // ReadStoredSchema reads the stored schema from the unified schema table. + // Returns ErrSchemaNotFound if no schema has been written. + ReadStoredSchema(ctx context.Context) (*core.StoredSchema, error) +} + +// SingleStoreSchemaWriter defines methods for writing schema to the unified single-store schema table. +type SingleStoreSchemaWriter interface { + // WriteStoredSchema writes the stored schema to the unified schema table. + WriteStoredSchema(ctx context.Context, schema *core.StoredSchema) error +} + +// LegacySchemaHashWriter defines an optional method for writing just the schema hash. +// This is used by datastores to write the schema hash during legacy schema writes +// without reading back buffered writes. +type LegacySchemaHashWriter interface { + // WriteLegacySchemaHashFromDefinitions writes the schema hash computed from the given definitions. + // This is called by the legacy schema adapter after buffering writes but before commit. + WriteLegacySchemaHashFromDefinitions(ctx context.Context, namespaces []RevisionedNamespace, caveats []RevisionedCaveat) error +} + +// DualSchemaReader combines both legacy and single-store schema reading interfaces. +// Datastores should implement this interface to support both schema storage modes. +type DualSchemaReader interface { + LegacySchemaReader + SingleStoreSchemaReader +} + +// DualSchemaWriter combines both legacy and single-store schema writing interfaces. +// Datastores should implement this interface to support both schema storage modes. +type DualSchemaWriter interface { + LegacySchemaWriter + SingleStoreSchemaWriter +} diff --git a/pkg/datastore/test/basic.go b/pkg/datastore/test/basic.go index a34e36496..875743cc9 100644 --- a/pkg/datastore/test/basic.go +++ b/pkg/datastore/test/basic.go @@ -23,7 +23,7 @@ func UseAfterCloseTest(t *testing.T, tester DatastoreTester) { require.NoError(err) // Attempt to use and ensure an error is returned. - _, err = ds.HeadRevision(t.Context()) + _, _, err = ds.HeadRevision(t.Context()) require.Error(err) } @@ -35,7 +35,7 @@ func DeleteAllDataTest(t *testing.T, tester DatastoreTester) { ctx := t.Context() // Ensure at least a few relationships and namespaces exist. - reader := ds.SnapshotReader(revision) + reader := ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting) nsDefs, err := reader.LegacyListAllNamespaces(ctx) require.NoError(t, err) require.NotEmpty(t, nsDefs, "no namespace definitions provided") @@ -61,10 +61,10 @@ func DeleteAllDataTest(t *testing.T, tester DatastoreTester) { require.NoError(t, err) // Ensure there are no relationships or namespaces. - headRev, err := ds.HeadRevision(ctx) + headRev, _, err := ds.HeadRevision(ctx) require.NoError(t, err) - reader = ds.SnapshotReader(headRev) + reader = ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) afterNSDefs, err := reader.LegacyListAllNamespaces(ctx) require.NoError(t, err) require.Empty(t, afterNSDefs, "namespace definitions still exist") diff --git a/pkg/datastore/test/bulk.go b/pkg/datastore/test/bulk.go index 42073196e..beab8d917 100644 --- a/pkg/datastore/test/bulk.go +++ b/pkg/datastore/test/bulk.go @@ -51,10 +51,10 @@ func BulkUploadTest(t *testing.T, tester DatastoreTester) { tRequire := testfixtures.RelationshipChecker{Require: require, DS: ds} - head, err := ds.HeadRevision(ctx) + head, _, err := ds.HeadRevision(ctx) require.NoError(err) - iter, err := ds.SnapshotReader(head).QueryRelationships(ctx, datastore.RelationshipsFilter{ + iter, err := ds.SnapshotReader(head, datastore.NoSchemaHashForTesting).QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: testfixtures.DocumentNS.Name, }, options.WithQueryShape(queryshape.FindResourceOfType)) require.NoError(err) @@ -149,7 +149,7 @@ func BulkUploadWithCaveats(t *testing.T, tester DatastoreTester) { }) require.NoError(err) - iter, err := ds.SnapshotReader(lastRevision).QueryRelationships(ctx, datastore.RelationshipsFilter{ + iter, err := ds.SnapshotReader(lastRevision, datastore.NoSchemaHashForTesting).QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: testfixtures.DocumentNS.Name, }, options.WithQueryShape(queryshape.FindResourceOfType)) require.NoError(err) @@ -191,7 +191,7 @@ func BulkUploadWithExpiration(t *testing.T, tester DatastoreTester) { }) require.NoError(err) - iter, err := ds.SnapshotReader(lastRevision).QueryRelationships(ctx, datastore.RelationshipsFilter{ + iter, err := ds.SnapshotReader(lastRevision, datastore.NoSchemaHashForTesting).QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: testfixtures.DocumentNS.Name, }, options.WithQueryShape(queryshape.FindResourceOfType)) require.NoError(err) @@ -228,7 +228,7 @@ func BulkUploadEditCaveat(t *testing.T, tester DatastoreTester) { }) require.NoError(err) - iter, err := ds.SnapshotReader(lastRevision).QueryRelationships(ctx, datastore.RelationshipsFilter{ + iter, err := ds.SnapshotReader(lastRevision, datastore.NoSchemaHashForTesting).QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: testfixtures.DocumentNS.Name, }, options.WithQueryShape(queryshape.FindResourceOfType)) require.NoError(err) @@ -253,7 +253,7 @@ func BulkUploadEditCaveat(t *testing.T, tester DatastoreTester) { }) require.NoError(err) - iter, err = ds.SnapshotReader(lastRevision).QueryRelationships(ctx, datastore.RelationshipsFilter{ + iter, err = ds.SnapshotReader(lastRevision, datastore.NoSchemaHashForTesting).QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: testfixtures.DocumentNS.Name, }, options.WithQueryShape(queryshape.FindResourceOfType)) require.NoError(err) diff --git a/pkg/datastore/test/caveat.go b/pkg/datastore/test/caveat.go index 0820d507e..e7efeefca 100644 --- a/pkg/datastore/test/caveat.go +++ b/pkg/datastore/test/caveat.go @@ -33,10 +33,10 @@ func CaveatNotFoundTest(t *testing.T, tester DatastoreTester) { ctx := t.Context() - startRevision, err := ds.HeadRevision(ctx) + startRevision, _, err := ds.HeadRevision(ctx) require.NoError(err) - _, _, err = ds.SnapshotReader(startRevision).LegacyReadCaveatByName(ctx, "unknown") + _, _, err = ds.SnapshotReader(startRevision, datastore.NoSchemaHashForTesting).LegacyReadCaveatByName(ctx, "unknown") require.ErrorAs(err, &datastore.CaveatNameNotFoundError{}) } @@ -65,7 +65,7 @@ func WriteReadDeleteCaveatTest(t *testing.T, tester DatastoreTester) { req.NoError(err) // The caveat can be looked up by name - cr := ds.SnapshotReader(rev) + cr := ds.SnapshotReader(rev, datastore.NoSchemaHashForTesting) cv, _, err := cr.LegacyReadCaveatByName(ctx, coreCaveat.Name) req.NoError(err) @@ -118,7 +118,7 @@ func WriteReadDeleteCaveatTest(t *testing.T, tester DatastoreTester) { return tx.LegacyDeleteCaveats(ctx, []string{coreCaveat.Name}) }) req.NoError(err) - cr = ds.SnapshotReader(rev) + cr = ds.SnapshotReader(rev, datastore.NoSchemaHashForTesting) _, _, err = cr.LegacyReadCaveatByName(ctx, coreCaveat.Name) req.ErrorAs(err, &datastore.CaveatNameNotFoundError{}) @@ -187,7 +187,7 @@ func WriteCaveatedRelationshipTest(t *testing.T, tester DatastoreTester) { rel.OptionalCaveat.CaveatName = "rando" rev, err = common.WriteRelationships(ctx, sds, tuple.UpdateOperationDelete, rel) req.NoError(err) - iter, err := ds.SnapshotReader(rev).QueryRelationships(t.Context(), datastore.RelationshipsFilter{ + iter, err := ds.SnapshotReader(rev, datastore.NoSchemaHashForTesting).QueryRelationships(t.Context(), datastore.RelationshipsFilter{ OptionalResourceType: rel.Resource.ObjectType, }, options.WithQueryShape(queryshape.FindResourceOfType)) req.NoError(err) @@ -227,7 +227,7 @@ func CaveatedRelationshipFilterTest(t *testing.T, tester DatastoreTester) { req.NoError(err) // filter by first caveat - iter, err := ds.SnapshotReader(rev).QueryRelationships(ctx, datastore.RelationshipsFilter{ + iter, err := ds.SnapshotReader(rev, datastore.NoSchemaHashForTesting).QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: rel.Resource.ObjectType, OptionalCaveatNameFilter: datastore.WithCaveatName(coreCaveat.Name), }, options.WithQueryShape(queryshape.Varying)) @@ -235,7 +235,7 @@ func CaveatedRelationshipFilterTest(t *testing.T, tester DatastoreTester) { expectRel(req, iter, rel) // filter by second caveat - iter, err = ds.SnapshotReader(rev).QueryRelationships(ctx, datastore.RelationshipsFilter{ + iter, err = ds.SnapshotReader(rev, datastore.NoSchemaHashForTesting).QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: anotherTpl.Resource.ObjectType, OptionalCaveatNameFilter: datastore.WithCaveatName(anotherCoreCaveat.Name), }, options.WithQueryShape(queryshape.Varying)) @@ -243,7 +243,7 @@ func CaveatedRelationshipFilterTest(t *testing.T, tester DatastoreTester) { expectRel(req, iter, anotherTpl) // filter by caveat required and ensure not found. - iter, err = ds.SnapshotReader(rev).QueryRelationships(ctx, datastore.RelationshipsFilter{ + iter, err = ds.SnapshotReader(rev, datastore.NoSchemaHashForTesting).QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: anotherTpl.Resource.ObjectType, OptionalResourceIds: []string{anotherTpl.Resource.ObjectID}, OptionalCaveatNameFilter: datastore.WithNoCaveat(), @@ -273,13 +273,13 @@ func CaveatSnapshotReadsTest(t *testing.T, tester DatastoreTester) { req.NoError(err) // check most recent revision - cr := ds.SnapshotReader(newRev) + cr := ds.SnapshotReader(newRev, datastore.NoSchemaHashForTesting) cv, _, err := cr.LegacyReadCaveatByName(ctx, coreCaveat.Name) req.NoError(err) req.Equal(newExpression, cv.SerializedExpression) // check previous revision - cr = ds.SnapshotReader(oldRev) + cr = ds.SnapshotReader(oldRev, datastore.NoSchemaHashForTesting) cv, _, err = cr.LegacyReadCaveatByName(ctx, coreCaveat.Name) req.NoError(err) req.Equal(oldExpression, cv.SerializedExpression) @@ -303,7 +303,7 @@ func CaveatedRelationshipWatchTest(t *testing.T, tester DatastoreTester) { // test relationship with caveat and context relWithContext := createTestCaveatedRel(t, "document:a#parent@folder:company#...", coreCaveat.Name) - revBeforeWrite, err := ds.HeadRevision(ctx) + revBeforeWrite, _, err := ds.HeadRevision(ctx) require.NoError(t, err) writeRev, err := common.WriteRelationships(ctx, ds, tuple.UpdateOperationCreate, relWithContext) @@ -319,7 +319,7 @@ func CaveatedRelationshipWatchTest(t *testing.T, tester DatastoreTester) { req.NoError(err) tupleWithEmptyContext.OptionalCaveat.Context = strct - secondRevBeforeWrite, err := ds.HeadRevision(ctx) + secondRevBeforeWrite, _, err := ds.HeadRevision(ctx) require.NoError(t, err) secondWriteRev, err := common.WriteRelationships(ctx, ds, tuple.UpdateOperationCreate, tupleWithEmptyContext) @@ -332,7 +332,7 @@ func CaveatedRelationshipWatchTest(t *testing.T, tester DatastoreTester) { tupleWithNilContext := createTestCaveatedRel(t, "document:c#parent@folder:company#...", coreCaveat.Name) tupleWithNilContext.OptionalCaveat.Context = nil - thirdRevBeforeWrite, err := ds.HeadRevision(ctx) + thirdRevBeforeWrite, _, err := ds.HeadRevision(ctx) require.NoError(t, err) thirdWriteRev, err := common.WriteRelationships(ctx, ds, tuple.UpdateOperationCreate, tupleWithNilContext) @@ -377,7 +377,7 @@ func expectNoRel(req *require.Assertions, iter datastore.RelationshipIterator) { } func assertRelCorrectlyStored(req *require.Assertions, ds datastore.Datastore, rev datastore.Revision, expected tuple.Relationship) { - iter, err := ds.SnapshotReader(rev).QueryRelationships(context.Background(), datastore.RelationshipsFilter{ + iter, err := ds.SnapshotReader(rev, datastore.NoSchemaHashForTesting).QueryRelationships(context.Background(), datastore.RelationshipsFilter{ OptionalResourceType: expected.Resource.ObjectType, }, options.WithQueryShape(queryshape.FindResourceOfType)) req.NoError(err) diff --git a/pkg/datastore/test/counters.go b/pkg/datastore/test/counters.go index 9bd451e16..d51859998 100644 --- a/pkg/datastore/test/counters.go +++ b/pkg/datastore/test/counters.go @@ -43,7 +43,7 @@ func RelationshipCounterOverExpiredTest(t *testing.T, tester DatastoreTester) { require.NoError(t, err) // Check the count using the filter. - reader := ds.SnapshotReader(updatedRev) + reader := ds.SnapshotReader(updatedRev, datastore.NoSchemaHashForTesting) expectedCount := 0 iter, err := reader.QueryRelationships(t.Context(), datastore.RelationshipsFilter{ @@ -114,7 +114,7 @@ func RelationshipCountersTest(t *testing.T, tester DatastoreTester) { ds, rev := testfixtures.StandardDatastoreWithData(rawDS, require.New(t)) // Try calling count without the filter being registered. - reader := ds.SnapshotReader(rev) + reader := ds.SnapshotReader(rev, datastore.NoSchemaHashForTesting) _, err = reader.CountRelationships(t.Context(), "somefilter") require.Error(t, err) @@ -149,7 +149,7 @@ func RelationshipCountersTest(t *testing.T, tester DatastoreTester) { require.NoError(t, err) // Check the count using the filter. - reader = ds.SnapshotReader(updatedRev) + reader = ds.SnapshotReader(updatedRev, datastore.NoSchemaHashForTesting) expectedCount := 0 iter, err := reader.QueryRelationships(t.Context(), datastore.RelationshipsFilter{ @@ -204,7 +204,7 @@ func RelationshipCountersTest(t *testing.T, tester DatastoreTester) { require.Equal(t, expectedCount, count) // Call the filter at the unregistered revision. - reader = ds.SnapshotReader(unregisterRev) + reader = ds.SnapshotReader(unregisterRev, datastore.NoSchemaHashForTesting) _, err = reader.CountRelationships(t.Context(), "document") require.Error(t, err) require.Contains(t, err.Error(), "counter with name `document` not found") @@ -229,7 +229,7 @@ func RelationshipCountersWithOddFilterTest(t *testing.T, tester DatastoreTester) require.NoError(t, err) // Check the count using the filter. - reader := ds.SnapshotReader(updatedRev) + reader := ds.SnapshotReader(updatedRev, datastore.NoSchemaHashForTesting) expectedCount := 0 iter, err := reader.QueryRelationships(t.Context(), datastore.RelationshipsFilter{ @@ -258,7 +258,7 @@ func UpdateRelationshipCounterTest(t *testing.T, tester DatastoreTester) { ds, rev := testfixtures.StandardDatastoreWithData(rawDS, require.New(t)) - reader := ds.SnapshotReader(rev) + reader := ds.SnapshotReader(rev, datastore.NoSchemaHashForTesting) filters, err := reader.LookupCounters(t.Context()) require.NoError(t, err) require.Empty(t, filters) @@ -283,7 +283,7 @@ func UpdateRelationshipCounterTest(t *testing.T, tester DatastoreTester) { require.NoError(t, err) // Read the filters. - reader = ds.SnapshotReader(updatedRev) + reader = ds.SnapshotReader(updatedRev, datastore.NoSchemaHashForTesting) filters, err = reader.LookupCounters(t.Context()) require.NoError(t, err) @@ -301,7 +301,7 @@ func UpdateRelationshipCounterTest(t *testing.T, tester DatastoreTester) { require.NoError(t, err) // Read the filters. - reader = ds.SnapshotReader(currentRev) + reader = ds.SnapshotReader(currentRev, datastore.NoSchemaHashForTesting) filters, err = reader.LookupCounters(t.Context()) require.NoError(t, err) @@ -324,7 +324,7 @@ func UpdateRelationshipCounterTest(t *testing.T, tester DatastoreTester) { require.NoError(t, err) // Read the filters. - reader = ds.SnapshotReader(newFilterRev) + reader = ds.SnapshotReader(newFilterRev, datastore.NoSchemaHashForTesting) filters, err = reader.LookupCounters(t.Context()) require.NoError(t, err) diff --git a/pkg/datastore/test/datastore.go b/pkg/datastore/test/datastore.go index b55825168..9569ae14d 100644 --- a/pkg/datastore/test/datastore.go +++ b/pkg/datastore/test/datastore.go @@ -218,6 +218,17 @@ func AllWithExceptions(t *testing.T, tester DatastoreTester, except Categories, t.Run("TestRelationshipCounterOverExpired", runner(tester, RelationshipCounterOverExpiredTest)) t.Run("TestRegisterRelationshipCountersInParallel", runner(tester, RegisterRelationshipCountersInParallelTest)) t.Run("TestRelationshipCountersWithOddFilter", runner(tester, RelationshipCountersWithOddFilterTest)) + + t.Run("TestUnifiedSchema", runner(tester, UnifiedSchemaTest)) + t.Run("TestUnifiedSchemaUpdate", runner(tester, UnifiedSchemaUpdateTest)) + t.Run("TestUnifiedSchemaRevision", runner(tester, UnifiedSchemaRevisionTest)) + t.Run("TestUnifiedSchemaWithCaveats", runner(tester, UnifiedSchemaWithCaveatsTest)) + t.Run("TestUnifiedSchemaEmpty", runner(tester, UnifiedSchemaEmptyTest)) + t.Run("TestUnifiedSchemaLookup", runner(tester, UnifiedSchemaLookupTest)) + t.Run("TestUnifiedSchemaLookupByNames", runner(tester, UnifiedSchemaLookupByNamesTest)) + t.Run("TestUnifiedSchemaValidation", runner(tester, UnifiedSchemaValidationTest)) + t.Run("TestUnifiedSchemaMultipleIterations", runner(tester, UnifiedSchemaMultipleIterationsTest)) + t.Run("TestUnifiedSchemaHash", runner(tester, UnifiedSchemaHashTest)) } func OnlyGCTests(t *testing.T, tester DatastoreTester, concurrent bool) { diff --git a/pkg/datastore/test/namespace.go b/pkg/datastore/test/namespace.go index afb181ac8..a80914b50 100644 --- a/pkg/datastore/test/namespace.go +++ b/pkg/datastore/test/namespace.go @@ -41,10 +41,10 @@ func NamespaceNotFoundTest(t *testing.T, tester DatastoreTester) { ctx := t.Context() - startRevision, err := ds.HeadRevision(ctx) + startRevision, _, err := ds.HeadRevision(ctx) require.NoError(err) - _, _, err = ds.SnapshotReader(startRevision).LegacyReadNamespaceByName(ctx, "unknown") + _, _, err = ds.SnapshotReader(startRevision, datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(ctx, "unknown") require.ErrorAs(err, &datastore.NamespaceNotFoundError{}) } @@ -58,10 +58,10 @@ func NamespaceWriteTest(t *testing.T, tester DatastoreTester) { ctx := t.Context() - startRevision, err := ds.HeadRevision(ctx) + startRevision, _, err := ds.HeadRevision(ctx) require.NoError(err) - nsDefs, err := ds.SnapshotReader(startRevision).LegacyListAllNamespaces(ctx) + nsDefs, err := ds.SnapshotReader(startRevision, datastore.NoSchemaHashForTesting).LegacyListAllNamespaces(ctx) require.NoError(err) require.Empty(nsDefs) @@ -71,7 +71,7 @@ func NamespaceWriteTest(t *testing.T, tester DatastoreTester) { require.NoError(err) require.True(writtenRev.GreaterThan(startRevision)) - nsDefs, err = ds.SnapshotReader(writtenRev).LegacyListAllNamespaces(ctx) + nsDefs, err = ds.SnapshotReader(writtenRev, datastore.NoSchemaHashForTesting).LegacyListAllNamespaces(ctx) require.NoError(err) require.Len(nsDefs, 1) require.Equal(testUserNS.Name, nsDefs[0].Definition.Name) @@ -82,18 +82,18 @@ func NamespaceWriteTest(t *testing.T, tester DatastoreTester) { require.NoError(err) require.True(secondWritten.GreaterThan(writtenRev)) - nsDefs, err = ds.SnapshotReader(secondWritten).LegacyListAllNamespaces(ctx) + nsDefs, err = ds.SnapshotReader(secondWritten, datastore.NoSchemaHashForTesting).LegacyListAllNamespaces(ctx) require.NoError(err) require.Len(nsDefs, 2) - _, _, err = ds.SnapshotReader(writtenRev).LegacyReadNamespaceByName(ctx, testNamespace.Name) + _, _, err = ds.SnapshotReader(writtenRev, datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(ctx, testNamespace.Name) require.Error(err) - nsDefs, err = ds.SnapshotReader(writtenRev).LegacyListAllNamespaces(ctx) + nsDefs, err = ds.SnapshotReader(writtenRev, datastore.NoSchemaHashForTesting).LegacyListAllNamespaces(ctx) require.NoError(err) require.Len(nsDefs, 1) - found, createdRev, err := ds.SnapshotReader(secondWritten).LegacyReadNamespaceByName(ctx, testNamespace.Name) + found, createdRev, err := ds.SnapshotReader(secondWritten, datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(ctx, testNamespace.Name) require.NoError(err) require.False(createdRev.GreaterThan(secondWritten)) require.True(createdRev.GreaterThan(startRevision)) @@ -105,36 +105,36 @@ func NamespaceWriteTest(t *testing.T, tester DatastoreTester) { }) require.NoError(err) - checkUpdated, createdRev, err := ds.SnapshotReader(updatedRevision).LegacyReadNamespaceByName(ctx, testNamespace.Name) + checkUpdated, createdRev, err := ds.SnapshotReader(updatedRevision, datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(ctx, testNamespace.Name) require.NoError(err) require.False(createdRev.GreaterThan(updatedRevision)) require.True(createdRev.GreaterThan(startRevision)) foundUpdated := cmp.Diff(updatedNamespace, checkUpdated, protocmp.Transform()) require.Empty(foundUpdated) - checkOld, createdRev, err := ds.SnapshotReader(writtenRev).LegacyReadNamespaceByName(ctx, testUserNamespace) + checkOld, createdRev, err := ds.SnapshotReader(writtenRev, datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(ctx, testUserNamespace) require.NoError(err) require.False(createdRev.GreaterThan(writtenRev)) require.True(createdRev.GreaterThan(startRevision)) require.Empty(cmp.Diff(testUserNS, checkOld, protocmp.Transform())) - checkOldList, err := ds.SnapshotReader(writtenRev).LegacyListAllNamespaces(ctx) + checkOldList, err := ds.SnapshotReader(writtenRev, datastore.NoSchemaHashForTesting).LegacyListAllNamespaces(ctx) require.NoError(err) require.Len(checkOldList, 1) require.Equal(testUserNS.Name, checkOldList[0].Definition.Name) require.Empty(cmp.Diff(testUserNS, checkOldList[0].Definition, protocmp.Transform())) - checkLookup, err := ds.SnapshotReader(secondWritten).LegacyLookupNamespacesWithNames(ctx, []string{testNamespace.Name}) + checkLookup, err := ds.SnapshotReader(secondWritten, datastore.NoSchemaHashForTesting).LegacyLookupNamespacesWithNames(ctx, []string{testNamespace.Name}) require.NoError(err) require.Len(checkLookup, 1) require.Equal(testNamespace.Name, checkLookup[0].Definition.Name) require.Empty(cmp.Diff(testNamespace, checkLookup[0].Definition, protocmp.Transform())) - checkLookupMultiple, err := ds.SnapshotReader(secondWritten).LegacyLookupNamespacesWithNames(ctx, []string{testNamespace.Name, testUserNS.Name}) + checkLookupMultiple, err := ds.SnapshotReader(secondWritten, datastore.NoSchemaHashForTesting).LegacyLookupNamespacesWithNames(ctx, []string{testNamespace.Name, testUserNS.Name}) require.NoError(err) require.Len(checkLookupMultiple, 2) - emptyLookup, err := ds.SnapshotReader(secondWritten).LegacyLookupNamespacesWithNames(ctx, []string{"anothername"}) + emptyLookup, err := ds.SnapshotReader(secondWritten, datastore.NoSchemaHashForTesting).LegacyLookupNamespacesWithNames(ctx, []string{"anothername"}) require.NoError(err) require.Empty(emptyLookup) } @@ -167,24 +167,24 @@ func NamespaceDeleteTest(t *testing.T, tester DatastoreTester) { require.NoError(err) require.True(deletedRev.GreaterThan(revision)) - _, _, err = ds.SnapshotReader(deletedRev).LegacyReadNamespaceByName(ctx, testfixtures.DocumentNS.Name) + _, _, err = ds.SnapshotReader(deletedRev, datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(ctx, testfixtures.DocumentNS.Name) require.ErrorAs(err, &datastore.NamespaceNotFoundError{}) - found, nsCreatedRev, err := ds.SnapshotReader(deletedRev).LegacyReadNamespaceByName(ctx, testfixtures.FolderNS.Name) + found, nsCreatedRev, err := ds.SnapshotReader(deletedRev, datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(ctx, testfixtures.FolderNS.Name) require.NoError(err) require.NotNil(found) require.True(nsCreatedRev.LessThan(deletedRev)) - allNamespaces, err := ds.SnapshotReader(deletedRev).LegacyListAllNamespaces(ctx) + allNamespaces, err := ds.SnapshotReader(deletedRev, datastore.NoSchemaHashForTesting).LegacyListAllNamespaces(ctx) require.NoError(err) for _, ns := range allNamespaces { require.NotEqual(testfixtures.DocumentNS.Name, ns.Definition.Name, "deleted namespace '%s' should not be in namespace list", ns.Definition.Name) } - deletedRevision, err := ds.HeadRevision(ctx) + deletedRevision, _, err := ds.HeadRevision(ctx) require.NoError(err) - iter, err := ds.SnapshotReader(deletedRevision).QueryRelationships(ctx, datastore.RelationshipsFilter{ + iter, err := ds.SnapshotReader(deletedRevision, datastore.NoSchemaHashForTesting).QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: testfixtures.DocumentNS.Name, }, options.WithQueryShape(queryshape.FindResourceOfType)) require.NoError(err) @@ -214,15 +214,15 @@ func NamespaceDeleteNoRelationshipsTest(t *testing.T, tester DatastoreTester) { require.NoError(err) require.True(deletedRev.GreaterThan(revision)) - _, _, err = ds.SnapshotReader(deletedRev).LegacyReadNamespaceByName(ctx, testfixtures.DocumentNS.Name) + _, _, err = ds.SnapshotReader(deletedRev, datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(ctx, testfixtures.DocumentNS.Name) require.ErrorAs(err, &datastore.NamespaceNotFoundError{}) - found, nsCreatedRev, err := ds.SnapshotReader(deletedRev).LegacyReadNamespaceByName(ctx, testfixtures.FolderNS.Name) + found, nsCreatedRev, err := ds.SnapshotReader(deletedRev, datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(ctx, testfixtures.FolderNS.Name) require.NoError(err) require.NotNil(found) require.True(nsCreatedRev.LessThan(deletedRev)) - allNamespaces, err := ds.SnapshotReader(deletedRev).LegacyListAllNamespaces(ctx) + allNamespaces, err := ds.SnapshotReader(deletedRev, datastore.NoSchemaHashForTesting).LegacyListAllNamespaces(ctx) require.NoError(err) for _, ns := range allNamespaces { require.NotEqual(testfixtures.DocumentNS.Name, ns.Definition.Name, "deleted namespace '%s' should not be in namespace list", ns.Definition.Name) @@ -236,7 +236,7 @@ func NamespaceMultiDeleteTest(t *testing.T, tester DatastoreTester) { ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require.New(t)) ctx := t.Context() - namespaces, err := ds.SnapshotReader(revision).LegacyListAllNamespaces(ctx) + namespaces, err := ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting).LegacyListAllNamespaces(ctx) require.NoError(t, err) nsNames := make([]string, 0, len(namespaces)) @@ -249,7 +249,7 @@ func NamespaceMultiDeleteTest(t *testing.T, tester DatastoreTester) { }) require.NoError(t, err) - namespacesAfterDel, err := ds.SnapshotReader(deletedRev).LegacyListAllNamespaces(ctx) + namespacesAfterDel, err := ds.SnapshotReader(deletedRev, datastore.NoSchemaHashForTesting).LegacyListAllNamespaces(ctx) require.NoError(t, err) require.Empty(t, namespacesAfterDel) } @@ -270,7 +270,7 @@ func EmptyNamespaceDeleteTest(t *testing.T, tester DatastoreTester) { require.NoError(err) require.True(deletedRev.GreaterThan(revision)) - _, _, err = ds.SnapshotReader(deletedRev).LegacyReadNamespaceByName(ctx, testfixtures.UserNS.Name) + _, _, err = ds.SnapshotReader(deletedRev, datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(ctx, testfixtures.UserNS.Name) require.ErrorAs(err, &datastore.NamespaceNotFoundError{}) } @@ -360,13 +360,13 @@ definition document { // Read the namespace definition back from the datastore and compare. nsConfig := compiled.ObjectDefinitions[0] - readNsDef, _, err := ds.SnapshotReader(updatedRevision).LegacyReadNamespaceByName(ctx, nsConfig.Name) + readNsDef, _, err := ds.SnapshotReader(updatedRevision, datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(ctx, nsConfig.Name) require.NoError(err) testutil.RequireProtoEqual(t, nsConfig, readNsDef, "found changed namespace definition") // Read the caveat back from the datastore and compare. caveatDef := compiled.CaveatDefinitions[0] - readCaveatDef, _, err := ds.SnapshotReader(updatedRevision).LegacyReadCaveatByName(ctx, caveatDef.Name) + readCaveatDef, _, err := ds.SnapshotReader(updatedRevision, datastore.NoSchemaHashForTesting).LegacyReadCaveatByName(ctx, caveatDef.Name) require.NoError(err) testutil.RequireProtoEqual(t, caveatDef, readCaveatDef, "found changed caveat definition") diff --git a/pkg/datastore/test/pagination.go b/pkg/datastore/test/pagination.go index 98f597cd3..b8f2e0bc7 100644 --- a/pkg/datastore/test/pagination.go +++ b/pkg/datastore/test/pagination.go @@ -48,7 +48,7 @@ func OrderingTest(t *testing.T, tester DatastoreTester) { expected := sortedStandardData(tc.resourceType, tc.ordering) // Check the snapshot reader order - iter, err := ds.SnapshotReader(rev).QueryRelationships(ctx, datastore.RelationshipsFilter{ + iter, err := ds.SnapshotReader(rev, datastore.NoSchemaHashForTesting).QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: tc.resourceType, }, options.WithSort(tc.ordering), options.WithQueryShape(queryshape.FindResourceOfType)) @@ -284,7 +284,7 @@ func ReverseQueryFilteredOverMultipleValuesCursorTest(t *testing.T, tester Datas // Issue a reverse query call with a limit. for _, sortBy := range []options.SortOrder{options.ByResource, options.BySubject} { t.Run(fmt.Sprintf("SortBy-%d", sortBy), func(t *testing.T) { - reader := ds.SnapshotReader(rev) + reader := ds.SnapshotReader(rev, datastore.NoSchemaHashForTesting) var limit uint64 = 2 var cursor options.Cursor @@ -351,7 +351,7 @@ func ReverseQueryCursorTest(t *testing.T, tester DatastoreTester) { // Issue a reverse query call with a limit. for _, sortBy := range []options.SortOrder{options.ByResource, options.BySubject} { t.Run(fmt.Sprintf("SortBy-%d", sortBy), func(t *testing.T) { - reader := ds.SnapshotReader(rev) + reader := ds.SnapshotReader(rev, datastore.NoSchemaHashForTesting) var limit uint64 = 2 var cursor options.Cursor @@ -396,7 +396,7 @@ func foreachTxType( snapshotRev datastore.Revision, fn func(reader datastore.Reader), ) { - reader := ds.SnapshotReader(snapshotRev) + reader := ds.SnapshotReader(snapshotRev, datastore.NoSchemaHashForTesting) fn(reader) _, _ = ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { diff --git a/pkg/datastore/test/relationships.go b/pkg/datastore/test/relationships.go index 9a161f7c8..9f1244d3d 100644 --- a/pkg/datastore/test/relationships.go +++ b/pkg/datastore/test/relationships.go @@ -80,7 +80,7 @@ func SimpleTest(t *testing.T, tester DatastoreTester) { _, err = common.WriteRelationships(ctx, ds, tuple.UpdateOperationCreate, testRels...) require.Error(t, err) - dsReader := ds.SnapshotReader(lastRevision) + dsReader := ds.SnapshotReader(lastRevision, datastore.NoSchemaHashForTesting) for _, relToFind := range testRels { relSubject := relToFind.Subject @@ -229,7 +229,7 @@ func SimpleTest(t *testing.T, tester DatastoreTester) { // Verify that it does not show up at the new revision tRequire.NoRelationshipExists(ctx, testRels[0], deletedAt) - alreadyDeletedIter, err := ds.SnapshotReader(deletedAt).QueryRelationships( + alreadyDeletedIter, err := ds.SnapshotReader(deletedAt, datastore.NoSchemaHashForTesting).QueryRelationships( ctx, datastore.RelationshipsFilter{ OptionalResourceType: testRels[0].Resource.ObjectType, @@ -289,9 +289,9 @@ func ObjectIDsTest(t *testing.T, tester DatastoreTester) { require.NoError(err) // Read it back - rev, err := ds.HeadRevision(ctx) + rev, _, err := ds.HeadRevision(ctx) require.NoError(err) - iter, err := ds.SnapshotReader(rev).QueryRelationships(ctx, datastore.RelationshipsFilter{ + iter, err := ds.SnapshotReader(rev, datastore.NoSchemaHashForTesting).QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: testResourceNamespace, OptionalResourceIds: []string{tc}, }, options.WithQueryShape(queryshape.Varying)) @@ -1108,7 +1108,7 @@ func DeleteRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTes require.NoError(err) } - writtenRev, err := ds.HeadRevision(ctx) + writtenRev, _, err := ds.HeadRevision(ctx) require.NoError(err) var delLimit *uint64 @@ -1125,13 +1125,13 @@ func DeleteRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTes require.NoError(err) // Read the updated relationships and ensure no matching relationships are found. - headRev, err := ds.HeadRevision(ctx) + headRev, _, err := ds.HeadRevision(ctx) require.NoError(err) filter, err := datastore.RelationshipsFilterFromPublicFilter(tc.filter) require.NoError(err) - reader := ds.SnapshotReader(headRev) + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) iter, err := reader.QueryRelationships(ctx, filter, options.WithQueryShape(queryshape.Varying)) require.NoError(err) @@ -1164,7 +1164,7 @@ func DeleteRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTes // Ensure the initial relationships are still present at the previous revision. allInitialRelationships := mapz.NewSet[string]() - olderReader := ds.SnapshotReader(writtenRev) + olderReader := ds.SnapshotReader(writtenRev, datastore.NoSchemaHashForTesting) for _, resourceType := range resourceTypes.AsSlice() { iter, err := olderReader.QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: resourceType, @@ -1775,10 +1775,10 @@ func QueryRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTest require.NoError(err) } - headRev, err := ds.HeadRevision(ctx) + headRev, _, err := ds.HeadRevision(ctx) require.NoError(err) - reader := ds.SnapshotReader(headRev) + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) iter, err := reader.QueryRelationships(ctx, tc.filter, options.WithSkipCaveats(tc.withoutCaveats), options.WithSkipExpiration(tc.withoutExpiration), options.WithQueryShape(queryshape.Varying)) require.NoError(err) @@ -2214,7 +2214,7 @@ func BulkDeleteRelationshipsTest(t *testing.T, tester DatastoreTester) { // Ensure the relationships were removed. t.Log(time.Now(), "starting check") - reader := ds.SnapshotReader(deletedRev) + reader := ds.SnapshotReader(deletedRev, datastore.NoSchemaHashForTesting) iter, err := reader.QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: testResourceNamespace, OptionalResourceRelation: testReaderRelation, @@ -2243,10 +2243,10 @@ func ensureNotReverseRelationships(ctx context.Context, require *require.Asserti } func ensureReverseRelationshipsStatus(ctx context.Context, require *require.Assertions, ds datastore.Datastore, rels []tuple.Relationship, mustExist bool) { - headRev, err := ds.HeadRevision(ctx) + headRev, _, err := ds.HeadRevision(ctx) require.NoError(err) - reader := ds.SnapshotReader(headRev) + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) for _, rel := range rels { filter := datastore.SubjectRelationFilter{ @@ -2289,10 +2289,10 @@ func ensureNotRelationships(ctx context.Context, require *require.Assertions, ds } func ensureRelationshipsStatus(ctx context.Context, require *require.Assertions, ds datastore.Datastore, rels []tuple.Relationship, mustExist bool) { - headRev, err := ds.HeadRevision(ctx) + headRev, _, err := ds.HeadRevision(ctx) require.NoError(err) - reader := ds.SnapshotReader(headRev) + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) for _, rel := range rels { iter, err := reader.QueryRelationships(ctx, datastore.RelationshipsFilter{ @@ -2325,10 +2325,10 @@ func ensureRelationshipsStatus(ctx context.Context, require *require.Assertions, } func ensureRelationshipWithFilter(ctx context.Context, require *require.Assertions, ds datastore.Datastore, filter datastore.RelationshipsFilter, rel tuple.Relationship) { - headRev, err := ds.HeadRevision(ctx) + headRev, _, err := ds.HeadRevision(ctx) require.NoError(err) - reader := ds.SnapshotReader(headRev) + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) iter, err := reader.QueryRelationships(ctx, filter, options.WithQueryShape(queryshape.Varying)) require.NoError(err) @@ -2342,10 +2342,10 @@ func ensureRelationshipWithFilter(ctx context.Context, require *require.Assertio } func ensureNoRelationshipWithFilter(ctx context.Context, require *require.Assertions, ds datastore.Datastore, filter datastore.RelationshipsFilter) { - headRev, err := ds.HeadRevision(ctx) + headRev, _, err := ds.HeadRevision(ctx) require.NoError(err) - reader := ds.SnapshotReader(headRev) + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) iter, err := reader.QueryRelationships(ctx, filter, options.WithQueryShape(queryshape.Varying)) require.NoError(err) @@ -2357,10 +2357,10 @@ func ensureNoRelationshipWithFilter(ctx context.Context, require *require.Assert } func countRels(ctx context.Context, require *require.Assertions, ds datastore.Datastore, resourceType string) int { - headRev, err := ds.HeadRevision(ctx) + headRev, _, err := ds.HeadRevision(ctx) require.NoError(err) - reader := ds.SnapshotReader(headRev) + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) iter, err := reader.QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: resourceType, diff --git a/pkg/datastore/test/revisions.go b/pkg/datastore/test/revisions.go index 4d7b6f756..369422772 100644 --- a/pkg/datastore/test/revisions.go +++ b/pkg/datastore/test/revisions.go @@ -39,7 +39,7 @@ func RevisionQuantizationTest(t *testing.T, tester DatastoreTester) { require.NoError(err) ctx := t.Context() - veryFirstRevision, err := ds.OptimizedRevision(ctx) + veryFirstRevision, _, err := ds.OptimizedRevision(ctx) require.NoError(err) postSetupRevision := setupDatastore(ds, require) @@ -55,7 +55,7 @@ func RevisionQuantizationTest(t *testing.T, tester DatastoreTester) { require.True(writtenAt.GreaterThan(postSetupRevision)) // Get the new now revision - nowRevision, err := ds.HeadRevision(ctx) + nowRevision, _, err := ds.HeadRevision(ctx) require.NoError(err) // Let the quantization window expire @@ -63,7 +63,7 @@ func RevisionQuantizationTest(t *testing.T, tester DatastoreTester) { // Now we should ONLY get revisions later than the now revision for start := time.Now(); time.Since(start) < 10*time.Millisecond; { - testRevision, err := ds.OptimizedRevision(ctx) + testRevision, _, err := ds.OptimizedRevision(ctx) require.NoError(err) require.True(nowRevision.LessThan(testRevision) || nowRevision.Equal(testRevision)) } @@ -171,7 +171,7 @@ func RevisionGCTest(t *testing.T, tester DatastoreTester) { require.NoError(ds.CheckRevision(ctx, previousRev), "expected latest write revision to be within GC window") - head, err := ds.HeadRevision(ctx) + head, _, err := ds.HeadRevision(ctx) require.NoError(err) require.NoError(ds.CheckRevision(ctx, head), "expected head revision to be valid in GC Window") @@ -195,23 +195,23 @@ func RevisionGCTest(t *testing.T, tester DatastoreTester) { // require.Error(ds.CheckRevision(ctx, head), "expected head revision to be valid if out of GC window") // // latest state of the system is invalid if head revision is out of GC window - // _, _, err = ds.SnapshotReader(head).LegacyReadNamespaceByName(ctx, "foo/bar") + // _, _, err = ds.SnapshotReader(head, datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(ctx, "foo/bar") // require.Error(err, "expected previously written schema to exist at out-of-GC window head") // check freshly fetched head revision is valid after GC window elapsed - head, err = ds.HeadRevision(ctx) + head, _, err = ds.HeadRevision(ctx) require.NoError(err) // check that we can read a caveat whose revision has been garbage collectged - _, _, err = ds.SnapshotReader(head).LegacyReadCaveatByName(ctx, testCaveat.Name) + _, _, err = ds.SnapshotReader(head, datastore.NoSchemaHashForTesting).LegacyReadCaveatByName(ctx, testCaveat.Name) require.NoError(err, "expected previously written caveat should exist at head") // check that we can read the namespace which had its revision garbage collected - _, _, err = ds.SnapshotReader(head).LegacyReadNamespaceByName(ctx, "foo/createdtxgc") + _, _, err = ds.SnapshotReader(head, datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(ctx, "foo/createdtxgc") require.NoError(err, "expected previously written namespace should exist at head") // state of the system is also consistent at a recent call to head - _, _, err = ds.SnapshotReader(head).LegacyReadNamespaceByName(ctx, "foo/bar") + _, _, err = ds.SnapshotReader(head, datastore.NoSchemaHashForTesting).LegacyReadNamespaceByName(ctx, "foo/bar") require.NoError(err, "expected previously written schema to exist at head") // and that recent call to head revision is also valid, even after a GC window cycle without writes elapsed @@ -242,7 +242,7 @@ func CheckRevisionsTest(t *testing.T, tester DatastoreTester) { require.NoError(err) require.NoError(ds.CheckRevision(ctx, writtenRev), "expected written revision to be valid in GC Window") - head, err := ds.HeadRevision(ctx) + head, _, err := ds.HeadRevision(ctx) require.NoError(err) // Check the head revision is valid @@ -259,7 +259,7 @@ func CheckRevisionsTest(t *testing.T, tester DatastoreTester) { require.NoError(ds.CheckRevision(ctx, head), "expected previous revision to be valid in GC Window") // Get the updated head revision. - head, err = ds.HeadRevision(ctx) + head, _, err = ds.HeadRevision(ctx) require.NoError(err) // Check the new head revision is valid. @@ -277,7 +277,7 @@ func SequentialRevisionsTest(t *testing.T, tester DatastoreTester) { var previous datastore.Revision for range 50 { - head, err := ds.HeadRevision(ctx) + head, _, err := ds.HeadRevision(ctx) require.NoError(err) require.NoError(ds.CheckRevision(ctx, head), "expected head revision to be valid in GC Window") @@ -301,7 +301,7 @@ func ConcurrentRevisionsTest(t *testing.T, tester DatastoreTester) { var wg sync.WaitGroup wg.Add(10) - startingRev, err := ds.HeadRevision(ctx) + startingRev, _, err := ds.HeadRevision(ctx) require.NoError(err) errCh := make(chan error, 10*5) @@ -311,7 +311,7 @@ func ConcurrentRevisionsTest(t *testing.T, tester DatastoreTester) { defer wg.Done() for i := 0; i < 5; i++ { - head, err := ds.HeadRevision(ctx) + head, _, err := ds.HeadRevision(ctx) if err != nil { errCh <- fmt.Errorf("HeadRevision error: %w", err) continue diff --git a/pkg/datastore/test/unifiedschema.go b/pkg/datastore/test/unifiedschema.go new file mode 100644 index 000000000..20db09cf5 --- /dev/null +++ b/pkg/datastore/test/unifiedschema.go @@ -0,0 +1,1019 @@ +package test + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "fmt" + "sort" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/diff" + ns "github.com/authzed/spicedb/pkg/namespace" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" +) + +var ( + testSchemaDefinitions = []compiler.SchemaDefinition{ + ns.Namespace("user"), + ns.Namespace("document", + ns.MustRelation("viewer", nil, ns.AllowedRelation("user", "...")), + ns.MustRelation("editor", nil, ns.AllowedRelation("user", "...")), + ), + } + + updatedSchemaDefinitions = []compiler.SchemaDefinition{ + ns.Namespace("user"), + ns.Namespace("document", + ns.MustRelation("viewer", nil, ns.AllowedRelation("user", "...")), + ns.MustRelation("editor", nil, ns.AllowedRelation("user", "...")), + ns.MustRelation("owner", nil, ns.AllowedRelation("user", "...")), + ), + } +) + +// computeExpectedSchemaHash computes the expected schema hash by sorting definitions +// by name (matching datastore behavior) and then hashing the generated schema text. +func computeExpectedSchemaHash(t *testing.T, definitions []compiler.SchemaDefinition) string { + // Sort definitions by name for consistent ordering (matches datastore behavior) + sortedDefs := make([]compiler.SchemaDefinition, len(definitions)) + copy(sortedDefs, definitions) + sort.Slice(sortedDefs, func(i, j int) bool { + return sortedDefs[i].GetName() < sortedDefs[j].GetName() + }) + + // Generate schema text from sorted definitions + schemaText, _, err := generator.GenerateSchema(sortedDefs) + require.NoError(t, err) + + // Compute SHA256 hash + hashBytes := sha256.Sum256([]byte(schemaText)) + return hex.EncodeToString(hashBytes[:]) +} + +// requireSchemasEqual compares two schema texts semantically using the diff engine. +// This allows schemas to be equivalent even if definitions are in different order. +func requireSchemasEqual(t *testing.T, expected, actual string) { + require := require.New(t) + + // Compile both schemas + expectedCompiled, err := compiler.Compile(compiler.InputSchema{ + Source: "expected", + SchemaString: expected, + }, compiler.AllowUnprefixedObjectType()) + require.NoError(err, "failed to compile expected schema") + + actualCompiled, err := compiler.Compile(compiler.InputSchema{ + Source: "actual", + SchemaString: actual, + }, compiler.AllowUnprefixedObjectType()) + require.NoError(err, "failed to compile actual schema") + + // Create diffable schemas + expectedDiffable := diff.NewDiffableSchemaFromCompiledSchema(expectedCompiled) + actualDiffable := diff.NewDiffableSchemaFromCompiledSchema(actualCompiled) + + // Compute diff + schemaDiff, err := diff.DiffSchemas(expectedDiffable, actualDiffable, nil) + require.NoError(err, "failed to diff schemas") + + // Check that there are no differences + require.Empty(schemaDiff.AddedNamespaces, "unexpected added namespaces") + require.Empty(schemaDiff.RemovedNamespaces, "unexpected removed namespaces") + require.Empty(schemaDiff.AddedCaveats, "unexpected added caveats") + require.Empty(schemaDiff.RemovedCaveats, "unexpected removed caveats") + require.Empty(schemaDiff.ChangedNamespaces, "unexpected changed namespaces") + require.Empty(schemaDiff.ChangedCaveats, "unexpected changed caveats") +} + +// UnifiedSchemaTest tests basic unified schema storage functionality +// Note: This test assumes the datastore's readers and transactions support DualSchema interfaces. +// The specific schema mode (read legacy, write both, etc.) is configured at datastore +// initialization time by the specific datastore tests. +func UnifiedSchemaTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + ctx := context.Background() + + ds, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + defer ds.Close() + + // Check if datastore readers and transactions support schema operations + headRev, _, err := ds.HeadRevision(ctx) + require.NoError(err) + + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) + _, err = reader.SchemaReader() + require.NoError(err, "datastore reader must provide SchemaReader") + + // Check if transaction supports schema writer + var schemaWriterErr error + _, _ = ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + _, schemaWriterErr = rwt.SchemaWriter() + return nil + }) + require.NoError(schemaWriterErr, "datastore transaction must provide SchemaWriter") + + // Get starting revision + startRevision, _, err := ds.HeadRevision(ctx) + require.NoError(err) + + // Generate schema text + schemaText, _, err := generator.GenerateSchema(testSchemaDefinitions) + require.NoError(err) + + // Convert to datastore.SchemaDefinition by casting each element + defs := make([]datastore.SchemaDefinition, 0, len(testSchemaDefinitions)) + for _, def := range testSchemaDefinitions { + defs = append(defs, def.(datastore.SchemaDefinition)) + } + + // Write schema + writtenRev, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + schemaWriter, err := rwt.SchemaWriter() + if err != nil { + return err + } + return schemaWriter.WriteSchema(ctx, defs, schemaText, nil) + }) + require.NoError(err) + require.True(startRevision.LessThan(writtenRev)) + + // Read schema using SchemaReader + reader = ds.SnapshotReader(writtenRev, datastore.NoSchemaHashForTesting) + schemaReader, err := reader.SchemaReader() + require.NoError(err) + + // Verify schema text (using semantic comparison to allow for different definition order) + readSchemaText, err := schemaReader.SchemaText() + require.NoError(err) + requireSchemasEqual(t, schemaText, readSchemaText) + + // Verify namespace definitions + typeDefs, err := schemaReader.ListAllTypeDefinitions(ctx) + require.NoError(err) + require.Len(typeDefs, 2) + + userFound := false + docFound := false + for _, def := range typeDefs { + switch def.Definition.Name { + case "user": + userFound = true + require.NotNil(def.LastWrittenRevision) + case "document": + docFound = true + require.NotNil(def.LastWrittenRevision) + } + } + require.True(userFound, "user namespace should be found") + require.True(docFound, "document namespace should be found") + + // Lookup individual namespace + docDef, found, err := schemaReader.LookupTypeDefByName(ctx, "document") + require.NoError(err) + require.True(found) + require.Equal("document", docDef.Definition.Name) +} + +// UnifiedSchemaUpdateTest tests updating schemas +func UnifiedSchemaUpdateTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + ctx := context.Background() + + ds, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + defer ds.Close() + + // Check if datastore supports SchemaReader + headRev, _, err := ds.HeadRevision(ctx) + require.NoError(err) + + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) + _, err = reader.SchemaReader() + require.NoError(err, "datastore reader must provide SchemaReader") + + // Write initial schema + schemaText, _, err := generator.GenerateSchema(testSchemaDefinitions) + require.NoError(err) + + defs := make([]datastore.SchemaDefinition, 0, len(testSchemaDefinitions)) + for _, def := range testSchemaDefinitions { + defs = append(defs, def.(datastore.SchemaDefinition)) + } + + firstRev, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + schemaWriter, err := rwt.SchemaWriter() + if err != nil { + return err + } + return schemaWriter.WriteSchema(ctx, defs, schemaText, nil) + }) + require.NoError(err) + + // Update schema with additional relation + updatedText, _, err := generator.GenerateSchema(updatedSchemaDefinitions) + require.NoError(err) + + updatedDefs := make([]datastore.SchemaDefinition, 0, len(updatedSchemaDefinitions)) + for _, def := range updatedSchemaDefinitions { + updatedDefs = append(updatedDefs, def.(datastore.SchemaDefinition)) + } + + secondRev, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + schemaWriter, err := rwt.SchemaWriter() + if err != nil { + return err + } + return schemaWriter.WriteSchema(ctx, updatedDefs, updatedText, nil) + }) + require.NoError(err) + require.True(secondRev.GreaterThan(firstRev)) + + // Read at first revision - should see old schema + reader1 := ds.SnapshotReader(firstRev, datastore.NoSchemaHashForTesting) + schemaReader1, err := reader1.SchemaReader() + require.NoError(err) + + typeDefs1, err := schemaReader1.ListAllTypeDefinitions(ctx) + require.NoError(err) + require.Len(typeDefs1, 2) + + docDef1, found, err := schemaReader1.LookupTypeDefByName(ctx, "document") + require.NoError(err) + require.True(found) + require.Len(docDef1.Definition.Relation, 2, "should have 2 relations at first revision") + + // Read at second revision - should see updated schema + reader2 := ds.SnapshotReader(secondRev, datastore.NoSchemaHashForTesting) + schemaReader2, err := reader2.SchemaReader() + require.NoError(err) + + docDef2, found, err := schemaReader2.LookupTypeDefByName(ctx, "document") + require.NoError(err) + require.True(found) + require.Len(docDef2.Definition.Relation, 3, "should have 3 relations at second revision") + + // Verify owner relation exists in updated schema + ownerFound := false + for _, rel := range docDef2.Definition.Relation { + if rel.Name == "owner" { + ownerFound = true + break + } + } + require.True(ownerFound, "owner relation should exist in updated schema") +} + +// UnifiedSchemaRevisionTest verifies that schema revisions are tracked correctly +func UnifiedSchemaRevisionTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + ctx := context.Background() + + ds, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + defer ds.Close() + + // Check if datastore supports SchemaReader + headRev, _, err := ds.HeadRevision(ctx) + require.NoError(err) + + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) + _, err = reader.SchemaReader() + require.NoError(err, "datastore reader must provide SchemaReader") + + // Write schema + schemaText, _, err := generator.GenerateSchema(testSchemaDefinitions) + require.NoError(err) + + defs := make([]datastore.SchemaDefinition, 0, len(testSchemaDefinitions)) + for _, def := range testSchemaDefinitions { + defs = append(defs, def.(datastore.SchemaDefinition)) + } + + writtenRev, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + schemaWriter, err := rwt.SchemaWriter() + if err != nil { + return err + } + return schemaWriter.WriteSchema(ctx, defs, schemaText, nil) + }) + require.NoError(err) + + // Read schema and verify all definitions have consistent revisions + reader = ds.SnapshotReader(writtenRev, datastore.NoSchemaHashForTesting) + schemaReader, err := reader.SchemaReader() + require.NoError(err) + + typeDefs, err := schemaReader.ListAllTypeDefinitions(ctx) + require.NoError(err) + require.Len(typeDefs, 2) + + // All definitions should have valid revisions + for _, def := range typeDefs { + require.NotNil(def.LastWrittenRevision, + "definition revision should be set") + } +} + +// UnifiedSchemaWithCaveatsTest tests unified schema with caveats +func UnifiedSchemaWithCaveatsTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + ctx := context.Background() + + ds, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + defer ds.Close() + + // Check if datastore supports SchemaReader + headRev, _, err := ds.HeadRevision(ctx) + require.NoError(err) + + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) + _, err = reader.SchemaReader() + require.NoError(err, "datastore reader must provide SchemaReader") + + // Define schema with caveat + schemaTextWithCaveat := `caveat is_allowed(allowed bool) { + allowed +} + +definition user {} + +definition document { + relation viewer: user with is_allowed +}` + + compiled, err := compiler.Compile(compiler.InputSchema{ + Source: "schema", + SchemaString: schemaTextWithCaveat, + }, compiler.AllowUnprefixedObjectType()) + require.NoError(err) + + defs := make([]datastore.SchemaDefinition, 0, len(compiled.ObjectDefinitions)+len(compiled.CaveatDefinitions)) + for _, objDef := range compiled.ObjectDefinitions { + defs = append(defs, objDef) + } + for _, caveatDef := range compiled.CaveatDefinitions { + defs = append(defs, caveatDef) + } + + // Write schema + writtenRev, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + schemaWriter, err := rwt.SchemaWriter() + if err != nil { + return err + } + return schemaWriter.WriteSchema(ctx, defs, schemaTextWithCaveat, nil) + }) + require.NoError(err) + + // Read schema + reader = ds.SnapshotReader(writtenRev, datastore.NoSchemaHashForTesting) + schemaReader, err := reader.SchemaReader() + require.NoError(err) + + // Verify caveat + caveats, err := schemaReader.ListAllCaveatDefinitions(ctx) + require.NoError(err) + require.Len(caveats, 1) + require.Equal("is_allowed", caveats[0].Definition.Name) + require.NotNil(caveats[0].LastWrittenRevision) + + // Lookup caveat by name + caveat, found, err := schemaReader.LookupCaveatDefByName(ctx, "is_allowed") + require.NoError(err) + require.True(found) + require.Equal("is_allowed", caveat.Definition.Name) + + // Verify namespace definitions + typeDefs, err := schemaReader.ListAllTypeDefinitions(ctx) + require.NoError(err) + require.Len(typeDefs, 2) +} + +// UnifiedSchemaEmptyTest tests reading when no schema exists +func UnifiedSchemaEmptyTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + ctx := context.Background() + + ds, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + defer ds.Close() + + // Check if datastore supports SchemaReader + headRev, _, err := ds.HeadRevision(ctx) + require.NoError(err) + + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) + _, err = reader.SchemaReader() + require.NoError(err, "datastore reader must provide SchemaReader") + + startRevision, _, err := ds.HeadRevision(ctx) + require.NoError(err) + + reader = ds.SnapshotReader(startRevision, datastore.NoSchemaHashForTesting) + schemaReader, err := reader.SchemaReader() + require.NoError(err) + + // Schema text should be empty or return error - we don't check the result + // since different datastores may behave differently with no schema + _, _ = schemaReader.SchemaText() + + // Should have no type definitions initially + typeDefs, err := schemaReader.ListAllTypeDefinitions(ctx) + require.NoError(err) + require.Empty(typeDefs) + + // Should have no caveat definitions initially + caveats, err := schemaReader.ListAllCaveatDefinitions(ctx) + require.NoError(err) + require.Empty(caveats) +} + +// UnifiedSchemaLookupTest tests lookup operations +func UnifiedSchemaLookupTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + ctx := context.Background() + + ds, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + defer ds.Close() + + // Check if datastore supports SchemaReader + headRev, _, err := ds.HeadRevision(ctx) + require.NoError(err) + + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) + _, err = reader.SchemaReader() + require.NoError(err, "datastore reader must provide SchemaReader") + + // Write schema + schemaText, _, err := generator.GenerateSchema(testSchemaDefinitions) + require.NoError(err) + + defs := make([]datastore.SchemaDefinition, 0, len(testSchemaDefinitions)) + for _, def := range testSchemaDefinitions { + defs = append(defs, def.(datastore.SchemaDefinition)) + } + + writtenRev, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + schemaWriter, err := rwt.SchemaWriter() + if err != nil { + return err + } + return schemaWriter.WriteSchema(ctx, defs, schemaText, nil) + }) + require.NoError(err) + + reader = ds.SnapshotReader(writtenRev, datastore.NoSchemaHashForTesting) + schemaReader, err := reader.SchemaReader() + require.NoError(err) + + // Test ListAllSchemaDefinitions + allDefs, err := schemaReader.ListAllSchemaDefinitions(ctx) + require.NoError(err) + require.Len(allDefs, 2) + require.Contains(allDefs, "user") + require.Contains(allDefs, "document") + + // Test looking up non-existent namespace + _, found, err := schemaReader.LookupTypeDefByName(ctx, "nonexistent") + require.NoError(err) + require.False(found) + + // Test looking up non-existent caveat + _, found, err = schemaReader.LookupCaveatDefByName(ctx, "nonexistent") + require.NoError(err) + require.False(found) +} + +// UnifiedSchemaValidationTest tests that stored schemas are validated +func UnifiedSchemaValidationTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + ctx := context.Background() + + ds, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + defer ds.Close() + + // Check if transaction supports SchemaWriter + var schemaWriterErr error + _, _ = ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + _, schemaWriterErr = rwt.SchemaWriter() + return nil + }) + require.NoError(schemaWriterErr, "datastore transaction must provide SchemaWriter") + + // Write a simple valid schema to verify the writer works + schemaText := "definition user {}" + compiled, err := compiler.Compile(compiler.InputSchema{ + Source: "schema", + SchemaString: schemaText, + }, compiler.AllowUnprefixedObjectType()) + require.NoError(err) + + defs := make([]datastore.SchemaDefinition, 0, len(compiled.ObjectDefinitions)) + for _, objDef := range compiled.ObjectDefinitions { + defs = append(defs, objDef) + } + + _, err = ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + schemaWriter, err := rwt.SchemaWriter() + if err != nil { + return err + } + return schemaWriter.WriteSchema(ctx, defs, schemaText, nil) + }) + require.NoError(err) +} + +// UnifiedSchemaMultipleIterationsTest tests writing and reading randomly generated +// schema data across multiple iterations, verifying that older revisions remain readable. +// This test uses schema diff for comparison, so definition order doesn't matter. +func UnifiedSchemaMultipleIterationsTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + ctx := context.Background() + + ds, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + defer ds.Close() + + // Check if datastore supports SchemaReader + headRev, _, err := ds.HeadRevision(ctx) + require.NoError(err) + + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) + _, err = reader.SchemaReader() + require.NoError(err, "datastore reader must provide SchemaReader") + + const numIterations = 5 + + // Store all revisions and their corresponding compiled schemas + type revisionData struct { + revision datastore.Revision + compiled *compiler.CompiledSchema + } + revisions := make([]revisionData, 0, numIterations) + + // Write schemas in a loop + for i := 0; i < numIterations; i++ { + // Generate random bytes for schema hash + randomBytes := make([]byte, 16) + _, err := rand.Read(randomBytes) + require.NoError(err) + + // Create a unique schema for this iteration + schemaText := fmt.Sprintf(`definition user {} + +definition resource_%d { + relation viewer: user + relation editor: user + permission view = viewer + editor +}`, i) + + // Compile the schema + compiled, err := compiler.Compile(compiler.InputSchema{ + Source: "schema", + SchemaString: schemaText, + }, compiler.AllowUnprefixedObjectType()) + require.NoError(err) + + // Convert to datastore.SchemaDefinition + defs := make([]datastore.SchemaDefinition, 0, len(compiled.ObjectDefinitions)) + for _, objDef := range compiled.ObjectDefinitions { + defs = append(defs, objDef) + } + + // Write schema + writtenRev, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + schemaWriter, err := rwt.SchemaWriter() + if err != nil { + return err + } + return schemaWriter.WriteSchema(ctx, defs, schemaText, nil) + }) + require.NoError(err) + + // Store revision data + revisions = append(revisions, revisionData{ + revision: writtenRev, + compiled: compiled, + }) + + t.Logf("Iteration %d: Written schema at revision %v", i, writtenRev) + } + + // First, verify the latest schema can be read without AS OF SYSTEM TIME + latestRev, _, err := ds.HeadRevision(ctx) + require.NoError(err) + latestReader := ds.SnapshotReader(latestRev, datastore.NoSchemaHashForTesting) + latestSchemaReader, err := latestReader.SchemaReader() + require.NoError(err) + latestSchemaText, err := latestSchemaReader.SchemaText() + require.NoError(err) + require.NotEmpty(latestSchemaText, "Latest schema text should not be empty") + t.Logf("Latest schema (at %v) read successfully", latestRev) + + // Verify all revisions can be read back + for i, revData := range revisions { + t.Logf("Verifying iteration %d at revision %v", i, revData.revision) + + reader := ds.SnapshotReader(revData.revision, datastore.NoSchemaHashForTesting) + schemaReader, err := reader.SchemaReader() + require.NoError(err, "Failed to get schema reader for iteration %d", i) + + // Read the schema text and compile it + readSchemaText, err := schemaReader.SchemaText() + require.NoError(err, "Failed to read schema text for iteration %d", i) + require.NotEmpty(readSchemaText, "Schema text should not be empty for iteration %d", i) + + // Compile the read schema + readCompiled, err := compiler.Compile(compiler.InputSchema{ + Source: "schema", + SchemaString: readSchemaText, + }, compiler.AllowUnprefixedObjectType()) + require.NoError(err, "Failed to compile read schema for iteration %d", i) + + // Use schema diff to compare (order doesn't matter) + expectedDiffable := diff.NewDiffableSchemaFromCompiledSchema(revData.compiled) + actualDiffable := diff.NewDiffableSchemaFromCompiledSchema(readCompiled) + + schemaDiff, err := diff.DiffSchemas(expectedDiffable, actualDiffable, nil) + require.NoError(err, "Failed to diff schemas for iteration %d", i) + + // Verify no differences + require.Empty(schemaDiff.AddedNamespaces, "Unexpected added namespaces at iteration %d", i) + require.Empty(schemaDiff.RemovedNamespaces, "Unexpected removed namespaces at iteration %d", i) + require.Empty(schemaDiff.AddedCaveats, "Unexpected added caveats at iteration %d", i) + require.Empty(schemaDiff.RemovedCaveats, "Unexpected removed caveats at iteration %d", i) + require.Empty(schemaDiff.ChangedNamespaces, "Unexpected changed namespaces at iteration %d", i) + require.Empty(schemaDiff.ChangedCaveats, "Unexpected changed caveats at iteration %d", i) + + t.Logf("Iteration %d: Successfully verified schema at revision %v", i, revData.revision) + } +} + +// UnifiedSchemaLookupByNamesTest tests the LookupTypeDefinitionsByNames and LookupCaveatDefinitionsByNames methods +func UnifiedSchemaLookupByNamesTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + ctx := context.Background() + + ds, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + defer ds.Close() + + // Check if datastore supports SchemaReader + headRev, _, err := ds.HeadRevision(ctx) + require.NoError(err) + + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) + _, err = reader.SchemaReader() + require.NoError(err, "datastore reader must provide SchemaReader") + + // Define schema with multiple namespaces and caveats + schemaText := `caveat is_admin(is_admin bool) { + is_admin +} + +caveat is_owner(is_owner bool) { + is_owner +} + +caveat has_permission(has_permission bool) { + has_permission +} + +definition user {} + +definition document { + relation viewer: user with is_admin + relation editor: user with is_owner + relation owner: user +} + +definition folder { + relation viewer: user + relation owner: user with has_permission +} + +definition organization { + relation member: user +}` + + compiled, err := compiler.Compile(compiler.InputSchema{ + Source: "schema", + SchemaString: schemaText, + }, compiler.AllowUnprefixedObjectType()) + require.NoError(err) + + defs := make([]datastore.SchemaDefinition, 0, len(compiled.ObjectDefinitions)+len(compiled.CaveatDefinitions)) + for _, objDef := range compiled.ObjectDefinitions { + defs = append(defs, objDef) + } + for _, caveatDef := range compiled.CaveatDefinitions { + defs = append(defs, caveatDef) + } + + // Write schema + writtenRev, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + schemaWriter, err := rwt.SchemaWriter() + if err != nil { + return err + } + return schemaWriter.WriteSchema(ctx, defs, schemaText, nil) + }) + require.NoError(err) + + // Read schema + reader = ds.SnapshotReader(writtenRev, datastore.NoSchemaHashForTesting) + schemaReader, err := reader.SchemaReader() + require.NoError(err) + + // Test LookupTypeDefinitionsByNames + t.Run("lookup existing type definitions", func(t *testing.T) { + typeDefs, err := schemaReader.LookupTypeDefinitionsByNames(ctx, []string{"user", "document"}) + require.NoError(err) + require.Len(typeDefs, 2) + require.Contains(typeDefs, "user") + require.Contains(typeDefs, "document") + + userDef, ok := typeDefs["user"].(*core.NamespaceDefinition) + require.True(ok) + require.Equal("user", userDef.Name) + + docDef, ok := typeDefs["document"].(*core.NamespaceDefinition) + require.True(ok) + require.Equal("document", docDef.Name) + }) + + t.Run("lookup all type definitions", func(t *testing.T) { + typeDefs, err := schemaReader.LookupTypeDefinitionsByNames(ctx, []string{"user", "document", "folder", "organization"}) + require.NoError(err) + require.Len(typeDefs, 4) + require.Contains(typeDefs, "user") + require.Contains(typeDefs, "document") + require.Contains(typeDefs, "folder") + require.Contains(typeDefs, "organization") + }) + + t.Run("lookup non-existent type definitions", func(t *testing.T) { + typeDefs, err := schemaReader.LookupTypeDefinitionsByNames(ctx, []string{"nonexistent"}) + require.NoError(err) + require.Empty(typeDefs) + }) + + t.Run("lookup mixed existing and non-existent type definitions", func(t *testing.T) { + typeDefs, err := schemaReader.LookupTypeDefinitionsByNames(ctx, []string{"user", "nonexistent", "document"}) + require.NoError(err) + require.Len(typeDefs, 2) + require.Contains(typeDefs, "user") + require.Contains(typeDefs, "document") + require.NotContains(typeDefs, "nonexistent") + }) + + t.Run("lookup empty list of type definitions", func(t *testing.T) { + typeDefs, err := schemaReader.LookupTypeDefinitionsByNames(ctx, []string{}) + require.NoError(err) + require.Empty(typeDefs) + }) + + t.Run("type lookup does not return caveats", func(t *testing.T) { + typeDefs, err := schemaReader.LookupTypeDefinitionsByNames(ctx, []string{"is_admin"}) + require.NoError(err) + require.Empty(typeDefs) + }) + + // Test LookupCaveatDefinitionsByNames + t.Run("lookup existing caveat definitions", func(t *testing.T) { + caveatDefs, err := schemaReader.LookupCaveatDefinitionsByNames(ctx, []string{"is_admin", "is_owner"}) + require.NoError(err) + require.Len(caveatDefs, 2) + require.Contains(caveatDefs, "is_admin") + require.Contains(caveatDefs, "is_owner") + + isAdminDef, ok := caveatDefs["is_admin"].(*core.CaveatDefinition) + require.True(ok) + require.Equal("is_admin", isAdminDef.Name) + + isOwnerDef, ok := caveatDefs["is_owner"].(*core.CaveatDefinition) + require.True(ok) + require.Equal("is_owner", isOwnerDef.Name) + }) + + t.Run("lookup all caveat definitions", func(t *testing.T) { + caveatDefs, err := schemaReader.LookupCaveatDefinitionsByNames(ctx, []string{"is_admin", "is_owner", "has_permission"}) + require.NoError(err) + require.Len(caveatDefs, 3) + require.Contains(caveatDefs, "is_admin") + require.Contains(caveatDefs, "is_owner") + require.Contains(caveatDefs, "has_permission") + }) + + t.Run("lookup non-existent caveat definitions", func(t *testing.T) { + caveatDefs, err := schemaReader.LookupCaveatDefinitionsByNames(ctx, []string{"nonexistent"}) + require.NoError(err) + require.Empty(caveatDefs) + }) + + t.Run("lookup mixed existing and non-existent caveat definitions", func(t *testing.T) { + caveatDefs, err := schemaReader.LookupCaveatDefinitionsByNames(ctx, []string{"is_admin", "nonexistent", "has_permission"}) + require.NoError(err) + require.Len(caveatDefs, 2) + require.Contains(caveatDefs, "is_admin") + require.Contains(caveatDefs, "has_permission") + require.NotContains(caveatDefs, "nonexistent") + }) + + t.Run("lookup empty list of caveat definitions", func(t *testing.T) { + caveatDefs, err := schemaReader.LookupCaveatDefinitionsByNames(ctx, []string{}) + require.NoError(err) + require.Empty(caveatDefs) + }) + + t.Run("caveat lookup does not return type definitions", func(t *testing.T) { + caveatDefs, err := schemaReader.LookupCaveatDefinitionsByNames(ctx, []string{"user"}) + require.NoError(err) + require.Empty(caveatDefs) + }) + + // Test LookupSchemaDefinitionsByNames (mixed types and caveats) + t.Run("lookup both types and caveats", func(t *testing.T) { + allDefs, err := schemaReader.LookupSchemaDefinitionsByNames(ctx, []string{"user", "is_admin", "document", "is_owner"}) + require.NoError(err) + require.Len(allDefs, 4) + require.Contains(allDefs, "user") + require.Contains(allDefs, "is_admin") + require.Contains(allDefs, "document") + require.Contains(allDefs, "is_owner") + + // Verify correct types + _, ok := allDefs["user"].(*core.NamespaceDefinition) + require.True(ok, "user should be a NamespaceDefinition") + + _, ok = allDefs["is_admin"].(*core.CaveatDefinition) + require.True(ok, "is_admin should be a CaveatDefinition") + }) +} + +// DatastoreTesterWithSchemaMode is a function that creates a datastore with a specific schema mode. +type DatastoreTesterWithSchemaMode func(schemaMode options.SchemaMode) DatastoreTester + +// UnifiedSchemaAllModesTest tests unified schema storage with all four schema modes: +// ReadLegacyWriteLegacy, ReadLegacyWriteBoth, ReadNewWriteBoth, and ReadNewWriteNew. +// This ensures that schema writes happen in the same transaction as the ReadWriteTx, +// allowing AS OF SYSTEM TIME (or equivalent MVCC) queries to work correctly. +func UnifiedSchemaAllModesTest(t *testing.T, testerFactory DatastoreTesterWithSchemaMode) { + testCases := []struct { + name string + schemaMode options.SchemaMode + }{ + { + name: "ReadLegacyWriteLegacy", + schemaMode: options.SchemaModeReadLegacyWriteLegacy, + }, + { + name: "ReadLegacyWriteBoth", + schemaMode: options.SchemaModeReadLegacyWriteBoth, + }, + { + name: "ReadNewWriteBoth", + schemaMode: options.SchemaModeReadNewWriteBoth, + }, + { + name: "ReadNewWriteNew", + schemaMode: options.SchemaModeReadNewWriteNew, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + tester := testerFactory(tc.schemaMode) + UnifiedSchemaMultipleIterationsTest(t, tester) + }) + } +} + +// UnifiedSchemaHashTest tests that the schema_revision table is correctly populated with the schema hash +func UnifiedSchemaHashTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + ctx := context.Background() + + ds, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + defer ds.Close() + + // Check if datastore supports SchemaReader + headRev, _, err := ds.HeadRevision(ctx) + require.NoError(err) + + reader := ds.SnapshotReader(headRev, datastore.NoSchemaHashForTesting) + _, err = reader.SchemaReader() + require.NoError(err, "datastore reader must provide SchemaReader") + + // Get the hash reader (test-only interface) + hashReader, hasHashReader := ds.(interface { + SchemaHashReaderForTesting() interface { + ReadSchemaHash(ctx context.Context) (string, error) + } + }) + require.True(hasHashReader, "datastore must implement SchemaHashReaderForTesting") + + // Get the reader implementation + readerImpl := hashReader.SchemaHashReaderForTesting() + require.NotNil(readerImpl, "SchemaHashReaderForTesting() must return non-nil reader") + + // Generate schema text + schemaText, _, err := generator.GenerateSchema(testSchemaDefinitions) + require.NoError(err) + + // Convert to datastore.SchemaDefinition + defs := make([]datastore.SchemaDefinition, 0, len(testSchemaDefinitions)) + for _, def := range testSchemaDefinitions { + defs = append(defs, def.(datastore.SchemaDefinition)) + } + + // Write schema + _, err = ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + schemaWriter, err := rwt.SchemaWriter() + if err != nil { + return err + } + return schemaWriter.WriteSchema(ctx, defs, schemaText, nil) + }) + require.NoError(err) + + // Get the schema mode from the datastore to determine expected behavior + schemaModeProvider, ok := ds.(interface { + SchemaModeForTesting() (options.SchemaMode, error) + }) + require.True(ok, "datastore must implement SchemaModeForTesting() for this test") + schemaMode, err := schemaModeProvider.SchemaModeForTesting() + require.NoError(err, "failed to get schema mode from datastore") + + // Determine if this mode should write the unified schema hash + // The hash is written in all modes that write to the unified schema + shouldWriteHash := schemaMode == options.SchemaModeReadLegacyWriteBoth || + schemaMode == options.SchemaModeReadNewWriteBoth || + schemaMode == options.SchemaModeReadNewWriteNew + + // Read the schema hash from schema_revision table + hash, hashErr := readerImpl.ReadSchemaHash(ctx) + + if shouldWriteHash { + // Hash MUST be present and correct + require.NoError(hashErr, "schema hash should be present in mode %s", schemaMode) + require.NotEmpty(hash, "schema hash should not be empty") + expectedHash := computeExpectedSchemaHash(t, testSchemaDefinitions) + require.Equal(expectedHash, hash, "schema hash should match computed hash of sorted schema text") + } else { + // Hash should NOT be present in ReadLegacyWriteLegacy mode + require.ErrorIs(hashErr, datastore.ErrSchemaNotFound, + "expected ErrSchemaNotFound in mode %s, got: %v", schemaMode, hashErr) + } + + // Update the schema + updatedSchemaText, _, err := generator.GenerateSchema(updatedSchemaDefinitions) + require.NoError(err) + + updatedDefs := make([]datastore.SchemaDefinition, 0, len(updatedSchemaDefinitions)) + for _, def := range updatedSchemaDefinitions { + updatedDefs = append(updatedDefs, def.(datastore.SchemaDefinition)) + } + + _, err = ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + schemaWriter, err := rwt.SchemaWriter() + if err != nil { + return err + } + return schemaWriter.WriteSchema(ctx, updatedDefs, updatedSchemaText, nil) + }) + require.NoError(err) + + // Read the updated schema hash and verify based on mode + updatedHash, updatedHashErr := hashReader.SchemaHashReaderForTesting().ReadSchemaHash(ctx) + + if shouldWriteHash { + // Hash MUST be present, correct, and different from the first hash + require.NoError(updatedHashErr, "updated schema hash should be present in mode %s", schemaMode) + require.NotEmpty(updatedHash, "updated schema hash should not be empty") + require.NotEqual(hash, updatedHash, "schema hash should change after update") + + // Verify the updated hash is correct + expectedUpdatedHash := computeExpectedSchemaHash(t, updatedSchemaDefinitions) + require.Equal(expectedUpdatedHash, updatedHash, "updated schema hash should match computed hash of sorted updated schema text") + } else { + // Hash should still NOT be present in ReadLegacyWriteLegacy mode + require.ErrorIs(updatedHashErr, datastore.ErrSchemaNotFound, + "expected ErrSchemaNotFound after update in mode %s, got: %v", schemaMode, updatedHashErr) + } +} diff --git a/pkg/datastore/test/watch.go b/pkg/datastore/test/watch.go index be3f9781f..e1c657372 100644 --- a/pkg/datastore/test/watch.go +++ b/pkg/datastore/test/watch.go @@ -66,7 +66,7 @@ func WatchTest(t *testing.T, tester DatastoreTester) { setupDatastore(ds, require) - lowestRevision, err := ds.HeadRevision(t.Context()) + lowestRevision, _, err := ds.HeadRevision(t.Context()) require.NoError(err) opts := datastore.WatchOptions{ @@ -290,7 +290,7 @@ func WatchWithTouchTest(t *testing.T, tester DatastoreTester) { ctx, cancel := context.WithCancel(t.Context()) defer cancel() - lowestRevision, err := ds.HeadRevision(ctx) + lowestRevision, _, err := ds.HeadRevision(ctx) require.NoError(err) // TOUCH a relationship and ensure watch sees it. @@ -395,7 +395,7 @@ func WatchWithExpirationTest(t *testing.T, tester DatastoreTester) { ctx, cancel := context.WithCancel(t.Context()) defer cancel() - lowestRevision, err := ds.HeadRevision(ctx) + lowestRevision, _, err := ds.HeadRevision(ctx) require.NoError(err) changes, errchan := ds.Watch(ctx, lowestRevision, datastore.WatchJustRelationships()) @@ -440,7 +440,7 @@ func WatchWithMetadataTest(t *testing.T, tester DatastoreTester) { ctx, cancel := context.WithCancel(t.Context()) defer cancel() - lowestRevision, err := ds.HeadRevision(ctx) + lowestRevision, _, err := ds.HeadRevision(ctx) require.NoError(err) changes, errchan := ds.Watch(ctx, lowestRevision, datastore.WatchJustRelationships()) @@ -479,7 +479,7 @@ func WatchWithDeleteTest(t *testing.T, tester DatastoreTester) { ctx, cancel := context.WithCancel(t.Context()) defer cancel() - lowestRevision, err := ds.HeadRevision(ctx) + lowestRevision, _, err := ds.HeadRevision(ctx) require.NoError(err) // TOUCH a relationship and ensure watch sees it. @@ -571,7 +571,7 @@ func WatchSchemaTest(t *testing.T, tester DatastoreTester) { ctx, cancel := context.WithCancel(t.Context()) defer cancel() - lowestRevision, err := ds.HeadRevision(ctx) + lowestRevision, _, err := ds.HeadRevision(ctx) require.NoError(err) changes, errchan := ds.Watch(ctx, lowestRevision, datastore.WatchJustSchema()) @@ -663,7 +663,7 @@ func WatchAllTest(t *testing.T, tester DatastoreTester) { ctx, cancel := context.WithCancel(t.Context()) defer cancel() - lowestRevision, err := ds.HeadRevision(ctx) + lowestRevision, _, err := ds.HeadRevision(ctx) require.NoError(err) changes, errchan := ds.Watch(ctx, lowestRevision, datastore.WatchOptions{ @@ -795,7 +795,7 @@ func WatchCheckpointsTest(t *testing.T, tester DatastoreTester) { ctx, cancel := context.WithCancel(t.Context()) defer cancel() - lowestRevision, err := ds.HeadRevision(ctx) + lowestRevision, _, err := ds.HeadRevision(ctx) require.NoError(err) changes, errchan := ds.Watch(ctx, lowestRevision, datastore.WatchOptions{ @@ -835,7 +835,7 @@ func WatchEmissionStrategyTest(t *testing.T, tester DatastoreTester) { expectsWatchError := (features.WatchEmitsImmediately.Status != datastore.FeatureSupported) - lowestRevision, err := ds.HeadRevision(ctx) + lowestRevision, _, err := ds.HeadRevision(ctx) require.NoError(err) changes, errchan := ds.Watch(ctx, lowestRevision, datastore.WatchOptions{ diff --git a/pkg/development/check.go b/pkg/development/check.go index 292a1ea2d..325bded83 100644 --- a/pkg/development/check.go +++ b/pkg/development/check.go @@ -6,6 +6,7 @@ import ( "github.com/authzed/spicedb/internal/graph/computed" v1 "github.com/authzed/spicedb/internal/services/v1" caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" v1dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1" "github.com/authzed/spicedb/pkg/tuple" ) @@ -33,6 +34,7 @@ func RunCheck(devContext *DevContext, resource tuple.ObjectAndRelation, subject Subject: subject, CaveatContext: caveatContext, AtRevision: devContext.Revision, + SchemaHash: datastore.NoSchemaHashForTesting, // DevContext uses memdb which doesn't support hashing MaximumDepth: maxDispatchDepth, DebugOption: computed.TraceDebuggingEnabled, }, @@ -43,7 +45,8 @@ func RunCheck(devContext *DevContext, resource tuple.ObjectAndRelation, subject return CheckResult{v1dispatch.ResourceCheckResult_NOT_MEMBER, nil, nil, nil}, err } - reader := devContext.Datastore.SnapshotReader(devContext.Revision) + // DevContext uses memdb which doesn't support schema hashing + reader := devContext.Datastore.SnapshotReader(devContext.Revision, datastore.NoSchemaHashForTesting) converted, err := v1.ConvertCheckDispatchDebugInformation(ctx, caveattypes.Default.TypeSet, caveatContext, meta.DebugInfo, reader) if err != nil { return CheckResult{v1dispatch.ResourceCheckResult_NOT_MEMBER, nil, nil, nil}, err diff --git a/pkg/middleware/consistency/consistency.go b/pkg/middleware/consistency/consistency.go index 6a87c8d99..3198a612b 100644 --- a/pkg/middleware/consistency/consistency.go +++ b/pkg/middleware/consistency/consistency.go @@ -61,7 +61,14 @@ var revisionKey ctxKeyType = struct{}{} var errInvalidZedToken = status.Error(codes.InvalidArgument, "invalid revision requested") type revisionHandle struct { - revision datastore.Revision + revision datastore.Revision + schemaHash datastore.SchemaHash +} + +// setRevisionAndHash sets both revision and schema hash for this handle. +func (h *revisionHandle) setRevisionAndHash(rev datastore.Revision, hash datastore.SchemaHash) { + h.revision = rev + h.schemaHash = hash } // ContextWithHandle adds a placeholder to a context that will later be @@ -94,6 +101,30 @@ func RevisionFromContext(ctx context.Context) (datastore.Revision, *v1.ZedToken, return nil, nil, status.Error(codes.Internal, "consistency middleware did not inject revision") } +// RevisionAndSchemaHashFromContext returns the revision and schema hash from the context. +func RevisionAndSchemaHashFromContext(ctx context.Context) (datastore.Revision, datastore.SchemaHash, *v1.ZedToken, error) { + if c := ctx.Value(revisionKey); c != nil { + handle := c.(*revisionHandle) + rev := handle.revision + schemaHash := handle.schemaHash + if rev != nil { + ds := datastoremw.FromContext(ctx) + if ds == nil { + return nil, "", nil, spiceerrors.MustBugf("consistency middleware did not inject datastore") + } + + zedToken, err := zedtoken.NewFromRevision(ctx, rev, ds) + if err != nil { + return nil, "", nil, err + } + + return rev, schemaHash, zedToken, nil + } + } + + return nil, "", nil, status.Error(codes.Internal, "consistency middleware did not inject revision") +} + // AddRevisionToContext adds a revision to the given context, based on the consistency block found // in the given request (if applicable). func AddRevisionToContext(ctx context.Context, req any, ds datastore.Datastore, serviceLabel string, option MismatchingTokenOption) error { @@ -108,12 +139,14 @@ func AddRevisionToContext(ctx context.Context, req any, ds datastore.Datastore, // addRevisionToContextFromConsistency adds a revision to the given context, based on the consistency block found // in the given request (if applicable). func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency, ds datastore.Datastore, serviceLabel string, option MismatchingTokenOption) error { - handle := ctx.Value(revisionKey) - if handle == nil { + handleValue := ctx.Value(revisionKey) + if handleValue == nil { return nil } + handle := handleValue.(*revisionHandle) var revision datastore.Revision + var schemaHash datastore.SchemaHash consistency := req.GetConsistency() withOptionalCursor, hasOptionalCursor := req.(hasOptionalCursor) @@ -148,11 +181,12 @@ func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency ConsistencyCounter.WithLabelValues("minlatency", source, serviceLabel).Inc() } - databaseRev, err := ds.OptimizedRevision(ctx) + databaseRev, hash, err := ds.OptimizedRevision(ctx) if err != nil { return rewriteDatastoreError(err) } revision = databaseRev + schemaHash = hash case consistency.GetFullyConsistent(): // Fully Consistent: Use the datastore's synchronized revision. @@ -160,11 +194,12 @@ func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency ConsistencyCounter.WithLabelValues("full", "request", serviceLabel).Inc() } - databaseRev, err := ds.HeadRevision(ctx) + databaseRev, hash, err := ds.HeadRevision(ctx) if err != nil { return rewriteDatastoreError(err) } revision = databaseRev + schemaHash = hash case consistency.GetAtLeastAsFresh() != nil: // At least as fresh as: Pick one of the datastore's revision and that specified, which @@ -211,7 +246,10 @@ func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency return status.Errorf(codes.Internal, "missing handling of consistency case in %v", consistency) } - handle.(*revisionHandle).revision = revision + // Note: For some consistency modes (AtLeastAsFresh, AtExactSnapshot), schema hash may be empty. + // This will cause cache lookups to fail, which is expected until schema hash is properly threaded + // through these code paths. + handle.setRevisionAndHash(revision, schemaHash) return nil } @@ -276,7 +314,7 @@ func (s *recvWrapper) RecvMsg(m any) error { // recent one. The boolean return value will be true if the provided ZedToken is the most recent, false otherwise. func pickBestRevision(ctx context.Context, requested *v1.ZedToken, ds datastore.Datastore, option MismatchingTokenOption) (datastore.Revision, bool, error) { // Calculate a revision as we see fit - databaseRev, err := ds.OptimizedRevision(ctx) + databaseRev, _, err := ds.OptimizedRevision(ctx) if err != nil { return datastore.NoRevision, false, err } @@ -291,7 +329,7 @@ func pickBestRevision(ctx context.Context, requested *v1.ZedToken, ds datastore. switch option { case TreatMismatchingTokensAsFullConsistency: log.Warn().Str("zedtoken", requested.Token).Msg("ZedToken specified references a different datastore instance and SpiceDB is configured to treat this as a full consistency request") - headRev, err := ds.HeadRevision(ctx) + headRev, _, err := ds.HeadRevision(ctx) if err != nil { return datastore.NoRevision, false, err } diff --git a/pkg/middleware/consistency/consistency_test.go b/pkg/middleware/consistency/consistency_test.go index a332a0058..26f46033c 100644 --- a/pkg/middleware/consistency/consistency_test.go +++ b/pkg/middleware/consistency/consistency_test.go @@ -31,7 +31,7 @@ func TestAddRevisionToContextNoneSupplied(t *testing.T) { require := require.New(t) ds := &proxy_test.MockDatastore{} - ds.On("OptimizedRevision").Return(optimized, nil).Once() + ds.On("OptimizedRevision").Return(optimized, datastore.NoSchemaHashForTesting, nil).Once() updated := ContextWithHandle(t.Context()) updated = datastoremw.ContextWithDatastore(updated, ds) @@ -50,7 +50,7 @@ func TestAddRevisionToContextMinimizeLatency(t *testing.T) { require := require.New(t) ds := &proxy_test.MockDatastore{} - ds.On("OptimizedRevision").Return(optimized, nil).Once() + ds.On("OptimizedRevision").Return(optimized, datastore.NoSchemaHashForTesting, nil).Once() updated := ContextWithHandle(t.Context()) updated = datastoremw.ContextWithDatastore(updated, ds) @@ -75,7 +75,7 @@ func TestAddRevisionToContextFullyConsistent(t *testing.T) { require := require.New(t) ds := &proxy_test.MockDatastore{} - ds.On("HeadRevision").Return(head, nil).Once() + ds.On("HeadRevision").Return(head, datastore.NoSchemaHashForTesting, nil).Once() updated := ContextWithHandle(t.Context()) updated = datastoremw.ContextWithDatastore(updated, ds) @@ -100,7 +100,7 @@ func TestAddRevisionToContextAtLeastAsFresh(t *testing.T) { require := require.New(t) ds := &proxy_test.MockDatastore{} - ds.On("OptimizedRevision").Return(optimized, nil).Once() + ds.On("OptimizedRevision").Return(optimized, datastore.NoSchemaHashForTesting, nil).Once() ds.On("RevisionFromString", exact.String()).Return(exact, nil).Once() updated := ContextWithHandle(t.Context()) @@ -190,7 +190,7 @@ func TestAddRevisionToContextWithCursor(t *testing.T) { ds.On("RevisionFromString", optimized.String()).Return(optimized, nil).Once() // cursor is at `optimized` - cursor, err := cursor.EncodeFromDispatchCursor(&dispatch.Cursor{}, "somehash", optimized, nil) + cursor, err := cursor.EncodeFromDispatchCursor(&dispatch.Cursor{}, "somehash", optimized, datastore.NoSchemaHashForTesting, nil) require.NoError(err) // revision in context is at `exact` @@ -229,7 +229,7 @@ func TestAddRevisionToContextAtMalformedExactSnapshot(t *testing.T) { func TestAddRevisionToContextMalformedAtLeastAsFreshSnapshot(t *testing.T) { ds := &proxy_test.MockDatastore{} - ds.On("OptimizedRevision").Return(optimized, nil).Once() + ds.On("OptimizedRevision").Return(optimized, datastore.NoSchemaHashForTesting, nil).Once() err := AddRevisionToContext(ContextWithHandle(t.Context()), &v1.LookupResourcesRequest{ Consistency: &v1.Consistency{ @@ -327,7 +327,7 @@ func TestAtLeastAsFreshWithMismatchedTokenExpectError(t *testing.T) { require := require.New(t) ds := &proxy_test.MockDatastore{} - ds.On("OptimizedRevision").Return(optimized, nil).Once() + ds.On("OptimizedRevision").Return(optimized, datastore.NoSchemaHashForTesting, nil).Once() ds.On("RevisionFromString", optimized.String()).Return(optimized, nil).Once() // revision in context is at `exact` @@ -355,7 +355,7 @@ func TestAtLeastAsFreshWithMismatchedTokenExpectMinLatency(t *testing.T) { require := require.New(t) ds := &proxy_test.MockDatastore{} - ds.On("OptimizedRevision").Return(optimized, nil).Once() + ds.On("OptimizedRevision").Return(optimized, datastore.NoSchemaHashForTesting, nil).Once() ds.On("RevisionFromString", optimized.String()).Return(optimized, nil).Once() // revision in context is at `exact` @@ -388,8 +388,8 @@ func TestAtLeastAsFreshWithMismatchedTokenExpectFullConsistency(t *testing.T) { require := require.New(t) ds := &proxy_test.MockDatastore{} - ds.On("HeadRevision").Return(head, nil).Once() - ds.On("OptimizedRevision").Return(optimized, nil).Once() + ds.On("HeadRevision").Return(head, datastore.NoSchemaHashForTesting, nil).Once() + ds.On("OptimizedRevision").Return(optimized, datastore.NoSchemaHashForTesting, nil).Once() ds.On("RevisionFromString", optimized.String()).Return(optimized, nil).Once() // revision in context is at `exact` @@ -422,7 +422,7 @@ func TestAddRevisionToContextAtLeastAsFreshMatchingIDs(t *testing.T) { require := require.New(t) ds := &proxy_test.MockDatastore{} - ds.On("OptimizedRevision").Return(optimized, nil).Once() + ds.On("OptimizedRevision").Return(optimized, datastore.NoSchemaHashForTesting, nil).Once() ds.On("RevisionFromString", exact.String()).Return(exact, nil).Once() ds.CurrentUniqueID = "foo" diff --git a/pkg/middleware/consistency/forcefull.go b/pkg/middleware/consistency/forcefull.go index 44e06e154..9bea06a45 100644 --- a/pkg/middleware/consistency/forcefull.go +++ b/pkg/middleware/consistency/forcefull.go @@ -44,8 +44,8 @@ func ForceFullConsistencyStreamServerInterceptor(serviceLabel string) grpc.Strea } func setFullConsistencyRevisionToContext(ctx context.Context, req any, ds datastore.Datastore, serviceLabel string, _ MismatchingTokenOption) error { - handle := ctx.Value(revisionKey) - if handle == nil { + handleValue := ctx.Value(revisionKey) + if handleValue == nil { return nil } @@ -54,11 +54,11 @@ func setFullConsistencyRevisionToContext(ctx context.Context, req any, ds datast ConsistencyCounter.WithLabelValues("full", "request", serviceLabel).Inc() } - databaseRev, err := ds.HeadRevision(ctx) + databaseRev, schemaHash, err := ds.HeadRevision(ctx) if err != nil { return rewriteDatastoreError(err) } - handle.(*revisionHandle).revision = databaseRev + handleValue.(*revisionHandle).setRevisionAndHash(databaseRev, schemaHash) } return nil diff --git a/pkg/proto/core/v1/core.pb.go b/pkg/proto/core/v1/core.pb.go index 3edaeb144..219ada1db 100644 --- a/pkg/proto/core/v1/core.pb.go +++ b/pkg/proto/core/v1/core.pb.go @@ -2592,6 +2592,81 @@ func (x *SubjectFilter) GetOptionalRelation() *SubjectFilter_RelationFilter { return nil } +// StoredSchema represents a stored schema in SpiceDB under the new, unified schema format. +type StoredSchema struct { + state protoimpl.MessageState `protogen:"open.v1"` + Version uint32 `protobuf:"varint,1,opt,name=version,proto3" json:"version,omitempty"` + // Types that are valid to be assigned to VersionOneof: + // + // *StoredSchema_V1 + VersionOneof isStoredSchema_VersionOneof `protobuf_oneof:"version_oneof"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StoredSchema) Reset() { + *x = StoredSchema{} + mi := &file_core_v1_core_proto_msgTypes[33] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StoredSchema) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StoredSchema) ProtoMessage() {} + +func (x *StoredSchema) ProtoReflect() protoreflect.Message { + mi := &file_core_v1_core_proto_msgTypes[33] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StoredSchema.ProtoReflect.Descriptor instead. +func (*StoredSchema) Descriptor() ([]byte, []int) { + return file_core_v1_core_proto_rawDescGZIP(), []int{33} +} + +func (x *StoredSchema) GetVersion() uint32 { + if x != nil { + return x.Version + } + return 0 +} + +func (x *StoredSchema) GetVersionOneof() isStoredSchema_VersionOneof { + if x != nil { + return x.VersionOneof + } + return nil +} + +func (x *StoredSchema) GetV1() *StoredSchema_V1StoredSchema { + if x != nil { + if x, ok := x.VersionOneof.(*StoredSchema_V1); ok { + return x.V1 + } + } + return nil +} + +type isStoredSchema_VersionOneof interface { + isStoredSchema_VersionOneof() +} + +type StoredSchema_V1 struct { + V1 *StoredSchema_V1StoredSchema `protobuf:"bytes,2,opt,name=v1,proto3,oneof"` +} + +func (*StoredSchema_V1) isStoredSchema_VersionOneof() {} + type AllowedRelation_PublicWildcard struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -2600,7 +2675,7 @@ type AllowedRelation_PublicWildcard struct { func (x *AllowedRelation_PublicWildcard) Reset() { *x = AllowedRelation_PublicWildcard{} - mi := &file_core_v1_core_proto_msgTypes[36] + mi := &file_core_v1_core_proto_msgTypes[37] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2612,7 +2687,7 @@ func (x *AllowedRelation_PublicWildcard) String() string { func (*AllowedRelation_PublicWildcard) ProtoMessage() {} func (x *AllowedRelation_PublicWildcard) ProtoReflect() protoreflect.Message { - mi := &file_core_v1_core_proto_msgTypes[36] + mi := &file_core_v1_core_proto_msgTypes[37] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2653,7 +2728,7 @@ type SetOperation_Child struct { func (x *SetOperation_Child) Reset() { *x = SetOperation_Child{} - mi := &file_core_v1_core_proto_msgTypes[37] + mi := &file_core_v1_core_proto_msgTypes[38] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2665,7 +2740,7 @@ func (x *SetOperation_Child) String() string { func (*SetOperation_Child) ProtoMessage() {} func (x *SetOperation_Child) ProtoReflect() protoreflect.Message { - mi := &file_core_v1_core_proto_msgTypes[37] + mi := &file_core_v1_core_proto_msgTypes[38] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2821,7 +2896,7 @@ type SetOperation_Child_This struct { func (x *SetOperation_Child_This) Reset() { *x = SetOperation_Child_This{} - mi := &file_core_v1_core_proto_msgTypes[38] + mi := &file_core_v1_core_proto_msgTypes[39] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2833,7 +2908,7 @@ func (x *SetOperation_Child_This) String() string { func (*SetOperation_Child_This) ProtoMessage() {} func (x *SetOperation_Child_This) ProtoReflect() protoreflect.Message { - mi := &file_core_v1_core_proto_msgTypes[38] + mi := &file_core_v1_core_proto_msgTypes[39] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2857,7 +2932,7 @@ type SetOperation_Child_Nil struct { func (x *SetOperation_Child_Nil) Reset() { *x = SetOperation_Child_Nil{} - mi := &file_core_v1_core_proto_msgTypes[39] + mi := &file_core_v1_core_proto_msgTypes[40] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2869,7 +2944,7 @@ func (x *SetOperation_Child_Nil) String() string { func (*SetOperation_Child_Nil) ProtoMessage() {} func (x *SetOperation_Child_Nil) ProtoReflect() protoreflect.Message { - mi := &file_core_v1_core_proto_msgTypes[39] + mi := &file_core_v1_core_proto_msgTypes[40] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2894,7 +2969,7 @@ type SetOperation_Child_Self struct { func (x *SetOperation_Child_Self) Reset() { *x = SetOperation_Child_Self{} - mi := &file_core_v1_core_proto_msgTypes[40] + mi := &file_core_v1_core_proto_msgTypes[41] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2906,7 +2981,7 @@ func (x *SetOperation_Child_Self) String() string { func (*SetOperation_Child_Self) ProtoMessage() {} func (x *SetOperation_Child_Self) ProtoReflect() protoreflect.Message { - mi := &file_core_v1_core_proto_msgTypes[40] + mi := &file_core_v1_core_proto_msgTypes[41] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2931,7 +3006,7 @@ type TupleToUserset_Tupleset struct { func (x *TupleToUserset_Tupleset) Reset() { *x = TupleToUserset_Tupleset{} - mi := &file_core_v1_core_proto_msgTypes[41] + mi := &file_core_v1_core_proto_msgTypes[42] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2943,7 +3018,7 @@ func (x *TupleToUserset_Tupleset) String() string { func (*TupleToUserset_Tupleset) ProtoMessage() {} func (x *TupleToUserset_Tupleset) ProtoReflect() protoreflect.Message { - mi := &file_core_v1_core_proto_msgTypes[41] + mi := &file_core_v1_core_proto_msgTypes[42] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2975,7 +3050,7 @@ type FunctionedTupleToUserset_Tupleset struct { func (x *FunctionedTupleToUserset_Tupleset) Reset() { *x = FunctionedTupleToUserset_Tupleset{} - mi := &file_core_v1_core_proto_msgTypes[42] + mi := &file_core_v1_core_proto_msgTypes[43] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2987,7 +3062,7 @@ func (x *FunctionedTupleToUserset_Tupleset) String() string { func (*FunctionedTupleToUserset_Tupleset) ProtoMessage() {} func (x *FunctionedTupleToUserset_Tupleset) ProtoReflect() protoreflect.Message { - mi := &file_core_v1_core_proto_msgTypes[42] + mi := &file_core_v1_core_proto_msgTypes[43] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3019,7 +3094,7 @@ type SubjectFilter_RelationFilter struct { func (x *SubjectFilter_RelationFilter) Reset() { *x = SubjectFilter_RelationFilter{} - mi := &file_core_v1_core_proto_msgTypes[43] + mi := &file_core_v1_core_proto_msgTypes[44] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3031,7 +3106,7 @@ func (x *SubjectFilter_RelationFilter) String() string { func (*SubjectFilter_RelationFilter) ProtoMessage() {} func (x *SubjectFilter_RelationFilter) ProtoReflect() protoreflect.Message { - mi := &file_core_v1_core_proto_msgTypes[43] + mi := &file_core_v1_core_proto_msgTypes[44] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3054,6 +3129,80 @@ func (x *SubjectFilter_RelationFilter) GetRelation() string { return "" } +type StoredSchema_V1StoredSchema struct { + state protoimpl.MessageState `protogen:"open.v1"` + // schema_text is the text of the schema that was given to SpiceDB by the API caller. + SchemaText string `protobuf:"bytes,1,opt,name=schema_text,json=schemaText,proto3" json:"schema_text,omitempty"` + // schema_hash is the hash of the schema, for change detection. + SchemaHash string `protobuf:"bytes,2,opt,name=schema_hash,json=schemaHash,proto3" json:"schema_hash,omitempty"` + // namespace_definitions is a map of namespace name to NamespaceDefinition. + // Entries must have full metadata filled out. + NamespaceDefinitions map[string]*NamespaceDefinition `protobuf:"bytes,3,rep,name=namespace_definitions,json=namespaceDefinitions,proto3" json:"namespace_definitions,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + // caveat_definitions is a map of caveat name to CaveatDefinition. + // Entries must have full metadata filled out. + CaveatDefinitions map[string]*CaveatDefinition `protobuf:"bytes,4,rep,name=caveat_definitions,json=caveatDefinitions,proto3" json:"caveat_definitions,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StoredSchema_V1StoredSchema) Reset() { + *x = StoredSchema_V1StoredSchema{} + mi := &file_core_v1_core_proto_msgTypes[45] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StoredSchema_V1StoredSchema) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StoredSchema_V1StoredSchema) ProtoMessage() {} + +func (x *StoredSchema_V1StoredSchema) ProtoReflect() protoreflect.Message { + mi := &file_core_v1_core_proto_msgTypes[45] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StoredSchema_V1StoredSchema.ProtoReflect.Descriptor instead. +func (*StoredSchema_V1StoredSchema) Descriptor() ([]byte, []int) { + return file_core_v1_core_proto_rawDescGZIP(), []int{33, 0} +} + +func (x *StoredSchema_V1StoredSchema) GetSchemaText() string { + if x != nil { + return x.SchemaText + } + return "" +} + +func (x *StoredSchema_V1StoredSchema) GetSchemaHash() string { + if x != nil { + return x.SchemaHash + } + return "" +} + +func (x *StoredSchema_V1StoredSchema) GetNamespaceDefinitions() map[string]*NamespaceDefinition { + if x != nil { + return x.NamespaceDefinitions + } + return nil +} + +func (x *StoredSchema_V1StoredSchema) GetCaveatDefinitions() map[string]*CaveatDefinition { + if x != nil { + return x.CaveatDefinitions + } + return nil +} + var File_core_v1_core_proto protoreflect.FileDescriptor const file_core_v1_core_proto_rawDesc = "" + @@ -3259,7 +3408,24 @@ const file_core_v1_core_proto_rawDesc = "" + "\x13optional_subject_id\x18\x02 \x01(\tB*\xfaB'r%(\x80\b2 ^(([a-zA-Z0-9/_|\\-=+]{1,})|\\*)?$R\x11optionalSubjectId\x12R\n" + "\x11optional_relation\x18\x03 \x01(\v2%.core.v1.SubjectFilter.RelationFilterR\x10optionalRelation\x1aX\n" + "\x0eRelationFilter\x12F\n" + - "\brelation\x18\x01 \x01(\tB*\xfaB'r%(@2!^([a-z][a-z0-9_]{1,62}[a-z0-9])?$R\brelationB\x8a\x01\n" + + "\brelation\x18\x01 \x01(\tB*\xfaB'r%(@2!^([a-z][a-z0-9_]{1,62}[a-z0-9])?$R\brelation\"\xef\x04\n" + + "\fStoredSchema\x12\x18\n" + + "\aversion\x18\x01 \x01(\rR\aversion\x126\n" + + "\x02v1\x18\x02 \x01(\v2$.core.v1.StoredSchema.V1StoredSchemaH\x00R\x02v1\x1a\xfb\x03\n" + + "\x0eV1StoredSchema\x12\x1f\n" + + "\vschema_text\x18\x01 \x01(\tR\n" + + "schemaText\x12\x1f\n" + + "\vschema_hash\x18\x02 \x01(\tR\n" + + "schemaHash\x12s\n" + + "\x15namespace_definitions\x18\x03 \x03(\v2>.core.v1.StoredSchema.V1StoredSchema.NamespaceDefinitionsEntryR\x14namespaceDefinitions\x12j\n" + + "\x12caveat_definitions\x18\x04 \x03(\v2;.core.v1.StoredSchema.V1StoredSchema.CaveatDefinitionsEntryR\x11caveatDefinitions\x1ae\n" + + "\x19NamespaceDefinitionsEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x122\n" + + "\x05value\x18\x02 \x01(\v2\x1c.core.v1.NamespaceDefinitionR\x05value:\x028\x01\x1a_\n" + + "\x16CaveatDefinitionsEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12/\n" + + "\x05value\x18\x02 \x01(\v2\x19.core.v1.CaveatDefinitionR\x05value:\x028\x01B\x0f\n" + + "\rversion_oneofB\x8a\x01\n" + "\vcom.core.v1B\tCoreProtoP\x01Z3github.com/authzed/spicedb/pkg/proto/core/v1;corev1\xa2\x02\x03CXX\xaa\x02\aCore.V1\xca\x02\aCore\\V1\xe2\x02\x13Core\\V1\\GPBMetadata\xea\x02\bCore::V1b\x06proto3" var ( @@ -3275,7 +3441,7 @@ func file_core_v1_core_proto_rawDescGZIP() []byte { } var file_core_v1_core_proto_enumTypes = make([]protoimpl.EnumInfo, 7) -var file_core_v1_core_proto_msgTypes = make([]protoimpl.MessageInfo, 44) +var file_core_v1_core_proto_msgTypes = make([]protoimpl.MessageInfo, 48) var file_core_v1_core_proto_goTypes = []any{ (RelationTupleUpdate_Operation)(0), // 0: core.v1.RelationTupleUpdate.Operation (SetOperationUserset_Operation)(0), // 1: core.v1.SetOperationUserset.Operation @@ -3317,30 +3483,34 @@ var file_core_v1_core_proto_goTypes = []any{ (*CaveatOperation)(nil), // 37: core.v1.CaveatOperation (*RelationshipFilter)(nil), // 38: core.v1.RelationshipFilter (*SubjectFilter)(nil), // 39: core.v1.SubjectFilter - nil, // 40: core.v1.CaveatDefinition.ParameterTypesEntry - nil, // 41: core.v1.ReachabilityGraph.EntrypointsBySubjectTypeEntry - nil, // 42: core.v1.ReachabilityGraph.EntrypointsBySubjectRelationEntry - (*AllowedRelation_PublicWildcard)(nil), // 43: core.v1.AllowedRelation.PublicWildcard - (*SetOperation_Child)(nil), // 44: core.v1.SetOperation.Child - (*SetOperation_Child_This)(nil), // 45: core.v1.SetOperation.Child.This - (*SetOperation_Child_Nil)(nil), // 46: core.v1.SetOperation.Child.Nil - (*SetOperation_Child_Self)(nil), // 47: core.v1.SetOperation.Child.Self - (*TupleToUserset_Tupleset)(nil), // 48: core.v1.TupleToUserset.Tupleset - (*FunctionedTupleToUserset_Tupleset)(nil), // 49: core.v1.FunctionedTupleToUserset.Tupleset - (*SubjectFilter_RelationFilter)(nil), // 50: core.v1.SubjectFilter.RelationFilter - (*timestamppb.Timestamp)(nil), // 51: google.protobuf.Timestamp - (*structpb.Struct)(nil), // 52: google.protobuf.Struct - (*anypb.Any)(nil), // 53: google.protobuf.Any + (*StoredSchema)(nil), // 40: core.v1.StoredSchema + nil, // 41: core.v1.CaveatDefinition.ParameterTypesEntry + nil, // 42: core.v1.ReachabilityGraph.EntrypointsBySubjectTypeEntry + nil, // 43: core.v1.ReachabilityGraph.EntrypointsBySubjectRelationEntry + (*AllowedRelation_PublicWildcard)(nil), // 44: core.v1.AllowedRelation.PublicWildcard + (*SetOperation_Child)(nil), // 45: core.v1.SetOperation.Child + (*SetOperation_Child_This)(nil), // 46: core.v1.SetOperation.Child.This + (*SetOperation_Child_Nil)(nil), // 47: core.v1.SetOperation.Child.Nil + (*SetOperation_Child_Self)(nil), // 48: core.v1.SetOperation.Child.Self + (*TupleToUserset_Tupleset)(nil), // 49: core.v1.TupleToUserset.Tupleset + (*FunctionedTupleToUserset_Tupleset)(nil), // 50: core.v1.FunctionedTupleToUserset.Tupleset + (*SubjectFilter_RelationFilter)(nil), // 51: core.v1.SubjectFilter.RelationFilter + (*StoredSchema_V1StoredSchema)(nil), // 52: core.v1.StoredSchema.V1StoredSchema + nil, // 53: core.v1.StoredSchema.V1StoredSchema.NamespaceDefinitionsEntry + nil, // 54: core.v1.StoredSchema.V1StoredSchema.CaveatDefinitionsEntry + (*timestamppb.Timestamp)(nil), // 55: google.protobuf.Timestamp + (*structpb.Struct)(nil), // 56: google.protobuf.Struct + (*anypb.Any)(nil), // 57: google.protobuf.Any } var file_core_v1_core_proto_depIdxs = []int32{ 12, // 0: core.v1.RelationTuple.resource_and_relation:type_name -> core.v1.ObjectAndRelation 12, // 1: core.v1.RelationTuple.subject:type_name -> core.v1.ObjectAndRelation 9, // 2: core.v1.RelationTuple.caveat:type_name -> core.v1.ContextualizedCaveat 8, // 3: core.v1.RelationTuple.integrity:type_name -> core.v1.RelationshipIntegrity - 51, // 4: core.v1.RelationTuple.optional_expiration_time:type_name -> google.protobuf.Timestamp - 51, // 5: core.v1.RelationshipIntegrity.hashed_at:type_name -> google.protobuf.Timestamp - 52, // 6: core.v1.ContextualizedCaveat.context:type_name -> google.protobuf.Struct - 40, // 7: core.v1.CaveatDefinition.parameter_types:type_name -> core.v1.CaveatDefinition.ParameterTypesEntry + 55, // 4: core.v1.RelationTuple.optional_expiration_time:type_name -> google.protobuf.Timestamp + 55, // 5: core.v1.RelationshipIntegrity.hashed_at:type_name -> google.protobuf.Timestamp + 56, // 6: core.v1.ContextualizedCaveat.context:type_name -> google.protobuf.Struct + 41, // 7: core.v1.CaveatDefinition.parameter_types:type_name -> core.v1.CaveatDefinition.ParameterTypesEntry 20, // 8: core.v1.CaveatDefinition.metadata:type_name -> core.v1.Metadata 35, // 9: core.v1.CaveatDefinition.source_position:type_name -> core.v1.SourcePosition 11, // 10: core.v1.CaveatTypeReference.child_types:type_name -> core.v1.CaveatTypeReference @@ -3355,7 +3525,7 @@ var file_core_v1_core_proto_depIdxs = []int32{ 12, // 19: core.v1.DirectSubject.subject:type_name -> core.v1.ObjectAndRelation 36, // 20: core.v1.DirectSubject.caveat_expression:type_name -> core.v1.CaveatExpression 18, // 21: core.v1.DirectSubjects.subjects:type_name -> core.v1.DirectSubject - 53, // 22: core.v1.Metadata.metadata_message:type_name -> google.protobuf.Any + 57, // 22: core.v1.Metadata.metadata_message:type_name -> google.protobuf.Any 22, // 23: core.v1.NamespaceDefinition.relation:type_name -> core.v1.Relation 20, // 24: core.v1.NamespaceDefinition.metadata:type_name -> core.v1.Metadata 35, // 25: core.v1.NamespaceDefinition.source_position:type_name -> core.v1.SourcePosition @@ -3363,15 +3533,15 @@ var file_core_v1_core_proto_depIdxs = []int32{ 26, // 27: core.v1.Relation.type_information:type_name -> core.v1.TypeInformation 20, // 28: core.v1.Relation.metadata:type_name -> core.v1.Metadata 35, // 29: core.v1.Relation.source_position:type_name -> core.v1.SourcePosition - 41, // 30: core.v1.ReachabilityGraph.entrypoints_by_subject_type:type_name -> core.v1.ReachabilityGraph.EntrypointsBySubjectTypeEntry - 42, // 31: core.v1.ReachabilityGraph.entrypoints_by_subject_relation:type_name -> core.v1.ReachabilityGraph.EntrypointsBySubjectRelationEntry + 42, // 30: core.v1.ReachabilityGraph.entrypoints_by_subject_type:type_name -> core.v1.ReachabilityGraph.EntrypointsBySubjectTypeEntry + 43, // 31: core.v1.ReachabilityGraph.entrypoints_by_subject_relation:type_name -> core.v1.ReachabilityGraph.EntrypointsBySubjectRelationEntry 25, // 32: core.v1.ReachabilityEntrypoints.entrypoints:type_name -> core.v1.ReachabilityEntrypoint 13, // 33: core.v1.ReachabilityEntrypoints.subject_relation:type_name -> core.v1.RelationReference 2, // 34: core.v1.ReachabilityEntrypoint.kind:type_name -> core.v1.ReachabilityEntrypoint.ReachabilityEntrypointKind 13, // 35: core.v1.ReachabilityEntrypoint.target_relation:type_name -> core.v1.RelationReference 3, // 36: core.v1.ReachabilityEntrypoint.result_status:type_name -> core.v1.ReachabilityEntrypoint.EntrypointResultStatus 27, // 37: core.v1.TypeInformation.allowed_direct_relations:type_name -> core.v1.AllowedRelation - 43, // 38: core.v1.AllowedRelation.public_wildcard:type_name -> core.v1.AllowedRelation.PublicWildcard + 44, // 38: core.v1.AllowedRelation.public_wildcard:type_name -> core.v1.AllowedRelation.PublicWildcard 35, // 39: core.v1.AllowedRelation.source_position:type_name -> core.v1.SourcePosition 29, // 40: core.v1.AllowedRelation.required_caveat:type_name -> core.v1.AllowedCaveat 28, // 41: core.v1.AllowedRelation.required_expiration:type_name -> core.v1.ExpirationTrait @@ -3379,12 +3549,12 @@ var file_core_v1_core_proto_depIdxs = []int32{ 31, // 43: core.v1.UsersetRewrite.intersection:type_name -> core.v1.SetOperation 31, // 44: core.v1.UsersetRewrite.exclusion:type_name -> core.v1.SetOperation 35, // 45: core.v1.UsersetRewrite.source_position:type_name -> core.v1.SourcePosition - 44, // 46: core.v1.SetOperation.child:type_name -> core.v1.SetOperation.Child - 48, // 47: core.v1.TupleToUserset.tupleset:type_name -> core.v1.TupleToUserset.Tupleset + 45, // 46: core.v1.SetOperation.child:type_name -> core.v1.SetOperation.Child + 49, // 47: core.v1.TupleToUserset.tupleset:type_name -> core.v1.TupleToUserset.Tupleset 34, // 48: core.v1.TupleToUserset.computed_userset:type_name -> core.v1.ComputedUserset 35, // 49: core.v1.TupleToUserset.source_position:type_name -> core.v1.SourcePosition 4, // 50: core.v1.FunctionedTupleToUserset.function:type_name -> core.v1.FunctionedTupleToUserset.Function - 49, // 51: core.v1.FunctionedTupleToUserset.tupleset:type_name -> core.v1.FunctionedTupleToUserset.Tupleset + 50, // 51: core.v1.FunctionedTupleToUserset.tupleset:type_name -> core.v1.FunctionedTupleToUserset.Tupleset 34, // 52: core.v1.FunctionedTupleToUserset.computed_userset:type_name -> core.v1.ComputedUserset 35, // 53: core.v1.FunctionedTupleToUserset.source_position:type_name -> core.v1.SourcePosition 5, // 54: core.v1.ComputedUserset.object:type_name -> core.v1.ComputedUserset.Object @@ -3394,23 +3564,28 @@ var file_core_v1_core_proto_depIdxs = []int32{ 6, // 58: core.v1.CaveatOperation.op:type_name -> core.v1.CaveatOperation.Operation 36, // 59: core.v1.CaveatOperation.children:type_name -> core.v1.CaveatExpression 39, // 60: core.v1.RelationshipFilter.optional_subject_filter:type_name -> core.v1.SubjectFilter - 50, // 61: core.v1.SubjectFilter.optional_relation:type_name -> core.v1.SubjectFilter.RelationFilter - 11, // 62: core.v1.CaveatDefinition.ParameterTypesEntry.value:type_name -> core.v1.CaveatTypeReference - 24, // 63: core.v1.ReachabilityGraph.EntrypointsBySubjectTypeEntry.value:type_name -> core.v1.ReachabilityEntrypoints - 24, // 64: core.v1.ReachabilityGraph.EntrypointsBySubjectRelationEntry.value:type_name -> core.v1.ReachabilityEntrypoints - 45, // 65: core.v1.SetOperation.Child._this:type_name -> core.v1.SetOperation.Child.This - 34, // 66: core.v1.SetOperation.Child.computed_userset:type_name -> core.v1.ComputedUserset - 32, // 67: core.v1.SetOperation.Child.tuple_to_userset:type_name -> core.v1.TupleToUserset - 30, // 68: core.v1.SetOperation.Child.userset_rewrite:type_name -> core.v1.UsersetRewrite - 33, // 69: core.v1.SetOperation.Child.functioned_tuple_to_userset:type_name -> core.v1.FunctionedTupleToUserset - 46, // 70: core.v1.SetOperation.Child._nil:type_name -> core.v1.SetOperation.Child.Nil - 47, // 71: core.v1.SetOperation.Child._self:type_name -> core.v1.SetOperation.Child.Self - 35, // 72: core.v1.SetOperation.Child.source_position:type_name -> core.v1.SourcePosition - 73, // [73:73] is the sub-list for method output_type - 73, // [73:73] is the sub-list for method input_type - 73, // [73:73] is the sub-list for extension type_name - 73, // [73:73] is the sub-list for extension extendee - 0, // [0:73] is the sub-list for field type_name + 51, // 61: core.v1.SubjectFilter.optional_relation:type_name -> core.v1.SubjectFilter.RelationFilter + 52, // 62: core.v1.StoredSchema.v1:type_name -> core.v1.StoredSchema.V1StoredSchema + 11, // 63: core.v1.CaveatDefinition.ParameterTypesEntry.value:type_name -> core.v1.CaveatTypeReference + 24, // 64: core.v1.ReachabilityGraph.EntrypointsBySubjectTypeEntry.value:type_name -> core.v1.ReachabilityEntrypoints + 24, // 65: core.v1.ReachabilityGraph.EntrypointsBySubjectRelationEntry.value:type_name -> core.v1.ReachabilityEntrypoints + 46, // 66: core.v1.SetOperation.Child._this:type_name -> core.v1.SetOperation.Child.This + 34, // 67: core.v1.SetOperation.Child.computed_userset:type_name -> core.v1.ComputedUserset + 32, // 68: core.v1.SetOperation.Child.tuple_to_userset:type_name -> core.v1.TupleToUserset + 30, // 69: core.v1.SetOperation.Child.userset_rewrite:type_name -> core.v1.UsersetRewrite + 33, // 70: core.v1.SetOperation.Child.functioned_tuple_to_userset:type_name -> core.v1.FunctionedTupleToUserset + 47, // 71: core.v1.SetOperation.Child._nil:type_name -> core.v1.SetOperation.Child.Nil + 48, // 72: core.v1.SetOperation.Child._self:type_name -> core.v1.SetOperation.Child.Self + 35, // 73: core.v1.SetOperation.Child.source_position:type_name -> core.v1.SourcePosition + 53, // 74: core.v1.StoredSchema.V1StoredSchema.namespace_definitions:type_name -> core.v1.StoredSchema.V1StoredSchema.NamespaceDefinitionsEntry + 54, // 75: core.v1.StoredSchema.V1StoredSchema.caveat_definitions:type_name -> core.v1.StoredSchema.V1StoredSchema.CaveatDefinitionsEntry + 21, // 76: core.v1.StoredSchema.V1StoredSchema.NamespaceDefinitionsEntry.value:type_name -> core.v1.NamespaceDefinition + 10, // 77: core.v1.StoredSchema.V1StoredSchema.CaveatDefinitionsEntry.value:type_name -> core.v1.CaveatDefinition + 78, // [78:78] is the sub-list for method output_type + 78, // [78:78] is the sub-list for method input_type + 78, // [78:78] is the sub-list for extension type_name + 78, // [78:78] is the sub-list for extension extendee + 0, // [0:78] is the sub-list for field type_name } func init() { file_core_v1_core_proto_init() } @@ -3435,7 +3610,10 @@ func file_core_v1_core_proto_init() { (*CaveatExpression_Operation)(nil), (*CaveatExpression_Caveat)(nil), } - file_core_v1_core_proto_msgTypes[37].OneofWrappers = []any{ + file_core_v1_core_proto_msgTypes[33].OneofWrappers = []any{ + (*StoredSchema_V1)(nil), + } + file_core_v1_core_proto_msgTypes[38].OneofWrappers = []any{ (*SetOperation_Child_XThis)(nil), (*SetOperation_Child_ComputedUserset)(nil), (*SetOperation_Child_TupleToUserset)(nil), @@ -3450,7 +3628,7 @@ func file_core_v1_core_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_core_v1_core_proto_rawDesc), len(file_core_v1_core_proto_rawDesc)), NumEnums: 7, - NumMessages: 44, + NumMessages: 48, NumExtensions: 0, NumServices: 0, }, diff --git a/pkg/proto/core/v1/core.pb.validate.go b/pkg/proto/core/v1/core.pb.validate.go index cd8de1d91..81449f0f4 100644 --- a/pkg/proto/core/v1/core.pb.validate.go +++ b/pkg/proto/core/v1/core.pb.validate.go @@ -5890,6 +5890,153 @@ var _SubjectFilter_SubjectType_Pattern = regexp.MustCompile("^([a-z][a-z0-9_]{1, var _SubjectFilter_OptionalSubjectId_Pattern = regexp.MustCompile("^(([a-zA-Z0-9/_|\\-=+]{1,})|\\*)?$") +// Validate checks the field values on StoredSchema with the rules defined in +// the proto definition for this message. If any rules are violated, the first +// error encountered is returned, or nil if there are no violations. +func (m *StoredSchema) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on StoredSchema with the rules defined +// in the proto definition for this message. If any rules are violated, the +// result is a list of violation errors wrapped in StoredSchemaMultiError, or +// nil if none found. +func (m *StoredSchema) ValidateAll() error { + return m.validate(true) +} + +func (m *StoredSchema) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + // no validation rules for Version + + switch v := m.VersionOneof.(type) { + case *StoredSchema_V1: + if v == nil { + err := StoredSchemaValidationError{ + field: "VersionOneof", + reason: "oneof value cannot be a typed-nil", + } + if !all { + return err + } + errors = append(errors, err) + } + + if all { + switch v := interface{}(m.GetV1()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, StoredSchemaValidationError{ + field: "V1", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, StoredSchemaValidationError{ + field: "V1", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetV1()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return StoredSchemaValidationError{ + field: "V1", + reason: "embedded message failed validation", + cause: err, + } + } + } + + default: + _ = v // ensures v is used + } + + if len(errors) > 0 { + return StoredSchemaMultiError(errors) + } + + return nil +} + +// StoredSchemaMultiError is an error wrapping multiple validation errors +// returned by StoredSchema.ValidateAll() if the designated constraints aren't met. +type StoredSchemaMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m StoredSchemaMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m StoredSchemaMultiError) AllErrors() []error { return m } + +// StoredSchemaValidationError is the validation error returned by +// StoredSchema.Validate if the designated constraints aren't met. +type StoredSchemaValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e StoredSchemaValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e StoredSchemaValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e StoredSchemaValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e StoredSchemaValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e StoredSchemaValidationError) ErrorName() string { return "StoredSchemaValidationError" } + +// Error satisfies the builtin error interface +func (e StoredSchemaValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sStoredSchema.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = StoredSchemaValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = StoredSchemaValidationError{} + // Validate checks the field values on AllowedRelation_PublicWildcard with the // rules defined in the proto definition for this message. If any rules are // violated, the first error encountered is returned, or nil if there are no violations. @@ -7165,3 +7312,202 @@ var _ interface { } = SubjectFilter_RelationFilterValidationError{} var _SubjectFilter_RelationFilter_Relation_Pattern = regexp.MustCompile("^([a-z][a-z0-9_]{1,62}[a-z0-9])?$") + +// Validate checks the field values on StoredSchema_V1StoredSchema with the +// rules defined in the proto definition for this message. If any rules are +// violated, the first error encountered is returned, or nil if there are no violations. +func (m *StoredSchema_V1StoredSchema) Validate() error { + return m.validate(false) +} + +// ValidateAll checks the field values on StoredSchema_V1StoredSchema with the +// rules defined in the proto definition for this message. If any rules are +// violated, the result is a list of violation errors wrapped in +// StoredSchema_V1StoredSchemaMultiError, or nil if none found. +func (m *StoredSchema_V1StoredSchema) ValidateAll() error { + return m.validate(true) +} + +func (m *StoredSchema_V1StoredSchema) validate(all bool) error { + if m == nil { + return nil + } + + var errors []error + + // no validation rules for SchemaText + + // no validation rules for SchemaHash + + { + sorted_keys := make([]string, len(m.GetNamespaceDefinitions())) + i := 0 + for key := range m.GetNamespaceDefinitions() { + sorted_keys[i] = key + i++ + } + sort.Slice(sorted_keys, func(i, j int) bool { return sorted_keys[i] < sorted_keys[j] }) + for _, key := range sorted_keys { + val := m.GetNamespaceDefinitions()[key] + _ = val + + // no validation rules for NamespaceDefinitions[key] + + if all { + switch v := interface{}(val).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, StoredSchema_V1StoredSchemaValidationError{ + field: fmt.Sprintf("NamespaceDefinitions[%v]", key), + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, StoredSchema_V1StoredSchemaValidationError{ + field: fmt.Sprintf("NamespaceDefinitions[%v]", key), + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(val).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return StoredSchema_V1StoredSchemaValidationError{ + field: fmt.Sprintf("NamespaceDefinitions[%v]", key), + reason: "embedded message failed validation", + cause: err, + } + } + } + + } + } + + { + sorted_keys := make([]string, len(m.GetCaveatDefinitions())) + i := 0 + for key := range m.GetCaveatDefinitions() { + sorted_keys[i] = key + i++ + } + sort.Slice(sorted_keys, func(i, j int) bool { return sorted_keys[i] < sorted_keys[j] }) + for _, key := range sorted_keys { + val := m.GetCaveatDefinitions()[key] + _ = val + + // no validation rules for CaveatDefinitions[key] + + if all { + switch v := interface{}(val).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, StoredSchema_V1StoredSchemaValidationError{ + field: fmt.Sprintf("CaveatDefinitions[%v]", key), + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, StoredSchema_V1StoredSchemaValidationError{ + field: fmt.Sprintf("CaveatDefinitions[%v]", key), + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(val).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return StoredSchema_V1StoredSchemaValidationError{ + field: fmt.Sprintf("CaveatDefinitions[%v]", key), + reason: "embedded message failed validation", + cause: err, + } + } + } + + } + } + + if len(errors) > 0 { + return StoredSchema_V1StoredSchemaMultiError(errors) + } + + return nil +} + +// StoredSchema_V1StoredSchemaMultiError is an error wrapping multiple +// validation errors returned by StoredSchema_V1StoredSchema.ValidateAll() if +// the designated constraints aren't met. +type StoredSchema_V1StoredSchemaMultiError []error + +// Error returns a concatenation of all the error messages it wraps. +func (m StoredSchema_V1StoredSchemaMultiError) Error() string { + msgs := make([]string, 0, len(m)) + for _, err := range m { + msgs = append(msgs, err.Error()) + } + return strings.Join(msgs, "; ") +} + +// AllErrors returns a list of validation violation errors. +func (m StoredSchema_V1StoredSchemaMultiError) AllErrors() []error { return m } + +// StoredSchema_V1StoredSchemaValidationError is the validation error returned +// by StoredSchema_V1StoredSchema.Validate if the designated constraints +// aren't met. +type StoredSchema_V1StoredSchemaValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e StoredSchema_V1StoredSchemaValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e StoredSchema_V1StoredSchemaValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e StoredSchema_V1StoredSchemaValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e StoredSchema_V1StoredSchemaValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e StoredSchema_V1StoredSchemaValidationError) ErrorName() string { + return "StoredSchema_V1StoredSchemaValidationError" +} + +// Error satisfies the builtin error interface +func (e StoredSchema_V1StoredSchemaValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sStoredSchema_V1StoredSchema.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = StoredSchema_V1StoredSchemaValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = StoredSchema_V1StoredSchemaValidationError{} diff --git a/pkg/proto/core/v1/core_vtproto.pb.go b/pkg/proto/core/v1/core_vtproto.pb.go index 1a90dbb78..ea58868c5 100644 --- a/pkg/proto/core/v1/core_vtproto.pb.go +++ b/pkg/proto/core/v1/core_vtproto.pb.go @@ -1023,6 +1023,69 @@ func (m *SubjectFilter) CloneMessageVT() proto.Message { return m.CloneVT() } +func (m *StoredSchema_V1StoredSchema) CloneVT() *StoredSchema_V1StoredSchema { + if m == nil { + return (*StoredSchema_V1StoredSchema)(nil) + } + r := new(StoredSchema_V1StoredSchema) + r.SchemaText = m.SchemaText + r.SchemaHash = m.SchemaHash + if rhs := m.NamespaceDefinitions; rhs != nil { + tmpContainer := make(map[string]*NamespaceDefinition, len(rhs)) + for k, v := range rhs { + tmpContainer[k] = v.CloneVT() + } + r.NamespaceDefinitions = tmpContainer + } + if rhs := m.CaveatDefinitions; rhs != nil { + tmpContainer := make(map[string]*CaveatDefinition, len(rhs)) + for k, v := range rhs { + tmpContainer[k] = v.CloneVT() + } + r.CaveatDefinitions = tmpContainer + } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } + return r +} + +func (m *StoredSchema_V1StoredSchema) CloneMessageVT() proto.Message { + return m.CloneVT() +} + +func (m *StoredSchema) CloneVT() *StoredSchema { + if m == nil { + return (*StoredSchema)(nil) + } + r := new(StoredSchema) + r.Version = m.Version + if m.VersionOneof != nil { + r.VersionOneof = m.VersionOneof.(interface { + CloneVT() isStoredSchema_VersionOneof + }).CloneVT() + } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } + return r +} + +func (m *StoredSchema) CloneMessageVT() proto.Message { + return m.CloneVT() +} + +func (m *StoredSchema_V1) CloneVT() isStoredSchema_VersionOneof { + if m == nil { + return (*StoredSchema_V1)(nil) + } + r := new(StoredSchema_V1) + r.V1 = m.V1.CloneVT() + return r +} + func (this *RelationTuple) EqualVT(that *RelationTuple) bool { if this == that { return true @@ -2581,6 +2644,124 @@ func (this *SubjectFilter) EqualMessageVT(thatMsg proto.Message) bool { } return this.EqualVT(that) } +func (this *StoredSchema_V1StoredSchema) EqualVT(that *StoredSchema_V1StoredSchema) bool { + if this == that { + return true + } else if this == nil || that == nil { + return false + } + if this.SchemaText != that.SchemaText { + return false + } + if this.SchemaHash != that.SchemaHash { + return false + } + if len(this.NamespaceDefinitions) != len(that.NamespaceDefinitions) { + return false + } + for i, vx := range this.NamespaceDefinitions { + vy, ok := that.NamespaceDefinitions[i] + if !ok { + return false + } + if p, q := vx, vy; p != q { + if p == nil { + p = &NamespaceDefinition{} + } + if q == nil { + q = &NamespaceDefinition{} + } + if !p.EqualVT(q) { + return false + } + } + } + if len(this.CaveatDefinitions) != len(that.CaveatDefinitions) { + return false + } + for i, vx := range this.CaveatDefinitions { + vy, ok := that.CaveatDefinitions[i] + if !ok { + return false + } + if p, q := vx, vy; p != q { + if p == nil { + p = &CaveatDefinition{} + } + if q == nil { + q = &CaveatDefinition{} + } + if !p.EqualVT(q) { + return false + } + } + } + return string(this.unknownFields) == string(that.unknownFields) +} + +func (this *StoredSchema_V1StoredSchema) EqualMessageVT(thatMsg proto.Message) bool { + that, ok := thatMsg.(*StoredSchema_V1StoredSchema) + if !ok { + return false + } + return this.EqualVT(that) +} +func (this *StoredSchema) EqualVT(that *StoredSchema) bool { + if this == that { + return true + } else if this == nil || that == nil { + return false + } + if this.VersionOneof == nil && that.VersionOneof != nil { + return false + } else if this.VersionOneof != nil { + if that.VersionOneof == nil { + return false + } + if !this.VersionOneof.(interface { + EqualVT(isStoredSchema_VersionOneof) bool + }).EqualVT(that.VersionOneof) { + return false + } + } + if this.Version != that.Version { + return false + } + return string(this.unknownFields) == string(that.unknownFields) +} + +func (this *StoredSchema) EqualMessageVT(thatMsg proto.Message) bool { + that, ok := thatMsg.(*StoredSchema) + if !ok { + return false + } + return this.EqualVT(that) +} +func (this *StoredSchema_V1) EqualVT(thatIface isStoredSchema_VersionOneof) bool { + that, ok := thatIface.(*StoredSchema_V1) + if !ok { + return false + } + if this == that { + return true + } + if this == nil && that != nil || this != nil && that == nil { + return false + } + if p, q := this.V1, that.V1; p != q { + if p == nil { + p = &StoredSchema_V1StoredSchema{} + } + if q == nil { + q = &StoredSchema_V1StoredSchema{} + } + if !p.EqualVT(q) { + return false + } + } + return true +} + func (m *RelationTuple) MarshalVT() (dAtA []byte, err error) { if m == nil { return nil, nil @@ -5154,85 +5335,246 @@ func (m *SubjectFilter) MarshalToSizedBufferVT(dAtA []byte) (int, error) { return len(dAtA) - i, nil } -func (m *RelationTuple) SizeVT() (n int) { +func (m *StoredSchema_V1StoredSchema) MarshalVT() (dAtA []byte, err error) { if m == nil { - return 0 - } - var l int - _ = l - if m.ResourceAndRelation != nil { - l = m.ResourceAndRelation.SizeVT() - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) - } - if m.Subject != nil { - l = m.Subject.SizeVT() - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) - } - if m.Caveat != nil { - l = m.Caveat.SizeVT() - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) - } - if m.Integrity != nil { - l = m.Integrity.SizeVT() - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + return nil, nil } - if m.OptionalExpirationTime != nil { - l = (*timestamppb1.Timestamp)(m.OptionalExpirationTime).SizeVT() - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVT(dAtA[:size]) + if err != nil { + return nil, err } - n += len(m.unknownFields) - return n + return dAtA[:n], nil } -func (m *RelationshipIntegrity) SizeVT() (n int) { +func (m *StoredSchema_V1StoredSchema) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *StoredSchema_V1StoredSchema) MarshalToSizedBufferVT(dAtA []byte) (int, error) { if m == nil { - return 0 + return 0, nil } + i := len(dAtA) + _ = i var l int _ = l - l = len(m.KeyId) - if l > 0 { - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) } - l = len(m.Hash) - if l > 0 { - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + if len(m.CaveatDefinitions) > 0 { + for k := range m.CaveatDefinitions { + v := m.CaveatDefinitions[k] + baseI := i + size, err := v.MarshalToSizedBufferVT(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) + i-- + dAtA[i] = 0x12 + i -= len(k) + copy(dAtA[i:], k) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(k))) + i-- + dAtA[i] = 0xa + i = protohelpers.EncodeVarint(dAtA, i, uint64(baseI-i)) + i-- + dAtA[i] = 0x22 + } } - if m.HashedAt != nil { - l = (*timestamppb1.Timestamp)(m.HashedAt).SizeVT() - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + if len(m.NamespaceDefinitions) > 0 { + for k := range m.NamespaceDefinitions { + v := m.NamespaceDefinitions[k] + baseI := i + size, err := v.MarshalToSizedBufferVT(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) + i-- + dAtA[i] = 0x12 + i -= len(k) + copy(dAtA[i:], k) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(k))) + i-- + dAtA[i] = 0xa + i = protohelpers.EncodeVarint(dAtA, i, uint64(baseI-i)) + i-- + dAtA[i] = 0x1a + } } - n += len(m.unknownFields) - return n + if len(m.SchemaHash) > 0 { + i -= len(m.SchemaHash) + copy(dAtA[i:], m.SchemaHash) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.SchemaHash))) + i-- + dAtA[i] = 0x12 + } + if len(m.SchemaText) > 0 { + i -= len(m.SchemaText) + copy(dAtA[i:], m.SchemaText) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.SchemaText))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil } -func (m *ContextualizedCaveat) SizeVT() (n int) { +func (m *StoredSchema) MarshalVT() (dAtA []byte, err error) { if m == nil { - return 0 - } - var l int - _ = l - l = len(m.CaveatName) - if l > 0 { - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + return nil, nil } - if m.Context != nil { - l = (*structpb1.Struct)(m.Context).SizeVT() - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVT(dAtA[:size]) + if err != nil { + return nil, err } - n += len(m.unknownFields) - return n + return dAtA[:n], nil } -func (m *CaveatDefinition) SizeVT() (n int) { +func (m *StoredSchema) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *StoredSchema) MarshalToSizedBufferVT(dAtA []byte) (int, error) { if m == nil { - return 0 + return 0, nil } + i := len(dAtA) + _ = i var l int _ = l - l = len(m.Name) - if l > 0 { - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) + } + if vtmsg, ok := m.VersionOneof.(interface { + MarshalToSizedBufferVT([]byte) (int, error) + }); ok { + size, err := vtmsg.MarshalToSizedBufferVT(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + } + if m.Version != 0 { + i = protohelpers.EncodeVarint(dAtA, i, uint64(m.Version)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func (m *StoredSchema_V1) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *StoredSchema_V1) MarshalToSizedBufferVT(dAtA []byte) (int, error) { + i := len(dAtA) + if m.V1 != nil { + size, err := m.V1.MarshalToSizedBufferVT(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) + i-- + dAtA[i] = 0x12 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x12 + } + return len(dAtA) - i, nil +} +func (m *RelationTuple) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.ResourceAndRelation != nil { + l = m.ResourceAndRelation.SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + if m.Subject != nil { + l = m.Subject.SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + if m.Caveat != nil { + l = m.Caveat.SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + if m.Integrity != nil { + l = m.Integrity.SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + if m.OptionalExpirationTime != nil { + l = (*timestamppb1.Timestamp)(m.OptionalExpirationTime).SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + n += len(m.unknownFields) + return n +} + +func (m *RelationshipIntegrity) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.KeyId) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + l = len(m.Hash) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + if m.HashedAt != nil { + l = (*timestamppb1.Timestamp)(m.HashedAt).SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + n += len(m.unknownFields) + return n +} + +func (m *ContextualizedCaveat) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.CaveatName) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + if m.Context != nil { + l = (*structpb1.Struct)(m.Context).SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + n += len(m.unknownFields) + return n +} + +func (m *CaveatDefinition) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Name) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) } l = len(m.SerializedExpression) if l > 0 { @@ -6194,6 +6536,80 @@ func (m *SubjectFilter) SizeVT() (n int) { return n } +func (m *StoredSchema_V1StoredSchema) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.SchemaText) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + l = len(m.SchemaHash) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + if len(m.NamespaceDefinitions) > 0 { + for k, v := range m.NamespaceDefinitions { + _ = k + _ = v + l = 0 + if v != nil { + l = v.SizeVT() + } + l += 1 + protohelpers.SizeOfVarint(uint64(l)) + mapEntrySize := 1 + len(k) + protohelpers.SizeOfVarint(uint64(len(k))) + l + n += mapEntrySize + 1 + protohelpers.SizeOfVarint(uint64(mapEntrySize)) + } + } + if len(m.CaveatDefinitions) > 0 { + for k, v := range m.CaveatDefinitions { + _ = k + _ = v + l = 0 + if v != nil { + l = v.SizeVT() + } + l += 1 + protohelpers.SizeOfVarint(uint64(l)) + mapEntrySize := 1 + len(k) + protohelpers.SizeOfVarint(uint64(len(k))) + l + n += mapEntrySize + 1 + protohelpers.SizeOfVarint(uint64(mapEntrySize)) + } + } + n += len(m.unknownFields) + return n +} + +func (m *StoredSchema) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Version != 0 { + n += 1 + protohelpers.SizeOfVarint(uint64(m.Version)) + } + if vtmsg, ok := m.VersionOneof.(interface{ SizeVT() int }); ok { + n += vtmsg.SizeVT() + } + n += len(m.unknownFields) + return n +} + +func (m *StoredSchema_V1) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.V1 != nil { + l = m.V1.SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 + } + return n +} func (m *RelationTuple) UnmarshalVT(dAtA []byte) error { l := len(dAtA) iNdEx := 0 @@ -12164,3 +12580,487 @@ func (m *SubjectFilter) UnmarshalVT(dAtA []byte) error { } return nil } +func (m *StoredSchema_V1StoredSchema) UnmarshalVT(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: StoredSchema_V1StoredSchema: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: StoredSchema_V1StoredSchema: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field SchemaText", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.SchemaText = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field SchemaHash", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.SchemaHash = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field NamespaceDefinitions", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.NamespaceDefinitions == nil { + m.NamespaceDefinitions = make(map[string]*NamespaceDefinition) + } + var mapkey string + var mapvalue *NamespaceDefinition + for iNdEx < postIndex { + entryPreIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + if fieldNum == 1 { + var stringLenmapkey uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLenmapkey |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLenmapkey := int(stringLenmapkey) + if intStringLenmapkey < 0 { + return protohelpers.ErrInvalidLength + } + postStringIndexmapkey := iNdEx + intStringLenmapkey + if postStringIndexmapkey < 0 { + return protohelpers.ErrInvalidLength + } + if postStringIndexmapkey > l { + return io.ErrUnexpectedEOF + } + mapkey = string(dAtA[iNdEx:postStringIndexmapkey]) + iNdEx = postStringIndexmapkey + } else if fieldNum == 2 { + var mapmsglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + mapmsglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if mapmsglen < 0 { + return protohelpers.ErrInvalidLength + } + postmsgIndex := iNdEx + mapmsglen + if postmsgIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postmsgIndex > l { + return io.ErrUnexpectedEOF + } + mapvalue = &NamespaceDefinition{} + if err := mapvalue.UnmarshalVT(dAtA[iNdEx:postmsgIndex]); err != nil { + return err + } + iNdEx = postmsgIndex + } else { + iNdEx = entryPreIndex + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return protohelpers.ErrInvalidLength + } + if (iNdEx + skippy) > postIndex { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + m.NamespaceDefinitions[mapkey] = mapvalue + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field CaveatDefinitions", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.CaveatDefinitions == nil { + m.CaveatDefinitions = make(map[string]*CaveatDefinition) + } + var mapkey string + var mapvalue *CaveatDefinition + for iNdEx < postIndex { + entryPreIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + if fieldNum == 1 { + var stringLenmapkey uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLenmapkey |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLenmapkey := int(stringLenmapkey) + if intStringLenmapkey < 0 { + return protohelpers.ErrInvalidLength + } + postStringIndexmapkey := iNdEx + intStringLenmapkey + if postStringIndexmapkey < 0 { + return protohelpers.ErrInvalidLength + } + if postStringIndexmapkey > l { + return io.ErrUnexpectedEOF + } + mapkey = string(dAtA[iNdEx:postStringIndexmapkey]) + iNdEx = postStringIndexmapkey + } else if fieldNum == 2 { + var mapmsglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + mapmsglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if mapmsglen < 0 { + return protohelpers.ErrInvalidLength + } + postmsgIndex := iNdEx + mapmsglen + if postmsgIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postmsgIndex > l { + return io.ErrUnexpectedEOF + } + mapvalue = &CaveatDefinition{} + if err := mapvalue.UnmarshalVT(dAtA[iNdEx:postmsgIndex]); err != nil { + return err + } + iNdEx = postmsgIndex + } else { + iNdEx = entryPreIndex + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return protohelpers.ErrInvalidLength + } + if (iNdEx + skippy) > postIndex { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + m.CaveatDefinitions[mapkey] = mapvalue + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return protohelpers.ErrInvalidLength + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *StoredSchema) UnmarshalVT(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: StoredSchema: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: StoredSchema: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Version", wireType) + } + m.Version = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Version |= uint32(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field V1", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if oneof, ok := m.VersionOneof.(*StoredSchema_V1); ok { + if err := oneof.V1.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + return err + } + } else { + v := &StoredSchema_V1StoredSchema{} + if err := v.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + return err + } + m.VersionOneof = &StoredSchema_V1{V1: v} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return protohelpers.ErrInvalidLength + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} diff --git a/pkg/proto/dispatch/v1/dispatch.pb.go b/pkg/proto/dispatch/v1/dispatch.pb.go index 6ac2f9529..2b52d56a0 100644 --- a/pkg/proto/dispatch/v1/dispatch.pb.go +++ b/pkg/proto/dispatch/v1/dispatch.pb.go @@ -1372,6 +1372,7 @@ type ResolverMeta struct { // Deprecated: Marked as deprecated in dispatch/v1/dispatch.proto. RequestId string `protobuf:"bytes,3,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` TraversalBloom []byte `protobuf:"bytes,4,opt,name=traversal_bloom,json=traversalBloom,proto3" json:"traversal_bloom,omitempty"` + SchemaHash []byte `protobuf:"bytes,5,opt,name=schema_hash,json=schemaHash,proto3" json:"schema_hash,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -1435,6 +1436,13 @@ func (x *ResolverMeta) GetTraversalBloom() []byte { return nil } +func (x *ResolverMeta) GetSchemaHash() []byte { + if x != nil { + return x.SchemaHash + } + return nil +} + type ResponseMeta struct { state protoimpl.MessageState `protogen:"open.v1"` DispatchCount uint32 `protobuf:"varint,1,opt,name=dispatch_count,json=dispatchCount,proto3" json:"dispatch_count,omitempty"` @@ -1762,14 +1770,16 @@ const file_dispatch_v1_dispatch_proto_rawDesc = "" + "\bmetadata\x18\x02 \x01(\v2\x19.dispatch.v1.ResponseMetaR\bmetadata\x1ah\n" + "\x1eFoundSubjectsByResourceIdEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x120\n" + - "\x05value\x18\x02 \x01(\v2\x1a.dispatch.v1.FoundSubjectsR\x05value:\x028\x01\"\xb7\x01\n" + + "\x05value\x18\x02 \x01(\v2\x1a.dispatch.v1.FoundSubjectsR\x05value:\x028\x01\"\xd8\x01\n" + "\fResolverMeta\x12\x1f\n" + "\vat_revision\x18\x01 \x01(\tR\n" + "atRevision\x120\n" + "\x0fdepth_remaining\x18\x02 \x01(\rB\a\xfaB\x04*\x02 \x00R\x0edepthRemaining\x12!\n" + "\n" + "request_id\x18\x03 \x01(\tB\x02\x18\x01R\trequestId\x121\n" + - "\x0ftraversal_bloom\x18\x04 \x01(\fB\b\xfaB\x05z\x03\x18\x80\bR\x0etraversalBloom\"\xda\x01\n" + + "\x0ftraversal_bloom\x18\x04 \x01(\fB\b\xfaB\x05z\x03\x18\x80\bR\x0etraversalBloom\x12\x1f\n" + + "\vschema_hash\x18\x05 \x01(\fR\n" + + "schemaHash\"\xda\x01\n" + "\fResponseMeta\x12%\n" + "\x0edispatch_count\x18\x01 \x01(\rR\rdispatchCount\x12%\n" + "\x0edepth_required\x18\x02 \x01(\rR\rdepthRequired\x122\n" + diff --git a/pkg/proto/dispatch/v1/dispatch.pb.validate.go b/pkg/proto/dispatch/v1/dispatch.pb.validate.go index 265bb06cc..cf3028d01 100644 --- a/pkg/proto/dispatch/v1/dispatch.pb.validate.go +++ b/pkg/proto/dispatch/v1/dispatch.pb.validate.go @@ -3134,6 +3134,8 @@ func (m *ResolverMeta) validate(all bool) error { errors = append(errors, err) } + // no validation rules for SchemaHash + if len(errors) > 0 { return ResolverMetaMultiError(errors) } diff --git a/pkg/proto/dispatch/v1/dispatch_vtproto.pb.go b/pkg/proto/dispatch/v1/dispatch_vtproto.pb.go index e03bb3ac8..777ff0248 100644 --- a/pkg/proto/dispatch/v1/dispatch_vtproto.pb.go +++ b/pkg/proto/dispatch/v1/dispatch_vtproto.pb.go @@ -551,6 +551,11 @@ func (m *ResolverMeta) CloneVT() *ResolverMeta { copy(tmpBytes, rhs) r.TraversalBloom = tmpBytes } + if rhs := m.SchemaHash; rhs != nil { + tmpBytes := make([]byte, len(rhs)) + copy(tmpBytes, rhs) + r.SchemaHash = tmpBytes + } if len(m.unknownFields) > 0 { r.unknownFields = make([]byte, len(m.unknownFields)) copy(r.unknownFields, m.unknownFields) @@ -1363,6 +1368,9 @@ func (this *ResolverMeta) EqualVT(that *ResolverMeta) bool { if string(this.TraversalBloom) != string(that.TraversalBloom) { return false } + if string(this.SchemaHash) != string(that.SchemaHash) { + return false + } return string(this.unknownFields) == string(that.unknownFields) } @@ -2847,6 +2855,13 @@ func (m *ResolverMeta) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i -= len(m.unknownFields) copy(dAtA[i:], m.unknownFields) } + if len(m.SchemaHash) > 0 { + i -= len(m.SchemaHash) + copy(dAtA[i:], m.SchemaHash) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.SchemaHash))) + i-- + dAtA[i] = 0x2a + } if len(m.TraversalBloom) > 0 { i -= len(m.TraversalBloom) copy(dAtA[i:], m.TraversalBloom) @@ -3663,6 +3678,10 @@ func (m *ResolverMeta) SizeVT() (n int) { if l > 0 { n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) } + l = len(m.SchemaHash) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } n += len(m.unknownFields) return n } @@ -7051,6 +7070,40 @@ func (m *ResolverMeta) UnmarshalVT(dAtA []byte) error { m.TraversalBloom = []byte{} } iNdEx = postIndex + case 5: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field SchemaHash", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.SchemaHash = append(m.SchemaHash[:0], dAtA[iNdEx:postIndex]...) + if m.SchemaHash == nil { + m.SchemaHash = []byte{} + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := protohelpers.Skip(dAtA[iNdEx:]) diff --git a/pkg/proto/impl/v1/impl.pb.go b/pkg/proto/impl/v1/impl.pb.go index 8a2b2eec2..da3c9bba5 100644 --- a/pkg/proto/impl/v1/impl.pb.go +++ b/pkg/proto/impl/v1/impl.pb.go @@ -403,8 +403,11 @@ type V1Cursor struct { // datastore_unique_id is the unique ID for the datastore. Will be empty for legacy // cursors. DatastoreUniqueId string `protobuf:"bytes,6,opt,name=datastore_unique_id,json=datastoreUniqueId,proto3" json:"datastore_unique_id,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // schema_hash is the hash of the schema at the time the cursor was created. Will be + // empty for legacy cursors. + SchemaHash []byte `protobuf:"bytes,7,opt,name=schema_hash,json=schemaHash,proto3" json:"schema_hash,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *V1Cursor) Reset() { @@ -479,6 +482,13 @@ func (x *V1Cursor) GetDatastoreUniqueId() string { return "" } +func (x *V1Cursor) GetSchemaHash() []byte { + if x != nil { + return x.SchemaHash + } + return nil +} + type DocComment struct { state protoimpl.MessageState `protogen:"open.v1"` Comment string `protobuf:"bytes,1,opt,name=comment,proto3" json:"comment,omitempty"` @@ -932,14 +942,16 @@ const file_impl_v1_impl_proto_rawDesc = "" + "\rversion_oneof\"E\n" + "\rDecodedCursor\x12#\n" + "\x02v1\x18\x01 \x01(\v2\x11.impl.v1.V1CursorH\x00R\x02v1B\x0f\n" + - "\rversion_oneof\"\xc4\x02\n" + + "\rversion_oneof\"\xe5\x02\n" + "\bV1Cursor\x12\x1a\n" + "\brevision\x18\x01 \x01(\tR\brevision\x12\x1a\n" + "\bsections\x18\x02 \x03(\tR\bsections\x127\n" + "\x18call_and_parameters_hash\x18\x03 \x01(\tR\x15callAndParametersHash\x12)\n" + "\x10dispatch_version\x18\x04 \x01(\rR\x0fdispatchVersion\x122\n" + "\x05flags\x18\x05 \x03(\v2\x1c.impl.v1.V1Cursor.FlagsEntryR\x05flags\x12.\n" + - "\x13datastore_unique_id\x18\x06 \x01(\tR\x11datastoreUniqueId\x1a8\n" + + "\x13datastore_unique_id\x18\x06 \x01(\tR\x11datastoreUniqueId\x12\x1f\n" + + "\vschema_hash\x18\a \x01(\fR\n" + + "schemaHash\x1a8\n" + "\n" + "FlagsEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + diff --git a/pkg/proto/impl/v1/impl.pb.validate.go b/pkg/proto/impl/v1/impl.pb.validate.go index 6811f6333..266a5b820 100644 --- a/pkg/proto/impl/v1/impl.pb.validate.go +++ b/pkg/proto/impl/v1/impl.pb.validate.go @@ -737,6 +737,8 @@ func (m *V1Cursor) validate(all bool) error { // no validation rules for DatastoreUniqueId + // no validation rules for SchemaHash + if len(errors) > 0 { return V1CursorMultiError(errors) } diff --git a/pkg/proto/impl/v1/impl_vtproto.pb.go b/pkg/proto/impl/v1/impl_vtproto.pb.go index d5323399c..5ec4b439a 100644 --- a/pkg/proto/impl/v1/impl_vtproto.pb.go +++ b/pkg/proto/impl/v1/impl_vtproto.pb.go @@ -256,6 +256,11 @@ func (m *V1Cursor) CloneVT() *V1Cursor { } r.Flags = tmpContainer } + if rhs := m.SchemaHash; rhs != nil { + tmpBytes := make([]byte, len(rhs)) + copy(tmpBytes, rhs) + r.SchemaHash = tmpBytes + } if len(m.unknownFields) > 0 { r.unknownFields = make([]byte, len(m.unknownFields)) copy(r.unknownFields, m.unknownFields) @@ -756,6 +761,9 @@ func (this *V1Cursor) EqualVT(that *V1Cursor) bool { if this.DatastoreUniqueId != that.DatastoreUniqueId { return false } + if string(this.SchemaHash) != string(that.SchemaHash) { + return false + } return string(this.unknownFields) == string(that.unknownFields) } @@ -1410,6 +1418,13 @@ func (m *V1Cursor) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i -= len(m.unknownFields) copy(dAtA[i:], m.unknownFields) } + if len(m.SchemaHash) > 0 { + i -= len(m.SchemaHash) + copy(dAtA[i:], m.SchemaHash) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.SchemaHash))) + i-- + dAtA[i] = 0x3a + } if len(m.DatastoreUniqueId) > 0 { i -= len(m.DatastoreUniqueId) copy(dAtA[i:], m.DatastoreUniqueId) @@ -1931,6 +1946,10 @@ func (m *V1Cursor) SizeVT() (n int) { if l > 0 { n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) } + l = len(m.SchemaHash) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } n += len(m.unknownFields) return n } @@ -3174,6 +3193,40 @@ func (m *V1Cursor) UnmarshalVT(dAtA []byte) error { } m.DatastoreUniqueId = string(dAtA[iNdEx:postIndex]) iNdEx = postIndex + case 7: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field SchemaHash", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.SchemaHash = append(m.SchemaHash[:0], dAtA[iNdEx:postIndex]...) + if m.SchemaHash == nil { + m.SchemaHash = []byte{} + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := protohelpers.Skip(dAtA[iNdEx:]) diff --git a/pkg/query/analyze_test.go b/pkg/query/analyze_test.go index 5c43a2eb5..043c3e5da 100644 --- a/pkg/query/analyze_test.go +++ b/pkg/query/analyze_test.go @@ -140,7 +140,7 @@ func TestAnalysisIntegration(t *testing.T) { // Create a context with analysis enabled analyze := NewAnalyzeCollector() ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithAnalyze(analyze)) // Execute a Check operation @@ -294,7 +294,7 @@ func TestOptimizationImprovements(t *testing.T) { // Execute unoptimized tree analyzeUnoptimized := NewAnalyzeCollector() ctxUnoptimized := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithAnalyze(analyzeUnoptimized)) resources := []Object{{ObjectType: "document", ObjectID: "doc1"}} @@ -317,7 +317,7 @@ func TestOptimizationImprovements(t *testing.T) { // Execute optimized tree analyzeOptimized := NewAnalyzeCollector() ctxOptimized := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithAnalyze(analyzeOptimized)) pathSeq, err = ctxOptimized.Check(optimized, resources, subject) @@ -365,7 +365,7 @@ func TestOptimizationImprovements(t *testing.T) { // Execute unoptimized tree analyzeUnoptimized := NewAnalyzeCollector() ctxUnoptimized := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithAnalyze(analyzeUnoptimized)) resources := []Object{{ObjectType: "document", ObjectID: "doc1"}} @@ -388,7 +388,7 @@ func TestOptimizationImprovements(t *testing.T) { // Execute optimized tree analyzeOptimized := NewAnalyzeCollector() ctxOptimized := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithAnalyze(analyzeOptimized)) pathSeq, err = ctxOptimized.Check(optimized, resources, subject) @@ -460,7 +460,7 @@ func TestOptimizationImprovements(t *testing.T) { // Execute unoptimized tree analyzeUnoptimized := NewAnalyzeCollector() ctxUnoptimized := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithAnalyze(analyzeUnoptimized)) resources := []Object{ @@ -486,7 +486,7 @@ func TestOptimizationImprovements(t *testing.T) { // Execute optimized tree analyzeOptimized := NewAnalyzeCollector() ctxOptimized := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithAnalyze(analyzeOptimized)) pathSeq, err = ctxOptimized.Check(optimized, resources, subject) diff --git a/pkg/query/benchmarks/check_deep_arrow_benchmark_test.go b/pkg/query/benchmarks/check_deep_arrow_benchmark_test.go index bc9e72cda..94d9f783a 100644 --- a/pkg/query/benchmarks/check_deep_arrow_benchmark_test.go +++ b/pkg/query/benchmarks/check_deep_arrow_benchmark_test.go @@ -86,7 +86,7 @@ func BenchmarkCheckDeepArrow(b *testing.B) { // Create query context queryCtx := query.NewLocalContext(ctx, - query.WithReader(rawDS.SnapshotReader(revision)), + query.WithReader(rawDS.SnapshotReader(revision, datastore.NoSchemaHashForTesting)), query.WithMaxRecursionDepth(50), ) diff --git a/pkg/query/build_tree_test.go b/pkg/query/build_tree_test.go index 82c9f1e57..5e9a57b0f 100644 --- a/pkg/query/build_tree_test.go +++ b/pkg/query/build_tree_test.go @@ -8,6 +8,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/testfixtures" + "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/namespace" corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/schema/v2" @@ -31,7 +32,7 @@ func TestBuildTree(t *testing.T) { require.NoError(err) ctx := NewLocalContext(t.Context(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) relSeq, err := ctx.Check(it, NewObjects("document", "specialplan"), NewObject("user", "multiroleguy").WithEllipses()) require.NoError(err) @@ -61,7 +62,7 @@ func TestBuildTreeMultipleRelations(t *testing.T) { require.Contains(explain.String(), "Union", "edit permission should create a union iterator") ctx := NewLocalContext(t.Context(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) relSeq, err := ctx.Check(it, NewObjects("document", "specialplan"), NewObject("user", "multiroleguy").WithEllipses()) require.NoError(err) @@ -112,7 +113,7 @@ func TestBuildTreeSubRelations(t *testing.T) { require.NotEmpty(explain.String()) ctx := NewLocalContext(t.Context(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) // Just test that the iterator can be executed without error relSeq, err := ctx.Check(it, NewObjects("document", "companyplan"), NewObject("user", "legal").WithEllipses()) @@ -211,7 +212,7 @@ func TestBuildTreeIntersectionOperation(t *testing.T) { require.Contains(explain.String(), "Intersection", "should create intersection iterator") ctx := NewLocalContext(t.Context(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) // Test execution relSeq, err := ctx.Check(it, NewObjects("document", "specialplan"), NewObject("user", "multiroleguy").WithEllipses()) @@ -274,7 +275,7 @@ func TestBuildTreeExclusionEdgeCases(t *testing.T) { ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) ctx := NewLocalContext(t.Context(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) userDef := testfixtures.UserNS.CloneVT() @@ -534,7 +535,7 @@ func TestBuildTreeSingleRelationOptimization(t *testing.T) { require.Contains(explain.String(), "Relation", "should create relation iterator") ctx := NewLocalContext(t.Context(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) // Test execution relSeq, err := ctx.Check(it, NewObjects("document", "companyplan"), NewObject("user", "legal").WithEllipses()) @@ -554,7 +555,7 @@ func TestBuildTreeSubrelationHandling(t *testing.T) { ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) ctx := NewLocalContext(t.Context(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) userDef := testfixtures.UserNS.CloneVT() diff --git a/pkg/query/caveat_test.go b/pkg/query/caveat_test.go index 518c8e6ad..9e9a290d6 100644 --- a/pkg/query/caveat_test.go +++ b/pkg/query/caveat_test.go @@ -109,7 +109,7 @@ func TestCaveatIteratorNoCaveat(t *testing.T) { require.NoError(t, err) queryCtx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(rev)), + WithReader(ds.SnapshotReader(rev, datastore.NoSchemaHashForTesting)), WithCaveatContext(tc.caveatContext), WithCaveatRunner(caveats.NewCaveatRunner(types.NewTypeSet()))) @@ -201,7 +201,7 @@ func TestCaveatIteratorWithCaveat(t *testing.T) { require.NoError(t, err) queryCtx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(rev)), + WithReader(ds.SnapshotReader(rev, datastore.NoSchemaHashForTesting)), WithCaveatContext(tc.caveatContext), WithCaveatRunner(caveats.NewCaveatRunner(types.NewTypeSet()))) diff --git a/pkg/query/exclusion_test.go b/pkg/query/exclusion_test.go index 9279e4028..658559500 100644 --- a/pkg/query/exclusion_test.go +++ b/pkg/query/exclusion_test.go @@ -8,6 +8,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/testfixtures" + "github.com/authzed/spicedb/pkg/datastore" core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/tuple" ) @@ -22,7 +23,7 @@ func TestExclusionIterator(t *testing.T) { ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) ctx := NewLocalContext(t.Context(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) // Create test paths path1 := MustPathFromString("document:doc1#viewer@user:alice") @@ -258,7 +259,7 @@ func TestExclusionWithEmptyIterator(t *testing.T) { ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) ctx := NewLocalContext(t.Context(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) path1 := MustPathFromString("document:doc1#viewer@user:alice") @@ -304,7 +305,7 @@ func TestExclusionErrorHandling(t *testing.T) { ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) ctx := NewLocalContext(t.Context(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) path1 := MustPathFromString("document:doc1#viewer@user:alice") @@ -385,7 +386,7 @@ func TestExclusionWithComplexIteratorTypes(t *testing.T) { ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) ctx := NewLocalContext(t.Context(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) // Create test relations path1 := MustPathFromString("document:doc1#viewer@user:alice") @@ -571,7 +572,7 @@ func TestExclusion_CombinedCaveatLogic(t *testing.T) { ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) ctx := NewLocalContext(t.Context(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) // Helper to create paths with caveats createPathWithCaveat := func(relation, caveatName string) Path { diff --git a/pkg/query/intersection_arrow_test.go b/pkg/query/intersection_arrow_test.go index 7f128264c..562b63b06 100644 --- a/pkg/query/intersection_arrow_test.go +++ b/pkg/query/intersection_arrow_test.go @@ -36,11 +36,11 @@ func TestIntersectionArrowIterator(t *testing.T) { ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) require.NoError(err) - revision, err := ds.HeadRevision(context.Background()) + revision, _, err := ds.HeadRevision(context.Background()) require.NoError(err) ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) // Test: alice should have access because she's a member of ALL teams (team1 and team2) resources := []Object{NewObject("document", "doc1")} @@ -89,7 +89,7 @@ func TestIntersectionArrowIterator(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) // Test: alice should NOT have access because she's not a member of ALL teams resources := []Object{NewObject("document", "doc1")} @@ -129,7 +129,7 @@ func TestIntersectionArrowIterator(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) // Test: alice should have access because she's a member of the only team resources := []Object{NewObject("document", "doc1")} @@ -174,7 +174,7 @@ func TestIntersectionArrowIterator(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) resources := []Object{NewObject("document", "doc1")} subject := ObjectAndRelation{ObjectType: "user", ObjectID: "alice"} @@ -217,7 +217,7 @@ func TestIntersectionArrowIterator(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) resources := []Object{NewObject("document", "doc1")} subject := ObjectAndRelation{ObjectType: "user", ObjectID: "alice"} @@ -256,7 +256,7 @@ func TestIntersectionArrowIterator(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) resources := []Object{} subject := ObjectAndRelation{ObjectType: "user", ObjectID: "alice"} @@ -285,7 +285,7 @@ func TestIntersectionArrowIteratorCaveatCombination(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) t.Run("CombineTwoCaveats_AND_Logic", func(t *testing.T) { t.Parallel() @@ -512,7 +512,7 @@ func TestIntersectionArrowIteratorClone(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) // Test that both iterators produce the same results resources := []Object{NewObject("document", "doc1")} diff --git a/pkg/query/quick_e2e_test.go b/pkg/query/quick_e2e_test.go index dab08f694..5570df083 100644 --- a/pkg/query/quick_e2e_test.go +++ b/pkg/query/quick_e2e_test.go @@ -8,6 +8,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/testfixtures" + "github.com/authzed/spicedb/pkg/datastore" corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/schema/v2" ) @@ -40,7 +41,7 @@ func TestCheck(t *testing.T) { it.addSubIterator(edit) ctx := NewLocalContext(t.Context(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) relSeq, err := ctx.Check(it, NewObjects("document", "specialplan"), NewObject("user", "multiroleguy").WithEllipses()) require.NoError(err) @@ -68,7 +69,7 @@ func TestBaseIterSubjects(t *testing.T) { vande := NewRelationIterator(vandeRel.BaseRelations()[0]) ctx := NewLocalContext(t.Context(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) relSeq, err := ctx.IterSubjects(vande, NewObject("document", "specialplan"), NoObjectFilter()) require.NoError(err) @@ -101,7 +102,7 @@ func TestCheckArrow(t *testing.T) { it := NewArrow(folders, view) ctx := NewLocalContext(t.Context(), - WithReader(ds.SnapshotReader(revision))) + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) relSeq, err := ctx.Check(it, NewObjects("document", "companyplan"), NewObject("user", "legal").WithEllipses()) require.NoError(err) diff --git a/pkg/query/recursive_benchmark_test.go b/pkg/query/recursive_benchmark_test.go index 657363e7e..979111950 100644 --- a/pkg/query/recursive_benchmark_test.go +++ b/pkg/query/recursive_benchmark_test.go @@ -43,7 +43,7 @@ func BenchmarkRecursiveShallowGraph(b *testing.B) { } ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithMaxRecursionDepth(50)) b.ResetTimer() @@ -96,7 +96,7 @@ func BenchmarkRecursiveWideGraph(b *testing.B) { } ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithMaxRecursionDepth(50)) b.ResetTimer() @@ -144,7 +144,7 @@ func BenchmarkRecursiveDeepGraph(b *testing.B) { } ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithMaxRecursionDepth(50)) b.ResetTimer() @@ -174,7 +174,7 @@ func BenchmarkRecursiveEmptyGraph(b *testing.B) { } ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithMaxRecursionDepth(50)) b.ResetTimer() @@ -233,7 +233,7 @@ func BenchmarkRecursiveSparseGraph(b *testing.B) { } ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithMaxRecursionDepth(50)) b.ResetTimer() @@ -281,7 +281,7 @@ func BenchmarkRecursiveCyclicGraph(b *testing.B) { } ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithMaxRecursionDepth(50)) b.ResetTimer() @@ -329,7 +329,7 @@ func BenchmarkRecursiveIterResources(b *testing.B) { } ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithMaxRecursionDepth(50)) b.ResetTimer() diff --git a/pkg/query/recursive_coverage_test.go b/pkg/query/recursive_coverage_test.go index 8e19d6b53..2b4162443 100644 --- a/pkg/query/recursive_coverage_test.go +++ b/pkg/query/recursive_coverage_test.go @@ -117,7 +117,7 @@ func TestBreadthFirstIterResources_MaxDepth(t *testing.T) { // Set a low max depth ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithMaxRecursionDepth(3)) seq, err := recursive.IterResourcesImpl(ctx, ObjectAndRelation{ObjectType: "user", ObjectID: "alice", Relation: "..."}, NoObjectFilter()) @@ -149,7 +149,7 @@ func TestBreadthFirstIterResources_ErrorHandling(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision))) + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting))) seq, err := recursive.IterResourcesImpl(ctx, ObjectAndRelation{ObjectType: "user", ObjectID: "alice", Relation: "..."}, NoObjectFilter()) require.NoError(err) @@ -174,7 +174,7 @@ func TestBreadthFirstIterResources_ErrorHandling(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision))) + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting))) seq, err := recursive.IterResourcesImpl(ctx, ObjectAndRelation{ObjectType: "user", ObjectID: "alice", Relation: "..."}, NoObjectFilter()) require.NoError(err) @@ -208,7 +208,7 @@ func TestBreadthFirstIterResources_MergeOrSemantics(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithMaxRecursionDepth(5)) seq, err := recursive.IterResourcesImpl(ctx, ObjectAndRelation{ObjectType: "user", ObjectID: "alice", Relation: "..."}, NoObjectFilter()) @@ -236,7 +236,7 @@ func TestIterativeDeepening_MaxDepth(t *testing.T) { maxDepth := 5 ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithMaxRecursionDepth(maxDepth)) seq, err := recursive.CheckImpl(ctx, []Object{{ObjectType: "folder", ObjectID: "folder1"}}, ObjectAndRelation{ObjectType: "user", ObjectID: "alice", Relation: "..."}) diff --git a/pkg/query/recursive_strategies_test.go b/pkg/query/recursive_strategies_test.go index c5db5e09a..e3ba7730f 100644 --- a/pkg/query/recursive_strategies_test.go +++ b/pkg/query/recursive_strategies_test.go @@ -42,7 +42,7 @@ func TestRecursiveCheckStrategies(t *testing.T) { require.NoError(t, err) queryCtx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision))) + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting))) // Test all three strategies strategies := []struct { @@ -105,7 +105,7 @@ func TestRecursiveCheckStrategiesEmpty(t *testing.T) { recursive := NewRecursiveIterator(emptyFixed, "folder", "view") queryCtx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision))) + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting))) strategies := []recursiveCheckStrategy{ recursiveCheckIterSubjects, @@ -159,7 +159,7 @@ func TestRecursiveCheckStrategiesMultipleResources(t *testing.T) { require.NoError(t, err) queryCtx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision))) + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting))) strategies := []recursiveCheckStrategy{ recursiveCheckIterSubjects, diff --git a/pkg/query/recursive_test.go b/pkg/query/recursive_test.go index 724296caa..189d646dd 100644 --- a/pkg/query/recursive_test.go +++ b/pkg/query/recursive_test.go @@ -24,7 +24,7 @@ func TestRecursiveSentinel(t *testing.T) { require.NoError(t, err) ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision))) + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting))) // CheckImpl should return empty seq, err := sentinel.CheckImpl(ctx, []Object{{ObjectType: "folder", ObjectID: "folder1"}}, ObjectAndRelation{ObjectType: "user", ObjectID: "tom", Relation: "..."}) @@ -63,7 +63,7 @@ func TestRecursiveIteratorEmptyBaseCase(t *testing.T) { require.NoError(t, err) ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision))) + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting))) // Execute - should terminate immediately with empty result seq, err := recursive.CheckImpl(ctx, []Object{{ObjectType: "folder", ObjectID: "folder1"}}, ObjectAndRelation{ObjectType: "user", ObjectID: "tom", Relation: "..."}) @@ -171,7 +171,7 @@ func TestRecursiveIteratorExecutionError(t *testing.T) { require.NoError(t, err) ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision))) + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting))) // Test CheckImpl with a faulty iterator seq, err := recursive.CheckImpl(ctx, []Object{{ObjectType: "folder", ObjectID: "folder1"}}, ObjectAndRelation{ObjectType: "user", ObjectID: "tom", Relation: "..."}) @@ -202,7 +202,7 @@ func TestRecursiveIteratorCollectionError(t *testing.T) { require.NoError(t, err) ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision))) + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting))) // Test CheckImpl with a faulty iterator that fails on collection seq, err := recursive.CheckImpl(ctx, []Object{{ObjectType: "folder", ObjectID: "folder1"}}, ObjectAndRelation{ObjectType: "user", ObjectID: "tom", Relation: "..."}) @@ -227,7 +227,7 @@ func TestBFSEarlyTermination(t *testing.T) { require.NoError(t, err) ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithMaxRecursionDepth(50)) // High max depth // IterSubjects on a node with no children (sentinel returns empty) @@ -275,7 +275,7 @@ func TestBFSCycleDetection(t *testing.T) { require.NoError(t, err) ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithMaxRecursionDepth(10)) seq, err := recursive.IterSubjectsImpl(ctx, Object{ObjectType: "folder", ObjectID: "folder1"}, NoObjectFilter()) @@ -308,7 +308,7 @@ func TestBFSSelfReferential(t *testing.T) { require.NoError(t, err) ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithMaxRecursionDepth(10)) seq, err := recursive.IterSubjectsImpl(ctx, Object{ObjectType: "folder", ObjectID: "folder1"}, NoObjectFilter()) @@ -350,7 +350,7 @@ func TestBFSResourcesWithEllipses(t *testing.T) { require.NoError(t, err) ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(datastore.NoRevision)), + WithReader(ds.SnapshotReader(datastore.NoRevision, datastore.NoSchemaHashForTesting)), WithMaxRecursionDepth(5)) // Query IterResources - should find folder2 diff --git a/pkg/query/simplify_caveat_test.go b/pkg/query/simplify_caveat_test.go index 563e4289a..01d215fdd 100644 --- a/pkg/query/simplify_caveat_test.go +++ b/pkg/query/simplify_caveat_test.go @@ -46,7 +46,7 @@ func TestSimplifyLeafCaveat(t *testing.T) { }) require.NoError(err) - reader := ds.SnapshotReader(revision) + reader := ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting) runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) // Create caveat expression without context @@ -133,7 +133,7 @@ func TestSimplifyAndOperation(t *testing.T) { }) require.NoError(err) - reader := ds.SnapshotReader(revision) + reader := ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting) runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) // Create AND expression: caveat1 AND caveat2 @@ -237,7 +237,7 @@ func TestSimplifyOrOperation(t *testing.T) { }) require.NoError(err) - reader := ds.SnapshotReader(revision) + reader := ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting) runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) // Create OR expression: caveat1 OR caveat2 @@ -353,7 +353,7 @@ func TestSimplifyNestedOperations(t *testing.T) { }) require.NoError(err) - reader := ds.SnapshotReader(revision) + reader := ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting) runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) // Create nested expression: (caveat1 OR caveat2) AND caveat3 @@ -438,7 +438,7 @@ func TestSimplifyOrWithSameCaveatDifferentContexts(t *testing.T) { }) require.NoError(err) - reader := ds.SnapshotReader(revision) + reader := ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting) runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) // Create OR expression: write_limit(limit=2) OR write_limit(limit=4) @@ -524,7 +524,7 @@ func TestSimplifyAndWithSameCaveatDifferentContexts(t *testing.T) { }) require.NoError(err) - reader := ds.SnapshotReader(revision) + reader := ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting) runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) // Create AND expression: write_limit(limit=2) AND write_limit(limit=4) @@ -620,7 +620,7 @@ func TestSimplifyNotWithSameCaveatDifferentContexts(t *testing.T) { }) require.NoError(err) - reader := ds.SnapshotReader(revision) + reader := ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting) runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) // Create NOT expression: NOT write_limit(limit=4) @@ -702,7 +702,7 @@ func TestSimplifyComplexNestedExpressions(t *testing.T) { }) require.NoError(err) - reader := ds.SnapshotReader(revision) + reader := ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting) runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) t.Run("OrOfAnds_ComplexNesting", func(t *testing.T) { @@ -1169,7 +1169,7 @@ func TestSimplifyWithEmptyContext(t *testing.T) { }) require.NoError(err) - reader := ds.SnapshotReader(revision) + reader := ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting) runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) // Create nested expression: (caveat1 OR caveat2) AND caveat3 @@ -1251,7 +1251,7 @@ func TestSimplifyNotConditional(t *testing.T) { }) require.NoError(err) - reader := ds.SnapshotReader(revision) + reader := ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting) runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) // Create NOT expression: NOT limit_check(limit=10) @@ -1337,7 +1337,7 @@ func TestSimplifyDeeplyNestedCaveats(t *testing.T) { }) require.NoError(err) - reader := ds.SnapshotReader(revision) + reader := ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting) runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) // Helper to create caveat expressions diff --git a/pkg/query/tracing_test.go b/pkg/query/tracing_test.go index 5c8866a57..ed752dfc0 100644 --- a/pkg/query/tracing_test.go +++ b/pkg/query/tracing_test.go @@ -23,7 +23,7 @@ func TestIteratorTracing(t *testing.T) { require.NoError(t, err) ctx := NewLocalContext(context.Background(), - WithReader(ds.SnapshotReader(revision)), + WithReader(ds.SnapshotReader(revision, datastore.NoSchemaHashForTesting)), WithTraceLogger(traceLogger), ) diff --git a/pkg/query/wildcard_multirelation_test.go b/pkg/query/wildcard_multirelation_test.go index f3653def2..57cada914 100644 --- a/pkg/query/wildcard_multirelation_test.go +++ b/pkg/query/wildcard_multirelation_test.go @@ -81,7 +81,7 @@ func TestIterSubjectsWildcardWithMultipleRelations(t *testing.T) { wildcardBranch := NewRelationIterator(viewerRel.BaseRelations()[1]) // user:* (wildcard) queryCtx := NewLocalContext(ctx, - WithReader(rawDS.SnapshotReader(revision)), + WithReader(rawDS.SnapshotReader(revision, datastore.NoSchemaHashForTesting)), WithTraceLogger(NewTraceLogger())) // Enable tracing for debugging subjects, err := queryCtx.IterSubjects(wildcardBranch, NewObject("document", "publicdoc"), NoObjectFilter()) require.NoError(err) @@ -112,7 +112,7 @@ func TestIterSubjectsWildcardWithMultipleRelations(t *testing.T) { union.addSubIterator(NewRelationIterator(viewerRel.BaseRelations()[1])) // user:* (wildcard) queryCtx := NewLocalContext(ctx, - WithReader(rawDS.SnapshotReader(revision)), + WithReader(rawDS.SnapshotReader(revision, datastore.NoSchemaHashForTesting)), WithTraceLogger(NewTraceLogger())) // Enable tracing for debugging subjects, err := queryCtx.IterSubjects(union, NewObject("document", "publicdoc"), NoObjectFilter()) require.NoError(err) diff --git a/pkg/query/wildcard_subjects_test.go b/pkg/query/wildcard_subjects_test.go index a0f651f81..88ad25eb7 100644 --- a/pkg/query/wildcard_subjects_test.go +++ b/pkg/query/wildcard_subjects_test.go @@ -71,7 +71,7 @@ func TestIterSubjectsWithWildcard(t *testing.T) { // The non-wildcard branch should only return concrete subjects, filtering out wildcards nonWildcardBranch := NewRelationIterator(viewerRel.BaseRelations()[0]) // user (non-wildcard) - queryCtx := NewLocalContext(ctx, WithReader(rawDS.SnapshotReader(revision))) + queryCtx := NewLocalContext(ctx, WithReader(rawDS.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) subjects, err := queryCtx.IterSubjects(nonWildcardBranch, NewObject("resource", "first"), NoObjectFilter()) require.NoError(err) @@ -90,7 +90,7 @@ func TestIterSubjectsWithWildcard(t *testing.T) { // The wildcard branch should enumerate concrete subjects when a wildcard exists wildcardBranch := NewRelationIterator(viewerRel.BaseRelations()[1]) // user:* (wildcard) - queryCtx := NewLocalContext(ctx, WithReader(rawDS.SnapshotReader(revision))) + queryCtx := NewLocalContext(ctx, WithReader(rawDS.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) subjects, err := queryCtx.IterSubjects(wildcardBranch, NewObject("resource", "first"), NoObjectFilter()) require.NoError(err) @@ -111,7 +111,7 @@ func TestIterSubjectsWithWildcard(t *testing.T) { union.addSubIterator(NewRelationIterator(viewerRel.BaseRelations()[0])) // user union.addSubIterator(NewRelationIterator(viewerRel.BaseRelations()[1])) // user:* - queryCtx := NewLocalContext(ctx, WithReader(rawDS.SnapshotReader(revision))) + queryCtx := NewLocalContext(ctx, WithReader(rawDS.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) subjects, err := queryCtx.IterSubjects(union, NewObject("resource", "first"), NoObjectFilter()) require.NoError(err) @@ -178,7 +178,7 @@ func TestIterSubjectsWildcardWithoutWildcardRelationship(t *testing.T) { // The wildcard branch should return empty because there's no wildcard relationship wildcardBranch := NewRelationIterator(viewerRel.BaseRelations()[1]) // user:* (wildcard) - queryCtx := NewLocalContext(ctx, WithReader(rawDS.SnapshotReader(revision))) + queryCtx := NewLocalContext(ctx, WithReader(rawDS.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) subjects, err := queryCtx.IterSubjects(wildcardBranch, NewObject("resource", "second"), NoObjectFilter()) require.NoError(err) @@ -194,7 +194,7 @@ func TestIterSubjectsWildcardWithoutWildcardRelationship(t *testing.T) { t.Parallel() nonWildcardBranch := NewRelationIterator(viewerRel.BaseRelations()[0]) // user (non-wildcard) - queryCtx := NewLocalContext(ctx, WithReader(rawDS.SnapshotReader(revision))) + queryCtx := NewLocalContext(ctx, WithReader(rawDS.SnapshotReader(revision, datastore.NoSchemaHashForTesting))) subjects, err := queryCtx.IterSubjects(nonWildcardBranch, NewObject("resource", "second"), NoObjectFilter()) require.NoError(err) diff --git a/pkg/services/v1/services.go b/pkg/services/v1/services.go index d9543dad1..52066dd2b 100644 --- a/pkg/services/v1/services.go +++ b/pkg/services/v1/services.go @@ -12,6 +12,6 @@ import ( // BulkExport implements the BulkExportRelationships API functionality. Given a datastore.Datastore, it will // export stream via the sender all relationships matched by the incoming request. // If no cursor is provided, it will fallback to the provided revision. -func BulkExport(ctx context.Context, ds datastore.ReadOnlyDatastore, batchSize uint64, req *v1.BulkExportRelationshipsRequest, fallbackRevision datastore.Revision, sender func(response *v1.BulkExportRelationshipsResponse) error) error { - return servicesv1.BulkExport(ctx, ds, batchSize, req, fallbackRevision, sender) +func BulkExport(ctx context.Context, ds datastore.ReadOnlyDatastore, batchSize uint64, req *v1.BulkExportRelationshipsRequest, fallbackRevision datastore.Revision, fallbackSchemaHash datastore.SchemaHash, sender func(response *v1.BulkExportRelationshipsResponse) error) error { + return servicesv1.BulkExport(ctx, ds, batchSize, req, fallbackRevision, fallbackSchemaHash, sender) } diff --git a/proto/internal/core/v1/core.proto b/proto/internal/core/v1/core.proto index f8173e872..cb01324ba 100644 --- a/proto/internal/core/v1/core.proto +++ b/proto/internal/core/v1/core.proto @@ -623,3 +623,28 @@ message SubjectFilter { RelationFilter optional_relation = 3; } + +// StoredSchema represents a stored schema in SpiceDB under the new, unified schema format. +message StoredSchema { + uint32 version = 1; + + message V1StoredSchema { + // schema_text is the text of the schema that was given to SpiceDB by the API caller. + string schema_text = 1; + + // schema_hash is the hash of the schema, for change detection. + string schema_hash = 2; + + // namespace_definitions is a map of namespace name to NamespaceDefinition. + // Entries must have full metadata filled out. + map namespace_definitions = 3; + + // caveat_definitions is a map of caveat name to CaveatDefinition. + // Entries must have full metadata filled out. + map caveat_definitions = 4; + } + + oneof version_oneof { + V1StoredSchema v1 = 2; + } +} diff --git a/proto/internal/dispatch/v1/dispatch.proto b/proto/internal/dispatch/v1/dispatch.proto index 16e2a6c29..dc9dbcdb8 100644 --- a/proto/internal/dispatch/v1/dispatch.proto +++ b/proto/internal/dispatch/v1/dispatch.proto @@ -178,6 +178,7 @@ message ResolverMeta { uint32 depth_remaining = 2 [(validate.rules).uint32.gt = 0]; string request_id = 3 [deprecated = true]; bytes traversal_bloom = 4 [(validate.rules).bytes = {max_len: 1024}]; + bytes schema_hash = 5; } message ResponseMeta { diff --git a/proto/internal/impl/v1/impl.proto b/proto/internal/impl/v1/impl.proto index cd85ece0c..e6b6dbb09 100644 --- a/proto/internal/impl/v1/impl.proto +++ b/proto/internal/impl/v1/impl.proto @@ -71,6 +71,10 @@ message V1Cursor { // datastore_unique_id is the unique ID for the datastore. Will be empty for legacy // cursors. string datastore_unique_id = 6; + + // schema_hash is the hash of the schema at the time the cursor was created. Will be + // empty for legacy cursors. + bytes schema_hash = 7; } message DocComment { diff --git a/tools/analyzers/cmd/analyzers/main.go b/tools/analyzers/cmd/analyzers/main.go index 8500b5e05..fb964e4d5 100644 --- a/tools/analyzers/cmd/analyzers/main.go +++ b/tools/analyzers/cmd/analyzers/main.go @@ -13,6 +13,7 @@ import ( "github.com/authzed/spicedb/tools/analyzers/paniccheck" "github.com/authzed/spicedb/tools/analyzers/protomarshalcheck" "github.com/authzed/spicedb/tools/analyzers/telemetryconvcheck" + "github.com/authzed/spicedb/tools/analyzers/testsentinelcheck" "github.com/authzed/spicedb/tools/analyzers/zerologmarshalcheck" ) @@ -29,5 +30,6 @@ func main() { protomarshalcheck.Analyzer(), zerologmarshalcheck.Analyzer(), telemetryconvcheck.Analyzer(), + testsentinelcheck.Analyzer(), ) } diff --git a/tools/analyzers/testsentinelcheck/testdata/src/github.com/authzed/spicedb/pkg/datastore/datastore.go b/tools/analyzers/testsentinelcheck/testdata/src/github.com/authzed/spicedb/pkg/datastore/datastore.go new file mode 100644 index 000000000..ecde2f83c --- /dev/null +++ b/tools/analyzers/testsentinelcheck/testdata/src/github.com/authzed/spicedb/pkg/datastore/datastore.go @@ -0,0 +1,4 @@ +package datastore + +// NoSchemaHashForTesting is a test-only sentinel value +const NoSchemaHashForTesting = "__test_only__" diff --git a/tools/analyzers/testsentinelcheck/testdata/src/testsentinel/bad.go b/tools/analyzers/testsentinelcheck/testdata/src/testsentinel/bad.go new file mode 100644 index 000000000..e3a799fa0 --- /dev/null +++ b/tools/analyzers/testsentinelcheck/testdata/src/testsentinel/bad.go @@ -0,0 +1,11 @@ +package testsentinel + +import ( + "github.com/authzed/spicedb/pkg/datastore" +) + +// This is NOT a test file, so usage of NoSchemaHashForTesting should be flagged +func SomeFunction() { + hash := datastore.NoSchemaHashForTesting // want "NoSchemaHashForTesting should only be used in test files" + _ = hash +} diff --git a/tools/analyzers/testsentinelcheck/testdata/src/testsentinel/good_test.go b/tools/analyzers/testsentinelcheck/testdata/src/testsentinel/good_test.go new file mode 100644 index 000000000..952f51167 --- /dev/null +++ b/tools/analyzers/testsentinelcheck/testdata/src/testsentinel/good_test.go @@ -0,0 +1,13 @@ +package testsentinel + +import ( + "testing" + + "github.com/authzed/spicedb/pkg/datastore" +) + +// This is a test file, so usage of NoSchemaHashForTesting is allowed +func TestSomething(t *testing.T) { + hash := datastore.NoSchemaHashForTesting + _ = hash +} diff --git a/tools/analyzers/testsentinelcheck/testsentinelcheck.go b/tools/analyzers/testsentinelcheck/testsentinelcheck.go new file mode 100644 index 000000000..ba2b561b2 --- /dev/null +++ b/tools/analyzers/testsentinelcheck/testsentinelcheck.go @@ -0,0 +1,91 @@ +package testsentinelcheck + +import ( + "flag" + "go/ast" + "path/filepath" + "strings" + + "github.com/samber/lo" + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/inspector" +) + +func Analyzer() *analysis.Analyzer { + flagSet := flag.NewFlagSet("testsentinelcheck", flag.ExitOnError) + skipPkg := flagSet.String("skip-pkg", "", "package(s) to skip for linting") + + return &analysis.Analyzer{ + Name: "testsentinelcheck", + Doc: "reports usage of test-only sentinel values like NoSchemaHashForTesting outside of test files", + Run: func(pass *analysis.Pass) (any, error) { + // Check for a skipped package. + if len(*skipPkg) > 0 { + skipped := lo.Map(strings.Split(*skipPkg, ","), func(skipped string, _ int) string { return strings.TrimSpace(skipped) }) + for _, s := range skipped { + if strings.Contains(pass.Pkg.Path(), s) { + return nil, nil + } + } + } + + inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) + + nodeFilter := []ast.Node{ + (*ast.File)(nil), + (*ast.Ident)(nil), + } + + var currentFile string + var isTestFile bool + + inspect.WithStack(nodeFilter, func(n ast.Node, push bool, stack []ast.Node) bool { + switch s := n.(type) { + case *ast.File: + // Track the current file being analyzed + currentFile = pass.Fset.Position(s.Package).Filename + baseName := filepath.Base(currentFile) + isTestFile = strings.HasSuffix(baseName, "_test.go") + return true + + case *ast.Ident: + // Skip if we're in a test file + if isTestFile { + return false + } + + // Check if this identifier is NoSchemaHashForTesting + if s.Name == "NoSchemaHashForTesting" { + // Verify this is actually referencing the datastore constant + // by checking if it's a selector expression or qualified identifier + obj := pass.TypesInfo.ObjectOf(s) + if obj != nil { + // Skip if this is the definition of the constant itself + if obj.Pos() == s.Pos() { + return false + } + + if obj.Pkg() != nil { + pkgPath := obj.Pkg().Path() + // Check if it's from the datastore package + if strings.HasSuffix(pkgPath, "github.com/authzed/spicedb/pkg/datastore") { + pass.Reportf(s.Pos(), "NoSchemaHashForTesting should only be used in test files (file: %s)", filepath.Base(currentFile)) + } + } + } + } + + return false + + default: + return true + } + }) + + return nil, nil + }, + Requires: []*analysis.Analyzer{inspect.Analyzer}, + Flags: *flagSet, + } +} diff --git a/tools/analyzers/testsentinelcheck/testsentinelcheck_test.go b/tools/analyzers/testsentinelcheck/testsentinelcheck_test.go new file mode 100644 index 000000000..b2cd6ad76 --- /dev/null +++ b/tools/analyzers/testsentinelcheck/testsentinelcheck_test.go @@ -0,0 +1,14 @@ +package testsentinelcheck_test + +import ( + "testing" + + "golang.org/x/tools/go/analysis/analysistest" + + "github.com/authzed/spicedb/tools/analyzers/testsentinelcheck" +) + +func TestAnalyzer(t *testing.T) { + testdata := analysistest.TestData() + analysistest.Run(t, testdata, testsentinelcheck.Analyzer(), "testsentinel") +}