Skip to content

Commit 9d82db3

Browse files
Return an error when an unexpected state is provided
In order to avoid people relying in the state but not providing an encoder we now return an error.
1 parent 3cf3ece commit 9d82db3

File tree

4 files changed

+64
-18
lines changed

4 files changed

+64
-18
lines changed

driver/sql/postgres/projector.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package postgres
33
import (
44
"context"
55
"database/sql/driver"
6+
"errors"
67

78
"github.com/hellofresh/goengine/driver/sql"
89
)
@@ -45,3 +46,12 @@ func resolveErrorAction(
4546

4647
return errorFallthrough
4748
}
49+
50+
// defaultProjectionStateEncoder this `ProjectionStateEncoder` is used for a goeninge.Projection
51+
func defaultProjectionStateEncoder(state interface{}) ([]byte, error) {
52+
if state == nil {
53+
return []byte{'{', '}'}, nil
54+
}
55+
56+
return nil, errors.New("unexpected state provided (Did you forget to implement goengine.ProjectionSaga?)")
57+
}

driver/sql/postgres/projector_aggregate_storage.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ func newAggregateProjectionStorage(
5353
if logger == nil {
5454
logger = goengine.NopLogger
5555
}
56+
if projectionStateEncoder == nil {
57+
projectionStateEncoder = defaultProjectionStateEncoder
58+
}
5659

5760
projectionTableQuoted := QuoteIdentifier(projectionTable)
5861
projectionTableStr := QuoteString(projectionTable)
@@ -115,15 +118,9 @@ func (a *aggregateProjectionStorage) LoadOutOfSync(ctx context.Context, conn *sq
115118
}
116119

117120
func (a *aggregateProjectionStorage) PersistState(conn *sql.Conn, notification *driverSQL.ProjectionNotification, state driverSQL.ProjectionState) error {
118-
var (
119-
err error
120-
encodedState = []byte{'{', '}'}
121-
)
122-
if a.projectionStateEncoder != nil {
123-
encodedState, err = a.projectionStateEncoder(state.ProjectionState)
124-
if err != nil {
125-
return err
126-
}
121+
encodedState, err := a.projectionStateEncoder(state.ProjectionState)
122+
if err != nil {
123+
return err
127124
}
128125

129126
_, err = conn.ExecContext(context.Background(), a.queryPersistState, notification.AggregateID, state.Position, encodedState)

driver/sql/postgres/projector_stream_storage.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ func newStreamProjectionStorage(
3838
projectionStateEncoder driverSQL.ProjectionStateEncoder,
3939
logger goengine.Logger,
4040
) *streamProjectionStorage {
41+
if logger == nil {
42+
logger = goengine.NopLogger
43+
}
44+
if projectionStateEncoder == nil {
45+
projectionStateEncoder = defaultProjectionStateEncoder
46+
}
47+
4148
projectionTableQuoted := QuoteIdentifier(projectionTable)
4249
projectionTableStr := QuoteString(projectionTable)
4350

@@ -73,15 +80,9 @@ func newStreamProjectionStorage(
7380
}
7481

7582
func (s *streamProjectionStorage) PersistState(conn *sql.Conn, notification *driverSQL.ProjectionNotification, state driverSQL.ProjectionState) error {
76-
var (
77-
err error
78-
encodedState = []byte{'{', '}'}
79-
)
80-
if s.projectionStateEncoder != nil {
81-
encodedState, err = s.projectionStateEncoder(state.ProjectionState)
82-
if err != nil {
83-
return err
84-
}
83+
encodedState, err := s.projectionStateEncoder(state.ProjectionState)
84+
if err != nil {
85+
return err
8586
}
8687

8788
_, err = conn.ExecContext(context.Background(), s.queryPersistState, state.Position, encodedState, s.projectionName)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// +build unit
2+
3+
package postgres
4+
5+
import (
6+
"fmt"
7+
"github.com/hellofresh/goengine"
8+
"github.com/stretchr/testify/assert"
9+
"testing"
10+
)
11+
12+
func TestDefaultProjectionStateEncoder(t *testing.T) {
13+
t.Run("Only accept nil values as valid", func(t *testing.T) {
14+
res, err := defaultProjectionStateEncoder(nil)
15+
16+
assert.Equal(t, []byte{'{', '}'}, res)
17+
assert.NoError(t, err)
18+
})
19+
20+
t.Run("Reject any state this not nil", func(t *testing.T) {
21+
var pointer *goengine.Projection
22+
testCases := []interface{}{
23+
struct{}{},
24+
pointer,
25+
"",
26+
0,
27+
}
28+
29+
for i, v := range testCases {
30+
t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) {
31+
res, err := defaultProjectionStateEncoder(v)
32+
33+
assert.Error(t, err)
34+
assert.Nil(t, res)
35+
})
36+
}
37+
})
38+
}

0 commit comments

Comments
 (0)