From 5c953dee55990d7328016ecfeae952d80e538323 Mon Sep 17 00:00:00 2001 From: ivanauth Date: Mon, 15 Dec 2025 15:05:16 -0500 Subject: [PATCH 1/7] fix: improve error messages for missing database tables When datastore operations fail due to missing database tables (typically because migrations haven't been run), users now see a clear error message: 'please ensure you have run spicedb datastore migrate' This improves the user experience by providing actionable guidance instead of raw database errors like 'relation X does not exist'. The improvement covers all reader operations across CRDB, PostgreSQL, and MySQL datastores including namespace, caveat, counter, and statistics queries. --- internal/datastore/common/errors.go | 29 +++++ internal/datastore/common/errors_test.go | 43 +++++++ internal/datastore/crdb/caveat.go | 7 ++ internal/datastore/crdb/reader.go | 15 +++ internal/datastore/crdb/stats.go | 12 ++ internal/datastore/mysql/caveat.go | 7 ++ internal/datastore/mysql/common/errors.go | 38 ++++++ .../datastore/mysql/common/errors_test.go | 112 ++++++++++++++++++ internal/datastore/mysql/reader.go | 16 +++ internal/datastore/mysql/stats.go | 16 +++ internal/datastore/postgres/caveat.go | 7 ++ internal/datastore/postgres/common/errors.go | 24 ++++ .../datastore/postgres/common/errors_test.go | 112 ++++++++++++++++++ internal/datastore/postgres/reader.go | 15 +++ internal/datastore/postgres/readwrite.go | 4 + internal/datastore/postgres/revisions.go | 3 + internal/datastore/postgres/stats.go | 15 +++ 17 files changed, 475 insertions(+) create mode 100644 internal/datastore/common/errors_test.go create mode 100644 internal/datastore/mysql/common/errors.go create mode 100644 internal/datastore/mysql/common/errors_test.go create mode 100644 internal/datastore/postgres/common/errors_test.go diff --git a/internal/datastore/common/errors.go b/internal/datastore/common/errors.go index 5344873c6..072bbf48c 100644 --- a/internal/datastore/common/errors.go +++ b/internal/datastore/common/errors.go @@ -176,3 +176,32 @@ type RevisionUnavailableError struct { func NewRevisionUnavailableError(err error) error { return RevisionUnavailableError{err} } + +// SchemaNotInitializedError is returned when a datastore operation fails because the +// required database tables do not exist. This typically means that migrations have not been run. +type SchemaNotInitializedError struct { + error +} + +func (err SchemaNotInitializedError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.FailedPrecondition, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_UNSPECIFIED, + map[string]string{}, + ), + ) +} + +func (err SchemaNotInitializedError) Unwrap() error { + return err.error +} + +// NewSchemaNotInitializedError creates a new SchemaNotInitializedError with a helpful message +// instructing the user to run migrations. +func NewSchemaNotInitializedError(underlying error) error { + return SchemaNotInitializedError{ + fmt.Errorf("datastore error: the database schema has not been initialized; please run \"spicedb datastore migrate\": %w", underlying), + } +} diff --git a/internal/datastore/common/errors_test.go b/internal/datastore/common/errors_test.go new file mode 100644 index 000000000..6d586add8 --- /dev/null +++ b/internal/datastore/common/errors_test.go @@ -0,0 +1,43 @@ +package common + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" +) + +func TestSchemaNotInitializedError(t *testing.T) { + underlyingErr := fmt.Errorf("relation \"caveat\" does not exist (SQLSTATE 42P01)") + err := NewSchemaNotInitializedError(underlyingErr) + + t.Run("error message contains migration instructions", func(t *testing.T) { + require.Contains(t, err.Error(), "spicedb datastore migrate") + require.Contains(t, err.Error(), "database schema has not been initialized") + }) + + t.Run("unwrap returns underlying error", func(t *testing.T) { + var schemaErr SchemaNotInitializedError + require.ErrorAs(t, err, &schemaErr) + require.ErrorIs(t, schemaErr.Unwrap(), underlyingErr) + }) + + t.Run("grpc status is FailedPrecondition", func(t *testing.T) { + var schemaErr SchemaNotInitializedError + require.ErrorAs(t, err, &schemaErr) + status := schemaErr.GRPCStatus() + require.Equal(t, codes.FailedPrecondition, status.Code()) + }) + + t.Run("can be detected with errors.As", func(t *testing.T) { + var schemaErr SchemaNotInitializedError + require.ErrorAs(t, err, &schemaErr) + }) + + t.Run("wrapped error preserves chain", func(t *testing.T) { + wrappedErr := fmt.Errorf("outer: %w", err) + var schemaErr SchemaNotInitializedError + require.ErrorAs(t, wrappedErr, &schemaErr) + }) +} diff --git a/internal/datastore/crdb/caveat.go b/internal/datastore/crdb/caveat.go index 5857a5702..03f46783b 100644 --- a/internal/datastore/crdb/caveat.go +++ b/internal/datastore/crdb/caveat.go @@ -10,6 +10,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/authzed/spicedb/internal/datastore/crdb/schema" + pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" "github.com/authzed/spicedb/internal/datastore/revisions" "github.com/authzed/spicedb/pkg/datastore" core "github.com/authzed/spicedb/pkg/proto/core/v1" @@ -53,6 +54,9 @@ func (cr *crdbReader) LegacyReadCaveatByName(ctx context.Context, name string) ( if errors.Is(err, pgx.ErrNoRows) { err = datastore.NewCaveatNameNotFoundErr(name) } + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return nil, datastore.NoRevision, wrappedErr + } return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, name, err) } @@ -109,6 +113,9 @@ func (cr *crdbReader) lookupCaveats(ctx context.Context, caveatNames []string) ( return nil }, sql, args...) if err != nil { + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return nil, wrappedErr + } return nil, fmt.Errorf(errListCaveats, err) } diff --git a/internal/datastore/crdb/reader.go b/internal/datastore/crdb/reader.go index 5bf2c9b1b..f6cd3f33e 100644 --- a/internal/datastore/crdb/reader.go +++ b/internal/datastore/crdb/reader.go @@ -129,6 +129,9 @@ func (cr *crdbReader) CountRelationships(ctx context.Context, name string) (int, return row.Scan(&count) }, sql, args...) if err != nil { + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return 0, wrappedErr + } return 0, err } @@ -193,6 +196,9 @@ func (cr *crdbReader) lookupCounters(ctx context.Context, optionalFilterName str return nil }, sql, args...) if err != nil { + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return nil, wrappedErr + } return nil, err } @@ -208,6 +214,9 @@ func (cr *crdbReader) LegacyReadNamespaceByName( if errors.As(err, &datastore.NamespaceNotFoundError{}) { return nil, datastore.NoRevision, err } + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return nil, datastore.NoRevision, wrappedErr + } return nil, datastore.NoRevision, fmt.Errorf(errUnableToReadConfig, err) } @@ -221,6 +230,9 @@ func (cr *crdbReader) LegacyListAllNamespaces(ctx context.Context) ([]datastore. nsDefs, sql, err := loadAllNamespaces(ctx, cr.query, addFromToQuery) if err != nil { + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return nil, wrappedErr + } return nil, fmt.Errorf(errUnableToListNamespaces, err) } cr.assertHasExpectedAsOfSystemTime(sql) @@ -233,6 +245,9 @@ func (cr *crdbReader) LegacyLookupNamespacesWithNames(ctx context.Context, nsNam } nsDefs, err := cr.lookupNamespaces(ctx, cr.query, nsNames) if err != nil { + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return nil, wrappedErr + } return nil, fmt.Errorf(errUnableToListNamespaces, err) } return nsDefs, nil diff --git a/internal/datastore/crdb/stats.go b/internal/datastore/crdb/stats.go index 6ec65c00a..7e9954294 100644 --- a/internal/datastore/crdb/stats.go +++ b/internal/datastore/crdb/stats.go @@ -33,6 +33,9 @@ func (cds *crdbDatastore) UniqueID(ctx context.Context) (string, error) { if err := cds.readPool.QueryRowFunc(ctx, func(ctx context.Context, row pgx.Row) error { return row.Scan(&uniqueID) }, sql, args...); err != nil { + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return "", wrappedErr + } return "", fmt.Errorf("unable to query unique ID: %w", err) } @@ -59,6 +62,9 @@ func (cds *crdbDatastore) Statistics(ctx context.Context) (datastore.Stats, erro return sb.From(tableName) }) if err != nil { + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return wrappedErr + } return fmt.Errorf("unable to read namespaces: %w", err) } return nil @@ -69,6 +75,9 @@ func (cds *crdbDatastore) Statistics(ctx context.Context) (datastore.Stats, erro if cds.analyzeBeforeStatistics { if err := cds.readPool.BeginTxFunc(ctx, pgx.TxOptions{AccessMode: pgx.ReadOnly}, func(tx pgx.Tx) error { if _, err := tx.Exec(ctx, "ANALYZE "+cds.schema.RelationshipTableName); err != nil { + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return wrappedErr + } return fmt.Errorf("unable to analyze tuple table: %w", err) } @@ -143,6 +152,9 @@ func (cds *crdbDatastore) Statistics(ctx context.Context) (datastore.Stats, erro log.Warn().Bool("has-rows", hasRows).Msg("unable to find row count in statistics query result") return nil }, "SHOW STATISTICS FOR TABLE "+cds.schema.RelationshipTableName); err != nil { + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return datastore.Stats{}, wrappedErr + } return datastore.Stats{}, fmt.Errorf("unable to query unique estimated row count: %w", err) } diff --git a/internal/datastore/mysql/caveat.go b/internal/datastore/mysql/caveat.go index bb48ad509..f6d2d5cf4 100644 --- a/internal/datastore/mysql/caveat.go +++ b/internal/datastore/mysql/caveat.go @@ -9,6 +9,7 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/authzed/spicedb/internal/datastore/common" + mysqlcommon "github.com/authzed/spicedb/internal/datastore/mysql/common" "github.com/authzed/spicedb/internal/datastore/revisions" "github.com/authzed/spicedb/pkg/datastore" core "github.com/authzed/spicedb/pkg/proto/core/v1" @@ -41,6 +42,9 @@ func (mr *mysqlReader) LegacyReadCaveatByName(ctx context.Context, name string) if errors.Is(err, sql.ErrNoRows) { return nil, datastore.NoRevision, datastore.NewCaveatNameNotFoundErr(name) } + if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { + return nil, datastore.NoRevision, wrappedErr + } return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, err) } def := core.CaveatDefinition{} @@ -82,6 +86,9 @@ func (mr *mysqlReader) lookupCaveats(ctx context.Context, caveatNames []string) rows, err := tx.QueryContext(ctx, listSQL, listArgs...) if err != nil { + if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { + return nil, wrappedErr + } return nil, fmt.Errorf(errListCaveats, err) } defer common.LogOnError(ctx, rows.Close) diff --git a/internal/datastore/mysql/common/errors.go b/internal/datastore/mysql/common/errors.go new file mode 100644 index 000000000..7eaf73393 --- /dev/null +++ b/internal/datastore/mysql/common/errors.go @@ -0,0 +1,38 @@ +package common + +import ( + "errors" + + "github.com/go-sql-driver/mysql" + + dscommon "github.com/authzed/spicedb/internal/datastore/common" +) + +const ( + // mysqlMissingTableErrorNumber is the MySQL error number for "table doesn't exist". + // This corresponds to MySQL error 1146 (ER_NO_SUCH_TABLE) with SQLSTATE 42S02. + mysqlMissingTableErrorNumber = 1146 +) + +// IsMissingTableError returns true if the error is a MySQL error indicating a missing table. +// This typically happens when migrations have not been run. +func IsMissingTableError(err error) bool { + var mysqlErr *mysql.MySQLError + return errors.As(err, &mysqlErr) && mysqlErr.Number == mysqlMissingTableErrorNumber +} + +// WrapMissingTableError checks if the error is a missing table error and wraps it with +// a helpful message instructing the user to run migrations. If it's not a missing table error, +// it returns nil. If it's already a SchemaNotInitializedError, it returns the original error +// to preserve the wrapped error through the call chain. +func WrapMissingTableError(err error) error { + // Don't double-wrap if already a SchemaNotInitializedError - return original to preserve it + var schemaErr dscommon.SchemaNotInitializedError + if errors.As(err, &schemaErr) { + return err + } + if IsMissingTableError(err) { + return dscommon.NewSchemaNotInitializedError(err) + } + return nil +} diff --git a/internal/datastore/mysql/common/errors_test.go b/internal/datastore/mysql/common/errors_test.go new file mode 100644 index 000000000..6f2498491 --- /dev/null +++ b/internal/datastore/mysql/common/errors_test.go @@ -0,0 +1,112 @@ +package common + +import ( + "fmt" + "testing" + + "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/require" + + dscommon "github.com/authzed/spicedb/internal/datastore/common" +) + +func TestIsMissingTableError(t *testing.T) { + t.Run("returns true for missing table error", func(t *testing.T) { + mysqlErr := &mysql.MySQLError{ + Number: mysqlMissingTableErrorNumber, + Message: "Table 'spicedb.caveat' doesn't exist", + } + require.True(t, IsMissingTableError(mysqlErr)) + }) + + t.Run("returns false for other mysql errors", func(t *testing.T) { + mysqlErr := &mysql.MySQLError{ + Number: 1062, // Duplicate entry error + Message: "Duplicate entry '1' for key 'PRIMARY'", + } + require.False(t, IsMissingTableError(mysqlErr)) + }) + + t.Run("returns false for non-mysql errors", func(t *testing.T) { + err := fmt.Errorf("some other error") + require.False(t, IsMissingTableError(err)) + }) + + t.Run("returns false for nil error", func(t *testing.T) { + require.False(t, IsMissingTableError(nil)) + }) + + t.Run("returns true for wrapped missing table error", func(t *testing.T) { + mysqlErr := &mysql.MySQLError{ + Number: mysqlMissingTableErrorNumber, + Message: "Table 'spicedb.caveat' doesn't exist", + } + wrappedErr := fmt.Errorf("query failed: %w", mysqlErr) + require.True(t, IsMissingTableError(wrappedErr)) + }) +} + +func TestWrapMissingTableError(t *testing.T) { + t.Run("wraps missing table error", func(t *testing.T) { + mysqlErr := &mysql.MySQLError{ + Number: mysqlMissingTableErrorNumber, + Message: "Table 'spicedb.caveat' doesn't exist", + } + wrapped := WrapMissingTableError(mysqlErr) + require.Error(t, wrapped) + + var schemaErr dscommon.SchemaNotInitializedError + require.ErrorAs(t, wrapped, &schemaErr) + require.Contains(t, wrapped.Error(), "spicedb datastore migrate") + }) + + t.Run("returns nil for non-missing-table errors", func(t *testing.T) { + mysqlErr := &mysql.MySQLError{ + Number: 1062, // Duplicate entry error + Message: "Duplicate entry '1' for key 'PRIMARY'", + } + require.NoError(t, WrapMissingTableError(mysqlErr)) + }) + + t.Run("returns nil for non-mysql errors", func(t *testing.T) { + err := fmt.Errorf("some other error") + require.NoError(t, WrapMissingTableError(err)) + }) + + t.Run("returns nil for nil error", func(t *testing.T) { + require.NoError(t, WrapMissingTableError(nil)) + }) + + t.Run("preserves original error in chain", func(t *testing.T) { + mysqlErr := &mysql.MySQLError{ + Number: mysqlMissingTableErrorNumber, + Message: "Table 'spicedb.caveat' doesn't exist", + } + wrapped := WrapMissingTableError(mysqlErr) + require.Error(t, wrapped) + + // The original mysql error should be accessible via unwrapping + var foundMySQLErr *mysql.MySQLError + require.ErrorAs(t, wrapped, &foundMySQLErr) + require.Equal(t, uint16(mysqlMissingTableErrorNumber), foundMySQLErr.Number) + }) + + t.Run("does not double-wrap already wrapped errors", func(t *testing.T) { + mysqlErr := &mysql.MySQLError{ + Number: mysqlMissingTableErrorNumber, + Message: "Table 'spicedb.caveat' doesn't exist", + } + // First wrap + wrapped := WrapMissingTableError(mysqlErr) + require.Error(t, wrapped) + + // Second wrap should return the already-wrapped error (preserving it through call chain) + doubleWrapped := WrapMissingTableError(wrapped) + require.Error(t, doubleWrapped) + require.Equal(t, wrapped, doubleWrapped) + + // Should still be detectable as SchemaNotInitializedError + var schemaErr dscommon.SchemaNotInitializedError + require.ErrorAs(t, doubleWrapped, &schemaErr) + }) +} diff --git a/internal/datastore/mysql/reader.go b/internal/datastore/mysql/reader.go index 33e056d9d..0c0bf7c45 100644 --- a/internal/datastore/mysql/reader.go +++ b/internal/datastore/mysql/reader.go @@ -9,6 +9,7 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/authzed/spicedb/internal/datastore/common" + mysqlcommon "github.com/authzed/spicedb/internal/datastore/mysql/common" "github.com/authzed/spicedb/internal/datastore/revisions" schemautil "github.com/authzed/spicedb/internal/datastore/schema" "github.com/authzed/spicedb/pkg/datastore" @@ -76,6 +77,9 @@ func (mr *mysqlReader) CountRelationships(ctx context.Context, name string) (int var count int rows, err := tx.QueryContext(ctx, sql, args...) if err != nil { + if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { + return 0, wrappedErr + } return 0, err } defer common.LogOnError(ctx, rows.Close) @@ -123,6 +127,9 @@ func (mr *mysqlReader) lookupCounters(ctx context.Context, optionalName string) rows, err := tx.QueryContext(ctx, sql, args...) if err != nil { + if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { + return nil, wrappedErr + } return nil, err } defer common.LogOnError(ctx, rows.Close) @@ -223,6 +230,9 @@ func (mr *mysqlReader) LegacyReadNamespaceByName(ctx context.Context, nsName str case err == nil: return loaded, version, nil default: + if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { + return nil, datastore.NoRevision, wrappedErr + } return nil, datastore.NoRevision, fmt.Errorf(errUnableToReadConfig, err) } } @@ -265,6 +275,9 @@ func (mr *mysqlReader) LegacyListAllNamespaces(ctx context.Context) ([]datastore nsDefs, err := loadAllNamespaces(ctx, tx, query) if err != nil { + if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { + return nil, wrappedErr + } return nil, fmt.Errorf(errUnableToListNamespaces, err) } @@ -291,6 +304,9 @@ func (mr *mysqlReader) LegacyLookupNamespacesWithNames(ctx context.Context, nsNa nsDefs, err := loadAllNamespaces(ctx, tx, query) if err != nil { + if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { + return nil, wrappedErr + } return nil, fmt.Errorf(errUnableToListNamespaces, err) } diff --git a/internal/datastore/mysql/stats.go b/internal/datastore/mysql/stats.go index dfbdc8618..a16ae51c4 100644 --- a/internal/datastore/mysql/stats.go +++ b/internal/datastore/mysql/stats.go @@ -9,6 +9,7 @@ import ( "github.com/ccoveille/go-safecast/v2" "github.com/authzed/spicedb/internal/datastore/common" + mysqlcommon "github.com/authzed/spicedb/internal/datastore/mysql/common" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/spiceerrors" ) @@ -26,6 +27,9 @@ func (mds *mysqlDatastore) Statistics(ctx context.Context) (datastore.Stats, err if mds.analyzeBeforeStats { _, err := mds.db.ExecContext(ctx, "ANALYZE TABLE "+mds.driver.RelationTuple()) if err != nil { + if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { + return datastore.Stats{}, wrappedErr + } return datastore.Stats{}, fmt.Errorf("unable to run ANALYZE TABLE: %w", err) } } @@ -47,6 +51,9 @@ func (mds *mysqlDatastore) Statistics(ctx context.Context) (datastore.Stats, err var count sql.NullInt64 err = mds.db.QueryRowContext(ctx, query, args...).Scan(&count) if err != nil { + if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { + return datastore.Stats{}, wrappedErr + } return datastore.Stats{}, err } @@ -59,6 +66,9 @@ func (mds *mysqlDatastore) Statistics(ctx context.Context) (datastore.Stats, err } err = mds.db.QueryRowContext(ctx, query, args...).Scan(&count) if err != nil { + if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { + return datastore.Stats{}, wrappedErr + } return datastore.Stats{}, err } } @@ -73,6 +83,9 @@ func (mds *mysqlDatastore) Statistics(ctx context.Context) (datastore.Stats, err nsDefs, err := loadAllNamespaces(ctx, tx, nsQuery) if err != nil { + if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { + return datastore.Stats{}, wrappedErr + } return datastore.Stats{}, fmt.Errorf("unable to load namespaces: %w", err) } @@ -97,6 +110,9 @@ func (mds *mysqlDatastore) UniqueID(ctx context.Context) (string, error) { var uniqueID string if err := mds.db.QueryRowContext(ctx, sql, args...).Scan(&uniqueID); err != nil { + if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { + return "", wrappedErr + } return "", fmt.Errorf("unable to query unique ID: %w", err) } mds.uniqueID.Store(&uniqueID) diff --git a/internal/datastore/postgres/caveat.go b/internal/datastore/postgres/caveat.go index 535632096..483a874ea 100644 --- a/internal/datastore/postgres/caveat.go +++ b/internal/datastore/postgres/caveat.go @@ -8,6 +8,7 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/jackc/pgx/v5" + pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" "github.com/authzed/spicedb/internal/datastore/postgres/schema" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/genutil/mapz" @@ -49,6 +50,9 @@ func (r *pgReader) LegacyReadCaveatByName(ctx context.Context, name string) (*co if errors.Is(err, pgx.ErrNoRows) { return nil, datastore.NoRevision, datastore.NewCaveatNameNotFoundErr(name) } + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return nil, datastore.NoRevision, wrappedErr + } return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, err) } def := core.CaveatDefinition{} @@ -106,6 +110,9 @@ func (r *pgReader) lookupCaveats(ctx context.Context, caveatNames []string) ([]d return rows.Err() }, sql, args...) if err != nil { + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return nil, wrappedErr + } return nil, fmt.Errorf(errListCaveats, err) } diff --git a/internal/datastore/postgres/common/errors.go b/internal/datastore/postgres/common/errors.go index 2728252aa..94fbb2e78 100644 --- a/internal/datastore/postgres/common/errors.go +++ b/internal/datastore/postgres/common/errors.go @@ -19,6 +19,7 @@ const ( pgReadOnlyTransaction = "25006" pgQueryCanceled = "57014" pgInvalidArgument = "22023" + pgMissingTable = "42P01" ) var ( @@ -106,3 +107,26 @@ func ConvertToWriteConstraintError(livingTupleConstraints []string, err error) e return nil } + +// IsMissingTableError returns true if the error is a Postgres error indicating a missing table. +// This typically happens when migrations have not been run. +func IsMissingTableError(err error) bool { + var pgerr *pgconn.PgError + return errors.As(err, &pgerr) && pgerr.Code == pgMissingTable +} + +// WrapMissingTableError checks if the error is a missing table error and wraps it with +// a helpful message instructing the user to run migrations. If it's not a missing table error, +// it returns nil. If it's already a SchemaNotInitializedError, it returns the original error +// to preserve the wrapped error through the call chain. +func WrapMissingTableError(err error) error { + // Don't double-wrap if already a SchemaNotInitializedError - return original to preserve it + var schemaErr dscommon.SchemaNotInitializedError + if errors.As(err, &schemaErr) { + return err + } + if IsMissingTableError(err) { + return dscommon.NewSchemaNotInitializedError(err) + } + return nil +} diff --git a/internal/datastore/postgres/common/errors_test.go b/internal/datastore/postgres/common/errors_test.go new file mode 100644 index 000000000..27ec49a88 --- /dev/null +++ b/internal/datastore/postgres/common/errors_test.go @@ -0,0 +1,112 @@ +package common + +import ( + "fmt" + "testing" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/require" + + dscommon "github.com/authzed/spicedb/internal/datastore/common" +) + +func TestIsMissingTableError(t *testing.T) { + t.Run("returns true for missing table error", func(t *testing.T) { + pgErr := &pgconn.PgError{ + Code: pgMissingTable, + Message: "relation \"caveat\" does not exist", + } + require.True(t, IsMissingTableError(pgErr)) + }) + + t.Run("returns false for other postgres errors", func(t *testing.T) { + pgErr := &pgconn.PgError{ + Code: pgSerializationFailure, + Message: "could not serialize access", + } + require.False(t, IsMissingTableError(pgErr)) + }) + + t.Run("returns false for non-postgres errors", func(t *testing.T) { + err := fmt.Errorf("some other error") + require.False(t, IsMissingTableError(err)) + }) + + t.Run("returns false for nil error", func(t *testing.T) { + require.False(t, IsMissingTableError(nil)) + }) + + t.Run("returns true for wrapped missing table error", func(t *testing.T) { + pgErr := &pgconn.PgError{ + Code: pgMissingTable, + Message: "relation \"caveat\" does not exist", + } + wrappedErr := fmt.Errorf("query failed: %w", pgErr) + require.True(t, IsMissingTableError(wrappedErr)) + }) +} + +func TestWrapMissingTableError(t *testing.T) { + t.Run("wraps missing table error", func(t *testing.T) { + pgErr := &pgconn.PgError{ + Code: pgMissingTable, + Message: "relation \"caveat\" does not exist", + } + wrapped := WrapMissingTableError(pgErr) + require.Error(t, wrapped) + + var schemaErr dscommon.SchemaNotInitializedError + require.ErrorAs(t, wrapped, &schemaErr) + require.Contains(t, wrapped.Error(), "spicedb datastore migrate") + }) + + t.Run("returns nil for non-missing-table errors", func(t *testing.T) { + pgErr := &pgconn.PgError{ + Code: pgSerializationFailure, + Message: "could not serialize access", + } + require.NoError(t, WrapMissingTableError(pgErr)) + }) + + t.Run("returns nil for non-postgres errors", func(t *testing.T) { + err := fmt.Errorf("some other error") + require.NoError(t, WrapMissingTableError(err)) + }) + + t.Run("returns nil for nil error", func(t *testing.T) { + require.NoError(t, WrapMissingTableError(nil)) + }) + + t.Run("preserves original error in chain", func(t *testing.T) { + pgErr := &pgconn.PgError{ + Code: pgMissingTable, + Message: "relation \"caveat\" does not exist", + } + wrapped := WrapMissingTableError(pgErr) + require.Error(t, wrapped) + + // The original postgres error should be accessible via unwrapping + var foundPgErr *pgconn.PgError + require.ErrorAs(t, wrapped, &foundPgErr) + require.Equal(t, pgMissingTable, foundPgErr.Code) + }) + + t.Run("does not double-wrap already wrapped errors", func(t *testing.T) { + pgErr := &pgconn.PgError{ + Code: pgMissingTable, + Message: "relation \"caveat\" does not exist", + } + // First wrap + wrapped := WrapMissingTableError(pgErr) + require.Error(t, wrapped) + + // Second wrap should return the already-wrapped error (preserving it through call chain) + doubleWrapped := WrapMissingTableError(wrapped) + require.Error(t, doubleWrapped) + require.Equal(t, wrapped, doubleWrapped) + + // Should still be detectable as SchemaNotInitializedError + var schemaErr dscommon.SchemaNotInitializedError + require.ErrorAs(t, doubleWrapped, &schemaErr) + }) +} diff --git a/internal/datastore/postgres/reader.go b/internal/datastore/postgres/reader.go index ce40b06c8..4f75017d3 100644 --- a/internal/datastore/postgres/reader.go +++ b/internal/datastore/postgres/reader.go @@ -85,6 +85,9 @@ func (r *pgReader) CountRelationships(ctx context.Context, name string) (int, er return rows.Err() }, sql, args...) if err != nil { + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return 0, wrappedErr + } return 0, err } @@ -140,6 +143,9 @@ func (r *pgReader) lookupCounters(ctx context.Context, optionalName string) ([]d return rows.Err() }, sql, args...) if err != nil { + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return nil, wrappedErr + } return nil, fmt.Errorf("unable to query counters: %w", err) } @@ -208,6 +214,9 @@ func (r *pgReader) LegacyReadNamespaceByName(ctx context.Context, nsName string) case err == nil: return loaded, version, nil default: + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return nil, datastore.NoRevision, wrappedErr + } return nil, datastore.NoRevision, fmt.Errorf(errUnableToReadConfig, err) } } @@ -233,6 +242,9 @@ func (r *pgReader) loadNamespace(ctx context.Context, namespace string, tx pgxco func (r *pgReader) LegacyListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { nsDefsWithRevisions, err := loadAllNamespaces(ctx, r.query, r.aliveFilter) if err != nil { + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return nil, wrappedErr + } return nil, fmt.Errorf(errUnableToListNamespaces, err) } @@ -253,6 +265,9 @@ func (r *pgReader) LegacyLookupNamespacesWithNames(ctx context.Context, nsNames return r.aliveFilter(original).Where(clause) }) if err != nil { + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return nil, wrappedErr + } return nil, fmt.Errorf(errUnableToListNamespaces, err) } diff --git a/internal/datastore/postgres/readwrite.go b/internal/datastore/postgres/readwrite.go index e327e2d3f..2ea09eb6c 100644 --- a/internal/datastore/postgres/readwrite.go +++ b/internal/datastore/postgres/readwrite.go @@ -419,6 +419,10 @@ func (rwt *pgReadWriteTXN) WriteRelationships(ctx context.Context, mutations []t } func handleWriteError(err error) error { + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return wrappedErr + } + if pgxcommon.IsSerializationError(err) { return common.NewSerializationError(fmt.Errorf("unable to write relationships due to a serialization error: [%w]; this typically indicates that a number of write transactions are contending over the same relationships; either reduce the contention or scale this Postgres instance", err)) } diff --git a/internal/datastore/postgres/revisions.go b/internal/datastore/postgres/revisions.go index 18a41f644..5cfdef5e1 100644 --- a/internal/datastore/postgres/revisions.go +++ b/internal/datastore/postgres/revisions.go @@ -312,6 +312,9 @@ func createNewTransaction(ctx context.Context, tx pgx.Tx, metadata map[string]an cterr := tx.QueryRow(ctx, sql, args...).Scan(&newXID, &newSnapshot, ×tamp) if cterr != nil { + if wrappedErr := common.WrapMissingTableError(cterr); wrappedErr != nil { + return newXID, newSnapshot, timestamp, wrappedErr + } err = fmt.Errorf("error when trying to create a new transaction: %w", cterr) } return newXID, newSnapshot, timestamp, err diff --git a/internal/datastore/postgres/stats.go b/internal/datastore/postgres/stats.go index dc9e2e639..2f059de7d 100644 --- a/internal/datastore/postgres/stats.go +++ b/internal/datastore/postgres/stats.go @@ -40,6 +40,9 @@ func (pgd *pgDatastore) UniqueID(ctx context.Context) (string, error) { if err := pgx.BeginTxFunc(ctx, pgd.readPool, pgd.readTxOptions, func(tx pgx.Tx) error { return tx.QueryRow(ctx, idSQL, idArgs...).Scan(&uniqueID) }); err != nil { + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return "", wrappedErr + } return "", fmt.Errorf("unable to query unique ID: %w", err) } @@ -71,22 +74,34 @@ func (pgd *pgDatastore) Statistics(ctx context.Context) (datastore.Stats, error) if err := pgx.BeginTxFunc(ctx, pgd.readPool, pgd.readTxOptions, func(tx pgx.Tx) error { if pgd.analyzeBeforeStatistics { if _, err := tx.Exec(ctx, "ANALYZE "+schema.TableTuple); err != nil { + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return wrappedErr + } return fmt.Errorf("unable to analyze tuple table: %w", err) } } if err := tx.QueryRow(ctx, idSQL, idArgs...).Scan(&uniqueID); err != nil { + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return wrappedErr + } return fmt.Errorf("unable to query unique ID: %w", err) } nsDefsWithRevisions, err := loadAllNamespaces(ctx, pgxcommon.QuerierFuncsFor(tx), aliveFilter) if err != nil { + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return wrappedErr + } return fmt.Errorf("unable to load namespaces: %w", err) } nsDefs = nsDefsWithRevisions if err := tx.QueryRow(ctx, rowCountSQL, rowCountArgs...).Scan(&relCount); err != nil { + if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { + return wrappedErr + } return fmt.Errorf("unable to read relationship count: %w", err) } From 54d14ab8c9543480a0ccf6ed83f49fc68c659503 Mon Sep 17 00:00:00 2001 From: ivanauth Date: Mon, 5 Jan 2026 12:07:03 -0500 Subject: [PATCH 2/7] fix: address PR #2775 review feedback - Add TODO comment for future ERROR_REASON_DATASTORE_NOT_MIGRATED - Simplify error message format to: "%w. please run migrate" - Use IsMissingTableError + reassignment pattern (not early return) - Export PgMissingTable constant, remove duplicates from migration drivers - Add Spanner IsMissingTableError support in reader.go and caveat.go - Differentiate CRDB watch service error message for missing tables - Add ImportBulk error handling in bulk.go - Update tests for new error message format --- internal/datastore/common/errors.go | 4 +++- internal/datastore/common/errors_test.go | 14 ++++++++++++- internal/datastore/crdb/caveat.go | 9 ++++---- internal/datastore/crdb/crdb.go | 6 +++++- internal/datastore/crdb/migrations/driver.go | 4 +--- internal/datastore/postgres/common/bulk.go | 9 +++++++- internal/datastore/postgres/common/errors.go | 7 +++++-- .../datastore/postgres/common/errors_test.go | 12 +++++------ .../datastore/postgres/migrations/driver.go | 4 +--- internal/datastore/postgres/reader.go | 20 +++++++++--------- internal/datastore/postgres/revisions.go | 4 ++-- internal/datastore/spanner/caveat.go | 6 ++++++ internal/datastore/spanner/errors.go | 21 +++++++++++++++++++ internal/datastore/spanner/reader.go | 6 ++++++ 14 files changed, 92 insertions(+), 34 deletions(-) create mode 100644 internal/datastore/spanner/errors.go diff --git a/internal/datastore/common/errors.go b/internal/datastore/common/errors.go index 072bbf48c..8a532dec4 100644 --- a/internal/datastore/common/errors.go +++ b/internal/datastore/common/errors.go @@ -184,6 +184,8 @@ type SchemaNotInitializedError struct { } func (err SchemaNotInitializedError) GRPCStatus() *status.Status { + // TODO: Create ERROR_REASON_DATASTORE_NOT_MIGRATED in authzed/api and use it here + // See: https://github.com/authzed/spicedb/pull/2775 return spiceerrors.WithCodeAndDetails( err, codes.FailedPrecondition, @@ -202,6 +204,6 @@ func (err SchemaNotInitializedError) Unwrap() error { // instructing the user to run migrations. func NewSchemaNotInitializedError(underlying error) error { return SchemaNotInitializedError{ - fmt.Errorf("datastore error: the database schema has not been initialized; please run \"spicedb datastore migrate\": %w", underlying), + fmt.Errorf("%w. please run \"spicedb datastore migrate\"", underlying), } } diff --git a/internal/datastore/common/errors_test.go b/internal/datastore/common/errors_test.go index 6d586add8..e53f44672 100644 --- a/internal/datastore/common/errors_test.go +++ b/internal/datastore/common/errors_test.go @@ -14,7 +14,8 @@ func TestSchemaNotInitializedError(t *testing.T) { t.Run("error message contains migration instructions", func(t *testing.T) { require.Contains(t, err.Error(), "spicedb datastore migrate") - require.Contains(t, err.Error(), "database schema has not been initialized") + // The error message now includes the underlying error first, followed by the instruction + require.Contains(t, err.Error(), "relation \"caveat\" does not exist") }) t.Run("unwrap returns underlying error", func(t *testing.T) { @@ -40,4 +41,15 @@ func TestSchemaNotInitializedError(t *testing.T) { var schemaErr SchemaNotInitializedError require.ErrorAs(t, wrappedErr, &schemaErr) }) + + t.Run("grpc status extractable from wrapped error", func(t *testing.T) { + // This tests the scenario where SchemaNotInitializedError is wrapped + // by another fmt.Errorf (e.g., in crdb/caveat.go). The gRPC library + // uses errors.As to extract GRPCStatus from wrapped errors. + wrappedErr := fmt.Errorf("outer context: %w", err) + var schemaErr SchemaNotInitializedError + require.ErrorAs(t, wrappedErr, &schemaErr) + status := schemaErr.GRPCStatus() + require.Equal(t, codes.FailedPrecondition, status.Code()) + }) } diff --git a/internal/datastore/crdb/caveat.go b/internal/datastore/crdb/caveat.go index 03f46783b..d1a1b29de 100644 --- a/internal/datastore/crdb/caveat.go +++ b/internal/datastore/crdb/caveat.go @@ -9,6 +9,7 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/jackc/pgx/v5" + dscommon "github.com/authzed/spicedb/internal/datastore/common" "github.com/authzed/spicedb/internal/datastore/crdb/schema" pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" "github.com/authzed/spicedb/internal/datastore/revisions" @@ -54,8 +55,8 @@ func (cr *crdbReader) LegacyReadCaveatByName(ctx context.Context, name string) ( if errors.Is(err, pgx.ErrNoRows) { err = datastore.NewCaveatNameNotFoundErr(name) } - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return nil, datastore.NoRevision, wrappedErr + if pgxcommon.IsMissingTableError(err) { + err = dscommon.NewSchemaNotInitializedError(err) } return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, name, err) } @@ -113,8 +114,8 @@ func (cr *crdbReader) lookupCaveats(ctx context.Context, caveatNames []string) ( return nil }, sql, args...) if err != nil { - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return nil, wrappedErr + if pgxcommon.IsMissingTableError(err) { + err = dscommon.NewSchemaNotInitializedError(err) } return nil, fmt.Errorf(errListCaveats, err) } diff --git a/internal/datastore/crdb/crdb.go b/internal/datastore/crdb/crdb.go index c47bd3e55..a2e7a2abb 100644 --- a/internal/datastore/crdb/crdb.go +++ b/internal/datastore/crdb/crdb.go @@ -594,7 +594,11 @@ func (cds *crdbDatastore) features(ctx context.Context) (*datastore.Features, er } features.Watch.Status = datastore.FeatureUnsupported - features.Watch.Reason = "Range feeds must be enabled in CockroachDB and the user must have permission to create them in order to enable the Watch API: " + err.Error() + if pgxcommon.IsMissingTableError(err) { + features.Watch.Reason = "Database schema has not been initialized. Please run \"spicedb datastore migrate\": " + err.Error() + } else { + features.Watch.Reason = "Range feeds must be enabled in CockroachDB and the user must have permission to create them in order to enable the Watch API: " + err.Error() + } return nil }, fmt.Sprintf(cds.beginChangefeedQuery, cds.schema.RelationshipTableName, head, "-1s")) } else { diff --git a/internal/datastore/crdb/migrations/driver.go b/internal/datastore/crdb/migrations/driver.go index 74dea5970..db3b72986 100644 --- a/internal/datastore/crdb/migrations/driver.go +++ b/internal/datastore/crdb/migrations/driver.go @@ -15,8 +15,6 @@ import ( const ( errUnableToInstantiate = "unable to instantiate CRDBDriver: %w" - postgresMissingTableErrorCode = "42P01" - queryLoadVersion = "SELECT version_num from schema_version" queryWriteVersion = "UPDATE schema_version SET version_num=$1 WHERE version_num=$2" ) @@ -52,7 +50,7 @@ func (apd *CRDBDriver) Version(ctx context.Context) (string, error) { if err := apd.db.QueryRow(ctx, queryLoadVersion).Scan(&loaded); err != nil { var pgErr *pgconn.PgError - if errors.As(err, &pgErr) && pgErr.Code == postgresMissingTableErrorCode { + if errors.As(err, &pgErr) && pgErr.Code == pgxcommon.PgMissingTable { return "", nil } return "", fmt.Errorf("unable to load alembic revision: %w", err) diff --git a/internal/datastore/postgres/common/bulk.go b/internal/datastore/postgres/common/bulk.go index 0fbbf7bb2..13573e606 100644 --- a/internal/datastore/postgres/common/bulk.go +++ b/internal/datastore/postgres/common/bulk.go @@ -6,6 +6,7 @@ import ( "github.com/ccoveille/go-safecast/v2" "github.com/jackc/pgx/v5" + dscommon "github.com/authzed/spicedb/internal/datastore/common" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/spiceerrors" "github.com/authzed/spicedb/pkg/tuple" @@ -77,9 +78,15 @@ func BulkLoad( colNames: colNames, } copied, err := tx.CopyFrom(ctx, pgx.Identifier{tupleTableName}, colNames, adapter) + if err != nil { + if IsMissingTableError(err) { + return 0, dscommon.NewSchemaNotInitializedError(err) + } + return 0, err + } uintCopied, castErr := safecast.Convert[uint64](copied) if castErr != nil { return 0, spiceerrors.MustBugf("number copied was negative: %v", castErr) } - return uintCopied, err + return uintCopied, nil } diff --git a/internal/datastore/postgres/common/errors.go b/internal/datastore/postgres/common/errors.go index 94fbb2e78..e7e191706 100644 --- a/internal/datastore/postgres/common/errors.go +++ b/internal/datastore/postgres/common/errors.go @@ -19,7 +19,10 @@ const ( pgReadOnlyTransaction = "25006" pgQueryCanceled = "57014" pgInvalidArgument = "22023" - pgMissingTable = "42P01" + + // PgMissingTable is the Postgres error code for "relation does not exist". + // This is used to detect when migrations have not been run. + PgMissingTable = "42P01" ) var ( @@ -112,7 +115,7 @@ func ConvertToWriteConstraintError(livingTupleConstraints []string, err error) e // This typically happens when migrations have not been run. func IsMissingTableError(err error) bool { var pgerr *pgconn.PgError - return errors.As(err, &pgerr) && pgerr.Code == pgMissingTable + return errors.As(err, &pgerr) && pgerr.Code == PgMissingTable } // WrapMissingTableError checks if the error is a missing table error and wraps it with diff --git a/internal/datastore/postgres/common/errors_test.go b/internal/datastore/postgres/common/errors_test.go index 27ec49a88..8410ef255 100644 --- a/internal/datastore/postgres/common/errors_test.go +++ b/internal/datastore/postgres/common/errors_test.go @@ -13,7 +13,7 @@ import ( func TestIsMissingTableError(t *testing.T) { t.Run("returns true for missing table error", func(t *testing.T) { pgErr := &pgconn.PgError{ - Code: pgMissingTable, + Code: PgMissingTable, Message: "relation \"caveat\" does not exist", } require.True(t, IsMissingTableError(pgErr)) @@ -38,7 +38,7 @@ func TestIsMissingTableError(t *testing.T) { t.Run("returns true for wrapped missing table error", func(t *testing.T) { pgErr := &pgconn.PgError{ - Code: pgMissingTable, + Code: PgMissingTable, Message: "relation \"caveat\" does not exist", } wrappedErr := fmt.Errorf("query failed: %w", pgErr) @@ -49,7 +49,7 @@ func TestIsMissingTableError(t *testing.T) { func TestWrapMissingTableError(t *testing.T) { t.Run("wraps missing table error", func(t *testing.T) { pgErr := &pgconn.PgError{ - Code: pgMissingTable, + Code: PgMissingTable, Message: "relation \"caveat\" does not exist", } wrapped := WrapMissingTableError(pgErr) @@ -79,7 +79,7 @@ func TestWrapMissingTableError(t *testing.T) { t.Run("preserves original error in chain", func(t *testing.T) { pgErr := &pgconn.PgError{ - Code: pgMissingTable, + Code: PgMissingTable, Message: "relation \"caveat\" does not exist", } wrapped := WrapMissingTableError(pgErr) @@ -88,12 +88,12 @@ func TestWrapMissingTableError(t *testing.T) { // The original postgres error should be accessible via unwrapping var foundPgErr *pgconn.PgError require.ErrorAs(t, wrapped, &foundPgErr) - require.Equal(t, pgMissingTable, foundPgErr.Code) + require.Equal(t, PgMissingTable, foundPgErr.Code) }) t.Run("does not double-wrap already wrapped errors", func(t *testing.T) { pgErr := &pgconn.PgError{ - Code: pgMissingTable, + Code: PgMissingTable, Message: "relation \"caveat\" does not exist", } // First wrap diff --git a/internal/datastore/postgres/migrations/driver.go b/internal/datastore/postgres/migrations/driver.go index 1054ecde4..522ae790e 100644 --- a/internal/datastore/postgres/migrations/driver.go +++ b/internal/datastore/postgres/migrations/driver.go @@ -15,8 +15,6 @@ import ( "github.com/authzed/spicedb/pkg/migrate" ) -const postgresMissingTableErrorCode = "42P01" - var tracer = otel.Tracer("spicedb/internal/datastore/common") // AlembicPostgresDriver implements a schema migration facility for use in @@ -74,7 +72,7 @@ func (apd *AlembicPostgresDriver) Version(ctx context.Context) (string, error) { if err := apd.db.QueryRow(ctx, "SELECT version_num from alembic_version").Scan(&loaded); err != nil { var pgErr *pgconn.PgError - if errors.As(err, &pgErr) && pgErr.Code == postgresMissingTableErrorCode { + if errors.As(err, &pgErr) && pgErr.Code == pgxcommon.PgMissingTable { return "", nil } return "", fmt.Errorf("unable to load alembic revision: %w", err) diff --git a/internal/datastore/postgres/reader.go b/internal/datastore/postgres/reader.go index 4f75017d3..90d4d8c8d 100644 --- a/internal/datastore/postgres/reader.go +++ b/internal/datastore/postgres/reader.go @@ -85,8 +85,8 @@ func (r *pgReader) CountRelationships(ctx context.Context, name string) (int, er return rows.Err() }, sql, args...) if err != nil { - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return 0, wrappedErr + if pgxcommon.IsMissingTableError(err) { + err = common.NewSchemaNotInitializedError(err) } return 0, err } @@ -143,8 +143,8 @@ func (r *pgReader) lookupCounters(ctx context.Context, optionalName string) ([]d return rows.Err() }, sql, args...) if err != nil { - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return nil, wrappedErr + if pgxcommon.IsMissingTableError(err) { + err = common.NewSchemaNotInitializedError(err) } return nil, fmt.Errorf("unable to query counters: %w", err) } @@ -214,8 +214,8 @@ func (r *pgReader) LegacyReadNamespaceByName(ctx context.Context, nsName string) case err == nil: return loaded, version, nil default: - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return nil, datastore.NoRevision, wrappedErr + if pgxcommon.IsMissingTableError(err) { + err = common.NewSchemaNotInitializedError(err) } return nil, datastore.NoRevision, fmt.Errorf(errUnableToReadConfig, err) } @@ -242,8 +242,8 @@ func (r *pgReader) loadNamespace(ctx context.Context, namespace string, tx pgxco func (r *pgReader) LegacyListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { nsDefsWithRevisions, err := loadAllNamespaces(ctx, r.query, r.aliveFilter) if err != nil { - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return nil, wrappedErr + if pgxcommon.IsMissingTableError(err) { + err = common.NewSchemaNotInitializedError(err) } return nil, fmt.Errorf(errUnableToListNamespaces, err) } @@ -265,8 +265,8 @@ func (r *pgReader) LegacyLookupNamespacesWithNames(ctx context.Context, nsNames return r.aliveFilter(original).Where(clause) }) if err != nil { - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return nil, wrappedErr + if pgxcommon.IsMissingTableError(err) { + err = common.NewSchemaNotInitializedError(err) } return nil, fmt.Errorf(errUnableToListNamespaces, err) } diff --git a/internal/datastore/postgres/revisions.go b/internal/datastore/postgres/revisions.go index 5cfdef5e1..861a8f97d 100644 --- a/internal/datastore/postgres/revisions.go +++ b/internal/datastore/postgres/revisions.go @@ -312,8 +312,8 @@ func createNewTransaction(ctx context.Context, tx pgx.Tx, metadata map[string]an cterr := tx.QueryRow(ctx, sql, args...).Scan(&newXID, &newSnapshot, ×tamp) if cterr != nil { - if wrappedErr := common.WrapMissingTableError(cterr); wrappedErr != nil { - return newXID, newSnapshot, timestamp, wrappedErr + if common.IsMissingTableError(cterr) { + cterr = dscommon.NewSchemaNotInitializedError(cterr) } err = fmt.Errorf("error when trying to create a new transaction: %w", cterr) } diff --git a/internal/datastore/spanner/caveat.go b/internal/datastore/spanner/caveat.go index b810a88e2..54d716c1f 100644 --- a/internal/datastore/spanner/caveat.go +++ b/internal/datastore/spanner/caveat.go @@ -18,6 +18,9 @@ func (sr spannerReader) LegacyReadCaveatByName(ctx context.Context, name string) caveatKey := spanner.Key{name} row, err := sr.txSource().ReadRow(ctx, tableCaveat, caveatKey, []string{colCaveatDefinition, colCaveatTS}) if err != nil { + if IsMissingTableError(err) { + return nil, datastore.NoRevision, common.NewSchemaNotInitializedError(err) + } if spanner.ErrCode(err) == codes.NotFound { return nil, datastore.NoRevision, datastore.NewCaveatNameNotFoundErr(name) } @@ -83,6 +86,9 @@ func (sr spannerReader) listCaveats(ctx context.Context, caveatNames []string) ( return nil }); err != nil { + if IsMissingTableError(err) { + return nil, common.NewSchemaNotInitializedError(err) + } return nil, fmt.Errorf(errUnableToListCaveats, err) } diff --git a/internal/datastore/spanner/errors.go b/internal/datastore/spanner/errors.go new file mode 100644 index 000000000..a379bb993 --- /dev/null +++ b/internal/datastore/spanner/errors.go @@ -0,0 +1,21 @@ +package spanner + +import ( + "strings" + + "cloud.google.com/go/spanner" + "google.golang.org/grpc/codes" +) + +// IsMissingTableError returns true if the error is a Spanner error indicating a missing table. +// This typically happens when migrations have not been run. +func IsMissingTableError(err error) bool { + if spanner.ErrCode(err) == codes.NotFound { + // Check if it's specifically about a missing table + errMsg := err.Error() + if strings.Contains(errMsg, "Table not found") { + return true + } + } + return false +} diff --git a/internal/datastore/spanner/reader.go b/internal/datastore/spanner/reader.go index 958709f13..29f8545e6 100644 --- a/internal/datastore/spanner/reader.go +++ b/internal/datastore/spanner/reader.go @@ -298,6 +298,9 @@ func (sr spannerReader) LegacyReadNamespaceByName(ctx context.Context, nsName st []string{colNamespaceConfig, colNamespaceTS}, ) if err != nil { + if IsMissingTableError(err) { + return nil, datastore.NoRevision, common.NewSchemaNotInitializedError(err) + } if spanner.ErrCode(err) == codes.NotFound { return nil, datastore.NoRevision, datastore.NewNamespaceNotFoundErr(nsName) } @@ -329,6 +332,9 @@ func (sr spannerReader) LegacyListAllNamespaces(ctx context.Context) ([]datastor allNamespaces, err := readAllNamespaces(iter, trace.SpanFromContext(ctx)) if err != nil { + if IsMissingTableError(err) { + return nil, common.NewSchemaNotInitializedError(err) + } return nil, fmt.Errorf(errUnableToListNamespaces, err) } From 39804a88a62a3ed894a76eaf3316658d3e9a9898 Mon Sep 17 00:00:00 2001 From: ivanauth Date: Mon, 5 Jan 2026 13:57:14 -0500 Subject: [PATCH 3/7] chore: update TODO to reference authzed/api#159 --- internal/datastore/common/errors.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/datastore/common/errors.go b/internal/datastore/common/errors.go index 8a532dec4..b6bb7fbf8 100644 --- a/internal/datastore/common/errors.go +++ b/internal/datastore/common/errors.go @@ -184,8 +184,7 @@ type SchemaNotInitializedError struct { } func (err SchemaNotInitializedError) GRPCStatus() *status.Status { - // TODO: Create ERROR_REASON_DATASTORE_NOT_MIGRATED in authzed/api and use it here - // See: https://github.com/authzed/spicedb/pull/2775 + // TODO: Update to use ERROR_REASON_DATASTORE_NOT_MIGRATED once authzed/api#159 is merged return spiceerrors.WithCodeAndDetails( err, codes.FailedPrecondition, From b019d3113aec764bb2343a3be19926994214e6e5 Mon Sep 17 00:00:00 2001 From: ivanauth Date: Tue, 6 Jan 2026 12:14:41 -0500 Subject: [PATCH 4/7] feat: centralize migration error handling with gRPC readiness middleware Replaces scattered IsMissingTableError checks (15 files, 157 lines) with a single gRPC readiness middleware that blocks ALL requests until the datastore is migrated. Key improvements: - Single point of control vs copy-pasted checks everywhere - Impossible to miss code paths - all gRPC requests gated automatically - Clear error message: "Please run 'spicedb datastore migrate'" - Cached checks (500ms) with singleflight to prevent thundering herd - Health probes bypass the gate for Kubernetes compatibility Net: -81 lines, better coverage, consistent UX. --- internal/datastore/crdb/caveat.go | 8 - internal/datastore/crdb/crdb.go | 6 +- internal/datastore/crdb/reader.go | 15 - internal/datastore/crdb/stats.go | 12 - internal/datastore/mysql/caveat.go | 7 - internal/datastore/mysql/reader.go | 16 - internal/datastore/mysql/stats.go | 16 - internal/datastore/postgres/caveat.go | 7 - internal/datastore/postgres/common/bulk.go | 9 +- internal/datastore/postgres/reader.go | 15 - internal/datastore/postgres/readwrite.go | 4 - internal/datastore/postgres/revisions.go | 3 - internal/datastore/postgres/stats.go | 15 - internal/datastore/spanner/caveat.go | 6 - internal/datastore/spanner/errors.go | 19 + internal/datastore/spanner/reader.go | 6 - pkg/cmd/server/defaults.go | 15 + pkg/cmd/server/defaults_test.go | 2 + pkg/cmd/server/server.go | 1 + pkg/cmd/server/server_test.go | 4 +- .../server/zz_generated.middlewareoption.go | 11 + pkg/middleware/readiness/readiness.go | 195 +++++++++ pkg/middleware/readiness/readiness_test.go | 393 ++++++++++++++++++ 23 files changed, 640 insertions(+), 145 deletions(-) create mode 100644 pkg/middleware/readiness/readiness.go create mode 100644 pkg/middleware/readiness/readiness_test.go diff --git a/internal/datastore/crdb/caveat.go b/internal/datastore/crdb/caveat.go index d1a1b29de..5857a5702 100644 --- a/internal/datastore/crdb/caveat.go +++ b/internal/datastore/crdb/caveat.go @@ -9,9 +9,7 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/jackc/pgx/v5" - dscommon "github.com/authzed/spicedb/internal/datastore/common" "github.com/authzed/spicedb/internal/datastore/crdb/schema" - pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" "github.com/authzed/spicedb/internal/datastore/revisions" "github.com/authzed/spicedb/pkg/datastore" core "github.com/authzed/spicedb/pkg/proto/core/v1" @@ -55,9 +53,6 @@ func (cr *crdbReader) LegacyReadCaveatByName(ctx context.Context, name string) ( if errors.Is(err, pgx.ErrNoRows) { err = datastore.NewCaveatNameNotFoundErr(name) } - if pgxcommon.IsMissingTableError(err) { - err = dscommon.NewSchemaNotInitializedError(err) - } return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, name, err) } @@ -114,9 +109,6 @@ func (cr *crdbReader) lookupCaveats(ctx context.Context, caveatNames []string) ( return nil }, sql, args...) if err != nil { - if pgxcommon.IsMissingTableError(err) { - err = dscommon.NewSchemaNotInitializedError(err) - } return nil, fmt.Errorf(errListCaveats, err) } diff --git a/internal/datastore/crdb/crdb.go b/internal/datastore/crdb/crdb.go index a2e7a2abb..c47bd3e55 100644 --- a/internal/datastore/crdb/crdb.go +++ b/internal/datastore/crdb/crdb.go @@ -594,11 +594,7 @@ func (cds *crdbDatastore) features(ctx context.Context) (*datastore.Features, er } features.Watch.Status = datastore.FeatureUnsupported - if pgxcommon.IsMissingTableError(err) { - features.Watch.Reason = "Database schema has not been initialized. Please run \"spicedb datastore migrate\": " + err.Error() - } else { - features.Watch.Reason = "Range feeds must be enabled in CockroachDB and the user must have permission to create them in order to enable the Watch API: " + err.Error() - } + features.Watch.Reason = "Range feeds must be enabled in CockroachDB and the user must have permission to create them in order to enable the Watch API: " + err.Error() return nil }, fmt.Sprintf(cds.beginChangefeedQuery, cds.schema.RelationshipTableName, head, "-1s")) } else { diff --git a/internal/datastore/crdb/reader.go b/internal/datastore/crdb/reader.go index f6cd3f33e..5bf2c9b1b 100644 --- a/internal/datastore/crdb/reader.go +++ b/internal/datastore/crdb/reader.go @@ -129,9 +129,6 @@ func (cr *crdbReader) CountRelationships(ctx context.Context, name string) (int, return row.Scan(&count) }, sql, args...) if err != nil { - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return 0, wrappedErr - } return 0, err } @@ -196,9 +193,6 @@ func (cr *crdbReader) lookupCounters(ctx context.Context, optionalFilterName str return nil }, sql, args...) if err != nil { - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return nil, wrappedErr - } return nil, err } @@ -214,9 +208,6 @@ func (cr *crdbReader) LegacyReadNamespaceByName( if errors.As(err, &datastore.NamespaceNotFoundError{}) { return nil, datastore.NoRevision, err } - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return nil, datastore.NoRevision, wrappedErr - } return nil, datastore.NoRevision, fmt.Errorf(errUnableToReadConfig, err) } @@ -230,9 +221,6 @@ func (cr *crdbReader) LegacyListAllNamespaces(ctx context.Context) ([]datastore. nsDefs, sql, err := loadAllNamespaces(ctx, cr.query, addFromToQuery) if err != nil { - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return nil, wrappedErr - } return nil, fmt.Errorf(errUnableToListNamespaces, err) } cr.assertHasExpectedAsOfSystemTime(sql) @@ -245,9 +233,6 @@ func (cr *crdbReader) LegacyLookupNamespacesWithNames(ctx context.Context, nsNam } nsDefs, err := cr.lookupNamespaces(ctx, cr.query, nsNames) if err != nil { - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return nil, wrappedErr - } return nil, fmt.Errorf(errUnableToListNamespaces, err) } return nsDefs, nil diff --git a/internal/datastore/crdb/stats.go b/internal/datastore/crdb/stats.go index 7e9954294..6ec65c00a 100644 --- a/internal/datastore/crdb/stats.go +++ b/internal/datastore/crdb/stats.go @@ -33,9 +33,6 @@ func (cds *crdbDatastore) UniqueID(ctx context.Context) (string, error) { if err := cds.readPool.QueryRowFunc(ctx, func(ctx context.Context, row pgx.Row) error { return row.Scan(&uniqueID) }, sql, args...); err != nil { - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return "", wrappedErr - } return "", fmt.Errorf("unable to query unique ID: %w", err) } @@ -62,9 +59,6 @@ func (cds *crdbDatastore) Statistics(ctx context.Context) (datastore.Stats, erro return sb.From(tableName) }) if err != nil { - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return wrappedErr - } return fmt.Errorf("unable to read namespaces: %w", err) } return nil @@ -75,9 +69,6 @@ func (cds *crdbDatastore) Statistics(ctx context.Context) (datastore.Stats, erro if cds.analyzeBeforeStatistics { if err := cds.readPool.BeginTxFunc(ctx, pgx.TxOptions{AccessMode: pgx.ReadOnly}, func(tx pgx.Tx) error { if _, err := tx.Exec(ctx, "ANALYZE "+cds.schema.RelationshipTableName); err != nil { - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return wrappedErr - } return fmt.Errorf("unable to analyze tuple table: %w", err) } @@ -152,9 +143,6 @@ func (cds *crdbDatastore) Statistics(ctx context.Context) (datastore.Stats, erro log.Warn().Bool("has-rows", hasRows).Msg("unable to find row count in statistics query result") return nil }, "SHOW STATISTICS FOR TABLE "+cds.schema.RelationshipTableName); err != nil { - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return datastore.Stats{}, wrappedErr - } return datastore.Stats{}, fmt.Errorf("unable to query unique estimated row count: %w", err) } diff --git a/internal/datastore/mysql/caveat.go b/internal/datastore/mysql/caveat.go index f6d2d5cf4..bb48ad509 100644 --- a/internal/datastore/mysql/caveat.go +++ b/internal/datastore/mysql/caveat.go @@ -9,7 +9,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/authzed/spicedb/internal/datastore/common" - mysqlcommon "github.com/authzed/spicedb/internal/datastore/mysql/common" "github.com/authzed/spicedb/internal/datastore/revisions" "github.com/authzed/spicedb/pkg/datastore" core "github.com/authzed/spicedb/pkg/proto/core/v1" @@ -42,9 +41,6 @@ func (mr *mysqlReader) LegacyReadCaveatByName(ctx context.Context, name string) if errors.Is(err, sql.ErrNoRows) { return nil, datastore.NoRevision, datastore.NewCaveatNameNotFoundErr(name) } - if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { - return nil, datastore.NoRevision, wrappedErr - } return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, err) } def := core.CaveatDefinition{} @@ -86,9 +82,6 @@ func (mr *mysqlReader) lookupCaveats(ctx context.Context, caveatNames []string) rows, err := tx.QueryContext(ctx, listSQL, listArgs...) if err != nil { - if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { - return nil, wrappedErr - } return nil, fmt.Errorf(errListCaveats, err) } defer common.LogOnError(ctx, rows.Close) diff --git a/internal/datastore/mysql/reader.go b/internal/datastore/mysql/reader.go index 0c0bf7c45..33e056d9d 100644 --- a/internal/datastore/mysql/reader.go +++ b/internal/datastore/mysql/reader.go @@ -9,7 +9,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/authzed/spicedb/internal/datastore/common" - mysqlcommon "github.com/authzed/spicedb/internal/datastore/mysql/common" "github.com/authzed/spicedb/internal/datastore/revisions" schemautil "github.com/authzed/spicedb/internal/datastore/schema" "github.com/authzed/spicedb/pkg/datastore" @@ -77,9 +76,6 @@ func (mr *mysqlReader) CountRelationships(ctx context.Context, name string) (int var count int rows, err := tx.QueryContext(ctx, sql, args...) if err != nil { - if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { - return 0, wrappedErr - } return 0, err } defer common.LogOnError(ctx, rows.Close) @@ -127,9 +123,6 @@ func (mr *mysqlReader) lookupCounters(ctx context.Context, optionalName string) rows, err := tx.QueryContext(ctx, sql, args...) if err != nil { - if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { - return nil, wrappedErr - } return nil, err } defer common.LogOnError(ctx, rows.Close) @@ -230,9 +223,6 @@ func (mr *mysqlReader) LegacyReadNamespaceByName(ctx context.Context, nsName str case err == nil: return loaded, version, nil default: - if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { - return nil, datastore.NoRevision, wrappedErr - } return nil, datastore.NoRevision, fmt.Errorf(errUnableToReadConfig, err) } } @@ -275,9 +265,6 @@ func (mr *mysqlReader) LegacyListAllNamespaces(ctx context.Context) ([]datastore nsDefs, err := loadAllNamespaces(ctx, tx, query) if err != nil { - if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { - return nil, wrappedErr - } return nil, fmt.Errorf(errUnableToListNamespaces, err) } @@ -304,9 +291,6 @@ func (mr *mysqlReader) LegacyLookupNamespacesWithNames(ctx context.Context, nsNa nsDefs, err := loadAllNamespaces(ctx, tx, query) if err != nil { - if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { - return nil, wrappedErr - } return nil, fmt.Errorf(errUnableToListNamespaces, err) } diff --git a/internal/datastore/mysql/stats.go b/internal/datastore/mysql/stats.go index a16ae51c4..dfbdc8618 100644 --- a/internal/datastore/mysql/stats.go +++ b/internal/datastore/mysql/stats.go @@ -9,7 +9,6 @@ import ( "github.com/ccoveille/go-safecast/v2" "github.com/authzed/spicedb/internal/datastore/common" - mysqlcommon "github.com/authzed/spicedb/internal/datastore/mysql/common" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/spiceerrors" ) @@ -27,9 +26,6 @@ func (mds *mysqlDatastore) Statistics(ctx context.Context) (datastore.Stats, err if mds.analyzeBeforeStats { _, err := mds.db.ExecContext(ctx, "ANALYZE TABLE "+mds.driver.RelationTuple()) if err != nil { - if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { - return datastore.Stats{}, wrappedErr - } return datastore.Stats{}, fmt.Errorf("unable to run ANALYZE TABLE: %w", err) } } @@ -51,9 +47,6 @@ func (mds *mysqlDatastore) Statistics(ctx context.Context) (datastore.Stats, err var count sql.NullInt64 err = mds.db.QueryRowContext(ctx, query, args...).Scan(&count) if err != nil { - if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { - return datastore.Stats{}, wrappedErr - } return datastore.Stats{}, err } @@ -66,9 +59,6 @@ func (mds *mysqlDatastore) Statistics(ctx context.Context) (datastore.Stats, err } err = mds.db.QueryRowContext(ctx, query, args...).Scan(&count) if err != nil { - if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { - return datastore.Stats{}, wrappedErr - } return datastore.Stats{}, err } } @@ -83,9 +73,6 @@ func (mds *mysqlDatastore) Statistics(ctx context.Context) (datastore.Stats, err nsDefs, err := loadAllNamespaces(ctx, tx, nsQuery) if err != nil { - if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { - return datastore.Stats{}, wrappedErr - } return datastore.Stats{}, fmt.Errorf("unable to load namespaces: %w", err) } @@ -110,9 +97,6 @@ func (mds *mysqlDatastore) UniqueID(ctx context.Context) (string, error) { var uniqueID string if err := mds.db.QueryRowContext(ctx, sql, args...).Scan(&uniqueID); err != nil { - if wrappedErr := mysqlcommon.WrapMissingTableError(err); wrappedErr != nil { - return "", wrappedErr - } return "", fmt.Errorf("unable to query unique ID: %w", err) } mds.uniqueID.Store(&uniqueID) diff --git a/internal/datastore/postgres/caveat.go b/internal/datastore/postgres/caveat.go index 483a874ea..535632096 100644 --- a/internal/datastore/postgres/caveat.go +++ b/internal/datastore/postgres/caveat.go @@ -8,7 +8,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/jackc/pgx/v5" - pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" "github.com/authzed/spicedb/internal/datastore/postgres/schema" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/genutil/mapz" @@ -50,9 +49,6 @@ func (r *pgReader) LegacyReadCaveatByName(ctx context.Context, name string) (*co if errors.Is(err, pgx.ErrNoRows) { return nil, datastore.NoRevision, datastore.NewCaveatNameNotFoundErr(name) } - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return nil, datastore.NoRevision, wrappedErr - } return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, err) } def := core.CaveatDefinition{} @@ -110,9 +106,6 @@ func (r *pgReader) lookupCaveats(ctx context.Context, caveatNames []string) ([]d return rows.Err() }, sql, args...) if err != nil { - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return nil, wrappedErr - } return nil, fmt.Errorf(errListCaveats, err) } diff --git a/internal/datastore/postgres/common/bulk.go b/internal/datastore/postgres/common/bulk.go index 13573e606..0fbbf7bb2 100644 --- a/internal/datastore/postgres/common/bulk.go +++ b/internal/datastore/postgres/common/bulk.go @@ -6,7 +6,6 @@ import ( "github.com/ccoveille/go-safecast/v2" "github.com/jackc/pgx/v5" - dscommon "github.com/authzed/spicedb/internal/datastore/common" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/spiceerrors" "github.com/authzed/spicedb/pkg/tuple" @@ -78,15 +77,9 @@ func BulkLoad( colNames: colNames, } copied, err := tx.CopyFrom(ctx, pgx.Identifier{tupleTableName}, colNames, adapter) - if err != nil { - if IsMissingTableError(err) { - return 0, dscommon.NewSchemaNotInitializedError(err) - } - return 0, err - } uintCopied, castErr := safecast.Convert[uint64](copied) if castErr != nil { return 0, spiceerrors.MustBugf("number copied was negative: %v", castErr) } - return uintCopied, nil + return uintCopied, err } diff --git a/internal/datastore/postgres/reader.go b/internal/datastore/postgres/reader.go index 90d4d8c8d..ce40b06c8 100644 --- a/internal/datastore/postgres/reader.go +++ b/internal/datastore/postgres/reader.go @@ -85,9 +85,6 @@ func (r *pgReader) CountRelationships(ctx context.Context, name string) (int, er return rows.Err() }, sql, args...) if err != nil { - if pgxcommon.IsMissingTableError(err) { - err = common.NewSchemaNotInitializedError(err) - } return 0, err } @@ -143,9 +140,6 @@ func (r *pgReader) lookupCounters(ctx context.Context, optionalName string) ([]d return rows.Err() }, sql, args...) if err != nil { - if pgxcommon.IsMissingTableError(err) { - err = common.NewSchemaNotInitializedError(err) - } return nil, fmt.Errorf("unable to query counters: %w", err) } @@ -214,9 +208,6 @@ func (r *pgReader) LegacyReadNamespaceByName(ctx context.Context, nsName string) case err == nil: return loaded, version, nil default: - if pgxcommon.IsMissingTableError(err) { - err = common.NewSchemaNotInitializedError(err) - } return nil, datastore.NoRevision, fmt.Errorf(errUnableToReadConfig, err) } } @@ -242,9 +233,6 @@ func (r *pgReader) loadNamespace(ctx context.Context, namespace string, tx pgxco func (r *pgReader) LegacyListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { nsDefsWithRevisions, err := loadAllNamespaces(ctx, r.query, r.aliveFilter) if err != nil { - if pgxcommon.IsMissingTableError(err) { - err = common.NewSchemaNotInitializedError(err) - } return nil, fmt.Errorf(errUnableToListNamespaces, err) } @@ -265,9 +253,6 @@ func (r *pgReader) LegacyLookupNamespacesWithNames(ctx context.Context, nsNames return r.aliveFilter(original).Where(clause) }) if err != nil { - if pgxcommon.IsMissingTableError(err) { - err = common.NewSchemaNotInitializedError(err) - } return nil, fmt.Errorf(errUnableToListNamespaces, err) } diff --git a/internal/datastore/postgres/readwrite.go b/internal/datastore/postgres/readwrite.go index 2ea09eb6c..e327e2d3f 100644 --- a/internal/datastore/postgres/readwrite.go +++ b/internal/datastore/postgres/readwrite.go @@ -419,10 +419,6 @@ func (rwt *pgReadWriteTXN) WriteRelationships(ctx context.Context, mutations []t } func handleWriteError(err error) error { - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return wrappedErr - } - if pgxcommon.IsSerializationError(err) { return common.NewSerializationError(fmt.Errorf("unable to write relationships due to a serialization error: [%w]; this typically indicates that a number of write transactions are contending over the same relationships; either reduce the contention or scale this Postgres instance", err)) } diff --git a/internal/datastore/postgres/revisions.go b/internal/datastore/postgres/revisions.go index 861a8f97d..18a41f644 100644 --- a/internal/datastore/postgres/revisions.go +++ b/internal/datastore/postgres/revisions.go @@ -312,9 +312,6 @@ func createNewTransaction(ctx context.Context, tx pgx.Tx, metadata map[string]an cterr := tx.QueryRow(ctx, sql, args...).Scan(&newXID, &newSnapshot, ×tamp) if cterr != nil { - if common.IsMissingTableError(cterr) { - cterr = dscommon.NewSchemaNotInitializedError(cterr) - } err = fmt.Errorf("error when trying to create a new transaction: %w", cterr) } return newXID, newSnapshot, timestamp, err diff --git a/internal/datastore/postgres/stats.go b/internal/datastore/postgres/stats.go index 2f059de7d..dc9e2e639 100644 --- a/internal/datastore/postgres/stats.go +++ b/internal/datastore/postgres/stats.go @@ -40,9 +40,6 @@ func (pgd *pgDatastore) UniqueID(ctx context.Context) (string, error) { if err := pgx.BeginTxFunc(ctx, pgd.readPool, pgd.readTxOptions, func(tx pgx.Tx) error { return tx.QueryRow(ctx, idSQL, idArgs...).Scan(&uniqueID) }); err != nil { - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return "", wrappedErr - } return "", fmt.Errorf("unable to query unique ID: %w", err) } @@ -74,34 +71,22 @@ func (pgd *pgDatastore) Statistics(ctx context.Context) (datastore.Stats, error) if err := pgx.BeginTxFunc(ctx, pgd.readPool, pgd.readTxOptions, func(tx pgx.Tx) error { if pgd.analyzeBeforeStatistics { if _, err := tx.Exec(ctx, "ANALYZE "+schema.TableTuple); err != nil { - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return wrappedErr - } return fmt.Errorf("unable to analyze tuple table: %w", err) } } if err := tx.QueryRow(ctx, idSQL, idArgs...).Scan(&uniqueID); err != nil { - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return wrappedErr - } return fmt.Errorf("unable to query unique ID: %w", err) } nsDefsWithRevisions, err := loadAllNamespaces(ctx, pgxcommon.QuerierFuncsFor(tx), aliveFilter) if err != nil { - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return wrappedErr - } return fmt.Errorf("unable to load namespaces: %w", err) } nsDefs = nsDefsWithRevisions if err := tx.QueryRow(ctx, rowCountSQL, rowCountArgs...).Scan(&relCount); err != nil { - if wrappedErr := pgxcommon.WrapMissingTableError(err); wrappedErr != nil { - return wrappedErr - } return fmt.Errorf("unable to read relationship count: %w", err) } diff --git a/internal/datastore/spanner/caveat.go b/internal/datastore/spanner/caveat.go index 54d716c1f..b810a88e2 100644 --- a/internal/datastore/spanner/caveat.go +++ b/internal/datastore/spanner/caveat.go @@ -18,9 +18,6 @@ func (sr spannerReader) LegacyReadCaveatByName(ctx context.Context, name string) caveatKey := spanner.Key{name} row, err := sr.txSource().ReadRow(ctx, tableCaveat, caveatKey, []string{colCaveatDefinition, colCaveatTS}) if err != nil { - if IsMissingTableError(err) { - return nil, datastore.NoRevision, common.NewSchemaNotInitializedError(err) - } if spanner.ErrCode(err) == codes.NotFound { return nil, datastore.NoRevision, datastore.NewCaveatNameNotFoundErr(name) } @@ -86,9 +83,6 @@ func (sr spannerReader) listCaveats(ctx context.Context, caveatNames []string) ( return nil }); err != nil { - if IsMissingTableError(err) { - return nil, common.NewSchemaNotInitializedError(err) - } return nil, fmt.Errorf(errUnableToListCaveats, err) } diff --git a/internal/datastore/spanner/errors.go b/internal/datastore/spanner/errors.go index a379bb993..6ab079449 100644 --- a/internal/datastore/spanner/errors.go +++ b/internal/datastore/spanner/errors.go @@ -1,10 +1,13 @@ package spanner import ( + "errors" "strings" "cloud.google.com/go/spanner" "google.golang.org/grpc/codes" + + "github.com/authzed/spicedb/internal/datastore/common" ) // IsMissingTableError returns true if the error is a Spanner error indicating a missing table. @@ -19,3 +22,19 @@ func IsMissingTableError(err error) bool { } return false } + +// WrapMissingTableError checks if the error is a missing table error and wraps it with +// a helpful message instructing the user to run migrations. If it's not a missing table error, +// it returns nil. If it's already a SchemaNotInitializedError, it returns the original error +// to preserve the wrapped error through the call chain. +func WrapMissingTableError(err error) error { + // Don't double-wrap if already a SchemaNotInitializedError - return original to preserve it + var schemaErr common.SchemaNotInitializedError + if errors.As(err, &schemaErr) { + return err + } + if IsMissingTableError(err) { + return common.NewSchemaNotInitializedError(err) + } + return nil +} diff --git a/internal/datastore/spanner/reader.go b/internal/datastore/spanner/reader.go index 29f8545e6..958709f13 100644 --- a/internal/datastore/spanner/reader.go +++ b/internal/datastore/spanner/reader.go @@ -298,9 +298,6 @@ func (sr spannerReader) LegacyReadNamespaceByName(ctx context.Context, nsName st []string{colNamespaceConfig, colNamespaceTS}, ) if err != nil { - if IsMissingTableError(err) { - return nil, datastore.NoRevision, common.NewSchemaNotInitializedError(err) - } if spanner.ErrCode(err) == codes.NotFound { return nil, datastore.NoRevision, datastore.NewNamespaceNotFoundErr(nsName) } @@ -332,9 +329,6 @@ func (sr spannerReader) LegacyListAllNamespaces(ctx context.Context) ([]datastor allNamespaces, err := readAllNamespaces(iter, trace.SpanFromContext(ctx)) if err != nil { - if IsMissingTableError(err) { - return nil, common.NewSchemaNotInitializedError(err) - } return nil, fmt.Errorf(errUnableToListNamespaces, err) } diff --git a/pkg/cmd/server/defaults.go b/pkg/cmd/server/defaults.go index 3d3ef0a37..b9cbee6e9 100644 --- a/pkg/cmd/server/defaults.go +++ b/pkg/cmd/server/defaults.go @@ -39,6 +39,7 @@ import ( "github.com/authzed/spicedb/internal/middleware/servicespecific" "github.com/authzed/spicedb/pkg/datastore" consistencymw "github.com/authzed/spicedb/pkg/middleware/consistency" + "github.com/authzed/spicedb/pkg/middleware/readiness" logmw "github.com/authzed/spicedb/pkg/middleware/logging" "github.com/authzed/spicedb/pkg/middleware/requestid" "github.com/authzed/spicedb/pkg/middleware/serverversion" @@ -173,6 +174,7 @@ const ( DefaultMiddlewareGRPCProm = "grpcprom" DefaultMiddlewareServerVersion = "serverversion" DefaultMiddlewareMemoryProtection = "memoryprotection" + DefaultMiddlewareReadiness = "readiness" DefaultInternalMiddlewareDispatch = "dispatch" DefaultInternalMiddlewareDatastore = "datastore" @@ -194,6 +196,7 @@ type MiddlewareOption struct { MismatchingZedTokenOption consistencymw.MismatchingTokenOption `debugmap:"visible"` MemoryUsageProvider memoryprotection.MemoryUsageProvider `debugmap:"hidden"` + ReadinessChecker readiness.ReadinessChecker `debugmap:"hidden"` unaryDatastoreMiddleware *ReferenceableMiddleware[grpc.UnaryServerInterceptor] `debugmap:"hidden"` streamDatastoreMiddleware *ReferenceableMiddleware[grpc.StreamServerInterceptor] `debugmap:"hidden"` @@ -307,6 +310,12 @@ func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryS WithInterceptor(grpcMetricsUnaryInterceptor). Done(), + NewUnaryMiddleware(). + WithName(DefaultMiddlewareReadiness). + WithInterceptor(readiness.NewGate(opts.ReadinessChecker).UnaryServerInterceptor()). + EnsureAlreadyExecuted(DefaultMiddlewareGRPCProm). // so that prom middleware reports blocked requests + Done(), + NewUnaryMiddleware(). WithName(DefaultMiddlewareMemoryProtection). WithInterceptor(selector.UnaryServerInterceptor( @@ -388,6 +397,12 @@ func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.St WithInterceptor(grpcMetricsStreamingInterceptor). Done(), + NewStreamMiddleware(). + WithName(DefaultMiddlewareReadiness). + WithInterceptor(readiness.NewGate(opts.ReadinessChecker).StreamServerInterceptor()). + EnsureInterceptorAlreadyExecuted(DefaultMiddlewareGRPCProm). // so that prom middleware reports blocked requests + Done(), + NewStreamMiddleware(). WithName(DefaultMiddlewareMemoryProtection). WithInterceptor(memoryProtectionStreamInterceptor.StreamServerInterceptor()). diff --git a/pkg/cmd/server/defaults_test.go b/pkg/cmd/server/defaults_test.go index d4e9d0ce3..9de8e1d37 100644 --- a/pkg/cmd/server/defaults_test.go +++ b/pkg/cmd/server/defaults_test.go @@ -35,6 +35,7 @@ func TestWithDatastore(t *testing.T) { "service", consistency.TreatMismatchingTokensAsError, memoryprotection.NewNoopMemoryUsageProvider(), + nil, // ReadinessChecker nil, nil, } @@ -78,6 +79,7 @@ func TestWithDatastoreMiddleware(t *testing.T) { "service", consistency.TreatMismatchingTokensAsError, memoryprotection.NewNoopMemoryUsageProvider(), + nil, // ReadinessChecker nil, nil, } diff --git a/pkg/cmd/server/server.go b/pkg/cmd/server/server.go index 966e30f63..f6cf78e16 100644 --- a/pkg/cmd/server/server.go +++ b/pkg/cmd/server/server.go @@ -423,6 +423,7 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) { MiddlewareServiceLabel: serverName, MismatchingZedTokenOption: mismatchZedTokenOption, MemoryUsageProvider: memoryUsageProvider, + ReadinessChecker: ds, } opts = opts.WithDatastore(ds) diff --git a/pkg/cmd/server/server_test.go b/pkg/cmd/server/server_test.go index 9226b849c..db87d25f6 100644 --- a/pkg/cmd/server/server_test.go +++ b/pkg/cmd/server/server_test.go @@ -456,7 +456,7 @@ func TestModifyUnaryMiddleware(t *testing.T) { }, }} - opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, "testing", consistency.TreatMismatchingTokensAsFullConsistency, memoryprotection.NewNoopMemoryUsageProvider(), nil, nil} + opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, "testing", consistency.TreatMismatchingTokensAsFullConsistency, memoryprotection.NewNoopMemoryUsageProvider(), nil, nil, nil} opt = opt.WithDatastore(nil) defaultMw, err := DefaultUnaryMiddleware(opt) @@ -484,7 +484,7 @@ func TestModifyStreamingMiddleware(t *testing.T) { }, }} - opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, "testing", consistency.TreatMismatchingTokensAsFullConsistency, memoryprotection.NewNoopMemoryUsageProvider(), nil, nil} + opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, "testing", consistency.TreatMismatchingTokensAsFullConsistency, memoryprotection.NewNoopMemoryUsageProvider(), nil, nil, nil} opt = opt.WithDatastore(nil) defaultMw, err := DefaultStreamingMiddleware(opt) diff --git a/pkg/cmd/server/zz_generated.middlewareoption.go b/pkg/cmd/server/zz_generated.middlewareoption.go index f31ca52e8..2e8a73841 100644 --- a/pkg/cmd/server/zz_generated.middlewareoption.go +++ b/pkg/cmd/server/zz_generated.middlewareoption.go @@ -5,6 +5,7 @@ import ( dispatch "github.com/authzed/spicedb/internal/dispatch" memoryprotection "github.com/authzed/spicedb/internal/middleware/memoryprotection" consistency "github.com/authzed/spicedb/pkg/middleware/consistency" + readiness "github.com/authzed/spicedb/pkg/middleware/readiness" defaults "github.com/creasty/defaults" auth "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth" zerolog "github.com/rs/zerolog" @@ -44,6 +45,9 @@ func (m *MiddlewareOption) ToOption() MiddlewareOptionOption { to.MiddlewareServiceLabel = m.MiddlewareServiceLabel to.MismatchingZedTokenOption = m.MismatchingZedTokenOption to.MemoryUsageProvider = m.MemoryUsageProvider + to.ReadinessChecker = m.ReadinessChecker + to.unaryDatastoreMiddleware = m.unaryDatastoreMiddleware + to.streamDatastoreMiddleware = m.streamDatastoreMiddleware } } @@ -169,3 +173,10 @@ func WithMemoryUsageProvider(memoryUsageProvider memoryprotection.MemoryUsagePro m.MemoryUsageProvider = memoryUsageProvider } } + +// WithReadinessChecker returns an option that can set ReadinessChecker on a MiddlewareOption +func WithReadinessChecker(readinessChecker readiness.ReadinessChecker) MiddlewareOptionOption { + return func(m *MiddlewareOption) { + m.ReadinessChecker = readinessChecker + } +} diff --git a/pkg/middleware/readiness/readiness.go b/pkg/middleware/readiness/readiness.go new file mode 100644 index 000000000..8b835a116 --- /dev/null +++ b/pkg/middleware/readiness/readiness.go @@ -0,0 +1,195 @@ +package readiness + +import ( + "context" + "strings" + "sync" + "time" + + "golang.org/x/sync/singleflight" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/authzed/spicedb/pkg/datastore" +) + +const ( + // healthCheckPrefix is the gRPC method prefix for health checks. + // We bypass readiness checks for health endpoints so Kubernetes probes work. + healthCheckPrefix = "/grpc.health.v1.Health/" + + // readyCacheTTL is how long to cache a positive ready state. + readyCacheTTL = 500 * time.Millisecond + + // notReadyCacheTTL is how long to cache a negative ready state. + // Shorter than readyCacheTTL to allow faster recovery detection. + notReadyCacheTTL = 100 * time.Millisecond + + // readinessCheckTimeout is the maximum time to wait for a readiness check. + // We use a dedicated context with this timeout inside singleflight to ensure + // consistent behavior regardless of which request's context triggered the check. + readinessCheckTimeout = 5 * time.Second +) + +// ReadinessChecker is the interface for checking datastore readiness. +type ReadinessChecker interface { + ReadyState(ctx context.Context) (datastore.ReadyState, error) +} + +// Gate blocks gRPC requests until the datastore is ready. +// It caches the ready state briefly to avoid overwhelming the datastore +// with readiness checks on every request. +type Gate struct { + checker ReadinessChecker + + // singleflight prevents thundering herd when cache expires + sfGroup singleflight.Group + + mu sync.RWMutex + cachedReady bool // GUARDED_BY(mu) + cachedMessage string // GUARDED_BY(mu) + cacheTime time.Time // GUARDED_BY(mu) +} + +// NewGate creates a new readiness gate with the given checker. +// If checker is nil, the gate will pass through all requests without checking. +func NewGate(checker ReadinessChecker) *Gate { + return &Gate{checker: checker} +} + +// readinessResult holds the result of a readiness check for singleflight. +type readinessResult struct { + ready bool + message string +} + +// isMigrationIssue returns true if the not-ready message indicates +// the database schema hasn't been migrated. Other not-ready reasons +// (like connection pool warmup) are transient and shouldn't block requests. +func isMigrationIssue(msg string) bool { + return strings.Contains(msg, "not migrated") || strings.Contains(msg, "migration") +} + +// isReady checks if the datastore is ready, using a cached value if available. +// Uses singleflight to prevent thundering herd on cache expiry. +// Only blocks requests for migration-related issues; transient states like +// connection pool warmup are allowed through. +func (g *Gate) isReady(ctx context.Context) (bool, string) { + // If no checker is configured, pass through + if g.checker == nil { + return true, "" + } + + // Fast path: check cache with read lock + g.mu.RLock() + elapsed := time.Since(g.cacheTime) + ttl := readyCacheTTL + if !g.cachedReady { + ttl = notReadyCacheTTL + } + if elapsed < ttl { + ready, msg := g.cachedReady, g.cachedMessage + g.mu.RUnlock() + return ready, msg + } + g.mu.RUnlock() + + // Slow path: use singleflight to deduplicate concurrent checks + result, _, _ := g.sfGroup.Do("readiness", func() (any, error) { + // Double-check cache after acquiring singleflight + g.mu.RLock() + elapsed := time.Since(g.cacheTime) + ttl := readyCacheTTL + if !g.cachedReady { + ttl = notReadyCacheTTL + } + if elapsed < ttl { + ready, msg := g.cachedReady, g.cachedMessage + g.mu.RUnlock() + return readinessResult{ready: ready, message: msg}, nil + } + g.mu.RUnlock() + + // Use an independent context with timeout for the readiness check. + // This ensures consistent behavior when multiple requests are coalesced + // by singleflight - we don't want the check to fail because the first + // request's context was cancelled. + checkCtx, cancel := context.WithTimeout(context.Background(), readinessCheckTimeout) + defer cancel() + + state, err := g.checker.ReadyState(checkCtx) + if err != nil { + // On error checking readiness, allow requests through. + // If the datastore is truly unavailable, requests will fail + // with appropriate errors from the datastore layer. + return readinessResult{ready: true, message: ""}, nil + } + + // Only block requests for migration-related issues. + // Transient states (connection pool warmup, etc.) should not block. + ready := state.IsReady || !isMigrationIssue(state.Message) + + // Update cache + g.mu.Lock() + g.cachedReady = ready + g.cachedMessage = state.Message + g.cacheTime = time.Now() + g.mu.Unlock() + + return readinessResult{ready: ready, message: state.Message}, nil + }) + + r := result.(readinessResult) + return r.ready, r.message +} + +// formatNotReadyError creates a user-friendly error message based on the readiness failure reason. +// TODO(authzed/api#159): Once ERROR_REASON_DATASTORE_NOT_MIGRATED is available in the API, +// use spiceerrors.WithCodeAndReason to include the structured error reason. +func formatNotReadyError(msg string) error { + // Check if this is a migration-related issue + if strings.Contains(msg, "not migrated") || strings.Contains(msg, "migration") { + return status.Errorf(codes.FailedPrecondition, + "SpiceDB datastore is not migrated. Please run 'spicedb datastore migrate'. Details: %s", msg) + } + // Generic not-ready message for other cases (connection issues, pool not ready, etc.) + return status.Errorf(codes.FailedPrecondition, + "SpiceDB datastore is not ready. Details: %s", msg) +} + +// UnaryServerInterceptor returns a gRPC unary interceptor that blocks +// requests until the datastore is ready. +func (g *Gate) UnaryServerInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + // Bypass health checks so Kubernetes probes work + if strings.HasPrefix(info.FullMethod, healthCheckPrefix) { + return handler(ctx, req) + } + + ready, msg := g.isReady(ctx) + if !ready { + return nil, formatNotReadyError(msg) + } + + return handler(ctx, req) + } +} + +// StreamServerInterceptor returns a gRPC stream interceptor that blocks +// streams until the datastore is ready. +func (g *Gate) StreamServerInterceptor() grpc.StreamServerInterceptor { + return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + // Bypass health checks so Kubernetes probes work + if strings.HasPrefix(info.FullMethod, healthCheckPrefix) { + return handler(srv, ss) + } + + ready, msg := g.isReady(ss.Context()) + if !ready { + return formatNotReadyError(msg) + } + + return handler(srv, ss) + } +} diff --git a/pkg/middleware/readiness/readiness_test.go b/pkg/middleware/readiness/readiness_test.go new file mode 100644 index 000000000..2b6fda668 --- /dev/null +++ b/pkg/middleware/readiness/readiness_test.go @@ -0,0 +1,393 @@ +package readiness + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/authzed/spicedb/pkg/datastore" +) + +type mockChecker struct { + ready bool + message string + err error + callCount atomic.Int32 +} + +func (m *mockChecker) ReadyState(_ context.Context) (datastore.ReadyState, error) { + m.callCount.Add(1) + if m.err != nil { + return datastore.ReadyState{}, m.err + } + return datastore.ReadyState{ + IsReady: m.ready, + Message: m.message, + }, nil +} + +func TestGate_BlocksWhenNotReady(t *testing.T) { + checker := &mockChecker{ + ready: false, + message: "datastore is not migrated", + } + gate := NewGate(checker) + + interceptor := gate.UnaryServerInterceptor() + _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/authzed.api.v1.PermissionsService/CheckPermission", + }, func(ctx context.Context, req any) (any, error) { + t.Fatal("handler should not be called when not ready") + return nil, nil + }) + + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.FailedPrecondition, st.Code()) + require.Contains(t, st.Message(), "not migrated") + require.Contains(t, st.Message(), "spicedb datastore migrate") +} + +func TestGate_AllowsWhenReady(t *testing.T) { + checker := &mockChecker{ready: true} + gate := NewGate(checker) + + handlerCalled := false + interceptor := gate.UnaryServerInterceptor() + _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/authzed.api.v1.PermissionsService/CheckPermission", + }, func(ctx context.Context, req any) (any, error) { + handlerCalled = true + return "response", nil + }) + + require.NoError(t, err) + require.True(t, handlerCalled) +} + +func TestGate_BypassesHealthCheck(t *testing.T) { + checker := &mockChecker{ + ready: false, + message: "not ready", + } + gate := NewGate(checker) + + handlerCalled := false + interceptor := gate.UnaryServerInterceptor() + _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/grpc.health.v1.Health/Check", + }, func(ctx context.Context, req any) (any, error) { + handlerCalled = true + return "ok", nil + }) + + require.NoError(t, err) + require.True(t, handlerCalled) + // Checker should not be called for health checks + require.Equal(t, int32(0), checker.callCount.Load()) +} + +func TestGate_BypassesHealthWatch(t *testing.T) { + checker := &mockChecker{ + ready: false, + message: "not ready", + } + gate := NewGate(checker) + + handlerCalled := false + interceptor := gate.UnaryServerInterceptor() + _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/grpc.health.v1.Health/Watch", + }, func(ctx context.Context, req any) (any, error) { + handlerCalled = true + return "ok", nil + }) + + require.NoError(t, err) + require.True(t, handlerCalled) +} + +func TestGate_CachesReadyState(t *testing.T) { + checker := &mockChecker{ready: true} + gate := NewGate(checker) + + interceptor := gate.UnaryServerInterceptor() + info := &grpc.UnaryServerInfo{FullMethod: "/test/Method"} + handler := func(ctx context.Context, req any) (any, error) { + return nil, nil + } + + // First call should check readiness + _, _ = interceptor(context.Background(), nil, info, handler) + require.Equal(t, int32(1), checker.callCount.Load()) + + // Second call should use cache + _, _ = interceptor(context.Background(), nil, info, handler) + require.Equal(t, int32(1), checker.callCount.Load()) + + // Third call should use cache + _, _ = interceptor(context.Background(), nil, info, handler) + require.Equal(t, int32(1), checker.callCount.Load()) +} + +func TestGate_CacheExpires(t *testing.T) { + checker := &mockChecker{ready: true} + gate := NewGate(checker) + + // Override cache time to test expiry + gate.mu.Lock() + gate.cachedReady = true + gate.cacheTime = time.Now().Add(-2 * readyCacheTTL) // Expired + gate.mu.Unlock() + + interceptor := gate.UnaryServerInterceptor() + _, _ = interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/test/Method", + }, func(ctx context.Context, req any) (any, error) { + return nil, nil + }) + + // Should have called checker because cache expired + require.Equal(t, int32(1), checker.callCount.Load()) +} + +func TestGate_SingleflightPreventsThunderingHerd(t *testing.T) { + // Slow checker that takes 100ms + checker := &mockChecker{ready: true} + gate := NewGate(checker) + + interceptor := gate.UnaryServerInterceptor() + info := &grpc.UnaryServerInfo{FullMethod: "/test/Method"} + handler := func(ctx context.Context, req any) (any, error) { + return nil, nil + } + + // Launch 10 concurrent requests + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, _ = interceptor(context.Background(), nil, info, handler) + }() + } + wg.Wait() + + // Singleflight should have deduplicated the calls + // We might get 1 or 2 calls depending on timing, but definitely not 10 + require.LessOrEqual(t, checker.callCount.Load(), int32(2)) +} + +func TestGate_PassesThroughOnCheckerError(t *testing.T) { + checker := &mockChecker{ + err: errors.New("connection refused"), + } + gate := NewGate(checker) + + handlerCalled := false + interceptor := gate.UnaryServerInterceptor() + _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/test/Method", + }, func(ctx context.Context, req any) (any, error) { + handlerCalled = true + return "ok", nil + }) + + // Errors during readiness check should allow requests through. + // If the datastore is truly unavailable, requests will fail at the + // datastore layer with appropriate errors. + require.NoError(t, err) + require.True(t, handlerCalled) +} + +func TestGate_StreamInterceptorBlocksWhenNotReady(t *testing.T) { + checker := &mockChecker{ + ready: false, + message: "not migrated", + } + gate := NewGate(checker) + + interceptor := gate.StreamServerInterceptor() + err := interceptor(nil, &mockServerStream{}, &grpc.StreamServerInfo{ + FullMethod: "/authzed.api.v1.WatchService/Watch", + }, func(srv any, stream grpc.ServerStream) error { + t.Fatal("handler should not be called when not ready") + return nil + }) + + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.FailedPrecondition, st.Code()) +} + +func TestGate_StreamInterceptorAllowsWhenReady(t *testing.T) { + checker := &mockChecker{ready: true} + gate := NewGate(checker) + + handlerCalled := false + interceptor := gate.StreamServerInterceptor() + err := interceptor(nil, &mockServerStream{}, &grpc.StreamServerInfo{ + FullMethod: "/authzed.api.v1.WatchService/Watch", + }, func(srv any, stream grpc.ServerStream) error { + handlerCalled = true + return nil + }) + + require.NoError(t, err) + require.True(t, handlerCalled) +} + +// mockServerStream implements grpc.ServerStream for testing +type mockServerStream struct { + grpc.ServerStream + ctx context.Context +} + +func (m *mockServerStream) Context() context.Context { + if m.ctx != nil { + return m.ctx + } + return context.Background() +} + +func TestGate_NilCheckerPassesThrough(t *testing.T) { + gate := NewGate(nil) + + handlerCalled := false + interceptor := gate.UnaryServerInterceptor() + _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/authzed.api.v1.PermissionsService/CheckPermission", + }, func(ctx context.Context, req any) (any, error) { + handlerCalled = true + return "response", nil + }) + + require.NoError(t, err) + require.True(t, handlerCalled) +} + +func TestGate_ErrorDoesNotCache(t *testing.T) { + checker := &mockChecker{err: errors.New("temporary connection error")} + gate := NewGate(checker) + + interceptor := gate.UnaryServerInterceptor() + info := &grpc.UnaryServerInfo{FullMethod: "/test/Method"} + + // First call: checker errors but request passes through + handlerCalled := false + _, err := interceptor(context.Background(), nil, info, func(ctx context.Context, req any) (any, error) { + handlerCalled = true + return "ok", nil + }) + require.NoError(t, err) + require.True(t, handlerCalled) + require.Equal(t, int32(1), checker.callCount.Load()) + + // Second call should retry checker (errors are not cached) + handlerCalled = false + _, err = interceptor(context.Background(), nil, info, func(ctx context.Context, req any) (any, error) { + handlerCalled = true + return "ok", nil + }) + require.NoError(t, err) + require.True(t, handlerCalled) + require.Equal(t, int32(2), checker.callCount.Load()) +} + +func TestGate_NegativeCacheReducesChecks(t *testing.T) { + checker := &mockChecker{ + ready: false, + message: "datastore is not migrated", + } + gate := NewGate(checker) + + interceptor := gate.UnaryServerInterceptor() + info := &grpc.UnaryServerInfo{FullMethod: "/test/Method"} + handler := func(ctx context.Context, req any) (any, error) { + return nil, nil + } + + // First call checks readiness + _, _ = interceptor(context.Background(), nil, info, handler) + require.Equal(t, int32(1), checker.callCount.Load()) + + // Second call should use negative cache (within notReadyCacheTTL) + _, _ = interceptor(context.Background(), nil, info, handler) + require.Equal(t, int32(1), checker.callCount.Load()) + + // Third call should use negative cache + _, _ = interceptor(context.Background(), nil, info, handler) + require.Equal(t, int32(1), checker.callCount.Load()) +} + +func TestGate_ErrorMessageContextAware(t *testing.T) { + t.Run("migration issue blocks with actionable message", func(t *testing.T) { + checker := &mockChecker{ + ready: false, + message: "datastore is not migrated: currently at revision \"\"", + } + gate := NewGate(checker) + + interceptor := gate.UnaryServerInterceptor() + _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/test/Method", + }, func(ctx context.Context, req any) (any, error) { + t.Fatal("handler should not be called for migration issues") + return nil, nil + }) + + require.Error(t, err) + require.Contains(t, err.Error(), "spicedb datastore migrate") + require.Contains(t, err.Error(), "not migrated") + }) + + t.Run("connection pool issue passes through", func(t *testing.T) { + checker := &mockChecker{ + ready: false, + message: "spicedb does not have the required minimum connection count", + } + gate := NewGate(checker) + + handlerCalled := false + interceptor := gate.UnaryServerInterceptor() + _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/test/Method", + }, func(ctx context.Context, req any) (any, error) { + handlerCalled = true + return "ok", nil + }) + + require.NoError(t, err) + require.True(t, handlerCalled) + }) + + t.Run("generic not ready passes through", func(t *testing.T) { + checker := &mockChecker{ + ready: false, + message: "some other issue", + } + gate := NewGate(checker) + + handlerCalled := false + interceptor := gate.UnaryServerInterceptor() + _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/test/Method", + }, func(ctx context.Context, req any) (any, error) { + handlerCalled = true + return "ok", nil + }) + + require.NoError(t, err) + require.True(t, handlerCalled) + }) +} From 24d2ead949fde29d3ea2b61dc3b1c007ec70325b Mon Sep 17 00:00:00 2001 From: ivanauth Date: Thu, 29 Jan 2026 11:30:57 -0500 Subject: [PATCH 5/7] fix: correct import order in defaults.go Sort imports alphabetically to satisfy gci linter - logmw before readiness. Co-Authored-By: Claude Opus 4.5 --- pkg/cmd/server/defaults.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/cmd/server/defaults.go b/pkg/cmd/server/defaults.go index b9cbee6e9..34521141f 100644 --- a/pkg/cmd/server/defaults.go +++ b/pkg/cmd/server/defaults.go @@ -39,8 +39,8 @@ import ( "github.com/authzed/spicedb/internal/middleware/servicespecific" "github.com/authzed/spicedb/pkg/datastore" consistencymw "github.com/authzed/spicedb/pkg/middleware/consistency" - "github.com/authzed/spicedb/pkg/middleware/readiness" logmw "github.com/authzed/spicedb/pkg/middleware/logging" + "github.com/authzed/spicedb/pkg/middleware/readiness" "github.com/authzed/spicedb/pkg/middleware/requestid" "github.com/authzed/spicedb/pkg/middleware/serverversion" "github.com/authzed/spicedb/pkg/releases" From 2eefa2616fa2f60d4fecf105ff4d9ca2818b98e3 Mon Sep 17 00:00:00 2001 From: ivanauth Date: Thu, 29 Jan 2026 12:09:30 -0500 Subject: [PATCH 6/7] fix: remove private fields from generated ToOption method --- pkg/cmd/server/zz_generated.middlewareoption.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/pkg/cmd/server/zz_generated.middlewareoption.go b/pkg/cmd/server/zz_generated.middlewareoption.go index 2e8a73841..24c494f3c 100644 --- a/pkg/cmd/server/zz_generated.middlewareoption.go +++ b/pkg/cmd/server/zz_generated.middlewareoption.go @@ -46,8 +46,6 @@ func (m *MiddlewareOption) ToOption() MiddlewareOptionOption { to.MismatchingZedTokenOption = m.MismatchingZedTokenOption to.MemoryUsageProvider = m.MemoryUsageProvider to.ReadinessChecker = m.ReadinessChecker - to.unaryDatastoreMiddleware = m.unaryDatastoreMiddleware - to.streamDatastoreMiddleware = m.streamDatastoreMiddleware } } From f918ab5375e8a8dfc27fa727b3c9cbd00d4aff00 Mon Sep 17 00:00:00 2001 From: ivanauth Date: Thu, 5 Mar 2026 14:35:16 -0500 Subject: [PATCH 7/7] retrigger CI