diff --git a/internal/datastore/common/errors.go b/internal/datastore/common/errors.go index 5344873c6..b6bb7fbf8 100644 --- a/internal/datastore/common/errors.go +++ b/internal/datastore/common/errors.go @@ -176,3 +176,33 @@ type RevisionUnavailableError struct { func NewRevisionUnavailableError(err error) error { return RevisionUnavailableError{err} } + +// SchemaNotInitializedError is returned when a datastore operation fails because the +// required database tables do not exist. This typically means that migrations have not been run. +type SchemaNotInitializedError struct { + error +} + +func (err SchemaNotInitializedError) GRPCStatus() *status.Status { + // TODO: Update to use ERROR_REASON_DATASTORE_NOT_MIGRATED once authzed/api#159 is merged + return spiceerrors.WithCodeAndDetails( + err, + codes.FailedPrecondition, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_UNSPECIFIED, + map[string]string{}, + ), + ) +} + +func (err SchemaNotInitializedError) Unwrap() error { + return err.error +} + +// NewSchemaNotInitializedError creates a new SchemaNotInitializedError with a helpful message +// instructing the user to run migrations. +func NewSchemaNotInitializedError(underlying error) error { + return SchemaNotInitializedError{ + fmt.Errorf("%w. please run \"spicedb datastore migrate\"", underlying), + } +} diff --git a/internal/datastore/common/errors_test.go b/internal/datastore/common/errors_test.go new file mode 100644 index 000000000..e53f44672 --- /dev/null +++ b/internal/datastore/common/errors_test.go @@ -0,0 +1,55 @@ +package common + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" +) + +func TestSchemaNotInitializedError(t *testing.T) { + underlyingErr := fmt.Errorf("relation \"caveat\" does not exist (SQLSTATE 42P01)") + err := NewSchemaNotInitializedError(underlyingErr) + + t.Run("error message contains migration instructions", func(t *testing.T) { + require.Contains(t, err.Error(), "spicedb datastore migrate") + // The error message now includes the underlying error first, followed by the instruction + require.Contains(t, err.Error(), "relation \"caveat\" does not exist") + }) + + t.Run("unwrap returns underlying error", func(t *testing.T) { + var schemaErr SchemaNotInitializedError + require.ErrorAs(t, err, &schemaErr) + require.ErrorIs(t, schemaErr.Unwrap(), underlyingErr) + }) + + t.Run("grpc status is FailedPrecondition", func(t *testing.T) { + var schemaErr SchemaNotInitializedError + require.ErrorAs(t, err, &schemaErr) + status := schemaErr.GRPCStatus() + require.Equal(t, codes.FailedPrecondition, status.Code()) + }) + + t.Run("can be detected with errors.As", func(t *testing.T) { + var schemaErr SchemaNotInitializedError + require.ErrorAs(t, err, &schemaErr) + }) + + t.Run("wrapped error preserves chain", func(t *testing.T) { + wrappedErr := fmt.Errorf("outer: %w", err) + var schemaErr SchemaNotInitializedError + require.ErrorAs(t, wrappedErr, &schemaErr) + }) + + t.Run("grpc status extractable from wrapped error", func(t *testing.T) { + // This tests the scenario where SchemaNotInitializedError is wrapped + // by another fmt.Errorf (e.g., in crdb/caveat.go). The gRPC library + // uses errors.As to extract GRPCStatus from wrapped errors. + wrappedErr := fmt.Errorf("outer context: %w", err) + var schemaErr SchemaNotInitializedError + require.ErrorAs(t, wrappedErr, &schemaErr) + status := schemaErr.GRPCStatus() + require.Equal(t, codes.FailedPrecondition, status.Code()) + }) +} diff --git a/internal/datastore/crdb/migrations/driver.go b/internal/datastore/crdb/migrations/driver.go index 74dea5970..db3b72986 100644 --- a/internal/datastore/crdb/migrations/driver.go +++ b/internal/datastore/crdb/migrations/driver.go @@ -15,8 +15,6 @@ import ( const ( errUnableToInstantiate = "unable to instantiate CRDBDriver: %w" - postgresMissingTableErrorCode = "42P01" - queryLoadVersion = "SELECT version_num from schema_version" queryWriteVersion = "UPDATE schema_version SET version_num=$1 WHERE version_num=$2" ) @@ -52,7 +50,7 @@ func (apd *CRDBDriver) Version(ctx context.Context) (string, error) { if err := apd.db.QueryRow(ctx, queryLoadVersion).Scan(&loaded); err != nil { var pgErr *pgconn.PgError - if errors.As(err, &pgErr) && pgErr.Code == postgresMissingTableErrorCode { + if errors.As(err, &pgErr) && pgErr.Code == pgxcommon.PgMissingTable { return "", nil } return "", fmt.Errorf("unable to load alembic revision: %w", err) diff --git a/internal/datastore/mysql/common/errors.go b/internal/datastore/mysql/common/errors.go new file mode 100644 index 000000000..7eaf73393 --- /dev/null +++ b/internal/datastore/mysql/common/errors.go @@ -0,0 +1,38 @@ +package common + +import ( + "errors" + + "github.com/go-sql-driver/mysql" + + dscommon "github.com/authzed/spicedb/internal/datastore/common" +) + +const ( + // mysqlMissingTableErrorNumber is the MySQL error number for "table doesn't exist". + // This corresponds to MySQL error 1146 (ER_NO_SUCH_TABLE) with SQLSTATE 42S02. + mysqlMissingTableErrorNumber = 1146 +) + +// IsMissingTableError returns true if the error is a MySQL error indicating a missing table. +// This typically happens when migrations have not been run. +func IsMissingTableError(err error) bool { + var mysqlErr *mysql.MySQLError + return errors.As(err, &mysqlErr) && mysqlErr.Number == mysqlMissingTableErrorNumber +} + +// WrapMissingTableError checks if the error is a missing table error and wraps it with +// a helpful message instructing the user to run migrations. If it's not a missing table error, +// it returns nil. If it's already a SchemaNotInitializedError, it returns the original error +// to preserve the wrapped error through the call chain. +func WrapMissingTableError(err error) error { + // Don't double-wrap if already a SchemaNotInitializedError - return original to preserve it + var schemaErr dscommon.SchemaNotInitializedError + if errors.As(err, &schemaErr) { + return err + } + if IsMissingTableError(err) { + return dscommon.NewSchemaNotInitializedError(err) + } + return nil +} diff --git a/internal/datastore/mysql/common/errors_test.go b/internal/datastore/mysql/common/errors_test.go new file mode 100644 index 000000000..6f2498491 --- /dev/null +++ b/internal/datastore/mysql/common/errors_test.go @@ -0,0 +1,112 @@ +package common + +import ( + "fmt" + "testing" + + "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/require" + + dscommon "github.com/authzed/spicedb/internal/datastore/common" +) + +func TestIsMissingTableError(t *testing.T) { + t.Run("returns true for missing table error", func(t *testing.T) { + mysqlErr := &mysql.MySQLError{ + Number: mysqlMissingTableErrorNumber, + Message: "Table 'spicedb.caveat' doesn't exist", + } + require.True(t, IsMissingTableError(mysqlErr)) + }) + + t.Run("returns false for other mysql errors", func(t *testing.T) { + mysqlErr := &mysql.MySQLError{ + Number: 1062, // Duplicate entry error + Message: "Duplicate entry '1' for key 'PRIMARY'", + } + require.False(t, IsMissingTableError(mysqlErr)) + }) + + t.Run("returns false for non-mysql errors", func(t *testing.T) { + err := fmt.Errorf("some other error") + require.False(t, IsMissingTableError(err)) + }) + + t.Run("returns false for nil error", func(t *testing.T) { + require.False(t, IsMissingTableError(nil)) + }) + + t.Run("returns true for wrapped missing table error", func(t *testing.T) { + mysqlErr := &mysql.MySQLError{ + Number: mysqlMissingTableErrorNumber, + Message: "Table 'spicedb.caveat' doesn't exist", + } + wrappedErr := fmt.Errorf("query failed: %w", mysqlErr) + require.True(t, IsMissingTableError(wrappedErr)) + }) +} + +func TestWrapMissingTableError(t *testing.T) { + t.Run("wraps missing table error", func(t *testing.T) { + mysqlErr := &mysql.MySQLError{ + Number: mysqlMissingTableErrorNumber, + Message: "Table 'spicedb.caveat' doesn't exist", + } + wrapped := WrapMissingTableError(mysqlErr) + require.Error(t, wrapped) + + var schemaErr dscommon.SchemaNotInitializedError + require.ErrorAs(t, wrapped, &schemaErr) + require.Contains(t, wrapped.Error(), "spicedb datastore migrate") + }) + + t.Run("returns nil for non-missing-table errors", func(t *testing.T) { + mysqlErr := &mysql.MySQLError{ + Number: 1062, // Duplicate entry error + Message: "Duplicate entry '1' for key 'PRIMARY'", + } + require.NoError(t, WrapMissingTableError(mysqlErr)) + }) + + t.Run("returns nil for non-mysql errors", func(t *testing.T) { + err := fmt.Errorf("some other error") + require.NoError(t, WrapMissingTableError(err)) + }) + + t.Run("returns nil for nil error", func(t *testing.T) { + require.NoError(t, WrapMissingTableError(nil)) + }) + + t.Run("preserves original error in chain", func(t *testing.T) { + mysqlErr := &mysql.MySQLError{ + Number: mysqlMissingTableErrorNumber, + Message: "Table 'spicedb.caveat' doesn't exist", + } + wrapped := WrapMissingTableError(mysqlErr) + require.Error(t, wrapped) + + // The original mysql error should be accessible via unwrapping + var foundMySQLErr *mysql.MySQLError + require.ErrorAs(t, wrapped, &foundMySQLErr) + require.Equal(t, uint16(mysqlMissingTableErrorNumber), foundMySQLErr.Number) + }) + + t.Run("does not double-wrap already wrapped errors", func(t *testing.T) { + mysqlErr := &mysql.MySQLError{ + Number: mysqlMissingTableErrorNumber, + Message: "Table 'spicedb.caveat' doesn't exist", + } + // First wrap + wrapped := WrapMissingTableError(mysqlErr) + require.Error(t, wrapped) + + // Second wrap should return the already-wrapped error (preserving it through call chain) + doubleWrapped := WrapMissingTableError(wrapped) + require.Error(t, doubleWrapped) + require.Equal(t, wrapped, doubleWrapped) + + // Should still be detectable as SchemaNotInitializedError + var schemaErr dscommon.SchemaNotInitializedError + require.ErrorAs(t, doubleWrapped, &schemaErr) + }) +} diff --git a/internal/datastore/postgres/common/errors.go b/internal/datastore/postgres/common/errors.go index 2728252aa..e7e191706 100644 --- a/internal/datastore/postgres/common/errors.go +++ b/internal/datastore/postgres/common/errors.go @@ -19,6 +19,10 @@ const ( pgReadOnlyTransaction = "25006" pgQueryCanceled = "57014" pgInvalidArgument = "22023" + + // PgMissingTable is the Postgres error code for "relation does not exist". + // This is used to detect when migrations have not been run. + PgMissingTable = "42P01" ) var ( @@ -106,3 +110,26 @@ func ConvertToWriteConstraintError(livingTupleConstraints []string, err error) e return nil } + +// IsMissingTableError returns true if the error is a Postgres error indicating a missing table. +// This typically happens when migrations have not been run. +func IsMissingTableError(err error) bool { + var pgerr *pgconn.PgError + return errors.As(err, &pgerr) && pgerr.Code == PgMissingTable +} + +// WrapMissingTableError checks if the error is a missing table error and wraps it with +// a helpful message instructing the user to run migrations. If it's not a missing table error, +// it returns nil. If it's already a SchemaNotInitializedError, it returns the original error +// to preserve the wrapped error through the call chain. +func WrapMissingTableError(err error) error { + // Don't double-wrap if already a SchemaNotInitializedError - return original to preserve it + var schemaErr dscommon.SchemaNotInitializedError + if errors.As(err, &schemaErr) { + return err + } + if IsMissingTableError(err) { + return dscommon.NewSchemaNotInitializedError(err) + } + return nil +} diff --git a/internal/datastore/postgres/common/errors_test.go b/internal/datastore/postgres/common/errors_test.go new file mode 100644 index 000000000..8410ef255 --- /dev/null +++ b/internal/datastore/postgres/common/errors_test.go @@ -0,0 +1,112 @@ +package common + +import ( + "fmt" + "testing" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/require" + + dscommon "github.com/authzed/spicedb/internal/datastore/common" +) + +func TestIsMissingTableError(t *testing.T) { + t.Run("returns true for missing table error", func(t *testing.T) { + pgErr := &pgconn.PgError{ + Code: PgMissingTable, + Message: "relation \"caveat\" does not exist", + } + require.True(t, IsMissingTableError(pgErr)) + }) + + t.Run("returns false for other postgres errors", func(t *testing.T) { + pgErr := &pgconn.PgError{ + Code: pgSerializationFailure, + Message: "could not serialize access", + } + require.False(t, IsMissingTableError(pgErr)) + }) + + t.Run("returns false for non-postgres errors", func(t *testing.T) { + err := fmt.Errorf("some other error") + require.False(t, IsMissingTableError(err)) + }) + + t.Run("returns false for nil error", func(t *testing.T) { + require.False(t, IsMissingTableError(nil)) + }) + + t.Run("returns true for wrapped missing table error", func(t *testing.T) { + pgErr := &pgconn.PgError{ + Code: PgMissingTable, + Message: "relation \"caveat\" does not exist", + } + wrappedErr := fmt.Errorf("query failed: %w", pgErr) + require.True(t, IsMissingTableError(wrappedErr)) + }) +} + +func TestWrapMissingTableError(t *testing.T) { + t.Run("wraps missing table error", func(t *testing.T) { + pgErr := &pgconn.PgError{ + Code: PgMissingTable, + Message: "relation \"caveat\" does not exist", + } + wrapped := WrapMissingTableError(pgErr) + require.Error(t, wrapped) + + var schemaErr dscommon.SchemaNotInitializedError + require.ErrorAs(t, wrapped, &schemaErr) + require.Contains(t, wrapped.Error(), "spicedb datastore migrate") + }) + + t.Run("returns nil for non-missing-table errors", func(t *testing.T) { + pgErr := &pgconn.PgError{ + Code: pgSerializationFailure, + Message: "could not serialize access", + } + require.NoError(t, WrapMissingTableError(pgErr)) + }) + + t.Run("returns nil for non-postgres errors", func(t *testing.T) { + err := fmt.Errorf("some other error") + require.NoError(t, WrapMissingTableError(err)) + }) + + t.Run("returns nil for nil error", func(t *testing.T) { + require.NoError(t, WrapMissingTableError(nil)) + }) + + t.Run("preserves original error in chain", func(t *testing.T) { + pgErr := &pgconn.PgError{ + Code: PgMissingTable, + Message: "relation \"caveat\" does not exist", + } + wrapped := WrapMissingTableError(pgErr) + require.Error(t, wrapped) + + // The original postgres error should be accessible via unwrapping + var foundPgErr *pgconn.PgError + require.ErrorAs(t, wrapped, &foundPgErr) + require.Equal(t, PgMissingTable, foundPgErr.Code) + }) + + t.Run("does not double-wrap already wrapped errors", func(t *testing.T) { + pgErr := &pgconn.PgError{ + Code: PgMissingTable, + Message: "relation \"caveat\" does not exist", + } + // First wrap + wrapped := WrapMissingTableError(pgErr) + require.Error(t, wrapped) + + // Second wrap should return the already-wrapped error (preserving it through call chain) + doubleWrapped := WrapMissingTableError(wrapped) + require.Error(t, doubleWrapped) + require.Equal(t, wrapped, doubleWrapped) + + // Should still be detectable as SchemaNotInitializedError + var schemaErr dscommon.SchemaNotInitializedError + require.ErrorAs(t, doubleWrapped, &schemaErr) + }) +} diff --git a/internal/datastore/postgres/migrations/driver.go b/internal/datastore/postgres/migrations/driver.go index 1054ecde4..522ae790e 100644 --- a/internal/datastore/postgres/migrations/driver.go +++ b/internal/datastore/postgres/migrations/driver.go @@ -15,8 +15,6 @@ import ( "github.com/authzed/spicedb/pkg/migrate" ) -const postgresMissingTableErrorCode = "42P01" - var tracer = otel.Tracer("spicedb/internal/datastore/common") // AlembicPostgresDriver implements a schema migration facility for use in @@ -74,7 +72,7 @@ func (apd *AlembicPostgresDriver) Version(ctx context.Context) (string, error) { if err := apd.db.QueryRow(ctx, "SELECT version_num from alembic_version").Scan(&loaded); err != nil { var pgErr *pgconn.PgError - if errors.As(err, &pgErr) && pgErr.Code == postgresMissingTableErrorCode { + if errors.As(err, &pgErr) && pgErr.Code == pgxcommon.PgMissingTable { return "", nil } return "", fmt.Errorf("unable to load alembic revision: %w", err) diff --git a/internal/datastore/spanner/errors.go b/internal/datastore/spanner/errors.go new file mode 100644 index 000000000..6ab079449 --- /dev/null +++ b/internal/datastore/spanner/errors.go @@ -0,0 +1,40 @@ +package spanner + +import ( + "errors" + "strings" + + "cloud.google.com/go/spanner" + "google.golang.org/grpc/codes" + + "github.com/authzed/spicedb/internal/datastore/common" +) + +// IsMissingTableError returns true if the error is a Spanner error indicating a missing table. +// This typically happens when migrations have not been run. +func IsMissingTableError(err error) bool { + if spanner.ErrCode(err) == codes.NotFound { + // Check if it's specifically about a missing table + errMsg := err.Error() + if strings.Contains(errMsg, "Table not found") { + return true + } + } + return false +} + +// WrapMissingTableError checks if the error is a missing table error and wraps it with +// a helpful message instructing the user to run migrations. If it's not a missing table error, +// it returns nil. If it's already a SchemaNotInitializedError, it returns the original error +// to preserve the wrapped error through the call chain. +func WrapMissingTableError(err error) error { + // Don't double-wrap if already a SchemaNotInitializedError - return original to preserve it + var schemaErr common.SchemaNotInitializedError + if errors.As(err, &schemaErr) { + return err + } + if IsMissingTableError(err) { + return common.NewSchemaNotInitializedError(err) + } + return nil +} diff --git a/pkg/cmd/server/defaults.go b/pkg/cmd/server/defaults.go index 3d3ef0a37..34521141f 100644 --- a/pkg/cmd/server/defaults.go +++ b/pkg/cmd/server/defaults.go @@ -40,6 +40,7 @@ import ( "github.com/authzed/spicedb/pkg/datastore" consistencymw "github.com/authzed/spicedb/pkg/middleware/consistency" logmw "github.com/authzed/spicedb/pkg/middleware/logging" + "github.com/authzed/spicedb/pkg/middleware/readiness" "github.com/authzed/spicedb/pkg/middleware/requestid" "github.com/authzed/spicedb/pkg/middleware/serverversion" "github.com/authzed/spicedb/pkg/releases" @@ -173,6 +174,7 @@ const ( DefaultMiddlewareGRPCProm = "grpcprom" DefaultMiddlewareServerVersion = "serverversion" DefaultMiddlewareMemoryProtection = "memoryprotection" + DefaultMiddlewareReadiness = "readiness" DefaultInternalMiddlewareDispatch = "dispatch" DefaultInternalMiddlewareDatastore = "datastore" @@ -194,6 +196,7 @@ type MiddlewareOption struct { MismatchingZedTokenOption consistencymw.MismatchingTokenOption `debugmap:"visible"` MemoryUsageProvider memoryprotection.MemoryUsageProvider `debugmap:"hidden"` + ReadinessChecker readiness.ReadinessChecker `debugmap:"hidden"` unaryDatastoreMiddleware *ReferenceableMiddleware[grpc.UnaryServerInterceptor] `debugmap:"hidden"` streamDatastoreMiddleware *ReferenceableMiddleware[grpc.StreamServerInterceptor] `debugmap:"hidden"` @@ -307,6 +310,12 @@ func DefaultUnaryMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.UnaryS WithInterceptor(grpcMetricsUnaryInterceptor). Done(), + NewUnaryMiddleware(). + WithName(DefaultMiddlewareReadiness). + WithInterceptor(readiness.NewGate(opts.ReadinessChecker).UnaryServerInterceptor()). + EnsureAlreadyExecuted(DefaultMiddlewareGRPCProm). // so that prom middleware reports blocked requests + Done(), + NewUnaryMiddleware(). WithName(DefaultMiddlewareMemoryProtection). WithInterceptor(selector.UnaryServerInterceptor( @@ -388,6 +397,12 @@ func DefaultStreamingMiddleware(opts MiddlewareOption) (*MiddlewareChain[grpc.St WithInterceptor(grpcMetricsStreamingInterceptor). Done(), + NewStreamMiddleware(). + WithName(DefaultMiddlewareReadiness). + WithInterceptor(readiness.NewGate(opts.ReadinessChecker).StreamServerInterceptor()). + EnsureInterceptorAlreadyExecuted(DefaultMiddlewareGRPCProm). // so that prom middleware reports blocked requests + Done(), + NewStreamMiddleware(). WithName(DefaultMiddlewareMemoryProtection). WithInterceptor(memoryProtectionStreamInterceptor.StreamServerInterceptor()). diff --git a/pkg/cmd/server/defaults_test.go b/pkg/cmd/server/defaults_test.go index d4e9d0ce3..9de8e1d37 100644 --- a/pkg/cmd/server/defaults_test.go +++ b/pkg/cmd/server/defaults_test.go @@ -35,6 +35,7 @@ func TestWithDatastore(t *testing.T) { "service", consistency.TreatMismatchingTokensAsError, memoryprotection.NewNoopMemoryUsageProvider(), + nil, // ReadinessChecker nil, nil, } @@ -78,6 +79,7 @@ func TestWithDatastoreMiddleware(t *testing.T) { "service", consistency.TreatMismatchingTokensAsError, memoryprotection.NewNoopMemoryUsageProvider(), + nil, // ReadinessChecker nil, nil, } diff --git a/pkg/cmd/server/server.go b/pkg/cmd/server/server.go index 966e30f63..f6cf78e16 100644 --- a/pkg/cmd/server/server.go +++ b/pkg/cmd/server/server.go @@ -423,6 +423,7 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) { MiddlewareServiceLabel: serverName, MismatchingZedTokenOption: mismatchZedTokenOption, MemoryUsageProvider: memoryUsageProvider, + ReadinessChecker: ds, } opts = opts.WithDatastore(ds) diff --git a/pkg/cmd/server/server_test.go b/pkg/cmd/server/server_test.go index 9226b849c..db87d25f6 100644 --- a/pkg/cmd/server/server_test.go +++ b/pkg/cmd/server/server_test.go @@ -456,7 +456,7 @@ func TestModifyUnaryMiddleware(t *testing.T) { }, }} - opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, "testing", consistency.TreatMismatchingTokensAsFullConsistency, memoryprotection.NewNoopMemoryUsageProvider(), nil, nil} + opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, "testing", consistency.TreatMismatchingTokensAsFullConsistency, memoryprotection.NewNoopMemoryUsageProvider(), nil, nil, nil} opt = opt.WithDatastore(nil) defaultMw, err := DefaultUnaryMiddleware(opt) @@ -484,7 +484,7 @@ func TestModifyStreamingMiddleware(t *testing.T) { }, }} - opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, "testing", consistency.TreatMismatchingTokensAsFullConsistency, memoryprotection.NewNoopMemoryUsageProvider(), nil, nil} + opt := MiddlewareOption{logging.Logger, nil, false, nil, false, false, false, "testing", consistency.TreatMismatchingTokensAsFullConsistency, memoryprotection.NewNoopMemoryUsageProvider(), nil, nil, nil} opt = opt.WithDatastore(nil) defaultMw, err := DefaultStreamingMiddleware(opt) diff --git a/pkg/cmd/server/zz_generated.middlewareoption.go b/pkg/cmd/server/zz_generated.middlewareoption.go index f31ca52e8..24c494f3c 100644 --- a/pkg/cmd/server/zz_generated.middlewareoption.go +++ b/pkg/cmd/server/zz_generated.middlewareoption.go @@ -5,6 +5,7 @@ import ( dispatch "github.com/authzed/spicedb/internal/dispatch" memoryprotection "github.com/authzed/spicedb/internal/middleware/memoryprotection" consistency "github.com/authzed/spicedb/pkg/middleware/consistency" + readiness "github.com/authzed/spicedb/pkg/middleware/readiness" defaults "github.com/creasty/defaults" auth "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth" zerolog "github.com/rs/zerolog" @@ -44,6 +45,7 @@ func (m *MiddlewareOption) ToOption() MiddlewareOptionOption { to.MiddlewareServiceLabel = m.MiddlewareServiceLabel to.MismatchingZedTokenOption = m.MismatchingZedTokenOption to.MemoryUsageProvider = m.MemoryUsageProvider + to.ReadinessChecker = m.ReadinessChecker } } @@ -169,3 +171,10 @@ func WithMemoryUsageProvider(memoryUsageProvider memoryprotection.MemoryUsagePro m.MemoryUsageProvider = memoryUsageProvider } } + +// WithReadinessChecker returns an option that can set ReadinessChecker on a MiddlewareOption +func WithReadinessChecker(readinessChecker readiness.ReadinessChecker) MiddlewareOptionOption { + return func(m *MiddlewareOption) { + m.ReadinessChecker = readinessChecker + } +} diff --git a/pkg/middleware/readiness/readiness.go b/pkg/middleware/readiness/readiness.go new file mode 100644 index 000000000..8b835a116 --- /dev/null +++ b/pkg/middleware/readiness/readiness.go @@ -0,0 +1,195 @@ +package readiness + +import ( + "context" + "strings" + "sync" + "time" + + "golang.org/x/sync/singleflight" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/authzed/spicedb/pkg/datastore" +) + +const ( + // healthCheckPrefix is the gRPC method prefix for health checks. + // We bypass readiness checks for health endpoints so Kubernetes probes work. + healthCheckPrefix = "/grpc.health.v1.Health/" + + // readyCacheTTL is how long to cache a positive ready state. + readyCacheTTL = 500 * time.Millisecond + + // notReadyCacheTTL is how long to cache a negative ready state. + // Shorter than readyCacheTTL to allow faster recovery detection. + notReadyCacheTTL = 100 * time.Millisecond + + // readinessCheckTimeout is the maximum time to wait for a readiness check. + // We use a dedicated context with this timeout inside singleflight to ensure + // consistent behavior regardless of which request's context triggered the check. + readinessCheckTimeout = 5 * time.Second +) + +// ReadinessChecker is the interface for checking datastore readiness. +type ReadinessChecker interface { + ReadyState(ctx context.Context) (datastore.ReadyState, error) +} + +// Gate blocks gRPC requests until the datastore is ready. +// It caches the ready state briefly to avoid overwhelming the datastore +// with readiness checks on every request. +type Gate struct { + checker ReadinessChecker + + // singleflight prevents thundering herd when cache expires + sfGroup singleflight.Group + + mu sync.RWMutex + cachedReady bool // GUARDED_BY(mu) + cachedMessage string // GUARDED_BY(mu) + cacheTime time.Time // GUARDED_BY(mu) +} + +// NewGate creates a new readiness gate with the given checker. +// If checker is nil, the gate will pass through all requests without checking. +func NewGate(checker ReadinessChecker) *Gate { + return &Gate{checker: checker} +} + +// readinessResult holds the result of a readiness check for singleflight. +type readinessResult struct { + ready bool + message string +} + +// isMigrationIssue returns true if the not-ready message indicates +// the database schema hasn't been migrated. Other not-ready reasons +// (like connection pool warmup) are transient and shouldn't block requests. +func isMigrationIssue(msg string) bool { + return strings.Contains(msg, "not migrated") || strings.Contains(msg, "migration") +} + +// isReady checks if the datastore is ready, using a cached value if available. +// Uses singleflight to prevent thundering herd on cache expiry. +// Only blocks requests for migration-related issues; transient states like +// connection pool warmup are allowed through. +func (g *Gate) isReady(ctx context.Context) (bool, string) { + // If no checker is configured, pass through + if g.checker == nil { + return true, "" + } + + // Fast path: check cache with read lock + g.mu.RLock() + elapsed := time.Since(g.cacheTime) + ttl := readyCacheTTL + if !g.cachedReady { + ttl = notReadyCacheTTL + } + if elapsed < ttl { + ready, msg := g.cachedReady, g.cachedMessage + g.mu.RUnlock() + return ready, msg + } + g.mu.RUnlock() + + // Slow path: use singleflight to deduplicate concurrent checks + result, _, _ := g.sfGroup.Do("readiness", func() (any, error) { + // Double-check cache after acquiring singleflight + g.mu.RLock() + elapsed := time.Since(g.cacheTime) + ttl := readyCacheTTL + if !g.cachedReady { + ttl = notReadyCacheTTL + } + if elapsed < ttl { + ready, msg := g.cachedReady, g.cachedMessage + g.mu.RUnlock() + return readinessResult{ready: ready, message: msg}, nil + } + g.mu.RUnlock() + + // Use an independent context with timeout for the readiness check. + // This ensures consistent behavior when multiple requests are coalesced + // by singleflight - we don't want the check to fail because the first + // request's context was cancelled. + checkCtx, cancel := context.WithTimeout(context.Background(), readinessCheckTimeout) + defer cancel() + + state, err := g.checker.ReadyState(checkCtx) + if err != nil { + // On error checking readiness, allow requests through. + // If the datastore is truly unavailable, requests will fail + // with appropriate errors from the datastore layer. + return readinessResult{ready: true, message: ""}, nil + } + + // Only block requests for migration-related issues. + // Transient states (connection pool warmup, etc.) should not block. + ready := state.IsReady || !isMigrationIssue(state.Message) + + // Update cache + g.mu.Lock() + g.cachedReady = ready + g.cachedMessage = state.Message + g.cacheTime = time.Now() + g.mu.Unlock() + + return readinessResult{ready: ready, message: state.Message}, nil + }) + + r := result.(readinessResult) + return r.ready, r.message +} + +// formatNotReadyError creates a user-friendly error message based on the readiness failure reason. +// TODO(authzed/api#159): Once ERROR_REASON_DATASTORE_NOT_MIGRATED is available in the API, +// use spiceerrors.WithCodeAndReason to include the structured error reason. +func formatNotReadyError(msg string) error { + // Check if this is a migration-related issue + if strings.Contains(msg, "not migrated") || strings.Contains(msg, "migration") { + return status.Errorf(codes.FailedPrecondition, + "SpiceDB datastore is not migrated. Please run 'spicedb datastore migrate'. Details: %s", msg) + } + // Generic not-ready message for other cases (connection issues, pool not ready, etc.) + return status.Errorf(codes.FailedPrecondition, + "SpiceDB datastore is not ready. Details: %s", msg) +} + +// UnaryServerInterceptor returns a gRPC unary interceptor that blocks +// requests until the datastore is ready. +func (g *Gate) UnaryServerInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + // Bypass health checks so Kubernetes probes work + if strings.HasPrefix(info.FullMethod, healthCheckPrefix) { + return handler(ctx, req) + } + + ready, msg := g.isReady(ctx) + if !ready { + return nil, formatNotReadyError(msg) + } + + return handler(ctx, req) + } +} + +// StreamServerInterceptor returns a gRPC stream interceptor that blocks +// streams until the datastore is ready. +func (g *Gate) StreamServerInterceptor() grpc.StreamServerInterceptor { + return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + // Bypass health checks so Kubernetes probes work + if strings.HasPrefix(info.FullMethod, healthCheckPrefix) { + return handler(srv, ss) + } + + ready, msg := g.isReady(ss.Context()) + if !ready { + return formatNotReadyError(msg) + } + + return handler(srv, ss) + } +} diff --git a/pkg/middleware/readiness/readiness_test.go b/pkg/middleware/readiness/readiness_test.go new file mode 100644 index 000000000..2b6fda668 --- /dev/null +++ b/pkg/middleware/readiness/readiness_test.go @@ -0,0 +1,393 @@ +package readiness + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/authzed/spicedb/pkg/datastore" +) + +type mockChecker struct { + ready bool + message string + err error + callCount atomic.Int32 +} + +func (m *mockChecker) ReadyState(_ context.Context) (datastore.ReadyState, error) { + m.callCount.Add(1) + if m.err != nil { + return datastore.ReadyState{}, m.err + } + return datastore.ReadyState{ + IsReady: m.ready, + Message: m.message, + }, nil +} + +func TestGate_BlocksWhenNotReady(t *testing.T) { + checker := &mockChecker{ + ready: false, + message: "datastore is not migrated", + } + gate := NewGate(checker) + + interceptor := gate.UnaryServerInterceptor() + _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/authzed.api.v1.PermissionsService/CheckPermission", + }, func(ctx context.Context, req any) (any, error) { + t.Fatal("handler should not be called when not ready") + return nil, nil + }) + + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.FailedPrecondition, st.Code()) + require.Contains(t, st.Message(), "not migrated") + require.Contains(t, st.Message(), "spicedb datastore migrate") +} + +func TestGate_AllowsWhenReady(t *testing.T) { + checker := &mockChecker{ready: true} + gate := NewGate(checker) + + handlerCalled := false + interceptor := gate.UnaryServerInterceptor() + _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/authzed.api.v1.PermissionsService/CheckPermission", + }, func(ctx context.Context, req any) (any, error) { + handlerCalled = true + return "response", nil + }) + + require.NoError(t, err) + require.True(t, handlerCalled) +} + +func TestGate_BypassesHealthCheck(t *testing.T) { + checker := &mockChecker{ + ready: false, + message: "not ready", + } + gate := NewGate(checker) + + handlerCalled := false + interceptor := gate.UnaryServerInterceptor() + _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/grpc.health.v1.Health/Check", + }, func(ctx context.Context, req any) (any, error) { + handlerCalled = true + return "ok", nil + }) + + require.NoError(t, err) + require.True(t, handlerCalled) + // Checker should not be called for health checks + require.Equal(t, int32(0), checker.callCount.Load()) +} + +func TestGate_BypassesHealthWatch(t *testing.T) { + checker := &mockChecker{ + ready: false, + message: "not ready", + } + gate := NewGate(checker) + + handlerCalled := false + interceptor := gate.UnaryServerInterceptor() + _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/grpc.health.v1.Health/Watch", + }, func(ctx context.Context, req any) (any, error) { + handlerCalled = true + return "ok", nil + }) + + require.NoError(t, err) + require.True(t, handlerCalled) +} + +func TestGate_CachesReadyState(t *testing.T) { + checker := &mockChecker{ready: true} + gate := NewGate(checker) + + interceptor := gate.UnaryServerInterceptor() + info := &grpc.UnaryServerInfo{FullMethod: "/test/Method"} + handler := func(ctx context.Context, req any) (any, error) { + return nil, nil + } + + // First call should check readiness + _, _ = interceptor(context.Background(), nil, info, handler) + require.Equal(t, int32(1), checker.callCount.Load()) + + // Second call should use cache + _, _ = interceptor(context.Background(), nil, info, handler) + require.Equal(t, int32(1), checker.callCount.Load()) + + // Third call should use cache + _, _ = interceptor(context.Background(), nil, info, handler) + require.Equal(t, int32(1), checker.callCount.Load()) +} + +func TestGate_CacheExpires(t *testing.T) { + checker := &mockChecker{ready: true} + gate := NewGate(checker) + + // Override cache time to test expiry + gate.mu.Lock() + gate.cachedReady = true + gate.cacheTime = time.Now().Add(-2 * readyCacheTTL) // Expired + gate.mu.Unlock() + + interceptor := gate.UnaryServerInterceptor() + _, _ = interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/test/Method", + }, func(ctx context.Context, req any) (any, error) { + return nil, nil + }) + + // Should have called checker because cache expired + require.Equal(t, int32(1), checker.callCount.Load()) +} + +func TestGate_SingleflightPreventsThunderingHerd(t *testing.T) { + // Slow checker that takes 100ms + checker := &mockChecker{ready: true} + gate := NewGate(checker) + + interceptor := gate.UnaryServerInterceptor() + info := &grpc.UnaryServerInfo{FullMethod: "/test/Method"} + handler := func(ctx context.Context, req any) (any, error) { + return nil, nil + } + + // Launch 10 concurrent requests + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, _ = interceptor(context.Background(), nil, info, handler) + }() + } + wg.Wait() + + // Singleflight should have deduplicated the calls + // We might get 1 or 2 calls depending on timing, but definitely not 10 + require.LessOrEqual(t, checker.callCount.Load(), int32(2)) +} + +func TestGate_PassesThroughOnCheckerError(t *testing.T) { + checker := &mockChecker{ + err: errors.New("connection refused"), + } + gate := NewGate(checker) + + handlerCalled := false + interceptor := gate.UnaryServerInterceptor() + _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/test/Method", + }, func(ctx context.Context, req any) (any, error) { + handlerCalled = true + return "ok", nil + }) + + // Errors during readiness check should allow requests through. + // If the datastore is truly unavailable, requests will fail at the + // datastore layer with appropriate errors. + require.NoError(t, err) + require.True(t, handlerCalled) +} + +func TestGate_StreamInterceptorBlocksWhenNotReady(t *testing.T) { + checker := &mockChecker{ + ready: false, + message: "not migrated", + } + gate := NewGate(checker) + + interceptor := gate.StreamServerInterceptor() + err := interceptor(nil, &mockServerStream{}, &grpc.StreamServerInfo{ + FullMethod: "/authzed.api.v1.WatchService/Watch", + }, func(srv any, stream grpc.ServerStream) error { + t.Fatal("handler should not be called when not ready") + return nil + }) + + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.FailedPrecondition, st.Code()) +} + +func TestGate_StreamInterceptorAllowsWhenReady(t *testing.T) { + checker := &mockChecker{ready: true} + gate := NewGate(checker) + + handlerCalled := false + interceptor := gate.StreamServerInterceptor() + err := interceptor(nil, &mockServerStream{}, &grpc.StreamServerInfo{ + FullMethod: "/authzed.api.v1.WatchService/Watch", + }, func(srv any, stream grpc.ServerStream) error { + handlerCalled = true + return nil + }) + + require.NoError(t, err) + require.True(t, handlerCalled) +} + +// mockServerStream implements grpc.ServerStream for testing +type mockServerStream struct { + grpc.ServerStream + ctx context.Context +} + +func (m *mockServerStream) Context() context.Context { + if m.ctx != nil { + return m.ctx + } + return context.Background() +} + +func TestGate_NilCheckerPassesThrough(t *testing.T) { + gate := NewGate(nil) + + handlerCalled := false + interceptor := gate.UnaryServerInterceptor() + _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/authzed.api.v1.PermissionsService/CheckPermission", + }, func(ctx context.Context, req any) (any, error) { + handlerCalled = true + return "response", nil + }) + + require.NoError(t, err) + require.True(t, handlerCalled) +} + +func TestGate_ErrorDoesNotCache(t *testing.T) { + checker := &mockChecker{err: errors.New("temporary connection error")} + gate := NewGate(checker) + + interceptor := gate.UnaryServerInterceptor() + info := &grpc.UnaryServerInfo{FullMethod: "/test/Method"} + + // First call: checker errors but request passes through + handlerCalled := false + _, err := interceptor(context.Background(), nil, info, func(ctx context.Context, req any) (any, error) { + handlerCalled = true + return "ok", nil + }) + require.NoError(t, err) + require.True(t, handlerCalled) + require.Equal(t, int32(1), checker.callCount.Load()) + + // Second call should retry checker (errors are not cached) + handlerCalled = false + _, err = interceptor(context.Background(), nil, info, func(ctx context.Context, req any) (any, error) { + handlerCalled = true + return "ok", nil + }) + require.NoError(t, err) + require.True(t, handlerCalled) + require.Equal(t, int32(2), checker.callCount.Load()) +} + +func TestGate_NegativeCacheReducesChecks(t *testing.T) { + checker := &mockChecker{ + ready: false, + message: "datastore is not migrated", + } + gate := NewGate(checker) + + interceptor := gate.UnaryServerInterceptor() + info := &grpc.UnaryServerInfo{FullMethod: "/test/Method"} + handler := func(ctx context.Context, req any) (any, error) { + return nil, nil + } + + // First call checks readiness + _, _ = interceptor(context.Background(), nil, info, handler) + require.Equal(t, int32(1), checker.callCount.Load()) + + // Second call should use negative cache (within notReadyCacheTTL) + _, _ = interceptor(context.Background(), nil, info, handler) + require.Equal(t, int32(1), checker.callCount.Load()) + + // Third call should use negative cache + _, _ = interceptor(context.Background(), nil, info, handler) + require.Equal(t, int32(1), checker.callCount.Load()) +} + +func TestGate_ErrorMessageContextAware(t *testing.T) { + t.Run("migration issue blocks with actionable message", func(t *testing.T) { + checker := &mockChecker{ + ready: false, + message: "datastore is not migrated: currently at revision \"\"", + } + gate := NewGate(checker) + + interceptor := gate.UnaryServerInterceptor() + _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/test/Method", + }, func(ctx context.Context, req any) (any, error) { + t.Fatal("handler should not be called for migration issues") + return nil, nil + }) + + require.Error(t, err) + require.Contains(t, err.Error(), "spicedb datastore migrate") + require.Contains(t, err.Error(), "not migrated") + }) + + t.Run("connection pool issue passes through", func(t *testing.T) { + checker := &mockChecker{ + ready: false, + message: "spicedb does not have the required minimum connection count", + } + gate := NewGate(checker) + + handlerCalled := false + interceptor := gate.UnaryServerInterceptor() + _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/test/Method", + }, func(ctx context.Context, req any) (any, error) { + handlerCalled = true + return "ok", nil + }) + + require.NoError(t, err) + require.True(t, handlerCalled) + }) + + t.Run("generic not ready passes through", func(t *testing.T) { + checker := &mockChecker{ + ready: false, + message: "some other issue", + } + gate := NewGate(checker) + + handlerCalled := false + interceptor := gate.UnaryServerInterceptor() + _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/test/Method", + }, func(ctx context.Context, req any) (any, error) { + handlerCalled = true + return "ok", nil + }) + + require.NoError(t, err) + require.True(t, handlerCalled) + }) +}