Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ coverage*.txt
.claude/
e2e/newenemy/spicedb
e2e/newenemy/cockroach
e2e/newenemy/chaosd
e2e/newenemy/chaosd
__pycache__/
14 changes: 3 additions & 11 deletions internal/datastore/crdb/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"sync"
"time"

"github.com/ccoveille/go-safecast/v2"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
Expand All @@ -17,6 +16,7 @@ import (

"github.com/authzed/spicedb/internal/datastore/postgres/common"
log "github.com/authzed/spicedb/internal/logging"
"github.com/authzed/spicedb/pkg/spiceerrors"
)

// pgxPool interface is the subset of pgxpool.Pool that RetryPool needs
Expand Down Expand Up @@ -156,21 +156,13 @@ func (p *RetryPool) ID() string {
// MaxConns returns the MaxConns configured on the underlying pool
func (p *RetryPool) MaxConns() uint32 {
// This should be non-negative
maxConns, err := safecast.Convert[uint32](p.pool.Config().MaxConns)
if err != nil {
maxConns = 0
}
return maxConns
return spiceerrors.MustSafecast[uint32](p.pool.Config().MaxConns)
}

// MinConns returns the MinConns configured on the underlying pool
func (p *RetryPool) MinConns() uint32 {
// This should be non-negative
minConns, err := safecast.Convert[uint32](p.pool.Config().MinConns)
if err != nil {
minConns = 0
}
return minConns
return spiceerrors.MustSafecast[uint32](p.pool.Config().MinConns)
}

// ExecFunc is a replacement for pgxpool.pgxPool.Exec that allows resetting the
Expand Down
8 changes: 3 additions & 5 deletions internal/datastore/postgres/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ import (
"strconv"
"strings"

"github.com/ccoveille/go-safecast/v2"
"github.com/jackc/pgx/v5/pgtype"

"github.com/authzed/spicedb/pkg/spiceerrors"
)

// RegisterTypes registers pgSnapshot and xid8 with a pgtype.ConnInfo.
Expand Down Expand Up @@ -273,10 +274,7 @@ func (s pgSnapshot) markInProgress(txid uint64) pgSnapshot {
startingXipLen := len(newSnapshot.xipList)
for numToDrop = 0; numToDrop < startingXipLen; numToDrop++ {
// numToDrop should be nonnegative
uintNumToDrop, err := safecast.Convert[uint64](numToDrop)
if err != nil {
uintNumToDrop = 0
}
uintNumToDrop := spiceerrors.MustSafecast[uint64](numToDrop)

if newSnapshot.xipList[startingXipLen-1-numToDrop] != newSnapshot.xmax-uintNumToDrop-1 {
break
Expand Down
9 changes: 2 additions & 7 deletions internal/developmentmembership/onrset.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package developmentmembership

import (
"github.com/ccoveille/go-safecast/v2"

"github.com/authzed/spicedb/pkg/genutil/mapz"
"github.com/authzed/spicedb/pkg/spiceerrors"
"github.com/authzed/spicedb/pkg/tuple"
)

Expand All @@ -26,11 +25,7 @@ func NewONRSet(onrs ...tuple.ObjectAndRelation) ONRSet {
// Length returns the size of the set.
func (ons ONRSet) Length() uint64 {
// This is the length of a set so we should never fall out of bounds.
length, err := safecast.Convert[uint64](ons.onrs.Len())
if err != nil {
return 0
}
return length
return spiceerrors.MustSafecast[uint64](ons.onrs.Len())
}

// IsEmpty returns whether the set is empty.
Expand Down
8 changes: 3 additions & 5 deletions pkg/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import (
"time"

"github.com/ccoveille/go-safecast/v2"
"github.com/dustin/go-humanize"
"github.com/rs/zerolog"

"github.com/authzed/spicedb/pkg/spiceerrors"
)

// KeyString is an interface for keys that can be converted to strings.
Expand Down Expand Up @@ -53,10 +54,7 @@
}

func (c *Config) MarshalZerologObject(e *zerolog.Event) {
maxCost, err := safecast.Convert[uint64](c.MaxCost)
if err != nil {
maxCost = 0
}
maxCost := spiceerrors.MustSafecast[uint64](c.MaxCost)

Check warning on line 57 in pkg/cache/cache.go

View check run for this annotation

Codecov / codecov/patch

pkg/cache/cache.go#L57

Added line #L57 was not covered by tests
e.
Str("maxCost", humanize.IBytes(maxCost)).
Int64("numCounters", c.NumCounters).
Expand Down
12 changes: 3 additions & 9 deletions pkg/composableschemadsl/compiler/translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"path/filepath"
"strings"

"github.com/ccoveille/go-safecast/v2"
"github.com/jzelinskie/stringz"

"github.com/authzed/spicedb/internal/logging"
Expand All @@ -20,6 +19,7 @@ import (
"github.com/authzed/spicedb/pkg/genutil/mapz"
"github.com/authzed/spicedb/pkg/namespace"
core "github.com/authzed/spicedb/pkg/proto/core/v1"
"github.com/authzed/spicedb/pkg/spiceerrors"
)

type translationContext struct {
Expand Down Expand Up @@ -334,14 +334,8 @@ func getSourcePosition(dslNode *dslNode, mapper input.PositionMapper) *core.Sour
return nil
}

uintLine, err := safecast.Convert[uint64](line)
if err != nil {
uintLine = 0
}
uintCol, err := safecast.Convert[uint64](col)
if err != nil {
uintCol = 0
}
uintLine := spiceerrors.MustSafecast[uint64](line)
uintCol := spiceerrors.MustSafecast[uint64](col)

return &core.SourcePosition{
ZeroIndexedLineNumber: uintLine,
Expand Down
12 changes: 3 additions & 9 deletions pkg/schemadsl/compiler/translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"path/filepath"
"strings"

"github.com/ccoveille/go-safecast/v2"
"github.com/jzelinskie/stringz"

"github.com/authzed/spicedb/internal/logging"
Expand All @@ -20,6 +19,7 @@ import (
core "github.com/authzed/spicedb/pkg/proto/core/v1"
"github.com/authzed/spicedb/pkg/schemadsl/dslshape"
"github.com/authzed/spicedb/pkg/schemadsl/input"
"github.com/authzed/spicedb/pkg/spiceerrors"
)

type translationContext struct {
Expand Down Expand Up @@ -345,14 +345,8 @@ func getSourcePosition(dslNode *dslNode, mapper input.PositionMapper) *core.Sour
return nil
}

uintLine, err := safecast.Convert[uint64](line)
if err != nil {
uintLine = 0
}
uintCol, err := safecast.Convert[uint64](col)
if err != nil {
uintCol = 0
}
uintLine := spiceerrors.MustSafecast[uint64](line)
uintCol := spiceerrors.MustSafecast[uint64](col)

return &core.SourcePosition{
ZeroIndexedLineNumber: uintLine,
Expand Down
27 changes: 27 additions & 0 deletions pkg/spiceerrors/bug.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
"os"
"strings"

"github.com/ccoveille/go-safecast/v2"
"github.com/go-errors/errors"

log "github.com/authzed/spicedb/internal/logging"
)

// IsInTests returns true if go test is running
Expand Down Expand Up @@ -33,3 +36,27 @@
e := errors.Errorf(format, args...)
return fmt.Errorf("BUG: %s", e.ErrorStack())
}

// MustSafecast converts a value from one numeric type to another using safecast.
// If the conversion fails (value out of range), it panics in tests and returns
// the zero value in production. This should only be used where the value is
// expected to always be convertible (e.g., converting from a statically defined
// value or a value known to be non-negative).
func MustSafecast[To, From safecast.Number](from From) To {
result, err := safecast.Convert[To](from)
if err != nil {
if IsInTests() {
panic(fmt.Sprintf("safecast conversion failed: %v (from %v to %T)", err, from, result))

Check warning on line 49 in pkg/spiceerrors/bug.go

View check run for this annotation

Codecov / codecov/patch

pkg/spiceerrors/bug.go#L48-L49

Added lines #L48 - L49 were not covered by tests
}
// In production, log a warning and return the zero value
var zero To
log.Warn().
Interface("from_value", from).
Str("from_type", fmt.Sprintf("%T", from)).
Str("to_type", fmt.Sprintf("%T", zero)).
Err(err).
Msg("MustSafecast conversion failed in production, returning zero value")
return zero

Check warning on line 59 in pkg/spiceerrors/bug.go

View check run for this annotation

Codecov / codecov/patch

pkg/spiceerrors/bug.go#L52-L59

Added lines #L52 - L59 were not covered by tests
}
return result
}
57 changes: 57 additions & 0 deletions pkg/spiceerrors/bug_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package spiceerrors

import (
"os"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -14,3 +15,59 @@ func TestMustBug(t *testing.T) {
require.Error(t, err)
}, "The code did not panic")
}

func TestMustSafecast(t *testing.T) {
require.True(t, IsInTests())

// Test successful conversion
t.Run("successful conversion", func(t *testing.T) {
result := MustSafecast[uint64](42)
assert.Equal(t, uint64(42), result)

result2 := MustSafecast[int32](100)
assert.Equal(t, int32(100), result2)
})

// Test that conversion failure panics in tests
t.Run("conversion failure panics in tests", func(t *testing.T) {
assert.Panics(t, func() {
// Try to convert a negative number to unsigned
MustSafecast[uint64](-1)
}, "Expected panic on invalid conversion")
})

// Test conversion from larger to smaller type that fits
t.Run("conversion within range", func(t *testing.T) {
result := MustSafecast[uint32](uint64(100))
assert.Equal(t, uint32(100), result)
})

// Test production behavior (returns zero value without panic)
t.Run("production behavior returns zero on failure", func(t *testing.T) {
// Temporarily simulate production by removing test flags from os.Args
originalArgs := os.Args
defer func() {
os.Args = originalArgs
}()

// Remove all -test.* flags to simulate production
var nonTestArgs []string
for _, arg := range os.Args {
if len(arg) < 6 || arg[:6] != "-test." {
nonTestArgs = append(nonTestArgs, arg)
}
}
os.Args = nonTestArgs

// Verify we're now simulating production
require.False(t, IsInTests(), "Should simulate production mode")

// Test that conversion failure returns zero in production
result := MustSafecast[uint64](-1)
assert.Equal(t, uint64(0), result, "Expected zero value in production mode")

// Test overflow case
result2 := MustSafecast[uint8](300)
assert.Equal(t, uint8(0), result2, "Expected zero value for overflow in production mode")
})
}
Loading