diff --git a/.gitignore b/.gitignore index 148a4040d..7be4317f0 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ coverage*.txt .claude/ e2e/newenemy/spicedb e2e/newenemy/cockroach -e2e/newenemy/chaosd \ No newline at end of file +e2e/newenemy/chaosd +__pycache__/ diff --git a/internal/datastore/crdb/pool/pool.go b/internal/datastore/crdb/pool/pool.go index 4699b33bd..73fc8f2c9 100644 --- a/internal/datastore/crdb/pool/pool.go +++ b/internal/datastore/crdb/pool/pool.go @@ -8,7 +8,6 @@ import ( "sync" "time" - "github.com/ccoveille/go-safecast/v2" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" @@ -17,6 +16,7 @@ import ( "github.com/authzed/spicedb/internal/datastore/postgres/common" log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/spiceerrors" ) // pgxPool interface is the subset of pgxpool.Pool that RetryPool needs @@ -156,21 +156,13 @@ func (p *RetryPool) ID() string { // MaxConns returns the MaxConns configured on the underlying pool func (p *RetryPool) MaxConns() uint32 { // This should be non-negative - maxConns, err := safecast.Convert[uint32](p.pool.Config().MaxConns) - if err != nil { - maxConns = 0 - } - return maxConns + return spiceerrors.MustSafecast[uint32](p.pool.Config().MaxConns) } // MinConns returns the MinConns configured on the underlying pool func (p *RetryPool) MinConns() uint32 { // This should be non-negative - minConns, err := safecast.Convert[uint32](p.pool.Config().MinConns) - if err != nil { - minConns = 0 - } - return minConns + return spiceerrors.MustSafecast[uint32](p.pool.Config().MinConns) } // ExecFunc is a replacement for pgxpool.pgxPool.Exec that allows resetting the diff --git a/internal/datastore/postgres/snapshot.go b/internal/datastore/postgres/snapshot.go index 0054c16b4..80ad9b80f 100644 --- a/internal/datastore/postgres/snapshot.go +++ b/internal/datastore/postgres/snapshot.go @@ -7,8 +7,9 @@ import ( "strconv" "strings" - "github.com/ccoveille/go-safecast/v2" "github.com/jackc/pgx/v5/pgtype" + + "github.com/authzed/spicedb/pkg/spiceerrors" ) // RegisterTypes registers pgSnapshot and xid8 with a pgtype.ConnInfo. @@ -273,10 +274,7 @@ func (s pgSnapshot) markInProgress(txid uint64) pgSnapshot { startingXipLen := len(newSnapshot.xipList) for numToDrop = 0; numToDrop < startingXipLen; numToDrop++ { // numToDrop should be nonnegative - uintNumToDrop, err := safecast.Convert[uint64](numToDrop) - if err != nil { - uintNumToDrop = 0 - } + uintNumToDrop := spiceerrors.MustSafecast[uint64](numToDrop) if newSnapshot.xipList[startingXipLen-1-numToDrop] != newSnapshot.xmax-uintNumToDrop-1 { break diff --git a/internal/developmentmembership/onrset.go b/internal/developmentmembership/onrset.go index ba53fe0a6..72aff73ec 100644 --- a/internal/developmentmembership/onrset.go +++ b/internal/developmentmembership/onrset.go @@ -1,9 +1,8 @@ package developmentmembership import ( - "github.com/ccoveille/go-safecast/v2" - "github.com/authzed/spicedb/pkg/genutil/mapz" + "github.com/authzed/spicedb/pkg/spiceerrors" "github.com/authzed/spicedb/pkg/tuple" ) @@ -26,11 +25,7 @@ func NewONRSet(onrs ...tuple.ObjectAndRelation) ONRSet { // Length returns the size of the set. func (ons ONRSet) Length() uint64 { // This is the length of a set so we should never fall out of bounds. - length, err := safecast.Convert[uint64](ons.onrs.Len()) - if err != nil { - return 0 - } - return length + return spiceerrors.MustSafecast[uint64](ons.onrs.Len()) } // IsEmpty returns whether the set is empty. diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index b18a1bb65..23195be06 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -3,9 +3,10 @@ package cache import ( "time" - "github.com/ccoveille/go-safecast/v2" "github.com/dustin/go-humanize" "github.com/rs/zerolog" + + "github.com/authzed/spicedb/pkg/spiceerrors" ) // KeyString is an interface for keys that can be converted to strings. @@ -53,10 +54,7 @@ type Config struct { } func (c *Config) MarshalZerologObject(e *zerolog.Event) { - maxCost, err := safecast.Convert[uint64](c.MaxCost) - if err != nil { - maxCost = 0 - } + maxCost := spiceerrors.MustSafecast[uint64](c.MaxCost) e. Str("maxCost", humanize.IBytes(maxCost)). Int64("numCounters", c.NumCounters). diff --git a/pkg/composableschemadsl/compiler/translator.go b/pkg/composableschemadsl/compiler/translator.go index ee7ab52bb..2c57a9d1e 100644 --- a/pkg/composableschemadsl/compiler/translator.go +++ b/pkg/composableschemadsl/compiler/translator.go @@ -9,7 +9,6 @@ import ( "path/filepath" "strings" - "github.com/ccoveille/go-safecast/v2" "github.com/jzelinskie/stringz" "github.com/authzed/spicedb/internal/logging" @@ -20,6 +19,7 @@ import ( "github.com/authzed/spicedb/pkg/genutil/mapz" "github.com/authzed/spicedb/pkg/namespace" core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" ) type translationContext struct { @@ -334,14 +334,8 @@ func getSourcePosition(dslNode *dslNode, mapper input.PositionMapper) *core.Sour return nil } - uintLine, err := safecast.Convert[uint64](line) - if err != nil { - uintLine = 0 - } - uintCol, err := safecast.Convert[uint64](col) - if err != nil { - uintCol = 0 - } + uintLine := spiceerrors.MustSafecast[uint64](line) + uintCol := spiceerrors.MustSafecast[uint64](col) return &core.SourcePosition{ ZeroIndexedLineNumber: uintLine, diff --git a/pkg/schemadsl/compiler/translator.go b/pkg/schemadsl/compiler/translator.go index 2543e0537..adb5db3ba 100644 --- a/pkg/schemadsl/compiler/translator.go +++ b/pkg/schemadsl/compiler/translator.go @@ -9,7 +9,6 @@ import ( "path/filepath" "strings" - "github.com/ccoveille/go-safecast/v2" "github.com/jzelinskie/stringz" "github.com/authzed/spicedb/internal/logging" @@ -20,6 +19,7 @@ import ( core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/schemadsl/dslshape" "github.com/authzed/spicedb/pkg/schemadsl/input" + "github.com/authzed/spicedb/pkg/spiceerrors" ) type translationContext struct { @@ -345,14 +345,8 @@ func getSourcePosition(dslNode *dslNode, mapper input.PositionMapper) *core.Sour return nil } - uintLine, err := safecast.Convert[uint64](line) - if err != nil { - uintLine = 0 - } - uintCol, err := safecast.Convert[uint64](col) - if err != nil { - uintCol = 0 - } + uintLine := spiceerrors.MustSafecast[uint64](line) + uintCol := spiceerrors.MustSafecast[uint64](col) return &core.SourcePosition{ ZeroIndexedLineNumber: uintLine, diff --git a/pkg/spiceerrors/bug.go b/pkg/spiceerrors/bug.go index ebac3549e..79208e0df 100644 --- a/pkg/spiceerrors/bug.go +++ b/pkg/spiceerrors/bug.go @@ -5,7 +5,10 @@ import ( "os" "strings" + "github.com/ccoveille/go-safecast/v2" "github.com/go-errors/errors" + + log "github.com/authzed/spicedb/internal/logging" ) // IsInTests returns true if go test is running @@ -33,3 +36,27 @@ func MustBugf(format string, args ...any) error { e := errors.Errorf(format, args...) return fmt.Errorf("BUG: %s", e.ErrorStack()) } + +// MustSafecast converts a value from one numeric type to another using safecast. +// If the conversion fails (value out of range), it panics in tests and returns +// the zero value in production. This should only be used where the value is +// expected to always be convertible (e.g., converting from a statically defined +// value or a value known to be non-negative). +func MustSafecast[To, From safecast.Number](from From) To { + result, err := safecast.Convert[To](from) + if err != nil { + if IsInTests() { + panic(fmt.Sprintf("safecast conversion failed: %v (from %v to %T)", err, from, result)) + } + // In production, log a warning and return the zero value + var zero To + log.Warn(). + Interface("from_value", from). + Str("from_type", fmt.Sprintf("%T", from)). + Str("to_type", fmt.Sprintf("%T", zero)). + Err(err). + Msg("MustSafecast conversion failed in production, returning zero value") + return zero + } + return result +} diff --git a/pkg/spiceerrors/bug_test.go b/pkg/spiceerrors/bug_test.go index d5bd139f9..f88d821ae 100644 --- a/pkg/spiceerrors/bug_test.go +++ b/pkg/spiceerrors/bug_test.go @@ -1,6 +1,7 @@ package spiceerrors import ( + "os" "testing" "github.com/stretchr/testify/assert" @@ -14,3 +15,59 @@ func TestMustBug(t *testing.T) { require.Error(t, err) }, "The code did not panic") } + +func TestMustSafecast(t *testing.T) { + require.True(t, IsInTests()) + + // Test successful conversion + t.Run("successful conversion", func(t *testing.T) { + result := MustSafecast[uint64](42) + assert.Equal(t, uint64(42), result) + + result2 := MustSafecast[int32](100) + assert.Equal(t, int32(100), result2) + }) + + // Test that conversion failure panics in tests + t.Run("conversion failure panics in tests", func(t *testing.T) { + assert.Panics(t, func() { + // Try to convert a negative number to unsigned + MustSafecast[uint64](-1) + }, "Expected panic on invalid conversion") + }) + + // Test conversion from larger to smaller type that fits + t.Run("conversion within range", func(t *testing.T) { + result := MustSafecast[uint32](uint64(100)) + assert.Equal(t, uint32(100), result) + }) + + // Test production behavior (returns zero value without panic) + t.Run("production behavior returns zero on failure", func(t *testing.T) { + // Temporarily simulate production by removing test flags from os.Args + originalArgs := os.Args + defer func() { + os.Args = originalArgs + }() + + // Remove all -test.* flags to simulate production + var nonTestArgs []string + for _, arg := range os.Args { + if len(arg) < 6 || arg[:6] != "-test." { + nonTestArgs = append(nonTestArgs, arg) + } + } + os.Args = nonTestArgs + + // Verify we're now simulating production + require.False(t, IsInTests(), "Should simulate production mode") + + // Test that conversion failure returns zero in production + result := MustSafecast[uint64](-1) + assert.Equal(t, uint64(0), result, "Expected zero value in production mode") + + // Test overflow case + result2 := MustSafecast[uint8](300) + assert.Equal(t, uint8(0), result2, "Expected zero value for overflow in production mode") + }) +}