Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion pkg/gofr/migration/dgraph.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package migration
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"

Expand All @@ -22,6 +23,12 @@ type dgraphMigrator struct {
migrator
}

type dgraphTxn interface {
Mutate(ctx context.Context, mu *api.Mutation) (*api.Response, error)
Commit(ctx context.Context) error
Discard(ctx context.Context) error
}

const (
// dgraphSchema defines the migration schema with fully qualified predicate names.
dgraphSchema = `
Expand All @@ -47,6 +54,10 @@ const (
`
)

var (
errInvalidDgraphTxn = errors.New("invalid dgraph transaction type")
)

// apply creates a new dgraphMigrator.
func (ds dgraphDS) apply(m migrator) migrator {
return dgraphMigrator{
Expand Down Expand Up @@ -144,6 +155,8 @@ func (dm dgraphMigrator) beginTransaction(c *container.Container) transactionDat

// commitMigration commits the migration and records its metadata.
func (dm dgraphMigrator) commitMigration(c *container.Container, data transactionData) error {
ctx := context.Background()

// Build the JSON payload for the migration record.
payload := map[string]any{
"migrations": []map[string]any{
Expand All @@ -161,13 +174,25 @@ func (dm dgraphMigrator) commitMigration(c *container.Container, data transactio
return err
}

_, err = c.DGraph.Mutate(context.Background(), &api.Mutation{
tx, ok := c.DGraph.NewTxn().(dgraphTxn)
if !ok {
return errInvalidDgraphTxn
}
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commitMigration has a new failure path when c.DGraph.NewTxn() doesn't implement the expected transaction interface, but the tests only cover mutate/commit errors. Add a test case where NewTxn() returns a non-transaction object (or nil) and assert errInvalidDgraphTxn is returned to prevent regressions in this guard logic.

Copilot uses AI. Check for mistakes.

defer tx.Discard(ctx)

_, err = tx.Mutate(ctx, &api.Mutation{
SetJson: jsonPayload,
})
if err != nil {
return err
}

err = tx.Commit(ctx)
if err != nil {
return err
}

c.Debugf("Inserted record for migration %v in Dgraph migrations", data.MigrationNumber)

return dm.migrator.commitMigration(c, data)
Expand Down
45 changes: 40 additions & 5 deletions pkg/gofr/migration/dgraph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"
"time"

"github.com/dgraph-io/dgo/v210/protos/api"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
Expand All @@ -13,6 +14,29 @@ import (
"gofr.dev/pkg/gofr/testutil"
)

type mockDgraphTxn struct {
mutateErr error
commitErr error
commitDone bool
discarded bool
}

func (m *mockDgraphTxn) Mutate(context.Context, *api.Mutation) (*api.Response, error) {
return nil, m.mutateErr
}

func (m *mockDgraphTxn) Commit(context.Context) error {
m.commitDone = true

return m.commitErr
}

func (m *mockDgraphTxn) Discard(context.Context) error {
m.discarded = true

return nil
}

func dgraphSetup(t *testing.T) (migrator, *container.MockDgraph, *container.Container) {
t.Helper()

Expand Down Expand Up @@ -99,11 +123,14 @@ func Test_DGraphCommitMigration(t *testing.T) {
timeNow := time.Now()

testCases := []struct {
desc string
err error
desc string
mutateErr error
commitErr error
err error
}{
{"success", nil},
{"mutation failed", context.DeadlineExceeded},
{"success", nil, nil, nil},
{"mutation failed", context.DeadlineExceeded, nil, context.DeadlineExceeded},
{"commit failed", nil, context.Canceled, context.Canceled},
}

td := transactionData{
Expand All @@ -112,11 +139,19 @@ func Test_DGraphCommitMigration(t *testing.T) {
}

for i, tc := range testCases {
mockDGraph.EXPECT().Mutate(gomock.Any(), gomock.Any()).Return(nil, tc.err)
tx := &mockDgraphTxn{mutateErr: tc.mutateErr, commitErr: tc.commitErr}
mockDGraph.EXPECT().NewTxn().Return(tx)

err := migratorWithDGraph.commitMigration(mockContainer, td)

assert.Equal(t, tc.err, err, "TEST[%v]\n %v Failed!", i, tc.desc)
assert.True(t, tx.discarded, "TEST[%v]\n %v Failed!", i, tc.desc)

if tc.mutateErr == nil {
assert.True(t, tx.commitDone, "TEST[%v]\n %v Failed!", i, tc.desc)
} else {
assert.False(t, tx.commitDone, "TEST[%v]\n %v Failed!", i, tc.desc)
}
}
}

Expand Down
Loading