diff --git a/pkg/code/async/account/gift_card.go b/pkg/code/async/account/gift_card.go index 697da8ba..7fca28fc 100644 --- a/pkg/code/async/account/gift_card.go +++ b/pkg/code/async/account/gift_card.go @@ -351,8 +351,8 @@ func markFulfillmentAsActivelyScheduled(ctx context.Context, data code_data.Prov return errors.New("expected fulfillment in unknown state") } - // Note: different than Save, since we don't have distributed locks - return data.MarkFulfillmentAsActivelyScheduled(ctx, fulfillmentRecord.Id) + fulfillmentRecord.DisableActiveScheduling = false + return data.UpdateFulfillment(ctx, fulfillmentRecord) } // Must be unique, but consistent for idempotency, and ideally fit in a 32 diff --git a/pkg/code/async/geyser/external_deposit.go b/pkg/code/async/geyser/external_deposit.go index a3bac1fc..4737bfae 100644 --- a/pkg/code/async/geyser/external_deposit.go +++ b/pkg/code/async/geyser/external_deposit.go @@ -3,6 +3,7 @@ package async_geyser import ( "context" "crypto/sha256" + "database/sql" "fmt" "strconv" "strings" @@ -274,6 +275,8 @@ func processPotentialExternalDepositIntoVm(ctx context.Context, data code_data.P if err == nil { syncedDepositCache.Insert(cacheKey, true, 1) return nil + } else if err != deposit.ErrDepositNotFound { + return errors.Wrap(err, "error checking for existing external deposit record") } ownerAccount, err := common.NewAccountFromPublicKeyString(accountInfoRecord.OwnerAccount) @@ -287,42 +290,49 @@ func processPotentialExternalDepositIntoVm(ctx context.Context, data code_data.P } usdMarketValue := usdExchangeRecord.Rate * float64(deltaQuarksIntoOmnibus) / float64(common.CoreMintQuarksPerUnit) - // For transaction history - intentRecord := &intent.Record{ - IntentId: getExternalDepositIntentID(signature, userVirtualTimelockVaultAccount), - IntentType: intent.ExternalDeposit, + err = data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { + // For transaction history + intentRecord := &intent.Record{ + IntentId: getExternalDepositIntentID(signature, userVirtualTimelockVaultAccount), + IntentType: intent.ExternalDeposit, - InitiatorOwnerAccount: ownerAccount.PublicKey().ToBase58(), + InitiatorOwnerAccount: ownerAccount.PublicKey().ToBase58(), - ExternalDepositMetadata: &intent.ExternalDepositMetadata{ - DestinationTokenAccount: userVirtualTimelockVaultAccount.PublicKey().ToBase58(), - Quantity: uint64(deltaQuarksIntoOmnibus), - UsdMarketValue: usdMarketValue, - }, + ExternalDepositMetadata: &intent.ExternalDepositMetadata{ + DestinationTokenAccount: userVirtualTimelockVaultAccount.PublicKey().ToBase58(), + Quantity: uint64(deltaQuarksIntoOmnibus), + UsdMarketValue: usdMarketValue, + }, - State: intent.StateConfirmed, - CreatedAt: time.Now(), - } - err = data.SaveIntent(ctx, intentRecord) - if err != nil { - return errors.Wrap(err, "error saving intent record") - } + State: intent.StateConfirmed, + CreatedAt: time.Now(), + } + err = data.SaveIntent(ctx, intentRecord) + if err != nil { + return errors.Wrap(err, "error saving intent record") + } - // For tracking in cached balances - externalDepositRecord := &deposit.Record{ - Signature: signature, - Destination: userVirtualTimelockVaultAccount.PublicKey().ToBase58(), - Amount: uint64(deltaQuarksIntoOmnibus), - UsdMarketValue: usdMarketValue, + // For tracking in cached balances + externalDepositRecord := &deposit.Record{ + Signature: signature, + Destination: userVirtualTimelockVaultAccount.PublicKey().ToBase58(), + Amount: uint64(deltaQuarksIntoOmnibus), + UsdMarketValue: usdMarketValue, - Slot: tokenBalances.Slot, - ConfirmationState: transaction.ConfirmationFinalized, + Slot: tokenBalances.Slot, + ConfirmationState: transaction.ConfirmationFinalized, - CreatedAt: time.Now(), - } - err = data.SaveExternalDeposit(ctx, externalDepositRecord) + CreatedAt: time.Now(), + } + err = data.SaveExternalDeposit(ctx, externalDepositRecord) + if err != nil { + return errors.Wrap(err, "error saving external deposit record") + } + + return nil + }) if err != nil { - return errors.Wrap(err, "error saving external deposit record") + return err } syncedDepositCache.Insert(cacheKey, true, 1) diff --git a/pkg/code/async/sequencer/action_handler.go b/pkg/code/async/sequencer/action_handler.go index 3f78d433..f99dce43 100644 --- a/pkg/code/async/sequencer/action_handler.go +++ b/pkg/code/async/sequencer/action_handler.go @@ -170,25 +170,6 @@ func markActionFailed(ctx context.Context, data code_data.Provider, intentId str return data.UpdateAction(ctx, record) } -func markActionRevoked(ctx context.Context, data code_data.Provider, intentId string, actionId uint32) error { - record, err := data.GetActionById(ctx, intentId, actionId) - if err != nil { - return err - } - - if record.State == action.StateRevoked { - return nil - } - - err = validateActionState(record, action.StateUnknown) - if err != nil { - return err - } - - record.State = action.StateRevoked - return data.UpdateAction(ctx, record) -} - func getActionHandlers(data code_data.Provider) map[action.Type]ActionHandler { handlersByType := make(map[action.Type]ActionHandler) handlersByType[action.OpenAccount] = NewOpenAccountActionHandler(data) diff --git a/pkg/code/async/sequencer/utils.go b/pkg/code/async/sequencer/utils.go index 46b54b0f..6183ba83 100644 --- a/pkg/code/async/sequencer/utils.go +++ b/pkg/code/async/sequencer/utils.go @@ -116,8 +116,8 @@ func markFulfillmentAsActivelyScheduled(ctx context.Context, data code_data.Prov return nil } - // Note: different than Save, since we don't have distributed locks - return data.MarkFulfillmentAsActivelyScheduled(ctx, fulfillmentRecord.Id) + fulfillmentRecord.DisableActiveScheduling = false + return data.UpdateFulfillment(ctx, fulfillmentRecord) } func (p *service) sendToBlockchain(ctx context.Context, record *fulfillment.Record) error { diff --git a/pkg/code/data/action/action.go b/pkg/code/data/action/action.go index 1ac47f77..61ce9213 100644 --- a/pkg/code/data/action/action.go +++ b/pkg/code/data/action/action.go @@ -66,6 +66,8 @@ type Record struct { State State + Version uint64 + CreatedAt time.Time } @@ -121,6 +123,8 @@ func (r *Record) Clone() Record { State: r.State, + Version: r.Version, + CreatedAt: r.CreatedAt, } } @@ -142,6 +146,8 @@ func (r *Record) CopyTo(dst *Record) { dst.State = r.State + dst.Version = r.Version + dst.CreatedAt = r.CreatedAt } diff --git a/pkg/code/data/action/memory/store.go b/pkg/code/data/action/memory/store.go index c4b3e3a9..27b76cc7 100644 --- a/pkg/code/data/action/memory/store.go +++ b/pkg/code/data/action/memory/store.go @@ -179,6 +179,7 @@ func (s *store) PutAll(ctx context.Context, records ...*action.Record) error { if record.CreatedAt.IsZero() { record.CreatedAt = time.Now() } + record.Version++ cloned := record.Clone() s.records = append(s.records, &cloned) @@ -193,14 +194,20 @@ func (s *store) Update(ctx context.Context, record *action.Record) error { defer s.mu.Unlock() if item := s.find(record); item != nil { + if record.Version != item.Version { + return action.ErrStaleVersion + } + record.Version++ + if record.IntentType == intent.SendPublicPayment && record.ActionType == action.NoPrivacyWithdraw { item.Quantity = pointer.Uint64Copy(record.Quantity) } item.State = record.State + item.Version = record.Version return nil } - return action.ErrActionNotFound + return action.ErrStaleVersion } // GetById implements action.store.GetById diff --git a/pkg/code/data/action/postgres/model.go b/pkg/code/data/action/postgres/model.go index 79970498..233d79e6 100644 --- a/pkg/code/data/action/postgres/model.go +++ b/pkg/code/data/action/postgres/model.go @@ -32,6 +32,7 @@ type model struct { Quantity sql.NullInt64 `db:"quantity"` FeeType sql.NullInt32 `db:"fee_type"` State uint `db:"state"` + Version int64 `db:"version"` CreatedAt time.Time `db:"created_at"` } @@ -68,6 +69,7 @@ func toModel(obj *action.Record) (*model, error) { Quantity: quantity, FeeType: feeType, State: uint(obj.State), + Version: int64(obj.Version), CreatedAt: obj.CreatedAt, }, nil } @@ -84,6 +86,7 @@ func fromModel(obj *model) *action.Record { Quantity: pointer.Uint64IfValid(obj.Quantity.Valid, uint64(obj.Quantity.Int64)), FeeType: (*transactionpb.FeePaymentAction_FeeType)(pointer.Int32IfValid(obj.FeeType.Valid, obj.FeeType.Int32)), State: action.State(obj.State), + Version: uint64(obj.Version), CreatedAt: obj.CreatedAt, } } @@ -94,18 +97,19 @@ func (m *model) dbUpdate(ctx context.Context, db *sqlx.DB) error { params := []interface{}{ m.Intent, m.ActionId, + m.Version, m.State, } if m.IntentType == uint(intent.SendPublicPayment) && m.ActionType == uint(action.NoPrivacyWithdraw) { - quantityUpdateStmt = ", quantity = $4" + quantityUpdateStmt = ", quantity = $5" params = append(params, m.Quantity) } query := fmt.Sprintf(`UPDATE `+tableName+` - SET state = $3%s - WHERE intent = $1 AND action_id = $2 - RETURNING id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, created_at + SET state = $4%s, version = version + 1 + WHERE intent = $1 AND action_id = $2 AND version = $3 + RETURNING id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, version, created_at `, quantityUpdateStmt) err := tx.QueryRowxContext( @@ -114,7 +118,7 @@ func (m *model) dbUpdate(ctx context.Context, db *sqlx.DB) error { params..., ).StructScan(m) if err != nil { - return pgutil.CheckNoRows(err, action.ErrActionNotFound) + return pgutil.CheckNoRows(err, action.ErrStaleVersion) } return nil @@ -124,7 +128,7 @@ func (m *model) dbUpdate(ctx context.Context, db *sqlx.DB) error { func dbPutAllInTx(ctx context.Context, tx *sqlx.Tx, models []*model) ([]*model, error) { var res []*model - query := `INSERT INTO ` + tableName + ` (intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, created_at) VALUES ` + query := `INSERT INTO ` + tableName + ` (intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, version, created_at) VALUES ` var parameters []interface{} for i, model := range models { @@ -134,8 +138,8 @@ func dbPutAllInTx(ctx context.Context, tx *sqlx.Tx, models []*model) ([]*model, baseIndex := len(parameters) query += fmt.Sprintf( - `($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d)`, - baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7, baseIndex+8, baseIndex+9, baseIndex+10, + `($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d + 1, $%d)`, + baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7, baseIndex+8, baseIndex+9, baseIndex+10, baseIndex+11, ) if i != len(models)-1 { @@ -153,11 +157,12 @@ func dbPutAllInTx(ctx context.Context, tx *sqlx.Tx, models []*model) ([]*model, model.Quantity, model.FeeType, model.State, + model.Version, model.CreatedAt, ) } - query += ` RETURNING id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, created_at` + query += ` RETURNING id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, version, created_at` err := tx.SelectContext( ctx, @@ -175,7 +180,7 @@ func dbPutAllInTx(ctx context.Context, tx *sqlx.Tx, models []*model) ([]*model, func dbGetById(ctx context.Context, db *sqlx.DB, intent string, actionId uint32) (*model, error) { res := &model{} - query := `SELECT id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, version, created_at FROM ` + tableName + ` WHERE intent = $1 AND action_id = $2 LIMIT 1` @@ -190,7 +195,7 @@ func dbGetById(ctx context.Context, db *sqlx.DB, intent string, actionId uint32) func dbGetAllByIntent(ctx context.Context, db *sqlx.DB, intent string) ([]*model, error) { res := []*model{} - query := `SELECT id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, version, created_at FROM ` + tableName + ` WHERE intent = $1 ORDER BY action_id ASC` @@ -210,7 +215,7 @@ func dbGetAllByIntent(ctx context.Context, db *sqlx.DB, intent string) ([]*model func dbGetAllByAddress(ctx context.Context, db *sqlx.DB, address string) ([]*model, error) { res := []*model{} - query := `SELECT id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, version, created_at FROM ` + tableName + ` WHERE source = $1 OR destination = $1` @@ -300,7 +305,7 @@ func dbGetNetBalanceBatch(ctx context.Context, db *sqlx.DB, accounts ...string) func dbGetGiftCardClaimedAction(ctx context.Context, db *sqlx.DB, giftCardVault string) (*model, error) { res := []*model{} - query := `SELECT id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, version, created_at FROM ` + tableName + ` WHERE source = $1 AND action_type = $2 AND intent_type = $3 AND state != $4 LIMIT 2` @@ -330,7 +335,7 @@ func dbGetGiftCardClaimedAction(ctx context.Context, db *sqlx.DB, giftCardVault func dbGetGiftCardAutoReturnAction(ctx context.Context, db *sqlx.DB, giftCardVault string) (*model, error) { res := []*model{} - query := `SELECT id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, version, created_at FROM ` + tableName + ` WHERE source = $1 AND action_type = $2 AND intent_type = $3 AND state != $4 LIMIT 2` diff --git a/pkg/code/data/action/postgres/store.go b/pkg/code/data/action/postgres/store.go index c786d493..246a24d1 100644 --- a/pkg/code/data/action/postgres/store.go +++ b/pkg/code/data/action/postgres/store.go @@ -89,7 +89,15 @@ func (s *store) Update(ctx context.Context, record *action.Record) error { return err } - return model.dbUpdate(ctx, s.db) + err = model.dbUpdate(ctx, s.db) + if err != nil { + return err + } + + updated := fromModel(model) + updated.CopyTo(record) + + return nil } // GetById implements action.store.GetById diff --git a/pkg/code/data/action/postgres/store_test.go b/pkg/code/data/action/postgres/store_test.go index 927ba4e1..e8e27d90 100644 --- a/pkg/code/data/action/postgres/store_test.go +++ b/pkg/code/data/action/postgres/store_test.go @@ -36,6 +36,8 @@ const ( state INTEGER NOT NULL, + version INTEGER NOT NULL, + created_at timestamp with time zone NOT NULL, CONSTRAINT codewallet__core_action__uniq__intent__and__action_id UNIQUE (intent, action_id) diff --git a/pkg/code/data/action/store.go b/pkg/code/data/action/store.go index 0d7d1bff..85d405ed 100644 --- a/pkg/code/data/action/store.go +++ b/pkg/code/data/action/store.go @@ -8,9 +8,10 @@ import ( ) var ( - ErrActionNotFound = errors.New("no action ecord could be found") - ErrMultipleActionsFound = errors.New("multiple action records found") - ErrActionExists = errors.New("action record already exists") + ErrActionNotFound = errors.New("no action could be found") + ErrMultipleActionsFound = errors.New("multiple actions found") + ErrActionExists = errors.New("action already exists") + ErrStaleVersion = errors.New("action version is stale") ) type Store interface { diff --git a/pkg/code/data/action/tests/tests.go b/pkg/code/data/action/tests/tests.go index d1583bac..7bb4ea27 100644 --- a/pkg/code/data/action/tests/tests.go +++ b/pkg/code/data/action/tests/tests.go @@ -21,6 +21,7 @@ func RunTests(t *testing.T, s action.Store, teardown func()) { for _, tf := range []func(t *testing.T, s action.Store){ testRoundTrip, testBatchPut, + testUpdateStaleRecord, testGetAllByIntent, testGetAllByAddress, testGetNetBalance, @@ -58,19 +59,19 @@ func testRoundTrip(t *testing.T, s action.Store) { _, err := s.GetById(ctx, expected.Intent, expected.ActionId) assert.Equal(t, action.ErrActionNotFound, err) - assert.Equal(t, action.ErrActionNotFound, s.Update(ctx, expected)) + assert.Equal(t, action.ErrStaleVersion, s.Update(ctx, expected)) cloned := expected.Clone() require.NoError(t, s.PutAll(ctx, expected)) + assert.EqualValues(t, 1, expected.Id) + assert.EqualValues(t, 1, expected.Version) actual, err := s.GetById(ctx, expected.Intent, expected.ActionId) require.NoError(t, err) - - assert.True(t, expected.Id > 0) assert.Equal(t, expected.Id, actual.Id) + assert.Equal(t, expected.Version, actual.Version) assert.True(t, expected.CreatedAt.After(start)) assert.Equal(t, expected.CreatedAt, actual.CreatedAt) - assertEquivalentRecords(t, &cloned, actual) assert.Equal(t, action.ErrActionExists, s.PutAll(ctx, expected)) @@ -78,10 +79,15 @@ func testRoundTrip(t *testing.T, s action.Store) { expected.Quantity = pointer.Uint64(12345) expected.State = action.StateFailed cloned = expected.Clone() + require.NoError(t, s.Update(ctx, expected)) + assert.EqualValues(t, 1, expected.Id) + assert.EqualValues(t, 2, expected.Version) actual, err = s.GetById(ctx, expected.Intent, expected.ActionId) require.NoError(t, err) + assert.Equal(t, expected.Id, actual.Id) + assert.Equal(t, expected.Version, actual.Version) assertEquivalentRecords(t, &cloned, actual) }) } @@ -116,6 +122,7 @@ func testBatchPut(t *testing.T, s action.Store) { for i, actual := range inserted { assert.True(t, actual.Id > 0) + assert.EqualValues(t, 1, actual.Version) assertEquivalentRecords(t, expected[i], actual) fetched, err := s.GetById(ctx, expected[i].Intent, expected[i].ActionId) @@ -141,6 +148,7 @@ func testBatchPut(t *testing.T, s action.Store) { fetched, err := s.GetById(ctx, expected[i].Intent, expected[i].ActionId) require.NoError(t, err) assert.Equal(t, actual.Id, fetched.Id) + assert.EqualValues(t, 1, actual.Version) assertEquivalentRecords(t, expected[i], fetched) } @@ -162,6 +170,61 @@ func testBatchPut(t *testing.T, s action.Store) { }) } +func testUpdateStaleRecord(t *testing.T, s action.Store) { + t.Run("testUpdateStaleRecord", func(t *testing.T) { + ctx := context.Background() + + start := time.Now() + + expected := &action.Record{ + Intent: "intent", + IntentType: intent.SendPublicPayment, + + ActionId: 1, + ActionType: action.NoPrivacyWithdraw, + + Source: "source", + Destination: pointer.String("destination"), + Quantity: nil, + + FeeType: (*transactionpb.FeePaymentAction_FeeType)(pointer.Int32((int32)(transactionpb.FeePaymentAction_CREATE_ON_SEND_WITHDRAWAL))), + + State: action.StateConfirmed, + } + + _, err := s.GetById(ctx, expected.Intent, expected.ActionId) + assert.Equal(t, action.ErrActionNotFound, err) + + cloned := expected.Clone() + require.NoError(t, s.PutAll(ctx, expected)) + assert.EqualValues(t, 1, expected.Id) + assert.EqualValues(t, 1, expected.Version) + + actual, err := s.GetById(ctx, expected.Intent, expected.ActionId) + require.NoError(t, err) + assert.Equal(t, expected.Id, actual.Id) + assert.Equal(t, expected.Version, actual.Version) + assert.True(t, expected.CreatedAt.After(start)) + assert.Equal(t, expected.CreatedAt, actual.CreatedAt) + assertEquivalentRecords(t, &cloned, actual) + + stale := expected.Clone() + stale.Quantity = pointer.Uint64(12345) + stale.State = action.StateFailed + stale.Version -= 1 + + assert.Equal(t, action.ErrStaleVersion, s.Update(ctx, &stale)) + assert.EqualValues(t, 1, stale.Id) + assert.EqualValues(t, 0, stale.Version) + + actual, err = s.GetById(ctx, expected.Intent, expected.ActionId) + require.NoError(t, err) + assert.EqualValues(t, 1, actual.Id) + assert.EqualValues(t, 1, actual.Version) + assertEquivalentRecords(t, &cloned, actual) + }) +} + func testGetAllByIntent(t *testing.T, s action.Store) { t.Run("testGetAllByIntent", func(t *testing.T) { ctx := context.Background() diff --git a/pkg/code/data/fulfillment/fulfillment.go b/pkg/code/data/fulfillment/fulfillment.go index 056f4929..e4221428 100644 --- a/pkg/code/data/fulfillment/fulfillment.go +++ b/pkg/code/data/fulfillment/fulfillment.go @@ -10,12 +10,6 @@ import ( "github.com/code-payments/code-server/pkg/pointer" ) -var ( - ErrFulfillmentNotFound = errors.New("no records could be found") - ErrFulfillmentExists = errors.New("fulfillment exists") - ErrInvalidFulfillment = errors.New("invalid fulfillment") -) - type Type uint8 const ( @@ -89,6 +83,8 @@ type Record struct { State State + Version uint64 + CreatedAt time.Time } @@ -154,6 +150,7 @@ func (r *Record) Clone() Record { FulfillmentOrderingIndex: r.FulfillmentOrderingIndex, DisableActiveScheduling: r.DisableActiveScheduling, State: r.State, + Version: r.Version, CreatedAt: r.CreatedAt, } } @@ -179,6 +176,7 @@ func (r *Record) CopyTo(dst *Record) { dst.FulfillmentOrderingIndex = r.FulfillmentOrderingIndex dst.DisableActiveScheduling = r.DisableActiveScheduling dst.State = r.State + dst.Version = r.Version dst.CreatedAt = r.CreatedAt } diff --git a/pkg/code/data/fulfillment/memory/store.go b/pkg/code/data/fulfillment/memory/store.go index 1cb08d09..7d65b9c8 100644 --- a/pkg/code/data/fulfillment/memory/store.go +++ b/pkg/code/data/fulfillment/memory/store.go @@ -478,6 +478,8 @@ func (s *store) PutAll(ctx context.Context, records ...*fulfillment.Record) erro data.CreatedAt = time.Now() } + data.Version++ + c := data.Clone() s.records = append(s.records, &c) } @@ -501,9 +503,15 @@ func (s *store) Update(ctx context.Context, data *fulfillment.Record) error { item := s.findById(data.Id) if item == nil { - return fulfillment.ErrFulfillmentNotFound + return fulfillment.ErrStaleVersion } + if data.Version != item.Version { + return fulfillment.ErrStaleVersion + } + + data.Version++ + item.Data = data.Data item.Signature = pointer.StringCopy(data.Signature) @@ -514,8 +522,12 @@ func (s *store) Update(ctx context.Context, data *fulfillment.Record) error { item.VirtualNonce = pointer.StringCopy(data.VirtualNonce) item.VirtualBlockhash = pointer.StringCopy(data.VirtualBlockhash) + item.DisableActiveScheduling = data.DisableActiveScheduling + item.State = data.State + item.Version = data.Version + if item.IntentType == intent.SendPublicPayment && item.FulfillmentType == fulfillment.NoPrivacyWithdraw { item.IntentOrderingIndex = data.IntentOrderingIndex item.ActionOrderingIndex = data.ActionOrderingIndex @@ -525,20 +537,6 @@ func (s *store) Update(ctx context.Context, data *fulfillment.Record) error { return nil } -func (s *store) MarkAsActivelyScheduled(ctx context.Context, id uint64) error { - s.mu.Lock() - defer s.mu.Unlock() - - item := s.findById(id) - if item == nil { - return fulfillment.ErrFulfillmentNotFound - } - - item.DisableActiveScheduling = false - - return nil -} - func (s *store) GetById(ctx context.Context, id uint64) (*fulfillment.Record, error) { if id == 0 { return nil, fulfillment.ErrFulfillmentNotFound diff --git a/pkg/code/data/fulfillment/postgres/model.go b/pkg/code/data/fulfillment/postgres/model.go index acfd710e..ed1992a5 100644 --- a/pkg/code/data/fulfillment/postgres/model.go +++ b/pkg/code/data/fulfillment/postgres/model.go @@ -11,6 +11,7 @@ import ( "github.com/code-payments/code-server/pkg/code/data/action" "github.com/code-payments/code-server/pkg/code/data/fulfillment" "github.com/code-payments/code-server/pkg/code/data/intent" + "github.com/code-payments/code-server/pkg/pointer" pgutil "github.com/code-payments/code-server/pkg/database/postgres" q "github.com/code-payments/code-server/pkg/database/query" @@ -41,6 +42,7 @@ type fulfillmentModel struct { FulfillmentOrderingIndex uint32 `db:"fulfillment_ordering_index"` DisableActiveScheduling bool `db:"disable_active_scheduling"` State uint `db:"state"` + Version int64 `db:"version"` CreatedAt time.Time `db:"created_at"` BatchInsertionId int `db:"batch_insertion_id"` } @@ -117,46 +119,12 @@ func toFulfillmentModel(obj *fulfillment.Record) (*fulfillmentModel, error) { FulfillmentOrderingIndex: obj.FulfillmentOrderingIndex, DisableActiveScheduling: obj.DisableActiveScheduling, State: uint(obj.State), + Version: int64(obj.Version), CreatedAt: obj.CreatedAt, }, nil } func fromFulfillmentModel(obj *fulfillmentModel) *fulfillment.Record { - var sig *string - if obj.Signature.Valid { - sig = &obj.Signature.String - } - - var nonce *string - if obj.Nonce.Valid { - nonce = &obj.Nonce.String - } - - var blockhash *string - if obj.Blockhash.Valid { - blockhash = &obj.Blockhash.String - } - - var virtualSig *string - if obj.VirtualSignature.Valid { - virtualSig = &obj.VirtualSignature.String - } - - var virtualNonce *string - if obj.VirtualNonce.Valid { - virtualNonce = &obj.VirtualNonce.String - } - - var virtualBlockhash *string - if obj.VirtualBlockhash.Valid { - virtualBlockhash = &obj.VirtualBlockhash.String - } - - var destination *string - if obj.Destination.Valid { - destination = &obj.Destination.String - } - return &fulfillment.Record{ Id: uint64(obj.Id), Intent: obj.Intent, @@ -165,19 +133,20 @@ func fromFulfillmentModel(obj *fulfillmentModel) *fulfillment.Record { ActionType: action.Type(obj.ActionType), FulfillmentType: fulfillment.Type(obj.FulfillmentType), Data: obj.Data, - Signature: sig, - Nonce: nonce, - Blockhash: blockhash, - VirtualSignature: virtualSig, - VirtualNonce: virtualNonce, - VirtualBlockhash: virtualBlockhash, + Signature: pointer.StringIfValid(obj.Signature.Valid, obj.Signature.String), + Nonce: pointer.StringIfValid(obj.Nonce.Valid, obj.Nonce.String), + Blockhash: pointer.StringIfValid(obj.Blockhash.Valid, obj.Blockhash.String), + VirtualSignature: pointer.StringIfValid(obj.VirtualSignature.Valid, obj.VirtualSignature.String), + VirtualNonce: pointer.StringIfValid(obj.VirtualNonce.Valid, obj.VirtualNonce.String), + VirtualBlockhash: pointer.StringIfValid(obj.VirtualBlockhash.Valid, obj.VirtualBlockhash.String), Source: obj.Source, - Destination: destination, + Destination: pointer.StringIfValid(obj.Destination.Valid, obj.Destination.String), IntentOrderingIndex: obj.IntentOrderingIndex, ActionOrderingIndex: obj.ActionOrderingIndex, FulfillmentOrderingIndex: obj.FulfillmentOrderingIndex, DisableActiveScheduling: obj.DisableActiveScheduling, State: fulfillment.State(obj.State), + Version: uint64(obj.Version), CreatedAt: obj.CreatedAt.UTC(), } } @@ -372,10 +341,12 @@ func (m *fulfillmentModel) dbUpdate(ctx context.Context, db *sqlx.DB) error { m.VirtualSignature, m.VirtualNonce, m.VirtualBlockhash, + m.DisableActiveScheduling, + m.Version, } if m.IntentType == uint(intent.SendPublicPayment) && m.FulfillmentType == uint(fulfillment.NoPrivacyWithdraw) { - preSortingUpdateStmt = ", intent_ordering_index = $10, action_ordering_index = $11, fulfillment_ordering_index = $12" + preSortingUpdateStmt = ", intent_ordering_index = $12, action_ordering_index = $13, fulfillment_ordering_index = $14" params = append( params, m.IntentOrderingIndex, @@ -385,10 +356,10 @@ func (m *fulfillmentModel) dbUpdate(ctx context.Context, db *sqlx.DB) error { } query := fmt.Sprintf(`UPDATE `+fulfillmentTableName+` - SET signature = $2, nonce = $3, blockhash = $4, data = $5, state = $6, virtual_signature = $7, virtual_nonce = $8, virtual_blockhash = $9%s - WHERE id = $1 + SET signature = $2, nonce = $3, blockhash = $4, data = $5, state = $6, virtual_signature = $7, virtual_nonce = $8, virtual_blockhash = $9, disable_active_scheduling = $10%s, version = version + 1 + WHERE id = $1 AND version = $11 RETURNING - id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, created_at`, + id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, version, created_at`, preSortingUpdateStmt, ) @@ -398,7 +369,7 @@ func (m *fulfillmentModel) dbUpdate(ctx context.Context, db *sqlx.DB) error { params..., ).StructScan(m) - return pgutil.CheckNoRows(err, fulfillment.ErrFulfillmentNotFound) + return pgutil.CheckNoRows(err, fulfillment.ErrStaleVersion) }) } @@ -406,7 +377,7 @@ func dbPutAllInTx(ctx context.Context, tx *sqlx.Tx, models []*fulfillmentModel) var res []*fulfillmentModel query := `WITH inserted AS (` - query += `INSERT INTO ` + fulfillmentTableName + ` (intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, created_at, batch_insertion_id) VALUES ` + query += `INSERT INTO ` + fulfillmentTableName + ` (intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, version, created_at, batch_insertion_id) VALUES ` var parameters []interface{} for i, model := range models { @@ -420,8 +391,8 @@ func dbPutAllInTx(ctx context.Context, tx *sqlx.Tx, models []*fulfillmentModel) baseIndex := len(parameters) query += fmt.Sprintf( - `($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d)`, - baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7, baseIndex+8, baseIndex+9, baseIndex+10, baseIndex+11, baseIndex+12, baseIndex+13, baseIndex+14, baseIndex+15, baseIndex+16, baseIndex+17, baseIndex+18, baseIndex+19, baseIndex+20, baseIndex+21, + `($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d + 1, $%d, $%d)`, + baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7, baseIndex+8, baseIndex+9, baseIndex+10, baseIndex+11, baseIndex+12, baseIndex+13, baseIndex+14, baseIndex+15, baseIndex+16, baseIndex+17, baseIndex+18, baseIndex+19, baseIndex+20, baseIndex+21, baseIndex+22, ) if i != len(models)-1 { @@ -451,12 +422,13 @@ func dbPutAllInTx(ctx context.Context, tx *sqlx.Tx, models []*fulfillmentModel) model.FulfillmentOrderingIndex, model.DisableActiveScheduling, model.State, + model.Version, model.CreatedAt, batchInsertionId, ) } - query += ` RETURNING id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, created_at, batch_insertion_id) ` + query += ` RETURNING id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, version, created_at, batch_insertion_id) ` // Kind of hacky, but we don't really have a great PK for on demand transactions // that allows us to update the corresponding record that was passed in (for example, @@ -477,30 +449,6 @@ func dbPutAllInTx(ctx context.Context, tx *sqlx.Tx, models []*fulfillmentModel) return res, nil } -func dbMarkAsActivelyScheduled(ctx context.Context, db *sqlx.DB, id uint64) error { - return pgutil.ExecuteInTx(ctx, db, sql.LevelDefault, func(tx *sqlx.Tx) error { - if id == 0 { - return fulfillment.ErrFulfillmentNotFound - } - - query := `UPDATE ` + fulfillmentTableName + ` SET disable_active_scheduling = false WHERE id = $1` - res, err := tx.ExecContext(ctx, query, id) - if err != nil { - return err - } - - rowsAffected, err := res.RowsAffected() - if err != nil { - return err - } - - if rowsAffected == 0 { - return fulfillment.ErrFulfillmentNotFound - } - return nil - }) -} - func dbGetById(ctx context.Context, db *sqlx.DB, id uint64) (*fulfillmentModel, error) { if id == 0 { return nil, fulfillment.ErrFulfillmentNotFound @@ -508,7 +456,7 @@ func dbGetById(ctx context.Context, db *sqlx.DB, id uint64) (*fulfillmentModel, res := &fulfillmentModel{} - query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, version, created_at FROM ` + fulfillmentTableName + ` WHERE id = $1 LIMIT 1` @@ -527,7 +475,7 @@ func dbGetBySignature(ctx context.Context, db *sqlx.DB, signature string) (*fulf res := &fulfillmentModel{} - query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, version, created_at FROM ` + fulfillmentTableName + ` WHERE signature = $1 LIMIT 1` @@ -546,7 +494,7 @@ func dbGetByVirtualSignature(ctx context.Context, db *sqlx.DB, signature string) res := &fulfillmentModel{} - query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, version, created_at FROM ` + fulfillmentTableName + ` WHERE virtual_signature = $1 LIMIT 1` @@ -561,7 +509,7 @@ func dbGetByVirtualSignature(ctx context.Context, db *sqlx.DB, signature string) func dbGetAllByState(ctx context.Context, db *sqlx.DB, state fulfillment.State, includeDisabledActiveScheduling bool, cursor q.Cursor, limit uint64, direction q.Ordering) ([]*fulfillmentModel, error) { res := []*fulfillmentModel{} - query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, version, created_at FROM ` + fulfillmentTableName + ` WHERE (state = $1 AND %s) ` @@ -590,7 +538,7 @@ func dbGetAllByState(ctx context.Context, db *sqlx.DB, state fulfillment.State, func dbGetAllByIntent(ctx context.Context, db *sqlx.DB, intent string, cursor q.Cursor, limit uint64, direction q.Ordering) ([]*fulfillmentModel, error) { res := []*fulfillmentModel{} - query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, version, created_at FROM ` + fulfillmentTableName + ` WHERE (intent = $1) ` @@ -613,7 +561,7 @@ func dbGetAllByIntent(ctx context.Context, db *sqlx.DB, intent string, cursor q. func dbGetAllByAction(ctx context.Context, db *sqlx.DB, intentId string, actionId uint32) ([]*fulfillmentModel, error) { res := []*fulfillmentModel{} - query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, version, created_at FROM ` + fulfillmentTableName + ` WHERE (intent = $1 and action_id = $2) ` @@ -633,7 +581,7 @@ func dbGetAllByAction(ctx context.Context, db *sqlx.DB, intentId string, actionI func dbGetAllByTypeAndAction(ctx context.Context, db *sqlx.DB, fulfillmentType fulfillment.Type, intentId string, actionId uint32) ([]*fulfillmentModel, error) { res := []*fulfillmentModel{} - query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, version, created_at FROM ` + fulfillmentTableName + ` WHERE intent = $1 AND action_id = $2 AND fulfillment_type = $3 ` @@ -653,7 +601,7 @@ func dbGetAllByTypeAndAction(ctx context.Context, db *sqlx.DB, fulfillmentType f func dbGetFirstSchedulableByAddressAsSource(ctx context.Context, db *sqlx.DB, address string) (*fulfillmentModel, error) { res := &fulfillmentModel{} - query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, version, created_at FROM ` + fulfillmentTableName + ` WHERE (source = $1 AND (state = $2 OR state = $3)) ORDER BY intent_ordering_index ASC, action_ordering_index ASC, fulfillment_ordering_index ASC @@ -669,7 +617,7 @@ func dbGetFirstSchedulableByAddressAsSource(ctx context.Context, db *sqlx.DB, ad func dbGetFirstSchedulableByAddressAsDestination(ctx context.Context, db *sqlx.DB, address string) (*fulfillmentModel, error) { res := &fulfillmentModel{} - query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, version, created_at FROM ` + fulfillmentTableName + ` WHERE (destination = $1 AND (state = $2 OR state = $3)) ORDER BY intent_ordering_index ASC, action_ordering_index ASC, fulfillment_ordering_index ASC @@ -685,7 +633,7 @@ func dbGetFirstSchedulableByAddressAsDestination(ctx context.Context, db *sqlx.D func dbGetFirstSchedulableByType(ctx context.Context, db *sqlx.DB, fulfillmentType fulfillment.Type) (*fulfillmentModel, error) { res := &fulfillmentModel{} - query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, version, created_at FROM ` + fulfillmentTableName + ` WHERE (fulfillment_type = $1 AND (state = $2 OR state = $3)) ORDER BY intent_ordering_index ASC, action_ordering_index ASC, fulfillment_ordering_index ASC @@ -701,7 +649,7 @@ func dbGetFirstSchedulableByType(ctx context.Context, db *sqlx.DB, fulfillmentTy func dbGetNextSchedulableByAddress(ctx context.Context, db *sqlx.DB, address string, intentOrderingIndex uint64, actionOrderingIndex, fulfillmentOrderingIndex uint32) (*fulfillmentModel, error) { res := &fulfillmentModel{} - query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, fulfillment_type, data, signature, nonce, blockhash, virtual_signature, virtual_nonce, virtual_blockhash, source, destination, intent_ordering_index, action_ordering_index, fulfillment_ordering_index, disable_active_scheduling, state, version, created_at FROM ` + fulfillmentTableName + ` WHERE ((source = $1 OR destination = $1) AND (state = $2 OR state = $3) AND (intent_ordering_index > $4 OR (intent_ordering_index = $4 AND action_ordering_index > $5) OR (intent_ordering_index = $4 AND action_ordering_index = $5 AND fulfillment_ordering_index > $6))) ORDER BY intent_ordering_index ASC, action_ordering_index ASC, fulfillment_ordering_index ASC diff --git a/pkg/code/data/fulfillment/postgres/store.go b/pkg/code/data/fulfillment/postgres/store.go index f6d151a8..7f4d5afb 100644 --- a/pkg/code/data/fulfillment/postgres/store.go +++ b/pkg/code/data/fulfillment/postgres/store.go @@ -126,11 +126,6 @@ func (s *store) Update(ctx context.Context, record *fulfillment.Record) error { return nil } -// MarkAsActivelyScheduled implements fulfillment.Store.MarkAsActivelyScheduled -func (s *store) MarkAsActivelyScheduled(ctx context.Context, id uint64) error { - return dbMarkAsActivelyScheduled(ctx, s.db, id) -} - // GetById implements fulfillment.Store.GetById func (s *store) GetById(ctx context.Context, id uint64) (*fulfillment.Record, error) { obj, err := dbGetById(ctx, s.db, id) diff --git a/pkg/code/data/fulfillment/postgres/store_test.go b/pkg/code/data/fulfillment/postgres/store_test.go index 13530cd6..32e15224 100644 --- a/pkg/code/data/fulfillment/postgres/store_test.go +++ b/pkg/code/data/fulfillment/postgres/store_test.go @@ -35,7 +35,7 @@ const ( nonce TEXT NULL, blockhash TEXT NULL, - virtual_signature TEXT NULL, + virtual_signature TEXT NULL UNIQUE, virtual_nonce TEXT NULL, virtual_blockhash TEXT NULL, @@ -50,6 +50,8 @@ const ( state INTEGER NOT NULL, + version INTEGER NOT NULL, + batch_insertion_id INTEGER NOT NULL, created_at timestamp with time zone NOT NULL diff --git a/pkg/code/data/fulfillment/store.go b/pkg/code/data/fulfillment/store.go index fbe0b097..6a018a23 100644 --- a/pkg/code/data/fulfillment/store.go +++ b/pkg/code/data/fulfillment/store.go @@ -3,9 +3,17 @@ package fulfillment import ( "context" + "github.com/pkg/errors" + "github.com/code-payments/code-server/pkg/database/query" ) +var ( + ErrFulfillmentNotFound = errors.New("no fulfillment could be found") + ErrFulfillmentExists = errors.New("fulfillment exists") + ErrStaleVersion = errors.New("fulfillment version is stale") +) + type Store interface { // Count returns the total count of fulfillment records. Count(ctx context.Context) (uint64, error) @@ -50,7 +58,6 @@ type Store interface { // Update updates an existing fulfillment record // // Note 1: Updating pre-sorting metadata is allowed but limited to certain fulfillment types - // Note 2: Updating DisableActiveScheduling is done in MarkAsActivelyScheduled, due to no distributed locks existing Update(ctx context.Context, record *Record) error // GetById find the fulfillment recofd for a given ID @@ -62,9 +69,6 @@ type Store interface { // GetByVirtualSignature finds the fulfillment record for a given virtual signature. GetByVirtualSignature(ctx context.Context, signature string) (*Record, error) - // MarkAsActivelyScheduled marks a fulfillment as actively scheduled - MarkAsActivelyScheduled(ctx context.Context, id uint64) error - // GetAllByState returns all fulfillment records for a given state. // // Returns ErrNotFound if no records are found. diff --git a/pkg/code/data/fulfillment/tests/tests.go b/pkg/code/data/fulfillment/tests/tests.go index 5835c05e..3c072e44 100644 --- a/pkg/code/data/fulfillment/tests/tests.go +++ b/pkg/code/data/fulfillment/tests/tests.go @@ -21,7 +21,8 @@ func RunTests(t *testing.T, s fulfillment.Store, teardown func()) { for _, tf := range []func(t *testing.T, s fulfillment.Store){ testRoundTrip, testBatchPut, - testUpdate, + testUpdateHappyPath, + testUpdateStaleVersion, testGetAllByState, testGetAllByIntent, testGetAllByAction, @@ -75,15 +76,18 @@ func testRoundTrip(t *testing.T, s fulfillment.Store) { err = s.PutAll(ctx, &expected) require.NoError(t, err) assert.EqualValues(t, 1, expected.Id) + assert.EqualValues(t, 1, expected.Version) actual, err = s.GetBySignature(ctx, "test_signature") require.NoError(t, err) assert.EqualValues(t, 1, actual.Id) + assert.EqualValues(t, 1, actual.Version) assertEquivalentRecords(t, actual, &cloned) actual, err = s.GetByVirtualSignature(ctx, "test_virtual_signature") require.NoError(t, err) assert.EqualValues(t, 1, actual.Id) + assert.EqualValues(t, 1, actual.Version) assertEquivalentRecords(t, actual, &cloned) actual, err = s.GetById(ctx, 2) @@ -125,10 +129,12 @@ func testRoundTrip(t *testing.T, s fulfillment.Store) { err = s.PutAll(ctx, &expected) require.NoError(t, err) assert.True(t, expected.Id >= 2) + assert.EqualValues(t, 1, expected.Version) actual, err = s.GetById(ctx, expected.Id) require.NoError(t, err) assert.EqualValues(t, expected.Id, actual.Id) + assert.EqualValues(t, 1, actual.Version) assertEquivalentRecords(t, actual, &cloned) assert.Equal(t, fulfillment.ErrFulfillmentExists, s.PutAll(ctx, &expected)) @@ -174,13 +180,18 @@ func testBatchPut(t *testing.T, s fulfillment.Store) { for i, fulfillmentRecord := range inserted { assert.EqualValues(t, i+1, fulfillmentRecord.Id) + assert.EqualValues(t, 1, fulfillmentRecord.Version) actual, err := s.GetById(ctx, fulfillmentRecord.Id) require.NoError(t, err) + assert.EqualValues(t, i+1, actual.Id) + assert.EqualValues(t, 1, actual.Version) assertEquivalentRecords(t, expected[i], actual) actual, err = s.GetBySignature(ctx, *fulfillmentRecord.Signature) require.NoError(t, err) + assert.EqualValues(t, i+1, actual.Id) + assert.EqualValues(t, 1, actual.Version) assertEquivalentRecords(t, expected[i], actual) } @@ -222,12 +233,10 @@ func testBatchPut(t *testing.T, s fulfillment.Store) { }) } -func testUpdate(t *testing.T, s fulfillment.Store) { - t.Run("testUpdate", func(t *testing.T) { +func testUpdateHappyPath(t *testing.T, s fulfillment.Store) { + t.Run("testUpdateHappyPath", func(t *testing.T) { ctx := context.Background() - assert.Equal(t, fulfillment.ErrFulfillmentNotFound, s.MarkAsActivelyScheduled(ctx, 1)) - expected := fulfillment.Record{ Intent: "test_intent", IntentType: intent.SendPublicPayment, @@ -253,21 +262,20 @@ func testUpdate(t *testing.T, s fulfillment.Store) { err := s.PutAll(ctx, &expected) require.NoError(t, err) assert.EqualValues(t, 1, expected.Id) - - require.NoError(t, s.MarkAsActivelyScheduled(ctx, 1)) - actual, err := s.GetById(ctx, 1) - require.NoError(t, err) - assert.False(t, actual.DisableActiveScheduling) - expected.DisableActiveScheduling = false + assert.EqualValues(t, 1, expected.Version) expected.State = fulfillment.StatePending + expected.DisableActiveScheduling = false cloned := expected.Clone() err = s.Update(ctx, &expected) require.NoError(t, err) assert.EqualValues(t, 1, expected.Id) + assert.EqualValues(t, 2, expected.Version) - actual, err = s.GetById(ctx, 1) + actual, err := s.GetById(ctx, 1) require.NoError(t, err) + assert.EqualValues(t, 1, actual.Id) + assert.EqualValues(t, 2, actual.Version) assertEquivalentRecords(t, actual, &cloned) expected.Signature = pointer.String("test_signature") @@ -278,32 +286,86 @@ func testUpdate(t *testing.T, s fulfillment.Store) { err = s.Update(ctx, &expected) require.NoError(t, err) assert.EqualValues(t, 1, expected.Id) + assert.EqualValues(t, 3, expected.Version) actual, err = s.GetBySignature(ctx, "test_signature") require.NoError(t, err) + assert.EqualValues(t, 1, actual.Id) + assert.EqualValues(t, 3, actual.Version) assertEquivalentRecords(t, actual, &cloned) expected.IntentOrderingIndex = math.MaxInt64 expected.ActionOrderingIndex = math.MaxInt32 expected.FulfillmentOrderingIndex = math.MaxInt32 - 1 - - actual, err = s.GetBySignature(ctx, "test_signature") - require.NoError(t, err) - assertEquivalentRecords(t, actual, &cloned) - expected.Data = nil expected.State = fulfillment.StateConfirmed cloned = expected.Clone() err = s.Update(ctx, &expected) require.NoError(t, err) assert.EqualValues(t, 1, expected.Id) + assert.EqualValues(t, 4, expected.Version) actual, err = s.GetBySignature(ctx, "test_signature") require.NoError(t, err) + assert.EqualValues(t, 1, actual.Id) + assert.EqualValues(t, 4, actual.Version) assertEquivalentRecords(t, actual, &cloned) expected.Id = 100 - assert.Equal(t, fulfillment.ErrFulfillmentNotFound, s.Update(ctx, &expected)) + assert.Equal(t, fulfillment.ErrStaleVersion, s.Update(ctx, &expected)) + }) +} + +func testUpdateStaleVersion(t *testing.T, s fulfillment.Store) { + t.Run("testUpdateStaleVersion", func(t *testing.T) { + ctx := context.Background() + + expected := fulfillment.Record{ + Intent: "test_intent", + IntentType: intent.SendPublicPayment, + ActionId: 4, + ActionType: action.NoPrivacyWithdraw, + FulfillmentType: fulfillment.NoPrivacyWithdraw, + Data: nil, + Signature: nil, + Nonce: nil, + Blockhash: nil, + Source: "test_source", + Destination: pointer.String("test_destination"), + IntentOrderingIndex: 1, + ActionOrderingIndex: 2, + FulfillmentOrderingIndex: 3, + DisableActiveScheduling: true, + State: fulfillment.StateUnknown, + CreatedAt: time.Now(), + } + cloned := expected.Clone() + + err := s.PutAll(ctx, &expected) + require.NoError(t, err) + assert.EqualValues(t, 1, expected.Id) + assert.EqualValues(t, 1, expected.Version) + + actual, err := s.GetById(ctx, 1) + require.NoError(t, err) + assert.EqualValues(t, 1, actual.Id) + assert.EqualValues(t, 1, actual.Version) + assertEquivalentRecords(t, actual, &cloned) + + stale := expected.Clone() + stale.State = fulfillment.StatePending + stale.DisableActiveScheduling = false + stale.Version -= 1 + + assert.Equal(t, fulfillment.ErrStaleVersion, s.Update(ctx, &stale)) + assert.EqualValues(t, 1, stale.Id) + assert.EqualValues(t, 0, stale.Version) + + actual, err = s.GetById(ctx, 1) + require.NoError(t, err) + assert.EqualValues(t, 1, actual.Id) + assert.EqualValues(t, 1, actual.Version) + assertEquivalentRecords(t, actual, &cloned) }) } diff --git a/pkg/code/data/intent/intent.go b/pkg/code/data/intent/intent.go index c8a5edf5..5c4c04a9 100644 --- a/pkg/code/data/intent/intent.go +++ b/pkg/code/data/intent/intent.go @@ -7,12 +7,6 @@ import ( "github.com/code-payments/code-server/pkg/currency" ) -var ( - ErrIntentNotFound = errors.New("no records could be found") - ErrInvalidIntent = errors.New("invalid intent") - ErrMultilpeIntentsFound = errors.New("multiple records found") -) - type State uint8 const ( @@ -58,6 +52,8 @@ type Record struct { State State + Version uint64 + CreatedAt time.Time } @@ -146,6 +142,8 @@ func (r *Record) Clone() Record { State: r.State, + Version: r.Version, + CreatedAt: r.CreatedAt, } } @@ -167,6 +165,8 @@ func (r *Record) CopyTo(dst *Record) { dst.State = r.State + dst.Version = r.Version + dst.CreatedAt = r.CreatedAt } diff --git a/pkg/code/data/intent/memory/store.go b/pkg/code/data/intent/memory/store.go index 1ddcf5ff..46cda81b 100644 --- a/pkg/code/data/intent/memory/store.go +++ b/pkg/code/data/intent/memory/store.go @@ -266,8 +266,14 @@ func (s *store) Save(ctx context.Context, data *intent.Record) error { s.last++ if item := s.find(data); item != nil { - // Only update state + if item.Version != data.Version { + return intent.ErrStaleVersion + } + + data.Version++ + item.State = data.State + item.Version = data.Version } else { if data.Id == 0 { data.Id = s.last @@ -275,6 +281,8 @@ func (s *store) Save(ctx context.Context, data *intent.Record) error { if data.CreatedAt.IsZero() { data.CreatedAt = time.Now() } + data.Version++ + c := data.Clone() s.records = append(s.records, &c) } diff --git a/pkg/code/data/intent/postgres/model.go b/pkg/code/data/intent/postgres/model.go index 9449d61d..7e88f41b 100644 --- a/pkg/code/data/intent/postgres/model.go +++ b/pkg/code/data/intent/postgres/model.go @@ -42,6 +42,7 @@ type intentModel struct { IsMicroPayment bool `db:"is_micro_payment"` ExtendedMetadata []byte `db:"extended_metadata"` State uint `db:"state"` + Version int64 `db:"version"` CreatedAt time.Time `db:"created_at"` } @@ -62,6 +63,7 @@ func toIntentModel(obj *intent.Record) (*intentModel, error) { ExtendedMetadata: obj.ExtendedMetadata, State: uint(obj.State), CreatedAt: obj.CreatedAt, + Version: int64(obj.Version), } switch obj.IntentType { @@ -110,6 +112,7 @@ func fromIntentModel(obj *intentModel) *intent.Record { InitiatorOwnerAccount: obj.InitiatorOwner, ExtendedMetadata: obj.ExtendedMetadata, State: intent.State(obj.State), + Version: uint64(obj.Version), CreatedAt: obj.CreatedAt.UTC(), } @@ -159,16 +162,16 @@ func fromIntentModel(obj *intentModel) *intent.Record { func (m *intentModel) dbSave(ctx context.Context, db *sqlx.DB) error { return pgutil.ExecuteInTx(ctx, db, sql.LevelDefault, func(tx *sqlx.Tx) error { query := `INSERT INTO ` + intentTableName + ` - (intent_id, intent_type, owner, source, destination_owner, destination, quantity, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, extended_metadata, state, created_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20) + (intent_id, intent_type, owner, source, destination_owner, destination, quantity, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, extended_metadata, state, version, created_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20 + 1, $21) ON CONFLICT (intent_id) DO UPDATE - SET state = $19 - WHERE ` + intentTableName + `.intent_id = $1 + SET state = $19, version = ` + intentTableName + `.version + 1 + WHERE ` + intentTableName + `.intent_id = $1 AND ` + intentTableName + `.version = $20 RETURNING - id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, extended_metadata, state, created_at` + id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, extended_metadata, state, version, created_at` err := tx.QueryRowxContext( ctx, @@ -192,17 +195,18 @@ func (m *intentModel) dbSave(ctx context.Context, db *sqlx.DB) error { m.IsMicroPayment, m.ExtendedMetadata, m.State, + m.Version, m.CreatedAt, ).StructScan(m) - return pgutil.CheckNoRows(err, intent.ErrInvalidIntent) + return pgutil.CheckNoRows(err, intent.ErrStaleVersion) }) } func dbGetIntent(ctx context.Context, db *sqlx.DB, intentID string) (*intentModel, error) { res := &intentModel{} - query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, extended_metadata, state, created_at + query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, extended_metadata, state, version, created_at FROM ` + intentTableName + ` WHERE intent_id = $1 LIMIT 1` @@ -217,7 +221,7 @@ func dbGetIntent(ctx context.Context, db *sqlx.DB, intentID string) (*intentMode func dbGetAllByOwner(ctx context.Context, db *sqlx.DB, owner string, cursor q.Cursor, limit uint64, direction q.Ordering) ([]*intentModel, error) { res := []*intentModel{} - query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, extended_metadata, state, created_at + query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, extended_metadata, state, version, created_at FROM ` + intentTableName + ` WHERE (owner = $1 OR destination_owner = $1) ` @@ -240,7 +244,7 @@ func dbGetAllByOwner(ctx context.Context, db *sqlx.DB, owner string, cursor q.Cu func dbGetLatestByInitiatorAndType(ctx context.Context, db *sqlx.DB, intentType intent.Type, owner string) (*intentModel, error) { res := &intentModel{} - query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, extended_metadata, state, created_at + query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, extended_metadata, state, version, created_at FROM ` + intentTableName + ` WHERE owner = $1 AND intent_type = $2 ORDER BY created_at DESC @@ -256,7 +260,7 @@ func dbGetLatestByInitiatorAndType(ctx context.Context, db *sqlx.DB, intentType func dbGetOriginalGiftCardIssuedIntent(ctx context.Context, db *sqlx.DB, giftCardVault string) (*intentModel, error) { res := []*intentModel{} - query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, extended_metadata, state, created_at + query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, extended_metadata, state, version, created_at FROM ` + intentTableName + ` WHERE destination = $1 and intent_type = $2 AND state != $3 AND is_remote_send IS TRUE LIMIT 2 @@ -288,7 +292,7 @@ func dbGetOriginalGiftCardIssuedIntent(ctx context.Context, db *sqlx.DB, giftCar func dbGetGiftCardClaimedIntent(ctx context.Context, db *sqlx.DB, giftCardVault string) (*intentModel, error) { res := []*intentModel{} - query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, extended_metadata, state, created_at + query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, extended_metadata, state, version, created_at FROM ` + intentTableName + ` WHERE source = $1 and intent_type = $2 AND state != $3 AND is_remote_send IS TRUE LIMIT 2 diff --git a/pkg/code/data/intent/postgres/store.go b/pkg/code/data/intent/postgres/store.go index ba5f08c3..d1b7b658 100644 --- a/pkg/code/data/intent/postgres/store.go +++ b/pkg/code/data/intent/postgres/store.go @@ -5,9 +5,10 @@ import ( "database/sql" "time" + "github.com/jmoiron/sqlx" + "github.com/code-payments/code-server/pkg/code/data/intent" "github.com/code-payments/code-server/pkg/database/query" - "github.com/jmoiron/sqlx" ) type store struct { diff --git a/pkg/code/data/intent/postgres/store_test.go b/pkg/code/data/intent/postgres/store_test.go index 2d96f9ac..db46128b 100644 --- a/pkg/code/data/intent/postgres/store_test.go +++ b/pkg/code/data/intent/postgres/store_test.go @@ -26,9 +26,9 @@ const ( intent_type INTEGER NOT NULL, owner text NOT NULL, - source text NULL, - destination text NULL, - destination_owner text NULL, + source TEXT NULL, + destination TEXT NULL, + destination_owner TEXT NULL, quantity bigint NULL CHECK (quantity >= 0), @@ -46,9 +46,11 @@ const ( extended_metadata BYTEA NULL, - state integer NOT NULL, + state INTEGER NOT NULL, - created_at timestamp with time zone NOT NULL + version INTEGER NOT NULL, + + created_at TIMESTAMP WITH TIME ZONE NOT NULL ); ` diff --git a/pkg/code/data/intent/store.go b/pkg/code/data/intent/store.go index 7d63e432..083b9327 100644 --- a/pkg/code/data/intent/store.go +++ b/pkg/code/data/intent/store.go @@ -2,11 +2,18 @@ package intent import ( "context" + "errors" "time" "github.com/code-payments/code-server/pkg/database/query" ) +var ( + ErrIntentNotFound = errors.New("no records could be found") + ErrMultilpeIntentsFound = errors.New("multiple records found") + ErrStaleVersion = errors.New("intent version is stale") +) + type Store interface { // Save creates or updates an intent on the store. Save(ctx context.Context, record *Record) error diff --git a/pkg/code/data/intent/tests/tests.go b/pkg/code/data/intent/tests/tests.go index 94ea885c..70247660 100644 --- a/pkg/code/data/intent/tests/tests.go +++ b/pkg/code/data/intent/tests/tests.go @@ -18,7 +18,8 @@ func RunTests(t *testing.T, s intent.Store, teardown func()) { testExternalDepositRoundTrip, testSendPublicPaymentRoundTrip, testReceivePaymentsPubliclyRoundTrip, - testUpdate, + testUpdateHappyPath, + testUpdateStaleRecord, testGetLatestByInitiatorAndType, testGetOriginalGiftCardIssuedIntent, testGetGiftCardClaimedIntent, @@ -50,6 +51,8 @@ func testOpenAccountsRoundTrip(t *testing.T, s intent.Store) { cloned := expected.Clone() err = s.Save(ctx, &expected) require.NoError(t, err) + assert.EqualValues(t, 1, expected.Id) + assert.EqualValues(t, 1, expected.Version) actual, err = s.Get(ctx, "test_intent_id") require.NoError(t, err) @@ -61,6 +64,7 @@ func testOpenAccountsRoundTrip(t *testing.T, s intent.Store) { assert.Equal(t, cloned.State, actual.State) assert.Equal(t, cloned.CreatedAt.Unix(), actual.CreatedAt.Unix()) assert.EqualValues(t, 1, actual.Id) + assert.EqualValues(t, 1, actual.Version) }) } @@ -89,6 +93,8 @@ func testExternalDepositRoundTrip(t *testing.T, s intent.Store) { cloned := expected.Clone() err = s.Save(ctx, &expected) require.NoError(t, err) + assert.EqualValues(t, 1, expected.Id) + assert.EqualValues(t, 1, expected.Version) actual, err = s.Get(ctx, "test_intent_id") require.NoError(t, err) @@ -103,6 +109,7 @@ func testExternalDepositRoundTrip(t *testing.T, s intent.Store) { assert.Equal(t, cloned.State, actual.State) assert.Equal(t, cloned.CreatedAt.Unix(), actual.CreatedAt.Unix()) assert.EqualValues(t, 1, actual.Id) + assert.EqualValues(t, 1, actual.Version) }) } @@ -139,6 +146,8 @@ func testSendPublicPaymentRoundTrip(t *testing.T, s intent.Store) { cloned := expected.Clone() err = s.Save(ctx, &expected) require.NoError(t, err) + assert.EqualValues(t, 1, expected.Id) + assert.EqualValues(t, 1, expected.Version) actual, err = s.Get(ctx, "test_intent_id") require.NoError(t, err) @@ -159,6 +168,7 @@ func testSendPublicPaymentRoundTrip(t *testing.T, s intent.Store) { assert.Equal(t, cloned.State, actual.State) assert.Equal(t, cloned.CreatedAt.Unix(), actual.CreatedAt.Unix()) assert.EqualValues(t, 1, actual.Id) + assert.EqualValues(t, 1, actual.Version) }) } @@ -193,6 +203,8 @@ func testReceivePaymentsPubliclyRoundTrip(t *testing.T, s intent.Store) { cloned := expected.Clone() err = s.Save(ctx, &expected) require.NoError(t, err) + assert.EqualValues(t, 1, expected.Id) + assert.EqualValues(t, 1, expected.Version) actual, err = s.Get(ctx, "test_intent_id") require.NoError(t, err) @@ -213,11 +225,12 @@ func testReceivePaymentsPubliclyRoundTrip(t *testing.T, s intent.Store) { assert.Equal(t, cloned.State, actual.State) assert.Equal(t, cloned.CreatedAt.Unix(), actual.CreatedAt.Unix()) assert.EqualValues(t, 1, actual.Id) + assert.EqualValues(t, 1, actual.Version) }) } -func testUpdate(t *testing.T, s intent.Store) { - t.Run("testUpdate", func(t *testing.T) { +func testUpdateHappyPath(t *testing.T, s intent.Store) { + t.Run("testUpdateHappyPath", func(t *testing.T) { ctx := context.Background() expected := intent.Record{ @@ -231,16 +244,54 @@ func testUpdate(t *testing.T, s intent.Store) { err := s.Save(ctx, &expected) require.NoError(t, err) assert.EqualValues(t, 1, expected.Id) + assert.EqualValues(t, 1, expected.Version) expected.State = intent.StatePending err = s.Save(ctx, &expected) require.NoError(t, err) + assert.EqualValues(t, 1, expected.Id) + assert.EqualValues(t, 2, expected.Version) actual, err := s.Get(ctx, "test_intent_id") require.NoError(t, err) assert.Equal(t, intent.StatePending, actual.State) assert.EqualValues(t, 1, actual.Id) + assert.EqualValues(t, 2, actual.Version) + }) +} + +func testUpdateStaleRecord(t *testing.T, s intent.Store) { + t.Run("testUpdateStaleRecord", func(t *testing.T) { + ctx := context.Background() + + expected := intent.Record{ + IntentId: "test_intent_id", + IntentType: intent.OpenAccounts, + InitiatorOwnerAccount: "test_owner", + OpenAccountsMetadata: &intent.OpenAccountsMetadata{}, + State: intent.StateUnknown, + CreatedAt: time.Now(), + } + err := s.Save(ctx, &expected) + require.NoError(t, err) + assert.EqualValues(t, 1, expected.Id) + assert.EqualValues(t, 1, expected.Version) + + stale := expected.Clone() + expected.State = intent.StatePending + stale.Version -= 1 + + err = s.Save(ctx, &stale) + assert.Equal(t, intent.ErrStaleVersion, err) + assert.EqualValues(t, 1, stale.Id) + assert.EqualValues(t, 0, stale.Version) + + actual, err := s.Get(ctx, "test_intent_id") + require.NoError(t, err) + assert.Equal(t, intent.StateUnknown, actual.State) + assert.EqualValues(t, 1, actual.Id) + assert.EqualValues(t, 1, actual.Version) }) } diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index e66c36fa..d14d4889 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -176,7 +176,6 @@ type DatabaseData interface { GetNextSchedulableFulfillmentByAddress(ctx context.Context, address string, intentOrderingIndex uint64, actionOrderingIndex, fulfillmentOrderingIndex uint32) (*fulfillment.Record, error) PutAllFulfillments(ctx context.Context, records ...*fulfillment.Record) error UpdateFulfillment(ctx context.Context, record *fulfillment.Record) error - MarkFulfillmentAsActivelyScheduled(ctx context.Context, id uint64) error // Intents // -------------------------------------------------------------------------------- @@ -618,9 +617,6 @@ func (dp *DatabaseProvider) PutAllFulfillments(ctx context.Context, records ...* func (dp *DatabaseProvider) UpdateFulfillment(ctx context.Context, record *fulfillment.Record) error { return dp.fulfillments.Update(ctx, record) } -func (dp *DatabaseProvider) MarkFulfillmentAsActivelyScheduled(ctx context.Context, id uint64) error { - return dp.fulfillments.MarkAsActivelyScheduled(ctx, id) -} // Intents // -------------------------------------------------------------------------------- diff --git a/pkg/code/data/nonce/nonce.go b/pkg/code/data/nonce/nonce.go index 9b4484bb..cd0cc142 100644 --- a/pkg/code/data/nonce/nonce.go +++ b/pkg/code/data/nonce/nonce.go @@ -23,11 +23,6 @@ const ( EnvironmentInstanceSolanaTestnet = "testnet" ) -var ( - ErrStaleVersion = errors.New("nonce version is stale") - ErrNonceNotFound = errors.New("no records could be found") -) - type State uint8 const ( diff --git a/pkg/code/data/nonce/store.go b/pkg/code/data/nonce/store.go index 7288b2c8..e2a60f2d 100644 --- a/pkg/code/data/nonce/store.go +++ b/pkg/code/data/nonce/store.go @@ -2,11 +2,17 @@ package nonce import ( "context" + "errors" "time" "github.com/code-payments/code-server/pkg/database/query" ) +var ( + ErrStaleVersion = errors.New("nonce version is stale") + ErrNonceNotFound = errors.New("no nonce could be found") +) + type Store interface { // Count returns the total count of nonce accounts within an environment instance Count(ctx context.Context, env Environment, instance string) (uint64, error) diff --git a/pkg/code/data/nonce/tests/tests.go b/pkg/code/data/nonce/tests/tests.go index 07f97425..8ed0dcd5 100644 --- a/pkg/code/data/nonce/tests/tests.go +++ b/pkg/code/data/nonce/tests/tests.go @@ -20,7 +20,8 @@ import ( func RunTests(t *testing.T, s nonce.Store, teardown func()) { for _, tf := range []func(t *testing.T, s nonce.Store){ testRoundTrip, - testUpdate, + testUpdateHappyPath, + testUpdateStaleRecord, testGetAllByState, testGetCount, testBatchClaimAvailableByPurpose, @@ -65,8 +66,8 @@ func testRoundTrip(t *testing.T, s nonce.Store) { }) } -func testUpdate(t *testing.T, s nonce.Store) { - t.Run("testUpdate", func(t *testing.T) { +func testUpdateHappyPath(t *testing.T, s nonce.Store) { + t.Run("testUpdateHappyPath", func(t *testing.T) { ctx := context.Background() expected := nonce.Record{ @@ -110,6 +111,51 @@ func testUpdate(t *testing.T, s nonce.Store) { }) } +func testUpdateStaleRecord(t *testing.T, s nonce.Store) { + t.Run("testUpdateStaleRecord", func(t *testing.T) { + ctx := context.Background() + + expected := nonce.Record{ + Address: "test_address", + Authority: "test_authority", + Blockhash: "test_blockhash", + Environment: nonce.EnvironmentSolana, + EnvironmentInstance: nonce.EnvironmentInstanceSolanaMainnet, + Purpose: nonce.PurposeInternalServerProcess, + } + cloned := expected.Clone() + err := s.Save(ctx, &expected) + require.NoError(t, err) + assert.EqualValues(t, 1, expected.Id) + assert.EqualValues(t, 1, expected.Version) + + actual, err := s.Get(ctx, "test_address") + require.NoError(t, err) + assertEquivalentRecords(t, &cloned, actual) + assert.EqualValues(t, 1, actual.Id) + assert.EqualValues(t, 1, actual.Version) + + stale := actual.Clone() + stale.State = nonce.StateClaimed + stale.Blockhash = "test_blockhash2" + stale.Signature = "test_signature" + stale.ClaimNodeID = pointer.String("test_claim_node_id") + stale.ClaimExpiresAt = pointer.Time(time.Now().Add(time.Hour)) + stale.Version -= 1 + + err = s.Save(ctx, &stale) + assert.Equal(t, nonce.ErrStaleVersion, err) + assert.EqualValues(t, 1, stale.Id) + assert.EqualValues(t, 0, stale.Version) + + actual, err = s.Get(ctx, "test_address") + require.NoError(t, err) + assertEquivalentRecords(t, &cloned, actual) + assert.EqualValues(t, 1, actual.Id) + assert.EqualValues(t, 1, actual.Version) + }) +} + func testGetAllByState(t *testing.T, s nonce.Store) { t.Run("testGetAllByState", func(t *testing.T) { ctx := context.Background() diff --git a/pkg/code/server/account/server.go b/pkg/code/server/account/server.go index ee964ebe..f460296c 100644 --- a/pkg/code/server/account/server.go +++ b/pkg/code/server/account/server.go @@ -379,7 +379,7 @@ func (s *server) getProtoAccountInfo(ctx context.Context, records *common.Accoun // Gift cards that are close to the auto-return window are marked as expired in // a consistent manner as SubmitIntent to avoid race conditions with the auto-return. - if time.Since(records.General.CreatedAt) > async_account.GiftCardExpiry-15*time.Minute { + if time.Since(records.General.CreatedAt) >= async_account.GiftCardExpiry { claimState = accountpb.TokenAccountInfo_CLAIM_STATE_EXPIRED } diff --git a/pkg/code/server/transaction/intent.go b/pkg/code/server/transaction/intent.go index f6a81f36..07fbd6d2 100644 --- a/pkg/code/server/transaction/intent.go +++ b/pkg/code/server/transaction/intent.go @@ -215,15 +215,10 @@ func (s *transactionServer) SubmitIntent(streamer transactionpb.Transaction_Subm CreatedAt: time.Now(), } - // Distributed locking. This is a partial view, since additional locking - // requirements may not be known until populating intent metadata. - intentLock := s.intentLocks.Get([]byte(intentId)) initiatorOwnerLock := s.ownerLocks.Get(initiatorOwnerAccount.PublicKey().ToBytes()) - intentLock.Lock() initiatorOwnerLock.Lock() defer func() { initiatorOwnerLock.Unlock() - intentLock.Unlock() }() existingIntentRecord, err := s.data.GetIntent(ctx, intentId) @@ -956,7 +951,7 @@ func (s *transactionServer) VoidGiftCard(ctx context.Context, req *transactionpb }, nil } - if time.Since(accountInfoRecord.CreatedAt) > async_account.GiftCardExpiry-15*time.Minute { + if time.Since(accountInfoRecord.CreatedAt) >= async_account.GiftCardExpiry { return &transactionpb.VoidGiftCardResponse{ Result: transactionpb.VoidGiftCardResponse_OK, }, nil diff --git a/pkg/code/server/transaction/intent_handler.go b/pkg/code/server/transaction/intent_handler.go index 8ffcdfe8..9b8f642b 100644 --- a/pkg/code/server/transaction/intent_handler.go +++ b/pkg/code/server/transaction/intent_handler.go @@ -1315,9 +1315,7 @@ func validateClaimedGiftCard(ctx context.Context, data code_data.Provider, giftC // Part 6: Are we within the threshold for auto-return back to the issuer? // - // todo: I think we use the same trick of doing deadline - x minutes to avoid race - // conditions without distributed locks. - if time.Since(accountInfoRecord.CreatedAt) > async_account.GiftCardExpiry-15*time.Minute { + if time.Since(accountInfoRecord.CreatedAt) >= async_account.GiftCardExpiry-15*time.Minute { return newStaleStateError("gift card is expired") } diff --git a/pkg/code/server/transaction/server.go b/pkg/code/server/transaction/server.go index 0b557b72..28a7a989 100644 --- a/pkg/code/server/transaction/server.go +++ b/pkg/code/server/transaction/server.go @@ -38,8 +38,6 @@ type transactionServer struct { feeCollector *common.Account - // todo: distributed locks - intentLocks *sync_util.StripedLock ownerLocks *sync_util.StripedLock giftCardLocks *sync_util.StripedLock @@ -81,7 +79,6 @@ func NewTransactionServer( noncePool: noncePool, - intentLocks: sync_util.NewStripedLock(stripedLockParallelization), ownerLocks: sync_util.NewStripedLock(stripedLockParallelization), giftCardLocks: sync_util.NewStripedLock(stripedLockParallelization), }