diff --git a/go.work.sum b/go.work.sum index 27faa6382..e5fdc6b38 100644 --- a/go.work.sum +++ b/go.work.sum @@ -973,6 +973,7 @@ github.com/kisielk/errcheck v1.5.0 h1:e8esj/e4R+SAOwFwN+n3zr0nYeCyeweozKfO23MvHz github.com/kisielk/gotool v1.0.0 h1:AV2c/EiW3KqPNT9ZKl07ehoAGi4C5/01Cfbblndcapg= github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46 h1:veS9QfglfvqAw2e+eeNT/SbGySq8ajECXJ9e4fPoLhY= github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= +github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg= github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= @@ -1023,6 +1024,7 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f h1:KUppIJq7/+SVif2QVs3tOP0zanoHgBEVAwHxUSIzRqU= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/nats-io/nkeys v0.4.12/go.mod h1:MT59A1HYcjIcyQDJStTfaOY6vhy9XTUjOFo+SVsvpBg= github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU= github.com/onsi/ginkgo/v2 v2.11.0/go.mod h1:ZhrRA5XmEE3x3rhlzamx/JJvujdZoJ2uvgI7kR0iZvM= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= diff --git a/pkg/gofr/migration/dgraph.go b/pkg/gofr/migration/dgraph.go index 546620643..1489816ad 100644 --- a/pkg/gofr/migration/dgraph.go +++ b/pkg/gofr/migration/dgraph.go @@ -3,6 +3,7 @@ package migration import ( "context" "encoding/json" + "errors" "fmt" "time" @@ -22,6 +23,12 @@ type dgraphMigrator struct { migrator } +type dgraphTxn interface { + Mutate(ctx context.Context, mu *api.Mutation) (*api.Response, error) + Commit(ctx context.Context) error + Discard(ctx context.Context) error +} + const ( // dgraphSchema defines the migration schema with fully qualified predicate names. dgraphSchema = ` @@ -47,6 +54,10 @@ const ( ` ) +var ( + errInvalidDgraphTxn = errors.New("invalid dgraph transaction type") +) + // apply creates a new dgraphMigrator. func (ds dgraphDS) apply(m migrator) migrator { return dgraphMigrator{ @@ -144,33 +155,53 @@ func (dm dgraphMigrator) beginTransaction(c *container.Container) transactionDat // commitMigration commits the migration and records its metadata. func (dm dgraphMigrator) commitMigration(c *container.Container, data transactionData) error { - // Build the JSON payload for the migration record. - payload := map[string]any{ - "migrations": []map[string]any{ - { - "migrations.version": data.MigrationNumber, - "migrations.method": "UP", - "migrations.start_time": data.StartTime.Format(time.RFC3339), - "migrations.duration": time.Since(data.StartTime).Milliseconds(), - }, - }, - } - - jsonPayload, err := json.Marshal(payload) - if err != nil { - return err - } - - _, err = c.DGraph.Mutate(context.Background(), &api.Mutation{ - SetJson: jsonPayload, - }) - if err != nil { - return err - } - - c.Debugf("Inserted record for migration %v in Dgraph migrations", data.MigrationNumber) - - return dm.migrator.commitMigration(c, data) + ctx := context.Background() + + // Build the JSON payload for the migration record. + payload := map[string]any{ + "migrations": []map[string]any{ + { + "migrations.version": data.MigrationNumber, + "migrations.method": "UP", + "migrations.start_time": data.StartTime.Format(time.RFC3339), + "migrations.duration": time.Since(data.StartTime).Milliseconds(), + }, + }, + } + + jsonPayload, err := json.Marshal(payload) + if err != nil { + return err + } + + tx, ok := c.DGraph.NewTxn().(dgraphTxn) + if !ok { + return errInvalidDgraphTxn + } + + defer func() { + if err = tx.Discard(ctx); err != nil { + c.Error("dgraph: transaction discard failed", err) + } + }() + + c.Debugf("Executing Dgraph migration mutation for version %v", data.MigrationNumber) + + _, err = tx.Mutate(ctx, &api.Mutation{ + SetJson: jsonPayload, + }) + if err != nil { + return err + } + + err = tx.Commit(ctx) + if err != nil { + return err + } + + c.Debugf("Inserted record for migration %v in Dgraph migrations", data.MigrationNumber) + + return dm.migrator.commitMigration(c, data) } // rollback handles migration failure and rollback. diff --git a/pkg/gofr/migration/dgraph_test.go b/pkg/gofr/migration/dgraph_test.go index 954f33d7f..127e98a32 100644 --- a/pkg/gofr/migration/dgraph_test.go +++ b/pkg/gofr/migration/dgraph_test.go @@ -2,9 +2,11 @@ package migration import ( "context" + "errors" "testing" "time" + "github.com/dgraph-io/dgo/v210/protos/api" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -13,6 +15,103 @@ import ( "gofr.dev/pkg/gofr/testutil" ) +func Test_DGraphCommitMigration_InvalidTxn(t *testing.T) { + migratorWithDGraph, mockDGraph, mockContainer := dgraphSetup(t) + + td := transactionData{ + StartTime: time.Now(), + MigrationNumber: 42, + } + + mockDGraph.EXPECT().NewTxn().Return(&fakeNotTxn{}) + + err := migratorWithDGraph.commitMigration(mockContainer, td) + + assert.Equal(t, errInvalidDgraphTxn, err) +} + +type fakeNotTxn struct{} + +type mockDgraphTxn2 struct { + mutateErr error + commitErr error + discarded bool + commitCalled bool +} + +func (m *mockDgraphTxn2) Mutate(context.Context, *api.Mutation) (*api.Response, error) { + return nil, m.mutateErr +} + +func (m *mockDgraphTxn2) Commit(context.Context) error { + m.commitCalled = true + return m.commitErr +} + +func (m *mockDgraphTxn2) Discard(context.Context) error { + m.discarded = true + return nil +} + +func Test_DGraphCommitMigration_MutateError(t *testing.T) { + migratorWithDGraph, mockDGraph, mockContainer := dgraphSetup(t) + + td := transactionData{ + StartTime: time.Now(), + MigrationNumber: 43, + } + + txn := &mockDgraphTxn2{mutateErr: errors.New("mutation failed")} + mockDGraph.EXPECT().NewTxn().Return(txn) + + err := migratorWithDGraph.commitMigration(mockContainer, td) + + assert.EqualError(t, err, "mutation failed") + assert.True(t, txn.discarded) + assert.False(t, txn.commitCalled) +} + +func Test_DGraphCommitMigration_CommitError(t *testing.T) { + migratorWithDGraph, mockDGraph, mockContainer := dgraphSetup(t) + + td := transactionData{ + StartTime: time.Now(), + MigrationNumber: 44, + } + + txn := &mockDgraphTxn2{commitErr: errors.New("commit failed")} + mockDGraph.EXPECT().NewTxn().Return(txn) + + err := migratorWithDGraph.commitMigration(mockContainer, td) + + assert.EqualError(t, err, "commit failed") + assert.True(t, txn.discarded) + assert.True(t, txn.commitCalled) +} + +type mockDgraphTxn struct { + mutateErr error + commitErr error + commitDone bool + discarded bool +} + +func (m *mockDgraphTxn) Mutate(context.Context, *api.Mutation) (*api.Response, error) { + return nil, m.mutateErr +} + +func (m *mockDgraphTxn) Commit(context.Context) error { + m.commitDone = true + + return m.commitErr +} + +func (m *mockDgraphTxn) Discard(context.Context) error { + m.discarded = true + + return nil +} + func dgraphSetup(t *testing.T) (migrator, *container.MockDgraph, *container.Container) { t.Helper() @@ -99,11 +198,14 @@ func Test_DGraphCommitMigration(t *testing.T) { timeNow := time.Now() testCases := []struct { - desc string - err error + desc string + mutateErr error + commitErr error + err error }{ - {"success", nil}, - {"mutation failed", context.DeadlineExceeded}, + {"success", nil, nil, nil}, + {"mutation failed", context.DeadlineExceeded, nil, context.DeadlineExceeded}, + {"commit failed", nil, context.Canceled, context.Canceled}, } td := transactionData{ @@ -112,11 +214,19 @@ func Test_DGraphCommitMigration(t *testing.T) { } for i, tc := range testCases { - mockDGraph.EXPECT().Mutate(gomock.Any(), gomock.Any()).Return(nil, tc.err) + tx := &mockDgraphTxn{mutateErr: tc.mutateErr, commitErr: tc.commitErr} + mockDGraph.EXPECT().NewTxn().Return(tx) err := migratorWithDGraph.commitMigration(mockContainer, td) assert.Equal(t, tc.err, err, "TEST[%v]\n %v Failed!", i, tc.desc) + assert.True(t, tx.discarded, "TEST[%v]\n %v Failed!", i, tc.desc) + + if tc.mutateErr == nil { + assert.True(t, tx.commitDone, "TEST[%v]\n %v Failed!", i, tc.desc) + } else { + assert.False(t, tx.commitDone, "TEST[%v]\n %v Failed!", i, tc.desc) + } } }