diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ada4af36..1ce2bdfaf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,20 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - feat(query planner): add recursive direction strategies, and fix IS BFS (https://github.com/authzed/spicedb/pull/2891) - feat(query planner): introduce query plan outlines and canonicalization (https://github.com/authzed/spicedb/pull/2901) - Schema v2: introduces support for PostOrder traversal in walk.go (https://github.com/authzed/spicedb/pull/2761) and improve PostOrder walker cycle detection (https://github.com/authzed/spicedb/pull/2902) +- Experimental: Add unified schema storage with ReadStoredSchema/WriteStoredSchema for improved schema read performance (https://github.com/authzed/spicedb/pull/2924) + + This feature stores the entire schema as a single serialized proto rather than reading individual namespace and caveat definitions separately, significantly improving schema read performance. + + Migration to unified schema storage is controlled by the `--experimental-schema-mode` flag, which supports a 4-phase rolling migration: + + 1. `read-legacy-write-legacy` (default) - No change; reads and writes use legacy per-definition storage. + 2. `read-legacy-write-both` - Reads from legacy storage, writes to both legacy and unified storage. This is the first migration step and backfills the unified schema table. + 3. `read-new-write-both` - Reads from unified storage, writes to both. Validates the new read path while maintaining backward compatibility. + 4. `read-new-write-new` - Reads and writes only unified storage. This is the final migration target. + + **With the SpiceDB Operator:** Configure the operator to roll through stages 1 through 4 in sequence. The operator handles the rolling update of SpiceDB instances at each stage. + + **Without the operator:** Progress through the stages manually by updating the `--experimental-schema-mode` flag and performing a rolling restart at each stage. You can also take the system down briefly and move directly from stage 1 to stage 4, which runs the full migration in one step. ### Changed - Begin deprecation of library "github.com/dlmiddlecote/sqlstats" (https://github.com/authzed/spicedb/pull/2904). diff --git a/docs/spicedb.md b/docs/spicedb.md index 29e673b44..bb6a41fab 100644 --- a/docs/spicedb.md +++ b/docs/spicedb.md @@ -517,6 +517,7 @@ spicedb serve [flags] --experimental-dispatch-secondary-upstream-exprs stringToString map from request type to its associated CEL expression, which returns the secondary upstream(s) to be used for the request (default []) --experimental-lookup-resources-version lr3 if non-empty, the version of the experimental lookup resources API to use: lr3 or empty --experimental-query-plan check if non-empty, the version of the experimental query plan to use: check or empty + --experimental-schema-mode string schema storage mode for migration to unified schema: read-legacy-write-legacy, read-legacy-write-both, read-new-write-both, read-new-write-new (default "read-legacy-write-legacy") --grpc-addr string address to listen on to serve gRPC (default ":50051") --grpc-enabled enable gRPC gRPC server (default true) --grpc-log-requests-enabled enable logging of API request payloads @@ -561,6 +562,10 @@ spicedb serve [flags] --pprof-block-profile-rate int sets the block profile sampling rate (between 0 and 1) --pprof-mutex-profile-rate int sets the mutex profile sampling rate (between 0 and 1) --schema-prefixes-required require prefixes on all object definitions in schemas + --stored-schema-cache-enabled enable caching of stored schema (default true) + --stored-schema-cache-max-cost string upper bound (in bytes or as a percent of available memory) of the cache for stored schema (default "32MiB") + --stored-schema-cache-metrics enable metrics for the cache for stored schema (default true) + --stored-schema-cache-num-counters int number of counters for tracking access frequency in the cache for stored schema. A higher number means more accurate eviction decisions but more memory usage (default 1000) --streaming-api-response-delay-timeout duration maximum time that streaming APIs (LookupSubjects, LookupResources, ReadRelationships and ExportBulkRelationships) can be allowed to run but no response be sent to the client before the stream times out (default 30s) --telemetry-ca-override-path string path to a custom CA to use with the telemetry endpoint --telemetry-endpoint string endpoint to which telemetry is reported, empty string to disable (default "https://telemetry.authzed.com") diff --git a/internal/caveats/run_test.go b/internal/caveats/run_test.go index 00a96d8da..07349003a 100644 --- a/internal/caveats/run_test.go +++ b/internal/caveats/run_test.go @@ -471,7 +471,7 @@ func TestRunCaveatExpressions(t *testing.T) { req.NoError(err) dl := datalayer.NewDataLayer(ds) - sr, err := dl.SnapshotReader(headRevision).ReadSchema(t.Context()) + sr, err := dl.SnapshotReader(headRevision, datalayer.NoSchemaHashForTesting).ReadSchema(t.Context()) req.NoError(err) for _, debugOption := range []RunCaveatExpressionDebugOption{ @@ -524,7 +524,7 @@ func TestRunCaveatWithMissingMap(t *testing.T) { req.NoError(err) dl := datalayer.NewDataLayer(ds) - sr, err := dl.SnapshotReader(headRevision).ReadSchema(t.Context()) + sr, err := dl.SnapshotReader(headRevision, datalayer.NoSchemaHashForTesting).ReadSchema(t.Context()) req.NoError(err) result, err := RunSingleCaveatExpression( @@ -556,7 +556,7 @@ func TestRunCaveatWithEmptyMap(t *testing.T) { req.NoError(err) dl := datalayer.NewDataLayer(ds) - sr, err := dl.SnapshotReader(headRevision).ReadSchema(t.Context()) + sr, err := dl.SnapshotReader(headRevision, datalayer.NoSchemaHashForTesting).ReadSchema(t.Context()) req.NoError(err) _, err = RunSingleCaveatExpression( @@ -594,7 +594,7 @@ func TestRunCaveatMultipleTimes(t *testing.T) { req.NoError(err) dl := datalayer.NewDataLayer(ds) - sr, err := dl.SnapshotReader(headRevision).ReadSchema(t.Context()) + sr, err := dl.SnapshotReader(headRevision, datalayer.NoSchemaHashForTesting).ReadSchema(t.Context()) req.NoError(err) runner := NewCaveatRunner(types.Default.TypeSet) @@ -662,7 +662,7 @@ func TestRunCaveatWithMissingDefinition(t *testing.T) { req.NoError(err) dl := datalayer.NewDataLayer(ds) - sr, err := dl.SnapshotReader(headRevision).ReadSchema(t.Context()) + sr, err := dl.SnapshotReader(headRevision, datalayer.NoSchemaHashForTesting).ReadSchema(t.Context()) req.NoError(err) // Try to run a caveat that doesn't exist @@ -697,7 +697,7 @@ func TestCaveatRunnerPopulateCaveatDefinitionsForExpr(t *testing.T) { req.NoError(err) dl := datalayer.NewDataLayer(ds) - sr, err := dl.SnapshotReader(headRevision).ReadSchema(t.Context()) + sr, err := dl.SnapshotReader(headRevision, datalayer.NoSchemaHashForTesting).ReadSchema(t.Context()) req.NoError(err) runner := NewCaveatRunner(types.Default.TypeSet) @@ -742,7 +742,7 @@ func TestCaveatRunnerEmptyExpression(t *testing.T) { req.NoError(err) dl := datalayer.NewDataLayer(ds) - sr, err := dl.SnapshotReader(headRevision).ReadSchema(t.Context()) + sr, err := dl.SnapshotReader(headRevision, datalayer.NoSchemaHashForTesting).ReadSchema(t.Context()) req.NoError(err) runner := NewCaveatRunner(types.Default.TypeSet) @@ -823,7 +823,7 @@ func TestUnknownCaveatOperation(t *testing.T) { req.NoError(err) dl := datalayer.NewDataLayer(ds) - sr, err := dl.SnapshotReader(headRevision).ReadSchema(t.Context()) + sr, err := dl.SnapshotReader(headRevision, datalayer.NoSchemaHashForTesting).ReadSchema(t.Context()) req.NoError(err) runner := NewCaveatRunner(types.Default.TypeSet) diff --git a/internal/datastore/common/chunkbytes.go b/internal/datastore/common/chunkbytes.go index 65d088d12..c04644215 100644 --- a/internal/datastore/common/chunkbytes.go +++ b/internal/datastore/common/chunkbytes.go @@ -6,8 +6,17 @@ import ( "fmt" sq "github.com/Masterminds/squirrel" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + + "github.com/authzed/spicedb/internal/telemetry/otelconv" ) +var tracer = otel.Tracer("spicedb/internal/datastore/common") + +// ErrNoChunksFound is returned when no chunks are found for a given key. +var ErrNoChunksFound = errors.New("no chunks found") + // ChunkedBytesTransaction defines the interface for executing SQL queries within a transaction. type ChunkedBytesTransaction interface { // ExecuteWrite executes an INSERT query. @@ -81,48 +90,70 @@ type SQLByteChunkerConfig[T any] struct { AliveValue T } +// WithExecutor returns a copy of the config with the specified executor. +func (c SQLByteChunkerConfig[T]) WithExecutor(executor ChunkedBytesExecutor) SQLByteChunkerConfig[T] { + c.Executor = executor + return c +} + +// WithTableName returns a copy of the config with the specified table name. +func (c SQLByteChunkerConfig[T]) WithTableName(tableName string) SQLByteChunkerConfig[T] { + c.TableName = tableName + return c +} + // SQLByteChunker provides methods for reading and writing byte data // that is chunked across multiple rows in a SQL table. type SQLByteChunker[T any] struct { config SQLByteChunkerConfig[T] } -// MustNewSQLByteChunker creates a new SQLByteChunker with the specified configuration. -// Panics if the configuration is invalid. -func MustNewSQLByteChunker[T any](config SQLByteChunkerConfig[T]) *SQLByteChunker[T] { +// NewSQLByteChunker creates a new SQLByteChunker with the specified configuration. +// Returns an error if the configuration is invalid. +func NewSQLByteChunker[T any](config SQLByteChunkerConfig[T]) (*SQLByteChunker[T], error) { if config.MaxChunkSize <= 0 { - panic("maxChunkSize must be greater than 0") + return nil, errors.New("maxChunkSize must be greater than 0") } if config.TableName == "" { - panic("tableName cannot be empty") + return nil, errors.New("tableName cannot be empty") } if config.NameColumn == "" { - panic("nameColumn cannot be empty") + return nil, errors.New("nameColumn cannot be empty") } if config.ChunkIndexColumn == "" { - panic("chunkIndexColumn cannot be empty") + return nil, errors.New("chunkIndexColumn cannot be empty") } if config.ChunkDataColumn == "" { - panic("chunkDataColumn cannot be empty") + return nil, errors.New("chunkDataColumn cannot be empty") } if config.PlaceholderFormat == nil { - panic("placeholderFormat cannot be nil") + return nil, errors.New("placeholderFormat cannot be nil") } if config.Executor == nil { - panic("executor cannot be nil") + return nil, errors.New("executor cannot be nil") } if config.WriteMode == WriteModeInsertWithTombstones { if config.CreatedAtColumn == "" { - panic("createdAtColumn is required when using WriteModeInsertWithTombstones") + return nil, errors.New("createdAtColumn is required when using WriteModeInsertWithTombstones") } if config.DeletedAtColumn == "" { - panic("deletedAtColumn is required when using WriteModeInsertWithTombstones") + return nil, errors.New("deletedAtColumn is required when using WriteModeInsertWithTombstones") } } return &SQLByteChunker[T]{ config: config, + }, nil +} + +// MustNewSQLByteChunker creates a new SQLByteChunker with the specified configuration. +// Panics if the configuration is invalid. +func MustNewSQLByteChunker[T any](config SQLByteChunkerConfig[T]) *SQLByteChunker[T] { + chunker, err := NewSQLByteChunker(config) + if err != nil { + panic(err) } + return chunker } // WriteChunkedBytes writes chunked byte data to the database within a transaction. @@ -143,6 +174,13 @@ func (c *SQLByteChunker[T]) WriteChunkedBytes( return errors.New("name cannot be empty") } + ctx, span := tracer.Start(ctx, "WriteChunkedBytes") + defer span.End() + span.SetAttributes( + attribute.String(otelconv.AttrSchemaDefinitionName, name), + attribute.Int(otelconv.AttrSchemaDataSizeBytes, len(data)), + ) + // Begin transaction txn, err := c.config.Executor.BeginTransaction(ctx) if err != nil { @@ -186,6 +224,7 @@ func (c *SQLByteChunker[T]) WriteChunkedBytes( // Handle empty data case - insert a single empty chunk chunks = [][]byte{{}} } + span.SetAttributes(attribute.Int(otelconv.AttrSchemaChunkCount, len(chunks))) // Set up the columns - base columns plus created_at (if using tombstone mode) columns := []string{c.config.NameColumn, c.config.ChunkIndexColumn, c.config.ChunkDataColumn} @@ -230,6 +269,10 @@ func (c *SQLByteChunker[T]) DeleteChunkedBytes( return errors.New("name cannot be empty") } + ctx, span := tracer.Start(ctx, "DeleteChunkedBytes") + defer span.End() + span.SetAttributes(attribute.String(otelconv.AttrSchemaDefinitionName, name)) + // Begin transaction txn, err := c.config.Executor.BeginTransaction(ctx) if err != nil { @@ -279,6 +322,10 @@ func (c *SQLByteChunker[T]) ReadChunkedBytes( return nil, errors.New("name cannot be empty") } + ctx, span := tracer.Start(ctx, "ReadChunkedBytes") + defer span.End() + span.SetAttributes(attribute.String(otelconv.AttrSchemaDefinitionName, name)) + selectBuilder := sq.StatementBuilder. PlaceholderFormat(c.config.PlaceholderFormat). Select(c.config.ChunkIndexColumn, c.config.ChunkDataColumn). @@ -292,12 +339,18 @@ func (c *SQLByteChunker[T]) ReadChunkedBytes( return nil, fmt.Errorf("failed to read chunks: %w", err) } + span.SetAttributes( + attribute.Int(otelconv.AttrSchemaChunkCount, len(chunks)), + ) + // Reassemble the chunks data, err := c.reassembleChunks(chunks) if err != nil { return nil, fmt.Errorf("failed to reassemble chunks: %w", err) } + span.SetAttributes(attribute.Int(otelconv.AttrSchemaDataSizeBytes, len(data))) + return data, nil } @@ -305,7 +358,7 @@ func (c *SQLByteChunker[T]) ReadChunkedBytes( // into the original byte array. It validates that all chunks are present and in order. func (c *SQLByteChunker[T]) reassembleChunks(chunks map[int][]byte) ([]byte, error) { if len(chunks) == 0 { - return nil, errors.New("no chunks found") + return nil, ErrNoChunksFound } // Validate that we have all chunks from 0 to N-1 and calculate total size @@ -338,7 +391,10 @@ func (c *SQLByteChunker[T]) chunkData(data []byte) [][]byte { chunks := make([][]byte, 0, numChunks) for i := 0; i < len(data); i += c.config.MaxChunkSize { - end := min(i+c.config.MaxChunkSize, len(data)) + end := i + c.config.MaxChunkSize + if end > len(data) { + end = len(data) + } chunks = append(chunks, data[i:end]) } diff --git a/internal/datastore/common/chunkbytes_test.go b/internal/datastore/common/chunkbytes_test.go index 341b5d3a3..3cc54f2f1 100644 --- a/internal/datastore/common/chunkbytes_test.go +++ b/internal/datastore/common/chunkbytes_test.go @@ -57,6 +57,7 @@ type fakeExecutor struct { readResult map[int][]byte readErr error transaction *fakeTransaction + onRead func() // Optional callback invoked on each read } func (m *fakeExecutor) BeginTransaction(ctx context.Context) (ChunkedBytesTransaction, error) { @@ -64,6 +65,9 @@ func (m *fakeExecutor) BeginTransaction(ctx context.Context) (ChunkedBytesTransa } func (m *fakeExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + if m.onRead != nil { + m.onRead() + } if m.readErr != nil { return nil, m.readErr } @@ -392,6 +396,7 @@ func TestReadChunkedBytes(t *testing.T) { name string chunks map[int][]byte expectedData []byte + expectedErr error expectedError string }{ { @@ -418,9 +423,9 @@ func TestReadChunkedBytes(t *testing.T) { expectedData: []byte{}, }, { - name: "no chunks", - chunks: map[int][]byte{}, - expectedError: "no chunks found", + name: "no chunks", + chunks: map[int][]byte{}, + expectedErr: ErrNoChunksFound, }, { name: "missing chunk in sequence", @@ -450,10 +455,13 @@ func TestReadChunkedBytes(t *testing.T) { data, err := chunker.ReadChunkedBytes(context.Background(), "test-key") - if tt.expectedError != "" { + switch { + case tt.expectedErr != nil: + require.ErrorIs(t, err, tt.expectedErr) + case tt.expectedError != "": require.Error(t, err) require.Contains(t, err.Error(), tt.expectedError) - } else { + default: require.NoError(t, err) require.Equal(t, tt.expectedData, data) } @@ -677,6 +685,49 @@ func TestWriteChunkedBytes_LargeData_DeleteAndInsert(t *testing.T) { require.Len(t, insertArgs, 33) // 11 chunks * 3 values per chunk } +func TestWithExecutor(t *testing.T) { + executor1 := &fakeExecutor{} + executor2 := &fakeExecutor{} + + config := SQLByteChunkerConfig[uint64]{ + TableName: "test_table", + NameColumn: "name", + ChunkIndexColumn: "chunk_index", + ChunkDataColumn: "chunk_data", + MaxChunkSize: 1024, + PlaceholderFormat: sq.Question, + Executor: executor1, + WriteMode: WriteModeDeleteAndInsert, + } + + // WithExecutor should return a copy with the new executor. + newConfig := config.WithExecutor(executor2) + require.Equal(t, executor2, newConfig.Executor) + + // Original should be unchanged. + require.Equal(t, executor1, config.Executor) +} + +func TestWithTableName(t *testing.T) { + config := SQLByteChunkerConfig[uint64]{ + TableName: "original_table", + NameColumn: "name", + ChunkIndexColumn: "chunk_index", + ChunkDataColumn: "chunk_data", + MaxChunkSize: 1024, + PlaceholderFormat: sq.Question, + Executor: &fakeExecutor{}, + WriteMode: WriteModeDeleteAndInsert, + } + + // WithTableName should return a copy with the new table name. + newConfig := config.WithTableName("new_table") + require.Equal(t, "new_table", newConfig.TableName) + + // Original should be unchanged. + require.Equal(t, "original_table", config.TableName) +} + func TestWriteChunkedBytes_LargeData_InsertWithTombstones(t *testing.T) { txn := &fakeTransaction{} executor := &fakeExecutor{transaction: txn} diff --git a/internal/datastore/common/sqlschema.go b/internal/datastore/common/sqlschema.go new file mode 100644 index 000000000..4e2f8d9fb --- /dev/null +++ b/internal/datastore/common/sqlschema.go @@ -0,0 +1,141 @@ +package common + +import ( + "context" + "errors" + "fmt" + "math" + + "github.com/authzed/spicedb/pkg/datastore" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +const ( + // UnifiedSchemaName is the name used to store the unified schema in the schema table. + UnifiedSchemaName = "unified_schema" + + // LiveDeletedTxnID is the transaction ID value used to indicate a row has not been deleted yet. + // This is the maximum value for uint64, used by databases with tombstone-based deletion. + LiveDeletedTxnID = uint64(math.MaxInt64) +) + +// SQLSingleStoreSchemaReaderWriter implements both SingleStoreSchemaReader and SingleStoreSchemaWriter +// using the SQL byte chunking system. This provides a common implementation that SQL-based datastores +// can use. +// +// The type parameter T represents the transaction ID type used by the datastore: +// - uint64 for datastores with numeric transaction IDs (MySQL, Postgres) +// - any for datastores without transaction IDs (CRDB, Spanner in delete-and-insert mode) +type SQLSingleStoreSchemaReaderWriter[T any] struct { + chunker *SQLByteChunker[T] + transactionIDProvider func(ctx context.Context) T +} + +// NewSQLSingleStoreSchemaReaderWriter creates a new SQL-based single-store schema reader/writer. +// The transactionIDProvider function should return the appropriate transaction ID for the current context. +// For datastores that don't use transaction IDs, this should return the zero value. +func NewSQLSingleStoreSchemaReaderWriter[T any]( + chunker *SQLByteChunker[T], + transactionIDProvider func(ctx context.Context) T, +) *SQLSingleStoreSchemaReaderWriter[T] { + return &SQLSingleStoreSchemaReaderWriter[T]{ + chunker: chunker, + transactionIDProvider: transactionIDProvider, + } +} + +// ReadStoredSchema reads the stored schema from the unified schema table. +func (s *SQLSingleStoreSchemaReaderWriter[T]) ReadStoredSchema(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + data, err := s.chunker.ReadChunkedBytes(ctx, UnifiedSchemaName) + if err != nil { + if isNoChunksFoundError(err) { + return nil, datastore.ErrSchemaNotFound + } + return nil, fmt.Errorf("failed to read schema: %w", err) + } + + storedSchema := &core.StoredSchema{} + if err := storedSchema.UnmarshalVT(data); err != nil { + return nil, fmt.Errorf("failed to unmarshal schema: %w", err) + } + + return datastore.NewReadOnlyStoredSchema(storedSchema), nil +} + +// WriteStoredSchema writes the stored schema to the unified schema table. +func (s *SQLSingleStoreSchemaReaderWriter[T]) WriteStoredSchema(ctx context.Context, storedSchema *core.StoredSchema) error { + data, err := storedSchema.MarshalVT() + if err != nil { + return fmt.Errorf("failed to marshal schema: %w", err) + } + + transactionID := s.transactionIDProvider(ctx) + if err := s.chunker.WriteChunkedBytes(ctx, UnifiedSchemaName, data, transactionID); err != nil { + return fmt.Errorf("failed to write schema: %w", err) + } + + return nil +} + +// isNoChunksFoundError checks if the error indicates no chunks were found. +func isNoChunksFoundError(err error) bool { + return errors.Is(err, ErrNoChunksFound) +} + +// NoTransactionID is a helper function that returns a zero-value transaction ID. +// This can be used for datastores that don't track transaction IDs. +func NoTransactionID[T any](_ context.Context) T { + var zero T + return zero +} + +// StaticTransactionID returns a function that always returns the given transaction ID. +// This is useful for testing or for datastores that use a constant value. +func StaticTransactionID[T any](id T) func(context.Context) T { + return func(_ context.Context) T { + return id + } +} + +// NewSQLSingleStoreSchemaReaderWriterForTransactionIDs is a convenience function for creating a schema reader/writer +// that uses uint64 transaction IDs for datastores with explicit transaction ID tracking (e.g., MySQL, Postgres). +func NewSQLSingleStoreSchemaReaderWriterForTransactionIDs( + chunker *SQLByteChunker[uint64], + transactionIDProvider func(ctx context.Context) uint64, +) *SQLSingleStoreSchemaReaderWriter[uint64] { + return NewSQLSingleStoreSchemaReaderWriter(chunker, transactionIDProvider) +} + +// NewSQLSingleStoreSchemaReaderWriterWithBuiltInMVCC is a convenience function for creating a schema reader/writer +// for datastores with built-in MVCC that don't require explicit transaction ID tracking (e.g., CRDB, Spanner). +func NewSQLSingleStoreSchemaReaderWriterWithBuiltInMVCC( + chunker *SQLByteChunker[any], +) *SQLSingleStoreSchemaReaderWriter[any] { + return NewSQLSingleStoreSchemaReaderWriter(chunker, NoTransactionID[any]) +} + +// ValidateStoredSchema validates that a stored schema is well-formed. +func ValidateStoredSchema(storedSchema *core.StoredSchema) error { + if storedSchema == nil { + return errors.New("stored schema is nil") + } + + if storedSchema.Version == 0 { + return errors.New("stored schema version is 0") + } + + v1 := storedSchema.GetV1() + if v1 == nil { + return fmt.Errorf("unsupported schema version: %d", storedSchema.Version) + } + + if v1.SchemaText == "" { + return errors.New("schema text is empty") + } + + if v1.SchemaHash == "" { + return errors.New("schema hash is empty") + } + + return nil +} diff --git a/internal/datastore/common/sqlschema_test.go b/internal/datastore/common/sqlschema_test.go new file mode 100644 index 000000000..6cf271852 --- /dev/null +++ b/internal/datastore/common/sqlschema_test.go @@ -0,0 +1,387 @@ +package common + +import ( + "context" + "testing" + + sq "github.com/Masterminds/squirrel" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + "github.com/authzed/spicedb/pkg/datastore" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +func TestSQLSingleStoreSchemaReaderWriter_WriteAndRead(t *testing.T) { + tests := []struct { + name string + schemaText string + namespaces map[string]*core.NamespaceDefinition + caveats map[string]*core.CaveatDefinition + useTransaction bool + }{ + { + name: "simple schema", + schemaText: "definition user {}", + namespaces: map[string]*core.NamespaceDefinition{ + "user": {Name: "user"}, + }, + caveats: map[string]*core.CaveatDefinition{}, + useTransaction: true, + }, + { + name: "schema with caveat", + schemaText: "caveat is_allowed(allowed bool) { allowed }\ndefinition user {}", + namespaces: map[string]*core.NamespaceDefinition{ + "user": {Name: "user"}, + }, + caveats: map[string]*core.CaveatDefinition{ + "is_allowed": {Name: "is_allowed"}, + }, + useTransaction: true, + }, + { + name: "empty schema", + schemaText: "", + namespaces: map[string]*core.NamespaceDefinition{}, + caveats: map[string]*core.CaveatDefinition{}, + useTransaction: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Build stored schema + storedSchema := &core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: tt.schemaText, + SchemaHash: "test-hash", + NamespaceDefinitions: tt.namespaces, + CaveatDefinitions: tt.caveats, + }, + }, + } + + // Marshal the schema to simulate what would be written + expectedData, err := proto.Marshal(storedSchema) + require.NoError(t, err) + + // Create chunker with fake executor + txn := &fakeTransaction{} + executor := &fakeExecutor{ + transaction: txn, + readResult: map[int][]byte{0: expectedData}, // Return the expected data + } + + chunker := MustNewSQLByteChunker(SQLByteChunkerConfig[uint64]{ + TableName: "schema", + NameColumn: "name", + ChunkIndexColumn: "chunk_index", + ChunkDataColumn: "chunk_data", + MaxChunkSize: 1024 * 64, + PlaceholderFormat: sq.Question, + Executor: executor, + WriteMode: WriteModeInsertWithTombstones, + CreatedAtColumn: "created_at", + DeletedAtColumn: "deleted_at", + AliveValue: 999, + }) + + transactionIDProvider := func(ctx context.Context) uint64 { + if tt.useTransaction { + return 100 + } + return 0 + } + + readerWriter := NewSQLSingleStoreSchemaReaderWriter(chunker, transactionIDProvider) + + // Write schema + ctx := context.Background() + err = readerWriter.WriteStoredSchema(ctx, storedSchema) + require.NoError(t, err) + + // Verify write was called + require.NotEmpty(t, txn.capturedSQL) + + // Read schema back + readSchema, err := readerWriter.ReadStoredSchema(ctx) + require.NoError(t, err) + require.NotNil(t, readSchema) + + // Verify + require.Equal(t, storedSchema.Version, readSchema.Get().Version) + require.NotNil(t, readSchema.Get().GetV1()) + require.Equal(t, tt.schemaText, readSchema.Get().GetV1().SchemaText) + require.Equal(t, "test-hash", readSchema.Get().GetV1().SchemaHash) + require.Len(t, readSchema.Get().GetV1().NamespaceDefinitions, len(tt.namespaces)) + require.Len(t, readSchema.Get().GetV1().CaveatDefinitions, len(tt.caveats)) + }) + } +} + +func TestSQLSingleStoreSchemaReaderWriter_ReadNotFound(t *testing.T) { + // Create chunker with fake executor that returns empty result + executor := &fakeExecutor{ + readResult: map[int][]byte{}, // No chunks found + } + chunker := MustNewSQLByteChunker(SQLByteChunkerConfig[uint64]{ + TableName: "schema", + NameColumn: "name", + ChunkIndexColumn: "chunk_index", + ChunkDataColumn: "chunk_data", + MaxChunkSize: 1024 * 64, + PlaceholderFormat: sq.Question, + Executor: executor, + WriteMode: WriteModeDeleteAndInsert, + }) + + readerWriter := NewSQLSingleStoreSchemaReaderWriter(chunker, StaticTransactionID[uint64](0)) + + // Try to read non-existent schema + ctx := context.Background() + _, err := readerWriter.ReadStoredSchema(ctx) + require.Error(t, err) + require.ErrorIs(t, err, datastore.ErrSchemaNotFound) +} + +func TestSQLSingleStoreSchemaReaderWriter_WithBuiltInMVCC(t *testing.T) { + // Test with "any" type parameter for datastores with built-in MVCC (like CRDB/Spanner) + storedSchema := &core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: "definition user {}", + SchemaHash: "test-hash", + NamespaceDefinitions: map[string]*core.NamespaceDefinition{"user": {Name: "user"}}, + CaveatDefinitions: map[string]*core.CaveatDefinition{}, + }, + }, + } + + expectedData, err := storedSchema.MarshalVT() + require.NoError(t, err) + + txn := &fakeTransaction{} + executor := &fakeExecutor{ + transaction: txn, + readResult: map[int][]byte{0: expectedData}, + } + + chunker := MustNewSQLByteChunker(SQLByteChunkerConfig[any]{ + TableName: "schema", + NameColumn: "name", + ChunkIndexColumn: "chunk_index", + ChunkDataColumn: "chunk_data", + MaxChunkSize: 1024 * 64, + PlaceholderFormat: sq.Question, + Executor: executor, + WriteMode: WriteModeDeleteAndInsert, + }) + + readerWriter := NewSQLSingleStoreSchemaReaderWriterWithBuiltInMVCC(chunker) + + // Write schema + ctx := context.Background() + err = readerWriter.WriteStoredSchema(ctx, storedSchema) + require.NoError(t, err) + + // Read schema back + readSchema, err := readerWriter.ReadStoredSchema(ctx) + require.NoError(t, err) + require.NotNil(t, readSchema) + require.Equal(t, "definition user {}", readSchema.Get().GetV1().SchemaText) +} + +func TestValidateStoredSchema(t *testing.T) { + tests := []struct { + name string + schema *core.StoredSchema + expectError bool + errorMsg string + }{ + { + name: "nil schema", + schema: nil, + expectError: true, + errorMsg: "stored schema is nil", + }, + { + name: "zero version", + schema: &core.StoredSchema{ + Version: 0, + }, + expectError: true, + errorMsg: "stored schema version is 0", + }, + { + name: "missing v1", + schema: &core.StoredSchema{ + Version: 1, + }, + expectError: true, + errorMsg: "unsupported schema version", + }, + { + name: "empty schema text", + schema: &core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: "", + SchemaHash: "hash", + }, + }, + }, + expectError: true, + errorMsg: "schema text is empty", + }, + { + name: "empty schema hash", + schema: &core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: "definition user {}", + SchemaHash: "", + }, + }, + }, + expectError: true, + errorMsg: "schema hash is empty", + }, + { + name: "valid schema", + schema: &core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: "definition user {}", + SchemaHash: "hash", + NamespaceDefinitions: map[string]*core.NamespaceDefinition{}, + CaveatDefinitions: map[string]*core.CaveatDefinition{}, + }, + }, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateStoredSchema(tt.schema) + if tt.expectError { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestBuildAndMarshal(t *testing.T) { + schemaText := "definition user {}\ncaveat is_allowed(allowed bool) { allowed }" + + // Build stored schema directly + storedSchema := &core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: schemaText, + SchemaHash: "test-hash", + NamespaceDefinitions: map[string]*core.NamespaceDefinition{ + "user": {Name: "user"}, + }, + CaveatDefinitions: map[string]*core.CaveatDefinition{ + "is_allowed": {Name: "is_allowed"}, + }, + }, + }, + } + + require.Equal(t, uint32(1), storedSchema.Version) + + v1 := storedSchema.GetV1() + require.NotNil(t, v1) + require.Equal(t, schemaText, v1.SchemaText) + require.Len(t, v1.NamespaceDefinitions, 1) + require.Len(t, v1.CaveatDefinitions, 1) + + // Marshal + data, err := storedSchema.MarshalVT() + require.NoError(t, err) + require.NotEmpty(t, data) + + // Unmarshal + unmarshaled := &core.StoredSchema{} + err = unmarshaled.UnmarshalVT(data) + require.NoError(t, err) + require.NotNil(t, unmarshaled) + require.True(t, storedSchema.EqualVT(unmarshaled)) +} + +func TestHelperFunctions(t *testing.T) { + // Test NoTransactionID + ctx := context.Background() + result := NoTransactionID[uint64](ctx) + require.Equal(t, uint64(0), result) + + // Test StaticTransactionID + staticFunc := StaticTransactionID[uint64](12345) + require.Equal(t, uint64(12345), staticFunc(ctx)) + + // Test with any type + anyResult := NoTransactionID[any](ctx) + require.Nil(t, anyResult) +} + +func TestNewSQLSingleStoreSchemaReaderWriterForTransactionIDs(t *testing.T) { + storedSchema := &core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: "definition user {}", + SchemaHash: "test-hash", + NamespaceDefinitions: map[string]*core.NamespaceDefinition{"user": {Name: "user"}}, + CaveatDefinitions: map[string]*core.CaveatDefinition{}, + }, + }, + } + + expectedData, err := storedSchema.MarshalVT() + require.NoError(t, err) + + txn := &fakeTransaction{} + executor := &fakeExecutor{ + transaction: txn, + readResult: map[int][]byte{0: expectedData}, + } + + chunker := MustNewSQLByteChunker(SQLByteChunkerConfig[uint64]{ + TableName: "schema", + NameColumn: "name", + ChunkIndexColumn: "chunk_index", + ChunkDataColumn: "chunk_data", + MaxChunkSize: 1024 * 64, + PlaceholderFormat: sq.Question, + Executor: executor, + WriteMode: WriteModeInsertWithTombstones, + CreatedAtColumn: "created_txn", + DeletedAtColumn: "deleted_txn", + AliveValue: uint64(9223372036854775807), + }) + + readerWriter := NewSQLSingleStoreSchemaReaderWriterForTransactionIDs(chunker, StaticTransactionID[uint64](42)) + + ctx := context.Background() + err = readerWriter.WriteStoredSchema(ctx, storedSchema) + require.NoError(t, err) + + readSchema, err := readerWriter.ReadStoredSchema(ctx) + require.NoError(t, err) + require.NotNil(t, readSchema) + require.Equal(t, "definition user {}", readSchema.Get().GetV1().SchemaText) +} diff --git a/internal/datastore/crdb/migrations/zz_migration.0010_add_schema_tables.go b/internal/datastore/crdb/migrations/zz_migration.0010_add_schema_tables.go new file mode 100644 index 000000000..e5dc680da --- /dev/null +++ b/internal/datastore/crdb/migrations/zz_migration.0010_add_schema_tables.go @@ -0,0 +1,41 @@ +package migrations + +import ( + "context" + + "github.com/jackc/pgx/v5" +) + +const ( + createSchemaTable = `CREATE TABLE schema ( + name VARCHAR NOT NULL, + chunk_index INT NOT NULL, + chunk_data BYTEA NOT NULL, + timestamp TIMESTAMP WITHOUT TIME ZONE DEFAULT now() NOT NULL, + CONSTRAINT pk_schema PRIMARY KEY (name, chunk_index) + );` + + createSchemaRevisionTable = `CREATE TABLE schema_revision ( + name VARCHAR NOT NULL DEFAULT 'current', + hash BYTEA NOT NULL, + timestamp TIMESTAMP WITHOUT TIME ZONE DEFAULT now() NOT NULL, + CONSTRAINT pk_schema_revision PRIMARY KEY (name) + );` +) + +func init() { + err := CRDBMigrations.Register("add-schema-tables", "add-expiration-support", addSchemaTablesFunc, noAtomicMigration) + if err != nil { + panic("failed to register migration: " + err.Error()) + } +} + +func addSchemaTablesFunc(ctx context.Context, conn *pgx.Conn) error { + if _, err := conn.Exec(ctx, createSchemaTable); err != nil { + return err + } + if _, err := conn.Exec(ctx, createSchemaRevisionTable); err != nil { + return err + } + return nil +} diff --git a/internal/datastore/crdb/migrations/zz_migration.0011_populate_schema_tables.go b/internal/datastore/crdb/migrations/zz_migration.0011_populate_schema_tables.go new file mode 100644 index 000000000..7cb50142c --- /dev/null +++ b/internal/datastore/crdb/migrations/zz_migration.0011_populate_schema_tables.go @@ -0,0 +1,167 @@ +package migrations + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "sort" + + "github.com/jackc/pgx/v5" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" +) + +const ( + schemaChunkSize = 1024 * 1024 // 1MB chunks + currentSchemaVersion = 1 + unifiedSchemaName = "unified_schema" + schemaRevisionName = "current" +) + +func init() { + if err := CRDBMigrations.Register("populate-schema-tables", "add-schema-tables", noNonAtomicMigration, func(ctx context.Context, tx pgx.Tx) error { + // Read all existing namespaces + rows, err := tx.Query(ctx, ` + SELECT namespace, serialized_config + FROM namespace_config + `) + if err != nil { + return fmt.Errorf("failed to query namespaces: %w", err) + } + defer rows.Close() + + namespaces := make(map[string]*core.NamespaceDefinition) + for rows.Next() { + var name string + var config []byte + if err := rows.Scan(&name, &config); err != nil { + return fmt.Errorf("failed to scan namespace: %w", err) + } + + var ns core.NamespaceDefinition + if err := ns.UnmarshalVT(config); err != nil { + return fmt.Errorf("failed to unmarshal namespace %s: %w", name, err) + } + namespaces[name] = &ns + } + if err := rows.Err(); err != nil { + return fmt.Errorf("error iterating namespaces: %w", err) + } + + // Read all existing caveats + rows, err = tx.Query(ctx, ` + SELECT name, definition + FROM caveat + `) + if err != nil { + return fmt.Errorf("failed to query caveats: %w", err) + } + defer rows.Close() + + caveats := make(map[string]*core.CaveatDefinition) + for rows.Next() { + var name string + var definition []byte + if err := rows.Scan(&name, &definition); err != nil { + return fmt.Errorf("failed to scan caveat: %w", err) + } + + var caveat core.CaveatDefinition + if err := caveat.UnmarshalVT(definition); err != nil { + return fmt.Errorf("failed to unmarshal caveat %s: %w", name, err) + } + caveats[name] = &caveat + } + if err := rows.Err(); err != nil { + return fmt.Errorf("error iterating caveats: %w", err) + } + + // If there are no namespaces or caveats, skip migration + if len(namespaces) == 0 && len(caveats) == 0 { + return nil + } + + // Generate canonical schema for hash computation + allDefs := make([]compiler.SchemaDefinition, 0, len(namespaces)+len(caveats)) + for _, ns := range namespaces { + allDefs = append(allDefs, ns) + } + for _, caveat := range caveats { + allDefs = append(allDefs, caveat) + } + + // Sort alphabetically for canonical ordering + sort.Slice(allDefs, func(i, j int) bool { + return allDefs[i].GetName() < allDefs[j].GetName() + }) + + // Generate canonical schema text + canonicalSchemaText, _, err := generator.GenerateSchema(allDefs) + if err != nil { + return fmt.Errorf("failed to generate canonical schema: %w", err) + } + + // Compute SHA256 hash + hashBytes := sha256.Sum256([]byte(canonicalSchemaText)) + schemaHash := hashBytes[:] + + // Generate user-facing schema text + schemaText, _, err := generator.GenerateSchema(allDefs) + if err != nil { + return fmt.Errorf("failed to generate schema: %w", err) + } + + // Create stored schema proto + storedSchema := &core.StoredSchema{ + Version: currentSchemaVersion, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: schemaText, + SchemaHash: hex.EncodeToString(schemaHash), + NamespaceDefinitions: namespaces, + CaveatDefinitions: caveats, + }, + }, + } + + // Marshal schema + schemaData, err := storedSchema.MarshalVT() + if err != nil { + return fmt.Errorf("failed to marshal schema: %w", err) + } + + // Insert schema chunks + for chunkIndex := 0; chunkIndex*schemaChunkSize < len(schemaData); chunkIndex++ { + start := chunkIndex * schemaChunkSize + end := start + schemaChunkSize + if end > len(schemaData) { + end = len(schemaData) + } + chunk := schemaData[start:end] + + _, err = tx.Exec(ctx, ` + INSERT INTO schema (name, chunk_index, chunk_data) + VALUES ($1, $2, $3) + `, unifiedSchemaName, chunkIndex, chunk) + if err != nil { + return fmt.Errorf("failed to insert schema chunk %d: %w", chunkIndex, err) + } + } + + // Insert schema hash + _, err = tx.Exec(ctx, ` + INSERT INTO schema_revision (name, hash) + VALUES ($1, $2) + `, schemaRevisionName, schemaHash) + if err != nil { + return fmt.Errorf("failed to insert schema hash: %w", err) + } + + return nil + }); err != nil { + panic("failed to register migration: " + err.Error()) + } +} diff --git a/internal/datastore/crdb/schema_chunker.go b/internal/datastore/crdb/schema_chunker.go new file mode 100644 index 000000000..564aaa26e --- /dev/null +++ b/internal/datastore/crdb/schema_chunker.go @@ -0,0 +1,123 @@ +package crdb + +import ( + "context" + "errors" + + sq "github.com/Masterminds/squirrel" + "github.com/jackc/pgx/v5" + + "github.com/authzed/spicedb/internal/datastore/common" + pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" +) + +const ( + // CockroachDB has no practical limit on BYTEA column size (similar to Postgres), + // but we use 1MB chunks for reasonable memory usage and query performance. + crdbMaxChunkSize = 1024 * 1024 // 1MB +) + +// BaseSchemaChunkerConfig provides the base configuration for CRDB schema chunking. +// CRDB uses delete-and-insert write mode since it handles MVCC automatically. +var BaseSchemaChunkerConfig = common.SQLByteChunkerConfig[any]{ + TableName: "schema", + NameColumn: "name", + ChunkIndexColumn: "chunk_index", + ChunkDataColumn: "chunk_data", + MaxChunkSize: crdbMaxChunkSize, + PlaceholderFormat: sq.Dollar, + WriteMode: common.WriteModeDeleteAndInsert, +} + +// revisionAwareExecutor wraps the reader's query infrastructure to provide revision-aware chunk reading +type revisionAwareExecutor struct { + query pgxcommon.DBFuncQuerier + addFromToQuery func(sq.SelectBuilder, string, string) sq.SelectBuilder + assertAsOfSysTime func(string) +} + +func (e *revisionAwareExecutor) BeginTransaction(ctx context.Context) (common.ChunkedBytesTransaction, error) { + // We don't support transactions for reading + return nil, errors.New("transactions not supported for revision-aware reads") +} + +func (e *revisionAwareExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + // Modify the builder to add AS OF SYSTEM TIME + builder = e.addFromToQuery(builder, "schema", "") + + sql, args, err := builder.ToSql() + if err != nil { + return nil, err + } + e.assertAsOfSysTime(sql) + + // Execute using the reader's query function + result := make(map[int][]byte) + err = e.query.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error { + defer rows.Close() + for rows.Next() { + var chunkIndex int + var chunkData []byte + if err := rows.Scan(&chunkIndex, &chunkData); err != nil { + return err + } + result[chunkIndex] = chunkData + } + return rows.Err() + }, sql, args...) + + return result, err +} + +// transactionAwareExecutor wraps an existing pgx.Tx to provide transaction-aware chunk writing +type transactionAwareExecutor struct { + tx pgx.Tx +} + +func newTransactionAwareExecutor(tx pgx.Tx) *transactionAwareExecutor { + return &transactionAwareExecutor{tx: tx} +} + +func (e *transactionAwareExecutor) BeginTransaction(ctx context.Context) (common.ChunkedBytesTransaction, error) { + // Return a transaction wrapper that uses the existing transaction + return &transactionAwareTransaction{tx: e.tx}, nil +} + +func (e *transactionAwareExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + return nil, errors.New("read operations not supported on transaction-aware executor") +} + +// transactionAwareTransaction implements common.ChunkedBytesTransaction using an existing pgx.Tx +type transactionAwareTransaction struct { + tx pgx.Tx +} + +func (t *transactionAwareTransaction) ExecuteWrite(ctx context.Context, builder sq.InsertBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.Exec(ctx, sql, args...) + return err +} + +func (t *transactionAwareTransaction) ExecuteDelete(ctx context.Context, builder sq.DeleteBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.Exec(ctx, sql, args...) + return err +} + +func (t *transactionAwareTransaction) ExecuteUpdate(ctx context.Context, builder sq.UpdateBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.Exec(ctx, sql, args...) + return err +} diff --git a/internal/datastore/crdb/schema_chunker_test.go b/internal/datastore/crdb/schema_chunker_test.go new file mode 100644 index 000000000..040eced82 --- /dev/null +++ b/internal/datastore/crdb/schema_chunker_test.go @@ -0,0 +1,145 @@ +package crdb + +import ( + "context" + "testing" + + sq "github.com/Masterminds/squirrel" + "github.com/stretchr/testify/require" + + "github.com/authzed/spicedb/internal/datastore/common" +) + +// fakeTransaction captures SQL and args via builder.ToSql() for verification. +type fakeTransaction struct { + capturedSQL []string + capturedArgs [][]any + deleteQueries []string +} + +func (f *fakeTransaction) ExecuteWrite(_ context.Context, builder sq.InsertBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + f.capturedSQL = append(f.capturedSQL, sql) + f.capturedArgs = append(f.capturedArgs, args) + return nil +} + +func (f *fakeTransaction) ExecuteDelete(_ context.Context, builder sq.DeleteBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + f.capturedSQL = append(f.capturedSQL, sql) + f.capturedArgs = append(f.capturedArgs, args) + f.deleteQueries = append(f.deleteQueries, sql) + return nil +} + +func (f *fakeTransaction) ExecuteUpdate(_ context.Context, builder sq.UpdateBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + f.capturedSQL = append(f.capturedSQL, sql) + f.capturedArgs = append(f.capturedArgs, args) + return nil +} + +// fakeExecutor returns the fakeTransaction from BeginTransaction. +type fakeExecutor struct { + transaction *fakeTransaction + readResult map[int][]byte +} + +func (e *fakeExecutor) BeginTransaction(_ context.Context) (common.ChunkedBytesTransaction, error) { + return e.transaction, nil +} + +func (e *fakeExecutor) ExecuteRead(_ context.Context, _ sq.SelectBuilder) (map[int][]byte, error) { + return e.readResult, nil +} + +func TestBaseSchemaChunkerConfig(t *testing.T) { + require.Equal(t, "schema", BaseSchemaChunkerConfig.TableName) + require.Equal(t, "name", BaseSchemaChunkerConfig.NameColumn) + require.Equal(t, "chunk_index", BaseSchemaChunkerConfig.ChunkIndexColumn) + require.Equal(t, "chunk_data", BaseSchemaChunkerConfig.ChunkDataColumn) + require.Equal(t, 1024*1024, BaseSchemaChunkerConfig.MaxChunkSize) + require.Equal(t, sq.Dollar, BaseSchemaChunkerConfig.PlaceholderFormat) + require.Equal(t, common.WriteModeDeleteAndInsert, BaseSchemaChunkerConfig.WriteMode) +} + +func TestWrite(t *testing.T) { + txn := &fakeTransaction{} + executor := &fakeExecutor{transaction: txn} + + config := BaseSchemaChunkerConfig.WithExecutor(executor) + chunker := common.MustNewSQLByteChunker(config) + + err := chunker.WriteChunkedBytes(context.Background(), "test-key", []byte("hello"), nil) + require.NoError(t, err) + + // Should have DELETE + INSERT + require.Len(t, txn.capturedSQL, 2) + require.Len(t, txn.deleteQueries, 1) + + // DELETE uses $ placeholders + require.Contains(t, txn.capturedSQL[0], "DELETE FROM schema") + require.Contains(t, txn.capturedSQL[0], "$1") + + // INSERT uses $ placeholders + require.Contains(t, txn.capturedSQL[1], "INSERT INTO schema") + require.Contains(t, txn.capturedSQL[1], "$1") +} + +func TestDelete(t *testing.T) { + txn := &fakeTransaction{} + executor := &fakeExecutor{transaction: txn} + + config := BaseSchemaChunkerConfig.WithExecutor(executor) + chunker := common.MustNewSQLByteChunker(config) + + err := chunker.DeleteChunkedBytes(context.Background(), "test-key", nil) + require.NoError(t, err) + + require.Len(t, txn.capturedSQL, 1) + require.Contains(t, txn.capturedSQL[0], "DELETE FROM schema") +} + +func TestRead(t *testing.T) { + executor := &fakeExecutor{ + readResult: map[int][]byte{ + 0: []byte("hello"), + }, + } + + config := BaseSchemaChunkerConfig.WithExecutor(executor) + chunker := common.MustNewSQLByteChunker(config) + + data, err := chunker.ReadChunkedBytes(context.Background(), "test-key") + require.NoError(t, err) + require.Equal(t, []byte("hello"), data) +} + +func TestMultipleChunks(t *testing.T) { + txn := &fakeTransaction{} + executor := &fakeExecutor{transaction: txn} + + config := BaseSchemaChunkerConfig.WithExecutor(executor) + config.MaxChunkSize = 5 + chunker := common.MustNewSQLByteChunker(config) + + err := chunker.WriteChunkedBytes(context.Background(), "test-key", []byte("hello world!"), nil) + require.NoError(t, err) + + // DELETE + INSERT + require.Len(t, txn.capturedSQL, 2) + + // "hello world!" is 12 bytes, chunk size 5 => 3 chunks (5+5+2) + // Each chunk has 3 args (name, chunk_index, chunk_data) + insertArgs := txn.capturedArgs[1] + require.Len(t, insertArgs, 9) // 3 chunks * 3 values +} diff --git a/internal/datastore/crdb/storedschema.go b/internal/datastore/crdb/storedschema.go new file mode 100644 index 000000000..c3001c5f0 --- /dev/null +++ b/internal/datastore/crdb/storedschema.go @@ -0,0 +1,46 @@ +package crdb + +import ( + "context" + "fmt" + + sq "github.com/Masterminds/squirrel" + + "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/pkg/datastore" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +// ReadStoredSchema reads the unified stored schema from the CRDB schema table. +func (cr *crdbReader) ReadStoredSchema(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + executor := &revisionAwareExecutor{ + query: cr.query, + addFromToQuery: func(builder sq.SelectBuilder, tableName string, indexHint string) sq.SelectBuilder { + return cr.addFromToQuery(builder, tableName, indexHint) + }, + assertAsOfSysTime: func(_ string) { + // No-op: the addFromToQuery already adds AS OF SYSTEM TIME + }, + } + + chunker, err := common.NewSQLByteChunker(BaseSchemaChunkerConfig.WithExecutor(executor)) + if err != nil { + return nil, fmt.Errorf("failed to create schema chunker: %w", err) + } + + rw := common.NewSQLSingleStoreSchemaReaderWriterWithBuiltInMVCC(chunker) + return rw.ReadStoredSchema(ctx) +} + +// WriteStoredSchema writes the unified stored schema to the CRDB schema table. +func (rwt *crdbReadWriteTXN) WriteStoredSchema(ctx context.Context, schema *core.StoredSchema) error { + executor := newTransactionAwareExecutor(rwt.tx) + + chunker, err := common.NewSQLByteChunker(BaseSchemaChunkerConfig.WithExecutor(executor)) + if err != nil { + return fmt.Errorf("failed to create schema chunker: %w", err) + } + + rw := common.NewSQLSingleStoreSchemaReaderWriterWithBuiltInMVCC(chunker) + return rw.WriteStoredSchema(ctx, schema) +} diff --git a/internal/datastore/memdb/schema.go b/internal/datastore/memdb/schema.go index 7905d4854..35d2769f7 100644 --- a/internal/datastore/memdb/schema.go +++ b/internal/datastore/memdb/schema.go @@ -26,6 +26,9 @@ const ( tableChangelog = "changelog" indexRevision = "id" + + tableSchema = "schema" + tableSchemaRevision = "schemarevision" ) type namespace struct { @@ -45,6 +48,11 @@ type counter struct { updated datastore.Revision } +type schemaData struct { + name string + data []byte +} + type relationship struct { namespace string resourceID string @@ -228,5 +236,25 @@ var schema = &memdb.DBSchema{ }, }, }, + tableSchema: { + Name: tableSchema, + Indexes: map[string]*memdb.IndexSchema{ + indexID: { + Name: indexID, + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "name"}, + }, + }, + }, + tableSchemaRevision: { + Name: tableSchemaRevision, + Indexes: map[string]*memdb.IndexSchema{ + indexID: { + Name: indexID, + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "name"}, + }, + }, + }, }, } diff --git a/internal/datastore/memdb/storedschema.go b/internal/datastore/memdb/storedschema.go new file mode 100644 index 000000000..4270ec090 --- /dev/null +++ b/internal/datastore/memdb/storedschema.go @@ -0,0 +1,68 @@ +package memdb + +import ( + "context" + "fmt" + + "github.com/authzed/spicedb/pkg/datastore" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +const unifiedSchemaName = "unified_schema" + +// ReadStoredSchema reads the unified stored schema from the in-memory database. +func (r *memdbReader) ReadStoredSchema(_ context.Context) (*datastore.ReadOnlyStoredSchema, error) { + r.mustLock() + defer r.Unlock() + + tx, err := r.txSource() + if err != nil { + return nil, err + } + + raw, err := tx.First(tableSchema, indexID, unifiedSchemaName) + if err != nil { + return nil, fmt.Errorf("failed to read schema: %w", err) + } + + if raw == nil { + return nil, datastore.ErrSchemaNotFound + } + + sd, ok := raw.(*schemaData) + if !ok { + return nil, fmt.Errorf("unexpected schema data type: %T", raw) + } + + storedSchema := &core.StoredSchema{} + if err := storedSchema.UnmarshalVT(sd.data); err != nil { + return nil, fmt.Errorf("failed to unmarshal schema: %w", err) + } + + return datastore.NewReadOnlyStoredSchema(storedSchema), nil +} + +// WriteStoredSchema writes the unified stored schema to the in-memory database. +func (rwt *memdbReadWriteTx) WriteStoredSchema(_ context.Context, schema *core.StoredSchema) error { + rwt.mustLock() + defer rwt.Unlock() + + tx, err := rwt.txSource() + if err != nil { + return err + } + + data, err := schema.MarshalVT() + if err != nil { + return fmt.Errorf("failed to marshal schema: %w", err) + } + + if err := tx.Insert(tableSchema, &schemaData{ + name: unifiedSchemaName, + data: data, + }); err != nil { + return fmt.Errorf("failed to write schema: %w", err) + } + + return nil +} diff --git a/internal/datastore/mysql/migrations/tables.go b/internal/datastore/mysql/migrations/tables.go index b9aaa528f..41b726394 100644 --- a/internal/datastore/mysql/migrations/tables.go +++ b/internal/datastore/mysql/migrations/tables.go @@ -8,6 +8,8 @@ const ( tableMetadataDefault = "mysql_metadata" tableCaveatDefault = "caveat" tableRelationshipCounters = "relationship_counters" + tableSchemaDefault = "stored_schema" + tableSchemaRevDefault = "stored_schema_revision" ) type tables struct { @@ -18,6 +20,8 @@ type tables struct { tableMetadata string tableCaveat string tableRelationshipCounters string + tableSchema string + tableSchemaRevision string } func newTables(prefix string) *tables { @@ -29,6 +33,8 @@ func newTables(prefix string) *tables { tableMetadata: prefix + tableMetadataDefault, tableCaveat: prefix + tableCaveatDefault, tableRelationshipCounters: prefix + tableRelationshipCounters, + tableSchema: prefix + tableSchemaDefault, + tableSchemaRevision: prefix + tableSchemaRevDefault, } } @@ -62,3 +68,11 @@ func (tn *tables) Caveat() string { func (tn *tables) RelationshipCounters() string { return tn.tableRelationshipCounters } + +func (tn *tables) Schema() string { + return tn.tableSchema +} + +func (tn *tables) SchemaRevision() string { + return tn.tableSchemaRevision +} diff --git a/internal/datastore/mysql/migrations/zz_migration.0011_add_schema_tables.go b/internal/datastore/mysql/migrations/zz_migration.0011_add_schema_tables.go new file mode 100644 index 000000000..9e5f44628 --- /dev/null +++ b/internal/datastore/mysql/migrations/zz_migration.0011_add_schema_tables.go @@ -0,0 +1,35 @@ +package migrations + +import "fmt" + +func createSchemaTable(t *tables) string { + return fmt.Sprintf(`CREATE TABLE %s ( + name VARCHAR(700) NOT NULL, + chunk_index INT NOT NULL, + chunk_data LONGBLOB NOT NULL, + created_transaction BIGINT NOT NULL, + deleted_transaction BIGINT NOT NULL DEFAULT '9223372036854775807', + CONSTRAINT pk_stored_schema PRIMARY KEY (name, chunk_index, created_transaction)) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`, + t.Schema(), + ) +} + +func createSchemaRevisionTable(t *tables) string { + return fmt.Sprintf(`CREATE TABLE %s ( + name VARCHAR(700) NOT NULL DEFAULT 'current', + hash BLOB NOT NULL, + created_transaction BIGINT NOT NULL, + deleted_transaction BIGINT NOT NULL DEFAULT '9223372036854775807', + CONSTRAINT pk_stored_schema_revision PRIMARY KEY (name, created_transaction)) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;`, + t.SchemaRevision(), + ) +} + +func init() { + mustRegisterMigration("add_schema_tables", "add_expiration_to_relation_tuple", noNonatomicMigration, + newStatementBatch( + createSchemaTable, + createSchemaRevisionTable, + ).execute, + ) +} diff --git a/internal/datastore/mysql/migrations/zz_migration.0012_populate_schema_tables.go b/internal/datastore/mysql/migrations/zz_migration.0012_populate_schema_tables.go new file mode 100644 index 000000000..9fa33139f --- /dev/null +++ b/internal/datastore/mysql/migrations/zz_migration.0012_populate_schema_tables.go @@ -0,0 +1,188 @@ +package migrations + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "sort" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" +) + +const ( + schemaChunkSize = 1024 * 1024 // 1MB chunks + currentSchemaVersion = 1 + unifiedSchemaName = "unified_schema" + schemaRevisionName = "current" +) + +func init() { + mustRegisterMigration("populate_schema_tables", "add_schema_tables", noNonatomicMigration, populateSchemaTablesFunc) +} + +func populateSchemaTablesFunc(ctx context.Context, wrapper TxWrapper) error { + tx := wrapper.tx + + // Read all existing namespaces (not deleted ones) + //nolint:gosec // Table name is from internal schema configuration, not user input + query := fmt.Sprintf(` + SELECT nc1.namespace, nc1.serialized_config + FROM %[1]s nc1 + INNER JOIN ( + SELECT namespace, MAX(created_transaction) as max_created + FROM %[1]s + WHERE deleted_transaction = 9223372036854775807 + GROUP BY namespace + ) nc2 ON nc1.namespace = nc2.namespace AND nc1.created_transaction = nc2.max_created + WHERE nc1.deleted_transaction = 9223372036854775807 + `, wrapper.tables.Namespace()) + + rows, err := tx.QueryContext(ctx, query) + if err != nil { + return fmt.Errorf("failed to query namespaces: %w", err) + } + defer rows.Close() + + namespaces := make(map[string]*core.NamespaceDefinition) + for rows.Next() { + var name string + var config []byte + if err := rows.Scan(&name, &config); err != nil { + return fmt.Errorf("failed to scan namespace: %w", err) + } + + var ns core.NamespaceDefinition + if err := ns.UnmarshalVT(config); err != nil { + return fmt.Errorf("failed to unmarshal namespace %s: %w", name, err) + } + namespaces[name] = &ns + } + if err := rows.Err(); err != nil { + return fmt.Errorf("error iterating namespaces: %w", err) + } + + // Read all existing caveats (not deleted ones) + query = fmt.Sprintf(` + SELECT c1.name, c1.definition + FROM %[1]s c1 + INNER JOIN ( + SELECT name, MAX(created_transaction) as max_created + FROM %[1]s + WHERE deleted_transaction = 9223372036854775807 + GROUP BY name + ) c2 ON c1.name = c2.name AND c1.created_transaction = c2.max_created + WHERE c1.deleted_transaction = 9223372036854775807 + `, wrapper.tables.Caveat()) + + rows, err = tx.QueryContext(ctx, query) + if err != nil { + return fmt.Errorf("failed to query caveats: %w", err) + } + defer rows.Close() + + caveats := make(map[string]*core.CaveatDefinition) + for rows.Next() { + var name string + var definition []byte + if err := rows.Scan(&name, &definition); err != nil { + return fmt.Errorf("failed to scan caveat: %w", err) + } + + var caveat core.CaveatDefinition + if err := caveat.UnmarshalVT(definition); err != nil { + return fmt.Errorf("failed to unmarshal caveat %s: %w", name, err) + } + caveats[name] = &caveat + } + if err := rows.Err(); err != nil { + return fmt.Errorf("error iterating caveats: %w", err) + } + + // If there are no namespaces or caveats, skip migration + if len(namespaces) == 0 && len(caveats) == 0 { + return nil + } + + // Generate canonical schema for hash computation + allDefs := make([]compiler.SchemaDefinition, 0, len(namespaces)+len(caveats)) + for _, ns := range namespaces { + allDefs = append(allDefs, ns) + } + for _, caveat := range caveats { + allDefs = append(allDefs, caveat) + } + + // Sort alphabetically for canonical ordering + sort.Slice(allDefs, func(i, j int) bool { + return allDefs[i].GetName() < allDefs[j].GetName() + }) + + // Generate canonical schema text + canonicalSchemaText, _, err := generator.GenerateSchema(allDefs) + if err != nil { + return fmt.Errorf("failed to generate canonical schema: %w", err) + } + + // Compute SHA256 hash + hashBytes := sha256.Sum256([]byte(canonicalSchemaText)) + schemaHash := hashBytes[:] + + // Generate user-facing schema text + schemaText, _, err := generator.GenerateSchema(allDefs) + if err != nil { + return fmt.Errorf("failed to generate schema: %w", err) + } + + // Create stored schema proto + storedSchema := &core.StoredSchema{ + Version: currentSchemaVersion, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: schemaText, + SchemaHash: hex.EncodeToString(schemaHash), + NamespaceDefinitions: namespaces, + CaveatDefinitions: caveats, + }, + }, + } + + // Marshal schema + schemaData, err := storedSchema.MarshalVT() + if err != nil { + return fmt.Errorf("failed to marshal schema: %w", err) + } + + // Insert schema chunks + for chunkIndex := 0; chunkIndex*schemaChunkSize < len(schemaData); chunkIndex++ { + start := chunkIndex * schemaChunkSize + end := start + schemaChunkSize + if end > len(schemaData) { + end = len(schemaData) + } + chunk := schemaData[start:end] + + query = fmt.Sprintf(` + INSERT INTO %s (name, chunk_index, chunk_data) + VALUES (?, ?, ?) + `, wrapper.tables.Schema()) + _, err = tx.ExecContext(ctx, query, unifiedSchemaName, chunkIndex, chunk) + if err != nil { + return fmt.Errorf("failed to insert schema chunk %d: %w", chunkIndex, err) + } + } + + // Insert schema hash + query = fmt.Sprintf(` + INSERT INTO %s (name, hash) + VALUES (?, ?) + `, wrapper.tables.SchemaRevision()) + _, err = tx.ExecContext(ctx, query, schemaRevisionName, schemaHash) + if err != nil { + return fmt.Errorf("failed to insert schema hash: %w", err) + } + + return nil +} diff --git a/internal/datastore/mysql/schema_chunker.go b/internal/datastore/mysql/schema_chunker.go new file mode 100644 index 000000000..8de48e458 --- /dev/null +++ b/internal/datastore/mysql/schema_chunker.go @@ -0,0 +1,133 @@ +package mysql + +import ( + "context" + "database/sql" + "errors" + + sq "github.com/Masterminds/squirrel" + + "github.com/authzed/spicedb/internal/datastore/common" +) + +const ( + // MySQL LONGBLOB can store up to 4GB, but we use 64KB chunks for safety + // and to avoid issues with max_allowed_packet settings. + mysqlMaxChunkSize = 64 * 1024 // 64KB +) + +// BaseSchemaChunkerConfig provides the base configuration for MySQL schema chunking. +// MySQL uses smaller chunks (64KB), question mark placeholders, and tombstone-based write mode. +var BaseSchemaChunkerConfig = common.SQLByteChunkerConfig[uint64]{ + TableName: "stored_schema", + NameColumn: "name", + ChunkIndexColumn: "chunk_index", + ChunkDataColumn: "chunk_data", + MaxChunkSize: mysqlMaxChunkSize, + PlaceholderFormat: sq.Question, + WriteMode: common.WriteModeInsertWithTombstones, + CreatedAtColumn: "created_transaction", + DeletedAtColumn: "deleted_transaction", + AliveValue: liveDeletedTxnID, +} + +// mysqlRevisionAwareExecutor wraps the reader's query infrastructure to provide revision-aware chunk reading +type mysqlRevisionAwareExecutor struct { + txSource txFactory + aliveFilter func(sq.SelectBuilder) sq.SelectBuilder +} + +func (e *mysqlRevisionAwareExecutor) BeginTransaction(ctx context.Context) (common.ChunkedBytesTransaction, error) { + // We don't support transactions for reading + return nil, errors.New("transactions not supported for revision-aware reads") +} + +func (e *mysqlRevisionAwareExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + // Apply the alive filter to get chunks that were alive at this revision + builder = e.aliveFilter(builder) + + sql, args, err := builder.ToSql() + if err != nil { + return nil, err + } + + // Get a transaction to execute the query + tx, txCleanup, err := e.txSource(ctx) + if err != nil { + return nil, err + } + defer common.LogOnError(ctx, txCleanup) + + // Execute the query + rows, err := tx.QueryContext(ctx, sql, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + result := make(map[int][]byte) + for rows.Next() { + var chunkIndex int + var chunkData []byte + if err := rows.Scan(&chunkIndex, &chunkData); err != nil { + return nil, err + } + result[chunkIndex] = chunkData + } + + return result, rows.Err() +} + +// mysqlTransactionAwareExecutor wraps an existing sql.Tx to provide transaction-aware chunk writing +type mysqlTransactionAwareExecutor struct { + tx *sql.Tx +} + +func newMySQLTransactionAwareExecutor(tx *sql.Tx) *mysqlTransactionAwareExecutor { + return &mysqlTransactionAwareExecutor{tx: tx} +} + +func (e *mysqlTransactionAwareExecutor) BeginTransaction(ctx context.Context) (common.ChunkedBytesTransaction, error) { + // Return a transaction wrapper that uses the existing transaction + return &mysqlTransactionAwareTransaction{tx: e.tx}, nil +} + +func (e *mysqlTransactionAwareExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + return nil, errors.New("read operations not supported on transaction-aware executor") +} + +// mysqlTransactionAwareTransaction implements common.ChunkedBytesTransaction using an existing sql.Tx +// without committing after each operation. +type mysqlTransactionAwareTransaction struct { + tx *sql.Tx +} + +func (t *mysqlTransactionAwareTransaction) ExecuteWrite(ctx context.Context, builder sq.InsertBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.ExecContext(ctx, sql, args...) + return err +} + +func (t *mysqlTransactionAwareTransaction) ExecuteDelete(ctx context.Context, builder sq.DeleteBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.ExecContext(ctx, sql, args...) + return err +} + +func (t *mysqlTransactionAwareTransaction) ExecuteUpdate(ctx context.Context, builder sq.UpdateBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.ExecContext(ctx, sql, args...) + return err +} diff --git a/internal/datastore/mysql/schema_chunker_test.go b/internal/datastore/mysql/schema_chunker_test.go new file mode 100644 index 000000000..960b2e3cc --- /dev/null +++ b/internal/datastore/mysql/schema_chunker_test.go @@ -0,0 +1,151 @@ +package mysql + +import ( + "context" + "testing" + + sq "github.com/Masterminds/squirrel" + "github.com/stretchr/testify/require" + + "github.com/authzed/spicedb/internal/datastore/common" +) + +// fakeTransaction captures SQL and args via builder.ToSql() for verification. +type fakeTransaction struct { + capturedSQL []string + capturedArgs [][]any + updateQueries []string +} + +func (f *fakeTransaction) ExecuteWrite(_ context.Context, builder sq.InsertBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + f.capturedSQL = append(f.capturedSQL, sql) + f.capturedArgs = append(f.capturedArgs, args) + return nil +} + +func (f *fakeTransaction) ExecuteDelete(_ context.Context, builder sq.DeleteBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + f.capturedSQL = append(f.capturedSQL, sql) + f.capturedArgs = append(f.capturedArgs, args) + return nil +} + +func (f *fakeTransaction) ExecuteUpdate(_ context.Context, builder sq.UpdateBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + f.capturedSQL = append(f.capturedSQL, sql) + f.capturedArgs = append(f.capturedArgs, args) + f.updateQueries = append(f.updateQueries, sql) + return nil +} + +// fakeExecutor returns the fakeTransaction from BeginTransaction. +type fakeExecutor struct { + transaction *fakeTransaction + readResult map[int][]byte +} + +func (e *fakeExecutor) BeginTransaction(_ context.Context) (common.ChunkedBytesTransaction, error) { + return e.transaction, nil +} + +func (e *fakeExecutor) ExecuteRead(_ context.Context, _ sq.SelectBuilder) (map[int][]byte, error) { + return e.readResult, nil +} + +func TestBaseSchemaChunkerConfig(t *testing.T) { + require.Equal(t, "stored_schema", BaseSchemaChunkerConfig.TableName) + require.Equal(t, "name", BaseSchemaChunkerConfig.NameColumn) + require.Equal(t, "chunk_index", BaseSchemaChunkerConfig.ChunkIndexColumn) + require.Equal(t, "chunk_data", BaseSchemaChunkerConfig.ChunkDataColumn) + require.Equal(t, 64*1024, BaseSchemaChunkerConfig.MaxChunkSize) + require.Equal(t, sq.Question, BaseSchemaChunkerConfig.PlaceholderFormat) + require.Equal(t, common.WriteModeInsertWithTombstones, BaseSchemaChunkerConfig.WriteMode) +} + +func TestWrite(t *testing.T) { + txn := &fakeTransaction{} + executor := &fakeExecutor{transaction: txn} + + config := BaseSchemaChunkerConfig.WithExecutor(executor) + chunker := common.MustNewSQLByteChunker(config) + + createdAt := uint64(100) + err := chunker.WriteChunkedBytes(context.Background(), "test-key", []byte("hello"), createdAt) + require.NoError(t, err) + + // Should have UPDATE (tombstone) + INSERT + require.Len(t, txn.capturedSQL, 2) + require.Len(t, txn.updateQueries, 1) + + // UPDATE uses ? placeholders + require.Contains(t, txn.capturedSQL[0], "UPDATE stored_schema") + require.Contains(t, txn.capturedSQL[0], "SET deleted_transaction = ?") + + // INSERT uses ? placeholders + require.Contains(t, txn.capturedSQL[1], "INSERT INTO stored_schema") + require.Contains(t, txn.capturedSQL[1], "?") +} + +func TestDelete(t *testing.T) { + txn := &fakeTransaction{} + executor := &fakeExecutor{transaction: txn} + + config := BaseSchemaChunkerConfig.WithExecutor(executor) + chunker := common.MustNewSQLByteChunker(config) + + deletedAt := uint64(200) + err := chunker.DeleteChunkedBytes(context.Background(), "test-key", deletedAt) + require.NoError(t, err) + + // Should have UPDATE (tombstone) + require.Len(t, txn.capturedSQL, 1) + require.Len(t, txn.updateQueries, 1) + require.Contains(t, txn.capturedSQL[0], "UPDATE stored_schema") + require.Contains(t, txn.capturedSQL[0], "SET deleted_transaction = ?") +} + +func TestRead(t *testing.T) { + executor := &fakeExecutor{ + readResult: map[int][]byte{ + 0: []byte("hello"), + }, + } + + config := BaseSchemaChunkerConfig.WithExecutor(executor) + chunker := common.MustNewSQLByteChunker(config) + + data, err := chunker.ReadChunkedBytes(context.Background(), "test-key") + require.NoError(t, err) + require.Equal(t, []byte("hello"), data) +} + +func TestMultipleChunks(t *testing.T) { + txn := &fakeTransaction{} + executor := &fakeExecutor{transaction: txn} + + config := BaseSchemaChunkerConfig.WithExecutor(executor) + config.MaxChunkSize = 5 + chunker := common.MustNewSQLByteChunker(config) + + createdAt := uint64(100) + err := chunker.WriteChunkedBytes(context.Background(), "test-key", []byte("hello world!"), createdAt) + require.NoError(t, err) + + // UPDATE (tombstone) + INSERT + require.Len(t, txn.capturedSQL, 2) + + // "hello world!" is 12 bytes, chunk size 5 => 3 chunks (5+5+2) + // Each chunk has 4 args (name, chunk_index, chunk_data, created_transaction) + insertArgs := txn.capturedArgs[1] + require.Len(t, insertArgs, 12) // 3 chunks * 4 values +} diff --git a/internal/datastore/mysql/storedschema.go b/internal/datastore/mysql/storedschema.go new file mode 100644 index 000000000..ff3f84f53 --- /dev/null +++ b/internal/datastore/mysql/storedschema.go @@ -0,0 +1,41 @@ +package mysql + +import ( + "context" + "fmt" + + "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/pkg/datastore" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +// ReadStoredSchema reads the unified stored schema from the MySQL schema table. +func (mr *mysqlReader) ReadStoredSchema(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + executor := &mysqlRevisionAwareExecutor{ + txSource: mr.txSource, + aliveFilter: mr.aliveFilter, + } + + chunker, err := common.NewSQLByteChunker(BaseSchemaChunkerConfig.WithExecutor(executor)) + if err != nil { + return nil, fmt.Errorf("failed to create schema chunker: %w", err) + } + + rw := common.NewSQLSingleStoreSchemaReaderWriterForTransactionIDs(chunker, common.NoTransactionID[uint64]) + return rw.ReadStoredSchema(ctx) +} + +// WriteStoredSchema writes the unified stored schema to the MySQL schema table. +func (rwt *mysqlReadWriteTXN) WriteStoredSchema(ctx context.Context, schema *core.StoredSchema) error { + executor := newMySQLTransactionAwareExecutor(rwt.tx) + + chunker, err := common.NewSQLByteChunker(BaseSchemaChunkerConfig.WithExecutor(executor)) + if err != nil { + return fmt.Errorf("failed to create schema chunker: %w", err) + } + + rw := common.NewSQLSingleStoreSchemaReaderWriterForTransactionIDs(chunker, func(_ context.Context) uint64 { + return rwt.newTxnID + }) + return rw.WriteStoredSchema(ctx, schema) +} diff --git a/internal/datastore/postgres/migrations/zz_migration.0024_add_schema_tables.go b/internal/datastore/postgres/migrations/zz_migration.0024_add_schema_tables.go new file mode 100644 index 000000000..17ab14883 --- /dev/null +++ b/internal/datastore/postgres/migrations/zz_migration.0024_add_schema_tables.go @@ -0,0 +1,41 @@ +package migrations + +import ( + "context" + + "github.com/jackc/pgx/v5" +) + +var schemaTablesStatements = []string{ + `CREATE TABLE schema ( + name VARCHAR NOT NULL, + chunk_index INT NOT NULL, + chunk_data BYTEA NOT NULL, + created_xid xid8 NOT NULL DEFAULT (pg_current_xact_id()), + deleted_xid xid8 NOT NULL DEFAULT ('9223372036854775807'), + CONSTRAINT pk_schema PRIMARY KEY (name, chunk_index, created_xid));`, + `CREATE INDEX ix_schema_gc ON schema (deleted_xid DESC) WHERE deleted_xid < '9223372036854775807'::xid8;`, + `CREATE TABLE schema_revision ( + name VARCHAR NOT NULL DEFAULT 'current', + hash BYTEA NOT NULL, + created_xid xid8 NOT NULL DEFAULT (pg_current_xact_id()), + deleted_xid xid8 NOT NULL DEFAULT ('9223372036854775807'), + CONSTRAINT pk_schema_revision PRIMARY KEY (name, created_xid));`, + `CREATE INDEX ix_schema_revision_gc ON schema_revision (deleted_xid DESC) WHERE deleted_xid < '9223372036854775807'::xid8;`, +} + +func init() { + if err := DatabaseMigrations.Register("add-schema-tables", "add-index-for-transaction-gc", + noNonatomicMigration, + func(ctx context.Context, tx pgx.Tx) error { + for _, stmt := range schemaTablesStatements { + if _, err := tx.Exec(ctx, stmt); err != nil { + return err + } + } + + return nil + }); err != nil { + panic("failed to register migration: " + err.Error()) + } +} diff --git a/internal/datastore/postgres/migrations/zz_migration.0025_populate_schema_tables.go b/internal/datastore/postgres/migrations/zz_migration.0025_populate_schema_tables.go new file mode 100644 index 000000000..632e3a03a --- /dev/null +++ b/internal/datastore/postgres/migrations/zz_migration.0025_populate_schema_tables.go @@ -0,0 +1,175 @@ +package migrations + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "sort" + + "github.com/jackc/pgx/v5" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" +) + +const ( + schemaChunkSize = 1024 * 1024 // 1MB chunks + currentSchemaVersion = 1 + unifiedSchemaName = "unified_schema" + schemaRevisionName = "current" +) + +func init() { + if err := DatabaseMigrations.Register("populate-schema-tables", "add-schema-tables", + noNonatomicMigration, + func(ctx context.Context, tx pgx.Tx) error { + // Read all existing namespaces + rows, err := tx.Query(ctx, ` + SELECT DISTINCT ON (namespace) + namespace, serialized_config + FROM namespace_config + WHERE deleted_xid = '9223372036854775807'::xid8 + ORDER BY namespace, created_xid DESC + `) + if err != nil { + return fmt.Errorf("failed to query namespaces: %w", err) + } + defer rows.Close() + + namespaces := make(map[string]*core.NamespaceDefinition) + for rows.Next() { + var name string + var config []byte + if err := rows.Scan(&name, &config); err != nil { + return fmt.Errorf("failed to scan namespace: %w", err) + } + + var ns core.NamespaceDefinition + if err := ns.UnmarshalVT(config); err != nil { + return fmt.Errorf("failed to unmarshal namespace %s: %w", name, err) + } + namespaces[name] = &ns + } + if err := rows.Err(); err != nil { + return fmt.Errorf("error iterating namespaces: %w", err) + } + + // Read all existing caveats + rows, err = tx.Query(ctx, ` + SELECT DISTINCT ON (name) + name, definition + FROM caveat + WHERE deleted_xid = '9223372036854775807'::xid8 + ORDER BY name, created_xid DESC + `) + if err != nil { + return fmt.Errorf("failed to query caveats: %w", err) + } + defer rows.Close() + + caveats := make(map[string]*core.CaveatDefinition) + for rows.Next() { + var name string + var definition []byte + if err := rows.Scan(&name, &definition); err != nil { + return fmt.Errorf("failed to scan caveat: %w", err) + } + + var caveat core.CaveatDefinition + if err := caveat.UnmarshalVT(definition); err != nil { + return fmt.Errorf("failed to unmarshal caveat %s: %w", name, err) + } + caveats[name] = &caveat + } + if err := rows.Err(); err != nil { + return fmt.Errorf("error iterating caveats: %w", err) + } + + // If there are no namespaces or caveats, skip migration + if len(namespaces) == 0 && len(caveats) == 0 { + return nil + } + + // Generate canonical schema for hash computation + allDefs := make([]compiler.SchemaDefinition, 0, len(namespaces)+len(caveats)) + for _, ns := range namespaces { + allDefs = append(allDefs, ns) + } + for _, caveat := range caveats { + allDefs = append(allDefs, caveat) + } + + // Sort alphabetically for canonical ordering + sort.Slice(allDefs, func(i, j int) bool { + return allDefs[i].GetName() < allDefs[j].GetName() + }) + + // Generate canonical schema text + canonicalSchemaText, _, err := generator.GenerateSchema(allDefs) + if err != nil { + return fmt.Errorf("failed to generate canonical schema: %w", err) + } + + // Compute SHA256 hash + hashBytes := sha256.Sum256([]byte(canonicalSchemaText)) + schemaHash := hashBytes[:] + + // Generate user-facing schema text (with original ordering) + schemaText, _, err := generator.GenerateSchema(allDefs) + if err != nil { + return fmt.Errorf("failed to generate schema: %w", err) + } + + // Create stored schema proto + storedSchema := &core.StoredSchema{ + Version: currentSchemaVersion, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: schemaText, + SchemaHash: hex.EncodeToString(schemaHash), + NamespaceDefinitions: namespaces, + CaveatDefinitions: caveats, + }, + }, + } + + // Marshal schema + schemaData, err := storedSchema.MarshalVT() + if err != nil { + return fmt.Errorf("failed to marshal schema: %w", err) + } + + // Insert schema chunks + for chunkIndex := 0; chunkIndex*schemaChunkSize < len(schemaData); chunkIndex++ { + start := chunkIndex * schemaChunkSize + end := start + schemaChunkSize + if end > len(schemaData) { + end = len(schemaData) + } + chunk := schemaData[start:end] + + _, err = tx.Exec(ctx, ` + INSERT INTO schema (name, chunk_index, chunk_data) + VALUES ($1, $2, $3) + `, unifiedSchemaName, chunkIndex, chunk) + if err != nil { + return fmt.Errorf("failed to insert schema chunk %d: %w", chunkIndex, err) + } + } + + // Insert schema hash + _, err = tx.Exec(ctx, ` + INSERT INTO schema_revision (name, hash) + VALUES ($1, $2) + `, schemaRevisionName, schemaHash) + if err != nil { + return fmt.Errorf("failed to insert schema hash: %w", err) + } + + return nil + }); err != nil { + panic("failed to register migration: " + err.Error()) + } +} diff --git a/internal/datastore/postgres/schema_chunker.go b/internal/datastore/postgres/schema_chunker.go new file mode 100644 index 000000000..907d9760a --- /dev/null +++ b/internal/datastore/postgres/schema_chunker.go @@ -0,0 +1,123 @@ +package postgres + +import ( + "context" + "errors" + + sq "github.com/Masterminds/squirrel" + "github.com/jackc/pgx/v5" + + "github.com/authzed/spicedb/internal/datastore/common" + pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" +) + +const ( + // PostgreSQL has no practical limit on BYTEA column size (up to 1GB per cell), + // but we use 1MB chunks for reasonable memory usage and query performance. + postgresMaxChunkSize = 1024 * 1024 // 1MB +) + +// BaseSchemaChunkerConfig provides the base configuration for Postgres schema chunking. +// Postgres uses tombstone-based write mode with XID8 columns (matching relationship tables). +var BaseSchemaChunkerConfig = common.SQLByteChunkerConfig[uint64]{ + TableName: "schema", + NameColumn: "name", + ChunkIndexColumn: "chunk_index", + ChunkDataColumn: "chunk_data", + MaxChunkSize: postgresMaxChunkSize, + PlaceholderFormat: sq.Dollar, + WriteMode: common.WriteModeInsertWithTombstones, + CreatedAtColumn: "created_xid", + DeletedAtColumn: "deleted_xid", + AliveValue: liveDeletedTxnID, +} + +// postgresChunkedBytesTransaction implements common.ChunkedBytesTransaction for PostgreSQL. +type postgresChunkedBytesTransaction struct { + tx pgx.Tx +} + +func (t *postgresChunkedBytesTransaction) ExecuteWrite(ctx context.Context, builder sq.InsertBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.Exec(ctx, sql, args...) + return err +} + +func (t *postgresChunkedBytesTransaction) ExecuteDelete(ctx context.Context, builder sq.DeleteBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.Exec(ctx, sql, args...) + return err +} + +func (t *postgresChunkedBytesTransaction) ExecuteUpdate(ctx context.Context, builder sq.UpdateBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + + _, err = t.tx.Exec(ctx, sql, args...) + return err +} + +// pgRevisionAwareExecutor wraps the reader's query infrastructure to provide revision-aware chunk reading +type pgRevisionAwareExecutor struct { + query pgxcommon.DBFuncQuerier + aliveFilter func(sq.SelectBuilder) sq.SelectBuilder +} + +func (e *pgRevisionAwareExecutor) BeginTransaction(ctx context.Context) (common.ChunkedBytesTransaction, error) { + // We don't support transactions for reading + return nil, errors.New("transactions not supported for revision-aware reads") +} + +func (e *pgRevisionAwareExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + // Apply the alive filter to get chunks that were alive at this revision + builder = e.aliveFilter(builder) + + sql, args, err := builder.ToSql() + if err != nil { + return nil, err + } + + // Execute using the reader's query function + result := make(map[int][]byte) + err = e.query.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error { + for rows.Next() { + var chunkIndex int + var chunkData []byte + if err := rows.Scan(&chunkIndex, &chunkData); err != nil { + return err + } + result[chunkIndex] = chunkData + } + return rows.Err() + }, sql, args...) + + return result, err +} + +// pgTransactionAwareExecutor wraps an existing pgx.Tx to provide transaction-aware chunk writing +type pgTransactionAwareExecutor struct { + tx pgx.Tx +} + +func newPGTransactionAwareExecutor(tx pgx.Tx) *pgTransactionAwareExecutor { + return &pgTransactionAwareExecutor{tx: tx} +} + +func (e *pgTransactionAwareExecutor) BeginTransaction(ctx context.Context) (common.ChunkedBytesTransaction, error) { + // Return a transaction wrapper that uses the existing transaction + return &postgresChunkedBytesTransaction{tx: e.tx}, nil +} + +func (e *pgTransactionAwareExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + return nil, errors.New("read operations not supported on transaction-aware executor") +} diff --git a/internal/datastore/postgres/schema_chunker_test.go b/internal/datastore/postgres/schema_chunker_test.go new file mode 100644 index 000000000..6be310e0b --- /dev/null +++ b/internal/datastore/postgres/schema_chunker_test.go @@ -0,0 +1,151 @@ +package postgres + +import ( + "context" + "testing" + + sq "github.com/Masterminds/squirrel" + "github.com/stretchr/testify/require" + + "github.com/authzed/spicedb/internal/datastore/common" +) + +// fakeTransaction captures SQL and args via builder.ToSql() for verification. +type fakeTransaction struct { + capturedSQL []string + capturedArgs [][]any + updateQueries []string +} + +func (f *fakeTransaction) ExecuteWrite(_ context.Context, builder sq.InsertBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + f.capturedSQL = append(f.capturedSQL, sql) + f.capturedArgs = append(f.capturedArgs, args) + return nil +} + +func (f *fakeTransaction) ExecuteDelete(_ context.Context, builder sq.DeleteBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + f.capturedSQL = append(f.capturedSQL, sql) + f.capturedArgs = append(f.capturedArgs, args) + return nil +} + +func (f *fakeTransaction) ExecuteUpdate(_ context.Context, builder sq.UpdateBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + f.capturedSQL = append(f.capturedSQL, sql) + f.capturedArgs = append(f.capturedArgs, args) + f.updateQueries = append(f.updateQueries, sql) + return nil +} + +// fakeExecutor returns the fakeTransaction from BeginTransaction. +type fakeExecutor struct { + transaction *fakeTransaction + readResult map[int][]byte +} + +func (e *fakeExecutor) BeginTransaction(_ context.Context) (common.ChunkedBytesTransaction, error) { + return e.transaction, nil +} + +func (e *fakeExecutor) ExecuteRead(_ context.Context, _ sq.SelectBuilder) (map[int][]byte, error) { + return e.readResult, nil +} + +func TestBaseSchemaChunkerConfig(t *testing.T) { + require.Equal(t, "schema", BaseSchemaChunkerConfig.TableName) + require.Equal(t, "name", BaseSchemaChunkerConfig.NameColumn) + require.Equal(t, "chunk_index", BaseSchemaChunkerConfig.ChunkIndexColumn) + require.Equal(t, "chunk_data", BaseSchemaChunkerConfig.ChunkDataColumn) + require.Equal(t, 1024*1024, BaseSchemaChunkerConfig.MaxChunkSize) + require.Equal(t, sq.Dollar, BaseSchemaChunkerConfig.PlaceholderFormat) + require.Equal(t, common.WriteModeInsertWithTombstones, BaseSchemaChunkerConfig.WriteMode) +} + +func TestWrite(t *testing.T) { + txn := &fakeTransaction{} + executor := &fakeExecutor{transaction: txn} + + config := BaseSchemaChunkerConfig.WithExecutor(executor) + chunker := common.MustNewSQLByteChunker(config) + + createdAt := uint64(100) + err := chunker.WriteChunkedBytes(context.Background(), "test-key", []byte("hello"), createdAt) + require.NoError(t, err) + + // Should have UPDATE (tombstone) + INSERT + require.Len(t, txn.capturedSQL, 2) + require.Len(t, txn.updateQueries, 1) + + // UPDATE uses $ placeholders + require.Contains(t, txn.capturedSQL[0], "UPDATE schema") + require.Contains(t, txn.capturedSQL[0], "$1") + + // INSERT uses $ placeholders + require.Contains(t, txn.capturedSQL[1], "INSERT INTO schema") + require.Contains(t, txn.capturedSQL[1], "$") +} + +func TestDelete(t *testing.T) { + txn := &fakeTransaction{} + executor := &fakeExecutor{transaction: txn} + + config := BaseSchemaChunkerConfig.WithExecutor(executor) + chunker := common.MustNewSQLByteChunker(config) + + deletedAt := uint64(200) + err := chunker.DeleteChunkedBytes(context.Background(), "test-key", deletedAt) + require.NoError(t, err) + + // Should have UPDATE (tombstone) + require.Len(t, txn.capturedSQL, 1) + require.Len(t, txn.updateQueries, 1) + require.Contains(t, txn.capturedSQL[0], "UPDATE schema") + require.Contains(t, txn.capturedSQL[0], "$1") +} + +func TestRead(t *testing.T) { + executor := &fakeExecutor{ + readResult: map[int][]byte{ + 0: []byte("hello"), + }, + } + + config := BaseSchemaChunkerConfig.WithExecutor(executor) + chunker := common.MustNewSQLByteChunker(config) + + data, err := chunker.ReadChunkedBytes(context.Background(), "test-key") + require.NoError(t, err) + require.Equal(t, []byte("hello"), data) +} + +func TestMultipleChunks(t *testing.T) { + txn := &fakeTransaction{} + executor := &fakeExecutor{transaction: txn} + + config := BaseSchemaChunkerConfig.WithExecutor(executor) + config.MaxChunkSize = 5 + chunker := common.MustNewSQLByteChunker(config) + + createdAt := uint64(100) + err := chunker.WriteChunkedBytes(context.Background(), "test-key", []byte("hello world!"), createdAt) + require.NoError(t, err) + + // UPDATE (tombstone) + INSERT + require.Len(t, txn.capturedSQL, 2) + + // "hello world!" is 12 bytes, chunk size 5 => 3 chunks (5+5+2) + // Each chunk has 4 args (name, chunk_index, chunk_data, created_xid) + insertArgs := txn.capturedArgs[1] + require.Len(t, insertArgs, 12) // 3 chunks * 4 values +} diff --git a/internal/datastore/postgres/storedschema.go b/internal/datastore/postgres/storedschema.go new file mode 100644 index 000000000..464722cfa --- /dev/null +++ b/internal/datastore/postgres/storedschema.go @@ -0,0 +1,41 @@ +package postgres + +import ( + "context" + "fmt" + + "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/pkg/datastore" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +// ReadStoredSchema reads the unified stored schema from the Postgres schema table. +func (r *pgReader) ReadStoredSchema(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + executor := &pgRevisionAwareExecutor{ + query: r.query, + aliveFilter: r.aliveFilter, + } + + chunker, err := common.NewSQLByteChunker(BaseSchemaChunkerConfig.WithExecutor(executor)) + if err != nil { + return nil, fmt.Errorf("failed to create schema chunker: %w", err) + } + + rw := common.NewSQLSingleStoreSchemaReaderWriterForTransactionIDs(chunker, common.NoTransactionID[uint64]) + return rw.ReadStoredSchema(ctx) +} + +// WriteStoredSchema writes the unified stored schema to the Postgres schema table. +func (rwt *pgReadWriteTXN) WriteStoredSchema(ctx context.Context, schema *core.StoredSchema) error { + executor := newPGTransactionAwareExecutor(rwt.tx) + + chunker, err := common.NewSQLByteChunker(BaseSchemaChunkerConfig.WithExecutor(executor)) + if err != nil { + return fmt.Errorf("failed to create schema chunker: %w", err) + } + + rw := common.NewSQLSingleStoreSchemaReaderWriterForTransactionIDs(chunker, func(_ context.Context) uint64 { + return rwt.newXID.Uint64 + }) + return rw.WriteStoredSchema(ctx, schema) +} diff --git a/internal/datastore/proxy/checkingreplicated.go b/internal/datastore/proxy/checkingreplicated.go index 0728b8d3f..39fe7c79c 100644 --- a/internal/datastore/proxy/checkingreplicated.go +++ b/internal/datastore/proxy/checkingreplicated.go @@ -212,6 +212,14 @@ func (rr *checkingStableReader) LookupCounters(ctx context.Context) ([]datastore return rr.chosenReader.LookupCounters(ctx) } +func (rr *checkingStableReader) ReadStoredSchema(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + if err := rr.determineSource(ctx); err != nil { + return nil, err + } + + return rr.chosenReader.ReadStoredSchema(ctx) +} + // determineSource will choose the replica or primary to read from based on the revision, by checking // if the replica contains the revision. If the replica does not contain the revision, the primary // will be used instead. diff --git a/internal/datastore/proxy/checkingreplicated_test.go b/internal/datastore/proxy/checkingreplicated_test.go index 3cd8259a5..72d301e53 100644 --- a/internal/datastore/proxy/checkingreplicated_test.go +++ b/internal/datastore/proxy/checkingreplicated_test.go @@ -241,6 +241,10 @@ func (fakeSnapshotReader) LookupCounters(ctx context.Context) ([]datastore.Relat return nil, fmt.Errorf("not implemented") } +func (fakeSnapshotReader) ReadStoredSchema(_ context.Context) (*datastore.ReadOnlyStoredSchema, error) { + return nil, nil +} + func fakeIterator(fsr fakeSnapshotReader, explainCallback options.SQLExplainCallbackForTest) datastore.RelationshipIterator { return func(yield func(tuple.Relationship, error) bool) { if fsr.state == "primary" { diff --git a/internal/datastore/proxy/indexcheck/fakedatastore_test.go b/internal/datastore/proxy/indexcheck/fakedatastore_test.go index 0c866b146..720e5fccc 100644 --- a/internal/datastore/proxy/indexcheck/fakedatastore_test.go +++ b/internal/datastore/proxy/indexcheck/fakedatastore_test.go @@ -155,6 +155,10 @@ func (fakeSnapshotReader) LookupCounters(ctx context.Context) ([]datastore.Relat return nil, fmt.Errorf("not implemented") } +func (fakeSnapshotReader) ReadStoredSchema(_ context.Context) (*datastore.ReadOnlyStoredSchema, error) { + return nil, nil +} + func fakeIterator(fsr fakeSnapshotReader, explainCallback options.SQLExplainCallbackForTest) datastore.RelationshipIterator { return func(yield func(tuple.Relationship, error) bool) { if explainCallback != nil { @@ -218,3 +222,7 @@ func (f *fakeRWT) DeleteRelationships(ctx context.Context, filter *v1.Relationsh func (f *fakeRWT) BulkLoad(ctx context.Context, iter datastore.BulkWriteRelationshipSource) (uint64, error) { return 0, nil } + +func (f *fakeRWT) WriteStoredSchema(ctx context.Context, schema *corev1.StoredSchema) error { + return nil +} diff --git a/internal/datastore/proxy/indexcheck/indexcheck.go b/internal/datastore/proxy/indexcheck/indexcheck.go index 78761cf77..1b5826e36 100644 --- a/internal/datastore/proxy/indexcheck/indexcheck.go +++ b/internal/datastore/proxy/indexcheck/indexcheck.go @@ -135,6 +135,10 @@ func (r *indexcheckingReader) LegacyReadNamespaceByName(ctx context.Context, nsN return r.delegate.LegacyReadNamespaceByName(ctx, nsName) } +func (r *indexcheckingReader) ReadStoredSchema(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + return r.delegate.ReadStoredSchema(ctx) +} + func (r *indexcheckingReader) mustEnsureIndexes(ctx context.Context, sql string, args []any, shape queryshape.Shape, explain string, expectedIndexes options.SQLIndexInformation) error { // If no indexes are expected, there is nothing to check. if len(expectedIndexes.ExpectedIndexNames) == 0 { @@ -226,6 +230,10 @@ func (rwt *indexcheckingRWT) BulkLoad(ctx context.Context, iter datastore.BulkWr return rwt.delegate.BulkLoad(ctx, iter) } +func (rwt *indexcheckingRWT) WriteStoredSchema(ctx context.Context, schema *core.StoredSchema) error { + return rwt.delegate.WriteStoredSchema(ctx, schema) +} + var ( _ datastore.Datastore = (*indexcheckingProxy)(nil) _ datastore.Reader = (*indexcheckingReader)(nil) diff --git a/internal/datastore/proxy/observable.go b/internal/datastore/proxy/observable.go index aa87bfd60..70391c07f 100644 --- a/internal/datastore/proxy/observable.go +++ b/internal/datastore/proxy/observable.go @@ -277,6 +277,12 @@ func (r *observableReader) ReverseQueryRelationships(ctx context.Context, subjec }, nil } +func (r *observableReader) ReadStoredSchema(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + ctx, closer := observe(ctx, "ReadStoredSchema", "") + defer closer() + return r.delegate.ReadStoredSchema(ctx) +} + type observableRWT struct { *observableReader delegate datastore.ReadWriteTransaction @@ -382,6 +388,12 @@ func (rwt *observableRWT) BulkLoad(ctx context.Context, iter datastore.BulkWrite return rwt.delegate.BulkLoad(ctx, iter) } +func (rwt *observableRWT) WriteStoredSchema(ctx context.Context, schema *core.StoredSchema) error { + ctx, closer := observe(ctx, "WriteStoredSchema", "") + defer closer() + return rwt.delegate.WriteStoredSchema(ctx, schema) +} + // nolint:spancheck func observe(ctx context.Context, name string, queryShape string, opts ...trace.SpanStartOption) (context.Context, func()) { if queryShape == "" { diff --git a/internal/datastore/proxy/proxy_test/mock.go b/internal/datastore/proxy/proxy_test/mock.go index 14b3176ce..92525a86b 100644 --- a/internal/datastore/proxy/proxy_test/mock.go +++ b/internal/datastore/proxy/proxy_test/mock.go @@ -201,6 +201,15 @@ func (dm *MockReader) LegacyListAllCaveats(_ context.Context) ([]datastore.Revis return args.Get(0).([]datastore.RevisionedCaveat), args.Error(1) } +func (dm *MockReader) ReadStoredSchema(_ context.Context) (*datastore.ReadOnlyStoredSchema, error) { + args := dm.Called() + var schema *datastore.ReadOnlyStoredSchema + if args.Get(0) != nil { + schema = args.Get(0).(*datastore.ReadOnlyStoredSchema) + } + return schema, args.Error(1) +} + type MockReadWriteTransaction struct { mock.Mock } @@ -361,6 +370,20 @@ func (dm *MockReadWriteTransaction) StoreCounterValue(ctx context.Context, name return args.Error(0) } +func (dm *MockReadWriteTransaction) ReadStoredSchema(_ context.Context) (*datastore.ReadOnlyStoredSchema, error) { + args := dm.Called() + var schema *datastore.ReadOnlyStoredSchema + if args.Get(0) != nil { + schema = args.Get(0).(*datastore.ReadOnlyStoredSchema) + } + return schema, args.Error(1) +} + +func (dm *MockReadWriteTransaction) WriteStoredSchema(_ context.Context, schema *core.StoredSchema) error { + args := dm.Called(schema) + return args.Error(0) +} + var ( _ datastore.Datastore = &MockDatastore{} _ datastore.Reader = &MockReader{} diff --git a/internal/datastore/proxy/relationshipintegrity.go b/internal/datastore/proxy/relationshipintegrity.go index 6d1fe6d72..a9ae1f5f1 100644 --- a/internal/datastore/proxy/relationshipintegrity.go +++ b/internal/datastore/proxy/relationshipintegrity.go @@ -385,6 +385,10 @@ func (r relationshipIntegrityReader) LegacyReadNamespaceByName(ctx context.Conte return r.wrapped.LegacyReadNamespaceByName(ctx, nsName) } +func (r relationshipIntegrityReader) ReadStoredSchema(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + return r.wrapped.ReadStoredSchema(ctx) +} + type relationshipIntegrityTx struct { datastore.ReadWriteTransaction diff --git a/internal/datastore/proxy/schemacaching/watchingcache_test.go b/internal/datastore/proxy/schemacaching/watchingcache_test.go index c9be53699..f5240b4c2 100644 --- a/internal/datastore/proxy/schemacaching/watchingcache_test.go +++ b/internal/datastore/proxy/schemacaching/watchingcache_test.go @@ -1019,3 +1019,7 @@ func (*fakeSnapshotReader) QueryRelationships(context.Context, datastore.Relatio func (*fakeSnapshotReader) ReverseQueryRelationships(context.Context, datastore.SubjectsFilter, ...options.ReverseQueryOptionsOption) (datastore.RelationshipIterator, error) { return nil, fmt.Errorf("not implemented") } + +func (*fakeSnapshotReader) ReadStoredSchema(_ context.Context) (*datastore.ReadOnlyStoredSchema, error) { + return nil, nil +} diff --git a/internal/datastore/proxy/strictreplicated.go b/internal/datastore/proxy/strictreplicated.go index 6ec93007a..5d664a08c 100644 --- a/internal/datastore/proxy/strictreplicated.go +++ b/internal/datastore/proxy/strictreplicated.go @@ -260,3 +260,13 @@ func (rr *strictReadReplicatedReader) LookupCounters(ctx context.Context) ([]dat } return counters, err } + +func (rr *strictReadReplicatedReader) ReadStoredSchema(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + sr := rr.replica.SnapshotReader(rr.rev) + schema, err := sr.ReadStoredSchema(ctx) + if err != nil && errors.As(err, &common.RevisionUnavailableError{}) { + log.Trace().Str("revision", rr.rev.String()).Msg("replica does not contain the requested revision, using primary") + return rr.primary.SnapshotReader(rr.rev).ReadStoredSchema(ctx) + } + return schema, err +} diff --git a/internal/datastore/spanner/migrations/zz_migration.0012_add_schema_tables.go b/internal/datastore/spanner/migrations/zz_migration.0012_add_schema_tables.go new file mode 100644 index 000000000..2f44accf1 --- /dev/null +++ b/internal/datastore/spanner/migrations/zz_migration.0012_add_schema_tables.go @@ -0,0 +1,40 @@ +package migrations + +import ( + "context" + + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" +) + +const ( + createSchemaTable = `CREATE TABLE schema ( + name STRING(1024) NOT NULL, + chunk_index INT64 NOT NULL, + chunk_data BYTES(MAX) NOT NULL, + timestamp TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true) + ) PRIMARY KEY (name, chunk_index)` + + createSchemaRevisionTable = `CREATE TABLE schema_revision ( + name STRING(1024) NOT NULL, + schema_hash BYTES(MAX) NOT NULL, + timestamp TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true) + ) PRIMARY KEY (name)` +) + +func init() { + if err := SpannerMigrations.Register("add-schema-tables", "add-expiration-support", func(ctx context.Context, w Wrapper) error { + updateOp, err := w.adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{ + Database: w.client.DatabaseName(), + Statements: []string{ + createSchemaTable, + createSchemaRevisionTable, + }, + }) + if err != nil { + return err + } + return updateOp.Wait(ctx) + }, nil); err != nil { + panic("failed to register migration: " + err.Error()) + } +} diff --git a/internal/datastore/spanner/migrations/zz_migration.0013_populate_schema_tables.go b/internal/datastore/spanner/migrations/zz_migration.0013_populate_schema_tables.go new file mode 100644 index 000000000..c868e67f2 --- /dev/null +++ b/internal/datastore/spanner/migrations/zz_migration.0013_populate_schema_tables.go @@ -0,0 +1,169 @@ +package migrations + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "sort" + + "cloud.google.com/go/spanner" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" +) + +const ( + schemaChunkSize = 1024 * 1024 // 1MB chunks + currentSchemaVersion = 1 + unifiedSchemaName = "unified_schema" + schemaRevisionName = "current" +) + +func init() { + if err := SpannerMigrations.Register("populate-schema-tables", "add-schema-tables", + nil, // No DDL changes needed + func(ctx context.Context, rwt *spanner.ReadWriteTransaction) error { + // Read all existing namespaces + stmt := spanner.Statement{ + SQL: `SELECT namespace, serialized_config + FROM namespace_config + WHERE timestamp = (SELECT MAX(timestamp) FROM namespace_config nc WHERE nc.namespace = namespace_config.namespace)`, + } + + iter := rwt.Query(ctx, stmt) + defer iter.Stop() + + namespaces := make(map[string]*core.NamespaceDefinition) + err := iter.Do(func(row *spanner.Row) error { + var name string + var config []byte + if err := row.Columns(&name, &config); err != nil { + return fmt.Errorf("failed to scan namespace: %w", err) + } + + var ns core.NamespaceDefinition + if err := ns.UnmarshalVT(config); err != nil { + return fmt.Errorf("failed to unmarshal namespace %s: %w", name, err) + } + namespaces[name] = &ns + return nil + }) + if err != nil { + return fmt.Errorf("failed to query namespaces: %w", err) + } + + // Read all existing caveats + stmt = spanner.Statement{ + SQL: `SELECT name, definition + FROM caveat + WHERE timestamp = (SELECT MAX(timestamp) FROM caveat c WHERE c.name = caveat.name)`, + } + + iter = rwt.Query(ctx, stmt) + defer iter.Stop() + + caveats := make(map[string]*core.CaveatDefinition) + err = iter.Do(func(row *spanner.Row) error { + var name string + var definition []byte + if err := row.Columns(&name, &definition); err != nil { + return fmt.Errorf("failed to scan caveat: %w", err) + } + + var caveat core.CaveatDefinition + if err := caveat.UnmarshalVT(definition); err != nil { + return fmt.Errorf("failed to unmarshal caveat %s: %w", name, err) + } + caveats[name] = &caveat + return nil + }) + if err != nil { + return fmt.Errorf("failed to query caveats: %w", err) + } + + // If there are no namespaces or caveats, skip migration + if len(namespaces) == 0 && len(caveats) == 0 { + return nil + } + + // Generate canonical schema for hash computation + allDefs := make([]compiler.SchemaDefinition, 0, len(namespaces)+len(caveats)) + for _, ns := range namespaces { + allDefs = append(allDefs, ns) + } + for _, caveat := range caveats { + allDefs = append(allDefs, caveat) + } + + // Sort alphabetically for canonical ordering + sort.Slice(allDefs, func(i, j int) bool { + return allDefs[i].GetName() < allDefs[j].GetName() + }) + + // Generate canonical schema text + canonicalSchemaText, _, err := generator.GenerateSchema(allDefs) + if err != nil { + return fmt.Errorf("failed to generate canonical schema: %w", err) + } + + // Compute SHA256 hash + hashBytes := sha256.Sum256([]byte(canonicalSchemaText)) + schemaHash := hashBytes[:] + + // Generate user-facing schema text + schemaText, _, err := generator.GenerateSchema(allDefs) + if err != nil { + return fmt.Errorf("failed to generate schema: %w", err) + } + + // Create stored schema proto + storedSchema := &core.StoredSchema{ + Version: currentSchemaVersion, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: schemaText, + SchemaHash: hex.EncodeToString(schemaHash), + NamespaceDefinitions: namespaces, + CaveatDefinitions: caveats, + }, + }, + } + + // Marshal schema + schemaData, err := storedSchema.MarshalVT() + if err != nil { + return fmt.Errorf("failed to marshal schema: %w", err) + } + + // Insert schema chunks using mutations + var mutations []*spanner.Mutation + for chunkIndex := 0; chunkIndex*schemaChunkSize < len(schemaData); chunkIndex++ { + start := chunkIndex * schemaChunkSize + end := start + schemaChunkSize + if end > len(schemaData) { + end = len(schemaData) + } + chunk := schemaData[start:end] + + mutations = append(mutations, spanner.Insert( + "schema", + []string{"name", "chunk_index", "chunk_data", "timestamp"}, + []any{unifiedSchemaName, chunkIndex, chunk, spanner.CommitTimestamp}, + )) + } + + // Insert schema hash + mutations = append(mutations, spanner.Insert( + "schema_revision", + []string{"name", "schema_hash", "timestamp"}, + []any{schemaRevisionName, schemaHash, spanner.CommitTimestamp}, + )) + + // Apply mutations + return rwt.BufferWrite(mutations) + }); err != nil { + panic("failed to register migration: " + err.Error()) + } +} diff --git a/internal/datastore/spanner/schema.go b/internal/datastore/spanner/schema.go index 8b5dabc39..12247695e 100644 --- a/internal/datastore/spanner/schema.go +++ b/internal/datastore/spanner/schema.go @@ -35,6 +35,11 @@ const ( tableTransactionMetadata = "transaction_metadata" colTransactionTag = "transaction_tag" colMetadata = "metadata" + + tableSchema = "schema" + colSchemaName = "name" + colSchemaChunkIndex = "chunk_index" + colSchemaChunkData = "chunk_data" ) var allRelationshipCols = []string{ diff --git a/internal/datastore/spanner/schema_chunker.go b/internal/datastore/spanner/schema_chunker.go new file mode 100644 index 000000000..20106f0ba --- /dev/null +++ b/internal/datastore/spanner/schema_chunker.go @@ -0,0 +1,168 @@ +package spanner + +import ( + "context" + "fmt" + + "cloud.google.com/go/spanner" + sq "github.com/Masterminds/squirrel" + + "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +const ( + // Spanner has a practical limit of 10 MiB on BYTES column size, + // but we use 1MB chunks for reasonable memory usage and query performance. + spannerMaxChunkSize = 1024 * 1024 // 1MB +) + +// BaseSchemaChunkerConfig provides the base configuration for Spanner schema chunking. +// Spanner uses @p-style placeholders and delete-and-insert write mode. +var BaseSchemaChunkerConfig = common.SQLByteChunkerConfig[any]{ + TableName: tableSchema, + NameColumn: colSchemaName, + ChunkIndexColumn: colSchemaChunkIndex, + ChunkDataColumn: colSchemaChunkData, + MaxChunkSize: spannerMaxChunkSize, + PlaceholderFormat: sq.AtP, + WriteMode: common.WriteModeDeleteAndInsert, +} + +// spannerChunkedBytesExecutor implements common.ChunkedBytesExecutor for Spanner's mutation-based API. +type spannerChunkedBytesExecutor struct { + rwt *spanner.ReadWriteTransaction +} + +// newSpannerChunkedBytesExecutor creates a new executor for Spanner chunk operations. +func newSpannerChunkedBytesExecutor(rwt *spanner.ReadWriteTransaction) *spannerChunkedBytesExecutor { + return &spannerChunkedBytesExecutor{rwt: rwt} +} + +// BeginTransaction returns a transaction wrapper for Spanner operations. +func (e *spannerChunkedBytesExecutor) BeginTransaction(ctx context.Context) (common.ChunkedBytesTransaction, error) { + return &spannerChunkedBytesTransaction{rwt: e.rwt}, nil +} + +// ExecuteRead executes a SELECT query and returns chunk data. +func (e *spannerChunkedBytesExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + sql, args, err := builder.ToSql() + if err != nil { + return nil, fmt.Errorf("failed to build query: %w", err) + } + + iter := e.rwt.Query(ctx, statementFromSQL(sql, args)) + defer iter.Stop() + + chunks := make(map[int][]byte) + if err := iter.Do(func(row *spanner.Row) error { + var chunkIndex int64 + var chunkData []byte + if err := row.Columns(&chunkIndex, &chunkData); err != nil { + return err + } + chunks[int(chunkIndex)] = chunkData + return nil + }); err != nil { + return nil, err + } + + return chunks, nil +} + +// spannerChunkedBytesTransaction implements common.ChunkedBytesTransaction for Spanner. +type spannerChunkedBytesTransaction struct { + rwt *spanner.ReadWriteTransaction +} + +// ExecuteWrite converts an INSERT builder to Spanner mutations. +func (t *spannerChunkedBytesTransaction) ExecuteWrite(ctx context.Context, builder sq.InsertBuilder) error { + _, args, err := builder.ToSql() + if err != nil { + return fmt.Errorf("failed to build insert: %w", err) + } + + // Convert the INSERT statement args to Spanner mutations. + // This assumes the specific format from the chunker (validated in tests). + mutations, err := t.convertInsertToMutations(args) + if err != nil { + return err + } + + return t.rwt.BufferWrite(mutations) +} + +// ExecuteDelete converts a DELETE builder to Spanner mutations. +func (t *spannerChunkedBytesTransaction) ExecuteDelete(ctx context.Context, builder sq.DeleteBuilder) error { + // For schema table, we can just delete all keys + // The chunker only deletes from the schema table by name + mutation := spanner.Delete(tableSchema, spanner.AllKeys()) + return t.rwt.BufferWrite([]*spanner.Mutation{mutation}) +} + +// ExecuteUpdate converts an UPDATE builder to Spanner mutations. +func (t *spannerChunkedBytesTransaction) ExecuteUpdate(ctx context.Context, builder sq.UpdateBuilder) error { + return spiceerrors.MustBugf("ExecuteUpdate not implemented for Spanner chunked bytes") +} + +// convertInsertToMutations converts INSERT args to Spanner mutations. +// The chunker generates args in groups of 3: [name, chunk_index, chunk_data] per row. +// Multi-MB schemas produce multiple chunks, so we iterate over args in groups. +func (t *spannerChunkedBytesTransaction) convertInsertToMutations(args []any) ([]*spanner.Mutation, error) { + const argsPerRow = 3 + if len(args) == 0 || len(args)%argsPerRow != 0 { + return nil, fmt.Errorf("expected args in groups of %d from chunker, got %d", argsPerRow, len(args)) + } + + cols := []string{colSchemaName, colSchemaChunkIndex, colSchemaChunkData, colTimestamp} + mutations := make([]*spanner.Mutation, 0, len(args)/argsPerRow) + for i := 0; i < len(args); i += argsPerRow { + vals := []any{args[i], args[i+1], args[i+2], spanner.CommitTimestamp} + mutations = append(mutations, spanner.Insert(tableSchema, cols, vals)) + } + + return mutations, nil +} + +// spannerSchemaReadExecutor implements common.ChunkedBytesExecutor for read-only operations. +type spannerSchemaReadExecutor struct { + txSource txFactory +} + +// BeginTransaction returns nil since read operations don't need transactions. +func (e *spannerSchemaReadExecutor) BeginTransaction(ctx context.Context) (common.ChunkedBytesTransaction, error) { + return nil, spiceerrors.MustBugf("BeginTransaction not supported for read-only executor") +} + +// ExecuteRead executes a SELECT query and returns chunk data. +func (e *spannerSchemaReadExecutor) ExecuteRead(ctx context.Context, builder sq.SelectBuilder) (map[int][]byte, error) { + sql, args, err := builder.ToSql() + if err != nil { + return nil, fmt.Errorf("failed to build query: %w", err) + } + + tx := e.txSource() + iter := tx.Query(ctx, statementFromSQL(sql, args)) + defer iter.Stop() + + chunks := make(map[int][]byte) + if err := iter.Do(func(row *spanner.Row) error { + var chunkIndex int64 + var chunkData []byte + if err := row.Columns(&chunkIndex, &chunkData); err != nil { + return err + } + chunks[int(chunkIndex)] = chunkData + return nil + }); err != nil { + return nil, err + } + + return chunks, nil +} + +var ( + _ common.ChunkedBytesExecutor = (*spannerChunkedBytesExecutor)(nil) + _ common.ChunkedBytesTransaction = (*spannerChunkedBytesTransaction)(nil) + _ common.ChunkedBytesExecutor = (*spannerSchemaReadExecutor)(nil) +) diff --git a/internal/datastore/spanner/schema_chunker_test.go b/internal/datastore/spanner/schema_chunker_test.go new file mode 100644 index 000000000..c4af47b2f --- /dev/null +++ b/internal/datastore/spanner/schema_chunker_test.go @@ -0,0 +1,201 @@ +package spanner + +import ( + "context" + "testing" + + "cloud.google.com/go/spanner" + sq "github.com/Masterminds/squirrel" + "github.com/stretchr/testify/require" + + "github.com/authzed/spicedb/internal/datastore/common" +) + +// fakeTransaction captures SQL and args via builder.ToSql() for verification. +type fakeTransaction struct { + capturedSQL []string + capturedArgs [][]any + deleteQueries []string +} + +func (f *fakeTransaction) ExecuteWrite(_ context.Context, builder sq.InsertBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + f.capturedSQL = append(f.capturedSQL, sql) + f.capturedArgs = append(f.capturedArgs, args) + return nil +} + +func (f *fakeTransaction) ExecuteDelete(_ context.Context, builder sq.DeleteBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + f.capturedSQL = append(f.capturedSQL, sql) + f.capturedArgs = append(f.capturedArgs, args) + f.deleteQueries = append(f.deleteQueries, sql) + return nil +} + +func (f *fakeTransaction) ExecuteUpdate(_ context.Context, builder sq.UpdateBuilder) error { + sql, args, err := builder.ToSql() + if err != nil { + return err + } + f.capturedSQL = append(f.capturedSQL, sql) + f.capturedArgs = append(f.capturedArgs, args) + return nil +} + +// fakeExecutor returns the fakeTransaction from BeginTransaction. +type fakeExecutor struct { + transaction *fakeTransaction + readResult map[int][]byte +} + +func (e *fakeExecutor) BeginTransaction(_ context.Context) (common.ChunkedBytesTransaction, error) { + return e.transaction, nil +} + +func (e *fakeExecutor) ExecuteRead(_ context.Context, _ sq.SelectBuilder) (map[int][]byte, error) { + return e.readResult, nil +} + +func TestBaseSchemaChunkerConfig(t *testing.T) { + require.Equal(t, tableSchema, BaseSchemaChunkerConfig.TableName) + require.Equal(t, colSchemaName, BaseSchemaChunkerConfig.NameColumn) + require.Equal(t, colSchemaChunkIndex, BaseSchemaChunkerConfig.ChunkIndexColumn) + require.Equal(t, colSchemaChunkData, BaseSchemaChunkerConfig.ChunkDataColumn) + require.Equal(t, 1024*1024, BaseSchemaChunkerConfig.MaxChunkSize) + require.Equal(t, sq.AtP, BaseSchemaChunkerConfig.PlaceholderFormat) + require.Equal(t, common.WriteModeDeleteAndInsert, BaseSchemaChunkerConfig.WriteMode) +} + +func TestWrite(t *testing.T) { + txn := &fakeTransaction{} + executor := &fakeExecutor{transaction: txn} + + config := BaseSchemaChunkerConfig.WithExecutor(executor) + chunker := common.MustNewSQLByteChunker(config) + + err := chunker.WriteChunkedBytes(context.Background(), "test-key", []byte("hello"), nil) + require.NoError(t, err) + + // Should have DELETE + INSERT + require.Len(t, txn.capturedSQL, 2) + require.Len(t, txn.deleteQueries, 1) + + // DELETE uses @p placeholders + require.Contains(t, txn.capturedSQL[0], "DELETE FROM "+tableSchema) + require.Contains(t, txn.capturedSQL[0], "@p1") + + // INSERT uses @p placeholders + require.Contains(t, txn.capturedSQL[1], "INSERT INTO "+tableSchema) + require.Contains(t, txn.capturedSQL[1], "@p1") +} + +func TestDelete(t *testing.T) { + txn := &fakeTransaction{} + executor := &fakeExecutor{transaction: txn} + + config := BaseSchemaChunkerConfig.WithExecutor(executor) + chunker := common.MustNewSQLByteChunker(config) + + err := chunker.DeleteChunkedBytes(context.Background(), "test-key", nil) + require.NoError(t, err) + + require.Len(t, txn.capturedSQL, 1) + require.Contains(t, txn.capturedSQL[0], "DELETE FROM "+tableSchema) +} + +func TestRead(t *testing.T) { + executor := &fakeExecutor{ + readResult: map[int][]byte{ + 0: []byte("hello"), + }, + } + + config := BaseSchemaChunkerConfig.WithExecutor(executor) + chunker := common.MustNewSQLByteChunker(config) + + data, err := chunker.ReadChunkedBytes(context.Background(), "test-key") + require.NoError(t, err) + require.Equal(t, []byte("hello"), data) +} + +func TestMultipleChunks(t *testing.T) { + txn := &fakeTransaction{} + executor := &fakeExecutor{transaction: txn} + + config := BaseSchemaChunkerConfig.WithExecutor(executor) + config.MaxChunkSize = 5 + chunker := common.MustNewSQLByteChunker(config) + + err := chunker.WriteChunkedBytes(context.Background(), "test-key", []byte("hello world!"), nil) + require.NoError(t, err) + + // DELETE + INSERT + require.Len(t, txn.capturedSQL, 2) + + // "hello world!" is 12 bytes, chunk size 5 => 3 chunks (5+5+2) + // Each chunk has 3 args (name, chunk_index, chunk_data) for delete-and-insert mode + insertArgs := txn.capturedArgs[1] + require.Len(t, insertArgs, 9) // 3 chunks * 3 values +} + +func TestConvertInsertToMutations_SingleRow(t *testing.T) { + txn := &spannerChunkedBytesTransaction{} + + args := []any{"schema_name", int64(0), []byte("chunk_data_0")} + mutations, err := txn.convertInsertToMutations(args) + require.NoError(t, err) + require.Len(t, mutations, 1) +} + +func TestConvertInsertToMutations_MultipleRows(t *testing.T) { + txn := &spannerChunkedBytesTransaction{} + + args := []any{ + "schema_name", int64(0), []byte("chunk_0"), + "schema_name", int64(1), []byte("chunk_1"), + "schema_name", int64(2), []byte("chunk_2"), + } + mutations, err := txn.convertInsertToMutations(args) + require.NoError(t, err) + require.Len(t, mutations, 3) +} + +func TestConvertInsertToMutations_EmptyArgs(t *testing.T) { + txn := &spannerChunkedBytesTransaction{} + + _, err := txn.convertInsertToMutations([]any{}) + require.Error(t, err) + require.Contains(t, err.Error(), "expected args in groups of 3") +} + +func TestConvertInsertToMutations_InvalidArgCount(t *testing.T) { + txn := &spannerChunkedBytesTransaction{} + + _, err := txn.convertInsertToMutations([]any{"a", "b"}) + require.Error(t, err) + require.Contains(t, err.Error(), "expected args in groups of 3") +} + +// Verify the returned mutations are spanner.Mutation (non-nil). +func TestConvertInsertToMutations_MutationsAreValid(t *testing.T) { + txn := &spannerChunkedBytesTransaction{} + + args := []any{ + "name1", int64(0), []byte("data0"), + "name1", int64(1), []byte("data1"), + } + mutations, err := txn.convertInsertToMutations(args) + require.NoError(t, err) + require.Len(t, mutations, 2) + + for _, m := range mutations { + require.IsType(t, &spanner.Mutation{}, m) + } +} diff --git a/internal/datastore/spanner/storedschema.go b/internal/datastore/spanner/storedschema.go new file mode 100644 index 000000000..82ca73c65 --- /dev/null +++ b/internal/datastore/spanner/storedschema.go @@ -0,0 +1,38 @@ +package spanner + +import ( + "context" + "fmt" + + "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/pkg/datastore" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +// ReadStoredSchema reads the unified stored schema from the Spanner schema table. +func (sr spannerReader) ReadStoredSchema(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + executor := &spannerSchemaReadExecutor{ + txSource: sr.txSource, + } + + chunker, err := common.NewSQLByteChunker(BaseSchemaChunkerConfig.WithExecutor(executor)) + if err != nil { + return nil, fmt.Errorf("failed to create schema chunker: %w", err) + } + + rw := common.NewSQLSingleStoreSchemaReaderWriterWithBuiltInMVCC(chunker) + return rw.ReadStoredSchema(ctx) +} + +// WriteStoredSchema writes the unified stored schema to the Spanner schema table. +func (rwt spannerReadWriteTXN) WriteStoredSchema(ctx context.Context, schema *core.StoredSchema) error { + executor := newSpannerChunkedBytesExecutor(rwt.spannerRWT) + + chunker, err := common.NewSQLByteChunker(BaseSchemaChunkerConfig.WithExecutor(executor)) + if err != nil { + return fmt.Errorf("failed to create schema chunker: %w", err) + } + + rw := common.NewSQLSingleStoreSchemaReaderWriterWithBuiltInMVCC(chunker) + return rw.WriteStoredSchema(ctx, schema) +} diff --git a/internal/dispatch/caching/cachingdispatch_test.go b/internal/dispatch/caching/cachingdispatch_test.go index 1cc55db7a..60edcc610 100644 --- a/internal/dispatch/caching/cachingdispatch_test.go +++ b/internal/dispatch/caching/cachingdispatch_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/pkg/datalayer" core "github.com/authzed/spicedb/pkg/proto/core/v1" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" "github.com/authzed/spicedb/pkg/tuple" @@ -105,6 +106,7 @@ func TestMaxDepthCaching(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: step.atRevision.String(), DepthRemaining: step.depthRemaining, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }).Return(&v1.DispatchCheckResponse{ ResultsByResourceId: map[string]*v1.ResourceCheckResult{ @@ -136,6 +138,7 @@ func TestMaxDepthCaching(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: step.atRevision.String(), DepthRemaining: step.depthRemaining, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }) require.NoError(err) @@ -196,6 +199,7 @@ func TestConcurrentDebugInfoAccess(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: decimal.Zero.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, Debug: v1.DispatchCheckRequest_ENABLE_BASIC_DEBUGGING, } diff --git a/internal/dispatch/combined/combined_test.go b/internal/dispatch/combined/combined_test.go index 2629c5581..32008e0a1 100644 --- a/internal/dispatch/combined/combined_test.go +++ b/internal/dispatch/combined/combined_test.go @@ -53,6 +53,7 @@ func TestCombinedRecursiveCall(t *testing.T) { Metadata: &dispatchv1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }) require.Error(t, err) diff --git a/internal/dispatch/graph/check_test.go b/internal/dispatch/graph/check_test.go index 53de00587..e01712154 100644 --- a/internal/dispatch/graph/check_test.go +++ b/internal/dispatch/graph/check_test.go @@ -135,6 +135,7 @@ func TestSimpleCheck(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }) @@ -184,6 +185,7 @@ func TestMaxDepth(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }) @@ -279,6 +281,7 @@ func TestCheckMetadata(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }) @@ -1442,6 +1445,7 @@ func TestCheckPermissionOverSchema(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, ResultsSetting: v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT, }) @@ -1570,6 +1574,7 @@ func TestCheckDebugging(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, Debug: v1.DispatchCheckRequest_ENABLE_BASIC_DEBUGGING, }) @@ -1951,6 +1956,7 @@ func TestCheckWithHints(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, CheckHints: tc.hints, }) @@ -2003,6 +2009,7 @@ func TestCheckHintsPartialApplication(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, CheckHints: []*v1.CheckHint{ hints.CheckHintForComputedUserset("document", "anotherdoc", "viewer", ONR("user", "tom", graph.Ellipsis), &v1.ResourceCheckResult{ @@ -2058,6 +2065,7 @@ func TestCheckHintsPartialApplicationOverArrow(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, CheckHints: []*v1.CheckHint{ hints.CheckHintForArrow("document", "anotherdoc", "org", "member", ONR("user", "tom", graph.Ellipsis), &v1.ResourceCheckResult{ diff --git a/internal/dispatch/graph/dispatch_test.go b/internal/dispatch/graph/dispatch_test.go index 26394f2ee..e6a92ef12 100644 --- a/internal/dispatch/graph/dispatch_test.go +++ b/internal/dispatch/graph/dispatch_test.go @@ -9,6 +9,7 @@ import ( "github.com/authzed/spicedb/internal/dispatch" "github.com/authzed/spicedb/internal/graph" + "github.com/authzed/spicedb/pkg/datalayer" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" "github.com/authzed/spicedb/pkg/tuple" ) @@ -47,6 +48,7 @@ func TestDispatchChunking(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }) @@ -70,6 +72,7 @@ func TestDispatchChunking(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, OptionalLimit: veryLargeLimit, }, stream) @@ -94,6 +97,7 @@ func TestDispatchChunking(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, stream) diff --git a/internal/dispatch/graph/expand_test.go b/internal/dispatch/graph/expand_test.go index 254bc7a19..1e8e6a518 100644 --- a/internal/dispatch/graph/expand_test.go +++ b/internal/dispatch/graph/expand_test.go @@ -176,6 +176,7 @@ func TestExpand(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, ExpansionMode: tc.expansionMode, }) @@ -304,6 +305,7 @@ func TestMaxDepthExpand(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, ExpansionMode: v1.DispatchExpandRequest_SHALLOW, }) @@ -914,6 +916,7 @@ func TestExpandOverSchema(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, ExpansionMode: tc.expansionMode, }) diff --git a/internal/dispatch/graph/graph.go b/internal/dispatch/graph/graph.go index 58188b52d..201b95966 100644 --- a/internal/dispatch/graph/graph.go +++ b/internal/dispatch/graph/graph.go @@ -203,8 +203,8 @@ type localDispatcher struct { lookupResourcesHandler3 *graph.CursoredLookupResources3 } -func (ld *localDispatcher) loadNamespace(ctx context.Context, nsName string, revision datastore.Revision) (*core.NamespaceDefinition, error) { - reader := datalayer.MustFromContext(ctx).SnapshotReader(revision) +func (ld *localDispatcher) loadNamespace(ctx context.Context, nsName string, revision datastore.Revision, schemaHash datalayer.SchemaHash) (*core.NamespaceDefinition, error) { + reader := datalayer.MustFromContext(ctx).SnapshotReader(revision, schemaHash) // Load namespace and relation from the datastore schemaReader, err := reader.ReadSchema(ctx) @@ -286,7 +286,7 @@ func (ld *localDispatcher) DispatchCheck(ctx context.Context, req *v1.DispatchCh return &v1.DispatchCheckResponse{Metadata: emptyMetadata}, rewriteError(ctx, err) } - ns, err := ld.loadNamespace(ctx, req.ResourceRelation.Namespace, revision) + ns, err := ld.loadNamespace(ctx, req.ResourceRelation.Namespace, revision, datalayer.SchemaHash(req.Metadata.GetSchemaHash())) if err != nil { return &v1.DispatchCheckResponse{Metadata: emptyMetadata}, rewriteError(ctx, err) } @@ -352,7 +352,7 @@ func (ld *localDispatcher) DispatchExpand(ctx context.Context, req *v1.DispatchE return &v1.DispatchExpandResponse{Metadata: emptyMetadata}, err } - ns, err := ld.loadNamespace(ctx, req.ResourceAndRelation.Namespace, revision) + ns, err := ld.loadNamespace(ctx, req.ResourceAndRelation.Namespace, revision, datalayer.SchemaHash(req.Metadata.GetSchemaHash())) if err != nil { return &v1.DispatchExpandResponse{Metadata: emptyMetadata}, err } diff --git a/internal/dispatch/graph/lookupresources2_test.go b/internal/dispatch/graph/lookupresources2_test.go index 63346d028..3f04f75f5 100644 --- a/internal/dispatch/graph/lookupresources2_test.go +++ b/internal/dispatch/graph/lookupresources2_test.go @@ -115,6 +115,7 @@ func TestSimpleLookupResources2(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, OptionalLimit: veryLargeLimit, }, stream) @@ -141,6 +142,7 @@ func TestSimpleLookupResources2(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, OptionalLimit: veryLargeLimit, }, stream) @@ -199,6 +201,7 @@ func TestSimpleLookupResourcesWithCursor2(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, OptionalLimit: 1, }, stream) @@ -222,6 +225,7 @@ func TestSimpleLookupResourcesWithCursor2(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, OptionalCursor: cursor, OptionalLimit: 2, @@ -259,6 +263,7 @@ func TestLookupResourcesCursorStability2(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, OptionalLimit: 2, }, stream) @@ -279,6 +284,7 @@ func TestLookupResourcesCursorStability2(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, OptionalLimit: 2, }, stream) @@ -334,6 +340,7 @@ func TestMaxDepthLookup2(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 0, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, stream) @@ -791,6 +798,7 @@ func TestLookupResources2OverSchemaWithCursors(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, OptionalLimit: uintPageSize, OptionalCursor: currentCursor, @@ -860,6 +868,7 @@ func TestLookupResources2ImmediateTimeout(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 10, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, stream) @@ -896,6 +905,7 @@ func TestLookupResources2WithError(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 10, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, stream) @@ -1376,6 +1386,7 @@ func TestLookupResources2EnsureCheckHints(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, stream) if tc.expectedError != "" { diff --git a/internal/dispatch/graph/lookupresources3_test.go b/internal/dispatch/graph/lookupresources3_test.go index 8b21c006c..4779fca72 100644 --- a/internal/dispatch/graph/lookupresources3_test.go +++ b/internal/dispatch/graph/lookupresources3_test.go @@ -114,6 +114,7 @@ func TestSimpleLookupResources3(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, OptionalLimit: veryLargeLimit, }, stream) @@ -137,6 +138,7 @@ func TestSimpleLookupResources3(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, OptionalLimit: veryLargeLimit, }, stream) @@ -192,6 +194,7 @@ func TestSimpleLookupResourcesWithCursor3(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, OptionalLimit: 1, }, stream) @@ -215,6 +218,7 @@ func TestSimpleLookupResourcesWithCursor3(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, OptionalCursor: cursor, OptionalLimit: 2, @@ -275,6 +279,7 @@ func TestMaxDepthLookup3(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 0, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, stream) @@ -732,6 +737,7 @@ func TestLookupResources3OverSchemaWithCursors(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, OptionalLimit: uintPageSize, OptionalCursor: currentCursor, @@ -809,6 +815,7 @@ func TestLookupResources3WithError(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 10, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, stream) require.Error(err) @@ -1287,6 +1294,7 @@ func TestLookupResources3EnsureCheckHints(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, stream) if tc.expectedError != "" { diff --git a/internal/dispatch/graph/lookupsubjects_test.go b/internal/dispatch/graph/lookupsubjects_test.go index ff31ab901..094b36d26 100644 --- a/internal/dispatch/graph/lookupsubjects_test.go +++ b/internal/dispatch/graph/lookupsubjects_test.go @@ -147,6 +147,7 @@ func TestSimpleLookupSubjects(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, stream) @@ -180,6 +181,7 @@ func TestSimpleLookupSubjects(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }) @@ -221,6 +223,7 @@ func TestLookupSubjectsMaxDepth(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, stream) require.Error(err) @@ -270,6 +273,7 @@ func TestLookupSubjectsDispatchCount(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, stream) @@ -1024,6 +1028,7 @@ func TestLookupSubjectsOverSchema(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: revision.String(), DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, stream) require.NoError(err) diff --git a/internal/dispatch/keys/computed_test.go b/internal/dispatch/keys/computed_test.go index dd55b4c34..2f6395262 100644 --- a/internal/dispatch/keys/computed_test.go +++ b/internal/dispatch/keys/computed_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/structpb" + "github.com/authzed/spicedb/pkg/datalayer" "github.com/authzed/spicedb/pkg/genutil/mapz" core "github.com/authzed/spicedb/pkg/proto/core/v1" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" @@ -44,6 +45,7 @@ func TestStableCacheKeys(t *testing.T) { Subject: ONR("user", "tom", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, computeBothHashes) }, @@ -58,6 +60,7 @@ func TestStableCacheKeys(t *testing.T) { Subject: ONR("user", "tom", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, computeBothHashes) }, @@ -72,6 +75,7 @@ func TestStableCacheKeys(t *testing.T) { Subject: ONR("user", "sarah", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "123456", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, computeBothHashes) }, @@ -86,6 +90,7 @@ func TestStableCacheKeys(t *testing.T) { Subject: ONR("user", "tom", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, "view") return key @@ -99,6 +104,7 @@ func TestStableCacheKeys(t *testing.T) { ResourceAndRelation: ONR("document", "foo", "view"), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, computeBothHashes) }, @@ -111,6 +117,7 @@ func TestStableCacheKeys(t *testing.T) { ResourceAndRelation: ONR("document", "foo2", "view"), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, computeBothHashes) }, @@ -123,6 +130,7 @@ func TestStableCacheKeys(t *testing.T) { ResourceAndRelation: ONR("document", "foo2", "view"), Metadata: &v1.ResolverMeta{ AtRevision: "1235", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, computeBothHashes) }, @@ -137,6 +145,7 @@ func TestStableCacheKeys(t *testing.T) { ResourceIds: []string{"mariah", "tom"}, Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, computeBothHashes) }, @@ -152,6 +161,7 @@ func TestStableCacheKeys(t *testing.T) { TerminalSubject: ONR("user", "mariah", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, computeBothHashes) }, @@ -167,6 +177,7 @@ func TestStableCacheKeys(t *testing.T) { TerminalSubject: ONR("user", "mariah", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, OptionalLimit: 0, }, computeBothHashes) @@ -183,6 +194,7 @@ func TestStableCacheKeys(t *testing.T) { TerminalSubject: ONR("user", "mariah", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, OptionalLimit: 42, }, computeBothHashes) @@ -199,6 +211,7 @@ func TestStableCacheKeys(t *testing.T) { TerminalSubject: ONR("user", "mariah", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, Context: nil, }, computeBothHashes) @@ -215,6 +228,7 @@ func TestStableCacheKeys(t *testing.T) { TerminalSubject: ONR("user", "mariah", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, Context: func() *structpb.Struct { v, _ := structpb.NewStruct(map[string]any{}) @@ -234,6 +248,7 @@ func TestStableCacheKeys(t *testing.T) { TerminalSubject: ONR("user", "mariah", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, Context: func() *structpb.Struct { v, _ := structpb.NewStruct(map[string]any{ @@ -256,6 +271,7 @@ func TestStableCacheKeys(t *testing.T) { TerminalSubject: ONR("user", "mariah", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, Context: func() *structpb.Struct { v, _ := structpb.NewStruct(map[string]any{ @@ -278,6 +294,7 @@ func TestStableCacheKeys(t *testing.T) { TerminalSubject: ONR("user", "mariah", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, Context: func() *structpb.Struct { v, _ := structpb.NewStruct(map[string]any{ @@ -299,6 +316,7 @@ func TestStableCacheKeys(t *testing.T) { TerminalSubject: ONR("user", "mariah", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, Context: func() *structpb.Struct { v, _ := structpb.NewStruct(map[string]any{ @@ -324,6 +342,7 @@ func TestStableCacheKeys(t *testing.T) { TerminalSubject: ONR("user", "mariah", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, OptionalCursor: &v1.Cursor{}, }, computeBothHashes) @@ -340,6 +359,7 @@ func TestStableCacheKeys(t *testing.T) { TerminalSubject: ONR("user", "mariah", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, OptionalCursor: &v1.Cursor{ Sections: []string{"foo"}, @@ -358,6 +378,7 @@ func TestStableCacheKeys(t *testing.T) { TerminalSubject: ONR("user", "mariah", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, OptionalCursor: &v1.Cursor{ Sections: []string{"foo", "bar"}, @@ -376,6 +397,7 @@ func TestStableCacheKeys(t *testing.T) { TerminalSubject: ONR("user", "sarah", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, OptionalCursor: &v1.Cursor{ Sections: []string{"foo", "bar"}, @@ -394,6 +416,7 @@ func TestStableCacheKeys(t *testing.T) { TerminalSubject: ONR("user", "sarah", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, OptionalCursor: &v1.Cursor{ Sections: []string{"foo", "bar"}, @@ -578,6 +601,7 @@ func TestCacheKeyNoOverlap(t *testing.T) { t.Run(revision, func(t *testing.T) { metadata := &v1.ResolverMeta{ AtRevision: revision, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), } for prefix, f := range generatorFuncs { @@ -609,6 +633,7 @@ func TestComputeOnlyStableHash(t *testing.T) { Subject: ONR("user", "tom", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }, computeOnlyStableHash) @@ -623,6 +648,7 @@ func TestComputeContextHash(t *testing.T) { TerminalSubject: ONR("user", "mariah", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, Context: func() *structpb.Struct { v, _ := structpb.NewStruct(map[string]any{ diff --git a/internal/dispatch/keys/keys.go b/internal/dispatch/keys/keys.go index aac781208..8ab6a9f3e 100644 --- a/internal/dispatch/keys/keys.go +++ b/internal/dispatch/keys/keys.go @@ -106,7 +106,7 @@ func (c *CanonicalKeyHandler) CheckCacheKey(ctx context.Context, req *v1.Dispatc if err != nil { return emptyDispatchCacheKey, err } - r := dl.SnapshotReader(revision) + r := dl.SnapshotReader(revision, datalayer.SchemaHash(req.Metadata.GetSchemaHash())) sr, err := r.ReadSchema(ctx) if err != nil { diff --git a/internal/dispatch/remote/cluster_benchmark_test.go b/internal/dispatch/remote/cluster_benchmark_test.go index 742478d24..3ee452fdd 100644 --- a/internal/dispatch/remote/cluster_benchmark_test.go +++ b/internal/dispatch/remote/cluster_benchmark_test.go @@ -9,6 +9,7 @@ import ( "google.golang.org/grpc" "github.com/authzed/spicedb/internal/dispatch/keys" + "github.com/authzed/spicedb/pkg/datalayer" corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" ) @@ -65,8 +66,11 @@ func BenchmarkSecondaryDispatching(b *testing.B) { _, err = dispatcher.DispatchCheck(b.Context(), &v1.DispatchCheckRequest{ ResourceRelation: &corev1.RelationReference{Namespace: "sometype", Relation: "somerel"}, ResourceIds: []string{"foo"}, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, - Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, + Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, }) require.NoError(b, err) } diff --git a/internal/dispatch/remote/cluster_test.go b/internal/dispatch/remote/cluster_test.go index fe4757e0f..bf63ec3e9 100644 --- a/internal/dispatch/remote/cluster_test.go +++ b/internal/dispatch/remote/cluster_test.go @@ -23,6 +23,7 @@ import ( "github.com/authzed/spicedb/internal/dispatch" "github.com/authzed/spicedb/internal/dispatch/keys" "github.com/authzed/spicedb/internal/grpchelpers" + "github.com/authzed/spicedb/pkg/datalayer" corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" "github.com/authzed/spicedb/pkg/spiceerrors" @@ -208,8 +209,11 @@ func TestDispatchTimeout(t *testing.T) { resp, err := dispatcher.DispatchCheck(t.Context(), &v1.DispatchCheckRequest{ ResourceRelation: &corev1.RelationReference{Namespace: "sometype", Relation: "somerel"}, ResourceIds: []string{"foo"}, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, - Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, + Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, }) if tc.sleepTime > tc.timeout { require.Error(t, err) @@ -225,8 +229,11 @@ func TestDispatchTimeout(t *testing.T) { err = dispatcher.DispatchLookupSubjects(&v1.DispatchLookupSubjectsRequest{ ResourceRelation: &corev1.RelationReference{Namespace: "sometype", Relation: "somerel"}, ResourceIds: []string{"foo"}, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, - SubjectRelation: &corev1.RelationReference{Namespace: "sometype", Relation: "somerel"}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, + SubjectRelation: &corev1.RelationReference{Namespace: "sometype", Relation: "somerel"}, }, stream) if tc.sleepTime > tc.timeout { require.Error(t, err) @@ -258,8 +265,11 @@ func TestCheckSecondaryDispatch(t *testing.T) { Relation: "somerelation", }, ResourceIds: []string{"foo"}, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, - Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, + Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, }, 0 * time.Millisecond, 1, @@ -273,8 +283,11 @@ func TestCheckSecondaryDispatch(t *testing.T) { Relation: "somerelation", }, ResourceIds: []string{"foo"}, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, - Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, + Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, }, 1 * time.Second, 2, @@ -288,8 +301,11 @@ func TestCheckSecondaryDispatch(t *testing.T) { Relation: "somerelation", }, ResourceIds: []string{"foo"}, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, - Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, + Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, }, 1 * time.Second, 1, @@ -303,8 +319,11 @@ func TestCheckSecondaryDispatch(t *testing.T) { Relation: "somerelation", }, ResourceIds: []string{"foo"}, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, - Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, + Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, }, 1 * time.Second, 2, @@ -318,8 +337,11 @@ func TestCheckSecondaryDispatch(t *testing.T) { Relation: "somerelation", }, ResourceIds: []string{"foo"}, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, - Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, + Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, }, 1 * time.Second, 1, @@ -377,7 +399,10 @@ func TestLRSecondaryDispatch(t *testing.T) { ObjectId: "bar", Relation: "...", }, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, }, 1, false, @@ -400,7 +425,10 @@ func TestLRSecondaryDispatch(t *testing.T) { ObjectId: "bar", Relation: "...", }, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, }, 2, false, @@ -427,7 +455,10 @@ func TestLRSecondaryDispatch(t *testing.T) { Sections: []string{"somethingelse"}, DispatchVersion: 1, }, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, }, 2, // Falls back to the default secondary. false, @@ -454,7 +485,10 @@ func TestLRSecondaryDispatch(t *testing.T) { Sections: []string{secondaryCursorPrefix + "tertiary"}, DispatchVersion: 1, }, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, }, 3, false, @@ -481,7 +515,10 @@ func TestLRSecondaryDispatch(t *testing.T) { Sections: []string{secondaryCursorPrefix + "error"}, DispatchVersion: 1, }, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, }, 1, true, // since the secondary was in the cursor, if it errors, the operation fails. @@ -504,7 +541,10 @@ func TestLRSecondaryDispatch(t *testing.T) { ObjectId: "bar", Relation: "...", }, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, }, 1, false, @@ -531,7 +571,10 @@ func TestLRSecondaryDispatch(t *testing.T) { Sections: []string{secondaryCursorPrefix + "unknown"}, DispatchVersion: 1, }, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, }, 0, true, @@ -558,7 +601,10 @@ func TestLRSecondaryDispatch(t *testing.T) { Sections: []string{secondaryCursorPrefix + "secondary"}, DispatchVersion: 1, }, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, }, 2, false, @@ -585,7 +631,10 @@ func TestLRSecondaryDispatch(t *testing.T) { Sections: []string{secondaryCursorPrefix + "secondary"}, DispatchVersion: 1, }, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, }, 2, false, @@ -665,7 +714,10 @@ func TestLRDispatchFallbackToPrimary(t *testing.T) { ObjectId: "bar", Relation: "...", }, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, }, stream) require.NoError(t, err) @@ -695,7 +747,10 @@ func TestLSSecondaryDispatch(t *testing.T) { Relation: "somerelation", }, ResourceIds: []string{"foo"}, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, }, 1, false, @@ -713,7 +768,10 @@ func TestLSSecondaryDispatch(t *testing.T) { Relation: "somerelation", }, ResourceIds: []string{"foo"}, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, }, 2, false, @@ -788,7 +846,10 @@ func TestLSDispatchFallbackToPrimary(t *testing.T) { Relation: "somerelation", }, ResourceIds: []string{"foo"}, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, }, stream) require.NoError(t, err) @@ -820,8 +881,11 @@ func TestCheckUsesDelayByDefaultForPrimary(t *testing.T) { resp, err := dispatcher.DispatchCheck(t.Context(), &v1.DispatchCheckRequest{ ResourceRelation: &corev1.RelationReference{Namespace: "sometype", Relation: "somerel"}, ResourceIds: []string{"foo"}, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, - Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, + Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, }) require.NoError(t, err) require.Equal(t, uint32(2), resp.Metadata.DispatchCount) @@ -863,7 +927,10 @@ func TestStreamingDispatchDelayByDefaultForPrimary(t *testing.T) { Relation: "somerelation", }, ResourceIds: []string{"foo"}, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, }, stream) require.NoError(t, err) @@ -943,8 +1010,11 @@ func TestCheckUsesMaximumDelayByDefaultForPrimary(t *testing.T) { resp, err := dispatcher.DispatchCheck(t.Context(), &v1.DispatchCheckRequest{ ResourceRelation: &corev1.RelationReference{Namespace: "sometype", Relation: "somerel"}, ResourceIds: []string{"foo"}, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, - Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, + Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, }) require.NoError(t, err) require.Equal(t, uint32(1), resp.Metadata.DispatchCount) @@ -1080,8 +1150,11 @@ func TestCheckToUnsupportedRemovesHedgingDelay(t *testing.T) { resp, err := dispatcher.DispatchCheck(t.Context(), &v1.DispatchCheckRequest{ ResourceRelation: &corev1.RelationReference{Namespace: "sometype", Relation: "somerel"}, ResourceIds: []string{"foo"}, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, - Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, + Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, }) require.NoError(t, err) require.Equal(t, uint32(1), resp.Metadata.DispatchCount) @@ -1096,8 +1169,11 @@ func TestCheckToUnsupportedRemovesHedgingDelay(t *testing.T) { resp, err = dispatcher.DispatchCheck(t.Context(), &v1.DispatchCheckRequest{ ResourceRelation: &corev1.RelationReference{Namespace: "sometype", Relation: "somerel"}, ResourceIds: []string{"foo"}, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, - Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, + Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, }) endTime := time.Now() require.NoError(t, err) @@ -1218,7 +1294,10 @@ func TestPrimaryDispatcherErrorReturned(t *testing.T) { ObjectId: "bar", Relation: "...", }, - Metadata: &v1.ResolverMeta{DepthRemaining: 50}, + Metadata: &v1.ResolverMeta{ + DepthRemaining: 50, + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + }, }, stream) // Should get the primary dispatcher error when no dispatcher returns results diff --git a/internal/dispatch/singleflight/singleflight_test.go b/internal/dispatch/singleflight/singleflight_test.go index bc9f40bac..9158f943d 100644 --- a/internal/dispatch/singleflight/singleflight_test.go +++ b/internal/dispatch/singleflight/singleflight_test.go @@ -16,6 +16,7 @@ import ( "github.com/authzed/spicedb/internal/dispatch" "github.com/authzed/spicedb/internal/dispatch/keys" + "github.com/authzed/spicedb/pkg/datalayer" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" "github.com/authzed/spicedb/pkg/tuple" ) @@ -37,6 +38,7 @@ func TestSingleFlightDispatcher(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: "1234", TraversalBloom: v1.MustNewTraversalBloomFilter(defaultBloomFilterSize), + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, } @@ -92,6 +94,7 @@ func TestSingleFlightDispatcherDetectsLoop(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: "1234", TraversalBloom: v1.MustNewTraversalBloomFilter(defaultBloomFilterSize), + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, } @@ -148,6 +151,7 @@ func TestSingleFlightDispatcherDetectsLoopThroughDelegate(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: "1234", TraversalBloom: v1.MustNewTraversalBloomFilter(defaultBloomFilterSize), + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, } @@ -189,6 +193,7 @@ func TestSingleFlightDispatcherCancelation(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: "1234", TraversalBloom: v1.MustNewTraversalBloomFilter(defaultBloomFilterSize), + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, } @@ -241,6 +246,7 @@ func TestSingleFlightDispatcherExpand(t *testing.T) { Metadata: &v1.ResolverMeta{ AtRevision: "1234", TraversalBloom: v1.MustNewTraversalBloomFilter(defaultBloomFilterSize), + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, } @@ -286,6 +292,7 @@ func TestSingleFlightDispatcherCheckBypassesIfMissingBloomFiler(t *testing.T) { Subject: tuple.ONRStringToCore("user", "tom", "..."), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, } @@ -309,6 +316,7 @@ func TestSingleFlightDispatcherExpandBypassesIfMissingBloomFiler(t *testing.T) { ResourceAndRelation: tuple.ONRStringToCore("document", "foo", "view"), Metadata: &v1.ResolverMeta{ AtRevision: "1234", + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, } diff --git a/internal/graph/check.go b/internal/graph/check.go index ae3d63dd2..6d9a9938e 100644 --- a/internal/graph/check.go +++ b/internal/graph/check.go @@ -326,7 +326,7 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest } }() log.Ctx(ctx).Trace().Object("direct", crc.parentReq).Send() - dl := datalayer.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision, datalayer.SchemaHash(crc.parentReq.Metadata.GetSchemaHash())) directSubjectsAndWildcardsWithoutCaveats := 0 directSubjectsAndWildcardsWithoutExpiration := 0 @@ -658,7 +658,7 @@ func (cc *ConcurrentChecker) checkComputedUserset(ctx context.Context, crc curre // for TTU-based computed usersets, as directly computed ones reference relations within // the same namespace as the caller, and thus must be fully typed checked. if cu.Object == core.ComputedUserset_TUPLE_USERSET_OBJECT { - dl := datalayer.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision, datalayer.SchemaHash(crc.parentReq.Metadata.GetSchemaHash())) sr, err := dl.ReadSchema(ctx) if err != nil { return checkResultError(err, emptyMetadata) @@ -811,7 +811,7 @@ func checkIntersectionTupleToUserset( // Query for the subjects over which to walk the TTU. log.Ctx(ctx).Trace().Object("intersectionttu", crc.parentReq).Send() - dl := datalayer.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision, datalayer.SchemaHash(crc.parentReq.Metadata.GetSchemaHash())) queryOpts, err := queryOptionsForArrowRelation(ctx, dl, crc.parentReq.ResourceRelation.Namespace, ttu.GetTupleset().GetRelation()) if err != nil { return checkResultError(NewCheckFailureErr(err), emptyMetadata) @@ -977,7 +977,7 @@ func checkTupleToUserset[T relation]( defer span.End() log.Ctx(ctx).Trace().Object("ttu", crc.parentReq).Send() - dl := datalayer.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision, datalayer.SchemaHash(crc.parentReq.Metadata.GetSchemaHash())) queryOpts, err := queryOptionsForArrowRelation(ctx, dl, crc.parentReq.ResourceRelation.Namespace, ttu.GetTupleset().GetRelation()) if err != nil { diff --git a/internal/graph/check_isolated_test.go b/internal/graph/check_isolated_test.go index 0f452ffa2..de153cfa3 100644 --- a/internal/graph/check_isolated_test.go +++ b/internal/graph/check_isolated_test.go @@ -134,7 +134,7 @@ func TestTraitsForArrowRelation(t *testing.T) { ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, tc.schema, nil, require) dl := datalayer.NewDataLayer(ds) - reader := dl.SnapshotReader(revision) + reader := dl.SnapshotReader(revision, datalayer.NoSchemaHashForTesting) traits, err := graph.TraitsForArrowRelation(t.Context(), reader, tc.namespaceName, tc.relationName) if tc.expectedError != "" { diff --git a/internal/graph/computed/computecheck.go b/internal/graph/computed/computecheck.go index be29a0bca..42f0bac7f 100644 --- a/internal/graph/computed/computecheck.go +++ b/internal/graph/computed/computecheck.go @@ -48,6 +48,7 @@ type CheckParameters struct { MaximumDepth uint32 DebugOption DebugOption CheckHints []*v1.CheckHint + SchemaHash datalayer.SchemaHash } // ComputeCheck computes a check result for the given resource and subject, computing any @@ -128,6 +129,7 @@ func computeCheck(ctx context.Context, AtRevision: params.AtRevision.String(), DepthRemaining: params.MaximumDepth, TraversalBloom: bf, + SchemaHash: []byte(params.SchemaHash), }, Debug: debugging, CheckHints: params.CheckHints, @@ -178,7 +180,7 @@ func computeCaveatedCheckResult(ctx context.Context, runner *cexpr.CaveatRunner, } dl := datalayer.MustFromContext(ctx) - reader := dl.SnapshotReader(params.AtRevision) + reader := dl.SnapshotReader(params.AtRevision, params.SchemaHash) sr, err := reader.ReadSchema(ctx) if err != nil { return nil, err diff --git a/internal/graph/computed/computecheck_test.go b/internal/graph/computed/computecheck_test.go index 76f318182..ed0faa886 100644 --- a/internal/graph/computed/computecheck_test.go +++ b/internal/graph/computed/computecheck_test.go @@ -828,6 +828,7 @@ func TestComputeCheckWithCaveats(t *testing.T) { AtRevision: revision, MaximumDepth: 50, DebugOption: computed.BasicDebuggingEnabled, + SchemaHash: datalayer.NoSchemaHashForTesting, }, rel.Resource.ObjectID, 100, @@ -872,6 +873,7 @@ func TestComputeCheckError(t *testing.T) { AtRevision: datastore.NoRevision, MaximumDepth: 50, DebugOption: computed.BasicDebuggingEnabled, + SchemaHash: datalayer.NoSchemaHashForTesting, }, "id", 100, @@ -920,6 +922,7 @@ func TestComputeBulkCheck(t *testing.T) { AtRevision: revision, MaximumDepth: 50, DebugOption: computed.NoDebugging, + SchemaHash: datalayer.NoSchemaHashForTesting, }, []string{"direct", "first", "second", "third"}, 100, diff --git a/internal/graph/expand.go b/internal/graph/expand.go index 8117b0b05..4c7d1714f 100644 --- a/internal/graph/expand.go +++ b/internal/graph/expand.go @@ -58,7 +58,7 @@ func (ce *ConcurrentExpander) expandDirect( ) ReduceableExpandFunc { log.Ctx(ctx).Trace().Object("direct", req).Send() return func(ctx context.Context, resultChan chan<- ExpandResult) { - dl := datalayer.MustFromContext(ctx).SnapshotReader(req.Revision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(req.Revision, datalayer.SchemaHash(req.Metadata.GetSchemaHash())) it, err := dl.QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: req.ResourceAndRelation.Namespace, OptionalResourceIds: []string{req.ResourceAndRelation.ObjectId}, @@ -243,7 +243,7 @@ func (ce *ConcurrentExpander) expandComputedUserset(ctx context.Context, req Val } // Check if the target relation exists. If not, return nothing. - dl := datalayer.MustFromContext(ctx).SnapshotReader(req.Revision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(req.Revision, datalayer.SchemaHash(req.Metadata.GetSchemaHash())) sr, err := dl.ReadSchema(ctx) if err != nil { return expandError(err) @@ -281,7 +281,7 @@ func expandTupleToUserset[T relation]( expandFunc expandFunc, ) ReduceableExpandFunc { return func(ctx context.Context, resultChan chan<- ExpandResult) { - dl := datalayer.MustFromContext(ctx).SnapshotReader(req.Revision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(req.Revision, datalayer.SchemaHash(req.Metadata.GetSchemaHash())) it, err := dl.QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: req.ResourceAndRelation.Namespace, OptionalResourceIds: []string{req.ResourceAndRelation.ObjectId}, diff --git a/internal/graph/graph.go b/internal/graph/graph.go index 40a2bb9fe..df2ca57d7 100644 --- a/internal/graph/graph.go +++ b/internal/graph/graph.go @@ -50,6 +50,7 @@ func decrementDepth(md *v1.ResolverMeta) *v1.ResolverMeta { AtRevision: md.AtRevision, DepthRemaining: md.DepthRemaining - 1, TraversalBloom: md.TraversalBloom, + SchemaHash: md.SchemaHash, } } diff --git a/internal/graph/lookupresources2.go b/internal/graph/lookupresources2.go index 76bde632a..0e3d2b227 100644 --- a/internal/graph/lookupresources2.go +++ b/internal/graph/lookupresources2.go @@ -127,7 +127,7 @@ func (crr *CursoredLookupResources2) afterSameType( // Load the type system and reachability graph to find the entrypoints for the reachability. dl := datalayer.MustFromContext(ctx) - reader := dl.SnapshotReader(req.Revision) + reader := dl.SnapshotReader(req.Revision, datalayer.SchemaHash(req.Metadata.GetSchemaHash())) sr, err := reader.ReadSchema(ctx) if err != nil { return err @@ -605,6 +605,7 @@ func (crr *CursoredLookupResources2) redispatchOrReport( MaximumDepth: parentRequest.Metadata.DepthRemaining - 1, DebugOption: computed.NoDebugging, CheckHints: checkHints, + SchemaHash: datalayer.SchemaHash(parentRequest.Metadata.GetSchemaHash()), }, resourceIDs, crr.dispatchChunkSize) if err != nil { return err @@ -694,6 +695,7 @@ func (crr *CursoredLookupResources2) redispatchOrReport( Metadata: &v1.ResolverMeta{ AtRevision: parentRequest.Revision.String(), DepthRemaining: parentRequest.Metadata.DepthRemaining - 1, + SchemaHash: parentRequest.Metadata.SchemaHash, }, OptionalCursor: ci.currentCursor, OptionalLimit: parentRequest.OptionalLimit, diff --git a/internal/graph/lookupresources3.go b/internal/graph/lookupresources3.go index 593566e90..b63fbff04 100644 --- a/internal/graph/lookupresources3.go +++ b/internal/graph/lookupresources3.go @@ -267,7 +267,7 @@ func (crr *CursoredLookupResources3) LookupResources3(req ValidatedLookupResourc // Build refs for the lookup resources operation. The lr3refs holds references to shared // interfaces used by various suboperations of the lookup resources operation. dl := datalayer.MustFromContext(stream.Context()) - reader := dl.SnapshotReader(req.Revision) + reader := dl.SnapshotReader(req.Revision, datalayer.SchemaHash(req.Metadata.GetSchemaHash())) sr, err := reader.ReadSchema(ctx) if err != nil { return err @@ -964,6 +964,7 @@ func (crr *CursoredLookupResources3) dispatchIter( Metadata: &v1.ResolverMeta{ AtRevision: refs.req.Revision.String(), DepthRemaining: refs.req.Metadata.DepthRemaining - 1, + SchemaHash: refs.req.Metadata.SchemaHash, }, OptionalCursor: currentCursor, OptionalLimit: refs.req.OptionalLimit, @@ -1021,6 +1022,7 @@ func (crr *CursoredLookupResources3) filterSubjectsByCheck( MaximumDepth: refs.req.Metadata.DepthRemaining - 1, DebugOption: computed.NoDebugging, CheckHints: checkHints, + SchemaHash: datalayer.SchemaHash(refs.req.Metadata.GetSchemaHash()), }, resourceIDsToCheck, crr.dispatchChunkSize) if err != nil { return nil, err diff --git a/internal/graph/lookupsubjects.go b/internal/graph/lookupsubjects.go index eda5752c0..8fdfe992a 100644 --- a/internal/graph/lookupsubjects.go +++ b/internal/graph/lookupsubjects.go @@ -68,7 +68,7 @@ func (cl *ConcurrentLookupSubjects) LookupSubjects( } dl := datalayer.MustFromContext(ctx) - reader := dl.SnapshotReader(req.Revision) + reader := dl.SnapshotReader(req.Revision, datalayer.SchemaHash(req.Metadata.GetSchemaHash())) sr, err := reader.ReadSchema(ctx) if err != nil { return err @@ -197,7 +197,7 @@ func (cl *ConcurrentLookupSubjects) lookupViaComputed( parentStream dispatch.LookupSubjectsStream, cu *core.ComputedUserset, ) error { - dl := datalayer.MustFromContext(ctx).SnapshotReader(parentRequest.Revision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(parentRequest.Revision, datalayer.SchemaHash(parentRequest.Metadata.GetSchemaHash())) sr, err := dl.ReadSchema(ctx) if err != nil { return err @@ -231,6 +231,7 @@ func (cl *ConcurrentLookupSubjects) lookupViaComputed( Metadata: &v1.ResolverMeta{ AtRevision: parentRequest.Revision.String(), DepthRemaining: parentRequest.Metadata.DepthRemaining - 1, + SchemaHash: parentRequest.Metadata.SchemaHash, }, }, stream) } @@ -257,7 +258,7 @@ func lookupViaIntersectionTupleToUserset( ts *schema.TypeSystem, ttu *core.FunctionedTupleToUserset, ) error { - dl := datalayer.MustFromContext(ctx).SnapshotReader(parentRequest.Revision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(parentRequest.Revision, datalayer.SchemaHash(parentRequest.Metadata.GetSchemaHash())) sr, err := dl.ReadSchema(ctx) if err != nil { return err @@ -341,6 +342,7 @@ func lookupViaIntersectionTupleToUserset( Metadata: &v1.ResolverMeta{ AtRevision: parentRequest.Revision.String(), DepthRemaining: parentRequest.Metadata.DepthRemaining - 1, + SchemaHash: parentRequest.Metadata.SchemaHash, }, }, collectingStream) if err != nil { @@ -436,7 +438,7 @@ func lookupViaTupleToUserset[T relation]( toDispatchByTuplesetType := datasets.NewSubjectByTypeSet() relationshipsBySubjectONR := mapz.NewMultiMap[tuple.ObjectAndRelation, tuple.Relationship]() - dl := datalayer.MustFromContext(ctx).SnapshotReader(parentRequest.Revision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(parentRequest.Revision, datalayer.SchemaHash(parentRequest.Metadata.GetSchemaHash())) sr, err := dl.ReadSchema(ctx) if err != nil { return err @@ -704,6 +706,7 @@ func (cl *ConcurrentLookupSubjects) dispatchTo( Metadata: &v1.ResolverMeta{ AtRevision: parentRequest.Revision.String(), DepthRemaining: parentRequest.Metadata.DepthRemaining - 1, + SchemaHash: parentRequest.Metadata.SchemaHash, }, }, stream) }) diff --git a/internal/graph/lr2streams.go b/internal/graph/lr2streams.go index 914f245d6..35f65e9b2 100644 --- a/internal/graph/lr2streams.go +++ b/internal/graph/lr2streams.go @@ -14,6 +14,7 @@ import ( "github.com/authzed/spicedb/internal/taskrunner" "github.com/authzed/spicedb/internal/telemetry/otelconv" caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datalayer" "github.com/authzed/spicedb/pkg/genutil/mapz" core "github.com/authzed/spicedb/pkg/proto/core/v1" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" @@ -140,6 +141,7 @@ func (rdc *checkAndDispatchRunner) runChecker(ctx context.Context, startingIndex MaximumDepth: rdc.parentRequest.Metadata.DepthRemaining - 1, DebugOption: computed.NoDebugging, CheckHints: checkHints, + SchemaHash: datalayer.SchemaHash(rdc.parentRequest.Metadata.GetSchemaHash()), }, resourceIDsToCheck, rdc.dispatchChunkSize) if err != nil { return err @@ -242,6 +244,7 @@ func (rdc *checkAndDispatchRunner) runDispatch( Metadata: &v1.ResolverMeta{ AtRevision: rdc.parentRequest.Revision.String(), DepthRemaining: rdc.parentRequest.Metadata.DepthRemaining - 1, + SchemaHash: rdc.parentRequest.Metadata.SchemaHash, }, OptionalCursor: updatedCi.currentCursor, OptionalLimit: rdc.ci.limits.currentLimit, diff --git a/internal/middleware/pertoken/pertoken_test.go b/internal/middleware/pertoken/pertoken_test.go index f7005faad..4df1855a0 100644 --- a/internal/middleware/pertoken/pertoken_test.go +++ b/internal/middleware/pertoken/pertoken_test.go @@ -54,12 +54,12 @@ func (t testServer) Ping(ctx context.Context, req *testpb.PingRequest) (*testpb. } } - headRev, err := dl.HeadRevision(ctx) + headRev, schemaHash, err := dl.HeadRevision(ctx) if err != nil { return nil, err } - reader := dl.SnapshotReader(headRev) + reader := dl.SnapshotReader(headRev, schemaHash) if reader == nil { return nil, errors.New("no snapshot reader available") } diff --git a/internal/namespace/util_test.go b/internal/namespace/util_test.go index 705c80ae5..3ca8c307e 100644 --- a/internal/namespace/util_test.go +++ b/internal/namespace/util_test.go @@ -171,7 +171,7 @@ func TestCheckNamespaceAndRelations(t *testing.T) { require.NoError(t, err) dl := datalayer.NewDataLayer(ds) - sr, err := dl.SnapshotReader(rev).ReadSchema(t.Context()) + sr, err := dl.SnapshotReader(rev, datalayer.NoSchemaHashForTesting).ReadSchema(t.Context()) require.NoError(t, err) err = namespace.CheckNamespaceAndRelations(t.Context(), tc.checks, sr) diff --git a/internal/relationships/validation_test.go b/internal/relationships/validation_test.go index b8521ffb3..e41f898e0 100644 --- a/internal/relationships/validation_test.go +++ b/internal/relationships/validation_test.go @@ -336,7 +336,7 @@ func TestValidateRelationshipOperations(t *testing.T) { uds, rev := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, nil, req) dl := datalayer.NewDataLayer(uds) - sr, err := dl.SnapshotReader(rev).ReadSchema(t.Context()) + sr, err := dl.SnapshotReader(rev, datalayer.NoSchemaHashForTesting).ReadSchema(t.Context()) req.NoError(err) op := tuple.Create diff --git a/internal/services/integrationtesting/consistency_test.go b/internal/services/integrationtesting/consistency_test.go index a28c765c4..ed52d46c0 100644 --- a/internal/services/integrationtesting/consistency_test.go +++ b/internal/services/integrationtesting/consistency_test.go @@ -21,6 +21,7 @@ import ( "github.com/authzed/spicedb/internal/dispatch" "github.com/authzed/spicedb/internal/services/integrationtesting/consistencytestutil" "github.com/authzed/spicedb/pkg/cmd/server" + "github.com/authzed/spicedb/pkg/datalayer" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/development" "github.com/authzed/spicedb/pkg/genutil/mapz" @@ -456,6 +457,7 @@ func validateExpansionSubjects(t *testing.T, vctx validationContext) { AtRevision: vctx.revision.String(), DepthRemaining: 100, TraversalBloom: dispatchv1.MustNewTraversalBloomFilter(100), + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, ExpansionMode: dispatchv1.DispatchExpandRequest_RECURSIVE, }) diff --git a/internal/services/integrationtesting/consistencytestutil/accessibilityset.go b/internal/services/integrationtesting/consistencytestutil/accessibilityset.go index 1316a1f8b..a0231c84d 100644 --- a/internal/services/integrationtesting/consistencytestutil/accessibilityset.go +++ b/internal/services/integrationtesting/consistencytestutil/accessibilityset.go @@ -13,6 +13,7 @@ import ( "github.com/authzed/spicedb/internal/dispatch/graph" "github.com/authzed/spicedb/internal/graph/computed" caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datalayer" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/genutil/mapz" dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" @@ -141,6 +142,7 @@ func BuildAccessibilitySet(t *testing.T, ctx context.Context, populated *validat AtRevision: headRevision.String(), DepthRemaining: 50, TraversalBloom: dispatchv1.MustNewTraversalBloomFilter(50), + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, }) require.NoError(t, err) @@ -173,6 +175,7 @@ func BuildAccessibilitySet(t *testing.T, ctx context.Context, populated *validat CaveatContext: nil, AtRevision: headRevision, MaximumDepth: 50, + SchemaHash: datalayer.NoSchemaHashForTesting, }, possibleResourceID, 100, @@ -413,6 +416,7 @@ func isAccessibleViaWildcardOnly( AtRevision: revision.String(), DepthRemaining: 100, TraversalBloom: dispatchv1.MustNewTraversalBloomFilter(100), + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), }, ExpansionMode: dispatchv1.DispatchExpandRequest_RECURSIVE, }) diff --git a/internal/services/integrationtesting/query_plan_consistency_test.go b/internal/services/integrationtesting/query_plan_consistency_test.go index e0f04a8a6..0f1e4540e 100644 --- a/internal/services/integrationtesting/query_plan_consistency_test.go +++ b/internal/services/integrationtesting/query_plan_consistency_test.go @@ -48,7 +48,7 @@ type queryPlanConsistencyHandle struct { func (q *queryPlanConsistencyHandle) buildContext(t *testing.T) *query.Context { return query.NewLocalContext(t.Context(), - query.WithRevisionedReader(datalayer.NewDataLayer(q.ds).SnapshotReader(q.revision)), + query.WithRevisionedReader(datalayer.NewDataLayer(q.ds).SnapshotReader(q.revision, datalayer.NoSchemaHashForTesting)), query.WithCaveatRunner(caveats.NewCaveatRunner(caveattypes.Default.TypeSet)), query.WithTraceLogger(query.NewTraceLogger())) // Enable tracing for debugging } diff --git a/internal/services/v1/bulkcheck.go b/internal/services/v1/bulkcheck.go index d19eec4ea..63106dcf5 100644 --- a/internal/services/v1/bulkcheck.go +++ b/internal/services/v1/bulkcheck.go @@ -48,7 +48,7 @@ const maxBulkCheckCount = 10000 func (bc *bulkChecker) checkBulkPermissions(ctx context.Context, req *v1.CheckBulkPermissionsRequest) (*v1.CheckBulkPermissionsResponse, error) { telemetry.LogicalChecks.Add(float64(len(req.Items))) - atRevision, checkedAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, checkedAt, err := consistency.RevisionFromContext(ctx) if err != nil { return nil, err } @@ -77,6 +77,7 @@ func (bc *bulkChecker) checkBulkPermissions(ctx context.Context, req *v1.CheckBu // the dispatching system already internally supports this kind of batching for performance. groupedItems, err := groupItems(ctx, groupingParameters{ atRevision: atRevision, + schemaHash: schemaHash, maxCaveatContextSize: bc.maxCaveatContextSize, maximumAPIDepth: bc.maxAPIDepth, withTracing: req.WithTracing, @@ -163,7 +164,7 @@ func (bc *bulkChecker) checkBulkPermissions(ctx context.Context, req *v1.CheckBu bulkResponseMutex.Lock() defer bulkResponseMutex.Unlock() - dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) sr, err := dl.ReadSchema(ctx) if err != nil { @@ -251,7 +252,7 @@ func (bc *bulkChecker) checkBulkPermissions(ctx context.Context, req *v1.CheckBu tr.Add(func(ctx context.Context) error { startTime := time.Now() - dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) sr, err := dl.ReadSchema(ctx) if err != nil { diff --git a/internal/services/v1/experimental.go b/internal/services/v1/experimental.go index fdef5988f..83d41decf 100644 --- a/internal/services/v1/experimental.go +++ b/internal/services/v1/experimental.go @@ -329,34 +329,38 @@ func (es *experimentalServer) BulkExportRelationships( ctx := resp.Context() perfinsights.SetInContext(ctx, perfinsights.NoLabels) - atRevision, _, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, _, err := consistency.RevisionFromContext(ctx) if err != nil { return shared.RewriteErrorWithoutConfig(ctx, err) } - return BulkExport(ctx, datalayer.MustFromContext(ctx), es.maxBatchSize, req, atRevision, resp.Send) + return BulkExport(ctx, datalayer.MustFromContext(ctx), es.maxBatchSize, req, atRevision, schemaHash, resp.Send) } // BulkExport implements the BulkExportRelationships API functionality. Given a datalayer.DataLayer, it will // export stream via the sender all relationships matched by the incoming request. // If no cursor is provided, it will fallback to the provided revision. -func BulkExport(ctx context.Context, dl datalayer.DataLayer, batchSize uint64, req *v1.BulkExportRelationshipsRequest, fallbackRevision datastore.Revision, sender func(response *v1.BulkExportRelationshipsResponse) error) error { +func BulkExport(ctx context.Context, dl datalayer.DataLayer, batchSize uint64, req *v1.BulkExportRelationshipsRequest, fallbackRevision datastore.Revision, fallbackSchemaHash datalayer.SchemaHash, sender func(response *v1.BulkExportRelationshipsResponse) error) error { if req.OptionalLimit > 0 && uint64(req.OptionalLimit) > batchSize { return shared.RewriteErrorWithoutConfig(ctx, NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), batchSize)) } atRevision := fallbackRevision + schemaHash := fallbackSchemaHash var curNamespace string var cur dsoptions.Cursor if req.OptionalCursor != nil { - var err error - atRevision, curNamespace, cur, err = decodeCursor(dl, req.OptionalCursor) + dc, err := decodeBulkExportCursor(dl, req.OptionalCursor) if err != nil { return shared.RewriteErrorWithoutConfig(ctx, err) } + atRevision = dc.revision + curNamespace = dc.namespace + cur = dc.cursor + schemaHash = dc.schemaHash } - reader := dl.SnapshotReader(atRevision) + reader := dl.SnapshotReader(atRevision, schemaHash) sr, err := reader.ReadSchema(ctx) if err != nil { @@ -576,7 +580,7 @@ func (es *experimentalServer) ExperimentalReflectSchema(ctx context.Context, req func (es *experimentalServer) ExperimentalDiffSchema(ctx context.Context, req *v1.ExperimentalDiffSchemaRequest) (*v1.ExperimentalDiffSchemaResponse, error) { perfinsights.SetInContext(ctx, perfinsights.NoLabels) - atRevision, _, err := consistency.RevisionFromContext(ctx) + atRevision, _, _, err := consistency.RevisionFromContext(ctx) if err != nil { return nil, err } @@ -603,12 +607,12 @@ func (es *experimentalServer) ExperimentalComputablePermissions(ctx context.Cont } }) - atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, revisionReadAt, err := consistency.RevisionFromContext(ctx) if err != nil { return nil, shared.RewriteErrorWithoutConfig(ctx, err) } - dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) sr, err := dl.ReadSchema(ctx) if err != nil { return nil, shared.RewriteErrorWithoutConfig(ctx, err) @@ -690,12 +694,12 @@ func (es *experimentalServer) ExperimentalDependentRelations(ctx context.Context } }) - atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, revisionReadAt, err := consistency.RevisionFromContext(ctx) if err != nil { return nil, shared.RewriteErrorWithoutConfig(ctx, err) } - dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) sr, err := dl.ReadSchema(ctx) if err != nil { return nil, shared.RewriteErrorWithoutConfig(ctx, err) @@ -824,12 +828,12 @@ func (es *experimentalServer) ExperimentalCountRelationships(ctx context.Context } dl := datalayer.MustFromContext(ctx) - headRev, err := dl.HeadRevision(ctx) + headRev, headSchemaHash, err := dl.HeadRevision(ctx) if err != nil { return nil, shared.RewriteErrorWithoutConfig(ctx, err) } - snapshotReader := dl.SnapshotReader(headRev) + snapshotReader := dl.SnapshotReader(headRev, headSchemaHash) count, err := snapshotReader.CountRelationships(ctx, req.Name) if err != nil { return nil, shared.RewriteErrorWithoutConfig(ctx, err) @@ -879,30 +883,47 @@ func queryForEach( return cursor, nil } -func decodeCursor(dl datalayer.DataLayer, encoded *v1.Cursor) (datastore.Revision, string, dsoptions.Cursor, error) { +// decodedCursor holds the decoded components of a bulk export cursor. +type decodedCursor struct { + revision datastore.Revision + namespace string + cursor dsoptions.Cursor + schemaHash datalayer.SchemaHash +} + +func decodeBulkExportCursor(dl datalayer.DataLayer, encoded *v1.Cursor) (decodedCursor, error) { decoded, err := cursor.Decode(encoded) if err != nil { - return datastore.NoRevision, "", nil, err + return decodedCursor{}, err } if decoded.GetV1() == nil { - return datastore.NoRevision, "", nil, errors.New("malformed cursor: no V1 in OneOf") + return decodedCursor{}, errors.New("malformed cursor: no V1 in OneOf") } if len(decoded.GetV1().Sections) != 2 { - return datastore.NoRevision, "", nil, errors.New("malformed cursor: wrong number of components") + return decodedCursor{}, errors.New("malformed cursor: wrong number of components") } atRevision, err := dl.RevisionFromString(decoded.GetV1().Revision) if err != nil { - return datastore.NoRevision, "", nil, err + return decodedCursor{}, err } cur, err := tuple.Parse(decoded.GetV1().GetSections()[1]) if err != nil { - return datastore.NoRevision, "", nil, fmt.Errorf("malformed cursor: invalid encoded relation tuple: %w", err) + return decodedCursor{}, fmt.Errorf("malformed cursor: invalid encoded relation tuple: %w", err) } - // Returns the current namespace and the cursor. - return atRevision, decoded.GetV1().GetSections()[0], dsoptions.ToCursor(cur), nil + schemaHash := datalayer.NoSchemaHashForLegacyCursor + if len(decoded.GetV1().SchemaHash) > 0 { + schemaHash = datalayer.SchemaHash(decoded.GetV1().SchemaHash) + } + + return decodedCursor{ + revision: atRevision, + namespace: decoded.GetV1().GetSections()[0], + cursor: dsoptions.ToCursor(cur), + schemaHash: schemaHash, + }, nil } diff --git a/internal/services/v1/grouping.go b/internal/services/v1/grouping.go index 99b681d2e..2101a0479 100644 --- a/internal/services/v1/grouping.go +++ b/internal/services/v1/grouping.go @@ -6,6 +6,7 @@ import ( v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/authzed/spicedb/internal/graph/computed" + "github.com/authzed/spicedb/pkg/datalayer" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/tuple" ) @@ -17,6 +18,7 @@ type groupedCheckParameters struct { type groupingParameters struct { atRevision datastore.Revision + schemaHash datalayer.SchemaHash maximumAPIDepth uint32 maxCaveatContextSize int withTracing bool @@ -68,5 +70,6 @@ func checkParametersFromCheckBulkPermissionsRequestItem( AtRevision: params.atRevision, MaximumDepth: params.maximumAPIDepth, DebugOption: debugOption, + SchemaHash: params.schemaHash, } } diff --git a/internal/services/v1/permissions.go b/internal/services/v1/permissions.go index 852841d90..5a6ab2aeb 100644 --- a/internal/services/v1/permissions.go +++ b/internal/services/v1/permissions.go @@ -74,12 +74,12 @@ func (ps *permissionServer) CheckPermission(ctx context.Context, req *v1.CheckPe return ps.checkPermissionWithQueryPlan(ctx, req) } - atRevision, checkedAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, checkedAt, err := consistency.RevisionFromContext(ctx) if err != nil { return nil, ps.rewriteError(ctx, err) } - dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) caveatContext, err := GetCaveatContext(ctx, req.Context, ps.config.MaxCaveatContextSize) if err != nil { @@ -129,6 +129,7 @@ func (ps *permissionServer) CheckPermission(ctx context.Context, req *v1.CheckPe AtRevision: atRevision, MaximumDepth: ps.config.MaximumAPIDepth, DebugOption: debugOption, + SchemaHash: schemaHash, }, req.Resource.ObjectId, ps.config.DispatchChunkSize, @@ -246,12 +247,12 @@ func (ps *permissionServer) ExpandPermissionTree(ctx context.Context, req *v1.Ex telemetry.LogicalChecks.Inc() - atRevision, expandedAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, expandedAt, err := consistency.RevisionFromContext(ctx) if err != nil { return nil, ps.rewriteError(ctx, err) } - dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) sr, err := dl.ReadSchema(ctx) if err != nil { @@ -273,6 +274,7 @@ func (ps *permissionServer) ExpandPermissionTree(ctx context.Context, req *v1.Ex AtRevision: atRevision.String(), DepthRemaining: ps.config.MaximumAPIDepth, TraversalBloom: bf, + SchemaHash: []byte(schemaHash), }, ResourceAndRelation: &core.ObjectAndRelation{ Namespace: req.Resource.ObjectType, @@ -489,12 +491,12 @@ func (ps *permissionServer) lookupResources3(req *v1.LookupResourcesRequest, res ctx := resp.Context() - atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, revisionReadAt, err := consistency.RevisionFromContext(ctx) if err != nil { return ps.rewriteError(ctx, err) } - dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) sr, err := dl.ReadSchema(ctx) if err != nil { @@ -569,7 +571,7 @@ func (ps *permissionServer) lookupResources3(req *v1.LookupResourcesRequest, res if len(item.AfterResponseCursorSections) > 0 { currentCursor = item.AfterResponseCursorSections - ec, err := cursor.EncodeFromDispatchCursorSections(currentCursor, lrRequestHash, atRevision, map[string]string{ + ec, err := cursor.EncodeFromDispatchCursorSections(currentCursor, lrRequestHash, atRevision, schemaHash, map[string]string{ lrv3CursorFlag: "1", }) if err != nil { @@ -606,6 +608,7 @@ func (ps *permissionServer) lookupResources3(req *v1.LookupResourcesRequest, res AtRevision: atRevision.String(), DepthRemaining: ps.config.MaximumAPIDepth, TraversalBloom: bf, + SchemaHash: []byte(schemaHash), }, ResourceRelation: &core.RelationReference{ Namespace: req.ResourceObjectType, @@ -640,12 +643,12 @@ func (ps *permissionServer) lookupResources2(req *v1.LookupResourcesRequest, res ctx := resp.Context() - atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, revisionReadAt, err := consistency.RevisionFromContext(ctx) if err != nil { return ps.rewriteError(ctx, err) } - dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) sr, err := dl.ReadSchema(ctx) if err != nil { @@ -720,7 +723,7 @@ func (ps *permissionServer) lookupResources2(req *v1.LookupResourcesRequest, res alreadyPublishedPermissionedResourceIds[found.ResourceId] = struct{}{} } - encodedCursor, err := cursor.EncodeFromDispatchCursor(result.AfterResponseCursor, lrRequestHash, atRevision, map[string]string{ + encodedCursor, err := cursor.EncodeFromDispatchCursor(result.AfterResponseCursor, lrRequestHash, atRevision, schemaHash, map[string]string{ lrv2CursorFlag: "1", }) if err != nil { @@ -753,6 +756,7 @@ func (ps *permissionServer) lookupResources2(req *v1.LookupResourcesRequest, res AtRevision: atRevision.String(), DepthRemaining: ps.config.MaximumAPIDepth, TraversalBloom: bf, + SchemaHash: []byte(schemaHash), }, ResourceRelation: &core.RelationReference{ Namespace: req.ResourceObjectType, @@ -796,12 +800,12 @@ func (ps *permissionServer) LookupSubjects(req *v1.LookupSubjectsRequest, resp v return ps.rewriteError(ctx, status.Errorf(codes.Unimplemented, "concrete limit is not yet supported")) } - atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, revisionReadAt, err := consistency.RevisionFromContext(ctx) if err != nil { return ps.rewriteError(ctx, err) } - dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) caveatContext, err := GetCaveatContext(ctx, req.Context, ps.config.MaxCaveatContextSize) if err != nil { @@ -906,6 +910,7 @@ func (ps *permissionServer) LookupSubjects(req *v1.LookupSubjectsRequest, resp v AtRevision: atRevision.String(), DepthRemaining: ps.config.MaximumAPIDepth, TraversalBloom: bf, + SchemaHash: []byte(schemaHash), }, ResourceRelation: &core.RelationReference{ Namespace: req.Resource.ObjectType, @@ -1155,34 +1160,38 @@ func (ps *permissionServer) ExportBulkRelationships( return labelsForFilter(req.OptionalRelationshipFilter) }) - atRevision, _, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, _, err := consistency.RevisionFromContext(ctx) if err != nil { return shared.RewriteErrorWithoutConfig(ctx, err) } - return ExportBulk(ctx, datalayer.MustFromContext(ctx), uint64(ps.config.MaxBulkExportRelationshipsLimit), req, atRevision, resp.Send) + return ExportBulk(ctx, datalayer.MustFromContext(ctx), uint64(ps.config.MaxBulkExportRelationshipsLimit), req, atRevision, schemaHash, resp.Send) } // ExportBulk implements the ExportBulkRelationships API functionality. Given a datalayer.DataLayer, it will // export stream via the sender all relationships matched by the incoming request. // If no cursor is provided, it will fallback to the provided revision. -func ExportBulk(ctx context.Context, dl datalayer.DataLayer, batchSize uint64, req *v1.ExportBulkRelationshipsRequest, fallbackRevision datastore.Revision, sender func(response *v1.ExportBulkRelationshipsResponse) error) error { +func ExportBulk(ctx context.Context, dl datalayer.DataLayer, batchSize uint64, req *v1.ExportBulkRelationshipsRequest, fallbackRevision datastore.Revision, fallbackSchemaHash datalayer.SchemaHash, sender func(response *v1.ExportBulkRelationshipsResponse) error) error { if req.OptionalLimit > 0 && uint64(req.OptionalLimit) > batchSize { return shared.RewriteErrorWithoutConfig(ctx, NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), batchSize)) } atRevision := fallbackRevision + schemaHash := fallbackSchemaHash var curNamespace string var cur dsoptions.Cursor if req.OptionalCursor != nil { - var err error - atRevision, curNamespace, cur, err = decodeCursor(dl, req.OptionalCursor) + dc, err := decodeBulkExportCursor(dl, req.OptionalCursor) if err != nil { return shared.RewriteErrorWithoutConfig(ctx, err) } + atRevision = dc.revision + curNamespace = dc.namespace + cur = dc.cursor + schemaHash = dc.schemaHash } - reader := dl.SnapshotReader(atRevision) + reader := dl.SnapshotReader(atRevision, schemaHash) readerSchema, err := reader.ReadSchema(ctx) if err != nil { diff --git a/internal/services/v1/permissions_queryplan.go b/internal/services/v1/permissions_queryplan.go index 077dd67fe..cb9ad8d5e 100644 --- a/internal/services/v1/permissions_queryplan.go +++ b/internal/services/v1/permissions_queryplan.go @@ -16,13 +16,13 @@ import ( // checkPermissionWithQueryPlan executes a permission check using the query plan API. // This builds an iterator tree from the schema and executes it against the datastore. func (ps *permissionServer) checkPermissionWithQueryPlan(ctx context.Context, req *v1.CheckPermissionRequest) (*v1.CheckPermissionResponse, error) { - atRevision, checkedAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, checkedAt, err := consistency.RevisionFromContext(ctx) if err != nil { return nil, ps.rewriteError(ctx, err) } dl := datalayer.MustFromContext(ctx) - reader := dl.SnapshotReader(atRevision) + reader := dl.SnapshotReader(atRevision, schemaHash) // Load all namespace and caveat definitions to build the schema // TODO: Better schema caching diff --git a/internal/services/v1/reflectionutil.go b/internal/services/v1/reflectionutil.go index fda797a4f..5336fae74 100644 --- a/internal/services/v1/reflectionutil.go +++ b/internal/services/v1/reflectionutil.go @@ -16,12 +16,12 @@ import ( func loadCurrentSchema(ctx context.Context) (*diff.DiffableSchema, datastore.Revision, error) { dl := datalayer.MustFromContext(ctx) - atRevision, _, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, _, err := consistency.RevisionFromContext(ctx) if err != nil { return nil, nil, err } - reader := dl.SnapshotReader(atRevision) + reader := dl.SnapshotReader(atRevision, schemaHash) sr, err := reader.ReadSchema(ctx) if err != nil { diff --git a/internal/services/v1/relationships.go b/internal/services/v1/relationships.go index 4e58def31..6ebfca703 100644 --- a/internal/services/v1/relationships.go +++ b/internal/services/v1/relationships.go @@ -203,12 +203,12 @@ func (ps *permissionServer) ReadRelationships(req *v1.ReadRelationshipsRequest, } ctx := resp.Context() - atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, revisionReadAt, err := consistency.RevisionFromContext(ctx) if err != nil { return ps.rewriteError(ctx, err) } - dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) sr, err := dl.ReadSchema(ctx) if err != nil { return ps.rewriteError(ctx, err) @@ -310,7 +310,7 @@ func (ps *permissionServer) ReadRelationships(req *v1.ReadRelationshipsRequest, } dispatchCursor.Sections[0] = tuple.StringWithoutCaveatOrExpiration(rel) - encodedCursor, err := cursor.EncodeFromDispatchCursor(dispatchCursor, rrRequestHash, atRevision, nil) + encodedCursor, err := cursor.EncodeFromDispatchCursor(dispatchCursor, rrRequestHash, atRevision, schemaHash, nil) if err != nil { return ps.rewriteError(ctx, err) } diff --git a/internal/services/v1/schema.go b/internal/services/v1/schema.go index 97da83600..9b1aadbc2 100644 --- a/internal/services/v1/schema.go +++ b/internal/services/v1/schema.go @@ -92,12 +92,12 @@ func (ss *schemaServer) ReadSchema(ctx context.Context, _ *v1.ReadSchemaRequest) // Schema is always read from the head revision. dl := datalayer.MustFromContext(ctx) - headRevision, err := dl.HeadRevision(ctx) + headRevision, headSchemaHash, err := dl.HeadRevision(ctx) if err != nil { return nil, ss.rewriteError(ctx, err) } - reader := dl.SnapshotReader(headRevision) + reader := dl.SnapshotReader(headRevision, headSchemaHash) sr, err := reader.ReadSchema(ctx) if err != nil { @@ -245,7 +245,7 @@ func (ss *schemaServer) ReflectSchema(ctx context.Context, req *v1.ReflectSchema func (ss *schemaServer) DiffSchema(ctx context.Context, req *v1.DiffSchemaRequest) (*v1.DiffSchemaResponse, error) { perfinsights.SetInContext(ctx, perfinsights.NoLabels) - atRevision, _, err := consistency.RevisionFromContext(ctx) + atRevision, _, _, err := consistency.RevisionFromContext(ctx) if err != nil { return nil, err } @@ -272,12 +272,12 @@ func (ss *schemaServer) ComputablePermissions(ctx context.Context, req *v1.Compu } }) - atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, revisionReadAt, err := consistency.RevisionFromContext(ctx) if err != nil { return nil, shared.RewriteErrorWithoutConfig(ctx, err) } - dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) sr, err := dl.ReadSchema(ctx) if err != nil { return nil, shared.RewriteErrorWithoutConfig(ctx, err) @@ -359,12 +359,12 @@ func (ss *schemaServer) DependentRelations(ctx context.Context, req *v1.Dependen } }) - atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + atRevision, schemaHash, revisionReadAt, err := consistency.RevisionFromContext(ctx) if err != nil { return nil, shared.RewriteErrorWithoutConfig(ctx, err) } - dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision) + dl := datalayer.MustFromContext(ctx).SnapshotReader(atRevision, schemaHash) sr2, err := dl.ReadSchema(ctx) if err != nil { return nil, shared.RewriteErrorWithoutConfig(ctx, err) diff --git a/internal/services/v1/watch.go b/internal/services/v1/watch.go index 3b2aa4316..4de4246ec 100644 --- a/internal/services/v1/watch.go +++ b/internal/services/v1/watch.go @@ -81,13 +81,13 @@ func (ws *watchServer) Watch(req *v1.WatchRequest, stream v1.WatchService_WatchS afterRevision = decodedRevision } else { var err error - afterRevision, err = dl.OptimizedRevision(ctx) + afterRevision, _, err = dl.OptimizedRevision(ctx) if err != nil { return status.Errorf(codes.Unavailable, "failed to start watch: %s", err) } } - reader := dl.SnapshotReader(afterRevision) + reader := dl.SnapshotReader(afterRevision, datalayer.NoSchemaHashForWatch) sr, err := reader.ReadSchema(ctx) if err != nil { return status.Errorf(codes.Internal, "failed to read schema: %s", err) diff --git a/internal/telemetry/otelconv/otelconv.go b/internal/telemetry/otelconv/otelconv.go index 6dc7d7f80..3a6e7ae1e 100644 --- a/internal/telemetry/otelconv/otelconv.go +++ b/internal/telemetry/otelconv/otelconv.go @@ -119,4 +119,11 @@ const ( AttrTestKey = "spicedb.internal.test.key" AttrTestNumber = "spicedb.internal.test.number" + + AttrSchemaReadFromCache = "spicedb.internal.schema.read_from_cache" + AttrSchemaChunkCount = "spicedb.internal.schema.chunk_count" + AttrSchemaDataSizeBytes = "spicedb.internal.schema.data_size_bytes" + AttrSchemaHash = "spicedb.internal.schema.hash" + AttrSchemaCacheBypassed = "spicedb.internal.schema.cache_bypassed" + AttrSchemaDefinitionName = "spicedb.internal.schema.definition_name" ) diff --git a/internal/testfixtures/validating.go b/internal/testfixtures/validating.go index 346f3bffb..23edf5582 100644 --- a/internal/testfixtures/validating.go +++ b/internal/testfixtures/validating.go @@ -187,6 +187,10 @@ func (vsr validatingSnapshotReader) LegacyListAllCaveats(ctx context.Context) ([ return read, err } +func (vsr validatingSnapshotReader) ReadStoredSchema(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + return vsr.delegate.ReadStoredSchema(ctx) +} + type validatingReadWriteTransaction struct { validatingSnapshotReader delegate datastore.ReadWriteTransaction @@ -257,6 +261,14 @@ func (vrwt validatingReadWriteTransaction) BulkLoad(ctx context.Context, source return vrwt.delegate.BulkLoad(ctx, source) } +func (vrwt validatingReadWriteTransaction) ReadStoredSchema(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + return vrwt.delegate.ReadStoredSchema(ctx) +} + +func (vrwt validatingReadWriteTransaction) WriteStoredSchema(ctx context.Context, schema *core.StoredSchema) error { + return vrwt.delegate.WriteStoredSchema(ctx, schema) +} + // validateUpdatesToWrite performs basic validation on relationship updates going into datastores. func validateUpdatesToWrite(updates ...tuple.RelationshipUpdate) error { for _, update := range updates { diff --git a/pkg/cache/cache_ristretto.go b/pkg/cache/cache_ristretto.go index 0eec09d77..dd9b97982 100644 --- a/pkg/cache/cache_ristretto.go +++ b/pkg/cache/cache_ristretto.go @@ -12,6 +12,10 @@ import ( "github.com/authzed/spicedb/internal/dispatch/keys" ) +type keyStringer interface { + KeyString() string +} + func ristrettoConfig(config *Config) *ristretto.Config { return &ristretto.Config{ NumCounters: config.NumCounters, @@ -20,9 +24,8 @@ func ristrettoConfig(config *Config) *ristretto.Config { KeyToHash: func(key any) (uint64, uint64) { dispatchCacheKey, ok := key.(keys.DispatchCacheKey) if !ok { - stringValue, ok := key.(StringKey) - if ok { - return z.KeyToHash(string(stringValue)) + if ks, ok := key.(keyStringer); ok { + return z.KeyToHash(ks.KeyString()) } return z.KeyToHash(key) diff --git a/pkg/cmd/serve.go b/pkg/cmd/serve.go index 406f88387..104ecfe92 100644 --- a/pkg/cmd/serve.go +++ b/pkg/cmd/serve.go @@ -58,6 +58,15 @@ var ( MaxCost: "50MiB", CacheKindForTesting: "", } + + storedSchemaCacheDefaults = &server.CacheConfig{ + Name: "stored_schema", + Enabled: true, + Metrics: true, + NumCounters: 1_000, + MaxCost: "32MiB", + CacheKindForTesting: "", + } ) func BoldBlue(name string) string { @@ -188,6 +197,7 @@ func RegisterServeFlags(cmd *cobra.Command, config *server.Config) error { return fmt.Errorf("failed to mark flag as deprecated: %w", err) } experimentalFlags.BoolVar(&config.EnableExperimentalWatchableSchemaCache, "enable-experimental-watchable-schema-cache", false, "enables the experimental schema cache, which uses the Watch API to keep the schema up to date") + experimentalFlags.StringVar(&config.ExperimentalSchemaMode, "experimental-schema-mode", "read-legacy-write-legacy", "schema storage mode for migration to unified schema: read-legacy-write-legacy, read-legacy-write-both, read-new-write-both, read-new-write-new") // TODO: these two could reasonably be put in either the Dispatch group or the Experimental group. Is there a preference? experimentalFlags.StringToStringVar(&config.DispatchSecondaryUpstreamAddrs, "experimental-dispatch-secondary-upstream-addrs", nil, "secondary upstream addresses for dispatches, each with a name") experimentalFlags.StringToStringVar(&config.DispatchSecondaryUpstreamExprs, "experimental-dispatch-secondary-upstream-exprs", nil, "map from request type to its associated CEL expression, which returns the secondary upstream(s) to be used for the request") @@ -204,6 +214,11 @@ func RegisterServeFlags(cmd *cobra.Command, config *server.Config) error { return fmt.Errorf("could not register lookup resources chunk cache flags: %w", err) } + err = server.RegisterCacheFlags(experimentalFlags, "stored-schema-cache", "stored schema", &config.StoredSchemaCacheConfig, storedSchemaCacheDefaults) + if err != nil { + return fmt.Errorf("could not register stored schema cache flags: %w", err) + } + tracingFlags := nfs.FlagSet(BoldBlue("Tracing")) // Flags for tracing // NOTE: cobraotel.New takes service name as an arg rather than command name. diff --git a/pkg/cmd/serve_test.go b/pkg/cmd/serve_test.go index ec48e4d59..117f6614a 100644 --- a/pkg/cmd/serve_test.go +++ b/pkg/cmd/serve_test.go @@ -28,6 +28,7 @@ func RunServeTest(t *testing.T, args []string, assertConfig func(t *testing.T, m config.DispatchCacheConfig.Metrics = false config.ClusterDispatchCacheConfig.Metrics = false config.NamespaceCacheConfig.Metrics = false + config.StoredSchemaCacheConfig.Metrics = false cmd.SetArgs(args) diff --git a/pkg/cmd/server/defaults.go b/pkg/cmd/server/defaults.go index c68f67b4f..febda4da7 100644 --- a/pkg/cmd/server/defaults.go +++ b/pkg/cmd/server/defaults.go @@ -176,11 +176,10 @@ const ( DefaultMiddlewareServerVersion = "serverversion" DefaultMiddlewareMemoryProtection = "memoryprotection" - DefaultInternalMiddlewareDispatch = "dispatch" - DefaultInternalMiddlewareDatastore = "datastore" - DefaultInternalMiddlewareDatastoreCounting = "datastore-counting" - DefaultInternalMiddlewareConsistency = "consistency" - DefaultInternalMiddlewareServerSpecific = "servicespecific" + DefaultInternalMiddlewareDispatch = "dispatch" + DefaultInternalMiddlewareDatastore = "datastore" + DefaultInternalMiddlewareConsistency = "consistency" + DefaultInternalMiddlewareServerSpecific = "servicespecific" ) //go:generate go run github.com/ecordell/optgen -output zz_generated.middlewareoption.go . MiddlewareOption @@ -225,8 +224,8 @@ func (m MiddlewareOption) WithDatastoreMiddleware(middleware Middleware) Middlew return m } -func (m MiddlewareOption) WithDatastore(ds datastore.Datastore) MiddlewareOption { - dl := datalayer.NewDataLayer(ds) +func (m MiddlewareOption) WithDatastore(ds datastore.Datastore, dlOpts ...datalayer.DataLayerOption) MiddlewareOption { + dl := datalayer.NewDataLayer(ds, dlOpts...) unary := NewUnaryMiddleware(). WithName(DefaultInternalMiddlewareDatastore). WithInternal(true). @@ -337,12 +336,6 @@ func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryS *opts.unaryDatastoreMiddleware, - NewUnaryMiddleware(). - WithName(DefaultInternalMiddlewareDatastoreCounting). - WithInternal(true). - WithInterceptor(datalayer.UnaryCountingInterceptor(nil)). - Done(), - NewUnaryMiddleware(). WithName(DefaultInternalMiddlewareConsistency). WithInterceptor(consistencymw.UnaryServerInterceptor(opts.MiddlewareServiceLabel, opts.MismatchingZedTokenOption)). @@ -416,12 +409,6 @@ func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.St *opts.streamDatastoreMiddleware, - NewStreamMiddleware(). - WithName(DefaultInternalMiddlewareDatastoreCounting). - WithInternal(true). - WithInterceptor(datalayer.StreamCountingInterceptor(nil)). - Done(), - NewStreamMiddleware(). WithName(DefaultInternalMiddlewareConsistency). WithInterceptor(consistencymw.StreamServerInterceptor(opts.MiddlewareServiceLabel, opts.MismatchingZedTokenOption)). @@ -450,10 +437,10 @@ func determineEventsToLog(opts MiddlewareOption) grpclog.Option { } // DefaultDispatchMiddleware generates the default middleware chain used for the internal dispatch SpiceDB gRPC API -func DefaultDispatchMiddleware(logger zerolog.Logger, authFunc grpcauth.AuthFunc, ds datastore.Datastore, disableGRPCLatencyHistogram bool, memoryUsageProvider memoryprotection.MemoryUsageProvider) ([]grpc.UnaryServerInterceptor, []grpc.StreamServerInterceptor) { +func DefaultDispatchMiddleware(logger zerolog.Logger, authFunc grpcauth.AuthFunc, ds datastore.Datastore, disableGRPCLatencyHistogram bool, memoryUsageProvider memoryprotection.MemoryUsageProvider, dlOpts ...datalayer.DataLayerOption) ([]grpc.UnaryServerInterceptor, []grpc.StreamServerInterceptor) { grpcMetricsUnaryInterceptor, grpcMetricsStreamingInterceptor := GRPCMetrics(disableGRPCLatencyHistogram) dispatchMemoryProtection := memoryprotection.New(memoryUsageProvider, "dispatch-middleware") - dl := datalayer.NewDataLayer(ds) + dl := datalayer.NewDataLayer(ds, dlOpts...) return []grpc.UnaryServerInterceptor{ requestid.UnaryServerInterceptor(requestid.GenerateIfMissing(true)), diff --git a/pkg/cmd/server/server.go b/pkg/cmd/server/server.go index 1599035aa..c4852d9d3 100644 --- a/pkg/cmd/server/server.go +++ b/pkg/cmd/server/server.go @@ -44,6 +44,7 @@ import ( "github.com/authzed/spicedb/pkg/cache" datastorecfg "github.com/authzed/spicedb/pkg/cmd/datastore" "github.com/authzed/spicedb/pkg/cmd/util" + "github.com/authzed/spicedb/pkg/datalayer" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/middleware/consistency" "github.com/authzed/spicedb/pkg/middleware/requestid" @@ -86,8 +87,12 @@ type Config struct { SchemaWatchHeartbeat time.Duration `debugmap:"visible"` NamespaceCacheConfig CacheConfig `debugmap:"visible"` + // Stored schema hash cache + StoredSchemaCacheConfig CacheConfig `debugmap:"visible"` + // Schema options - SchemaPrefixesRequired bool `debugmap:"visible"` + SchemaPrefixesRequired bool `debugmap:"visible"` + ExperimentalSchemaMode string `debugmap:"visible"` // Dispatch options DispatchServer util.GRPCServerConfig `debugmap:"visible"` @@ -224,6 +229,13 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) { cachingMode = schemacaching.WatchIfSupported } + storedSchemaCache, err := CompleteCache[datalayer.SchemaCacheKey, *datastore.ReadOnlyStoredSchema](&c.StoredSchemaCacheConfig) + if err != nil { + return nil, fmt.Errorf("failed to create stored schema cache: %w", err) + } + log.Ctx(ctx).Info().EmbedObject(storedSchemaCache).Msg("configured stored schema cache") + closeables.AddWithoutError(storedSchemaCache.Close) + ds = proxy.NewObservableDatastoreProxy(ds) ds = proxy.NewSingleflightDatastoreProxy(ds) ds = schemacaching.NewCachingDatastoreProxy(ds, nscc, c.DatastoreConfig.GCWindow, cachingMode, c.SchemaWatchHeartbeat) @@ -371,6 +383,17 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) { memoryUsageProvider := c.BuildMemoryUsageProvider() + // Parse schema mode for datalayer construction + var dlOpts []datalayer.DataLayerOption + dlOpts = append(dlOpts, datalayer.WithSchemaCache(storedSchemaCache)) + if c.ExperimentalSchemaMode != "" { + schemaMode, smErr := datalayer.ParseSchemaMode(c.ExperimentalSchemaMode) + if smErr != nil { + return nil, smErr + } + dlOpts = append(dlOpts, datalayer.WithSchemaMode(schemaMode)) + } + opts := MiddlewareOption{ Logger: log.Logger, AuthFunc: c.GRPCAuthFunc, @@ -383,7 +406,7 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) { MismatchingZedTokenOption: mismatchZedTokenOption, MemoryUsageProvider: memoryUsageProvider, } - opts = opts.WithDatastore(ds) + opts = opts.WithDatastore(ds, dlOpts...) // Build OTel stats handler options (shared by both gRPC servers) // Always disable health check tracing to reduce trace volume @@ -391,7 +414,7 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) { otelgrpc.WithFilter(filters.Not(filters.HealthCheck())), } - dispatchGrpcServer, err := c.buildDispatchServer(memoryUsageProvider, ds, cachingClusterDispatch, statsHandlerOpts) + dispatchGrpcServer, err := c.buildDispatchServer(memoryUsageProvider, ds, cachingClusterDispatch, statsHandlerOpts, dlOpts) if err != nil { return nil, err } @@ -549,12 +572,12 @@ func (c *Config) BuildMemoryUsageProvider() memoryprotection.MemoryUsageProvider return &memoryprotection.HarcodedMemoryUsageProvider{AcceptAllRequests: true} } -func (c *Config) buildDispatchServer(memoryUsageProvider memoryprotection.MemoryUsageProvider, ds datastore.Datastore, cachingClusterDispatch dispatch.Dispatcher, otelOpts []otelgrpc.Option) (util.RunnableGRPCServer, error) { +func (c *Config) buildDispatchServer(memoryUsageProvider memoryprotection.MemoryUsageProvider, ds datastore.Datastore, cachingClusterDispatch dispatch.Dispatcher, otelOpts []otelgrpc.Option, dlOpts []datalayer.DataLayerOption) (util.RunnableGRPCServer, error) { if len(c.DispatchUnaryMiddleware) == 0 && len(c.DispatchStreamingMiddleware) == 0 { if c.GRPCAuthFunc == nil { - c.DispatchUnaryMiddleware, c.DispatchStreamingMiddleware = DefaultDispatchMiddleware(log.Logger, auth.MustRequirePresharedKey(c.PresharedSecureKey), ds, c.DisableGRPCLatencyHistogram, memoryUsageProvider) + c.DispatchUnaryMiddleware, c.DispatchStreamingMiddleware = DefaultDispatchMiddleware(log.Logger, auth.MustRequirePresharedKey(c.PresharedSecureKey), ds, c.DisableGRPCLatencyHistogram, memoryUsageProvider, dlOpts...) } else { - c.DispatchUnaryMiddleware, c.DispatchStreamingMiddleware = DefaultDispatchMiddleware(log.Logger, c.GRPCAuthFunc, ds, c.DisableGRPCLatencyHistogram, memoryUsageProvider) + c.DispatchUnaryMiddleware, c.DispatchStreamingMiddleware = DefaultDispatchMiddleware(log.Logger, c.GRPCAuthFunc, ds, c.DisableGRPCLatencyHistogram, memoryUsageProvider, dlOpts...) } } diff --git a/pkg/cmd/server/server_test.go b/pkg/cmd/server/server_test.go index 5dd76f0f2..e1e9b6019 100644 --- a/pkg/cmd/server/server_test.go +++ b/pkg/cmd/server/server_test.go @@ -629,7 +629,7 @@ func TestBuildDispatchServer(t *testing.T) { sampler := memoryprotection.NewNoopMemoryUsageProvider() - srv, err := tc.config.buildDispatchServer(sampler, mockDatastore, mockDispatcher, nil) + srv, err := tc.config.buildDispatchServer(sampler, mockDatastore, mockDispatcher, nil, nil) require.NoError(t, err) require.NotNil(t, srv) require.Len(t, tc.config.DispatchUnaryMiddleware, tc.expectedDispatchUnnaryMiddleware) diff --git a/pkg/cmd/server/zz_generated.options.go b/pkg/cmd/server/zz_generated.options.go index f9fc5a91c..1ffc27c4e 100644 --- a/pkg/cmd/server/zz_generated.options.go +++ b/pkg/cmd/server/zz_generated.options.go @@ -56,7 +56,9 @@ func (c *Config) ToOption() ConfigOption { to.EnableExperimentalWatchableSchemaCache = c.EnableExperimentalWatchableSchemaCache to.SchemaWatchHeartbeat = c.SchemaWatchHeartbeat to.NamespaceCacheConfig = c.NamespaceCacheConfig + to.StoredSchemaCacheConfig = c.StoredSchemaCacheConfig to.SchemaPrefixesRequired = c.SchemaPrefixesRequired + to.ExperimentalSchemaMode = c.ExperimentalSchemaMode to.DispatchServer = c.DispatchServer to.DispatchMaxDepth = c.DispatchMaxDepth to.GlobalDispatchConcurrencyLimit = c.GlobalDispatchConcurrencyLimit @@ -158,7 +160,13 @@ func (c *Config) DebugMap() map[string]any { debugMap["EnableExperimentalWatchableSchemaCache"] = c.EnableExperimentalWatchableSchemaCache debugMap["SchemaWatchHeartbeat"] = c.SchemaWatchHeartbeat debugMap["NamespaceCacheConfig"] = c.NamespaceCacheConfig + debugMap["StoredSchemaCacheConfig"] = c.StoredSchemaCacheConfig debugMap["SchemaPrefixesRequired"] = c.SchemaPrefixesRequired + if c.ExperimentalSchemaMode == "" { + debugMap["ExperimentalSchemaMode"] = "(empty)" + } else { + debugMap["ExperimentalSchemaMode"] = c.ExperimentalSchemaMode + } debugMap["DispatchServer"] = c.DispatchServer debugMap["DispatchMaxDepth"] = c.DispatchMaxDepth debugMap["GlobalDispatchConcurrencyLimit"] = c.GlobalDispatchConcurrencyLimit @@ -435,6 +443,13 @@ func WithNamespaceCacheConfig(namespaceCacheConfig CacheConfig) ConfigOption { } } +// WithStoredSchemaCacheConfig returns an option that can set StoredSchemaCacheConfig on a Config +func WithStoredSchemaCacheConfig(storedSchemaCacheConfig CacheConfig) ConfigOption { + return func(c *Config) { + c.StoredSchemaCacheConfig = storedSchemaCacheConfig + } +} + // WithSchemaPrefixesRequired returns an option that can set SchemaPrefixesRequired on a Config func WithSchemaPrefixesRequired(schemaPrefixesRequired bool) ConfigOption { return func(c *Config) { @@ -442,6 +457,13 @@ func WithSchemaPrefixesRequired(schemaPrefixesRequired bool) ConfigOption { } } +// WithExperimentalSchemaMode returns an option that can set ExperimentalSchemaMode on a Config +func WithExperimentalSchemaMode(experimentalSchemaMode string) ConfigOption { + return func(c *Config) { + c.ExperimentalSchemaMode = experimentalSchemaMode + } +} + // WithDispatchServer returns an option that can set DispatchServer on a Config func WithDispatchServer(dispatchServer util.GRPCServerConfig) ConfigOption { return func(c *Config) { diff --git a/pkg/cursor/cursor.go b/pkg/cursor/cursor.go index 9ab4f7086..2d9025930 100644 --- a/pkg/cursor/cursor.go +++ b/pkg/cursor/cursor.go @@ -8,6 +8,7 @@ import ( v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/authzed/spicedb/pkg/datalayer" "github.com/authzed/spicedb/pkg/datastore" dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1" impl "github.com/authzed/spicedb/pkg/proto/impl/v1" @@ -50,7 +51,7 @@ func Decode(encoded *v1.Cursor) (*impl.DecodedCursor, error) { // consumption, including the provided call context to ensure the API cursor reflects the calling // API method. The call hash should contain all the parameters of the calling API function, // as well as its revision and name. -func EncodeFromDispatchCursor(dispatchCursor *dispatch.Cursor, callAndParameterHash string, revision datastore.Revision, flags map[string]string) (*v1.Cursor, error) { +func EncodeFromDispatchCursor(dispatchCursor *dispatch.Cursor, callAndParameterHash string, revision datastore.Revision, schemaHash datalayer.SchemaHash, flags map[string]string) (*v1.Cursor, error) { if dispatchCursor == nil { return nil, spiceerrors.MustBugf("got nil dispatch cursor") } @@ -62,13 +63,14 @@ func EncodeFromDispatchCursor(dispatchCursor *dispatch.Cursor, callAndParameterH DispatchVersion: dispatchCursor.DispatchVersion, Sections: dispatchCursor.Sections, CallAndParametersHash: callAndParameterHash, + SchemaHash: []byte(schemaHash), Flags: flags, }, }, }) } -func EncodeFromDispatchCursorSections(dispatchCursorSections []string, callAndParameterHash string, revision datastore.Revision, flags map[string]string) (*v1.Cursor, error) { +func EncodeFromDispatchCursorSections(dispatchCursorSections []string, callAndParameterHash string, revision datastore.Revision, schemaHash datalayer.SchemaHash, flags map[string]string) (*v1.Cursor, error) { return Encode(&impl.DecodedCursor{ VersionOneof: &impl.DecodedCursor_V1{ V1: &impl.V1Cursor{ @@ -76,6 +78,7 @@ func EncodeFromDispatchCursorSections(dispatchCursorSections []string, callAndPa DispatchVersion: 1, Sections: dispatchCursorSections, CallAndParametersHash: callAndParameterHash, + SchemaHash: []byte(schemaHash), Flags: flags, }, }, @@ -123,38 +126,42 @@ func DecodeToDispatchCursor(encoded *v1.Cursor, callAndParameterHash string) (*d }, v1decoded.Flags, nil } -// DecodeToDispatchRevision decodes an encoded API cursor into an internal dispatch revision. +// DecodeToDispatchRevisionAndSchemaHash decodes an encoded API cursor into an internal dispatch revision and schema hash. // NOTE: this method does *not* verify the caller's method signature. -func DecodeToDispatchRevision(ctx context.Context, encoded *v1.Cursor, ds revisionDecoder) (datastore.Revision, zedtoken.TokenStatus, error) { +func DecodeToDispatchRevisionAndSchemaHash(ctx context.Context, encoded *v1.Cursor, ds revisionDecoder) (datastore.Revision, datalayer.SchemaHash, zedtoken.TokenStatus, error) { decoded, err := Decode(encoded) if err != nil { - return nil, zedtoken.StatusUnknown, err + return nil, "", zedtoken.StatusUnknown, err } v1decoded := decoded.GetV1() if v1decoded == nil { - return nil, zedtoken.StatusUnknown, ErrNilCursor + return nil, "", zedtoken.StatusUnknown, ErrNilCursor } datastoreUniqueID, err := ds.UniqueID(ctx) if err != nil { - return nil, zedtoken.StatusUnknown, fmt.Errorf(errEncodeError, err) + return nil, "", zedtoken.StatusUnknown, fmt.Errorf(errEncodeError, err) } parsed, err := ds.RevisionFromString(v1decoded.Revision) if err != nil { - return datastore.NoRevision, zedtoken.StatusUnknown, fmt.Errorf(errDecodeError, err) + return datastore.NoRevision, "", zedtoken.StatusUnknown, fmt.Errorf(errDecodeError, err) } if v1decoded.DatastoreUniqueId == "" { - return parsed, zedtoken.StatusLegacyEmptyDatastoreID, nil + return parsed, datalayer.NoSchemaHashForLegacyCursor, zedtoken.StatusLegacyEmptyDatastoreID, nil } if v1decoded.DatastoreUniqueId != datastoreUniqueID { - return parsed, zedtoken.StatusMismatchedDatastoreID, nil + return parsed, datalayer.NoSchemaHashForLegacyCursor, zedtoken.StatusMismatchedDatastoreID, nil } - return parsed, zedtoken.StatusValid, nil + schemaHash := datalayer.NoSchemaHashForLegacyCursor + if len(v1decoded.GetSchemaHash()) > 0 { + schemaHash = datalayer.SchemaHash(v1decoded.GetSchemaHash()) + } + return parsed, schemaHash, zedtoken.StatusValid, nil } type revisionDecoder interface { diff --git a/pkg/cursor/cursor_test.go b/pkg/cursor/cursor_test.go index 575b3ebaf..259550a15 100644 --- a/pkg/cursor/cursor_test.go +++ b/pkg/cursor/cursor_test.go @@ -10,8 +10,11 @@ import ( v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/authzed/spicedb/internal/datastore/revisions" + "github.com/authzed/spicedb/pkg/datalayer" "github.com/authzed/spicedb/pkg/datastore" dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + impl "github.com/authzed/spicedb/pkg/proto/impl/v1" + "github.com/authzed/spicedb/pkg/zedtoken" ) var ( @@ -49,7 +52,7 @@ func TestEncodeDecode(t *testing.T) { require := require.New(t) encoded, err := EncodeFromDispatchCursor(&dispatch.Cursor{ Sections: tc.sections, - }, tc.hash, tc.revision, map[string]string{"some": "flag"}) + }, tc.hash, tc.revision, datalayer.NoSchemaHashForLegacyCursor, map[string]string{"some": "flag"}) require.NoError(err) require.NotNil(encoded) @@ -60,7 +63,7 @@ func TestEncodeDecode(t *testing.T) { require.Equal(tc.sections, decoded.Sections) - decodedRev, _, err := DecodeToDispatchRevision(context.Background(), encoded, revisions.CommonDecoder{ + decodedRev, _, _, err := DecodeToDispatchRevisionAndSchemaHash(context.Background(), encoded, revisions.CommonDecoder{ Kind: revisions.TransactionID, }) require.NoError(err) @@ -137,7 +140,7 @@ func TestDecode(t *testing.T) { require.NotNil(decoded) require.Equal(testCase.expectedSections, decoded.Sections) - decodedRev, _, err := DecodeToDispatchRevision(context.Background(), &v1.Cursor{ + decodedRev, _, _, err := DecodeToDispatchRevisionAndSchemaHash(context.Background(), &v1.Cursor{ Token: testCase.token, }, revisions.CommonDecoder{ Kind: revisions.TransactionID, @@ -153,3 +156,63 @@ func TestDecode(t *testing.T) { }) } } + +func TestDecodeToDispatchRevisionAndSchemaHashWithDatastoreID(t *testing.T) { + require := require.New(t) + + // Encode a cursor that includes both a DatastoreUniqueId and a SchemaHash. + encoded, err := Encode(&impl.DecodedCursor{ + VersionOneof: &impl.DecodedCursor_V1{ + V1: &impl.V1Cursor{ + Revision: revision1.String(), + DispatchVersion: 1, + Sections: []string{"a", "b"}, + CallAndParametersHash: "testhash", + DatastoreUniqueId: "testdsid", + SchemaHash: []byte("myschema123"), + }, + }, + }) + require.NoError(err) + + decodedRev, schemaHash, status, err := DecodeToDispatchRevisionAndSchemaHash( + context.Background(), + encoded, + revisions.CommonDecoder{ + Kind: revisions.TransactionID, + DatastoreUniqueID: "testdsid", + }, + ) + require.NoError(err) + require.Equal(zedtoken.StatusValid, status) + require.True(revision1.Equal(decodedRev)) + require.Equal(datalayer.SchemaHash("myschema123"), schemaHash) +} + +func TestDecodeToDispatchRevisionAndSchemaHashMismatchedDatastoreID(t *testing.T) { + require := require.New(t) + + encoded, err := Encode(&impl.DecodedCursor{ + VersionOneof: &impl.DecodedCursor_V1{ + V1: &impl.V1Cursor{ + Revision: revision1.String(), + DispatchVersion: 1, + DatastoreUniqueId: "otherid", + SchemaHash: []byte("myschema123"), + }, + }, + }) + require.NoError(err) + + _, schemaHash, status, err := DecodeToDispatchRevisionAndSchemaHash( + context.Background(), + encoded, + revisions.CommonDecoder{ + Kind: revisions.TransactionID, + DatastoreUniqueID: "testdsid", + }, + ) + require.NoError(err) + require.Equal(zedtoken.StatusMismatchedDatastoreID, status) + require.Equal(datalayer.NoSchemaHashForLegacyCursor, schemaHash) +} diff --git a/pkg/datalayer/counting.go b/pkg/datalayer/counting.go deleted file mode 100644 index b27c372e3..000000000 --- a/pkg/datalayer/counting.go +++ /dev/null @@ -1,204 +0,0 @@ -package datalayer - -import ( - "context" - "sync/atomic" - - middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2" - "google.golang.org/grpc" - - v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" - - "github.com/authzed/spicedb/pkg/datastore" - "github.com/authzed/spicedb/pkg/datastore/options" - "github.com/authzed/spicedb/pkg/tuple" -) - -// MethodCounts holds per-request counters for tracked methods. -type MethodCounts struct { - queryRelationships atomic.Uint64 - reverseQueryRelationships atomic.Uint64 -} - -// QueryRelationships returns the count of QueryRelationships calls. -func (m *MethodCounts) QueryRelationships() uint64 { - return m.queryRelationships.Load() -} - -// ReverseQueryRelationships returns the count of ReverseQueryRelationships calls. -func (m *MethodCounts) ReverseQueryRelationships() uint64 { - return m.reverseQueryRelationships.Load() -} - -// MethodCountsExporter is a function that exports method counts (e.g. to Prometheus). -// This is provided by callers to avoid duplicate metric registration. -type MethodCountsExporter func(counts *MethodCounts) - -// NewCountingDataLayer wraps a DataLayer with per-request counting. -func NewCountingDataLayer(dl DataLayer) (DataLayer, *MethodCounts) { - counts := &MethodCounts{} - return &countingDataLayer{ - delegate: dl, - counts: counts, - }, counts -} - -type countingDataLayer struct { - delegate DataLayer - counts *MethodCounts -} - -func (c *countingDataLayer) SnapshotReader(rev datastore.Revision) RevisionedReader { - return &countingRevisionedReader{ - delegate: c.delegate.SnapshotReader(rev), - counts: c.counts, - } -} - -func (c *countingDataLayer) ReadWriteTx(ctx context.Context, fn TxUserFunc, opts ...options.RWTOptionsOption) (datastore.Revision, error) { - return c.delegate.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { - return fn(ctx, &countingReadWriteTransaction{ - ReadWriteTransaction: rwt, - counts: c.counts, - }) - }, opts...) -} - -func (c *countingDataLayer) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { - return c.delegate.OptimizedRevision(ctx) -} - -func (c *countingDataLayer) HeadRevision(ctx context.Context) (datastore.Revision, error) { - return c.delegate.HeadRevision(ctx) -} - -func (c *countingDataLayer) CheckRevision(ctx context.Context, revision datastore.Revision) error { - return c.delegate.CheckRevision(ctx, revision) -} - -func (c *countingDataLayer) RevisionFromString(serialized string) (datastore.Revision, error) { - return c.delegate.RevisionFromString(serialized) -} - -func (c *countingDataLayer) Watch(ctx context.Context, afterRevision datastore.Revision, opts datastore.WatchOptions) (<-chan datastore.RevisionChanges, <-chan error) { - return c.delegate.Watch(ctx, afterRevision, opts) -} - -func (c *countingDataLayer) ReadyState(ctx context.Context) (datastore.ReadyState, error) { - return c.delegate.ReadyState(ctx) -} - -func (c *countingDataLayer) Features(ctx context.Context) (*datastore.Features, error) { - return c.delegate.Features(ctx) -} - -func (c *countingDataLayer) OfflineFeatures() (*datastore.Features, error) { - return c.delegate.OfflineFeatures() -} - -func (c *countingDataLayer) Statistics(ctx context.Context) (datastore.Stats, error) { - return c.delegate.Statistics(ctx) -} - -func (c *countingDataLayer) UniqueID(ctx context.Context) (string, error) { - return c.delegate.UniqueID(ctx) -} - -func (c *countingDataLayer) MetricsID() (string, error) { - return c.delegate.MetricsID() -} - -func (c *countingDataLayer) Close() error { - return c.delegate.Close() -} - -func (c *countingDataLayer) unwrapDatastore() datastore.Datastore { - return UnwrapDatastore(c.delegate) -} - -type countingRevisionedReader struct { - delegate RevisionedReader - counts *MethodCounts -} - -func (r *countingRevisionedReader) ReadSchema(ctx context.Context) (SchemaReader, error) { - return r.delegate.ReadSchema(ctx) -} - -func (r *countingRevisionedReader) QueryRelationships(ctx context.Context, filter datastore.RelationshipsFilter, opts ...options.QueryOptionsOption) (datastore.RelationshipIterator, error) { - r.counts.queryRelationships.Add(1) - return r.delegate.QueryRelationships(ctx, filter, opts...) -} - -func (r *countingRevisionedReader) ReverseQueryRelationships(ctx context.Context, subjectsFilter datastore.SubjectsFilter, opts ...options.ReverseQueryOptionsOption) (datastore.RelationshipIterator, error) { - r.counts.reverseQueryRelationships.Add(1) - return r.delegate.ReverseQueryRelationships(ctx, subjectsFilter, opts...) -} - -func (r *countingRevisionedReader) CountRelationships(ctx context.Context, name string) (int, error) { - return r.delegate.CountRelationships(ctx, name) -} - -func (r *countingRevisionedReader) LookupCounters(ctx context.Context) ([]datastore.RelationshipCounter, error) { - return r.delegate.LookupCounters(ctx) -} - -type countingReadWriteTransaction struct { - ReadWriteTransaction - counts *MethodCounts -} - -func (t *countingReadWriteTransaction) QueryRelationships(ctx context.Context, filter datastore.RelationshipsFilter, opts ...options.QueryOptionsOption) (datastore.RelationshipIterator, error) { - t.counts.queryRelationships.Add(1) - return t.ReadWriteTransaction.QueryRelationships(ctx, filter, opts...) -} - -func (t *countingReadWriteTransaction) ReverseQueryRelationships(ctx context.Context, subjectsFilter datastore.SubjectsFilter, opts ...options.ReverseQueryOptionsOption) (datastore.RelationshipIterator, error) { - t.counts.reverseQueryRelationships.Add(1) - return t.ReadWriteTransaction.ReverseQueryRelationships(ctx, subjectsFilter, opts...) -} - -func (t *countingReadWriteTransaction) WriteRelationships(ctx context.Context, mutations []tuple.RelationshipUpdate) error { - return t.ReadWriteTransaction.WriteRelationships(ctx, mutations) -} - -func (t *countingReadWriteTransaction) DeleteRelationships(ctx context.Context, filter *v1.RelationshipFilter, opts ...options.DeleteOptionsOption) (uint64, bool, error) { - return t.ReadWriteTransaction.DeleteRelationships(ctx, filter, opts...) -} - -// UnaryCountingInterceptor wraps the datalayer with counting for each unary request. -// The exporter function is called after each request to export counts (e.g. to Prometheus). -func UnaryCountingInterceptor(exporter MethodCountsExporter) grpc.UnaryServerInterceptor { - return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { - dl := MustFromContext(ctx) - countingDL, counts := NewCountingDataLayer(dl) - if err := SetInContext(ctx, countingDL); err != nil { - return nil, err - } - resp, err := handler(ctx, req) - if exporter != nil { - exporter(counts) - } - return resp, err - } -} - -// StreamCountingInterceptor wraps the datalayer with counting for each stream request. -// The exporter function is called after each stream to export counts (e.g. to Prometheus). -func StreamCountingInterceptor(exporter MethodCountsExporter) grpc.StreamServerInterceptor { - return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - ctx := ss.Context() - dl := MustFromContext(ctx) - countingDL, counts := NewCountingDataLayer(dl) - if err := SetInContext(ctx, countingDL); err != nil { - return err - } - wrapped := middleware.WrapServerStream(ss) - wrapped.WrappedContext = ctx - err := handler(srv, wrapped) - if exporter != nil { - exporter(counts) - } - return err - } -} diff --git a/pkg/datalayer/datalayer.go b/pkg/datalayer/datalayer.go index 2e3c0de5a..35e9d4725 100644 --- a/pkg/datalayer/datalayer.go +++ b/pkg/datalayer/datalayer.go @@ -28,10 +28,10 @@ type LegacySchemaWriter interface { // It abstracts the underlying datastore, hiding Legacy* methods and // providing clean access to schema, relationships, and metadata. type DataLayer interface { - SnapshotReader(datastore.Revision) RevisionedReader + SnapshotReader(datastore.Revision, SchemaHash) RevisionedReader ReadWriteTx(context.Context, TxUserFunc, ...options.RWTOptionsOption) (datastore.Revision, error) - OptimizedRevision(ctx context.Context) (datastore.Revision, error) - HeadRevision(ctx context.Context) (datastore.Revision, error) + OptimizedRevision(ctx context.Context) (datastore.Revision, SchemaHash, error) + HeadRevision(ctx context.Context) (datastore.Revision, SchemaHash, error) CheckRevision(ctx context.Context, revision datastore.Revision) error RevisionFromString(serialized string) (datastore.Revision, error) Watch(ctx context.Context, afterRevision datastore.Revision, options datastore.WatchOptions) (<-chan datastore.RevisionChanges, <-chan error) diff --git a/pkg/datalayer/datalayer_test.go b/pkg/datalayer/datalayer_test.go new file mode 100644 index 000000000..8906bf9ea --- /dev/null +++ b/pkg/datalayer/datalayer_test.go @@ -0,0 +1,1093 @@ +package datalayer + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/pkg/caveats" + "github.com/authzed/spicedb/pkg/datastore" + ns "github.com/authzed/spicedb/pkg/namespace" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" + "github.com/authzed/spicedb/pkg/schemadsl/input" +) + +var testDefinitions = []compiler.SchemaDefinition{ + ns.Namespace("user"), + ns.Namespace("document", + ns.MustRelation("viewer", nil, ns.AllowedRelation("user", "...")), + ), +} + +func newTestDatastore(t *testing.T) datastore.Datastore { + t.Helper() + ds, err := memdb.NewMemdbDatastore(0, 1*time.Second, 90000*time.Second) + require.NoError(t, err) + t.Cleanup(func() { ds.Close() }) + return ds +} + +func testSchemaDefinitions(t *testing.T) ([]datastore.SchemaDefinition, string) { + t.Helper() + schemaText, _, err := generator.GenerateSchema(testDefinitions) + require.NoError(t, err) + + defs := make([]datastore.SchemaDefinition, 0, len(testDefinitions)) + for _, def := range testDefinitions { + defs = append(defs, def.(datastore.SchemaDefinition)) + } + return defs, schemaText +} + +// hasLegacyData checks whether legacy schema storage has namespace data at the given revision. +func hasLegacyData(t *testing.T, ds datastore.Datastore, rev datastore.Revision) bool { + t.Helper() + ctx := t.Context() + nsDefs, err := ds.SnapshotReader(rev).LegacyListAllNamespaces(ctx) + require.NoError(t, err) + return len(nsDefs) > 0 +} + +// hasUnifiedData checks whether unified schema storage has data at the given revision. +func hasUnifiedData(t *testing.T, ds datastore.Datastore, rev datastore.Revision) bool { + t.Helper() + ctx := t.Context() + _, err := ds.SnapshotReader(rev).ReadStoredSchema(ctx) + if err != nil { + require.ErrorIs(t, err, datastore.ErrSchemaNotFound) + return false + } + return true +} + +func TestWriteSchemaRouting(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mode SchemaMode + expectLegacy bool + expectUnified bool + }{ + { + name: "ReadLegacyWriteLegacy", + mode: SchemaModeReadLegacyWriteLegacy, + expectLegacy: true, + expectUnified: false, + }, + { + name: "ReadLegacyWriteBoth", + mode: SchemaModeReadLegacyWriteBoth, + expectLegacy: true, + expectUnified: true, + }, + { + name: "ReadNewWriteBoth", + mode: SchemaModeReadNewWriteBoth, + expectLegacy: true, + expectUnified: true, + }, + { + name: "ReadNewWriteNew", + mode: SchemaModeReadNewWriteNew, + expectLegacy: false, + expectUnified: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + require := require.New(t) + + ds := newTestDatastore(t) + dl := NewDataLayer(ds, WithSchemaMode(tc.mode)) + + defs, schemaText := testSchemaDefinitions(t) + ctx := t.Context() + + rev, err := dl.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { + return rwt.WriteSchema(ctx, defs, schemaText, nil) + }) + require.NoError(err) + + require.Equal(tc.expectLegacy, hasLegacyData(t, ds, rev), + "legacy store populated = %v, want %v", hasLegacyData(t, ds, rev), tc.expectLegacy) + require.Equal(tc.expectUnified, hasUnifiedData(t, ds, rev), + "unified store populated = %v, want %v", hasUnifiedData(t, ds, rev), tc.expectUnified) + }) + } +} + +func TestReadSchemaRouting(t *testing.T) { + t.Parallel() + + modes := []struct { + name string + mode SchemaMode + }{ + {"ReadLegacyWriteLegacy", SchemaModeReadLegacyWriteLegacy}, + {"ReadLegacyWriteBoth", SchemaModeReadLegacyWriteBoth}, + {"ReadNewWriteBoth", SchemaModeReadNewWriteBoth}, + {"ReadNewWriteNew", SchemaModeReadNewWriteNew}, + } + + for _, tc := range modes { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + require := require.New(t) + + ds := newTestDatastore(t) + dl := NewDataLayer(ds, WithSchemaMode(tc.mode)) + + defs, schemaText := testSchemaDefinitions(t) + ctx := t.Context() + + // Write through the datalayer so both stores are populated if needed. + rev, err := dl.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { + return rwt.WriteSchema(ctx, defs, schemaText, nil) + }) + require.NoError(err) + + // Read back through the datalayer. + reader := dl.SnapshotReader(rev, NoSchemaHashForTesting) + schemaReader, err := reader.ReadSchema(ctx) + require.NoError(err) + + // Verify we can read back the schema. + readText, err := schemaReader.SchemaText(ctx) + require.NoError(err) + require.NotEmpty(readText) + + typeDefs, err := schemaReader.ListAllTypeDefinitions(ctx) + require.NoError(err) + require.Len(typeDefs, 2) + + // Both legacy and unified readers should return a real LastWrittenRevision. + for _, td := range typeDefs { + require.NotNil(td.LastWrittenRevision, + "LastWrittenRevision should be set for %s", td.Definition.Name) + require.NotEqual(datastore.NoRevision, td.LastWrittenRevision, + "LastWrittenRevision should not be NoRevision for %s", td.Definition.Name) + } + }) + } +} + +func TestReadSchemaWithinTransaction(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mode SchemaMode + }{ + {"ReadLegacyWriteLegacy", SchemaModeReadLegacyWriteLegacy}, + {"ReadLegacyWriteBoth", SchemaModeReadLegacyWriteBoth}, + {"ReadNewWriteBoth", SchemaModeReadNewWriteBoth}, + {"ReadNewWriteNew", SchemaModeReadNewWriteNew}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + require := require.New(t) + + ds := newTestDatastore(t) + dl := NewDataLayer(ds, WithSchemaMode(tc.mode)) + + defs, schemaText := testSchemaDefinitions(t) + ctx := t.Context() + + _, err := dl.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { + // Write schema. + if err := rwt.WriteSchema(ctx, defs, schemaText, nil); err != nil { + return err + } + + // Read schema back within the same transaction. + schemaReader, err := rwt.ReadSchema(ctx) + if err != nil { + return err + } + + readText, err := schemaReader.SchemaText(ctx) + if err != nil { + return err + } + require.NotEmpty(readText) + + typeDefs, err := schemaReader.ListAllTypeDefinitions(ctx) + if err != nil { + return err + } + require.Len(typeDefs, 2) + + return nil + }) + require.NoError(err) + }) + } +} + +func TestWriteSchemaDualWriteConsistency(t *testing.T) { + t.Parallel() + + // For modes that write to both stores, verify the data is consistent between them. + dualWriteModes := []struct { + name string + mode SchemaMode + }{ + {"ReadLegacyWriteBoth", SchemaModeReadLegacyWriteBoth}, + {"ReadNewWriteBoth", SchemaModeReadNewWriteBoth}, + } + + for _, tc := range dualWriteModes { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + require := require.New(t) + + ds := newTestDatastore(t) + dl := NewDataLayer(ds, WithSchemaMode(tc.mode)) + + defs, schemaText := testSchemaDefinitions(t) + ctx := t.Context() + + rev, err := dl.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { + return rwt.WriteSchema(ctx, defs, schemaText, nil) + }) + require.NoError(err) + + // Read from legacy store directly. + legacyNsDefs, err := ds.SnapshotReader(rev).LegacyListAllNamespaces(ctx) + require.NoError(err) + legacyNames := make(map[string]bool, len(legacyNsDefs)) + for _, ns := range legacyNsDefs { + legacyNames[ns.Definition.Name] = true + } + + // Read from unified store directly. + storedSchema, err := ds.SnapshotReader(rev).ReadStoredSchema(ctx) + require.NoError(err) + v1 := storedSchema.Get().GetV1() + require.NotNil(v1) + + // Both stores should have the same namespace definitions. + require.Len(v1.NamespaceDefinitions, len(legacyNsDefs), + "both stores should have the same number of namespace definitions") + for name := range v1.NamespaceDefinitions { + require.True(legacyNames[name], + "namespace %q in unified store but not in legacy store", name) + } + }) + } +} + +func TestSchemaLookupOperationsPerMode(t *testing.T) { + t.Parallel() + + modes := []struct { + name string + mode SchemaMode + }{ + {"ReadLegacyWriteLegacy", SchemaModeReadLegacyWriteLegacy}, + {"ReadLegacyWriteBoth", SchemaModeReadLegacyWriteBoth}, + {"ReadNewWriteBoth", SchemaModeReadNewWriteBoth}, + {"ReadNewWriteNew", SchemaModeReadNewWriteNew}, + } + + for _, tc := range modes { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + require := require.New(t) + + ds := newTestDatastore(t) + dl := NewDataLayer(ds, WithSchemaMode(tc.mode)) + + defs, schemaText := testSchemaDefinitions(t) + ctx := t.Context() + + rev, err := dl.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { + return rwt.WriteSchema(ctx, defs, schemaText, nil) + }) + require.NoError(err) + + reader := dl.SnapshotReader(rev, NoSchemaHashForTesting) + schemaReader, err := reader.ReadSchema(ctx) + require.NoError(err) + + // LookupTypeDefByName: existing. + docDef, found, err := schemaReader.LookupTypeDefByName(ctx, "document") + require.NoError(err) + require.True(found) + require.Equal("document", docDef.Definition.Name) + + // LookupTypeDefByName: missing. + _, found, err = schemaReader.LookupTypeDefByName(ctx, "nonexistent") + require.NoError(err) + require.False(found) + + // LookupCaveatDefByName: missing (no caveats in this schema). + _, found, err = schemaReader.LookupCaveatDefByName(ctx, "nonexistent") + require.NoError(err) + require.False(found) + + // ListAllSchemaDefinitions. + allDefs, err := schemaReader.ListAllSchemaDefinitions(ctx) + require.NoError(err) + require.Len(allDefs, 2) + require.Contains(allDefs, "user") + require.Contains(allDefs, "document") + + // LookupSchemaDefinitionsByNames. + lookedUp, err := schemaReader.LookupSchemaDefinitionsByNames(ctx, []string{"user", "nonexistent"}) + require.NoError(err) + require.Len(lookedUp, 1) + require.Contains(lookedUp, "user") + + // LookupTypeDefinitionsByNames. + typeDefs, err := schemaReader.LookupTypeDefinitionsByNames(ctx, []string{"user", "document"}) + require.NoError(err) + require.Len(typeDefs, 2) + + // LookupCaveatDefinitionsByNames: empty since no caveats. + caveatDefs, err := schemaReader.LookupCaveatDefinitionsByNames(ctx, []string{"anything"}) + require.NoError(err) + require.Empty(caveatDefs) + }) + } +} + +func TestWriteSchemaWithCaveatsPerMode(t *testing.T) { + t.Parallel() + + modes := []struct { + name string + mode SchemaMode + }{ + {"ReadLegacyWriteLegacy", SchemaModeReadLegacyWriteLegacy}, + {"ReadLegacyWriteBoth", SchemaModeReadLegacyWriteBoth}, + {"ReadNewWriteBoth", SchemaModeReadNewWriteBoth}, + {"ReadNewWriteNew", SchemaModeReadNewWriteNew}, + } + + for _, tc := range modes { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + require := require.New(t) + + ds := newTestDatastore(t) + dl := NewDataLayer(ds, WithSchemaMode(tc.mode)) + + schemaText := `caveat test_caveat(allowed bool) { + allowed +} + +definition user {} + +definition document { + relation viewer: user with test_caveat +}` + + compiled, err := compiler.Compile(compiler.InputSchema{ + Source: input.Source("schema"), + SchemaString: schemaText, + }, compiler.AllowUnprefixedObjectType()) + require.NoError(err) + + defs := make([]datastore.SchemaDefinition, 0, len(compiled.ObjectDefinitions)+len(compiled.CaveatDefinitions)) + for _, caveatDef := range compiled.CaveatDefinitions { + defs = append(defs, caveatDef) + } + for _, objDef := range compiled.ObjectDefinitions { + defs = append(defs, objDef) + } + + ctx := t.Context() + + rev, err := dl.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { + return rwt.WriteSchema(ctx, defs, schemaText, nil) + }) + require.NoError(err) + + // Verify through the datalayer. + reader := dl.SnapshotReader(rev, NoSchemaHashForTesting) + schemaReader, err := reader.ReadSchema(ctx) + require.NoError(err) + + caveats, err := schemaReader.ListAllCaveatDefinitions(ctx) + require.NoError(err) + require.Len(caveats, 1) + require.Equal("test_caveat", caveats[0].Definition.Name) + + caveat, found, err := schemaReader.LookupCaveatDefByName(ctx, "test_caveat") + require.NoError(err) + require.True(found) + require.Equal("test_caveat", caveat.Definition.Name) + }) + } +} + +func TestWriteSchemaStableTextPerMode(t *testing.T) { + t.Parallel() + + // Schema text with definitions intentionally in non-alphabetical order. + outOfOrderSchemaText := `caveat zebra_caveat(flag bool) { + flag +} + +caveat alpha_caveat(allowed bool) { + allowed +} + +definition zresource { + relation viewer: auser with alpha_caveat +} + +definition auser {}` + + modes := []struct { + name string + mode SchemaMode + readsFromNew bool + }{ + {"ReadLegacyWriteLegacy", SchemaModeReadLegacyWriteLegacy, false}, + {"ReadLegacyWriteBoth", SchemaModeReadLegacyWriteBoth, false}, + {"ReadNewWriteBoth", SchemaModeReadNewWriteBoth, true}, + {"ReadNewWriteNew", SchemaModeReadNewWriteNew, true}, + } + + for _, tc := range modes { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + require := require.New(t) + + ds := newTestDatastore(t) + dl := NewDataLayer(ds, WithSchemaMode(tc.mode)) + + compiled, err := compiler.Compile(compiler.InputSchema{ + Source: input.Source("schema"), + SchemaString: outOfOrderSchemaText, + }, compiler.AllowUnprefixedObjectType()) + require.NoError(err) + + // Build definitions in the compiled order (which may not be alphabetical). + defs := make([]datastore.SchemaDefinition, 0, + len(compiled.CaveatDefinitions)+len(compiled.ObjectDefinitions)) + for _, cd := range compiled.CaveatDefinitions { + defs = append(defs, cd) + } + for _, od := range compiled.ObjectDefinitions { + defs = append(defs, od) + } + + ctx := t.Context() + + rev, err := dl.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { + return rwt.WriteSchema(ctx, defs, outOfOrderSchemaText, nil) + }) + require.NoError(err) + + reader := dl.SnapshotReader(rev, NoSchemaHashForTesting) + schemaReader, err := reader.ReadSchema(ctx) + require.NoError(err) + + readText, err := schemaReader.SchemaText(ctx) + require.NoError(err) + + if tc.readsFromNew { + // Read-new modes should return the exact written text, preserving the + // original ordering even though definitions were out of alphabetical order. + require.Equal(outOfOrderSchemaText, readText, + "read-new mode should return the original written text verbatim") + } else { + // Read-legacy modes regenerate text from sorted definitions, so the + // output should be sorted alphabetically. + require.NotEmpty(readText) + require.Contains(readText, "alpha_caveat") + require.Contains(readText, "zebra_caveat") + require.Contains(readText, "auser") + require.Contains(readText, "zresource") + + // Read a second time and verify the text is stable. + schemaReader2, err := reader.ReadSchema(ctx) + require.NoError(err) + readText2, err := schemaReader2.SchemaText(ctx) + require.NoError(err) + require.Equal(readText, readText2, + "legacy text should be stable across multiple reads") + } + }) + } +} + +func TestWriteSchemaRejectsCaveatAndNamespaceWithSameName(t *testing.T) { + t.Parallel() + + env := caveats.NewEnvironmentWithDefaultTypeSet() + caveatDef := ns.MustCaveatDefinition(env, "samename", "1 == 1") + namespaceDef := ns.Namespace("samename") + + defs := []datastore.SchemaDefinition{namespaceDef, caveatDef} + + // WriteSchemaViaStoredSchema uses MustBugf for duplicate names, which panics in tests. + require.Panics(t, func() { + _ = WriteSchemaViaStoredSchema(t.Context(), nil, defs, "", nil) + }) +} + +func TestParseSchemaMode(t *testing.T) { + t.Parallel() + + tests := []struct { + input string + expected SchemaMode + wantErr bool + }{ + {"read-legacy-write-legacy", SchemaModeReadLegacyWriteLegacy, false}, + {"read-legacy-write-both", SchemaModeReadLegacyWriteBoth, false}, + {"read-new-write-both", SchemaModeReadNewWriteBoth, false}, + {"read-new-write-new", SchemaModeReadNewWriteNew, false}, + {"invalid", SchemaModeReadLegacyWriteLegacy, true}, + {"", SchemaModeReadLegacyWriteLegacy, true}, + } + + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + t.Parallel() + mode, err := ParseSchemaMode(tc.input) + if tc.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), "invalid schema mode") + } else { + require.NoError(t, err) + require.Equal(t, tc.expected, mode) + } + }) + } +} + +func TestSchemaModeString(t *testing.T) { + t.Parallel() + + tests := []struct { + mode SchemaMode + expected string + }{ + {SchemaModeReadLegacyWriteLegacy, "read-legacy-write-legacy"}, + {SchemaModeReadLegacyWriteBoth, "read-legacy-write-both"}, + {SchemaModeReadNewWriteBoth, "read-new-write-both"}, + {SchemaModeReadNewWriteNew, "read-new-write-new"}, + {SchemaMode(255), "unknown"}, + } + + for _, tc := range tests { + t.Run(tc.expected, func(t *testing.T) { + t.Parallel() + require.Equal(t, tc.expected, tc.mode.String()) + }) + } +} + +func TestSchemaModeProperties(t *testing.T) { + t.Parallel() + + tests := []struct { + mode SchemaMode + readsFromNew bool + writesToLegacy bool + writesToNew bool + }{ + {SchemaModeReadLegacyWriteLegacy, false, true, false}, + {SchemaModeReadLegacyWriteBoth, false, true, true}, + {SchemaModeReadNewWriteBoth, true, true, true}, + {SchemaModeReadNewWriteNew, true, false, true}, + } + + for _, tc := range tests { + t.Run(tc.mode.String(), func(t *testing.T) { + t.Parallel() + require.Equal(t, tc.readsFromNew, tc.mode.ReadsFromNew()) + require.Equal(t, tc.writesToLegacy, tc.mode.WritesToLegacy()) + require.Equal(t, tc.writesToNew, tc.mode.WritesToNew()) + }) + } +} + +func TestParseSchemaModeRoundTrip(t *testing.T) { + t.Parallel() + + modes := []SchemaMode{ + SchemaModeReadLegacyWriteLegacy, + SchemaModeReadLegacyWriteBoth, + SchemaModeReadNewWriteBoth, + SchemaModeReadNewWriteNew, + } + + for _, mode := range modes { + t.Run(mode.String(), func(t *testing.T) { + t.Parallel() + parsed, err := ParseSchemaMode(mode.String()) + require.NoError(t, err) + require.Equal(t, mode, parsed) + }) + } +} + +func TestStoredSchemaReaderAdapterEmptySchema(t *testing.T) { + t.Parallel() + + // Test the storedSchemaReaderAdapter when schema is empty (no definitions). + adapter := &storedSchemaReaderAdapter{ + storedSchema: datastore.NewReadOnlyStoredSchema(&core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{}, + }, + }), + lastWrittenRevision: datastore.NoRevision, + } + + ctx := t.Context() + + // SchemaText should return error for empty schema. + _, err := adapter.SchemaText(ctx) + require.Error(t, err) + + // Lookups should return not-found. + _, found, err := adapter.LookupTypeDefByName(ctx, "user") + require.NoError(t, err) + require.False(t, found) + + _, found, err = adapter.LookupCaveatDefByName(ctx, "somecaveat") + require.NoError(t, err) + require.False(t, found) + + // Lists should return empty. + typeDefs, err := adapter.ListAllTypeDefinitions(ctx) + require.NoError(t, err) + require.Empty(t, typeDefs) + + caveatDefs, err := adapter.ListAllCaveatDefinitions(ctx) + require.NoError(t, err) + require.Empty(t, caveatDefs) + + allDefs, err := adapter.ListAllSchemaDefinitions(ctx) + require.NoError(t, err) + require.Empty(t, allDefs) + + // Lookups by name should return empty maps. + looked, err := adapter.LookupSchemaDefinitionsByNames(ctx, []string{"user"}) + require.NoError(t, err) + require.Empty(t, looked) + + typeLooked, err := adapter.LookupTypeDefinitionsByNames(ctx, []string{"user"}) + require.NoError(t, err) + require.Empty(t, typeLooked) + + caveatLooked, err := adapter.LookupCaveatDefinitionsByNames(ctx, []string{"somecaveat"}) + require.NoError(t, err) + require.Empty(t, caveatLooked) +} + +func TestStoredSchemaReaderAdapterV1Nil(t *testing.T) { + t.Parallel() + + // Test the v1() fallback path when VersionOneof is nil. + adapter := &storedSchemaReaderAdapter{ + storedSchema: datastore.NewReadOnlyStoredSchema(&core.StoredSchema{ + Version: 1, + }), + lastWrittenRevision: datastore.NoRevision, + } + + // v1() should return a zero-valued V1StoredSchema, so no panic. + _, err := adapter.SchemaText(t.Context()) + require.Error(t, err) // empty schema, no definitions +} + +func TestUnwrapDatastore(t *testing.T) { + t.Parallel() + + ds := newTestDatastore(t) + dl := NewDataLayer(ds) + + // Should unwrap to the original datastore. + unwrapped := UnwrapDatastore(dl) + require.NotNil(t, unwrapped) + require.Equal(t, ds, unwrapped) +} + +func TestComputeSchemaHash(t *testing.T) { + t.Parallel() + + defs, _ := testSchemaDefinitions(t) + + toCompilerDefs := func(defs []datastore.SchemaDefinition) []compiler.SchemaDefinition { + result := make([]compiler.SchemaDefinition, 0, len(defs)) + for _, def := range defs { + result = append(result, def.(compiler.SchemaDefinition)) + } + return result + } + + hash1, err := generator.ComputeSchemaHash(toCompilerDefs(defs)) + require.NoError(t, err) + require.NotEmpty(t, hash1) + + // Same definitions should produce same hash. + hash2, err := generator.ComputeSchemaHash(toCompilerDefs(defs)) + require.NoError(t, err) + require.Equal(t, hash1, hash2) + + // Different definitions should produce different hash. + differentDefs := []datastore.SchemaDefinition{ + ns.Namespace("different"), + } + hash3, err := generator.ComputeSchemaHash(toCompilerDefs(differentDefs)) + require.NoError(t, err) + require.NotEqual(t, hash1, hash3) +} + +func TestWriteSchemaDeletesRemovedDefinitions(t *testing.T) { + t.Parallel() + + // Test that writing a schema with fewer definitions removes the old ones. + modes := []struct { + name string + mode SchemaMode + }{ + {"ReadLegacyWriteLegacy", SchemaModeReadLegacyWriteLegacy}, + {"ReadNewWriteNew", SchemaModeReadNewWriteNew}, + } + + for _, tc := range modes { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + require := require.New(t) + + ds := newTestDatastore(t) + dl := NewDataLayer(ds, WithSchemaMode(tc.mode)) + ctx := t.Context() + + // Write initial schema with two definitions. + defs, schemaText := testSchemaDefinitions(t) + _, err := dl.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { + return rwt.WriteSchema(ctx, defs, schemaText, nil) + }) + require.NoError(err) + + // Write schema with only one definition. + smallerDefs := []datastore.SchemaDefinition{ns.Namespace("user")} + smallSchemaText, _, err := generator.GenerateSchema([]compiler.SchemaDefinition{ns.Namespace("user")}) + require.NoError(err) + + rev, err := dl.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { + return rwt.WriteSchema(ctx, smallerDefs, smallSchemaText, nil) + }) + require.NoError(err) + + // Read back and verify "document" is gone. + reader := dl.SnapshotReader(rev, NoSchemaHashForTesting) + schemaReader, err := reader.ReadSchema(ctx) + require.NoError(err) + + typeDefs, err := schemaReader.ListAllTypeDefinitions(ctx) + require.NoError(err) + require.Len(typeDefs, 1) + require.Equal("user", typeDefs[0].Definition.Name) + }) + } +} + +// testSchemaCache is a simple in-memory cache satisfying SchemaCache for tests. +type testSchemaCache struct { + mu sync.Mutex + items map[SchemaCacheKey]*datastore.ReadOnlyStoredSchema // GUARDED_BY(mu) +} + +func newTestSchemaCache() *testSchemaCache { + return &testSchemaCache{items: make(map[SchemaCacheKey]*datastore.ReadOnlyStoredSchema)} +} + +func (c *testSchemaCache) Get(key SchemaCacheKey) (*datastore.ReadOnlyStoredSchema, bool) { + c.mu.Lock() + defer c.mu.Unlock() + v, ok := c.items[key] + return v, ok +} + +func (c *testSchemaCache) Set(key SchemaCacheKey, entry *datastore.ReadOnlyStoredSchema, _ int64) bool { + c.mu.Lock() + defer c.mu.Unlock() + c.items[key] = entry + return true +} + +func (c *testSchemaCache) Wait() {} + +// countingDatastore wraps a datastore.Datastore and counts ReadStoredSchema calls. +type countingDatastore struct { + datastore.Datastore + readStoredSchemaCalls atomic.Int32 +} + +func (c *countingDatastore) SnapshotReader(rev datastore.Revision) datastore.Reader { + return &countingReader{Reader: c.Datastore.SnapshotReader(rev), counter: &c.readStoredSchemaCalls} +} + +type countingReader struct { + datastore.Reader + counter *atomic.Int32 +} + +func (r *countingReader) ReadStoredSchema(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + r.counter.Add(1) + return r.Reader.ReadStoredSchema(ctx) +} + +func TestHashCacheIntegration_CachesReadStoredSchema(t *testing.T) { + t.Parallel() + require := require.New(t) + + rawDS := newTestDatastore(t) + ds := &countingDatastore{Datastore: rawDS} + dl := NewDataLayer(ds, WithSchemaMode(SchemaModeReadNewWriteNew), WithSchemaCache(newTestSchemaCache())) + + defs, schemaText := testSchemaDefinitions(t) + ctx := t.Context() + + // Write schema to populate unified storage. + rev, err := dl.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { + return rwt.WriteSchema(ctx, defs, schemaText, nil) + }) + require.NoError(err) + + // Compute the expected hash to use for cached reads. + compDefs := make([]compiler.SchemaDefinition, 0, len(defs)) + for _, def := range defs { + compDefs = append(compDefs, def.(compiler.SchemaDefinition)) + } + hash, err := generator.ComputeSchemaHash(compDefs) + require.NoError(err) + schemaHash := SchemaHash(hash) + + // Reset call counter (writes may have triggered reads). + ds.readStoredSchemaCalls.Store(0) + + // First read with the schema hash should hit the cache (populated by WriteSchema). + reader := dl.SnapshotReader(rev, schemaHash) + schemaReader, err := reader.ReadSchema(ctx) + require.NoError(err) + + readText, err := schemaReader.SchemaText(ctx) + require.NoError(err) + require.NotEmpty(readText) + + callsAfterFirst := ds.readStoredSchemaCalls.Load() + require.Equal(int32(0), callsAfterFirst, "should serve from cache without hitting datastore") + + // Second read should also hit cache. + reader2 := dl.SnapshotReader(rev, schemaHash) + schemaReader2, err := reader2.ReadSchema(ctx) + require.NoError(err) + + readText2, err := schemaReader2.SchemaText(ctx) + require.NoError(err) + require.Equal(readText, readText2) + require.Equal(int32(0), ds.readStoredSchemaCalls.Load(), "second read should also be cached") +} + +func TestHashCacheIntegration_BypassSentinelSkipsCache(t *testing.T) { + t.Parallel() + require := require.New(t) + + rawDS := newTestDatastore(t) + ds := &countingDatastore{Datastore: rawDS} + dl := NewDataLayer(ds, WithSchemaMode(SchemaModeReadNewWriteNew), WithSchemaCache(newTestSchemaCache())) + + defs, schemaText := testSchemaDefinitions(t) + ctx := t.Context() + + rev, err := dl.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { + return rwt.WriteSchema(ctx, defs, schemaText, nil) + }) + require.NoError(err) + + ds.readStoredSchemaCalls.Store(0) + + // Read with testing sentinel should bypass cache and hit datastore. + reader := dl.SnapshotReader(rev, NoSchemaHashForTesting) + schemaReader, err := reader.ReadSchema(ctx) + require.NoError(err) + + readText, err := schemaReader.SchemaText(ctx) + require.NoError(err) + require.NotEmpty(readText) + require.Equal(int32(1), ds.readStoredSchemaCalls.Load(), "sentinel should bypass cache") +} + +func TestHashCacheIntegration_WriteUpdatesCache(t *testing.T) { + t.Parallel() + require := require.New(t) + + rawDS := newTestDatastore(t) + ds := &countingDatastore{Datastore: rawDS} + dl := NewDataLayer(ds, WithSchemaMode(SchemaModeReadNewWriteNew), WithSchemaCache(newTestSchemaCache())) + + ctx := t.Context() + + // Write initial schema. + defs1, schemaText1 := testSchemaDefinitions(t) + rev1, err := dl.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { + return rwt.WriteSchema(ctx, defs1, schemaText1, nil) + }) + require.NoError(err) + + // Write different schema. + defs2 := []datastore.SchemaDefinition{ns.Namespace("newtype")} + schemaText2, _, err := generator.GenerateSchema([]compiler.SchemaDefinition{ns.Namespace("newtype")}) + require.NoError(err) + + rev2, err := dl.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { + return rwt.WriteSchema(ctx, defs2, schemaText2, nil) + }) + require.NoError(err) + + // Compute hash for the second schema. + hash2, err := generator.ComputeSchemaHash([]compiler.SchemaDefinition{ns.Namespace("newtype")}) + require.NoError(err) + + ds.readStoredSchemaCalls.Store(0) + + // Read with the new hash should be served from cache (populated by WriteSchema). + reader := dl.SnapshotReader(rev2, SchemaHash(hash2)) + schemaReader, err := reader.ReadSchema(ctx) + require.NoError(err) + + typeDefs, err := schemaReader.ListAllTypeDefinitions(ctx) + require.NoError(err) + require.Len(typeDefs, 1) + require.Equal("newtype", typeDefs[0].Definition.Name) + require.Equal(int32(0), ds.readStoredSchemaCalls.Load(), "new hash should be cached from write") + + // Read with the first schema's hash should also be cached. + compDefs1 := make([]compiler.SchemaDefinition, 0, len(defs1)) + for _, def := range defs1 { + compDefs1 = append(compDefs1, def.(compiler.SchemaDefinition)) + } + hash1, err := generator.ComputeSchemaHash(compDefs1) + require.NoError(err) + + reader1 := dl.SnapshotReader(rev1, SchemaHash(hash1)) + _, err = reader1.ReadSchema(ctx) + require.NoError(err) + require.Equal(int32(0), ds.readStoredSchemaCalls.Load(), "old hash should also be cached from first write") +} + +func TestHashCacheIntegration_CacheAcrossPhaseTransitions(t *testing.T) { + t.Parallel() + + rawDS := newTestDatastore(t) + ds := &countingDatastore{Datastore: rawDS} + cache := newTestSchemaCache() + ctx := t.Context() + + // Phase 1: Write schema with cache in ReadLegacyWriteBoth (populates unified storage). + dl1 := NewDataLayer(ds, WithSchemaMode(SchemaModeReadLegacyWriteBoth), WithSchemaCache(cache)) + defs1, schemaText1 := testSchemaDefinitions(t) + + rev1, err := dl1.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { + return rwt.WriteSchema(ctx, defs1, schemaText1, nil) + }) + require.NoError(t, err) + + // Compute the hash for defs1. + compDefs1 := make([]compiler.SchemaDefinition, 0, len(defs1)) + for _, def := range defs1 { + compDefs1 = append(compDefs1, def.(compiler.SchemaDefinition)) + } + hash1, err := generator.ComputeSchemaHash(compDefs1) + require.NoError(t, err) + + // Phase 2: Create a new DataLayer in ReadNewWriteNew with the SAME cache. + // The cache entry from phase 1's write should still be usable. + dl2 := NewDataLayer(ds, WithSchemaMode(SchemaModeReadNewWriteNew), WithSchemaCache(cache)) + + ds.readStoredSchemaCalls.Store(0) + + reader := dl2.SnapshotReader(rev1, SchemaHash(hash1)) + schemaReader, err := reader.ReadSchema(ctx) + require.NoError(t, err) + + readText, err := schemaReader.SchemaText(ctx) + require.NoError(t, err) + require.NotEmpty(t, readText) + require.Equal(t, int32(0), ds.readStoredSchemaCalls.Load(), + "cache from phase 1 write should serve reads in phase 2 without hitting datastore") + + // Write a different schema in phase 2. + defs2 := []datastore.SchemaDefinition{ns.Namespace("newtype")} + schemaText2, _, err := generator.GenerateSchema([]compiler.SchemaDefinition{ns.Namespace("newtype")}) + require.NoError(t, err) + + rev2, err := dl2.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { + return rwt.WriteSchema(ctx, defs2, schemaText2, nil) + }) + require.NoError(t, err) + + hash2, err := generator.ComputeSchemaHash([]compiler.SchemaDefinition{ns.Namespace("newtype")}) + require.NoError(t, err) + + ds.readStoredSchemaCalls.Store(0) + + // Read with new hash should be cached from the phase 2 write. + reader2 := dl2.SnapshotReader(rev2, SchemaHash(hash2)) + schemaReader2, err := reader2.ReadSchema(ctx) + require.NoError(t, err) + + typeDefs, err := schemaReader2.ListAllTypeDefinitions(ctx) + require.NoError(t, err) + require.Len(t, typeDefs, 1) + require.Equal(t, "newtype", typeDefs[0].Definition.Name) + require.Equal(t, int32(0), ds.readStoredSchemaCalls.Load(), + "phase 2 write should populate cache for new hash") + + // Old hash from phase 1 should still be cached. + reader3 := dl2.SnapshotReader(rev1, SchemaHash(hash1)) + _, err = reader3.ReadSchema(ctx) + require.NoError(t, err) + require.Equal(t, int32(0), ds.readStoredSchemaCalls.Load(), + "old hash from phase 1 should still be in cache") +} + +func TestHashCacheIntegration_NoCacheStillWorks(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Without a cache, everything should still work. + ds := newTestDatastore(t) + dl := NewDataLayer(ds, WithSchemaMode(SchemaModeReadNewWriteNew)) + + defs, schemaText := testSchemaDefinitions(t) + ctx := t.Context() + + rev, err := dl.ReadWriteTx(ctx, func(ctx context.Context, rwt ReadWriteTransaction) error { + return rwt.WriteSchema(ctx, defs, schemaText, nil) + }) + require.NoError(err) + + reader := dl.SnapshotReader(rev, NoSchemaHashForTesting) + schemaReader, err := reader.ReadSchema(ctx) + require.NoError(err) + + readText, err := schemaReader.SchemaText(ctx) + require.NoError(err) + require.NotEmpty(readText) +} diff --git a/pkg/datalayer/hashcache.go b/pkg/datalayer/hashcache.go new file mode 100644 index 000000000..39970d43f --- /dev/null +++ b/pkg/datalayer/hashcache.go @@ -0,0 +1,209 @@ +package datalayer + +import ( + "context" + "sync/atomic" + "time" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "resenje.org/singleflight" + + "github.com/authzed/spicedb/internal/telemetry/otelconv" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +var tracer = otel.Tracer("spicedb/pkg/datalayer") + +// singleflightTimeout is the maximum time to wait for a singleflight peer to +// load a schema before falling back to a direct load. This prevents a possible +// deadlock when all connections in a pool are held by goroutines waiting on the +// singleflight while the singleflight leader is blocked waiting for a connection. +// +//nolint:revive // var instead of const to allow test overrides +var singleflightTimeout = 1 * time.Second + +// SchemaCacheKey is the key type used for schema cache lookups. +// It implements cache.KeyString so it can be used with the cache package. +type SchemaCacheKey string + +// KeyString implements cache.KeyString. +func (k SchemaCacheKey) KeyString() string { return string(k) } + +// SchemaCache defines the interface for the backing cache used by schemaHashCache. +// This is satisfied by cache.Cache[SchemaCacheKey, *datastore.ReadOnlyStoredSchema]. +type SchemaCache interface { + Get(key SchemaCacheKey) (*datastore.ReadOnlyStoredSchema, bool) + Set(key SchemaCacheKey, entry *datastore.ReadOnlyStoredSchema, cost int64) bool + Wait() +} + +// latestSchemaEntry holds the most recent schema entry for fast-path lookups. +type latestSchemaEntry struct { + hash SchemaHash + schema *datastore.ReadOnlyStoredSchema +} + +// schemaHashCache is a thread-safe cache for schemas indexed by hash. +// It maintains an atomic pointer to the latest schema for fast-path reads when +// multiple requests access the same (most recent) schema. +// +// The underlying cache is thread-safe, so no locks are needed for cache operations. +// The atomic latest entry provides lock-free fast-path for the common case. +type schemaHashCache struct { + cache SchemaCache + latest atomic.Pointer[latestSchemaEntry] // Fast path for latest schema + singleflight singleflight.Group[string, *datastore.ReadOnlyStoredSchema] +} + +// newSchemaHashCache creates a new hash-based schema cache wrapping the given cache. +func newSchemaHashCache(c SchemaCache) *schemaHashCache { + return &schemaHashCache{ + cache: c, + } +} + +var _ storedSchemaCache = (*schemaHashCache)(nil) + +// get retrieves a schema from the cache by revision and hash. +// Fast path: If the hash matches the atomic latest entry, return immediately. +// Slow path: Check the cache. +// Returns (nil, nil) if the hash is a bypass sentinel (NoSchemaHashInTransaction or NoSchemaHashForTesting). +// Returns error if hash is empty string (indicates a bug where schema hash wasn't properly provided). +func (c *schemaHashCache) get(schemaHash SchemaHash) (*datastore.ReadOnlyStoredSchema, error) { + // Check for bypass sentinels - these intentionally skip the cache + if schemaHash.IsBypassSentinel() { + return nil, nil + } + + // Empty hash indicates a bug - schema hash should always be provided or use a sentinel + if schemaHash == "" { + return nil, spiceerrors.MustBugf("empty schema hash passed to cache.Get() - use NoSchemaHashInTransaction or NoSchemaHashForTesting, or provide a real hash") + } + + // Fast path: Check atomic latest entry + if latest := c.latest.Load(); latest != nil && latest.hash == schemaHash { + return latest.schema, nil + } + + // Slow path: Check cache + schema, ok := c.cache.Get(SchemaCacheKey(schemaHash)) + if !ok { + return nil, nil + } + + return schema, nil +} + +// Set stores a schema in the cache by hash. +// Adds to the cache and updates the atomic latest entry. +// No-ops if hash is a bypass sentinel (NoSchemaHashInTransaction or NoSchemaHashForTesting). +// Returns error if hash is empty string (indicates a bug where it wasn't properly provided). +func (c *schemaHashCache) Set(schemaHash SchemaHash, schema *datastore.ReadOnlyStoredSchema) error { + if schemaHash.IsBypassSentinel() { + return nil + } + if schemaHash == "" { + return spiceerrors.MustBugf("empty schema hash passed to cache.Set() - use NoSchemaHashInTransaction, NoSchemaHashForTesting, NoSchemaHashForWatch, or provide a real hash") + } + + c.latest.Store(&latestSchemaEntry{ + hash: schemaHash, + schema: schema, + }) + + c.cache.Set(SchemaCacheKey(schemaHash), schema, 1) + return nil +} + +// GetOrLoad retrieves a schema from the cache, or loads it using the provided loader. +// Uses singleflight to deduplicate concurrent loads of the same hash. +// Bypasses cache for sentinel values (NoSchemaHashInTransaction or NoSchemaHashForTesting). +// Returns error if hash is empty string (indicates a bug where schema hash wasn't properly provided). +func (c *schemaHashCache) GetOrLoad( + ctx context.Context, + rev datastore.Revision, + schemaHash SchemaHash, + loader func(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error), +) (*datastore.ReadOnlyStoredSchema, error) { + // Check for bypass sentinels - load directly without caching + if schemaHash.IsBypassSentinel() { + schema, err := loader(ctx) + if err != nil { + return nil, err + } + return schema, nil + } + + // Empty hash indicates a bug - schema hash should always be provided or use a sentinel + if schemaHash == "" { + return nil, spiceerrors.MustBugf("empty schema hash passed to cache.GetOrLoad() - use NoSchemaHashInTransaction, NoSchemaHashForTesting, NoSchemaHashForWatch, or provide a real hash") + } + + ctx, span := tracer.Start(ctx, "SchemaCache.GetOrLoad") + defer span.End() + span.SetAttributes(attribute.String(otelconv.AttrSchemaHash, string(schemaHash))) + + // Try cache first + schema, err := c.get(schemaHash) + if err != nil { + return nil, err + } + if schema != nil { + span.SetAttributes(attribute.Bool(otelconv.AttrSchemaReadFromCache, true)) + return schema, nil + } + + // Load with singleflight to prevent duplicate loads. Use a short timeout to + // avoid deadlocks: if all connection pool slots are held by goroutines waiting + // on this singleflight, the leader can't acquire a connection to load the + // schema. The timeout lets waiters give up and fall back to loading directly, + // freeing connections for progress to be made. + sfCtx, cancel := context.WithTimeout(ctx, singleflightTimeout) + defer cancel() + + result, _, err := c.singleflight.Do(sfCtx, string(schemaHash), func(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + // Check cache again in case another goroutine loaded it + schema, err := c.get(schemaHash) + if err != nil { + return nil, err + } + if schema != nil { + return schema, nil + } + + // Load from datastore + schema, err = loader(ctx) + if err != nil { + return nil, err + } + + // Cache the result + if err := c.Set(schemaHash, schema); err != nil { + return nil, err + } + + return schema, nil + }) + if err != nil { + // If the singleflight timed out but the caller's context is still valid, + // check the cache once more (the leader may have finished between our + // timeout and now), then fall back to loading directly. + if sfCtx.Err() != nil && ctx.Err() == nil { + schema, cacheErr := c.get(schemaHash) + if cacheErr != nil { + return nil, cacheErr + } + if schema != nil { + return schema, nil + } + + return loader(ctx) + } + return nil, err + } + + span.SetAttributes(attribute.Bool(otelconv.AttrSchemaReadFromCache, false)) + return result, nil +} diff --git a/pkg/datalayer/hashcache_test.go b/pkg/datalayer/hashcache_test.go new file mode 100644 index 000000000..13af64a9d --- /dev/null +++ b/pkg/datalayer/hashcache_test.go @@ -0,0 +1,261 @@ +package datalayer + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/authzed/spicedb/pkg/datastore" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +func makeTestSchema(text string) *datastore.ReadOnlyStoredSchema { + return datastore.NewReadOnlyStoredSchema(&core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: text, + }, + }, + }) +} + +func TestSchemaHashCache_BasicGetSet(t *testing.T) { + shc := newSchemaHashCache(newTestSchemaCache()) + + // Cache miss + retrieved, err := shc.get(SchemaHash("hash1")) + require.NoError(t, err) + require.Nil(t, retrieved) + + // Set and get + schema := makeTestSchema("definition user {}") + err = shc.Set(SchemaHash("hash1"), schema) + require.NoError(t, err) + + shc.cache.Wait() + + retrieved, err = shc.get(SchemaHash("hash1")) + require.NoError(t, err) + require.NotNil(t, retrieved) + require.Equal(t, schema.Get().GetV1().SchemaText, retrieved.Get().GetV1().SchemaText) +} + +func TestSchemaHashCache_EmptyHash(t *testing.T) { + shc := newSchemaHashCache(newTestSchemaCache()) + + require.Panics(t, func() { + _ = shc.Set(SchemaHash(""), makeTestSchema("definition user {}")) + }, "empty hash should panic") + + require.Panics(t, func() { + _, _ = shc.get(SchemaHash("")) + }, "empty hash should panic") +} + +func TestSchemaHashCache_GetOrLoad(t *testing.T) { + shc := newSchemaHashCache(newTestSchemaCache()) + + loadCalls := 0 + loader := func(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + loadCalls++ + return makeTestSchema("loaded definition"), nil + } + + // First call should load + schema, err := shc.GetOrLoad(context.Background(), datastore.NoRevision, SchemaHash("hash1"), loader) + require.NoError(t, err) + require.NotNil(t, schema) + require.Equal(t, "loaded definition", schema.Get().GetV1().SchemaText) + require.Equal(t, 1, loadCalls) + + // Second call should hit cache + schema, err = shc.GetOrLoad(context.Background(), datastore.NoRevision, SchemaHash("hash1"), loader) + require.NoError(t, err) + require.NotNil(t, schema) + require.Equal(t, 1, loadCalls) +} + +func TestSchemaHashCache_GetOrLoadEmptyHash(t *testing.T) { + shc := newSchemaHashCache(newTestSchemaCache()) + + require.Panics(t, func() { + _, _ = shc.GetOrLoad(context.Background(), datastore.NoRevision, SchemaHash(""), func(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + return makeTestSchema("loaded definition"), nil + }) + }, "empty hash should panic") +} + +func TestSchemaHashCache_Singleflight(t *testing.T) { + shc := newSchemaHashCache(newTestSchemaCache()) + + loadCalls := 0 + loadStarted := make(chan struct{}) + loadContinue := make(chan struct{}) + + loader := func(_ context.Context) (*datastore.ReadOnlyStoredSchema, error) { + loadCalls++ + close(loadStarted) + <-loadContinue + return makeTestSchema("loaded definition"), nil + } + + const numGoroutines = 10 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + results := make(chan error, numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + schema, err := shc.GetOrLoad(context.Background(), datastore.NoRevision, SchemaHash("hash1"), loader) + if err != nil { + results <- err + return + } + if schema == nil { + results <- fmt.Errorf("schema is nil") + return + } + results <- nil + }() + } + + <-loadStarted + close(loadContinue) + wg.Wait() + close(results) + + for err := range results { + require.NoError(t, err) + } + + require.Equal(t, 1, loadCalls) +} + +func TestSchemaHashCache_LoadError(t *testing.T) { + shc := newSchemaHashCache(newTestSchemaCache()) + + expectedErr := fmt.Errorf("load failed") + schema, err := shc.GetOrLoad(context.Background(), datastore.NoRevision, SchemaHash("hash1"), func(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + return nil, expectedErr + }) + require.Error(t, err) + require.Equal(t, expectedErr, err) + require.Nil(t, schema) +} + +func TestSchemaHashCache_SetUpdatesLatest(t *testing.T) { + shc := newSchemaHashCache(newTestSchemaCache()) + + err := shc.Set(SchemaHash("hash1"), makeTestSchema("definition v1 {}")) + require.NoError(t, err) + + latest := shc.latest.Load() + require.NotNil(t, latest) + require.Equal(t, SchemaHash("hash1"), latest.hash) + + err = shc.Set(SchemaHash("hash2"), makeTestSchema("definition v2 {}")) + require.NoError(t, err) + + latest = shc.latest.Load() + require.Equal(t, SchemaHash("hash2"), latest.hash) +} + +func TestSchemaHashCache_SentinelBypass(t *testing.T) { + shc := newSchemaHashCache(newTestSchemaCache()) + + allSentinels := []SchemaHash{ + NoSchemaHashInTransaction, + NoSchemaHashForTesting, + NoSchemaHashForWatch, + NoSchemaHashForLegacyCursor, + NoSchemaHashInDevelopment, + } + + schema := makeTestSchema("definition user {}") + + for _, sentinel := range allSentinels { + // Set should no-op + err := shc.Set(sentinel, schema) + require.NoError(t, err) + + // GetOrLoad should always call loader + loadCalled := false + result, err := shc.GetOrLoad(context.Background(), datastore.NoRevision, sentinel, func(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + loadCalled = true + return schema, nil + }) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, loadCalled) + } +} + +func TestSchemaHashCache_SingleflightTimeoutFallback(t *testing.T) { + shc := newSchemaHashCache(newTestSchemaCache()) + + originalTimeout := singleflightTimeout + defer func() { singleflightTimeout = originalTimeout }() + singleflightTimeout = 100 * time.Millisecond + + leaderStarted := make(chan struct{}) + leaderContinue := make(chan struct{}) + + var loadCount atomic.Int32 + + slowLoader := func(_ context.Context) (*datastore.ReadOnlyStoredSchema, error) { + count := loadCount.Add(1) + if count == 1 { + close(leaderStarted) + <-leaderContinue + } + return makeTestSchema("loaded schema"), nil + } + + hash := SchemaHash("timeout-hash") + + var leaderWg sync.WaitGroup + var leaderErr error + leaderWg.Add(1) + go func() { + defer leaderWg.Done() + _, leaderErr = shc.GetOrLoad(context.Background(), datastore.NoRevision, hash, slowLoader) + }() + + <-leaderStarted + + // Waiter should fall back after timeout + schema, err := shc.GetOrLoad(context.Background(), datastore.NoRevision, hash, slowLoader) + require.NoError(t, err) + require.NotNil(t, schema) + + require.GreaterOrEqual(t, int(loadCount.Load()), 2) + + close(leaderContinue) + leaderWg.Wait() + require.NoError(t, leaderErr) +} + +func TestNoopSchemaCache(t *testing.T) { + noop := noopSchemaCache{} + + // Set is a no-op + err := noop.Set(SchemaHash("hash"), makeTestSchema("test")) + require.NoError(t, err) + + // GetOrLoad always calls loader + loadCalled := false + schema, err := noop.GetOrLoad(context.Background(), datastore.NoRevision, SchemaHash("hash"), func(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + loadCalled = true + return makeTestSchema("loaded"), nil + }) + require.NoError(t, err) + require.True(t, loadCalled) + require.Equal(t, "loaded", schema.Get().GetV1().SchemaText) +} diff --git a/pkg/datalayer/impl.go b/pkg/datalayer/impl.go index 14e4ed3ff..ecd8cf184 100644 --- a/pkg/datalayer/impl.go +++ b/pkg/datalayer/impl.go @@ -9,35 +9,105 @@ import ( "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/datastore/options" core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" "github.com/authzed/spicedb/pkg/tuple" ) +// storedSchemaCache caches stored schemas by hash. +type storedSchemaCache interface { + GetOrLoad(ctx context.Context, rev datastore.Revision, schemaHash SchemaHash, + loader func(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error)) (*datastore.ReadOnlyStoredSchema, error) + Set(schemaHash SchemaHash, schema *datastore.ReadOnlyStoredSchema) error +} + +// noopSchemaCache is a storedSchemaCache that always delegates to the loader. +type noopSchemaCache struct{} + +func (noopSchemaCache) GetOrLoad(ctx context.Context, _ datastore.Revision, _ SchemaHash, + loader func(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error), +) (*datastore.ReadOnlyStoredSchema, error) { + return loader(ctx) +} + +func (noopSchemaCache) Set(_ SchemaHash, _ *datastore.ReadOnlyStoredSchema) error { + return nil +} + +// DataLayerOption configures a DataLayer. +type DataLayerOption func(*defaultDataLayer) + +// WithSchemaMode sets the schema mode for the DataLayer. +func WithSchemaMode(mode SchemaMode) DataLayerOption { + return func(d *defaultDataLayer) { + d.schemaMode = mode + } +} + +// WithSchemaCache sets the backing schema cache for the DataLayer. +// When set, ReadStoredSchema calls are cached and WriteStoredSchema updates the cache. +func WithSchemaCache(cache SchemaCache) DataLayerOption { + return func(d *defaultDataLayer) { + d.cache = newSchemaHashCache(cache) + } +} + // NewDataLayer creates a new DataLayer wrapping a datastore.Datastore. -func NewDataLayer(ds datastore.Datastore) DataLayer { - return &defaultDataLayer{ds: ds} +func NewDataLayer(ds datastore.Datastore, opts ...DataLayerOption) DataLayer { + d := &defaultDataLayer{ + ds: ds, + schemaMode: SchemaModeReadLegacyWriteLegacy, + cache: noopSchemaCache{}, + } + for _, opt := range opts { + opt(d) + } + return d } // defaultDataLayer wraps a datastore.Datastore and implements DataLayer. type defaultDataLayer struct { - ds datastore.Datastore + ds datastore.Datastore + schemaMode SchemaMode + cache storedSchemaCache } -func (d *defaultDataLayer) SnapshotReader(rev datastore.Revision) RevisionedReader { - return &revisionedReader{reader: d.ds.SnapshotReader(rev)} +func (d *defaultDataLayer) SnapshotReader(rev datastore.Revision, schemaHash SchemaHash) RevisionedReader { + if schemaHash == "" { + _ = spiceerrors.MustBugf("empty string passed as SchemaHash; use a named sentinel") + } + return &revisionedReader{ + reader: d.ds.SnapshotReader(rev), + rev: rev, + schemaMode: d.schemaMode, + schemaHash: schemaHash, + cache: d.cache, + } } func (d *defaultDataLayer) ReadWriteTx(ctx context.Context, fn TxUserFunc, opts ...options.RWTOptionsOption) (datastore.Revision, error) { return d.ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { - return fn(ctx, &readWriteTransaction{rwt: rwt}) + return fn(ctx, &readWriteTransaction{rwt: rwt, schemaMode: d.schemaMode, cache: d.cache}) }, opts...) } -func (d *defaultDataLayer) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { - return d.ds.OptimizedRevision(ctx) +func (d *defaultDataLayer) OptimizedRevision(ctx context.Context) (datastore.Revision, SchemaHash, error) { + rev, err := d.ds.OptimizedRevision(ctx) + if err != nil { + return datastore.NoRevision, NoSchemaHashInLegacyMode, err + } + + // TODO: track schema hash from watch cache or initial read for unified modes + return rev, NoSchemaHashInLegacyMode, nil } -func (d *defaultDataLayer) HeadRevision(ctx context.Context) (datastore.Revision, error) { - return d.ds.HeadRevision(ctx) +func (d *defaultDataLayer) HeadRevision(ctx context.Context) (datastore.Revision, SchemaHash, error) { + rev, err := d.ds.HeadRevision(ctx) + if err != nil { + return datastore.NoRevision, NoSchemaHashInLegacyMode, err + } + + // TODO: track schema hash from watch cache or initial read for unified modes + return rev, NoSchemaHashInLegacyMode, nil } func (d *defaultDataLayer) CheckRevision(ctx context.Context, revision datastore.Revision) error { @@ -82,10 +152,17 @@ func (d *defaultDataLayer) Close() error { // revisionedReader wraps a datastore.Reader and implements RevisionedReader. type revisionedReader struct { - reader datastore.Reader + reader datastore.Reader + rev datastore.Revision + schemaMode SchemaMode + schemaHash SchemaHash + cache storedSchemaCache } func (r *revisionedReader) ReadSchema(ctx context.Context) (SchemaReader, error) { + if r.schemaMode.ReadsFromNew() { + return newStoredSchemaReaderAdapter(r.reader, r.schemaHash, r.rev, r.cache) + } return &legacySchemaReaderAdapter{legacyReader: r.reader}, nil } @@ -107,10 +184,15 @@ func (r *revisionedReader) LookupCounters(ctx context.Context) ([]datastore.Rela // readWriteTransaction wraps a datastore.ReadWriteTransaction and implements ReadWriteTransaction. type readWriteTransaction struct { - rwt datastore.ReadWriteTransaction + rwt datastore.ReadWriteTransaction + schemaMode SchemaMode + cache storedSchemaCache } func (t *readWriteTransaction) ReadSchema(_ context.Context) (SchemaReader, error) { + if t.schemaMode.ReadsFromNew() { + return newStoredSchemaReaderAdapter(t.rwt, NoSchemaHashInTransaction, datastore.NoRevision, t.cache) + } return &legacySchemaReaderAdapter{legacyReader: t.rwt}, nil } @@ -143,7 +225,21 @@ func (t *readWriteTransaction) BulkLoad(ctx context.Context, iter datastore.Bulk } func (t *readWriteTransaction) WriteSchema(ctx context.Context, definitions []datastore.SchemaDefinition, schemaString string, caveatTypeSet *caveattypes.TypeSet) error { - return writeSchemaViaLegacy(ctx, t.rwt, t.rwt, definitions) + // Write to legacy storage if mode requires it + if t.schemaMode.WritesToLegacy() { + if err := writeSchemaViaLegacy(ctx, t.rwt, t.rwt, definitions); err != nil { + return err + } + } + + // Write to unified storage if mode requires it + if t.schemaMode.WritesToNew() { + if err := WriteSchemaViaStoredSchema(ctx, t.rwt, definitions, schemaString, t.cache); err != nil { + return err + } + } + + return nil } func (t *readWriteTransaction) LegacySchemaWriter() LegacySchemaWriter { @@ -208,20 +304,36 @@ type readOnlyDatastoreAdapter struct { ds datastore.ReadOnlyDatastore } -func (r *readOnlyDatastoreAdapter) SnapshotReader(rev datastore.Revision) RevisionedReader { - return &revisionedReader{reader: r.ds.SnapshotReader(rev)} +func (r *readOnlyDatastoreAdapter) SnapshotReader(rev datastore.Revision, schemaHash SchemaHash) RevisionedReader { + if schemaHash == "" { + _ = spiceerrors.MustBugf("empty string passed as SchemaHash; use a named sentinel") + } + return &revisionedReader{ + reader: r.ds.SnapshotReader(rev), + rev: rev, + schemaMode: SchemaModeReadLegacyWriteLegacy, + schemaHash: schemaHash, + } } func (r *readOnlyDatastoreAdapter) ReadWriteTx(_ context.Context, _ TxUserFunc, _ ...options.RWTOptionsOption) (datastore.Revision, error) { return datastore.NoRevision, datastore.NewReadonlyErr() } -func (r *readOnlyDatastoreAdapter) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { - return r.ds.OptimizedRevision(ctx) +func (r *readOnlyDatastoreAdapter) OptimizedRevision(ctx context.Context) (datastore.Revision, SchemaHash, error) { + rev, err := r.ds.OptimizedRevision(ctx) + if err != nil { + return datastore.NoRevision, NoSchemaHashInLegacyMode, err + } + return rev, NoSchemaHashInLegacyMode, nil } -func (r *readOnlyDatastoreAdapter) HeadRevision(ctx context.Context) (datastore.Revision, error) { - return r.ds.HeadRevision(ctx) +func (r *readOnlyDatastoreAdapter) HeadRevision(ctx context.Context) (datastore.Revision, SchemaHash, error) { + rev, err := r.ds.HeadRevision(ctx) + if err != nil { + return datastore.NoRevision, NoSchemaHashInLegacyMode, err + } + return rev, NoSchemaHashInLegacyMode, nil } func (r *readOnlyDatastoreAdapter) CheckRevision(ctx context.Context, revision datastore.Revision) error { diff --git a/pkg/datalayer/mocks/mock_datalayer.go b/pkg/datalayer/mocks/mock_datalayer.go index eb4cdde53..3a4882373 100644 --- a/pkg/datalayer/mocks/mock_datalayer.go +++ b/pkg/datalayer/mocks/mock_datalayer.go @@ -176,12 +176,13 @@ func (mr *MockDataLayerMockRecorder) Features(ctx any) *gomock.Call { } // HeadRevision mocks base method. -func (m *MockDataLayer) HeadRevision(ctx context.Context) (datastore.Revision, error) { +func (m *MockDataLayer) HeadRevision(ctx context.Context) (datastore.Revision, datalayer.SchemaHash, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "HeadRevision", ctx) ret0, _ := ret[0].(datastore.Revision) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret1, _ := ret[1].(datalayer.SchemaHash) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // HeadRevision indicates an expected call of HeadRevision. @@ -221,12 +222,13 @@ func (mr *MockDataLayerMockRecorder) OfflineFeatures() *gomock.Call { } // OptimizedRevision mocks base method. -func (m *MockDataLayer) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { +func (m *MockDataLayer) OptimizedRevision(ctx context.Context) (datastore.Revision, datalayer.SchemaHash, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OptimizedRevision", ctx) ret0, _ := ret[0].(datastore.Revision) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret1, _ := ret[1].(datalayer.SchemaHash) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // OptimizedRevision indicates an expected call of OptimizedRevision. @@ -286,17 +288,17 @@ func (mr *MockDataLayerMockRecorder) RevisionFromString(serialized any) *gomock. } // SnapshotReader mocks base method. -func (m *MockDataLayer) SnapshotReader(arg0 datastore.Revision) datalayer.RevisionedReader { +func (m *MockDataLayer) SnapshotReader(arg0 datastore.Revision, arg1 datalayer.SchemaHash) datalayer.RevisionedReader { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SnapshotReader", arg0) + ret := m.ctrl.Call(m, "SnapshotReader", arg0, arg1) ret0, _ := ret[0].(datalayer.RevisionedReader) return ret0 } // SnapshotReader indicates an expected call of SnapshotReader. -func (mr *MockDataLayerMockRecorder) SnapshotReader(arg0 any) *gomock.Call { +func (mr *MockDataLayerMockRecorder) SnapshotReader(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SnapshotReader", reflect.TypeOf((*MockDataLayer)(nil).SnapshotReader), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SnapshotReader", reflect.TypeOf((*MockDataLayer)(nil).SnapshotReader), arg0, arg1) } // Statistics mocks base method. diff --git a/pkg/datalayer/readonly.go b/pkg/datalayer/readonly.go index a9e654474..1916b7d97 100644 --- a/pkg/datalayer/readonly.go +++ b/pkg/datalayer/readonly.go @@ -16,19 +16,19 @@ type readonlyDataLayer struct { delegate DataLayer } -func (r *readonlyDataLayer) SnapshotReader(rev datastore.Revision) RevisionedReader { - return r.delegate.SnapshotReader(rev) +func (r *readonlyDataLayer) SnapshotReader(rev datastore.Revision, schemaHash SchemaHash) RevisionedReader { + return r.delegate.SnapshotReader(rev, schemaHash) } func (r *readonlyDataLayer) ReadWriteTx(_ context.Context, _ TxUserFunc, _ ...options.RWTOptionsOption) (datastore.Revision, error) { return datastore.NoRevision, datastore.NewReadonlyErr() } -func (r *readonlyDataLayer) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { +func (r *readonlyDataLayer) OptimizedRevision(ctx context.Context) (datastore.Revision, SchemaHash, error) { return r.delegate.OptimizedRevision(ctx) } -func (r *readonlyDataLayer) HeadRevision(ctx context.Context) (datastore.Revision, error) { +func (r *readonlyDataLayer) HeadRevision(ctx context.Context) (datastore.Revision, SchemaHash, error) { return r.delegate.HeadRevision(ctx) } diff --git a/pkg/datalayer/readonly_test.go b/pkg/datalayer/readonly_test.go new file mode 100644 index 000000000..b9dd9bdaa --- /dev/null +++ b/pkg/datalayer/readonly_test.go @@ -0,0 +1,251 @@ +package datalayer_test + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/authzed/spicedb/pkg/datalayer" + mock_datalayer "github.com/authzed/spicedb/pkg/datalayer/mocks" + "github.com/authzed/spicedb/pkg/datastore" +) + +// fakeRevision is a simple datastore.Revision for unit tests. +type fakeRevision string + +func (r fakeRevision) String() string { return string(r) } +func (r fakeRevision) Equal(o datastore.Revision) bool { return r.String() == o.String() } +func (r fakeRevision) GreaterThan(datastore.Revision) bool { return false } +func (r fakeRevision) LessThan(datastore.Revision) bool { return false } +func (r fakeRevision) ByteSortable() bool { return false } + +func TestReadonlyDL_ReadWriteTx_ReturnsError(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dl := mock_datalayer.NewMockDataLayer(ctrl) + rdl := datalayer.NewReadonlyDataLayer(dl) + + rev, err := rdl.ReadWriteTx(t.Context(), func(_ context.Context, _ datalayer.ReadWriteTransaction) error { + t.Fatal("should not be called") + return nil + }) + require.Error(t, err) + require.Equal(t, datastore.NoRevision, rev) + + var roErr datastore.ReadOnlyError + require.ErrorAs(t, err, &roErr) +} + +func TestReadonlyDL_SnapshotReader_DelegatesToUnderlying(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockReader := mock_datalayer.NewMockRevisionedReader(ctrl) + mockReader.EXPECT().QueryRelationships(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + + dl := mock_datalayer.NewMockDataLayer(ctrl) + dl.EXPECT().SnapshotReader(gomock.Any(), datalayer.NoSchemaHashForTesting).Return(mockReader).Times(1) + + rdl := datalayer.NewReadonlyDataLayer(dl) + reader := rdl.SnapshotReader(fakeRevision("r1"), datalayer.NoSchemaHashForTesting) + require.NotNil(t, reader) + + _, _ = reader.QueryRelationships(t.Context(), datastore.RelationshipsFilter{}) +} + +func TestReadonlyDL_OptimizedRevision_Delegates(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dl := mock_datalayer.NewMockDataLayer(ctrl) + dl.EXPECT().OptimizedRevision(gomock.Any()).Return(fakeRevision("opt"), datalayer.SchemaHash("hash"), nil).Times(1) + + rdl := datalayer.NewReadonlyDataLayer(dl) + rev, hash, err := rdl.OptimizedRevision(t.Context()) + require.NoError(t, err) + require.Equal(t, fakeRevision("opt"), rev) + require.Equal(t, datalayer.SchemaHash("hash"), hash) +} + +func TestReadonlyDL_HeadRevision_Delegates(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dl := mock_datalayer.NewMockDataLayer(ctrl) + dl.EXPECT().HeadRevision(gomock.Any()).Return(fakeRevision("head"), datalayer.SchemaHash("hash"), nil).Times(1) + + rdl := datalayer.NewReadonlyDataLayer(dl) + rev, hash, err := rdl.HeadRevision(t.Context()) + require.NoError(t, err) + require.Equal(t, fakeRevision("head"), rev) + require.Equal(t, datalayer.SchemaHash("hash"), hash) +} + +func TestReadonlyDL_CheckRevision_Delegates(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dl := mock_datalayer.NewMockDataLayer(ctrl) + dl.EXPECT().CheckRevision(gomock.Any(), gomock.Any()).Return(nil).Times(1) + + rdl := datalayer.NewReadonlyDataLayer(dl) + err := rdl.CheckRevision(t.Context(), fakeRevision("r1")) + require.NoError(t, err) +} + +func TestReadonlyDL_RevisionFromString_Delegates(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dl := mock_datalayer.NewMockDataLayer(ctrl) + dl.EXPECT().RevisionFromString("abc").Return(fakeRevision("abc"), nil).Times(1) + + rdl := datalayer.NewReadonlyDataLayer(dl) + rev, err := rdl.RevisionFromString("abc") + require.NoError(t, err) + require.Equal(t, fakeRevision("abc"), rev) +} + +func TestReadonlyDL_Watch_Delegates(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dl := mock_datalayer.NewMockDataLayer(ctrl) + dl.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + + rdl := datalayer.NewReadonlyDataLayer(dl) + ch, errCh := rdl.Watch(t.Context(), fakeRevision("r1"), datastore.WatchOptions{}) + require.Nil(t, ch) + require.Nil(t, errCh) +} + +func TestReadonlyDL_ReadyState_Delegates(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dl := mock_datalayer.NewMockDataLayer(ctrl) + dl.EXPECT().ReadyState(gomock.Any()).Return(datastore.ReadyState{IsReady: true}, nil).Times(1) + + rdl := datalayer.NewReadonlyDataLayer(dl) + state, err := rdl.ReadyState(t.Context()) + require.NoError(t, err) + require.True(t, state.IsReady) +} + +func TestReadonlyDL_Features_Delegates(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dl := mock_datalayer.NewMockDataLayer(ctrl) + dl.EXPECT().Features(gomock.Any()).Return(nil, nil).Times(1) + + rdl := datalayer.NewReadonlyDataLayer(dl) + f, err := rdl.Features(t.Context()) + require.NoError(t, err) + require.Nil(t, f) +} + +func TestReadonlyDL_OfflineFeatures_Delegates(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dl := mock_datalayer.NewMockDataLayer(ctrl) + dl.EXPECT().OfflineFeatures().Return(nil, nil).Times(1) + + rdl := datalayer.NewReadonlyDataLayer(dl) + f, err := rdl.OfflineFeatures() + require.NoError(t, err) + require.Nil(t, f) +} + +func TestReadonlyDL_Statistics_Delegates(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dl := mock_datalayer.NewMockDataLayer(ctrl) + dl.EXPECT().Statistics(gomock.Any()).Return(datastore.Stats{}, nil).Times(1) + + rdl := datalayer.NewReadonlyDataLayer(dl) + _, err := rdl.Statistics(t.Context()) + require.NoError(t, err) +} + +func TestReadonlyDL_UniqueID_Delegates(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dl := mock_datalayer.NewMockDataLayer(ctrl) + dl.EXPECT().UniqueID(gomock.Any()).Return("uid", nil).Times(1) + + rdl := datalayer.NewReadonlyDataLayer(dl) + uid, err := rdl.UniqueID(t.Context()) + require.NoError(t, err) + require.Equal(t, "uid", uid) +} + +func TestReadonlyDL_MetricsID_Delegates(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dl := mock_datalayer.NewMockDataLayer(ctrl) + dl.EXPECT().MetricsID().Return("mid", nil).Times(1) + + rdl := datalayer.NewReadonlyDataLayer(dl) + mid, err := rdl.MetricsID() + require.NoError(t, err) + require.Equal(t, "mid", mid) +} + +func TestReadonlyDL_Close_Delegates(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dl := mock_datalayer.NewMockDataLayer(ctrl) + dl.EXPECT().Close().Return(nil).Times(1) + + rdl := datalayer.NewReadonlyDataLayer(dl) + err := rdl.Close() + require.NoError(t, err) +} + +func TestReadonlyDL_Close_PropagatesError(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dl := mock_datalayer.NewMockDataLayer(ctrl) + dl.EXPECT().Close().Return(errors.New("close err")).Times(1) + + rdl := datalayer.NewReadonlyDataLayer(dl) + err := rdl.Close() + require.ErrorContains(t, err, "close err") +} + +func TestReadonlyDL_UnwrapDatastore_ReturnsNilForMock(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + dl := mock_datalayer.NewMockDataLayer(ctrl) + rdl := datalayer.NewReadonlyDataLayer(dl) + + require.Nil(t, datalayer.UnwrapDatastore(rdl)) +} diff --git a/pkg/datalayer/schema_adapter.go b/pkg/datalayer/schema_adapter.go index ced174697..0f2a6dc5f 100644 --- a/pkg/datalayer/schema_adapter.go +++ b/pkg/datalayer/schema_adapter.go @@ -6,7 +6,11 @@ import ( "fmt" "maps" "slices" + "sort" + "go.opentelemetry.io/otel/attribute" + + "github.com/authzed/spicedb/internal/telemetry/otelconv" "github.com/authzed/spicedb/pkg/caveats/types" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/genutil/mapz" @@ -48,6 +52,14 @@ func (l *legacySchemaReaderAdapter) SchemaText(ctx context.Context) (string, err return "", err } + // Sort the definitions by name for deterministic output + sort.Slice(namespaces, func(i, j int) bool { + return namespaces[i].Definition.Name < namespaces[j].Definition.Name + }) + sort.Slice(caveats, func(i, j int) bool { + return caveats[i].Definition.Name < caveats[j].Definition.Name + }) + // Build a list of all schema definitions caveatTypeSet := types.Default.TypeSet definitions := make([]compiler.SchemaDefinition, 0, len(caveats)+len(namespaces)) @@ -308,3 +320,236 @@ func writeSchemaViaLegacy(ctx context.Context, legacyWriter datastore.LegacySche return nil } + +// storedSchemaReaderAdapter implements SchemaReader by reading from the unified +// StoredSchema proto via ReadStoredSchema on the underlying datastore reader. +type storedSchemaReaderAdapter struct { + storedSchema *datastore.ReadOnlyStoredSchema + lastWrittenRevision datastore.Revision +} + +// storedSchemaReader is the interface required to read a stored schema. +type storedSchemaReader interface { + ReadStoredSchema(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) +} + +// newStoredSchemaReaderAdapter creates a storedSchemaReaderAdapter by loading +// the stored schema from the reader. The lastWrittenRevision is the revision at +// which this snapshot was taken; it is returned as LastWrittenRevision on each definition. +// The cache is used to cache and deduplicate ReadStoredSchema calls. +func newStoredSchemaReaderAdapter(reader storedSchemaReader, schemaHash SchemaHash, lastWrittenRevision datastore.Revision, cache storedSchemaCache, +) (*storedSchemaReaderAdapter, error) { + ctx, span := tracer.Start(context.Background(), "ReadStoredSchema") + defer span.End() + span.SetAttributes(attribute.String(otelconv.AttrSchemaHash, string(schemaHash))) + + storedSchema, err := cache.GetOrLoad(ctx, lastWrittenRevision, schemaHash, func(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + return reader.ReadStoredSchema(ctx) + }) + if err != nil { + if errors.Is(err, datastore.ErrSchemaNotFound) { + // No unified schema yet; return an adapter with no definitions + return &storedSchemaReaderAdapter{ + storedSchema: datastore.NewReadOnlyStoredSchema(&core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{}, + }, + }), + lastWrittenRevision: lastWrittenRevision, + }, nil + } + return nil, fmt.Errorf("failed to read stored schema: %w", err) + } + return &storedSchemaReaderAdapter{storedSchema: storedSchema, lastWrittenRevision: lastWrittenRevision}, nil +} + +func (s *storedSchemaReaderAdapter) v1() *core.StoredSchema_V1StoredSchema { + if v1 := s.storedSchema.Get().GetV1(); v1 != nil { + return v1 + } + return &core.StoredSchema_V1StoredSchema{} +} + +func (s *storedSchemaReaderAdapter) SchemaText(_ context.Context) (string, error) { + v1 := s.v1() + if v1.SchemaText == "" { + if len(v1.NamespaceDefinitions) == 0 && len(v1.CaveatDefinitions) == 0 { + return "", datastore.NewSchemaNotDefinedErr() + } + } + return v1.SchemaText, nil +} + +func (s *storedSchemaReaderAdapter) LookupTypeDefByName(_ context.Context, name string) (datastore.RevisionedTypeDefinition, bool, error) { + v1 := s.v1() + ns, ok := v1.NamespaceDefinitions[name] + if !ok { + return datastore.RevisionedTypeDefinition{}, false, nil + } + return datastore.RevisionedTypeDefinition{ + Definition: ns, + LastWrittenRevision: s.lastWrittenRevision, + }, true, nil +} + +func (s *storedSchemaReaderAdapter) LookupCaveatDefByName(_ context.Context, name string) (datastore.RevisionedCaveat, bool, error) { + v1 := s.v1() + caveat, ok := v1.CaveatDefinitions[name] + if !ok { + return datastore.RevisionedCaveat{}, false, nil + } + return datastore.RevisionedCaveat{ + Definition: caveat, + LastWrittenRevision: s.lastWrittenRevision, + }, true, nil +} + +func (s *storedSchemaReaderAdapter) ListAllTypeDefinitions(_ context.Context) ([]datastore.RevisionedTypeDefinition, error) { + v1 := s.v1() + result := make([]datastore.RevisionedTypeDefinition, 0, len(v1.NamespaceDefinitions)) + for _, ns := range v1.NamespaceDefinitions { + result = append(result, datastore.RevisionedTypeDefinition{ + Definition: ns, + LastWrittenRevision: s.lastWrittenRevision, + }) + } + return result, nil +} + +func (s *storedSchemaReaderAdapter) ListAllCaveatDefinitions(_ context.Context) ([]datastore.RevisionedCaveat, error) { + v1 := s.v1() + result := make([]datastore.RevisionedCaveat, 0, len(v1.CaveatDefinitions)) + for _, caveat := range v1.CaveatDefinitions { + result = append(result, datastore.RevisionedCaveat{ + Definition: caveat, + LastWrittenRevision: s.lastWrittenRevision, + }) + } + return result, nil +} + +func (s *storedSchemaReaderAdapter) ListAllSchemaDefinitions(_ context.Context) (map[string]datastore.SchemaDefinition, error) { + v1 := s.v1() + result := make(map[string]datastore.SchemaDefinition, len(v1.NamespaceDefinitions)+len(v1.CaveatDefinitions)) + for name, ns := range v1.NamespaceDefinitions { + result[name] = ns + } + for name, caveat := range v1.CaveatDefinitions { + result[name] = caveat + } + return result, nil +} + +func (s *storedSchemaReaderAdapter) LookupSchemaDefinitionsByNames(_ context.Context, names []string) (map[string]datastore.SchemaDefinition, error) { + v1 := s.v1() + result := make(map[string]datastore.SchemaDefinition, len(names)) + for _, name := range names { + if ns, ok := v1.NamespaceDefinitions[name]; ok { + result[name] = ns + } else if caveat, ok := v1.CaveatDefinitions[name]; ok { + result[name] = caveat + } + } + return result, nil +} + +func (s *storedSchemaReaderAdapter) LookupTypeDefinitionsByNames(_ context.Context, names []string) (map[string]datastore.TypeDefinition, error) { + v1 := s.v1() + result := make(map[string]datastore.TypeDefinition, len(names)) + for _, name := range names { + if ns, ok := v1.NamespaceDefinitions[name]; ok { + result[name] = ns + } + } + return result, nil +} + +func (s *storedSchemaReaderAdapter) LookupCaveatDefinitionsByNames(_ context.Context, names []string) (map[string]datastore.CaveatDefinition, error) { + v1 := s.v1() + result := make(map[string]datastore.CaveatDefinition, len(names)) + for _, name := range names { + if caveat, ok := v1.CaveatDefinitions[name]; ok { + result[name] = caveat + } + } + return result, nil +} + +var _ SchemaReader = (*storedSchemaReaderAdapter)(nil) + +// WriteSchemaViaStoredSchema builds a StoredSchema proto and writes it via WriteStoredSchema. +// If cache is nil, a no-op cache is used. +func WriteSchemaViaStoredSchema(ctx context.Context, rwt datastore.ReadWriteTransaction, + definitions []datastore.SchemaDefinition, schemaString string, cache storedSchemaCache, +) error { + if cache == nil { + cache = noopSchemaCache{} + } + + ctx, span := tracer.Start(ctx, "WriteSchemaViaStoredSchema") + defer span.End() + + namespaceDefs := make(map[string]*core.NamespaceDefinition) + caveatDefs := make(map[string]*core.CaveatDefinition) + + for _, def := range definitions { + if _, existing := namespaceDefs[def.GetName()]; existing { + return spiceerrors.MustBugf("duplicate definition name: %s", def.GetName()) + } + if _, existing := caveatDefs[def.GetName()]; existing { + return spiceerrors.MustBugf("duplicate definition name: %s", def.GetName()) + } + + switch typedDef := def.(type) { + case *core.NamespaceDefinition: + namespaceDefs[typedDef.Name] = typedDef + case *core.CaveatDefinition: + caveatDefs[typedDef.Name] = typedDef + default: + return spiceerrors.MustBugf("unknown definition type: %T", def) + } + } + + // Compute schema hash from the definitions + compDefs := make([]compiler.SchemaDefinition, 0, len(definitions)) + for _, def := range definitions { + compDef, ok := def.(compiler.SchemaDefinition) + if !ok { + return fmt.Errorf("definition %q does not implement compiler.SchemaDefinition", def.GetName()) + } + compDefs = append(compDefs, compDef) + } + + schemaHash, err := generator.ComputeSchemaHash(compDefs) + if err != nil { + return fmt.Errorf("failed to compute schema hash: %w", err) + } + span.SetAttributes(attribute.String(otelconv.AttrSchemaHash, schemaHash)) + span.SetAttributes(attribute.Int(otelconv.AttrSchemaDataSizeBytes, len(schemaString))) + + storedSchema := &core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: schemaString, + SchemaHash: schemaHash, + NamespaceDefinitions: namespaceDefs, + CaveatDefinitions: caveatDefs, + }, + }, + } + + if err := rwt.WriteStoredSchema(ctx, storedSchema); err != nil { + return err + } + + // Update cache after successful write + if v1 := storedSchema.GetV1(); v1 != nil && v1.SchemaHash != "" { + if err := cache.Set(SchemaHash(v1.SchemaHash), datastore.NewReadOnlyStoredSchema(storedSchema)); err != nil { + return err + } + } + + return nil +} diff --git a/pkg/datalayer/schema_adapter_test.go b/pkg/datalayer/schema_adapter_test.go new file mode 100644 index 000000000..eabaacb1e --- /dev/null +++ b/pkg/datalayer/schema_adapter_test.go @@ -0,0 +1,919 @@ +package datalayer + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/authzed/spicedb/pkg/datastore" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +// --- fake helpers --- + +// fakeRevision is a simple datastore.Revision for unit tests. +type fakeRevision string + +func (r fakeRevision) String() string { return string(r) } +func (r fakeRevision) Equal(o datastore.Revision) bool { return r.String() == o.String() } +func (r fakeRevision) GreaterThan(datastore.Revision) bool { return false } +func (r fakeRevision) LessThan(datastore.Revision) bool { return false } +func (r fakeRevision) ByteSortable() bool { return false } + +var testRevision fakeRevision = "rev1" + +// fakeLegacySchemaReader is a fake datastore.LegacySchemaReader with canned responses. +type fakeLegacySchemaReader struct { + namespaces []datastore.RevisionedNamespace + caveats []datastore.RevisionedCaveat + namespacesErr error + caveatsErr error + readNSByName map[string]*core.NamespaceDefinition + readNSRevision datastore.Revision + readNSErr error + readCavByName map[string]*core.CaveatDefinition + readCavRev datastore.Revision + readCavErr error + lookupNSResult []datastore.RevisionedNamespace + lookupNSErr error + lookupCavResult []datastore.RevisionedCaveat + lookupCavErr error +} + +func (m *fakeLegacySchemaReader) LegacyReadNamespaceByName(_ context.Context, name string) (*core.NamespaceDefinition, datastore.Revision, error) { + if m.readNSErr != nil { + return nil, nil, m.readNSErr + } + ns, ok := m.readNSByName[name] + if !ok { + return nil, nil, datastore.NewNamespaceNotFoundErr(name) + } + return ns, m.readNSRevision, nil +} + +func (m *fakeLegacySchemaReader) LegacyListAllNamespaces(_ context.Context) ([]datastore.RevisionedNamespace, error) { + return m.namespaces, m.namespacesErr +} + +func (m *fakeLegacySchemaReader) LegacyLookupNamespacesWithNames(_ context.Context, _ []string) ([]datastore.RevisionedNamespace, error) { + return m.lookupNSResult, m.lookupNSErr +} + +func (m *fakeLegacySchemaReader) LegacyReadCaveatByName(_ context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { + if m.readCavErr != nil { + return nil, nil, m.readCavErr + } + cav, ok := m.readCavByName[name] + if !ok { + return nil, nil, datastore.NewCaveatNameNotFoundErr(name) + } + return cav, m.readCavRev, nil +} + +func (m *fakeLegacySchemaReader) LegacyListAllCaveats(_ context.Context) ([]datastore.RevisionedCaveat, error) { + return m.caveats, m.caveatsErr +} + +func (m *fakeLegacySchemaReader) LegacyLookupCaveatsWithNames(_ context.Context, _ []string) ([]datastore.RevisionedCaveat, error) { + return m.lookupCavResult, m.lookupCavErr +} + +// fakeLegacySchemaWriter is a fake datastore.LegacySchemaWriter that records calls. +type fakeLegacySchemaWriter struct { + fakeLegacySchemaReader + + writtenNS []*core.NamespaceDefinition + writtenCaveats []*core.CaveatDefinition + deletedNS []string + deletedCaveats []string + writeNSErr error + writeCavErr error + deleteNSErr error + deleteCavErr error +} + +func (m *fakeLegacySchemaWriter) LegacyWriteNamespaces(_ context.Context, nsDefs ...*core.NamespaceDefinition) error { + m.writtenNS = append(m.writtenNS, nsDefs...) + return m.writeNSErr +} + +func (m *fakeLegacySchemaWriter) LegacyWriteCaveats(_ context.Context, cavDefs []*core.CaveatDefinition) error { + m.writtenCaveats = append(m.writtenCaveats, cavDefs...) + return m.writeCavErr +} + +func (m *fakeLegacySchemaWriter) LegacyDeleteNamespaces(_ context.Context, names []string, _ datastore.DeleteNamespacesRelationshipsOption) error { + m.deletedNS = append(m.deletedNS, names...) + return m.deleteNSErr +} + +func (m *fakeLegacySchemaWriter) LegacyDeleteCaveats(_ context.Context, names []string) error { + m.deletedCaveats = append(m.deletedCaveats, names...) + return m.deleteCavErr +} + +// --- legacySchemaReaderAdapter tests --- + +func TestLegacyAdapter_SchemaText_Empty(t *testing.T) { + t.Parallel() + reader := &fakeLegacySchemaReader{} + adapter := SchemaReaderFromLegacy(reader) + + _, err := adapter.SchemaText(t.Context()) + require.Error(t, err) +} + +func TestLegacyAdapter_SchemaText_ListNamespacesError(t *testing.T) { + t.Parallel() + reader := &fakeLegacySchemaReader{namespacesErr: errors.New("boom")} + adapter := SchemaReaderFromLegacy(reader) + + _, err := adapter.SchemaText(t.Context()) + require.ErrorContains(t, err, "boom") +} + +func TestLegacyAdapter_SchemaText_ListCaveatsError(t *testing.T) { + t.Parallel() + ns := &core.NamespaceDefinition{Name: "user"} + reader := &fakeLegacySchemaReader{ + namespaces: []datastore.RevisionedNamespace{ + {Definition: ns, LastWrittenRevision: testRevision}, + }, + caveatsErr: errors.New("caveat err"), + } + adapter := SchemaReaderFromLegacy(reader) + + _, err := adapter.SchemaText(t.Context()) + require.ErrorContains(t, err, "caveat err") +} + +func TestLegacyAdapter_SchemaText_WithDefinitions(t *testing.T) { + t.Parallel() + ns := &core.NamespaceDefinition{Name: "user"} + reader := &fakeLegacySchemaReader{ + namespaces: []datastore.RevisionedNamespace{ + {Definition: ns, LastWrittenRevision: testRevision}, + }, + } + adapter := SchemaReaderFromLegacy(reader) + + text, err := adapter.SchemaText(t.Context()) + require.NoError(t, err) + require.Contains(t, text, "user") +} + +func TestLegacyAdapter_LookupTypeDefByName_Found(t *testing.T) { + t.Parallel() + ns := &core.NamespaceDefinition{Name: "document"} + reader := &fakeLegacySchemaReader{ + readNSByName: map[string]*core.NamespaceDefinition{"document": ns}, + readNSRevision: testRevision, + } + adapter := SchemaReaderFromLegacy(reader) + + result, found, err := adapter.LookupTypeDefByName(t.Context(), "document") + require.NoError(t, err) + require.True(t, found) + require.Equal(t, "document", result.Definition.Name) + require.Equal(t, testRevision, result.LastWrittenRevision) +} + +func TestLegacyAdapter_LookupTypeDefByName_NotFound(t *testing.T) { + t.Parallel() + reader := &fakeLegacySchemaReader{ + readNSByName: map[string]*core.NamespaceDefinition{}, + } + adapter := SchemaReaderFromLegacy(reader) + + _, found, err := adapter.LookupTypeDefByName(t.Context(), "missing") + require.NoError(t, err) + require.False(t, found) +} + +func TestLegacyAdapter_LookupTypeDefByName_OtherError(t *testing.T) { + t.Parallel() + reader := &fakeLegacySchemaReader{ + readNSErr: errors.New("db down"), + } + adapter := SchemaReaderFromLegacy(reader) + + _, _, err := adapter.LookupTypeDefByName(t.Context(), "anything") + require.ErrorContains(t, err, "db down") +} + +func TestLegacyAdapter_LookupCaveatDefByName_Found(t *testing.T) { + t.Parallel() + cav := &core.CaveatDefinition{Name: "mycaveat"} + reader := &fakeLegacySchemaReader{ + readCavByName: map[string]*core.CaveatDefinition{"mycaveat": cav}, + readCavRev: testRevision, + } + adapter := SchemaReaderFromLegacy(reader) + + result, found, err := adapter.LookupCaveatDefByName(t.Context(), "mycaveat") + require.NoError(t, err) + require.True(t, found) + require.Equal(t, "mycaveat", result.Definition.Name) + require.Equal(t, testRevision, result.LastWrittenRevision) +} + +func TestLegacyAdapter_LookupCaveatDefByName_NotFound(t *testing.T) { + t.Parallel() + reader := &fakeLegacySchemaReader{ + readCavByName: map[string]*core.CaveatDefinition{}, + } + adapter := SchemaReaderFromLegacy(reader) + + _, found, err := adapter.LookupCaveatDefByName(t.Context(), "missing") + require.NoError(t, err) + require.False(t, found) +} + +func TestLegacyAdapter_LookupCaveatDefByName_OtherError(t *testing.T) { + t.Parallel() + reader := &fakeLegacySchemaReader{ + readCavErr: errors.New("db error"), + } + adapter := SchemaReaderFromLegacy(reader) + + _, _, err := adapter.LookupCaveatDefByName(t.Context(), "anything") + require.ErrorContains(t, err, "db error") +} + +func TestLegacyAdapter_ListAllTypeDefinitions(t *testing.T) { + t.Parallel() + nsDefs := []datastore.RevisionedNamespace{ + {Definition: &core.NamespaceDefinition{Name: "user"}, LastWrittenRevision: testRevision}, + {Definition: &core.NamespaceDefinition{Name: "doc"}, LastWrittenRevision: testRevision}, + } + reader := &fakeLegacySchemaReader{namespaces: nsDefs} + adapter := SchemaReaderFromLegacy(reader) + + result, err := adapter.ListAllTypeDefinitions(t.Context()) + require.NoError(t, err) + require.Len(t, result, 2) +} + +func TestLegacyAdapter_ListAllTypeDefinitions_Error(t *testing.T) { + t.Parallel() + reader := &fakeLegacySchemaReader{namespacesErr: errors.New("fail")} + adapter := SchemaReaderFromLegacy(reader) + + _, err := adapter.ListAllTypeDefinitions(t.Context()) + require.ErrorContains(t, err, "fail") +} + +func TestLegacyAdapter_ListAllCaveatDefinitions(t *testing.T) { + t.Parallel() + cavDefs := []datastore.RevisionedCaveat{ + {Definition: &core.CaveatDefinition{Name: "cav1"}, LastWrittenRevision: testRevision}, + } + reader := &fakeLegacySchemaReader{caveats: cavDefs} + adapter := SchemaReaderFromLegacy(reader) + + result, err := adapter.ListAllCaveatDefinitions(t.Context()) + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, "cav1", result[0].Definition.Name) +} + +func TestLegacyAdapter_ListAllCaveatDefinitions_Error(t *testing.T) { + t.Parallel() + reader := &fakeLegacySchemaReader{caveatsErr: errors.New("fail")} + adapter := SchemaReaderFromLegacy(reader) + + _, err := adapter.ListAllCaveatDefinitions(t.Context()) + require.ErrorContains(t, err, "fail") +} + +func TestLegacyAdapter_ListAllSchemaDefinitions(t *testing.T) { + t.Parallel() + reader := &fakeLegacySchemaReader{ + namespaces: []datastore.RevisionedNamespace{ + {Definition: &core.NamespaceDefinition{Name: "user"}, LastWrittenRevision: testRevision}, + }, + caveats: []datastore.RevisionedCaveat{ + {Definition: &core.CaveatDefinition{Name: "cav1"}, LastWrittenRevision: testRevision}, + }, + } + adapter := SchemaReaderFromLegacy(reader) + + result, err := adapter.ListAllSchemaDefinitions(t.Context()) + require.NoError(t, err) + require.Len(t, result, 2) + require.Contains(t, result, "user") + require.Contains(t, result, "cav1") +} + +func TestLegacyAdapter_ListAllSchemaDefinitions_NSError(t *testing.T) { + t.Parallel() + reader := &fakeLegacySchemaReader{namespacesErr: errors.New("ns fail")} + adapter := SchemaReaderFromLegacy(reader) + + _, err := adapter.ListAllSchemaDefinitions(t.Context()) + require.ErrorContains(t, err, "ns fail") +} + +func TestLegacyAdapter_ListAllSchemaDefinitions_CaveatError(t *testing.T) { + t.Parallel() + reader := &fakeLegacySchemaReader{ + namespaces: []datastore.RevisionedNamespace{ + {Definition: &core.NamespaceDefinition{Name: "user"}, LastWrittenRevision: testRevision}, + }, + caveatsErr: errors.New("cav fail"), + } + adapter := SchemaReaderFromLegacy(reader) + + _, err := adapter.ListAllSchemaDefinitions(t.Context()) + require.ErrorContains(t, err, "cav fail") +} + +func TestLegacyAdapter_LookupSchemaDefinitionsByNames(t *testing.T) { + t.Parallel() + ns := &core.NamespaceDefinition{Name: "user"} + cav := &core.CaveatDefinition{Name: "mycav"} + reader := &fakeLegacySchemaReader{ + lookupNSResult: []datastore.RevisionedNamespace{ + {Definition: ns, LastWrittenRevision: testRevision}, + }, + lookupCavResult: []datastore.RevisionedCaveat{ + {Definition: cav, LastWrittenRevision: testRevision}, + }, + } + adapter := SchemaReaderFromLegacy(reader) + + result, err := adapter.LookupSchemaDefinitionsByNames(t.Context(), []string{"user", "mycav"}) + require.NoError(t, err) + require.Contains(t, result, "user") + require.Contains(t, result, "mycav") +} + +func TestLegacyAdapter_LookupSchemaDefinitionsByNames_NSError(t *testing.T) { + t.Parallel() + reader := &fakeLegacySchemaReader{lookupNSErr: errors.New("ns lookup fail")} + adapter := SchemaReaderFromLegacy(reader) + + _, err := adapter.LookupSchemaDefinitionsByNames(t.Context(), []string{"foo"}) + require.ErrorContains(t, err, "ns lookup fail") +} + +func TestLegacyAdapter_LookupSchemaDefinitionsByNames_CaveatError(t *testing.T) { + t.Parallel() + reader := &fakeLegacySchemaReader{ + lookupNSResult: []datastore.RevisionedNamespace{}, + lookupCavErr: errors.New("cav lookup fail"), + } + adapter := SchemaReaderFromLegacy(reader) + + _, err := adapter.LookupSchemaDefinitionsByNames(t.Context(), []string{"foo"}) + require.ErrorContains(t, err, "cav lookup fail") +} + +func TestLegacyAdapter_LookupTypeDefinitionsByNames(t *testing.T) { + t.Parallel() + reader := &fakeLegacySchemaReader{ + lookupNSResult: []datastore.RevisionedNamespace{ + {Definition: &core.NamespaceDefinition{Name: "user"}, LastWrittenRevision: testRevision}, + {Definition: &core.NamespaceDefinition{Name: "doc"}, LastWrittenRevision: testRevision}, + }, + } + adapter := SchemaReaderFromLegacy(reader) + + result, err := adapter.LookupTypeDefinitionsByNames(t.Context(), []string{"user", "doc"}) + require.NoError(t, err) + require.Len(t, result, 2) + require.Contains(t, result, "user") + require.Contains(t, result, "doc") +} + +func TestLegacyAdapter_LookupTypeDefinitionsByNames_Error(t *testing.T) { + t.Parallel() + reader := &fakeLegacySchemaReader{lookupNSErr: errors.New("err")} + adapter := SchemaReaderFromLegacy(reader) + + _, err := adapter.LookupTypeDefinitionsByNames(t.Context(), []string{"user"}) + require.ErrorContains(t, err, "err") +} + +func TestLegacyAdapter_LookupCaveatDefinitionsByNames(t *testing.T) { + t.Parallel() + reader := &fakeLegacySchemaReader{ + lookupCavResult: []datastore.RevisionedCaveat{ + {Definition: &core.CaveatDefinition{Name: "c1"}, LastWrittenRevision: testRevision}, + }, + } + adapter := SchemaReaderFromLegacy(reader) + + result, err := adapter.LookupCaveatDefinitionsByNames(t.Context(), []string{"c1", "missing"}) + require.NoError(t, err) + require.Len(t, result, 1) + require.Contains(t, result, "c1") +} + +func TestLegacyAdapter_LookupCaveatDefinitionsByNames_Error(t *testing.T) { + t.Parallel() + reader := &fakeLegacySchemaReader{lookupCavErr: errors.New("err")} + adapter := SchemaReaderFromLegacy(reader) + + _, err := adapter.LookupCaveatDefinitionsByNames(t.Context(), []string{"c1"}) + require.ErrorContains(t, err, "err") +} + +// --- storedSchemaReaderAdapter tests --- + +func newStoredAdapter(nsDefs map[string]*core.NamespaceDefinition, cavDefs map[string]*core.CaveatDefinition, schemaText string) *storedSchemaReaderAdapter { + return &storedSchemaReaderAdapter{ + storedSchema: datastore.NewReadOnlyStoredSchema(&core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: schemaText, + NamespaceDefinitions: nsDefs, + CaveatDefinitions: cavDefs, + }, + }, + }), + lastWrittenRevision: testRevision, + } +} + +func TestStoredAdapter_SchemaText(t *testing.T) { + t.Parallel() + adapter := newStoredAdapter( + map[string]*core.NamespaceDefinition{"user": {Name: "user"}}, + nil, + "definition user {}", + ) + + text, err := adapter.SchemaText(t.Context()) + require.NoError(t, err) + require.Equal(t, "definition user {}", text) +} + +func TestStoredAdapter_SchemaText_EmptyWithDefinitions(t *testing.T) { + t.Parallel() + // SchemaText is empty string but definitions exist -> return empty string, no error. + adapter := newStoredAdapter( + map[string]*core.NamespaceDefinition{"user": {Name: "user"}}, + nil, + "", + ) + + text, err := adapter.SchemaText(t.Context()) + require.NoError(t, err) + require.Empty(t, text) +} + +func TestStoredAdapter_SchemaText_EmptyNoDefinitions(t *testing.T) { + t.Parallel() + adapter := newStoredAdapter(nil, nil, "") + + _, err := adapter.SchemaText(t.Context()) + require.Error(t, err) +} + +func TestStoredAdapter_LookupTypeDefByName_Found(t *testing.T) { + t.Parallel() + ns := &core.NamespaceDefinition{Name: "doc"} + adapter := newStoredAdapter(map[string]*core.NamespaceDefinition{"doc": ns}, nil, "") + + result, found, err := adapter.LookupTypeDefByName(t.Context(), "doc") + require.NoError(t, err) + require.True(t, found) + require.Equal(t, "doc", result.Definition.Name) + require.Equal(t, testRevision, result.LastWrittenRevision) +} + +func TestStoredAdapter_LookupTypeDefByName_NotFound(t *testing.T) { + t.Parallel() + adapter := newStoredAdapter(nil, nil, "") + + _, found, err := adapter.LookupTypeDefByName(t.Context(), "missing") + require.NoError(t, err) + require.False(t, found) +} + +func TestStoredAdapter_LookupCaveatDefByName_Found(t *testing.T) { + t.Parallel() + cav := &core.CaveatDefinition{Name: "c1"} + adapter := newStoredAdapter(nil, map[string]*core.CaveatDefinition{"c1": cav}, "") + + result, found, err := adapter.LookupCaveatDefByName(t.Context(), "c1") + require.NoError(t, err) + require.True(t, found) + require.Equal(t, "c1", result.Definition.Name) + require.Equal(t, testRevision, result.LastWrittenRevision) +} + +func TestStoredAdapter_LookupCaveatDefByName_NotFound(t *testing.T) { + t.Parallel() + adapter := newStoredAdapter(nil, nil, "") + + _, found, err := adapter.LookupCaveatDefByName(t.Context(), "missing") + require.NoError(t, err) + require.False(t, found) +} + +func TestStoredAdapter_ListAllTypeDefinitions(t *testing.T) { + t.Parallel() + adapter := newStoredAdapter( + map[string]*core.NamespaceDefinition{ + "user": {Name: "user"}, + "doc": {Name: "doc"}, + }, nil, "") + + result, err := adapter.ListAllTypeDefinitions(t.Context()) + require.NoError(t, err) + require.Len(t, result, 2) + for _, td := range result { + require.Equal(t, testRevision, td.LastWrittenRevision) + } +} + +func TestStoredAdapter_ListAllCaveatDefinitions(t *testing.T) { + t.Parallel() + adapter := newStoredAdapter(nil, + map[string]*core.CaveatDefinition{ + "c1": {Name: "c1"}, + }, "") + + result, err := adapter.ListAllCaveatDefinitions(t.Context()) + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, "c1", result[0].Definition.Name) + require.Equal(t, testRevision, result[0].LastWrittenRevision) +} + +func TestStoredAdapter_ListAllSchemaDefinitions(t *testing.T) { + t.Parallel() + adapter := newStoredAdapter( + map[string]*core.NamespaceDefinition{"user": {Name: "user"}}, + map[string]*core.CaveatDefinition{"c1": {Name: "c1"}}, + "", + ) + + result, err := adapter.ListAllSchemaDefinitions(t.Context()) + require.NoError(t, err) + require.Len(t, result, 2) + require.Contains(t, result, "user") + require.Contains(t, result, "c1") +} + +func TestStoredAdapter_LookupSchemaDefinitionsByNames(t *testing.T) { + t.Parallel() + adapter := newStoredAdapter( + map[string]*core.NamespaceDefinition{"user": {Name: "user"}}, + map[string]*core.CaveatDefinition{"c1": {Name: "c1"}}, + "", + ) + + result, err := adapter.LookupSchemaDefinitionsByNames(t.Context(), []string{"user", "c1", "missing"}) + require.NoError(t, err) + require.Len(t, result, 2) + require.Contains(t, result, "user") + require.Contains(t, result, "c1") +} + +func TestStoredAdapter_LookupSchemaDefinitionsByNames_NamespacePreferredOverCaveat(t *testing.T) { + t.Parallel() + // When both a namespace and caveat have the same name, the namespace should be found first. + adapter := newStoredAdapter( + map[string]*core.NamespaceDefinition{"shared": {Name: "shared"}}, + map[string]*core.CaveatDefinition{"shared": {Name: "shared"}}, + "", + ) + + result, err := adapter.LookupSchemaDefinitionsByNames(t.Context(), []string{"shared"}) + require.NoError(t, err) + require.Len(t, result, 1) + // The namespace def should be returned since the code checks namespaces first. + _, isNS := result["shared"].(*core.NamespaceDefinition) + require.True(t, isNS) +} + +func TestStoredAdapter_LookupTypeDefinitionsByNames(t *testing.T) { + t.Parallel() + adapter := newStoredAdapter( + map[string]*core.NamespaceDefinition{ + "user": {Name: "user"}, + "doc": {Name: "doc"}, + }, + map[string]*core.CaveatDefinition{"c1": {Name: "c1"}}, // should be ignored + "", + ) + + result, err := adapter.LookupTypeDefinitionsByNames(t.Context(), []string{"user", "c1", "missing"}) + require.NoError(t, err) + require.Len(t, result, 1) + require.Contains(t, result, "user") +} + +func TestStoredAdapter_LookupCaveatDefinitionsByNames(t *testing.T) { + t.Parallel() + adapter := newStoredAdapter( + map[string]*core.NamespaceDefinition{"user": {Name: "user"}}, // should be ignored + map[string]*core.CaveatDefinition{"c1": {Name: "c1"}, "c2": {Name: "c2"}}, + "", + ) + + result, err := adapter.LookupCaveatDefinitionsByNames(t.Context(), []string{"c1", "user", "missing"}) + require.NoError(t, err) + require.Len(t, result, 1) + require.Contains(t, result, "c1") +} + +func TestStoredAdapter_V1NilFallback(t *testing.T) { + t.Parallel() + // When VersionOneof is nil, v1() should return an empty struct without panicking. + adapter := &storedSchemaReaderAdapter{ + storedSchema: datastore.NewReadOnlyStoredSchema(&core.StoredSchema{Version: 1}), + lastWrittenRevision: testRevision, + } + + _, err := adapter.SchemaText(t.Context()) + require.Error(t, err) + + defs, err := adapter.ListAllTypeDefinitions(t.Context()) + require.NoError(t, err) + require.Empty(t, defs) +} + +// --- newStoredSchemaReaderAdapter tests --- + +type fakeStoredSchemaReader struct { + schema *datastore.ReadOnlyStoredSchema + err error +} + +func (m *fakeStoredSchemaReader) ReadStoredSchema(_ context.Context) (*datastore.ReadOnlyStoredSchema, error) { + return m.schema, m.err +} + +func TestNewStoredSchemaReaderAdapter_Success(t *testing.T) { + t.Parallel() + schema := datastore.NewReadOnlyStoredSchema(&core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: "definition user {}", + NamespaceDefinitions: map[string]*core.NamespaceDefinition{ + "user": {Name: "user"}, + }, + }, + }, + }) + + reader := &fakeStoredSchemaReader{schema: schema} + adapter, err := newStoredSchemaReaderAdapter(reader, "somehash", testRevision, noopSchemaCache{}) + require.NoError(t, err) + + text, err := adapter.SchemaText(t.Context()) + require.NoError(t, err) + require.Equal(t, "definition user {}", text) +} + +func TestNewStoredSchemaReaderAdapter_SchemaNotFound(t *testing.T) { + t.Parallel() + reader := &fakeStoredSchemaReader{err: datastore.ErrSchemaNotFound} + adapter, err := newStoredSchemaReaderAdapter(reader, "somehash", testRevision, noopSchemaCache{}) + require.NoError(t, err) + + // Should return empty adapter + defs, err := adapter.ListAllTypeDefinitions(t.Context()) + require.NoError(t, err) + require.Empty(t, defs) + + _, err = adapter.SchemaText(t.Context()) + require.Error(t, err) // SchemaNotDefinedErr +} + +func TestNewStoredSchemaReaderAdapter_OtherError(t *testing.T) { + t.Parallel() + reader := &fakeStoredSchemaReader{err: errors.New("connection refused")} + _, err := newStoredSchemaReaderAdapter(reader, "somehash", testRevision, noopSchemaCache{}) + require.ErrorContains(t, err, "connection refused") +} + +// --- writeSchemaViaLegacy tests --- + +func TestWriteSchemaViaLegacy_NewDefinitions(t *testing.T) { + t.Parallel() + writer := &fakeLegacySchemaWriter{} + + ns := &core.NamespaceDefinition{Name: "user"} + cav := &core.CaveatDefinition{Name: "mycav"} + defs := []datastore.SchemaDefinition{ns, cav} + + err := writeSchemaViaLegacy(t.Context(), writer, &writer.fakeLegacySchemaReader, defs) + require.NoError(t, err) + + require.Len(t, writer.writtenNS, 1) + require.Equal(t, "user", writer.writtenNS[0].Name) + require.Len(t, writer.writtenCaveats, 1) + require.Equal(t, "mycav", writer.writtenCaveats[0].Name) +} + +func TestWriteSchemaViaLegacy_SkipsUnchanged(t *testing.T) { + t.Parallel() + ns := &core.NamespaceDefinition{Name: "user"} + writer := &fakeLegacySchemaWriter{ + fakeLegacySchemaReader: fakeLegacySchemaReader{ + namespaces: []datastore.RevisionedNamespace{ + {Definition: ns, LastWrittenRevision: testRevision}, + }, + }, + } + + // Write the same namespace — should skip it since it's unchanged. + err := writeSchemaViaLegacy(t.Context(), writer, &writer.fakeLegacySchemaReader, []datastore.SchemaDefinition{ns}) + require.NoError(t, err) + require.Empty(t, writer.writtenNS) +} + +func TestWriteSchemaViaLegacy_WritesChanged(t *testing.T) { + t.Parallel() + existingNS := &core.NamespaceDefinition{Name: "user"} + updatedNS := &core.NamespaceDefinition{Name: "user", Metadata: &core.Metadata{}} + writer := &fakeLegacySchemaWriter{ + fakeLegacySchemaReader: fakeLegacySchemaReader{ + namespaces: []datastore.RevisionedNamespace{ + {Definition: existingNS, LastWrittenRevision: testRevision}, + }, + }, + } + + err := writeSchemaViaLegacy(t.Context(), writer, &writer.fakeLegacySchemaReader, []datastore.SchemaDefinition{updatedNS}) + require.NoError(t, err) + require.Len(t, writer.writtenNS, 1) + require.Equal(t, "user", writer.writtenNS[0].Name) +} + +func TestWriteSchemaViaLegacy_DeletesRemoved(t *testing.T) { + t.Parallel() + existingNS := &core.NamespaceDefinition{Name: "oldns"} + existingCav := &core.CaveatDefinition{Name: "oldcav"} + writer := &fakeLegacySchemaWriter{ + fakeLegacySchemaReader: fakeLegacySchemaReader{ + namespaces: []datastore.RevisionedNamespace{ + {Definition: existingNS, LastWrittenRevision: testRevision}, + }, + caveats: []datastore.RevisionedCaveat{ + {Definition: existingCav, LastWrittenRevision: testRevision}, + }, + }, + } + + // Write with no definitions -> should delete both. + err := writeSchemaViaLegacy(t.Context(), writer, &writer.fakeLegacySchemaReader, nil) + require.NoError(t, err) + require.Contains(t, writer.deletedNS, "oldns") + require.Contains(t, writer.deletedCaveats, "oldcav") +} + +func TestWriteSchemaViaLegacy_DuplicateNameError(t *testing.T) { + t.Parallel() + writer := &fakeLegacySchemaWriter{} + + ns1 := &core.NamespaceDefinition{Name: "dup"} + ns2 := &core.NamespaceDefinition{Name: "dup"} + err := writeSchemaViaLegacy(t.Context(), writer, &writer.fakeLegacySchemaReader, []datastore.SchemaDefinition{ns1, ns2}) + require.ErrorContains(t, err, "duplicate definition name: dup") +} + +func TestWriteSchemaViaLegacy_ListNamespacesError(t *testing.T) { + t.Parallel() + writer := &fakeLegacySchemaWriter{ + fakeLegacySchemaReader: fakeLegacySchemaReader{ + namespacesErr: errors.New("list ns fail"), + }, + } + + ns := &core.NamespaceDefinition{Name: "user"} + err := writeSchemaViaLegacy(t.Context(), writer, &writer.fakeLegacySchemaReader, []datastore.SchemaDefinition{ns}) + require.ErrorContains(t, err, "list ns fail") +} + +func TestWriteSchemaViaLegacy_ListCaveatsError(t *testing.T) { + t.Parallel() + writer := &fakeLegacySchemaWriter{ + fakeLegacySchemaReader: fakeLegacySchemaReader{ + caveatsErr: errors.New("list cav fail"), + }, + } + + ns := &core.NamespaceDefinition{Name: "user"} + err := writeSchemaViaLegacy(t.Context(), writer, &writer.fakeLegacySchemaReader, []datastore.SchemaDefinition{ns}) + require.ErrorContains(t, err, "list cav fail") +} + +func TestWriteSchemaViaLegacy_WriteNamespacesError(t *testing.T) { + t.Parallel() + writer := &fakeLegacySchemaWriter{writeNSErr: errors.New("write ns fail")} + + ns := &core.NamespaceDefinition{Name: "user"} + err := writeSchemaViaLegacy(t.Context(), writer, &writer.fakeLegacySchemaReader, []datastore.SchemaDefinition{ns}) + require.ErrorContains(t, err, "write ns fail") +} + +func TestWriteSchemaViaLegacy_WriteCaveatsError(t *testing.T) { + t.Parallel() + writer := &fakeLegacySchemaWriter{writeCavErr: errors.New("write cav fail")} + + cav := &core.CaveatDefinition{Name: "mycav"} + err := writeSchemaViaLegacy(t.Context(), writer, &writer.fakeLegacySchemaReader, []datastore.SchemaDefinition{cav}) + require.ErrorContains(t, err, "write cav fail") +} + +func TestWriteSchemaViaLegacy_DeleteNamespacesError(t *testing.T) { + t.Parallel() + writer := &fakeLegacySchemaWriter{ + fakeLegacySchemaReader: fakeLegacySchemaReader{ + namespaces: []datastore.RevisionedNamespace{ + {Definition: &core.NamespaceDefinition{Name: "old"}, LastWrittenRevision: testRevision}, + }, + }, + deleteNSErr: errors.New("delete ns fail"), + } + + err := writeSchemaViaLegacy(t.Context(), writer, &writer.fakeLegacySchemaReader, nil) + require.ErrorContains(t, err, "delete ns fail") +} + +func TestWriteSchemaViaLegacy_DeleteCaveatsError(t *testing.T) { + t.Parallel() + writer := &fakeLegacySchemaWriter{ + fakeLegacySchemaReader: fakeLegacySchemaReader{ + caveats: []datastore.RevisionedCaveat{ + {Definition: &core.CaveatDefinition{Name: "old"}, LastWrittenRevision: testRevision}, + }, + }, + deleteCavErr: errors.New("delete cav fail"), + } + + err := writeSchemaViaLegacy(t.Context(), writer, &writer.fakeLegacySchemaReader, nil) + require.ErrorContains(t, err, "delete cav fail") +} + +// --- SchemaReaderFromLegacy interface compliance --- + +func TestSchemaReaderFromLegacy_ReturnsSchemaReader(t *testing.T) { + t.Parallel() + reader := &fakeLegacySchemaReader{} + adapter := SchemaReaderFromLegacy(reader) + require.NotNil(t, adapter) + + // Verify it satisfies the interface. + _ = adapter +} + +// --- newStoredSchemaReaderAdapter cache tests --- + +type fakeSchemaCache struct { + loaded int + setCount int + cachedValue *datastore.ReadOnlyStoredSchema +} + +func (m *fakeSchemaCache) GetOrLoad(ctx context.Context, _ datastore.Revision, _ SchemaHash, + loader func(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error), +) (*datastore.ReadOnlyStoredSchema, error) { + m.loaded++ + if m.cachedValue != nil { + return m.cachedValue, nil + } + return loader(ctx) +} + +func (m *fakeSchemaCache) Set(_ SchemaHash, schema *datastore.ReadOnlyStoredSchema) error { + m.setCount++ + m.cachedValue = schema + return nil +} + +func TestNewStoredSchemaReaderAdapter_UsesCache(t *testing.T) { + t.Parallel() + schema := datastore.NewReadOnlyStoredSchema(&core.StoredSchema{ + Version: 1, + VersionOneof: &core.StoredSchema_V1{ + V1: &core.StoredSchema_V1StoredSchema{ + SchemaText: "definition user {}", + NamespaceDefinitions: map[string]*core.NamespaceDefinition{ + "user": {Name: "user"}, + }, + }, + }, + }) + + cache := &fakeSchemaCache{cachedValue: schema} + reader := &fakeStoredSchemaReader{err: errors.New("should not be called")} + + adapter, err := newStoredSchemaReaderAdapter(reader, "hash1", testRevision, cache) + require.NoError(t, err) + require.Equal(t, 1, cache.loaded) + + text, err := adapter.SchemaText(t.Context()) + require.NoError(t, err) + require.Equal(t, "definition user {}", text) +} diff --git a/pkg/datalayer/schemahash.go b/pkg/datalayer/schemahash.go new file mode 100644 index 000000000..284c1743e --- /dev/null +++ b/pkg/datalayer/schemahash.go @@ -0,0 +1,48 @@ +package datalayer + +// SchemaHash is a string that uniquely identifies a specific version of a schema. +// It is used by caching layers to avoid redundant reads of the schema from the +// underlying datastore when the schema hasn't changed. +type SchemaHash string + +const ( + // NoSchemaHashInTransaction is a sentinel value used when reading within a + // read-write transaction where the schema revision is not yet stable. + NoSchemaHashInTransaction SchemaHash = "no-schema-hash-in-transaction" + + // NoSchemaHashInDevelopment is a sentinel value used when operating in + // development mode, where caching of schema is not desired. + NoSchemaHashInDevelopment SchemaHash = "no-schema-hash-in-development" + + // NoSchemaHashForTesting is a sentinel value used in tests, where schema + // caching is not needed. + NoSchemaHashForTesting SchemaHash = "no-schema-hash-for-testing" + + // NoSchemaHashForWatch is a sentinel value used when reading schema for + // watch operations, where the hash is not yet available. + NoSchemaHashForWatch SchemaHash = "no-schema-hash-for-watch" + + // NoSchemaHashForLegacyCursor is a sentinel value for decoding legacy cursors + // that do not contain a schema hash field. + NoSchemaHashForLegacyCursor SchemaHash = "no-schema-hash-for-legacy-cursor" + + // NoSchemaHashInLegacyMode is a sentinel value used when the DataLayer is + // operating in legacy schema mode, where no unified schema exists. + NoSchemaHashInLegacyMode SchemaHash = "no-schema-hash-in-legacy-mode" +) + +// IsBypassSentinel returns true if this SchemaHash is a sentinel value that +// should bypass any caching. +func (sh SchemaHash) IsBypassSentinel() bool { + switch sh { + case NoSchemaHashInTransaction, + NoSchemaHashInDevelopment, + NoSchemaHashForTesting, + NoSchemaHashForWatch, + NoSchemaHashForLegacyCursor, + NoSchemaHashInLegacyMode: + return true + default: + return false + } +} diff --git a/pkg/datalayer/schemamode.go b/pkg/datalayer/schemamode.go new file mode 100644 index 000000000..453081eed --- /dev/null +++ b/pkg/datalayer/schemamode.go @@ -0,0 +1,74 @@ +package datalayer + +import "fmt" + +// SchemaMode represents the experimental schema mode for datastore operations. +// It controls how schema is read from and written to the datastore, allowing +// a gradual migration from legacy per-definition storage to unified schema storage. +type SchemaMode uint8 + +const ( + // SchemaModeReadLegacyWriteLegacy uses legacy schema reader and writer. + // This is the default and backward-compatible mode. + SchemaModeReadLegacyWriteLegacy SchemaMode = iota + + // SchemaModeReadLegacyWriteBoth uses legacy schema reader and writes to both + // legacy and unified schema storage. Use this as the first migration step. + SchemaModeReadLegacyWriteBoth + + // SchemaModeReadNewWriteBoth uses unified schema reader and writes to both + // legacy and unified schema storage. Use this as the second migration step. + SchemaModeReadNewWriteBoth + + // SchemaModeReadNewWriteNew uses unified schema reader and writer only. + // This is the final migration target. + SchemaModeReadNewWriteNew +) + +var schemaModeNames = map[string]SchemaMode{ + "read-legacy-write-legacy": SchemaModeReadLegacyWriteLegacy, + "read-legacy-write-both": SchemaModeReadLegacyWriteBoth, + "read-new-write-both": SchemaModeReadNewWriteBoth, + "read-new-write-new": SchemaModeReadNewWriteNew, +} + +var schemaModeStrings = map[SchemaMode]string{ + SchemaModeReadLegacyWriteLegacy: "read-legacy-write-legacy", + SchemaModeReadLegacyWriteBoth: "read-legacy-write-both", + SchemaModeReadNewWriteBoth: "read-new-write-both", + SchemaModeReadNewWriteNew: "read-new-write-new", +} + +// ParseSchemaMode converts a string to a SchemaMode. Returns an error if the string is invalid. +func ParseSchemaMode(s string) (SchemaMode, error) { + mode, ok := schemaModeNames[s] + if !ok { + return SchemaModeReadLegacyWriteLegacy, fmt.Errorf( + "invalid schema mode %q, must be one of: read-legacy-write-legacy, read-legacy-write-both, read-new-write-both, read-new-write-new", s) + } + return mode, nil +} + +// String returns the string representation of the SchemaMode. +func (s SchemaMode) String() string { + str, ok := schemaModeStrings[s] + if !ok { + return "unknown" + } + return str +} + +// ReadsFromNew returns true if the mode reads from the unified schema storage. +func (s SchemaMode) ReadsFromNew() bool { + return s == SchemaModeReadNewWriteBoth || s == SchemaModeReadNewWriteNew +} + +// WritesToLegacy returns true if the mode writes to legacy schema storage. +func (s SchemaMode) WritesToLegacy() bool { + return s == SchemaModeReadLegacyWriteLegacy || s == SchemaModeReadLegacyWriteBoth || s == SchemaModeReadNewWriteBoth +} + +// WritesToNew returns true if the mode writes to unified schema storage. +func (s SchemaMode) WritesToNew() bool { + return s == SchemaModeReadLegacyWriteBoth || s == SchemaModeReadNewWriteBoth || s == SchemaModeReadNewWriteNew +} diff --git a/pkg/datastore/context.go b/pkg/datastore/context.go index c2af9105c..1c8b5b7cb 100644 --- a/pkg/datastore/context.go +++ b/pkg/datastore/context.go @@ -137,6 +137,10 @@ func (r *ctxReader) ReverseQueryRelationships(ctx context.Context, subjectsFilte return r.delegate.ReverseQueryRelationships(context.WithoutCancel(ctx), subjectsFilter, options...) } +func (r *ctxReader) ReadStoredSchema(ctx context.Context) (*ReadOnlyStoredSchema, error) { + return r.delegate.ReadStoredSchema(context.WithoutCancel(ctx)) +} + var ( _ Datastore = (*ctxProxy)(nil) _ Reader = (*ctxReader)(nil) diff --git a/pkg/datastore/datastore.go b/pkg/datastore/datastore.go index 05888c7bc..5c358d3ed 100644 --- a/pkg/datastore/datastore.go +++ b/pkg/datastore/datastore.go @@ -515,11 +515,36 @@ func (rd RevisionedDefinition[T]) GetLastWrittenRevision() Revision { // RevisionedNamespace is a revisioned version of a namespace definition. type RevisionedNamespace = RevisionedDefinition[*core.NamespaceDefinition] +// ReadOnlyStoredSchema wraps a *core.StoredSchema to indicate it is read-only +// and must not be modified, as it may be shared across multiple callers via caching. +type ReadOnlyStoredSchema struct { + schema *core.StoredSchema +} + +// NewReadOnlyStoredSchema wraps a StoredSchema as read-only. +// Returns nil if the provided schema is nil. +func NewReadOnlyStoredSchema(schema *core.StoredSchema) *ReadOnlyStoredSchema { + if schema == nil { + return nil + } + return &ReadOnlyStoredSchema{schema: schema} +} + +// Get returns the underlying StoredSchema. Callers must not modify the returned value. +func (r *ReadOnlyStoredSchema) Get() *core.StoredSchema { + return r.schema +} + // Reader is an interface for reading relationships from the datastore. type Reader interface { LegacySchemaReader CounterReader + // ReadStoredSchema reads the unified stored schema from the datastore. + // The returned ReadOnlyStoredSchema must not be modified, as it may be shared + // across callers via caching. + ReadStoredSchema(ctx context.Context) (*ReadOnlyStoredSchema, error) + // QueryRelationships reads relationships, starting from the resource side. QueryRelationships( ctx context.Context, @@ -552,6 +577,9 @@ type ReadWriteTransaction interface { options ...options.DeleteOptionsOption, ) (uint64, bool, error) + // WriteStoredSchema writes the unified stored schema to the datastore. + WriteStoredSchema(ctx context.Context, schema *core.StoredSchema) error + // BulkLoad takes a relationship source iterator, and writes all of the // relationships to the backing datastore in an optimized fashion. This // method can and will omit checks and otherwise cut corners in the diff --git a/pkg/datastore/errors.go b/pkg/datastore/errors.go index 985a7e3a3..9dc8fc42c 100644 --- a/pkg/datastore/errors.go +++ b/pkg/datastore/errors.go @@ -312,6 +312,7 @@ var ( ErrClosedIterator = errors.New("unable to iterate: iterator closed") ErrCursorsWithoutSorting = errors.New("cursors are disabled on unsorted results") ErrCursorEmpty = errors.New("cursors are only available after the first result") + ErrSchemaNotFound = errors.New("schema not found") ) // CreateRelationshipExistsError is returned when attempting to CREATE an already-existing relationship. diff --git a/pkg/datastore/mocks/mock_datastore.go b/pkg/datastore/mocks/mock_datastore.go index 3db748db9..cffbc4e3a 100644 --- a/pkg/datastore/mocks/mock_datastore.go +++ b/pkg/datastore/mocks/mock_datastore.go @@ -239,6 +239,21 @@ func (mr *MockReaderMockRecorder) QueryRelationships(ctx, filter any, arg2 ...an return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRelationships", reflect.TypeOf((*MockReader)(nil).QueryRelationships), varargs...) } +// ReadStoredSchema mocks base method. +func (m *MockReader) ReadStoredSchema(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadStoredSchema", ctx) + ret0, _ := ret[0].(*datastore.ReadOnlyStoredSchema) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadStoredSchema indicates an expected call of ReadStoredSchema. +func (mr *MockReaderMockRecorder) ReadStoredSchema(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStoredSchema", reflect.TypeOf((*MockReader)(nil).ReadStoredSchema), ctx) +} + // ReverseQueryRelationships mocks base method. func (m *MockReader) ReverseQueryRelationships(ctx context.Context, subjectsFilter datastore.SubjectsFilter, arg2 ...options.ReverseQueryOptionsOption) (datastore.RelationshipIterator, error) { m.ctrl.T.Helper() @@ -522,6 +537,21 @@ func (mr *MockReadWriteTransactionMockRecorder) QueryRelationships(ctx, filter a return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRelationships", reflect.TypeOf((*MockReadWriteTransaction)(nil).QueryRelationships), varargs...) } +// ReadStoredSchema mocks base method. +func (m *MockReadWriteTransaction) ReadStoredSchema(ctx context.Context) (*datastore.ReadOnlyStoredSchema, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadStoredSchema", ctx) + ret0, _ := ret[0].(*datastore.ReadOnlyStoredSchema) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadStoredSchema indicates an expected call of ReadStoredSchema. +func (mr *MockReadWriteTransactionMockRecorder) ReadStoredSchema(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStoredSchema", reflect.TypeOf((*MockReadWriteTransaction)(nil).ReadStoredSchema), ctx) +} + // RegisterCounter mocks base method. func (m *MockReadWriteTransaction) RegisterCounter(ctx context.Context, name string, filter *corev1.RelationshipFilter) error { m.ctrl.T.Helper() @@ -598,6 +628,20 @@ func (mr *MockReadWriteTransactionMockRecorder) WriteRelationships(ctx, mutation return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteRelationships", reflect.TypeOf((*MockReadWriteTransaction)(nil).WriteRelationships), ctx, mutations) } +// WriteStoredSchema mocks base method. +func (m *MockReadWriteTransaction) WriteStoredSchema(ctx context.Context, schema *corev1.StoredSchema) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteStoredSchema", ctx, schema) + ret0, _ := ret[0].(error) + return ret0 +} + +// WriteStoredSchema indicates an expected call of WriteStoredSchema. +func (mr *MockReadWriteTransactionMockRecorder) WriteStoredSchema(ctx, schema any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteStoredSchema", reflect.TypeOf((*MockReadWriteTransaction)(nil).WriteStoredSchema), ctx, schema) +} + // MockBulkWriteRelationshipSource is a mock of BulkWriteRelationshipSource interface. type MockBulkWriteRelationshipSource struct { ctrl *gomock.Controller diff --git a/pkg/datastore/test/datastore.go b/pkg/datastore/test/datastore.go index 105a2f475..b23faf29c 100644 --- a/pkg/datastore/test/datastore.go +++ b/pkg/datastore/test/datastore.go @@ -234,6 +234,18 @@ func AllWithExceptions(t *testing.T, tester DatastoreTester, except Categories, t.Run("TestRelationshipCounterOverExpired", runner(tester, RelationshipCounterOverExpiredTest)) t.Run("TestRegisterRelationshipCountersInParallel", runner(tester, RegisterRelationshipCountersInParallelTest)) t.Run("TestRelationshipCountersWithOddFilter", runner(tester, RelationshipCountersWithOddFilterTest)) + + t.Run("TestStoredSchemaNotFound", runner(tester, StoredSchemaNotFoundTest)) + t.Run("TestStoredSchemaWriteRead", runner(tester, StoredSchemaWriteReadTest)) + t.Run("TestStoredSchemaRevision", runner(tester, StoredSchemaRevisionTest)) + t.Run("TestStoredSchemaUpdate", runner(tester, StoredSchemaUpdateTest)) + t.Run("TestStoredSchemaMultipleRevisions", runner(tester, StoredSchemaMultipleRevisionsTest)) + if !except.Transaction() { + t.Run("TestStoredSchemaReadWithinTransaction", runner(tester, StoredSchemaReadWithinTransactionTest)) + } + t.Run("TestStoredSchemaStableText", runner(tester, StoredSchemaStableTextTest)) + t.Run("TestStoredSchemaLarge", runner(tester, StoredSchemaLargeTest)) + t.Run("TestStoredSchemaPhaseMigration", runner(tester, StoredSchemaPhaseMigrationTest)) } func OnlyGCTests(t *testing.T, tester DatastoreTester, concurrent bool) { diff --git a/pkg/datastore/test/revisions.go b/pkg/datastore/test/revisions.go index 0305f7ebe..8f09415fb 100644 --- a/pkg/datastore/test/revisions.go +++ b/pkg/datastore/test/revisions.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/pkg/datalayer" "github.com/authzed/spicedb/pkg/datastore" ns "github.com/authzed/spicedb/pkg/namespace" core "github.com/authzed/spicedb/pkg/proto/core/v1" @@ -89,6 +90,7 @@ func RevisionSerializationTest(t *testing.T, tester DatastoreTester) { AtRevision: revToTest.String(), DepthRemaining: 50, TraversalBloom: dispatch.MustNewTraversalBloomFilter(50), + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), } require.NoError(protovalidate.Validate(meta)) } diff --git a/pkg/datastore/test/storedschema.go b/pkg/datastore/test/storedschema.go new file mode 100644 index 000000000..b9890823a --- /dev/null +++ b/pkg/datastore/test/storedschema.go @@ -0,0 +1,698 @@ +package test + +import ( + "context" + "fmt" + "sort" + "testing" + + "github.com/stretchr/testify/require" + + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datalayer" + "github.com/authzed/spicedb/pkg/datastore" + ns "github.com/authzed/spicedb/pkg/namespace" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" + "github.com/authzed/spicedb/pkg/schemadsl/input" + "github.com/authzed/spicedb/pkg/testutil" +) + +// toSchemaDefinitions converts compiler.SchemaDefinition slice to datastore.SchemaDefinition slice. +func toSchemaDefinitions(defs []compiler.SchemaDefinition) []datastore.SchemaDefinition { + result := make([]datastore.SchemaDefinition, len(defs)) + for i, def := range defs { + result[i] = def.(datastore.SchemaDefinition) + } + return result +} + +// writeSchema is a test helper that writes a schema via WriteSchemaViaStoredSchema. +func writeSchema(ctx context.Context, t *testing.T, ds datastore.Datastore, + definitions []compiler.SchemaDefinition, schemaText string, +) datastore.Revision { + t.Helper() + rev, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + return datalayer.WriteSchemaViaStoredSchema(ctx, rwt, toSchemaDefinitions(definitions), schemaText, nil) + }) + require.NoError(t, err) + return rev +} + +// StoredSchemaNotFoundTest tests that reading a stored schema when none exists returns ErrSchemaNotFound. +func StoredSchemaNotFoundTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + + ds, err := tester.New(t, 0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + + ctx := t.Context() + + startRevision, err := ds.HeadRevision(ctx) + require.NoError(err) + + _, err = ds.SnapshotReader(startRevision).ReadStoredSchema(ctx) + require.ErrorIs(err, datastore.ErrSchemaNotFound) +} + +// StoredSchemaWriteReadTest tests basic write and read of a stored schema, +// including both namespace and caveat definitions. +func StoredSchemaWriteReadTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + + ds, err := tester.New(t, 0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + + ctx := t.Context() + + schemaString := `caveat is_allowed(allowed bool) { + allowed +} + +definition user {} + +definition document { + relation viewer: user with is_allowed +}` + + compiled, err := compiler.Compile(compiler.InputSchema{ + Source: input.Source("schema"), + SchemaString: schemaString, + }, compiler.AllowUnprefixedObjectType()) + require.NoError(err) + + allDefs := make([]compiler.SchemaDefinition, 0, len(compiled.ObjectDefinitions)+len(compiled.CaveatDefinitions)) + for _, caveatDef := range compiled.CaveatDefinitions { + allDefs = append(allDefs, caveatDef) + } + for _, objDef := range compiled.ObjectDefinitions { + allDefs = append(allDefs, objDef) + } + + writtenRev := writeSchema(ctx, t, ds, allDefs, schemaString) + + // Read it back at the written revision. + readSchema, err := ds.SnapshotReader(writtenRev).ReadStoredSchema(ctx) + require.NoError(err) + + // Verify the stored schema contents. + require.EqualValues(1, readSchema.Get().Version) + v1 := readSchema.Get().GetV1() + require.NotNil(v1) + require.Equal(schemaString, v1.SchemaText) + require.NotEmpty(v1.SchemaHash) + require.Len(v1.NamespaceDefinitions, 2) + require.Contains(v1.NamespaceDefinitions, "user") + require.Contains(v1.NamespaceDefinitions, "document") + require.Len(v1.CaveatDefinitions, 1) + require.Contains(v1.CaveatDefinitions, "is_allowed") + + // Verify the namespace definitions are faithfully round-tripped. + for _, objDef := range compiled.ObjectDefinitions { + readDef, ok := v1.NamespaceDefinitions[objDef.Name] + require.True(ok) + testutil.RequireProtoEqual(t, objDef, readDef, "namespace %s should round-trip", objDef.Name) + } + + // Verify the caveat definition round-trips. + testutil.RequireProtoEqual(t, compiled.CaveatDefinitions[0], v1.CaveatDefinitions["is_allowed"], + "caveat definition should round-trip") +} + +// StoredSchemaRevisionTest writes three schema versions, then reads all of them +// back (including older revisions after newer ones have been written) to verify +// that each revision returns the correct content. +func StoredSchemaRevisionTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + + ds, err := tester.New(t, 0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + + ctx := t.Context() + + // Write first schema: just "user". + firstDefs := []compiler.SchemaDefinition{ + ns.Namespace("user"), + } + firstText, _, err := generator.GenerateSchema(firstDefs) + require.NoError(err) + firstRev := writeSchema(ctx, t, ds, firstDefs, firstText) + + // Write second schema: "user" + "document" with viewer. + secondDefs := []compiler.SchemaDefinition{ + ns.Namespace("user"), + ns.Namespace("document", + ns.MustRelation("viewer", nil, ns.AllowedRelation("user", "...")), + ), + } + secondText, _, err := generator.GenerateSchema(secondDefs) + require.NoError(err) + secondRev := writeSchema(ctx, t, ds, secondDefs, secondText) + require.True(secondRev.GreaterThan(firstRev)) + + // Write third schema: "user" + "document" with viewer+editor. + thirdDefs := []compiler.SchemaDefinition{ + ns.Namespace("user"), + ns.Namespace("document", + ns.MustRelation("viewer", nil, ns.AllowedRelation("user", "...")), + ns.MustRelation("editor", nil, ns.AllowedRelation("user", "...")), + ), + } + thirdText, _, err := generator.GenerateSchema(thirdDefs) + require.NoError(err) + thirdRev := writeSchema(ctx, t, ds, thirdDefs, thirdText) + require.True(thirdRev.GreaterThan(secondRev)) + + // Now read ALL revisions back, starting with the oldest, to verify that + // reading older revisions after newer writes still returns the old content. + readFirst, err := ds.SnapshotReader(firstRev).ReadStoredSchema(ctx) + require.NoError(err) + require.NotEmpty(readFirst.Get().GetV1().SchemaHash) + require.Len(readFirst.Get().GetV1().NamespaceDefinitions, 1) + require.Contains(readFirst.Get().GetV1().NamespaceDefinitions, "user") + + readSecond, err := ds.SnapshotReader(secondRev).ReadStoredSchema(ctx) + require.NoError(err) + require.NotEmpty(readSecond.Get().GetV1().SchemaHash) + require.Len(readSecond.Get().GetV1().NamespaceDefinitions, 2) + require.Contains(readSecond.Get().GetV1().NamespaceDefinitions, "user") + require.Contains(readSecond.Get().GetV1().NamespaceDefinitions, "document") + require.Len(readSecond.Get().GetV1().NamespaceDefinitions["document"].Relation, 1, + "document should have only viewer at second revision") + + readThird, err := ds.SnapshotReader(thirdRev).ReadStoredSchema(ctx) + require.NoError(err) + require.NotEmpty(readThird.Get().GetV1().SchemaHash) + require.Len(readThird.Get().GetV1().NamespaceDefinitions, 2) + require.Len(readThird.Get().GetV1().NamespaceDefinitions["document"].Relation, 2, + "document should have viewer and editor at third revision") + + // Verify all three revisions have different hashes (different schemas). + require.NotEqual(readFirst.Get().GetV1().SchemaHash, readSecond.Get().GetV1().SchemaHash) + require.NotEqual(readSecond.Get().GetV1().SchemaHash, readThird.Get().GetV1().SchemaHash) +} + +// StoredSchemaUpdateTest tests that overwriting a stored schema replaces it completely. +func StoredSchemaUpdateTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + + ds, err := tester.New(t, 0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + + ctx := t.Context() + + // Write initial schema with 2 relations. + initialDefs := []compiler.SchemaDefinition{ + ns.Namespace("user"), + ns.Namespace("document", + ns.MustRelation("viewer", nil, ns.AllowedRelation("user", "...")), + ns.MustRelation("editor", nil, ns.AllowedRelation("user", "...")), + ), + } + initialText, _, err := generator.GenerateSchema(initialDefs) + require.NoError(err) + firstRev := writeSchema(ctx, t, ds, initialDefs, initialText) + + // Update schema: add an "owner" relation. + updatedDefs := []compiler.SchemaDefinition{ + ns.Namespace("user"), + ns.Namespace("document", + ns.MustRelation("viewer", nil, ns.AllowedRelation("user", "...")), + ns.MustRelation("editor", nil, ns.AllowedRelation("user", "...")), + ns.MustRelation("owner", nil, ns.AllowedRelation("user", "...")), + ), + } + updatedText, _, err := generator.GenerateSchema(updatedDefs) + require.NoError(err) + secondRev := writeSchema(ctx, t, ds, updatedDefs, updatedText) + require.True(secondRev.GreaterThan(firstRev)) + + // At first revision: 2 relations. + readFirst, err := ds.SnapshotReader(firstRev).ReadStoredSchema(ctx) + require.NoError(err) + docFirst := readFirst.Get().GetV1().NamespaceDefinitions["document"] + require.Len(docFirst.Relation, 2) + + // At second revision: 3 relations. + readSecond, err := ds.SnapshotReader(secondRev).ReadStoredSchema(ctx) + require.NoError(err) + docSecond := readSecond.Get().GetV1().NamespaceDefinitions["document"] + require.Len(docSecond.Relation, 3) + + ownerFound := false + for _, rel := range docSecond.Relation { + if rel.Name == "owner" { + ownerFound = true + break + } + } + require.True(ownerFound, "owner relation should exist in updated schema") +} + +// StoredSchemaMultipleRevisionsTest writes multiple schema versions and verifies each revision +// can still be read independently. +func StoredSchemaMultipleRevisionsTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + + ds, err := tester.New(t, 0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + + ctx := t.Context() + + const numVersions = 5 + + type versionData struct { + revision datastore.Revision + schemaText string + numDefs int + } + versions := make([]versionData, 0, numVersions) + + for i := range numVersions { + defs := []compiler.SchemaDefinition{ + ns.Namespace("user"), + ns.Namespace(fmt.Sprintf("resource_%d", i), + ns.MustRelation("viewer", nil, ns.AllowedRelation("user", "...")), + ), + } + + schemaText, _, err := generator.GenerateSchema(defs) + require.NoError(err) + + rev := writeSchema(ctx, t, ds, defs, schemaText) + + versions = append(versions, versionData{ + revision: rev, + schemaText: schemaText, + numDefs: len(defs), + }) + } + + // Verify each revision independently. + for i, v := range versions { + readSchema, err := ds.SnapshotReader(v.revision).ReadStoredSchema(ctx) + require.NoError(err, "failed to read schema at version %d", i) + require.Equal(v.schemaText, readSchema.Get().GetV1().SchemaText, + "schema text mismatch at version %d", i) + require.Len(readSchema.Get().GetV1().NamespaceDefinitions, v.numDefs, + "namespace count mismatch at version %d", i) + } +} + +// StoredSchemaReadWithinTransactionTest tests that ReadStoredSchema works within a +// read-write transaction and sees the schema written in the same transaction. +func StoredSchemaReadWithinTransactionTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + + ds, err := tester.New(t, 0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + + ctx := t.Context() + + defs := []compiler.SchemaDefinition{ + ns.Namespace("user"), + } + schemaText, _, err := generator.GenerateSchema(defs) + require.NoError(err) + + _, err = ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + // Write within the transaction. + if err := datalayer.WriteSchemaViaStoredSchema(ctx, rwt, toSchemaDefinitions(defs), schemaText, nil); err != nil { + return err + } + + // Read within the same transaction. + readSchema, err := rwt.ReadStoredSchema(ctx) + if err != nil { + return err + } + + require.Equal(schemaText, readSchema.Get().GetV1().SchemaText) + return nil + }) + require.NoError(err) +} + +// StoredSchemaStableTextTest verifies that a schema with multiple namespaces and caveats +// produces stable text when definitions are sorted before generation, regardless of the +// initial ordering of definitions. +func StoredSchemaStableTextTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + + ds, err := tester.New(t, 0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + + ctx := t.Context() + + schemaString := `caveat beta_caveat(flag bool) { + flag +} + +caveat alpha_caveat(allowed bool) { + allowed +} + +definition zebra {} + +definition apple { + relation viewer: zebra with alpha_caveat +} + +definition middle { + relation editor: zebra with beta_caveat +}` + + compiled, err := compiler.Compile(compiler.InputSchema{ + Source: input.Source("schema"), + SchemaString: schemaString, + }, compiler.AllowUnprefixedObjectType()) + require.NoError(err) + + // Build definitions in unsorted order (caveats first, then objects, as-compiled). + unsortedDefs := make([]compiler.SchemaDefinition, 0, + len(compiled.CaveatDefinitions)+len(compiled.ObjectDefinitions)) + for _, cd := range compiled.CaveatDefinitions { + unsortedDefs = append(unsortedDefs, cd) + } + for _, od := range compiled.ObjectDefinitions { + unsortedDefs = append(unsortedDefs, od) + } + + // sortDefinitions returns a new slice with caveats sorted by name first, + // then namespaces sorted by name. + sortDefinitions := func(defs []compiler.SchemaDefinition) []compiler.SchemaDefinition { + var caveatDefs, nsDefs []compiler.SchemaDefinition + for _, def := range defs { + switch def.(type) { + case *core.CaveatDefinition: + caveatDefs = append(caveatDefs, def) + case *core.NamespaceDefinition: + nsDefs = append(nsDefs, def) + } + } + sort.Slice(caveatDefs, func(i, j int) bool { + return caveatDefs[i].GetName() < caveatDefs[j].GetName() + }) + sort.Slice(nsDefs, func(i, j int) bool { + return nsDefs[i].GetName() < nsDefs[j].GetName() + }) + + sorted := make([]compiler.SchemaDefinition, 0, len(defs)) + sorted = append(sorted, caveatDefs...) + sorted = append(sorted, nsDefs...) + return sorted + } + + sortedDefs := sortDefinitions(unsortedDefs) + + // Generate text from sorted definitions. + sortedText, _, err := generator.GenerateSchema(sortedDefs) + require.NoError(err) + + // Write the stored schema with sorted text. + writtenRev := writeSchema(ctx, t, ds, sortedDefs, sortedText) + + // Read it back and verify text is preserved. + readSchema, err := ds.SnapshotReader(writtenRev).ReadStoredSchema(ctx) + require.NoError(err) + require.Equal(sortedText, readSchema.Get().GetV1().SchemaText) + + // Generate text again from the same definitions in a different initial order + // (reversed), sort, and verify the generated text is identical. + reversedDefs := make([]compiler.SchemaDefinition, 0, len(unsortedDefs)) + for i := len(unsortedDefs) - 1; i >= 0; i-- { + reversedDefs = append(reversedDefs, unsortedDefs[i]) + } + + reSortedDefs := sortDefinitions(reversedDefs) + reSortedText, _, err := generator.GenerateSchema(reSortedDefs) + require.NoError(err) + require.Equal(sortedText, reSortedText, + "text generated from differently-ordered definitions should be identical after sorting") +} + +// StoredSchemaPhaseMigrationTest tests reading and writing schema through all four +// schema mode phases, simulating a live migration from legacy to unified storage. +// Each phase creates a new DataLayer on the same underlying datastore. For phases 2-4, +// it first reads the schema written by the previous phase to verify continuity, then +// writes an updated schema and verifies the change is reflected. +func StoredSchemaPhaseMigrationTest(t *testing.T, tester DatastoreTester) { + ds, err := tester.New(t, 0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(t, err) + + ctx := t.Context() + + // Each phase writes a different schema so we can verify the update is visible. + // Caveats are included to verify they are properly handled across phases. + phaseSchemas := []string{ + // Phase 1: base schema with a caveat + `caveat is_public(public bool) { + public +} + +definition user {} + +definition document { + relation viewer: user with is_public +}`, + // Phase 2: add editor relation and a second caveat + `caveat is_public(public bool) { + public +} + +caveat has_access(allowed bool) { + allowed +} + +definition user {} + +definition document { + relation viewer: user with is_public + relation editor: user with has_access +}`, + // Phase 3: add permission + `caveat is_public(public bool) { + public +} + +caveat has_access(allowed bool) { + allowed +} + +definition user {} + +definition document { + relation viewer: user with is_public + relation editor: user with has_access + permission view = viewer + editor +}`, + // Phase 4: add a new type, remove first caveat + `caveat has_access(allowed bool) { + allowed +} + +definition user {} + +definition group { + relation member: user +} + +definition document { + relation viewer: user | group#member + relation editor: user with has_access + permission view = viewer + editor +}`, + } + + phases := []struct { + name string + mode datalayer.SchemaMode + hasLegacy bool + hasUnified bool + expectedTypes []string + expectedCaveats []string + }{ + { + name: "Phase1_ReadLegacyWriteLegacy", + mode: datalayer.SchemaModeReadLegacyWriteLegacy, + hasLegacy: true, + hasUnified: false, + expectedTypes: []string{"document", "user"}, + expectedCaveats: []string{"is_public"}, + }, + { + name: "Phase2_ReadLegacyWriteBoth", + mode: datalayer.SchemaModeReadLegacyWriteBoth, + hasLegacy: true, + hasUnified: true, + expectedTypes: []string{"document", "user"}, + expectedCaveats: []string{"has_access", "is_public"}, + }, + { + name: "Phase3_ReadNewWriteBoth", + mode: datalayer.SchemaModeReadNewWriteBoth, + hasLegacy: true, + hasUnified: true, + expectedTypes: []string{"document", "user"}, + expectedCaveats: []string{"has_access", "is_public"}, + }, + { + name: "Phase4_ReadNewWriteNew", + mode: datalayer.SchemaModeReadNewWriteNew, + hasLegacy: false, + hasUnified: true, + expectedTypes: []string{"document", "group", "user"}, + expectedCaveats: []string{"has_access"}, + }, + } + + // Track the last revision and expected types/caveats so subsequent phases can verify reads. + var lastRev datastore.Revision + var prevExpectedTypes []string + var prevExpectedCaveats []string + + for i, phase := range phases { + schemaText := phaseSchemas[i] + compiled, err := compiler.Compile(compiler.InputSchema{ + Source: input.Source("schema"), + SchemaString: schemaText, + }, compiler.AllowUnprefixedObjectType()) + require.NoError(t, err) + + allDefs := make([]datastore.SchemaDefinition, 0, len(compiled.ObjectDefinitions)+len(compiled.CaveatDefinitions)) + for _, caveatDef := range compiled.CaveatDefinitions { + allDefs = append(allDefs, caveatDef) + } + for _, objDef := range compiled.ObjectDefinitions { + allDefs = append(allDefs, objDef) + } + + // Create a fresh DataLayer for this phase on the same underlying datastore. + dl := datalayer.NewDataLayer(ds, datalayer.WithSchemaMode(phase.mode)) + + // For phases 2+, verify we can read the schema written by the previous phase + // before writing anything new. This validates cross-phase read continuity. + if i > 0 { + t.Run(phase.name+"/ReadFromPreviousPhase", func(t *testing.T) { + verifySchemaTypes(t, ctx, dl, lastRev, prevExpectedTypes, prevExpectedCaveats) + }) + } + + // Write schema through the datalayer. + t.Run(phase.name+"/Write", func(t *testing.T) { + rev, err := dl.ReadWriteTx(ctx, func(ctx context.Context, rwt datalayer.ReadWriteTransaction) error { + return rwt.WriteSchema(ctx, allDefs, schemaText, caveattypes.Default.TypeSet) + }) + require.NoError(t, err) + lastRev = rev + + // Verify legacy storage state. + legacyNsDefs, err := ds.SnapshotReader(rev).LegacyListAllNamespaces(ctx) + require.NoError(t, err) + if phase.hasLegacy { + require.NotEmpty(t, legacyNsDefs, "phase %d: expected legacy data", i+1) + } + + legacyCaveatDefs, err := ds.SnapshotReader(rev).LegacyListAllCaveats(ctx) + require.NoError(t, err) + if phase.hasLegacy { + require.NotEmpty(t, legacyCaveatDefs, "phase %d: expected legacy caveat data", i+1) + } + + // Verify unified storage state. + storedSchema, err := ds.SnapshotReader(rev).ReadStoredSchema(ctx) + if phase.hasUnified { + require.NoError(t, err) + require.NotNil(t, storedSchema) + require.NotEmpty(t, storedSchema.Get().GetV1().SchemaText) + } else { + require.ErrorIs(t, err, datastore.ErrSchemaNotFound) + } + }) + + // Read back through the datalayer after writing and verify the new schema is visible. + t.Run(phase.name+"/ReadAfterWrite", func(t *testing.T) { + verifySchemaTypes(t, ctx, dl, lastRev, phase.expectedTypes, phase.expectedCaveats) + }) + + prevExpectedTypes = phase.expectedTypes + prevExpectedCaveats = phase.expectedCaveats + } +} + +// StoredSchemaLargeTest generates a large schema (>2MB) with 4000 types, each having 20 relations, +// writes it, reads it back, and verifies the exact match. +func StoredSchemaLargeTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + + ds, err := tester.New(t, 0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + + ctx := t.Context() + + // Build a large schema with 4000 types, each with 20 relations. + const numTypes = 4000 + const numRelations = 20 + + allDefs := make([]compiler.SchemaDefinition, 0, numTypes) + for i := range numTypes { + typeName := fmt.Sprintf("type%04d", i) + rels := make([]*core.Relation, 0, numRelations) + for j := range numRelations { + relName := fmt.Sprintf("rel_%d", j) + rels = append(rels, ns.MustRelation(relName, nil, ns.AllowedRelation(typeName, "..."))) + } + allDefs = append(allDefs, ns.Namespace(typeName, rels...)) + } + + schemaText, _, err := generator.GenerateSchema(allDefs) + require.NoError(err) + + // Verify it is indeed >2MB. + require.Greater(len(schemaText), 2*1024*1024, "generated schema should exceed 2MB") + + writtenRev := writeSchema(ctx, t, ds, allDefs, schemaText) + + // Read back and verify exact match. + readSchema, err := ds.SnapshotReader(writtenRev).ReadStoredSchema(ctx) + require.NoError(err) + require.Equal(schemaText, readSchema.Get().GetV1().SchemaText) + require.Len(readSchema.Get().GetV1().NamespaceDefinitions, numTypes) +} + +// verifySchemaTypes reads schema through the datalayer at the given revision +// and verifies the expected type definition and caveat definition names are present. +func verifySchemaTypes(t *testing.T, ctx context.Context, dl datalayer.DataLayer, rev datastore.Revision, expectedTypes []string, expectedCaveats []string) { + t.Helper() + + reader := dl.SnapshotReader(rev, datalayer.NoSchemaHashForTesting) + schemaReader, err := reader.ReadSchema(ctx) + require.NoError(t, err) + + readText, err := schemaReader.SchemaText(ctx) + require.NoError(t, err) + require.NotEmpty(t, readText) + + typeDefs, err := schemaReader.ListAllTypeDefinitions(ctx) + require.NoError(t, err) + require.Len(t, typeDefs, len(expectedTypes)) + + typeNames := make([]string, 0, len(typeDefs)) + for _, td := range typeDefs { + typeNames = append(typeNames, td.Definition.Name) + } + sort.Strings(typeNames) + require.Equal(t, expectedTypes, typeNames) + + caveatDefs, err := schemaReader.ListAllCaveatDefinitions(ctx) + require.NoError(t, err) + require.Len(t, caveatDefs, len(expectedCaveats)) + + caveatNames := make([]string, 0, len(caveatDefs)) + for _, cd := range caveatDefs { + caveatNames = append(caveatNames, cd.Definition.Name) + } + sort.Strings(caveatNames) + require.Equal(t, expectedCaveats, caveatNames) +} diff --git a/pkg/development/check.go b/pkg/development/check.go index 0cc9944c1..5791029c9 100644 --- a/pkg/development/check.go +++ b/pkg/development/check.go @@ -6,6 +6,7 @@ import ( "github.com/authzed/spicedb/internal/graph/computed" v1 "github.com/authzed/spicedb/internal/services/v1" caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datalayer" v1dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1" "github.com/authzed/spicedb/pkg/tuple" ) @@ -35,6 +36,7 @@ func RunCheck(devContext *DevContext, resource tuple.ObjectAndRelation, subject AtRevision: devContext.Revision, MaximumDepth: maxDispatchDepth, DebugOption: computed.TraceDebuggingEnabled, + SchemaHash: datalayer.NoSchemaHashInDevelopment, }, resource.ObjectID, defaultWasmDispatchChunkSize, @@ -43,7 +45,7 @@ func RunCheck(devContext *DevContext, resource tuple.ObjectAndRelation, subject return CheckResult{v1dispatch.ResourceCheckResult_NOT_MEMBER, nil, nil, nil}, err } - reader := devContext.DataLayer.SnapshotReader(devContext.Revision) + reader := devContext.DataLayer.SnapshotReader(devContext.Revision, datalayer.NoSchemaHashInDevelopment) sr, srErr := reader.ReadSchema(ctx) if srErr != nil { return CheckResult{v1dispatch.ResourceCheckResult_NOT_MEMBER, nil, nil, nil}, srErr diff --git a/pkg/development/validation.go b/pkg/development/validation.go index 7021c1eff..be7e395bb 100644 --- a/pkg/development/validation.go +++ b/pkg/development/validation.go @@ -11,6 +11,7 @@ import ( "github.com/authzed/spicedb/internal/developmentmembership" log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datalayer" devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" "github.com/authzed/spicedb/pkg/tuple" @@ -31,6 +32,7 @@ func RunValidation(devContext *DevContext, validation *blocks.ParsedExpectedRela AtRevision: devContext.Revision.String(), DepthRemaining: maxDispatchDepth, TraversalBloom: v1.MustNewTraversalBloomFilter(uint(maxDispatchDepth)), + SchemaHash: []byte(datalayer.NoSchemaHashInDevelopment), }, ExpansionMode: v1.DispatchExpandRequest_RECURSIVE, }) diff --git a/pkg/middleware/consistency/consistency.go b/pkg/middleware/consistency/consistency.go index c9447edf4..cc1d42098 100644 --- a/pkg/middleware/consistency/consistency.go +++ b/pkg/middleware/consistency/consistency.go @@ -61,7 +61,8 @@ var revisionKey ctxKeyType = struct{}{} var errInvalidZedToken = status.Error(codes.InvalidArgument, "invalid revision requested") type revisionHandle struct { - revision datastore.Revision + revision datastore.Revision + schemaHash datalayer.SchemaHash } // ContextWithHandle adds a placeholder to a context that will later be @@ -70,28 +71,28 @@ func ContextWithHandle(ctx context.Context) context.Context { return context.WithValue(ctx, revisionKey, &revisionHandle{}) } -// RevisionFromContext reads the selected revision out of a context.Context, computes a zedtoken -// from it, and returns an error if it has not been set on the context. -func RevisionFromContext(ctx context.Context) (datastore.Revision, *v1.ZedToken, error) { +// RevisionFromContext reads the selected revision and schema hash out of a context.Context, +// computes a zedtoken from it, and returns an error if it has not been set on the context. +func RevisionFromContext(ctx context.Context) (datastore.Revision, datalayer.SchemaHash, *v1.ZedToken, error) { if c := ctx.Value(revisionKey); c != nil { handle := c.(*revisionHandle) rev := handle.revision if rev != nil { dl := datalayer.FromContext(ctx) if dl == nil { - return nil, nil, spiceerrors.MustBugf("consistency middleware did not inject datastore") + return nil, "", nil, spiceerrors.MustBugf("consistency middleware did not inject datastore") } zedToken, err := zedtoken.NewFromRevision(ctx, rev, dl) if err != nil { - return nil, nil, err + return nil, "", nil, err } - return rev, zedToken, nil + return rev, handle.schemaHash, zedToken, nil } } - return nil, nil, status.Error(codes.Internal, "consistency middleware did not inject revision") + return nil, "", nil, status.Error(codes.Internal, "consistency middleware did not inject revision") } // AddRevisionToContext adds a revision to the given context, based on the consistency block found @@ -114,6 +115,7 @@ func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency } var revision datastore.Revision + var schemaHash datalayer.SchemaHash consistency := req.GetConsistency() withOptionalCursor, hasOptionalCursor := req.(hasOptionalCursor) @@ -125,7 +127,7 @@ func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency ConsistencyCounter.WithLabelValues("snapshot", "cursor", serviceLabel).Inc() } - requestedRev, _, err := cursor.DecodeToDispatchRevision(ctx, withOptionalCursor.GetOptionalCursor(), dl) + requestedRev, cursorSchemaHash, _, err := cursor.DecodeToDispatchRevisionAndSchemaHash(ctx, withOptionalCursor.GetOptionalCursor(), dl) if err != nil { return rewriteDatastoreError(err) } @@ -136,6 +138,7 @@ func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency } revision = requestedRev + schemaHash = cursorSchemaHash case consistency == nil || consistency.GetMinimizeLatency(): // Minimize Latency: Use the datastore's current revision, whatever it may be. @@ -148,11 +151,12 @@ func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency ConsistencyCounter.WithLabelValues("minlatency", source, serviceLabel).Inc() } - databaseRev, err := dl.OptimizedRevision(ctx) + databaseRev, hash, err := dl.OptimizedRevision(ctx) if err != nil { return rewriteDatastoreError(err) } revision = databaseRev + schemaHash = hash case consistency.GetFullyConsistent(): // Fully Consistent: Use the datastore's synchronized revision. @@ -160,16 +164,17 @@ func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency ConsistencyCounter.WithLabelValues("full", "request", serviceLabel).Inc() } - databaseRev, err := dl.HeadRevision(ctx) + databaseRev, hash, err := dl.HeadRevision(ctx) if err != nil { return rewriteDatastoreError(err) } revision = databaseRev + schemaHash = hash case consistency.GetAtLeastAsFresh() != nil: // At least as fresh as: Pick one of the datastore's revision and that specified, which // ever is later. - picked, pickedRequest, err := pickBestRevision(ctx, consistency.GetAtLeastAsFresh(), dl, option) + picked, hash, pickedRequest, err := pickBestRevision(ctx, consistency.GetAtLeastAsFresh(), dl, option) if err != nil { return rewriteDatastoreError(err) } @@ -184,6 +189,7 @@ func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency } revision = picked + schemaHash = hash case consistency.GetAtExactSnapshot() != nil: // Exact snapshot: Use the revision as encoded in the zed token. @@ -206,12 +212,15 @@ func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency } revision = requestedRev + schemaHash = datalayer.NoSchemaHashForLegacyCursor default: return status.Errorf(codes.Internal, "missing handling of consistency case in %v", consistency) } - handle.(*revisionHandle).revision = revision + rh := handle.(*revisionHandle) + rh.revision = revision + rh.schemaHash = schemaHash return nil } @@ -274,51 +283,51 @@ func (s *recvWrapper) RecvMsg(m any) error { // pickBestRevision compares the provided ZedToken with the optimized revision of the datastore, and returns the most // recent one. The boolean return value will be true if the provided ZedToken is the most recent, false otherwise. -func pickBestRevision(ctx context.Context, requested *v1.ZedToken, dl datalayer.DataLayer, option MismatchingTokenOption) (datastore.Revision, bool, error) { +func pickBestRevision(ctx context.Context, requested *v1.ZedToken, dl datalayer.DataLayer, option MismatchingTokenOption) (datastore.Revision, datalayer.SchemaHash, bool, error) { // Calculate a revision as we see fit - databaseRev, err := dl.OptimizedRevision(ctx) + databaseRev, hash, err := dl.OptimizedRevision(ctx) if err != nil { - return datastore.NoRevision, false, err + return datastore.NoRevision, "", false, err } if requested != nil { requestedRev, status, err := zedtoken.DecodeRevision(requested, dl) if err != nil { - return datastore.NoRevision, false, errInvalidZedToken + return datastore.NoRevision, "", false, errInvalidZedToken } if status == zedtoken.StatusMismatchedDatastoreID { switch option { case TreatMismatchingTokensAsFullConsistency: log.Warn().Str("zedtoken", requested.Token).Msg("ZedToken specified references a different datastore instance and SpiceDB is configured to treat this as a full consistency request") - headRev, err := dl.HeadRevision(ctx) + headRev, headHash, err := dl.HeadRevision(ctx) if err != nil { - return datastore.NoRevision, false, err + return datastore.NoRevision, "", false, err } - return headRev, false, nil + return headRev, headHash, false, nil case TreatMismatchingTokensAsMinLatency: log.Warn().Str("zedtoken", requested.Token).Msg("ZedToken specified references a different datastore instance and SpiceDB is configured to treat this as a min latency request") - return databaseRev, false, nil + return databaseRev, hash, false, nil case TreatMismatchingTokensAsError: log.Warn().Str("zedtoken", requested.Token).Msg("ZedToken specified references a different datastore instance and SpiceDB is configured to raise an error in this scenario") - return datastore.NoRevision, false, errors.New("ZedToken specified references a different datastore instance and SpiceDB is configured to raise an error in this scenario") + return datastore.NoRevision, "", false, errors.New("ZedToken specified references a different datastore instance and SpiceDB is configured to raise an error in this scenario") default: - return datastore.NoRevision, false, spiceerrors.MustBugf("unknown mismatching token option: %v", option) + return datastore.NoRevision, "", false, spiceerrors.MustBugf("unknown mismatching token option: %v", option) } } if databaseRev.GreaterThan(requestedRev) { - return databaseRev, false, nil + return databaseRev, hash, false, nil } - return requestedRev, true, nil + return requestedRev, datalayer.NoSchemaHashForLegacyCursor, true, nil } - return databaseRev, false, nil + return databaseRev, hash, false, nil } func rewriteDatastoreError(err error) error { diff --git a/pkg/middleware/consistency/consistency_test.go b/pkg/middleware/consistency/consistency_test.go index 312884882..882452396 100644 --- a/pkg/middleware/consistency/consistency_test.go +++ b/pkg/middleware/consistency/consistency_test.go @@ -17,6 +17,7 @@ import ( "github.com/authzed/spicedb/pkg/datalayer" "github.com/authzed/spicedb/pkg/datastore" dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + impl "github.com/authzed/spicedb/pkg/proto/impl/v1" "github.com/authzed/spicedb/pkg/zedtoken" ) @@ -40,7 +41,7 @@ func TestAddRevisionToContextNoneSupplied(t *testing.T) { err := AddRevisionToContext(updated, &v1.ReadRelationshipsRequest{}, dl, "somelabel", TreatMismatchingTokensAsError) require.NoError(err) - rev, _, err := RevisionFromContext(updated) + rev, _, _, err := RevisionFromContext(updated) require.NoError(err) require.True(optimized.Equal(rev)) @@ -66,7 +67,7 @@ func TestAddRevisionToContextMinimizeLatency(t *testing.T) { }, dl, "somelabel", TreatMismatchingTokensAsError) require.NoError(err) - rev, _, err := RevisionFromContext(updated) + rev, _, _, err := RevisionFromContext(updated) require.NoError(err) require.True(optimized.Equal(rev)) @@ -92,7 +93,7 @@ func TestAddRevisionToContextFullyConsistent(t *testing.T) { }, dl, "somelabel", TreatMismatchingTokensAsError) require.NoError(err) - rev, _, err := RevisionFromContext(updated) + rev, _, _, err := RevisionFromContext(updated) require.NoError(err) require.True(head.Equal(rev)) @@ -119,7 +120,7 @@ func TestAddRevisionToContextAtLeastAsFresh(t *testing.T) { }, dl, "somelabel", TreatMismatchingTokensAsError) require.NoError(err) - rev, _, err := RevisionFromContext(updated) + rev, _, _, err := RevisionFromContext(updated) require.NoError(err) require.True(exact.Equal(rev)) @@ -146,7 +147,7 @@ func TestAddRevisionToContextAtValidExactSnapshot(t *testing.T) { }, dl, "somelabel", TreatMismatchingTokensAsError) require.NoError(err) - rev, _, err := RevisionFromContext(updated) + rev, _, _, err := RevisionFromContext(updated) require.NoError(err) require.True(exact.Equal(rev)) @@ -185,7 +186,7 @@ func TestAddRevisionToContextNoConsistencyAPI(t *testing.T) { updated := ContextWithHandle(t.Context()) updated = datalayer.ContextWithDataLayer(updated, dl) - _, _, err := RevisionFromContext(updated) + _, _, _, err := RevisionFromContext(updated) require.Error(err) } @@ -198,7 +199,7 @@ func TestAddRevisionToContextWithCursor(t *testing.T) { dl := datalayer.NewDataLayer(ds) // cursor is at `optimized` - cursor, err := cursor.EncodeFromDispatchCursor(&dispatch.Cursor{}, "somehash", optimized, nil) + cursor, err := cursor.EncodeFromDispatchCursor(&dispatch.Cursor{}, "somehash", optimized, datalayer.NoSchemaHashForLegacyCursor, nil) require.NoError(err) // revision in context is at `exact` @@ -216,13 +217,57 @@ func TestAddRevisionToContextWithCursor(t *testing.T) { require.NoError(err) // ensure we get back `optimized` from the cursor - rev, _, err := RevisionFromContext(updated) + rev, _, _, err := RevisionFromContext(updated) require.NoError(err) require.True(optimized.Equal(rev)) ds.AssertExpectations(t) } +func TestAddRevisionToContextWithCursorAndSchemaHash(t *testing.T) { + require := require.New(t) + + ds := &proxy_test.MockDatastore{} + ds.On("CheckRevision", optimized).Return(nil).Times(1) + ds.On("RevisionFromString", optimized.String()).Return(optimized, nil).Once() + dl := datalayer.NewDataLayer(ds) + + // Encode a cursor with DatastoreUniqueId set so the schema hash roundtrips. + // The mock datastore returns "mockds" as its unique ID. + encodedCursor, err := cursor.Encode(&impl.DecodedCursor{ + VersionOneof: &impl.DecodedCursor_V1{ + V1: &impl.V1Cursor{ + Revision: optimized.String(), + DispatchVersion: 1, + CallAndParametersHash: "somehash", + DatastoreUniqueId: "mockds", + SchemaHash: []byte("myspecialschema"), + }, + }, + }) + require.NoError(err) + + updated := ContextWithHandle(t.Context()) + updated = datalayer.ContextWithDataLayer(updated, dl) + + err = AddRevisionToContext(updated, &v1.LookupResourcesRequest{ + Consistency: &v1.Consistency{ + Requirement: &v1.Consistency_AtExactSnapshot{ + AtExactSnapshot: zedtoken.MustNewFromRevisionForTesting(exact), + }, + }, + OptionalCursor: encodedCursor, + }, dl, "somelabel", TreatMismatchingTokensAsError) + require.NoError(err) + + rev, schemaHash, _, err := RevisionFromContext(updated) + require.NoError(err) + + require.True(optimized.Equal(rev)) + require.Equal(datalayer.SchemaHash("myspecialschema"), schemaHash) + ds.AssertExpectations(t) +} + func TestAddRevisionToContextAtMalformedExactSnapshot(t *testing.T) { err := AddRevisionToContext(ContextWithHandle(t.Context()), &v1.LookupResourcesRequest{ Consistency: &v1.Consistency{ @@ -253,7 +298,7 @@ func TestAddRevisionToContextMalformedAtLeastAsFreshSnapshot(t *testing.T) { func TestRevisionFromContextMissingConsistency(t *testing.T) { updated := ContextWithHandle(t.Context()) - _, _, err := RevisionFromContext(updated) + _, _, _, err := RevisionFromContext(updated) require.Error(t, err) grpcutil.RequireStatus(t, codes.Internal, err) require.ErrorContains(t, err, "consistency middleware did not inject revision") @@ -389,7 +434,7 @@ func TestAtLeastAsFreshWithMismatchedTokenExpectMinLatency(t *testing.T) { }, dl, "somelabel", TreatMismatchingTokensAsMinLatency) require.NoError(err) - rev, _, err := RevisionFromContext(updated) + rev, _, _, err := RevisionFromContext(updated) require.NoError(err) require.True(optimized.Equal(rev)) @@ -424,7 +469,7 @@ func TestAtLeastAsFreshWithMismatchedTokenExpectFullConsistency(t *testing.T) { }, dl, "somelabel", TreatMismatchingTokensAsFullConsistency) require.NoError(err) - rev, _, err := RevisionFromContext(updated) + rev, _, _, err := RevisionFromContext(updated) require.NoError(err) require.True(head.Equal(rev)) @@ -453,7 +498,7 @@ func TestAddRevisionToContextAtLeastAsFreshMatchingIDs(t *testing.T) { }, dl, "somelabel", TreatMismatchingTokensAsError) require.NoError(err) - rev, _, err := RevisionFromContext(updated) + rev, _, _, err := RevisionFromContext(updated) require.NoError(err) require.True(exact.Equal(rev)) diff --git a/pkg/middleware/consistency/forcefull.go b/pkg/middleware/consistency/forcefull.go index 68f2f5bb0..6441bcd9a 100644 --- a/pkg/middleware/consistency/forcefull.go +++ b/pkg/middleware/consistency/forcefull.go @@ -53,11 +53,13 @@ func setFullConsistencyRevisionToContext(ctx context.Context, req any, dl datala ConsistencyCounter.WithLabelValues("full", "request", serviceLabel).Inc() } - databaseRev, err := dl.HeadRevision(ctx) + databaseRev, hash, err := dl.HeadRevision(ctx) if err != nil { return rewriteDatastoreError(err) } - handle.(*revisionHandle).revision = databaseRev + rh := handle.(*revisionHandle) + rh.revision = databaseRev + rh.schemaHash = hash } return nil diff --git a/pkg/middleware/consistency/forcefull_test.go b/pkg/middleware/consistency/forcefull_test.go index 6135ed21f..a0b5b27f4 100644 --- a/pkg/middleware/consistency/forcefull_test.go +++ b/pkg/middleware/consistency/forcefull_test.go @@ -47,7 +47,7 @@ func TestSetFullConsistencyRevisionToContext(t *testing.T) { err := setFullConsistencyRevisionToContext(ctx, &requestWithConsistency{}, dl, "", TreatMismatchingTokensAsFullConsistency) require.NoError(t, err) - rev, _, err := RevisionFromContext(ctx) + rev, _, _, err := RevisionFromContext(ctx) require.Error(t, err) require.Nil(t, rev) }) @@ -60,7 +60,7 @@ func TestSetFullConsistencyRevisionToContext(t *testing.T) { dl := mock_datalayer.NewMockDataLayer(ctrl) mockRev := mocks.NewMockRevision(ctrl) mockRev.EXPECT().String().Return("a revision").Times(1) - dl.EXPECT().HeadRevision(gomock.Any()).Return(mockRev, nil).Times(1) + dl.EXPECT().HeadRevision(gomock.Any()).Return(mockRev, datalayer.NoSchemaHashInLegacyMode, nil).Times(1) dl.EXPECT().UniqueID(gomock.Any()).Return("uniqueid", nil).Times(1) ctx := ContextWithHandle(t.Context()) @@ -69,7 +69,7 @@ func TestSetFullConsistencyRevisionToContext(t *testing.T) { err := setFullConsistencyRevisionToContext(ctx, &requestWithConsistency{}, dl, "somelabel", TreatMismatchingTokensAsFullConsistency) require.NoError(t, err) - rev, _, err := RevisionFromContext(ctx) + rev, _, _, err := RevisionFromContext(ctx) require.NoError(t, err) require.Equal(t, mockRev, rev) }) @@ -82,7 +82,7 @@ func TestSetFullConsistencyRevisionToContext(t *testing.T) { dl := mock_datalayer.NewMockDataLayer(ctrl) mockRev := mocks.NewMockRevision(ctrl) mockRev.EXPECT().String().Return("a revision").Times(1) - dl.EXPECT().HeadRevision(gomock.Any()).Return(mockRev, nil).Times(1) + dl.EXPECT().HeadRevision(gomock.Any()).Return(mockRev, datalayer.NoSchemaHashInLegacyMode, nil).Times(1) dl.EXPECT().UniqueID(gomock.Any()).Return("uniqueid", nil).Times(1) ctx := ContextWithHandle(t.Context()) @@ -91,7 +91,7 @@ func TestSetFullConsistencyRevisionToContext(t *testing.T) { err := setFullConsistencyRevisionToContext(ctx, &requestWithConsistency{}, dl, "", TreatMismatchingTokensAsFullConsistency) require.NoError(t, err) - rev, _, err := RevisionFromContext(ctx) + rev, _, _, err := RevisionFromContext(ctx) require.NoError(t, err) require.Equal(t, mockRev, rev) }) @@ -115,7 +115,7 @@ func TestSetFullConsistencyRevisionToContext(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() dl := mock_datalayer.NewMockDataLayer(ctrl) - dl.EXPECT().HeadRevision(gomock.Any()).Return(nil, errors.New("some error")).Times(1) + dl.EXPECT().HeadRevision(gomock.Any()).Return(nil, datalayer.NoSchemaHashInLegacyMode, errors.New("some error")).Times(1) ctx := ContextWithHandle(t.Context()) ctx = datalayer.ContextWithDataLayer(ctx, dl) @@ -137,7 +137,7 @@ func TestForceFullConsistencyUnaryServerInterceptor(t *testing.T) { dl := mock_datalayer.NewMockDataLayer(ctrl) mockRev := mocks.NewMockRevision(ctrl) mockRev.EXPECT().String().Return("a revision").Times(1) - dl.EXPECT().HeadRevision(gomock.Any()).Return(mockRev, nil).Times(1) + dl.EXPECT().HeadRevision(gomock.Any()).Return(mockRev, datalayer.NoSchemaHashInLegacyMode, nil).Times(1) dl.EXPECT().UniqueID(gomock.Any()).Return("uniqueid", nil).Times(1) interceptor := ForceFullConsistencyUnaryServerInterceptor("somelabel") ctx := datalayer.ContextWithDataLayer(t.Context(), dl) @@ -157,7 +157,7 @@ func TestForceFullConsistencyUnaryServerInterceptor(t *testing.T) { require.NoError(t, err) require.Equal(t, "response", resp) - rev, _, err := RevisionFromContext(capturedCtx) + rev, _, _, err := RevisionFromContext(capturedCtx) require.NoError(t, err) require.Equal(t, mockRev, rev) }) @@ -168,7 +168,7 @@ func TestForceFullConsistencyUnaryServerInterceptor(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() dl := mock_datalayer.NewMockDataLayer(ctrl) - dl.EXPECT().HeadRevision(gomock.Any()).Return(nil, errors.New("some error")).Times(1) + dl.EXPECT().HeadRevision(gomock.Any()).Return(nil, datalayer.NoSchemaHashInLegacyMode, errors.New("some error")).Times(1) interceptor := ForceFullConsistencyUnaryServerInterceptor("somelabel") ctx := datalayer.ContextWithDataLayer(t.Context(), dl) @@ -242,7 +242,7 @@ func TestForceFullConsistencyStreamServerInterceptor(t *testing.T) { dl := mock_datalayer.NewMockDataLayer(ctrl) mockRev := mocks.NewMockRevision(ctrl) mockRev.EXPECT().String().Return("a revision").Times(1) - dl.EXPECT().HeadRevision(gomock.Any()).Return(mockRev, nil).Times(1) + dl.EXPECT().HeadRevision(gomock.Any()).Return(mockRev, datalayer.NoSchemaHashInLegacyMode, nil).Times(1) dl.EXPECT().UniqueID(gomock.Any()).Return("uniqueid", nil).Times(1) interceptor := ForceFullConsistencyStreamServerInterceptor("somelabel") @@ -269,7 +269,7 @@ func TestForceFullConsistencyStreamServerInterceptor(t *testing.T) { err = wrapper.RecvMsg(&requestWithConsistency{}) require.NoError(t, err) - rev, _, err := RevisionFromContext(wrapper.Context()) + rev, _, _, err := RevisionFromContext(wrapper.Context()) require.NoError(t, err) require.Equal(t, mockRev, rev) }) @@ -280,7 +280,7 @@ func TestForceFullConsistencyStreamServerInterceptor(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() dl := mock_datalayer.NewMockDataLayer(ctrl) - dl.EXPECT().HeadRevision(gomock.Any()).Return(nil, errors.New("some error")).Times(1) + dl.EXPECT().HeadRevision(gomock.Any()).Return(nil, datalayer.NoSchemaHashInLegacyMode, errors.New("some error")).Times(1) interceptor := ForceFullConsistencyStreamServerInterceptor("somelabel") ctx := datalayer.ContextWithDataLayer(t.Context(), dl) @@ -369,7 +369,7 @@ func TestForceFullConsistencyUnaryBypassWhitelist(t *testing.T) { require.NoError(t, err) require.Equal(t, "bypassed", resp) - rev, _, err := RevisionFromContext(capturedCtx) + rev, _, _, err := RevisionFromContext(capturedCtx) require.Error(t, err) require.Nil(t, rev) }) @@ -417,7 +417,7 @@ func TestSetFullConsistencyRevisionToContextWithReadonlyError(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() dl := mock_datalayer.NewMockDataLayer(ctrl) - dl.EXPECT().HeadRevision(gomock.Any()).Return(nil, datastore.NewReadonlyErr()).Times(1) + dl.EXPECT().HeadRevision(gomock.Any()).Return(nil, datalayer.NoSchemaHashInLegacyMode, datastore.NewReadonlyErr()).Times(1) ctx := ContextWithHandle(t.Context()) ctx = datalayer.ContextWithDataLayer(ctx, dl) diff --git a/pkg/middleware/datalayer/datastore.go b/pkg/middleware/datalayer/datastore.go index 5d98a6ff6..42fd006d9 100644 --- a/pkg/middleware/datalayer/datastore.go +++ b/pkg/middleware/datalayer/datastore.go @@ -3,8 +3,6 @@ package datalayer import ( "context" - "google.golang.org/grpc" - "github.com/authzed/spicedb/pkg/datalayer" ) @@ -24,15 +22,3 @@ func MustFromContext(ctx context.Context) datalayer.DataLayer { return dl } - -// UnaryCountingInterceptor wraps the datalayer with a counting proxy for unary requests. -// After each request completes, it exports the method call counts to Prometheus metrics. -func UnaryCountingInterceptor() grpc.UnaryServerInterceptor { - return datalayer.UnaryCountingInterceptor(nil) -} - -// StreamCountingInterceptor wraps the datalayer with a counting proxy for stream requests. -// After each stream completes, it exports the method call counts to Prometheus metrics. -func StreamCountingInterceptor() grpc.StreamServerInterceptor { - return datalayer.StreamCountingInterceptor(nil) -} diff --git a/pkg/proto/core/v1/core.pb.go b/pkg/proto/core/v1/core.pb.go index 4afc65200..4478de048 100644 --- a/pkg/proto/core/v1/core.pb.go +++ b/pkg/proto/core/v1/core.pb.go @@ -2592,6 +2592,81 @@ func (x *SubjectFilter) GetOptionalRelation() *SubjectFilter_RelationFilter { return nil } +// StoredSchema represents a stored schema in SpiceDB under the new, unified schema format. +type StoredSchema struct { + state protoimpl.MessageState `protogen:"open.v1"` + Version uint32 `protobuf:"varint,1,opt,name=version,proto3" json:"version,omitempty"` + // Types that are valid to be assigned to VersionOneof: + // + // *StoredSchema_V1 + VersionOneof isStoredSchema_VersionOneof `protobuf_oneof:"version_oneof"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StoredSchema) Reset() { + *x = StoredSchema{} + mi := &file_core_v1_core_proto_msgTypes[33] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StoredSchema) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StoredSchema) ProtoMessage() {} + +func (x *StoredSchema) ProtoReflect() protoreflect.Message { + mi := &file_core_v1_core_proto_msgTypes[33] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StoredSchema.ProtoReflect.Descriptor instead. +func (*StoredSchema) Descriptor() ([]byte, []int) { + return file_core_v1_core_proto_rawDescGZIP(), []int{33} +} + +func (x *StoredSchema) GetVersion() uint32 { + if x != nil { + return x.Version + } + return 0 +} + +func (x *StoredSchema) GetVersionOneof() isStoredSchema_VersionOneof { + if x != nil { + return x.VersionOneof + } + return nil +} + +func (x *StoredSchema) GetV1() *StoredSchema_V1StoredSchema { + if x != nil { + if x, ok := x.VersionOneof.(*StoredSchema_V1); ok { + return x.V1 + } + } + return nil +} + +type isStoredSchema_VersionOneof interface { + isStoredSchema_VersionOneof() +} + +type StoredSchema_V1 struct { + V1 *StoredSchema_V1StoredSchema `protobuf:"bytes,2,opt,name=v1,proto3,oneof"` +} + +func (*StoredSchema_V1) isStoredSchema_VersionOneof() {} + type AllowedRelation_PublicWildcard struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -2600,7 +2675,7 @@ type AllowedRelation_PublicWildcard struct { func (x *AllowedRelation_PublicWildcard) Reset() { *x = AllowedRelation_PublicWildcard{} - mi := &file_core_v1_core_proto_msgTypes[36] + mi := &file_core_v1_core_proto_msgTypes[37] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2612,7 +2687,7 @@ func (x *AllowedRelation_PublicWildcard) String() string { func (*AllowedRelation_PublicWildcard) ProtoMessage() {} func (x *AllowedRelation_PublicWildcard) ProtoReflect() protoreflect.Message { - mi := &file_core_v1_core_proto_msgTypes[36] + mi := &file_core_v1_core_proto_msgTypes[37] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2653,7 +2728,7 @@ type SetOperation_Child struct { func (x *SetOperation_Child) Reset() { *x = SetOperation_Child{} - mi := &file_core_v1_core_proto_msgTypes[37] + mi := &file_core_v1_core_proto_msgTypes[38] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2665,7 +2740,7 @@ func (x *SetOperation_Child) String() string { func (*SetOperation_Child) ProtoMessage() {} func (x *SetOperation_Child) ProtoReflect() protoreflect.Message { - mi := &file_core_v1_core_proto_msgTypes[37] + mi := &file_core_v1_core_proto_msgTypes[38] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2821,7 +2896,7 @@ type SetOperation_Child_This struct { func (x *SetOperation_Child_This) Reset() { *x = SetOperation_Child_This{} - mi := &file_core_v1_core_proto_msgTypes[38] + mi := &file_core_v1_core_proto_msgTypes[39] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2833,7 +2908,7 @@ func (x *SetOperation_Child_This) String() string { func (*SetOperation_Child_This) ProtoMessage() {} func (x *SetOperation_Child_This) ProtoReflect() protoreflect.Message { - mi := &file_core_v1_core_proto_msgTypes[38] + mi := &file_core_v1_core_proto_msgTypes[39] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2857,7 +2932,7 @@ type SetOperation_Child_Nil struct { func (x *SetOperation_Child_Nil) Reset() { *x = SetOperation_Child_Nil{} - mi := &file_core_v1_core_proto_msgTypes[39] + mi := &file_core_v1_core_proto_msgTypes[40] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2869,7 +2944,7 @@ func (x *SetOperation_Child_Nil) String() string { func (*SetOperation_Child_Nil) ProtoMessage() {} func (x *SetOperation_Child_Nil) ProtoReflect() protoreflect.Message { - mi := &file_core_v1_core_proto_msgTypes[39] + mi := &file_core_v1_core_proto_msgTypes[40] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2894,7 +2969,7 @@ type SetOperation_Child_Self struct { func (x *SetOperation_Child_Self) Reset() { *x = SetOperation_Child_Self{} - mi := &file_core_v1_core_proto_msgTypes[40] + mi := &file_core_v1_core_proto_msgTypes[41] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2906,7 +2981,7 @@ func (x *SetOperation_Child_Self) String() string { func (*SetOperation_Child_Self) ProtoMessage() {} func (x *SetOperation_Child_Self) ProtoReflect() protoreflect.Message { - mi := &file_core_v1_core_proto_msgTypes[40] + mi := &file_core_v1_core_proto_msgTypes[41] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2931,7 +3006,7 @@ type TupleToUserset_Tupleset struct { func (x *TupleToUserset_Tupleset) Reset() { *x = TupleToUserset_Tupleset{} - mi := &file_core_v1_core_proto_msgTypes[41] + mi := &file_core_v1_core_proto_msgTypes[42] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2943,7 +3018,7 @@ func (x *TupleToUserset_Tupleset) String() string { func (*TupleToUserset_Tupleset) ProtoMessage() {} func (x *TupleToUserset_Tupleset) ProtoReflect() protoreflect.Message { - mi := &file_core_v1_core_proto_msgTypes[41] + mi := &file_core_v1_core_proto_msgTypes[42] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2975,7 +3050,7 @@ type FunctionedTupleToUserset_Tupleset struct { func (x *FunctionedTupleToUserset_Tupleset) Reset() { *x = FunctionedTupleToUserset_Tupleset{} - mi := &file_core_v1_core_proto_msgTypes[42] + mi := &file_core_v1_core_proto_msgTypes[43] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2987,7 +3062,7 @@ func (x *FunctionedTupleToUserset_Tupleset) String() string { func (*FunctionedTupleToUserset_Tupleset) ProtoMessage() {} func (x *FunctionedTupleToUserset_Tupleset) ProtoReflect() protoreflect.Message { - mi := &file_core_v1_core_proto_msgTypes[42] + mi := &file_core_v1_core_proto_msgTypes[43] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3019,7 +3094,7 @@ type SubjectFilter_RelationFilter struct { func (x *SubjectFilter_RelationFilter) Reset() { *x = SubjectFilter_RelationFilter{} - mi := &file_core_v1_core_proto_msgTypes[43] + mi := &file_core_v1_core_proto_msgTypes[44] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3031,7 +3106,7 @@ func (x *SubjectFilter_RelationFilter) String() string { func (*SubjectFilter_RelationFilter) ProtoMessage() {} func (x *SubjectFilter_RelationFilter) ProtoReflect() protoreflect.Message { - mi := &file_core_v1_core_proto_msgTypes[43] + mi := &file_core_v1_core_proto_msgTypes[44] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3054,6 +3129,80 @@ func (x *SubjectFilter_RelationFilter) GetRelation() string { return "" } +type StoredSchema_V1StoredSchema struct { + state protoimpl.MessageState `protogen:"open.v1"` + // schema_text is the text of the schema that was given to SpiceDB by the API caller. + SchemaText string `protobuf:"bytes,1,opt,name=schema_text,json=schemaText,proto3" json:"schema_text,omitempty"` + // schema_hash is the hash of the schema, for change detection. + SchemaHash string `protobuf:"bytes,2,opt,name=schema_hash,json=schemaHash,proto3" json:"schema_hash,omitempty"` + // namespace_definitions is a map of namespace name to NamespaceDefinition. + // Entries must have full metadata filled out. + NamespaceDefinitions map[string]*NamespaceDefinition `protobuf:"bytes,3,rep,name=namespace_definitions,json=namespaceDefinitions,proto3" json:"namespace_definitions,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + // caveat_definitions is a map of caveat name to CaveatDefinition. + // Entries must have full metadata filled out. + CaveatDefinitions map[string]*CaveatDefinition `protobuf:"bytes,4,rep,name=caveat_definitions,json=caveatDefinitions,proto3" json:"caveat_definitions,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StoredSchema_V1StoredSchema) Reset() { + *x = StoredSchema_V1StoredSchema{} + mi := &file_core_v1_core_proto_msgTypes[45] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StoredSchema_V1StoredSchema) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StoredSchema_V1StoredSchema) ProtoMessage() {} + +func (x *StoredSchema_V1StoredSchema) ProtoReflect() protoreflect.Message { + mi := &file_core_v1_core_proto_msgTypes[45] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StoredSchema_V1StoredSchema.ProtoReflect.Descriptor instead. +func (*StoredSchema_V1StoredSchema) Descriptor() ([]byte, []int) { + return file_core_v1_core_proto_rawDescGZIP(), []int{33, 0} +} + +func (x *StoredSchema_V1StoredSchema) GetSchemaText() string { + if x != nil { + return x.SchemaText + } + return "" +} + +func (x *StoredSchema_V1StoredSchema) GetSchemaHash() string { + if x != nil { + return x.SchemaHash + } + return "" +} + +func (x *StoredSchema_V1StoredSchema) GetNamespaceDefinitions() map[string]*NamespaceDefinition { + if x != nil { + return x.NamespaceDefinitions + } + return nil +} + +func (x *StoredSchema_V1StoredSchema) GetCaveatDefinitions() map[string]*CaveatDefinition { + if x != nil { + return x.CaveatDefinitions + } + return nil +} + var File_core_v1_core_proto protoreflect.FileDescriptor const file_core_v1_core_proto_rawDesc = "" + @@ -3260,7 +3409,24 @@ const file_core_v1_core_proto_rawDesc = "" + "\x13optional_subject_id\x18\x02 \x01(\tB*\xbaH'r%(\x80\b2 ^(([a-zA-Z0-9/_|\\-=+]{1,})|\\*)?$R\x11optionalSubjectId\x12R\n" + "\x11optional_relation\x18\x03 \x01(\v2%.core.v1.SubjectFilter.RelationFilterR\x10optionalRelation\x1aX\n" + "\x0eRelationFilter\x12F\n" + - "\brelation\x18\x01 \x01(\tB*\xbaH'r%(@2!^([a-z][a-z0-9_]{1,62}[a-z0-9])?$R\brelationB\x8a\x01\n" + + "\brelation\x18\x01 \x01(\tB*\xbaH'r%(@2!^([a-z][a-z0-9_]{1,62}[a-z0-9])?$R\brelation\"\xef\x04\n" + + "\fStoredSchema\x12\x18\n" + + "\aversion\x18\x01 \x01(\rR\aversion\x126\n" + + "\x02v1\x18\x02 \x01(\v2$.core.v1.StoredSchema.V1StoredSchemaH\x00R\x02v1\x1a\xfb\x03\n" + + "\x0eV1StoredSchema\x12\x1f\n" + + "\vschema_text\x18\x01 \x01(\tR\n" + + "schemaText\x12\x1f\n" + + "\vschema_hash\x18\x02 \x01(\tR\n" + + "schemaHash\x12s\n" + + "\x15namespace_definitions\x18\x03 \x03(\v2>.core.v1.StoredSchema.V1StoredSchema.NamespaceDefinitionsEntryR\x14namespaceDefinitions\x12j\n" + + "\x12caveat_definitions\x18\x04 \x03(\v2;.core.v1.StoredSchema.V1StoredSchema.CaveatDefinitionsEntryR\x11caveatDefinitions\x1ae\n" + + "\x19NamespaceDefinitionsEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x122\n" + + "\x05value\x18\x02 \x01(\v2\x1c.core.v1.NamespaceDefinitionR\x05value:\x028\x01\x1a_\n" + + "\x16CaveatDefinitionsEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12/\n" + + "\x05value\x18\x02 \x01(\v2\x19.core.v1.CaveatDefinitionR\x05value:\x028\x01B\x0f\n" + + "\rversion_oneofB\x8a\x01\n" + "\vcom.core.v1B\tCoreProtoP\x01Z3github.com/authzed/spicedb/pkg/proto/core/v1;corev1\xa2\x02\x03CXX\xaa\x02\aCore.V1\xca\x02\aCore\\V1\xe2\x02\x13Core\\V1\\GPBMetadata\xea\x02\bCore::V1b\x06proto3" var ( @@ -3276,7 +3442,7 @@ func file_core_v1_core_proto_rawDescGZIP() []byte { } var file_core_v1_core_proto_enumTypes = make([]protoimpl.EnumInfo, 7) -var file_core_v1_core_proto_msgTypes = make([]protoimpl.MessageInfo, 44) +var file_core_v1_core_proto_msgTypes = make([]protoimpl.MessageInfo, 48) var file_core_v1_core_proto_goTypes = []any{ (RelationTupleUpdate_Operation)(0), // 0: core.v1.RelationTupleUpdate.Operation (SetOperationUserset_Operation)(0), // 1: core.v1.SetOperationUserset.Operation @@ -3318,30 +3484,34 @@ var file_core_v1_core_proto_goTypes = []any{ (*CaveatOperation)(nil), // 37: core.v1.CaveatOperation (*RelationshipFilter)(nil), // 38: core.v1.RelationshipFilter (*SubjectFilter)(nil), // 39: core.v1.SubjectFilter - nil, // 40: core.v1.CaveatDefinition.ParameterTypesEntry - nil, // 41: core.v1.ReachabilityGraph.EntrypointsBySubjectTypeEntry - nil, // 42: core.v1.ReachabilityGraph.EntrypointsBySubjectRelationEntry - (*AllowedRelation_PublicWildcard)(nil), // 43: core.v1.AllowedRelation.PublicWildcard - (*SetOperation_Child)(nil), // 44: core.v1.SetOperation.Child - (*SetOperation_Child_This)(nil), // 45: core.v1.SetOperation.Child.This - (*SetOperation_Child_Nil)(nil), // 46: core.v1.SetOperation.Child.Nil - (*SetOperation_Child_Self)(nil), // 47: core.v1.SetOperation.Child.Self - (*TupleToUserset_Tupleset)(nil), // 48: core.v1.TupleToUserset.Tupleset - (*FunctionedTupleToUserset_Tupleset)(nil), // 49: core.v1.FunctionedTupleToUserset.Tupleset - (*SubjectFilter_RelationFilter)(nil), // 50: core.v1.SubjectFilter.RelationFilter - (*timestamppb.Timestamp)(nil), // 51: google.protobuf.Timestamp - (*structpb.Struct)(nil), // 52: google.protobuf.Struct - (*anypb.Any)(nil), // 53: google.protobuf.Any + (*StoredSchema)(nil), // 40: core.v1.StoredSchema + nil, // 41: core.v1.CaveatDefinition.ParameterTypesEntry + nil, // 42: core.v1.ReachabilityGraph.EntrypointsBySubjectTypeEntry + nil, // 43: core.v1.ReachabilityGraph.EntrypointsBySubjectRelationEntry + (*AllowedRelation_PublicWildcard)(nil), // 44: core.v1.AllowedRelation.PublicWildcard + (*SetOperation_Child)(nil), // 45: core.v1.SetOperation.Child + (*SetOperation_Child_This)(nil), // 46: core.v1.SetOperation.Child.This + (*SetOperation_Child_Nil)(nil), // 47: core.v1.SetOperation.Child.Nil + (*SetOperation_Child_Self)(nil), // 48: core.v1.SetOperation.Child.Self + (*TupleToUserset_Tupleset)(nil), // 49: core.v1.TupleToUserset.Tupleset + (*FunctionedTupleToUserset_Tupleset)(nil), // 50: core.v1.FunctionedTupleToUserset.Tupleset + (*SubjectFilter_RelationFilter)(nil), // 51: core.v1.SubjectFilter.RelationFilter + (*StoredSchema_V1StoredSchema)(nil), // 52: core.v1.StoredSchema.V1StoredSchema + nil, // 53: core.v1.StoredSchema.V1StoredSchema.NamespaceDefinitionsEntry + nil, // 54: core.v1.StoredSchema.V1StoredSchema.CaveatDefinitionsEntry + (*timestamppb.Timestamp)(nil), // 55: google.protobuf.Timestamp + (*structpb.Struct)(nil), // 56: google.protobuf.Struct + (*anypb.Any)(nil), // 57: google.protobuf.Any } var file_core_v1_core_proto_depIdxs = []int32{ 12, // 0: core.v1.RelationTuple.resource_and_relation:type_name -> core.v1.ObjectAndRelation 12, // 1: core.v1.RelationTuple.subject:type_name -> core.v1.ObjectAndRelation 9, // 2: core.v1.RelationTuple.caveat:type_name -> core.v1.ContextualizedCaveat 8, // 3: core.v1.RelationTuple.integrity:type_name -> core.v1.RelationshipIntegrity - 51, // 4: core.v1.RelationTuple.optional_expiration_time:type_name -> google.protobuf.Timestamp - 51, // 5: core.v1.RelationshipIntegrity.hashed_at:type_name -> google.protobuf.Timestamp - 52, // 6: core.v1.ContextualizedCaveat.context:type_name -> google.protobuf.Struct - 40, // 7: core.v1.CaveatDefinition.parameter_types:type_name -> core.v1.CaveatDefinition.ParameterTypesEntry + 55, // 4: core.v1.RelationTuple.optional_expiration_time:type_name -> google.protobuf.Timestamp + 55, // 5: core.v1.RelationshipIntegrity.hashed_at:type_name -> google.protobuf.Timestamp + 56, // 6: core.v1.ContextualizedCaveat.context:type_name -> google.protobuf.Struct + 41, // 7: core.v1.CaveatDefinition.parameter_types:type_name -> core.v1.CaveatDefinition.ParameterTypesEntry 20, // 8: core.v1.CaveatDefinition.metadata:type_name -> core.v1.Metadata 35, // 9: core.v1.CaveatDefinition.source_position:type_name -> core.v1.SourcePosition 11, // 10: core.v1.CaveatTypeReference.child_types:type_name -> core.v1.CaveatTypeReference @@ -3356,7 +3526,7 @@ var file_core_v1_core_proto_depIdxs = []int32{ 12, // 19: core.v1.DirectSubject.subject:type_name -> core.v1.ObjectAndRelation 36, // 20: core.v1.DirectSubject.caveat_expression:type_name -> core.v1.CaveatExpression 18, // 21: core.v1.DirectSubjects.subjects:type_name -> core.v1.DirectSubject - 53, // 22: core.v1.Metadata.metadata_message:type_name -> google.protobuf.Any + 57, // 22: core.v1.Metadata.metadata_message:type_name -> google.protobuf.Any 22, // 23: core.v1.NamespaceDefinition.relation:type_name -> core.v1.Relation 20, // 24: core.v1.NamespaceDefinition.metadata:type_name -> core.v1.Metadata 35, // 25: core.v1.NamespaceDefinition.source_position:type_name -> core.v1.SourcePosition @@ -3364,15 +3534,15 @@ var file_core_v1_core_proto_depIdxs = []int32{ 26, // 27: core.v1.Relation.type_information:type_name -> core.v1.TypeInformation 20, // 28: core.v1.Relation.metadata:type_name -> core.v1.Metadata 35, // 29: core.v1.Relation.source_position:type_name -> core.v1.SourcePosition - 41, // 30: core.v1.ReachabilityGraph.entrypoints_by_subject_type:type_name -> core.v1.ReachabilityGraph.EntrypointsBySubjectTypeEntry - 42, // 31: core.v1.ReachabilityGraph.entrypoints_by_subject_relation:type_name -> core.v1.ReachabilityGraph.EntrypointsBySubjectRelationEntry + 42, // 30: core.v1.ReachabilityGraph.entrypoints_by_subject_type:type_name -> core.v1.ReachabilityGraph.EntrypointsBySubjectTypeEntry + 43, // 31: core.v1.ReachabilityGraph.entrypoints_by_subject_relation:type_name -> core.v1.ReachabilityGraph.EntrypointsBySubjectRelationEntry 25, // 32: core.v1.ReachabilityEntrypoints.entrypoints:type_name -> core.v1.ReachabilityEntrypoint 13, // 33: core.v1.ReachabilityEntrypoints.subject_relation:type_name -> core.v1.RelationReference 2, // 34: core.v1.ReachabilityEntrypoint.kind:type_name -> core.v1.ReachabilityEntrypoint.ReachabilityEntrypointKind 13, // 35: core.v1.ReachabilityEntrypoint.target_relation:type_name -> core.v1.RelationReference 3, // 36: core.v1.ReachabilityEntrypoint.result_status:type_name -> core.v1.ReachabilityEntrypoint.EntrypointResultStatus 27, // 37: core.v1.TypeInformation.allowed_direct_relations:type_name -> core.v1.AllowedRelation - 43, // 38: core.v1.AllowedRelation.public_wildcard:type_name -> core.v1.AllowedRelation.PublicWildcard + 44, // 38: core.v1.AllowedRelation.public_wildcard:type_name -> core.v1.AllowedRelation.PublicWildcard 35, // 39: core.v1.AllowedRelation.source_position:type_name -> core.v1.SourcePosition 29, // 40: core.v1.AllowedRelation.required_caveat:type_name -> core.v1.AllowedCaveat 28, // 41: core.v1.AllowedRelation.required_expiration:type_name -> core.v1.ExpirationTrait @@ -3380,12 +3550,12 @@ var file_core_v1_core_proto_depIdxs = []int32{ 31, // 43: core.v1.UsersetRewrite.intersection:type_name -> core.v1.SetOperation 31, // 44: core.v1.UsersetRewrite.exclusion:type_name -> core.v1.SetOperation 35, // 45: core.v1.UsersetRewrite.source_position:type_name -> core.v1.SourcePosition - 44, // 46: core.v1.SetOperation.child:type_name -> core.v1.SetOperation.Child - 48, // 47: core.v1.TupleToUserset.tupleset:type_name -> core.v1.TupleToUserset.Tupleset + 45, // 46: core.v1.SetOperation.child:type_name -> core.v1.SetOperation.Child + 49, // 47: core.v1.TupleToUserset.tupleset:type_name -> core.v1.TupleToUserset.Tupleset 34, // 48: core.v1.TupleToUserset.computed_userset:type_name -> core.v1.ComputedUserset 35, // 49: core.v1.TupleToUserset.source_position:type_name -> core.v1.SourcePosition 4, // 50: core.v1.FunctionedTupleToUserset.function:type_name -> core.v1.FunctionedTupleToUserset.Function - 49, // 51: core.v1.FunctionedTupleToUserset.tupleset:type_name -> core.v1.FunctionedTupleToUserset.Tupleset + 50, // 51: core.v1.FunctionedTupleToUserset.tupleset:type_name -> core.v1.FunctionedTupleToUserset.Tupleset 34, // 52: core.v1.FunctionedTupleToUserset.computed_userset:type_name -> core.v1.ComputedUserset 35, // 53: core.v1.FunctionedTupleToUserset.source_position:type_name -> core.v1.SourcePosition 5, // 54: core.v1.ComputedUserset.object:type_name -> core.v1.ComputedUserset.Object @@ -3395,23 +3565,28 @@ var file_core_v1_core_proto_depIdxs = []int32{ 6, // 58: core.v1.CaveatOperation.op:type_name -> core.v1.CaveatOperation.Operation 36, // 59: core.v1.CaveatOperation.children:type_name -> core.v1.CaveatExpression 39, // 60: core.v1.RelationshipFilter.optional_subject_filter:type_name -> core.v1.SubjectFilter - 50, // 61: core.v1.SubjectFilter.optional_relation:type_name -> core.v1.SubjectFilter.RelationFilter - 11, // 62: core.v1.CaveatDefinition.ParameterTypesEntry.value:type_name -> core.v1.CaveatTypeReference - 24, // 63: core.v1.ReachabilityGraph.EntrypointsBySubjectTypeEntry.value:type_name -> core.v1.ReachabilityEntrypoints - 24, // 64: core.v1.ReachabilityGraph.EntrypointsBySubjectRelationEntry.value:type_name -> core.v1.ReachabilityEntrypoints - 45, // 65: core.v1.SetOperation.Child._this:type_name -> core.v1.SetOperation.Child.This - 34, // 66: core.v1.SetOperation.Child.computed_userset:type_name -> core.v1.ComputedUserset - 32, // 67: core.v1.SetOperation.Child.tuple_to_userset:type_name -> core.v1.TupleToUserset - 30, // 68: core.v1.SetOperation.Child.userset_rewrite:type_name -> core.v1.UsersetRewrite - 33, // 69: core.v1.SetOperation.Child.functioned_tuple_to_userset:type_name -> core.v1.FunctionedTupleToUserset - 46, // 70: core.v1.SetOperation.Child._nil:type_name -> core.v1.SetOperation.Child.Nil - 47, // 71: core.v1.SetOperation.Child._self:type_name -> core.v1.SetOperation.Child.Self - 35, // 72: core.v1.SetOperation.Child.source_position:type_name -> core.v1.SourcePosition - 73, // [73:73] is the sub-list for method output_type - 73, // [73:73] is the sub-list for method input_type - 73, // [73:73] is the sub-list for extension type_name - 73, // [73:73] is the sub-list for extension extendee - 0, // [0:73] is the sub-list for field type_name + 51, // 61: core.v1.SubjectFilter.optional_relation:type_name -> core.v1.SubjectFilter.RelationFilter + 52, // 62: core.v1.StoredSchema.v1:type_name -> core.v1.StoredSchema.V1StoredSchema + 11, // 63: core.v1.CaveatDefinition.ParameterTypesEntry.value:type_name -> core.v1.CaveatTypeReference + 24, // 64: core.v1.ReachabilityGraph.EntrypointsBySubjectTypeEntry.value:type_name -> core.v1.ReachabilityEntrypoints + 24, // 65: core.v1.ReachabilityGraph.EntrypointsBySubjectRelationEntry.value:type_name -> core.v1.ReachabilityEntrypoints + 46, // 66: core.v1.SetOperation.Child._this:type_name -> core.v1.SetOperation.Child.This + 34, // 67: core.v1.SetOperation.Child.computed_userset:type_name -> core.v1.ComputedUserset + 32, // 68: core.v1.SetOperation.Child.tuple_to_userset:type_name -> core.v1.TupleToUserset + 30, // 69: core.v1.SetOperation.Child.userset_rewrite:type_name -> core.v1.UsersetRewrite + 33, // 70: core.v1.SetOperation.Child.functioned_tuple_to_userset:type_name -> core.v1.FunctionedTupleToUserset + 47, // 71: core.v1.SetOperation.Child._nil:type_name -> core.v1.SetOperation.Child.Nil + 48, // 72: core.v1.SetOperation.Child._self:type_name -> core.v1.SetOperation.Child.Self + 35, // 73: core.v1.SetOperation.Child.source_position:type_name -> core.v1.SourcePosition + 53, // 74: core.v1.StoredSchema.V1StoredSchema.namespace_definitions:type_name -> core.v1.StoredSchema.V1StoredSchema.NamespaceDefinitionsEntry + 54, // 75: core.v1.StoredSchema.V1StoredSchema.caveat_definitions:type_name -> core.v1.StoredSchema.V1StoredSchema.CaveatDefinitionsEntry + 21, // 76: core.v1.StoredSchema.V1StoredSchema.NamespaceDefinitionsEntry.value:type_name -> core.v1.NamespaceDefinition + 10, // 77: core.v1.StoredSchema.V1StoredSchema.CaveatDefinitionsEntry.value:type_name -> core.v1.CaveatDefinition + 78, // [78:78] is the sub-list for method output_type + 78, // [78:78] is the sub-list for method input_type + 78, // [78:78] is the sub-list for extension type_name + 78, // [78:78] is the sub-list for extension extendee + 0, // [0:78] is the sub-list for field type_name } func init() { file_core_v1_core_proto_init() } @@ -3436,7 +3611,10 @@ func file_core_v1_core_proto_init() { (*CaveatExpression_Operation)(nil), (*CaveatExpression_Caveat)(nil), } - file_core_v1_core_proto_msgTypes[37].OneofWrappers = []any{ + file_core_v1_core_proto_msgTypes[33].OneofWrappers = []any{ + (*StoredSchema_V1)(nil), + } + file_core_v1_core_proto_msgTypes[38].OneofWrappers = []any{ (*SetOperation_Child_XThis)(nil), (*SetOperation_Child_ComputedUserset)(nil), (*SetOperation_Child_TupleToUserset)(nil), @@ -3451,7 +3629,7 @@ func file_core_v1_core_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_core_v1_core_proto_rawDesc), len(file_core_v1_core_proto_rawDesc)), NumEnums: 7, - NumMessages: 44, + NumMessages: 48, NumExtensions: 0, NumServices: 0, }, diff --git a/pkg/proto/core/v1/core_vtproto.pb.go b/pkg/proto/core/v1/core_vtproto.pb.go index 1a90dbb78..ea58868c5 100644 --- a/pkg/proto/core/v1/core_vtproto.pb.go +++ b/pkg/proto/core/v1/core_vtproto.pb.go @@ -1023,6 +1023,69 @@ func (m *SubjectFilter) CloneMessageVT() proto.Message { return m.CloneVT() } +func (m *StoredSchema_V1StoredSchema) CloneVT() *StoredSchema_V1StoredSchema { + if m == nil { + return (*StoredSchema_V1StoredSchema)(nil) + } + r := new(StoredSchema_V1StoredSchema) + r.SchemaText = m.SchemaText + r.SchemaHash = m.SchemaHash + if rhs := m.NamespaceDefinitions; rhs != nil { + tmpContainer := make(map[string]*NamespaceDefinition, len(rhs)) + for k, v := range rhs { + tmpContainer[k] = v.CloneVT() + } + r.NamespaceDefinitions = tmpContainer + } + if rhs := m.CaveatDefinitions; rhs != nil { + tmpContainer := make(map[string]*CaveatDefinition, len(rhs)) + for k, v := range rhs { + tmpContainer[k] = v.CloneVT() + } + r.CaveatDefinitions = tmpContainer + } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } + return r +} + +func (m *StoredSchema_V1StoredSchema) CloneMessageVT() proto.Message { + return m.CloneVT() +} + +func (m *StoredSchema) CloneVT() *StoredSchema { + if m == nil { + return (*StoredSchema)(nil) + } + r := new(StoredSchema) + r.Version = m.Version + if m.VersionOneof != nil { + r.VersionOneof = m.VersionOneof.(interface { + CloneVT() isStoredSchema_VersionOneof + }).CloneVT() + } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } + return r +} + +func (m *StoredSchema) CloneMessageVT() proto.Message { + return m.CloneVT() +} + +func (m *StoredSchema_V1) CloneVT() isStoredSchema_VersionOneof { + if m == nil { + return (*StoredSchema_V1)(nil) + } + r := new(StoredSchema_V1) + r.V1 = m.V1.CloneVT() + return r +} + func (this *RelationTuple) EqualVT(that *RelationTuple) bool { if this == that { return true @@ -2581,6 +2644,124 @@ func (this *SubjectFilter) EqualMessageVT(thatMsg proto.Message) bool { } return this.EqualVT(that) } +func (this *StoredSchema_V1StoredSchema) EqualVT(that *StoredSchema_V1StoredSchema) bool { + if this == that { + return true + } else if this == nil || that == nil { + return false + } + if this.SchemaText != that.SchemaText { + return false + } + if this.SchemaHash != that.SchemaHash { + return false + } + if len(this.NamespaceDefinitions) != len(that.NamespaceDefinitions) { + return false + } + for i, vx := range this.NamespaceDefinitions { + vy, ok := that.NamespaceDefinitions[i] + if !ok { + return false + } + if p, q := vx, vy; p != q { + if p == nil { + p = &NamespaceDefinition{} + } + if q == nil { + q = &NamespaceDefinition{} + } + if !p.EqualVT(q) { + return false + } + } + } + if len(this.CaveatDefinitions) != len(that.CaveatDefinitions) { + return false + } + for i, vx := range this.CaveatDefinitions { + vy, ok := that.CaveatDefinitions[i] + if !ok { + return false + } + if p, q := vx, vy; p != q { + if p == nil { + p = &CaveatDefinition{} + } + if q == nil { + q = &CaveatDefinition{} + } + if !p.EqualVT(q) { + return false + } + } + } + return string(this.unknownFields) == string(that.unknownFields) +} + +func (this *StoredSchema_V1StoredSchema) EqualMessageVT(thatMsg proto.Message) bool { + that, ok := thatMsg.(*StoredSchema_V1StoredSchema) + if !ok { + return false + } + return this.EqualVT(that) +} +func (this *StoredSchema) EqualVT(that *StoredSchema) bool { + if this == that { + return true + } else if this == nil || that == nil { + return false + } + if this.VersionOneof == nil && that.VersionOneof != nil { + return false + } else if this.VersionOneof != nil { + if that.VersionOneof == nil { + return false + } + if !this.VersionOneof.(interface { + EqualVT(isStoredSchema_VersionOneof) bool + }).EqualVT(that.VersionOneof) { + return false + } + } + if this.Version != that.Version { + return false + } + return string(this.unknownFields) == string(that.unknownFields) +} + +func (this *StoredSchema) EqualMessageVT(thatMsg proto.Message) bool { + that, ok := thatMsg.(*StoredSchema) + if !ok { + return false + } + return this.EqualVT(that) +} +func (this *StoredSchema_V1) EqualVT(thatIface isStoredSchema_VersionOneof) bool { + that, ok := thatIface.(*StoredSchema_V1) + if !ok { + return false + } + if this == that { + return true + } + if this == nil && that != nil || this != nil && that == nil { + return false + } + if p, q := this.V1, that.V1; p != q { + if p == nil { + p = &StoredSchema_V1StoredSchema{} + } + if q == nil { + q = &StoredSchema_V1StoredSchema{} + } + if !p.EqualVT(q) { + return false + } + } + return true +} + func (m *RelationTuple) MarshalVT() (dAtA []byte, err error) { if m == nil { return nil, nil @@ -5154,85 +5335,246 @@ func (m *SubjectFilter) MarshalToSizedBufferVT(dAtA []byte) (int, error) { return len(dAtA) - i, nil } -func (m *RelationTuple) SizeVT() (n int) { +func (m *StoredSchema_V1StoredSchema) MarshalVT() (dAtA []byte, err error) { if m == nil { - return 0 - } - var l int - _ = l - if m.ResourceAndRelation != nil { - l = m.ResourceAndRelation.SizeVT() - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) - } - if m.Subject != nil { - l = m.Subject.SizeVT() - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) - } - if m.Caveat != nil { - l = m.Caveat.SizeVT() - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) - } - if m.Integrity != nil { - l = m.Integrity.SizeVT() - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + return nil, nil } - if m.OptionalExpirationTime != nil { - l = (*timestamppb1.Timestamp)(m.OptionalExpirationTime).SizeVT() - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVT(dAtA[:size]) + if err != nil { + return nil, err } - n += len(m.unknownFields) - return n + return dAtA[:n], nil } -func (m *RelationshipIntegrity) SizeVT() (n int) { +func (m *StoredSchema_V1StoredSchema) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *StoredSchema_V1StoredSchema) MarshalToSizedBufferVT(dAtA []byte) (int, error) { if m == nil { - return 0 + return 0, nil } + i := len(dAtA) + _ = i var l int _ = l - l = len(m.KeyId) - if l > 0 { - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) } - l = len(m.Hash) - if l > 0 { - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + if len(m.CaveatDefinitions) > 0 { + for k := range m.CaveatDefinitions { + v := m.CaveatDefinitions[k] + baseI := i + size, err := v.MarshalToSizedBufferVT(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) + i-- + dAtA[i] = 0x12 + i -= len(k) + copy(dAtA[i:], k) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(k))) + i-- + dAtA[i] = 0xa + i = protohelpers.EncodeVarint(dAtA, i, uint64(baseI-i)) + i-- + dAtA[i] = 0x22 + } } - if m.HashedAt != nil { - l = (*timestamppb1.Timestamp)(m.HashedAt).SizeVT() - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + if len(m.NamespaceDefinitions) > 0 { + for k := range m.NamespaceDefinitions { + v := m.NamespaceDefinitions[k] + baseI := i + size, err := v.MarshalToSizedBufferVT(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) + i-- + dAtA[i] = 0x12 + i -= len(k) + copy(dAtA[i:], k) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(k))) + i-- + dAtA[i] = 0xa + i = protohelpers.EncodeVarint(dAtA, i, uint64(baseI-i)) + i-- + dAtA[i] = 0x1a + } } - n += len(m.unknownFields) - return n + if len(m.SchemaHash) > 0 { + i -= len(m.SchemaHash) + copy(dAtA[i:], m.SchemaHash) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.SchemaHash))) + i-- + dAtA[i] = 0x12 + } + if len(m.SchemaText) > 0 { + i -= len(m.SchemaText) + copy(dAtA[i:], m.SchemaText) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.SchemaText))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil } -func (m *ContextualizedCaveat) SizeVT() (n int) { +func (m *StoredSchema) MarshalVT() (dAtA []byte, err error) { if m == nil { - return 0 - } - var l int - _ = l - l = len(m.CaveatName) - if l > 0 { - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + return nil, nil } - if m.Context != nil { - l = (*structpb1.Struct)(m.Context).SizeVT() - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVT(dAtA[:size]) + if err != nil { + return nil, err } - n += len(m.unknownFields) - return n + return dAtA[:n], nil } -func (m *CaveatDefinition) SizeVT() (n int) { +func (m *StoredSchema) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *StoredSchema) MarshalToSizedBufferVT(dAtA []byte) (int, error) { if m == nil { - return 0 + return 0, nil } + i := len(dAtA) + _ = i var l int _ = l - l = len(m.Name) - if l > 0 { - n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) + } + if vtmsg, ok := m.VersionOneof.(interface { + MarshalToSizedBufferVT([]byte) (int, error) + }); ok { + size, err := vtmsg.MarshalToSizedBufferVT(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + } + if m.Version != 0 { + i = protohelpers.EncodeVarint(dAtA, i, uint64(m.Version)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func (m *StoredSchema_V1) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *StoredSchema_V1) MarshalToSizedBufferVT(dAtA []byte) (int, error) { + i := len(dAtA) + if m.V1 != nil { + size, err := m.V1.MarshalToSizedBufferVT(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) + i-- + dAtA[i] = 0x12 + } else { + i = protohelpers.EncodeVarint(dAtA, i, 0) + i-- + dAtA[i] = 0x12 + } + return len(dAtA) - i, nil +} +func (m *RelationTuple) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.ResourceAndRelation != nil { + l = m.ResourceAndRelation.SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + if m.Subject != nil { + l = m.Subject.SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + if m.Caveat != nil { + l = m.Caveat.SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + if m.Integrity != nil { + l = m.Integrity.SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + if m.OptionalExpirationTime != nil { + l = (*timestamppb1.Timestamp)(m.OptionalExpirationTime).SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + n += len(m.unknownFields) + return n +} + +func (m *RelationshipIntegrity) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.KeyId) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + l = len(m.Hash) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + if m.HashedAt != nil { + l = (*timestamppb1.Timestamp)(m.HashedAt).SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + n += len(m.unknownFields) + return n +} + +func (m *ContextualizedCaveat) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.CaveatName) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + if m.Context != nil { + l = (*structpb1.Struct)(m.Context).SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + n += len(m.unknownFields) + return n +} + +func (m *CaveatDefinition) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Name) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) } l = len(m.SerializedExpression) if l > 0 { @@ -6194,6 +6536,80 @@ func (m *SubjectFilter) SizeVT() (n int) { return n } +func (m *StoredSchema_V1StoredSchema) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.SchemaText) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + l = len(m.SchemaHash) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + if len(m.NamespaceDefinitions) > 0 { + for k, v := range m.NamespaceDefinitions { + _ = k + _ = v + l = 0 + if v != nil { + l = v.SizeVT() + } + l += 1 + protohelpers.SizeOfVarint(uint64(l)) + mapEntrySize := 1 + len(k) + protohelpers.SizeOfVarint(uint64(len(k))) + l + n += mapEntrySize + 1 + protohelpers.SizeOfVarint(uint64(mapEntrySize)) + } + } + if len(m.CaveatDefinitions) > 0 { + for k, v := range m.CaveatDefinitions { + _ = k + _ = v + l = 0 + if v != nil { + l = v.SizeVT() + } + l += 1 + protohelpers.SizeOfVarint(uint64(l)) + mapEntrySize := 1 + len(k) + protohelpers.SizeOfVarint(uint64(len(k))) + l + n += mapEntrySize + 1 + protohelpers.SizeOfVarint(uint64(mapEntrySize)) + } + } + n += len(m.unknownFields) + return n +} + +func (m *StoredSchema) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Version != 0 { + n += 1 + protohelpers.SizeOfVarint(uint64(m.Version)) + } + if vtmsg, ok := m.VersionOneof.(interface{ SizeVT() int }); ok { + n += vtmsg.SizeVT() + } + n += len(m.unknownFields) + return n +} + +func (m *StoredSchema_V1) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.V1 != nil { + l = m.V1.SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } else { + n += 2 + } + return n +} func (m *RelationTuple) UnmarshalVT(dAtA []byte) error { l := len(dAtA) iNdEx := 0 @@ -12164,3 +12580,487 @@ func (m *SubjectFilter) UnmarshalVT(dAtA []byte) error { } return nil } +func (m *StoredSchema_V1StoredSchema) UnmarshalVT(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: StoredSchema_V1StoredSchema: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: StoredSchema_V1StoredSchema: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field SchemaText", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.SchemaText = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field SchemaHash", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.SchemaHash = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field NamespaceDefinitions", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.NamespaceDefinitions == nil { + m.NamespaceDefinitions = make(map[string]*NamespaceDefinition) + } + var mapkey string + var mapvalue *NamespaceDefinition + for iNdEx < postIndex { + entryPreIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + if fieldNum == 1 { + var stringLenmapkey uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLenmapkey |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLenmapkey := int(stringLenmapkey) + if intStringLenmapkey < 0 { + return protohelpers.ErrInvalidLength + } + postStringIndexmapkey := iNdEx + intStringLenmapkey + if postStringIndexmapkey < 0 { + return protohelpers.ErrInvalidLength + } + if postStringIndexmapkey > l { + return io.ErrUnexpectedEOF + } + mapkey = string(dAtA[iNdEx:postStringIndexmapkey]) + iNdEx = postStringIndexmapkey + } else if fieldNum == 2 { + var mapmsglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + mapmsglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if mapmsglen < 0 { + return protohelpers.ErrInvalidLength + } + postmsgIndex := iNdEx + mapmsglen + if postmsgIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postmsgIndex > l { + return io.ErrUnexpectedEOF + } + mapvalue = &NamespaceDefinition{} + if err := mapvalue.UnmarshalVT(dAtA[iNdEx:postmsgIndex]); err != nil { + return err + } + iNdEx = postmsgIndex + } else { + iNdEx = entryPreIndex + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return protohelpers.ErrInvalidLength + } + if (iNdEx + skippy) > postIndex { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + m.NamespaceDefinitions[mapkey] = mapvalue + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field CaveatDefinitions", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.CaveatDefinitions == nil { + m.CaveatDefinitions = make(map[string]*CaveatDefinition) + } + var mapkey string + var mapvalue *CaveatDefinition + for iNdEx < postIndex { + entryPreIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + if fieldNum == 1 { + var stringLenmapkey uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLenmapkey |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLenmapkey := int(stringLenmapkey) + if intStringLenmapkey < 0 { + return protohelpers.ErrInvalidLength + } + postStringIndexmapkey := iNdEx + intStringLenmapkey + if postStringIndexmapkey < 0 { + return protohelpers.ErrInvalidLength + } + if postStringIndexmapkey > l { + return io.ErrUnexpectedEOF + } + mapkey = string(dAtA[iNdEx:postStringIndexmapkey]) + iNdEx = postStringIndexmapkey + } else if fieldNum == 2 { + var mapmsglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + mapmsglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if mapmsglen < 0 { + return protohelpers.ErrInvalidLength + } + postmsgIndex := iNdEx + mapmsglen + if postmsgIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postmsgIndex > l { + return io.ErrUnexpectedEOF + } + mapvalue = &CaveatDefinition{} + if err := mapvalue.UnmarshalVT(dAtA[iNdEx:postmsgIndex]); err != nil { + return err + } + iNdEx = postmsgIndex + } else { + iNdEx = entryPreIndex + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return protohelpers.ErrInvalidLength + } + if (iNdEx + skippy) > postIndex { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + m.CaveatDefinitions[mapkey] = mapvalue + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return protohelpers.ErrInvalidLength + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *StoredSchema) UnmarshalVT(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: StoredSchema: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: StoredSchema: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Version", wireType) + } + m.Version = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Version |= uint32(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field V1", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if oneof, ok := m.VersionOneof.(*StoredSchema_V1); ok { + if err := oneof.V1.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + return err + } + } else { + v := &StoredSchema_V1StoredSchema{} + if err := v.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + return err + } + m.VersionOneof = &StoredSchema_V1{V1: v} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return protohelpers.ErrInvalidLength + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} diff --git a/pkg/proto/dispatch/v1/02_resolvermeta_test.go b/pkg/proto/dispatch/v1/02_resolvermeta_test.go index d689703f1..da9b8efe5 100644 --- a/pkg/proto/dispatch/v1/02_resolvermeta_test.go +++ b/pkg/proto/dispatch/v1/02_resolvermeta_test.go @@ -5,6 +5,8 @@ import ( "github.com/bits-and-blooms/bloom/v3" "github.com/stretchr/testify/require" + + "github.com/authzed/spicedb/pkg/datalayer" ) func TestRecordTraversal(t *testing.T) { @@ -12,15 +14,23 @@ func TestRecordTraversal(t *testing.T) { _, err := rm.RecordTraversal("test") require.ErrorContains(t, err, "missing") - rm = &ResolverMeta{} + rm = &ResolverMeta{ + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + } _, err = rm.RecordTraversal("test") require.ErrorContains(t, err, "missing") - rm = &ResolverMeta{TraversalBloom: []byte("")} + rm = &ResolverMeta{ + TraversalBloom: []byte(""), + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + } _, err = rm.RecordTraversal("test") require.ErrorContains(t, err, "missing") - rm = &ResolverMeta{TraversalBloom: []byte("foo")} + rm = &ResolverMeta{ + TraversalBloom: []byte("foo"), + SchemaHash: []byte(datalayer.NoSchemaHashForTesting), + } _, err = rm.RecordTraversal("test") require.ErrorContains(t, err, "unmarshall") diff --git a/pkg/proto/dispatch/v1/dispatch.pb.go b/pkg/proto/dispatch/v1/dispatch.pb.go index 48c5a7f73..9641e20f1 100644 --- a/pkg/proto/dispatch/v1/dispatch.pb.go +++ b/pkg/proto/dispatch/v1/dispatch.pb.go @@ -1372,6 +1372,7 @@ type ResolverMeta struct { // Deprecated: Marked as deprecated in dispatch/v1/dispatch.proto. RequestId string `protobuf:"bytes,3,opt,name=request_id,json=requestId,proto3" json:"request_id,omitempty"` TraversalBloom []byte `protobuf:"bytes,4,opt,name=traversal_bloom,json=traversalBloom,proto3" json:"traversal_bloom,omitempty"` + SchemaHash []byte `protobuf:"bytes,5,opt,name=schema_hash,json=schemaHash,proto3" json:"schema_hash,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -1435,6 +1436,13 @@ func (x *ResolverMeta) GetTraversalBloom() []byte { return nil } +func (x *ResolverMeta) GetSchemaHash() []byte { + if x != nil { + return x.SchemaHash + } + return nil +} + type ResponseMeta struct { state protoimpl.MessageState `protogen:"open.v1"` DispatchCount uint32 `protobuf:"varint,1,opt,name=dispatch_count,json=dispatchCount,proto3" json:"dispatch_count,omitempty"` @@ -1762,14 +1770,16 @@ const file_dispatch_v1_dispatch_proto_rawDesc = "" + "\bmetadata\x18\x02 \x01(\v2\x19.dispatch.v1.ResponseMetaR\bmetadata\x1ah\n" + "\x1eFoundSubjectsByResourceIdEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x120\n" + - "\x05value\x18\x02 \x01(\v2\x1a.dispatch.v1.FoundSubjectsR\x05value:\x028\x01\"\xb7\x01\n" + + "\x05value\x18\x02 \x01(\v2\x1a.dispatch.v1.FoundSubjectsR\x05value:\x028\x01\"\xd8\x01\n" + "\fResolverMeta\x12\x1f\n" + "\vat_revision\x18\x01 \x01(\tR\n" + "atRevision\x120\n" + "\x0fdepth_remaining\x18\x02 \x01(\rB\a\xbaH\x04*\x02 \x00R\x0edepthRemaining\x12!\n" + "\n" + "request_id\x18\x03 \x01(\tB\x02\x18\x01R\trequestId\x121\n" + - "\x0ftraversal_bloom\x18\x04 \x01(\fB\b\xbaH\x05z\x03\x18\x80\bR\x0etraversalBloom\"\xda\x01\n" + + "\x0ftraversal_bloom\x18\x04 \x01(\fB\b\xbaH\x05z\x03\x18\x80\bR\x0etraversalBloom\x12\x1f\n" + + "\vschema_hash\x18\x05 \x01(\fR\n" + + "schemaHash\"\xda\x01\n" + "\fResponseMeta\x12%\n" + "\x0edispatch_count\x18\x01 \x01(\rR\rdispatchCount\x12%\n" + "\x0edepth_required\x18\x02 \x01(\rR\rdepthRequired\x122\n" + diff --git a/pkg/proto/dispatch/v1/dispatch_vtproto.pb.go b/pkg/proto/dispatch/v1/dispatch_vtproto.pb.go index e03bb3ac8..777ff0248 100644 --- a/pkg/proto/dispatch/v1/dispatch_vtproto.pb.go +++ b/pkg/proto/dispatch/v1/dispatch_vtproto.pb.go @@ -551,6 +551,11 @@ func (m *ResolverMeta) CloneVT() *ResolverMeta { copy(tmpBytes, rhs) r.TraversalBloom = tmpBytes } + if rhs := m.SchemaHash; rhs != nil { + tmpBytes := make([]byte, len(rhs)) + copy(tmpBytes, rhs) + r.SchemaHash = tmpBytes + } if len(m.unknownFields) > 0 { r.unknownFields = make([]byte, len(m.unknownFields)) copy(r.unknownFields, m.unknownFields) @@ -1363,6 +1368,9 @@ func (this *ResolverMeta) EqualVT(that *ResolverMeta) bool { if string(this.TraversalBloom) != string(that.TraversalBloom) { return false } + if string(this.SchemaHash) != string(that.SchemaHash) { + return false + } return string(this.unknownFields) == string(that.unknownFields) } @@ -2847,6 +2855,13 @@ func (m *ResolverMeta) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i -= len(m.unknownFields) copy(dAtA[i:], m.unknownFields) } + if len(m.SchemaHash) > 0 { + i -= len(m.SchemaHash) + copy(dAtA[i:], m.SchemaHash) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.SchemaHash))) + i-- + dAtA[i] = 0x2a + } if len(m.TraversalBloom) > 0 { i -= len(m.TraversalBloom) copy(dAtA[i:], m.TraversalBloom) @@ -3663,6 +3678,10 @@ func (m *ResolverMeta) SizeVT() (n int) { if l > 0 { n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) } + l = len(m.SchemaHash) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } n += len(m.unknownFields) return n } @@ -7051,6 +7070,40 @@ func (m *ResolverMeta) UnmarshalVT(dAtA []byte) error { m.TraversalBloom = []byte{} } iNdEx = postIndex + case 5: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field SchemaHash", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.SchemaHash = append(m.SchemaHash[:0], dAtA[iNdEx:postIndex]...) + if m.SchemaHash == nil { + m.SchemaHash = []byte{} + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := protohelpers.Skip(dAtA[iNdEx:]) diff --git a/pkg/proto/impl/v1/impl.pb.go b/pkg/proto/impl/v1/impl.pb.go index 8a2b2eec2..da3c9bba5 100644 --- a/pkg/proto/impl/v1/impl.pb.go +++ b/pkg/proto/impl/v1/impl.pb.go @@ -403,8 +403,11 @@ type V1Cursor struct { // datastore_unique_id is the unique ID for the datastore. Will be empty for legacy // cursors. DatastoreUniqueId string `protobuf:"bytes,6,opt,name=datastore_unique_id,json=datastoreUniqueId,proto3" json:"datastore_unique_id,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // schema_hash is the hash of the schema at the time the cursor was created. Will be + // empty for legacy cursors. + SchemaHash []byte `protobuf:"bytes,7,opt,name=schema_hash,json=schemaHash,proto3" json:"schema_hash,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *V1Cursor) Reset() { @@ -479,6 +482,13 @@ func (x *V1Cursor) GetDatastoreUniqueId() string { return "" } +func (x *V1Cursor) GetSchemaHash() []byte { + if x != nil { + return x.SchemaHash + } + return nil +} + type DocComment struct { state protoimpl.MessageState `protogen:"open.v1"` Comment string `protobuf:"bytes,1,opt,name=comment,proto3" json:"comment,omitempty"` @@ -932,14 +942,16 @@ const file_impl_v1_impl_proto_rawDesc = "" + "\rversion_oneof\"E\n" + "\rDecodedCursor\x12#\n" + "\x02v1\x18\x01 \x01(\v2\x11.impl.v1.V1CursorH\x00R\x02v1B\x0f\n" + - "\rversion_oneof\"\xc4\x02\n" + + "\rversion_oneof\"\xe5\x02\n" + "\bV1Cursor\x12\x1a\n" + "\brevision\x18\x01 \x01(\tR\brevision\x12\x1a\n" + "\bsections\x18\x02 \x03(\tR\bsections\x127\n" + "\x18call_and_parameters_hash\x18\x03 \x01(\tR\x15callAndParametersHash\x12)\n" + "\x10dispatch_version\x18\x04 \x01(\rR\x0fdispatchVersion\x122\n" + "\x05flags\x18\x05 \x03(\v2\x1c.impl.v1.V1Cursor.FlagsEntryR\x05flags\x12.\n" + - "\x13datastore_unique_id\x18\x06 \x01(\tR\x11datastoreUniqueId\x1a8\n" + + "\x13datastore_unique_id\x18\x06 \x01(\tR\x11datastoreUniqueId\x12\x1f\n" + + "\vschema_hash\x18\a \x01(\fR\n" + + "schemaHash\x1a8\n" + "\n" + "FlagsEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + diff --git a/pkg/proto/impl/v1/impl_vtproto.pb.go b/pkg/proto/impl/v1/impl_vtproto.pb.go index d5323399c..5ec4b439a 100644 --- a/pkg/proto/impl/v1/impl_vtproto.pb.go +++ b/pkg/proto/impl/v1/impl_vtproto.pb.go @@ -256,6 +256,11 @@ func (m *V1Cursor) CloneVT() *V1Cursor { } r.Flags = tmpContainer } + if rhs := m.SchemaHash; rhs != nil { + tmpBytes := make([]byte, len(rhs)) + copy(tmpBytes, rhs) + r.SchemaHash = tmpBytes + } if len(m.unknownFields) > 0 { r.unknownFields = make([]byte, len(m.unknownFields)) copy(r.unknownFields, m.unknownFields) @@ -756,6 +761,9 @@ func (this *V1Cursor) EqualVT(that *V1Cursor) bool { if this.DatastoreUniqueId != that.DatastoreUniqueId { return false } + if string(this.SchemaHash) != string(that.SchemaHash) { + return false + } return string(this.unknownFields) == string(that.unknownFields) } @@ -1410,6 +1418,13 @@ func (m *V1Cursor) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i -= len(m.unknownFields) copy(dAtA[i:], m.unknownFields) } + if len(m.SchemaHash) > 0 { + i -= len(m.SchemaHash) + copy(dAtA[i:], m.SchemaHash) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.SchemaHash))) + i-- + dAtA[i] = 0x3a + } if len(m.DatastoreUniqueId) > 0 { i -= len(m.DatastoreUniqueId) copy(dAtA[i:], m.DatastoreUniqueId) @@ -1931,6 +1946,10 @@ func (m *V1Cursor) SizeVT() (n int) { if l > 0 { n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) } + l = len(m.SchemaHash) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } n += len(m.unknownFields) return n } @@ -3174,6 +3193,40 @@ func (m *V1Cursor) UnmarshalVT(dAtA []byte) error { } m.DatastoreUniqueId = string(dAtA[iNdEx:postIndex]) iNdEx = postIndex + case 7: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field SchemaHash", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.SchemaHash = append(m.SchemaHash[:0], dAtA[iNdEx:postIndex]...) + if m.SchemaHash == nil { + m.SchemaHash = []byte{} + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := protohelpers.Skip(dAtA[iNdEx:]) diff --git a/pkg/query/arrow_reversal_test.go b/pkg/query/arrow_reversal_test.go index a4480fd7d..1aa367bb3 100644 --- a/pkg/query/arrow_reversal_test.go +++ b/pkg/query/arrow_reversal_test.go @@ -112,7 +112,7 @@ func TestDoubleWideArrowAdvisedMatchesPlain(t *testing.T) { resources := NewObjects("file", "file0") subject := NewObject("user", "user42").WithEllipses() - readerOpt := WithRevisionedReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision)) + readerOpt := WithRevisionedReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision, datalayer.NoSchemaHashForTesting)) // ---- plain (LTR) ---- diff --git a/pkg/query/benchmarks/check_deep_arrow_benchmark_test.go b/pkg/query/benchmarks/check_deep_arrow_benchmark_test.go index 85bf334bc..118bbbbc9 100644 --- a/pkg/query/benchmarks/check_deep_arrow_benchmark_test.go +++ b/pkg/query/benchmarks/check_deep_arrow_benchmark_test.go @@ -88,7 +88,7 @@ func BenchmarkCheckDeepArrow(b *testing.B) { subject := query.NewObject("user", "slow").WithEllipses() // Base reader (no simulated latency). - reader := query.NewQueryDatastoreReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision)) + reader := query.NewQueryDatastoreReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision, datalayer.NoSchemaHashForTesting)) // Delay reader wrapping the base reader with simulated network latency. delayReader := query.NewDelayReader(networkDelay, reader) diff --git a/pkg/query/benchmarks/check_double_wide_arrow_benchmark_test.go b/pkg/query/benchmarks/check_double_wide_arrow_benchmark_test.go index 62cda3025..5365da123 100644 --- a/pkg/query/benchmarks/check_double_wide_arrow_benchmark_test.go +++ b/pkg/query/benchmarks/check_double_wide_arrow_benchmark_test.go @@ -141,7 +141,7 @@ func BenchmarkCheckDoubleWideArrow(b *testing.B) { subject := query.NewObject("user", "user181").WithEllipses() // Base reader (no simulated latency). - reader := query.NewQueryDatastoreReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision)) + reader := query.NewQueryDatastoreReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision, datalayer.NoSchemaHashForTesting)) // Delay reader wrapping the base reader with simulated network latency. delayReader := query.NewDelayReader(networkDelay, reader) diff --git a/pkg/query/benchmarks/check_wide_arrow_benchmark_test.go b/pkg/query/benchmarks/check_wide_arrow_benchmark_test.go index 2dab7353d..f9fccde32 100644 --- a/pkg/query/benchmarks/check_wide_arrow_benchmark_test.go +++ b/pkg/query/benchmarks/check_wide_arrow_benchmark_test.go @@ -135,7 +135,7 @@ func BenchmarkCheckWideArrow(b *testing.B) { subject := query.NewObject("user", "user15").WithEllipses() // Base reader (no simulated latency). - reader := query.NewQueryDatastoreReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision)) + reader := query.NewQueryDatastoreReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision, datalayer.NoSchemaHashForTesting)) // Delay reader wrapping the base reader with simulated network latency. delayReader := query.NewDelayReader(networkDelay, reader) diff --git a/pkg/query/build_tree_test.go b/pkg/query/build_tree_test.go index e7b616258..90e2b0e04 100644 --- a/pkg/query/build_tree_test.go +++ b/pkg/query/build_tree_test.go @@ -32,7 +32,7 @@ func TestBuildTree(t *testing.T) { require.NoError(err) ctx := NewLocalContext(t.Context(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) relSeq, err := ctx.Check(it, NewObjects("document", "specialplan"), NewObject("user", "multiroleguy").WithEllipses()) require.NoError(err) @@ -62,7 +62,7 @@ func TestBuildTreeMultipleRelations(t *testing.T) { require.Contains(explain.String(), "Union", "edit permission should create a union iterator") ctx := NewLocalContext(t.Context(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) relSeq, err := ctx.Check(it, NewObjects("document", "specialplan"), NewObject("user", "multiroleguy").WithEllipses()) require.NoError(err) @@ -113,7 +113,7 @@ func TestBuildTreeSubRelations(t *testing.T) { require.NotEmpty(explain.String()) ctx := NewLocalContext(t.Context(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) // Just test that the iterator can be executed without error relSeq, err := ctx.Check(it, NewObjects("document", "companyplan"), NewObject("user", "legal").WithEllipses()) @@ -212,7 +212,7 @@ func TestBuildTreeIntersectionOperation(t *testing.T) { require.Contains(explain.String(), "Intersection", "should create intersection iterator") ctx := NewLocalContext(t.Context(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) // Test execution relSeq, err := ctx.Check(it, NewObjects("document", "specialplan"), NewObject("user", "multiroleguy").WithEllipses()) @@ -275,7 +275,7 @@ func TestBuildTreeExclusionEdgeCases(t *testing.T) { ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) ctx := NewLocalContext(t.Context(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) userDef := testfixtures.UserNS.CloneVT() @@ -535,7 +535,7 @@ func TestBuildTreeSingleRelationOptimization(t *testing.T) { require.Contains(explain.String(), "Datastore", "should create datastore iterator") ctx := NewLocalContext(t.Context(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) // Test execution relSeq, err := ctx.Check(it, NewObjects("document", "companyplan"), NewObject("user", "legal").WithEllipses()) @@ -555,7 +555,7 @@ func TestBuildTreeSubrelationHandling(t *testing.T) { ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) ctx := NewLocalContext(t.Context(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) userDef := testfixtures.UserNS.CloneVT() diff --git a/pkg/query/caveat_test.go b/pkg/query/caveat_test.go index 645a10b61..b4f560aea 100644 --- a/pkg/query/caveat_test.go +++ b/pkg/query/caveat_test.go @@ -110,7 +110,7 @@ func TestCaveatIteratorNoCaveat(t *testing.T) { require.NoError(t, err) queryCtx := NewLocalContext(context.Background(), - WithRevisionedReader(dl.SnapshotReader(rev)), + WithRevisionedReader(dl.SnapshotReader(rev, datalayer.NoSchemaHashForTesting)), WithCaveatContext(tc.caveatContext), WithCaveatRunner(caveats.NewCaveatRunner(types.NewTypeSet()))) @@ -203,7 +203,7 @@ func TestCaveatIteratorWithCaveat(t *testing.T) { require.NoError(t, err) queryCtx := NewLocalContext(context.Background(), - WithRevisionedReader(dl.SnapshotReader(rev)), + WithRevisionedReader(dl.SnapshotReader(rev, datalayer.NoSchemaHashForTesting)), WithCaveatContext(tc.caveatContext), WithCaveatRunner(caveats.NewCaveatRunner(types.NewTypeSet()))) diff --git a/pkg/query/exclusion_test.go b/pkg/query/exclusion_test.go index 6b3134661..524afa01e 100644 --- a/pkg/query/exclusion_test.go +++ b/pkg/query/exclusion_test.go @@ -23,7 +23,7 @@ func TestExclusionIterator(t *testing.T) { ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) ctx := NewLocalContext(t.Context(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) // Create test paths path1 := MustPathFromString("document:doc1#viewer@user:alice") @@ -259,7 +259,7 @@ func TestExclusionWithEmptyIterator(t *testing.T) { ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) ctx := NewLocalContext(t.Context(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) path1 := MustPathFromString("document:doc1#viewer@user:alice") @@ -305,7 +305,7 @@ func TestExclusionErrorHandling(t *testing.T) { ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) ctx := NewLocalContext(t.Context(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) path1 := MustPathFromString("document:doc1#viewer@user:alice") @@ -386,7 +386,7 @@ func TestExclusionWithComplexIteratorTypes(t *testing.T) { ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) ctx := NewLocalContext(t.Context(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) // Create test relations path1 := MustPathFromString("document:doc1#viewer@user:alice") @@ -574,7 +574,7 @@ func TestExclusion_CombinedCaveatLogic(t *testing.T) { ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) ctx := NewLocalContext(t.Context(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) // Helper to create paths with caveats createPathWithCaveat := func(relation, caveatName string) Path { diff --git a/pkg/query/intersection_arrow_test.go b/pkg/query/intersection_arrow_test.go index 00c00bfa7..4102068ec 100644 --- a/pkg/query/intersection_arrow_test.go +++ b/pkg/query/intersection_arrow_test.go @@ -41,7 +41,7 @@ func TestIntersectionArrowIterator(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) // Test: alice should have access because she's a member of ALL teams (team1 and team2) resources := []Object{NewObject("document", "doc1")} @@ -90,7 +90,7 @@ func TestIntersectionArrowIterator(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) // Test: alice should NOT have access because she's not a member of ALL teams resources := []Object{NewObject("document", "doc1")} @@ -130,7 +130,7 @@ func TestIntersectionArrowIterator(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) // Test: alice should have access because she's a member of the only team resources := []Object{NewObject("document", "doc1")} @@ -175,7 +175,7 @@ func TestIntersectionArrowIterator(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) resources := []Object{NewObject("document", "doc1")} subject := ObjectAndRelation{ObjectType: "user", ObjectID: "alice"} @@ -218,7 +218,7 @@ func TestIntersectionArrowIterator(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) resources := []Object{NewObject("document", "doc1")} subject := ObjectAndRelation{ObjectType: "user", ObjectID: "alice"} @@ -257,7 +257,7 @@ func TestIntersectionArrowIterator(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) resources := []Object{} subject := ObjectAndRelation{ObjectType: "user", ObjectID: "alice"} @@ -286,7 +286,7 @@ func TestIntersectionArrowIteratorCaveatCombination(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) t.Run("CombineTwoCaveats_AND_Logic", func(t *testing.T) { t.Parallel() @@ -513,7 +513,7 @@ func TestIntersectionArrowIteratorClone(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) // Test that both iterators produce the same results resources := []Object{NewObject("document", "doc1")} diff --git a/pkg/query/observer_analyze_test.go b/pkg/query/observer_analyze_test.go index 27827283b..06dd031d4 100644 --- a/pkg/query/observer_analyze_test.go +++ b/pkg/query/observer_analyze_test.go @@ -149,7 +149,7 @@ func TestAnalysisIntegration(t *testing.T) { // Create a context with analysis enabled analyze := NewAnalyzeObserver() ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision)), + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting)), WithObserver(analyze)) // Execute a Check operation diff --git a/pkg/query/observer_count_test.go b/pkg/query/observer_count_test.go index 91d2b9358..0aa0fe549 100644 --- a/pkg/query/observer_count_test.go +++ b/pkg/query/observer_count_test.go @@ -154,7 +154,7 @@ func TestCountObserverIntegration(t *testing.T) { // Create a context with count observer enabled countObs := NewCountObserver() ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision)), + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting)), WithObserver(countObs)) // Execute a Check operation diff --git a/pkg/query/quick_e2e_test.go b/pkg/query/quick_e2e_test.go index 02a6349f8..56bf9bc2c 100644 --- a/pkg/query/quick_e2e_test.go +++ b/pkg/query/quick_e2e_test.go @@ -39,7 +39,7 @@ func TestCheck(t *testing.T) { it := NewIntersectionIterator(vande, edit) ctx := NewLocalContext(t.Context(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) relSeq, err := ctx.Check(it, NewObjects("document", "specialplan"), NewObject("user", "multiroleguy").WithEllipses()) require.NoError(err) @@ -67,7 +67,7 @@ func TestBaseIterSubjects(t *testing.T) { vande := NewDatastoreIterator(vandeRel.BaseRelations()[0]) ctx := NewLocalContext(t.Context(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) relSeq, err := ctx.IterSubjects(vande, NewObject("document", "specialplan"), NoObjectFilter()) require.NoError(err) @@ -100,7 +100,7 @@ func TestCheckArrow(t *testing.T) { it := NewArrowIterator(folders, view) ctx := NewLocalContext(t.Context(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) relSeq, err := ctx.Check(it, NewObjects("document", "companyplan"), NewObject("user", "legal").WithEllipses()) require.NoError(err) diff --git a/pkg/query/recursive_benchmark_test.go b/pkg/query/recursive_benchmark_test.go index faf344f73..009f435bc 100644 --- a/pkg/query/recursive_benchmark_test.go +++ b/pkg/query/recursive_benchmark_test.go @@ -44,7 +44,7 @@ func BenchmarkRecursiveShallowGraph(b *testing.B) { } ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision)), + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting)), WithMaxRecursionDepth(50)) b.ResetTimer() @@ -97,7 +97,7 @@ func BenchmarkRecursiveWideGraph(b *testing.B) { } ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision)), + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting)), WithMaxRecursionDepth(50)) b.ResetTimer() @@ -145,7 +145,7 @@ func BenchmarkRecursiveDeepGraph(b *testing.B) { } ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision)), + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting)), WithMaxRecursionDepth(50)) b.ResetTimer() @@ -175,7 +175,7 @@ func BenchmarkRecursiveEmptyGraph(b *testing.B) { } ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision)), + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting)), WithMaxRecursionDepth(50)) b.ResetTimer() @@ -234,7 +234,7 @@ func BenchmarkRecursiveSparseGraph(b *testing.B) { } ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision)), + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting)), WithMaxRecursionDepth(50)) b.ResetTimer() @@ -282,7 +282,7 @@ func BenchmarkRecursiveCyclicGraph(b *testing.B) { } ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision)), + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting)), WithMaxRecursionDepth(50)) b.ResetTimer() @@ -330,7 +330,7 @@ func BenchmarkRecursiveIterResources(b *testing.B) { } ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision)), + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting)), WithMaxRecursionDepth(50)) b.ResetTimer() diff --git a/pkg/query/recursive_coverage_test.go b/pkg/query/recursive_coverage_test.go index 68423781f..d9941098d 100644 --- a/pkg/query/recursive_coverage_test.go +++ b/pkg/query/recursive_coverage_test.go @@ -111,7 +111,7 @@ func TestBreadthFirstIterResources_MaxDepth(t *testing.T) { // Set a low max depth ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision)), + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting)), WithMaxRecursionDepth(3)) seq, err := recursive.IterResourcesImpl(ctx, ObjectAndRelation{ObjectType: "user", ObjectID: "alice", Relation: "..."}, NoObjectFilter()) @@ -143,7 +143,7 @@ func TestBreadthFirstIterResources_ErrorHandling(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting))) seq, err := recursive.IterResourcesImpl(ctx, ObjectAndRelation{ObjectType: "user", ObjectID: "alice", Relation: "..."}, NoObjectFilter()) require.NoError(err) @@ -168,7 +168,7 @@ func TestBreadthFirstIterResources_ErrorHandling(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting))) seq, err := recursive.IterResourcesImpl(ctx, ObjectAndRelation{ObjectType: "user", ObjectID: "alice", Relation: "..."}, NoObjectFilter()) require.NoError(err) @@ -202,7 +202,7 @@ func TestBreadthFirstIterResources_MergeOrSemantics(t *testing.T) { require.NoError(err) ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision)), + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting)), WithMaxRecursionDepth(5)) seq, err := recursive.IterResourcesImpl(ctx, ObjectAndRelation{ObjectType: "user", ObjectID: "alice", Relation: "..."}, NoObjectFilter()) @@ -230,7 +230,7 @@ func TestIterativeDeepening_MaxDepth(t *testing.T) { maxDepth := 5 ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision)), + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting)), WithMaxRecursionDepth(maxDepth)) seq, err := recursive.CheckImpl(ctx, []Object{{ObjectType: "folder", ObjectID: "folder1"}}, ObjectAndRelation{ObjectType: "user", ObjectID: "alice", Relation: "..."}) diff --git a/pkg/query/recursive_strategies_test.go b/pkg/query/recursive_strategies_test.go index 4382b4159..5f1158560 100644 --- a/pkg/query/recursive_strategies_test.go +++ b/pkg/query/recursive_strategies_test.go @@ -60,7 +60,7 @@ func TestRecursiveCheckStrategies(t *testing.T) { // Contexts contain mutable state (e.g., recursiveFrontierCollectors) // that must not be shared across concurrent goroutines. queryCtx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting))) // Create recursive iterator with the specific strategy recursive := NewRecursiveIterator(union, "folder", "view") @@ -110,7 +110,7 @@ func TestRecursiveCheckStrategiesEmpty(t *testing.T) { recursive := NewRecursiveIterator(emptyFixed, "folder", "view") queryCtx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting))) strategies := []recursiveCheckStrategy{ recursiveCheckIterSubjects, @@ -164,7 +164,7 @@ func TestRecursiveCheckStrategiesMultipleResources(t *testing.T) { require.NoError(t, err) queryCtx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting))) strategies := []recursiveCheckStrategy{ recursiveCheckIterSubjects, diff --git a/pkg/query/recursive_test.go b/pkg/query/recursive_test.go index 1b0ddf4d4..23385afd7 100644 --- a/pkg/query/recursive_test.go +++ b/pkg/query/recursive_test.go @@ -27,7 +27,7 @@ func TestRecursiveSentinel(t *testing.T) { require.NoError(t, err) ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting))) // CheckImpl should return empty seq, err := sentinel.CheckImpl(ctx, []Object{{ObjectType: "folder", ObjectID: "folder1"}}, ObjectAndRelation{ObjectType: "user", ObjectID: "tom", Relation: "..."}) @@ -63,7 +63,7 @@ func TestRecursiveIteratorEmptyBaseCase(t *testing.T) { require.NoError(t, err) ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting))) // Execute - should terminate immediately with empty result seq, err := recursive.CheckImpl(ctx, []Object{{ObjectType: "folder", ObjectID: "folder1"}}, ObjectAndRelation{ObjectType: "user", ObjectID: "tom", Relation: "..."}) @@ -167,7 +167,7 @@ func TestRecursiveIteratorExecutionError(t *testing.T) { require.NoError(t, err) ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting))) // Test CheckImpl with a faulty iterator seq, err := recursive.CheckImpl(ctx, []Object{{ObjectType: "folder", ObjectID: "folder1"}}, ObjectAndRelation{ObjectType: "user", ObjectID: "tom", Relation: "..."}) @@ -198,7 +198,7 @@ func TestRecursiveIteratorCollectionError(t *testing.T) { require.NoError(t, err) ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision))) + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting))) // Test CheckImpl with a faulty iterator that fails on collection seq, err := recursive.CheckImpl(ctx, []Object{{ObjectType: "folder", ObjectID: "folder1"}}, ObjectAndRelation{ObjectType: "user", ObjectID: "tom", Relation: "..."}) @@ -223,7 +223,7 @@ func TestBFSEarlyTermination(t *testing.T) { require.NoError(t, err) ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision)), + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting)), WithMaxRecursionDepth(50)) // High max depth // IterSubjects on a node with no children (sentinel returns empty) @@ -271,7 +271,7 @@ func TestBFSCycleDetection(t *testing.T) { require.NoError(t, err) ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision)), + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting)), WithMaxRecursionDepth(10)) seq, err := recursive.IterSubjectsImpl(ctx, Object{ObjectType: "folder", ObjectID: "folder1"}, NoObjectFilter()) @@ -304,7 +304,7 @@ func TestBFSSelfReferential(t *testing.T) { require.NoError(t, err) ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision)), + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting)), WithMaxRecursionDepth(10)) seq, err := recursive.IterSubjectsImpl(ctx, Object{ObjectType: "folder", ObjectID: "folder1"}, NoObjectFilter()) @@ -346,7 +346,7 @@ func TestBFSResourcesWithEllipses(t *testing.T) { require.NoError(t, err) ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision)), + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(datastore.NoRevision, datalayer.NoSchemaHashForTesting)), WithMaxRecursionDepth(5)) // Query IterResources - should find folder2 diff --git a/pkg/query/simplify_caveat_test.go b/pkg/query/simplify_caveat_test.go index 4de367cdd..391ca1783 100644 --- a/pkg/query/simplify_caveat_test.go +++ b/pkg/query/simplify_caveat_test.go @@ -47,7 +47,7 @@ func TestSimplifyLeafCaveat(t *testing.T) { }) require.NoError(err) - sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))} + sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))} runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) // Create caveat expression without context @@ -134,7 +134,7 @@ func TestSimplifyAndOperation(t *testing.T) { }) require.NoError(err) - sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))} + sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))} runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) // Create AND expression: caveat1 AND caveat2 @@ -238,7 +238,7 @@ func TestSimplifyOrOperation(t *testing.T) { }) require.NoError(err) - sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))} + sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))} runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) // Create OR expression: caveat1 OR caveat2 @@ -354,7 +354,7 @@ func TestSimplifyNestedOperations(t *testing.T) { }) require.NoError(err) - sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))} + sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))} runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) // Create nested expression: (caveat1 OR caveat2) AND caveat3 @@ -439,7 +439,7 @@ func TestSimplifyOrWithSameCaveatDifferentContexts(t *testing.T) { }) require.NoError(err) - sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))} + sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))} runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) // Create OR expression: write_limit(limit=2) OR write_limit(limit=4) @@ -525,7 +525,7 @@ func TestSimplifyAndWithSameCaveatDifferentContexts(t *testing.T) { }) require.NoError(err) - sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))} + sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))} runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) // Create AND expression: write_limit(limit=2) AND write_limit(limit=4) @@ -621,7 +621,7 @@ func TestSimplifyNotWithSameCaveatDifferentContexts(t *testing.T) { }) require.NoError(err) - sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))} + sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))} runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) // Create NOT expression: NOT write_limit(limit=4) @@ -703,7 +703,7 @@ func TestSimplifyComplexNestedExpressions(t *testing.T) { }) require.NoError(err) - sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))} + sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))} runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) t.Run("OrOfAnds_ComplexNesting", func(t *testing.T) { @@ -1170,7 +1170,7 @@ func TestSimplifyWithEmptyContext(t *testing.T) { }) require.NoError(err) - sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))} + sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))} runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) // Create nested expression: (caveat1 OR caveat2) AND caveat3 @@ -1252,7 +1252,7 @@ func TestSimplifyNotConditional(t *testing.T) { }) require.NoError(err) - sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))} + sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))} runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) // Create NOT expression: NOT limit_check(limit=10) @@ -1338,7 +1338,7 @@ func TestSimplifyDeeplyNestedCaveats(t *testing.T) { }) require.NoError(err) - sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision))} + sr := caveatDefinitionLookupAdapter{NewQueryDatastoreReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))} runner := internalcaveats.NewCaveatRunner(caveattypes.Default.TypeSet) // Helper to create caveat expressions diff --git a/pkg/query/tracing_test.go b/pkg/query/tracing_test.go index dedc3b121..8ac4dd547 100644 --- a/pkg/query/tracing_test.go +++ b/pkg/query/tracing_test.go @@ -24,7 +24,7 @@ func TestIteratorTracing(t *testing.T) { require.NoError(t, err) ctx := NewLocalContext(context.Background(), - WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision)), + WithRevisionedReader(datalayer.NewDataLayer(ds).SnapshotReader(revision, datalayer.NoSchemaHashForTesting)), WithTraceLogger(traceLogger), ) diff --git a/pkg/query/wildcard_multirelation_test.go b/pkg/query/wildcard_multirelation_test.go index 4ee41ada4..3f19039e6 100644 --- a/pkg/query/wildcard_multirelation_test.go +++ b/pkg/query/wildcard_multirelation_test.go @@ -82,7 +82,7 @@ func TestIterSubjectsWildcardWithMultipleRelations(t *testing.T) { wildcardBranch := NewDatastoreIterator(viewerRel.BaseRelations()[1]) // user:* (wildcard) queryCtx := NewLocalContext(ctx, - WithRevisionedReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision)), + WithRevisionedReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision, datalayer.NoSchemaHashForTesting)), WithTraceLogger(NewTraceLogger())) // Enable tracing for debugging subjects, err := queryCtx.IterSubjects(wildcardBranch, NewObject("document", "publicdoc"), NoObjectFilter()) require.NoError(err) @@ -114,7 +114,7 @@ func TestIterSubjectsWildcardWithMultipleRelations(t *testing.T) { ) queryCtx := NewLocalContext(ctx, - WithRevisionedReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision)), + WithRevisionedReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision, datalayer.NoSchemaHashForTesting)), WithTraceLogger(NewTraceLogger())) // Enable tracing for debugging subjects, err := queryCtx.IterSubjects(union, NewObject("document", "publicdoc"), NoObjectFilter()) require.NoError(err) diff --git a/pkg/query/wildcard_subjects_test.go b/pkg/query/wildcard_subjects_test.go index 39eeb276b..f4046c101 100644 --- a/pkg/query/wildcard_subjects_test.go +++ b/pkg/query/wildcard_subjects_test.go @@ -72,7 +72,7 @@ func TestIterSubjectsWithWildcard(t *testing.T) { // The non-wildcard branch should only return concrete subjects, filtering out wildcards nonWildcardBranch := NewDatastoreIterator(viewerRel.BaseRelations()[0]) // user (non-wildcard) - queryCtx := NewLocalContext(ctx, WithRevisionedReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision))) + queryCtx := NewLocalContext(ctx, WithRevisionedReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) subjects, err := queryCtx.IterSubjects(nonWildcardBranch, NewObject("resource", "first"), NoObjectFilter()) require.NoError(err) @@ -91,7 +91,7 @@ func TestIterSubjectsWithWildcard(t *testing.T) { // The wildcard branch should enumerate concrete subjects when a wildcard exists wildcardBranch := NewDatastoreIterator(viewerRel.BaseRelations()[1]) // user:* (wildcard) - queryCtx := NewLocalContext(ctx, WithRevisionedReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision))) + queryCtx := NewLocalContext(ctx, WithRevisionedReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) subjects, err := queryCtx.IterSubjects(wildcardBranch, NewObject("resource", "first"), NoObjectFilter()) require.NoError(err) @@ -113,7 +113,7 @@ func TestIterSubjectsWithWildcard(t *testing.T) { NewDatastoreIterator(viewerRel.BaseRelations()[1]), // user:* ) - queryCtx := NewLocalContext(ctx, WithRevisionedReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision))) + queryCtx := NewLocalContext(ctx, WithRevisionedReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) subjects, err := queryCtx.IterSubjects(union, NewObject("resource", "first"), NoObjectFilter()) require.NoError(err) @@ -180,7 +180,7 @@ func TestIterSubjectsWildcardWithoutWildcardRelationship(t *testing.T) { // The wildcard branch should return empty because there's no wildcard relationship wildcardBranch := NewDatastoreIterator(viewerRel.BaseRelations()[1]) // user:* (wildcard) - queryCtx := NewLocalContext(ctx, WithRevisionedReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision))) + queryCtx := NewLocalContext(ctx, WithRevisionedReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) subjects, err := queryCtx.IterSubjects(wildcardBranch, NewObject("resource", "second"), NoObjectFilter()) require.NoError(err) @@ -196,7 +196,7 @@ func TestIterSubjectsWildcardWithoutWildcardRelationship(t *testing.T) { t.Parallel() nonWildcardBranch := NewDatastoreIterator(viewerRel.BaseRelations()[0]) // user (non-wildcard) - queryCtx := NewLocalContext(ctx, WithRevisionedReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision))) + queryCtx := NewLocalContext(ctx, WithRevisionedReader(datalayer.NewDataLayer(rawDS).SnapshotReader(revision, datalayer.NoSchemaHashForTesting))) subjects, err := queryCtx.IterSubjects(nonWildcardBranch, NewObject("resource", "second"), NoObjectFilter()) require.NoError(err) diff --git a/pkg/schemadsl/generator/generator.go b/pkg/schemadsl/generator/generator.go index 11a555ddc..1a3e5eac2 100644 --- a/pkg/schemadsl/generator/generator.go +++ b/pkg/schemadsl/generator/generator.go @@ -3,6 +3,8 @@ package generator import ( "bufio" "context" + "crypto/sha256" + "encoding/hex" "fmt" "maps" "slices" @@ -33,6 +35,25 @@ func GenerateSchema(definitions []compiler.SchemaDefinition) (string, bool, erro return GenerateSchemaWithCaveatTypeSet(context.TODO(), definitions, caveattypes.Default.TypeSet) } +// ComputeSchemaHash computes a SHA256 hash of the given schema definitions. +// Definitions are sorted by name before generating the schema text for consistent ordering. +func ComputeSchemaHash(definitions []compiler.SchemaDefinition) (string, error) { + sorted := make([]compiler.SchemaDefinition, len(definitions)) + copy(sorted, definitions) + + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].GetName() < sorted[j].GetName() + }) + + schemaText, _, err := GenerateSchema(sorted) + if err != nil { + return "", fmt.Errorf("failed to generate schema for hashing: %w", err) + } + + hash := sha256.Sum256([]byte(schemaText)) + return hex.EncodeToString(hash[:]), nil +} + // GenerateSchemaWithCaveatTypeSet generates a DSL view of the given schema. func GenerateSchemaWithCaveatTypeSet(ctx context.Context, definitions []compiler.SchemaDefinition, caveatTypeSet *caveattypes.TypeSet) (string, bool, error) { _, span := tracer.Start(ctx, "GenerateSchemaWithCaveatTypeSet") diff --git a/pkg/services/v1/services.go b/pkg/services/v1/services.go index d1ce37e60..a909128b9 100644 --- a/pkg/services/v1/services.go +++ b/pkg/services/v1/services.go @@ -15,5 +15,5 @@ import ( // If no cursor is provided, it will fallback to the provided revision. func BulkExport(ctx context.Context, ds datastore.ReadOnlyDatastore, batchSize uint64, req *v1.BulkExportRelationshipsRequest, fallbackRevision datastore.Revision, sender func(response *v1.BulkExportRelationshipsResponse) error) error { dl := datalayer.NewReadOnlyDataLayer(ds) - return servicesv1.BulkExport(ctx, dl, batchSize, req, fallbackRevision, sender) + return servicesv1.BulkExport(ctx, dl, batchSize, req, fallbackRevision, datalayer.NoSchemaHashForLegacyCursor, sender) } diff --git a/proto/internal/core/v1/core.proto b/proto/internal/core/v1/core.proto index f4792ff37..527c12505 100644 --- a/proto/internal/core/v1/core.proto +++ b/proto/internal/core/v1/core.proto @@ -619,3 +619,28 @@ message SubjectFilter { RelationFilter optional_relation = 3; } + +// StoredSchema represents a stored schema in SpiceDB under the new, unified schema format. +message StoredSchema { + uint32 version = 1; + + message V1StoredSchema { + // schema_text is the text of the schema that was given to SpiceDB by the API caller. + string schema_text = 1; + + // schema_hash is the hash of the schema, for change detection. + string schema_hash = 2; + + // namespace_definitions is a map of namespace name to NamespaceDefinition. + // Entries must have full metadata filled out. + map namespace_definitions = 3; + + // caveat_definitions is a map of caveat name to CaveatDefinition. + // Entries must have full metadata filled out. + map caveat_definitions = 4; + } + + oneof version_oneof { + V1StoredSchema v1 = 2; + } +} diff --git a/proto/internal/dispatch/v1/dispatch.proto b/proto/internal/dispatch/v1/dispatch.proto index 57f6961b1..085180f09 100644 --- a/proto/internal/dispatch/v1/dispatch.proto +++ b/proto/internal/dispatch/v1/dispatch.proto @@ -178,6 +178,7 @@ message ResolverMeta { uint32 depth_remaining = 2 [(buf.validate.field).uint32.gt = 0]; string request_id = 3 [deprecated = true]; bytes traversal_bloom = 4 [(buf.validate.field).bytes = {max_len: 1024}]; + bytes schema_hash = 5; } message ResponseMeta { diff --git a/proto/internal/impl/v1/impl.proto b/proto/internal/impl/v1/impl.proto index cd85ece0c..e6b6dbb09 100644 --- a/proto/internal/impl/v1/impl.proto +++ b/proto/internal/impl/v1/impl.proto @@ -71,6 +71,10 @@ message V1Cursor { // datastore_unique_id is the unique ID for the datastore. Will be empty for legacy // cursors. string datastore_unique_id = 6; + + // schema_hash is the hash of the schema at the time the cursor was created. Will be + // empty for legacy cursors. + bytes schema_hash = 7; } message DocComment {