diff --git a/go.mod b/go.mod index be175d13..f0223d16 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.23.0 require ( github.com/aws/aws-sdk-go-v2 v0.17.0 github.com/bits-and-blooms/bloom/v3 v3.1.0 - github.com/code-payments/code-protobuf-api v1.19.1-0.20250602171721-c057e3310d81 + github.com/code-payments/code-protobuf-api v1.19.1-0.20250603030803-cbe2bfca5052 github.com/code-payments/code-vm-indexer v0.1.11-0.20241028132209-23031e814fba github.com/emirpasic/gods v1.12.0 github.com/envoyproxy/protoc-gen-validate v1.2.1 diff --git a/go.sum b/go.sum index 2d0aa842..ac4ec632 100644 --- a/go.sum +++ b/go.sum @@ -78,8 +78,8 @@ github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnht github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= -github.com/code-payments/code-protobuf-api v1.19.1-0.20250602171721-c057e3310d81 h1:sX2VijdiCok3hJI1Yv/a+NslvVzCTolpNDQIOFzk9DE= -github.com/code-payments/code-protobuf-api v1.19.1-0.20250602171721-c057e3310d81/go.mod h1:ee6TzKbgMS42ZJgaFEMG3c4R3dGOiffHSu6MrY7WQvs= +github.com/code-payments/code-protobuf-api v1.19.1-0.20250603030803-cbe2bfca5052 h1:lfxaakPHAWFPukrqsUn8nYdpw1WaXQfP4KLCzmL8UxU= +github.com/code-payments/code-protobuf-api v1.19.1-0.20250603030803-cbe2bfca5052/go.mod h1:ee6TzKbgMS42ZJgaFEMG3c4R3dGOiffHSu6MrY7WQvs= github.com/code-payments/code-vm-indexer v0.1.11-0.20241028132209-23031e814fba h1:Bkp+gmeb6Y2PWXfkSCTMBGWkb2P1BujRDSjWeI+0j5I= github.com/code-payments/code-vm-indexer v0.1.11-0.20241028132209-23031e814fba/go.mod h1:jSiifpiBpyBQ8q2R0MGEbkSgWC6sbdRTyDBntmW+j1E= github.com/containerd/continuity v0.0.0-20190827140505-75bee3e2ccb6 h1:NmTXa/uVnDyp0TY5MKi197+3HWcnYWfnHGyaFthlnGw= diff --git a/pkg/code/async/sequencer/fulfillment_handler.go b/pkg/code/async/sequencer/fulfillment_handler.go index d599328d..e9fbd227 100644 --- a/pkg/code/async/sequencer/fulfillment_handler.go +++ b/pkg/code/async/sequencer/fulfillment_handler.go @@ -7,6 +7,7 @@ import ( "github.com/mr-tron/base58" commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" + transactionpb "github.com/code-payments/code-protobuf-api/generated/go/transaction/v2" indexerpb "github.com/code-payments/code-vm-indexer/generated/indexer/v1" "github.com/code-payments/code-server/pkg/code/common" @@ -209,7 +210,7 @@ func (h *NoPrivacyTransferWithAuthorityFulfillmentHandler) CanSubmitToBlockchain // The source user account is a Code account, so we must validate it exists on // the blockchain prior to sending funds from it. - isSourceAccountCreated, err := isTokenAccountOnBlockchain(ctx, h.data, fulfillmentRecord.Source) + isSourceAccountCreated, err := isAccountInitialized(ctx, h.data, fulfillmentRecord.Source) if err != nil { return false, err } else if !isSourceAccountCreated { @@ -217,12 +218,27 @@ func (h *NoPrivacyTransferWithAuthorityFulfillmentHandler) CanSubmitToBlockchain } // The destination user account might be a Code account or external wallet, so we - // must validate it exists on the blockchain prior to send funds to it. - isDestinationAccountCreated, err := isTokenAccountOnBlockchain(ctx, h.data, *fulfillmentRecord.Destination) + // must validate it exists on the blockchain prior to sending funds to it, or if we'll + // be creating it at time of send. + destinationAccount, err := common.NewAccountFromPublicKeyString(*fulfillmentRecord.Destination) if err != nil { return false, err - } else if !isDestinationAccountCreated { - return false, nil + } + isInternalTransfer, err := isInternalVmTransfer(ctx, h.data, destinationAccount) + if err != nil { + return false, err + } + hasCreateOnSendFee, err := h.data.HasFeeAction(ctx, fulfillmentRecord.Intent, transactionpb.FeePaymentAction_CREATE_ON_SEND_WITHDRAWAL) + if err != nil { + return false, err + } + if isInternalTransfer || !hasCreateOnSendFee { + isDestinationAccountCreated, err := isAccountInitialized(ctx, h.data, *fulfillmentRecord.Destination) + if err != nil { + return false, err + } else if !isDestinationAccountCreated { + return false, nil + } } // Check whether there's an earlier fulfillment that should be scheduled first @@ -269,17 +285,17 @@ func (h *NoPrivacyTransferWithAuthorityFulfillmentHandler) SupportsOnDemandTrans } func (h *NoPrivacyTransferWithAuthorityFulfillmentHandler) MakeOnDemandTransaction(ctx context.Context, fulfillmentRecord *fulfillment.Record, selectedNonce *transaction_util.Nonce) (*solana.Transaction, error) { - virtualSignatureBytes, err := base58.Decode(*fulfillmentRecord.VirtualSignature) + actionRecord, err := h.data.GetActionById(ctx, fulfillmentRecord.Intent, fulfillmentRecord.ActionId) if err != nil { return nil, err } - virtualNonce, err := common.NewAccountFromPublicKeyString(*fulfillmentRecord.VirtualNonce) + virtualSignatureBytes, err := base58.Decode(*fulfillmentRecord.VirtualSignature) if err != nil { return nil, err } - actionRecord, err := h.data.GetActionById(ctx, fulfillmentRecord.Intent, fulfillmentRecord.ActionId) + virtualNonce, err := common.NewAccountFromPublicKeyString(*fulfillmentRecord.VirtualNonce) if err != nil { return nil, err } @@ -354,6 +370,24 @@ func (h *NoPrivacyTransferWithAuthorityFulfillmentHandler) MakeOnDemandTransacti *actionRecord.Quantity, ) } else { + isCreateOnSend, err := h.data.HasFeeAction(ctx, fulfillmentRecord.Intent, transactionpb.FeePaymentAction_CREATE_ON_SEND_WITHDRAWAL) + if err != nil { + return &solana.Transaction{}, err + } + + var destinationOwnerAccount *common.Account + if isCreateOnSend { + intentRecord, err := h.data.GetIntent(ctx, fulfillmentRecord.Intent) + if err != nil { + return nil, err + } + + destinationOwnerAccount, err = common.NewAccountFromPublicKeyString(intentRecord.SendPublicPaymentMetadata.DestinationOwnerAccount) + if err != nil { + return nil, err + } + } + txn, makeTxnErr = transaction_util.MakeExternalTransferWithAuthorityTransaction( selectedNonce.Account, selectedNonce.Blockhash, @@ -368,7 +402,10 @@ func (h *NoPrivacyTransferWithAuthorityFulfillmentHandler) MakeOnDemandTransacti sourceMemory, sourceIndex, + isCreateOnSend, + destinationOwnerAccount, destinationTokenAccount, + *actionRecord.Quantity, ) } @@ -397,7 +434,7 @@ func (h *NoPrivacyWithdrawFulfillmentHandler) CanSubmitToBlockchain(ctx context. // The source user account is a Code account, so we must validate it exists on // the blockchain prior to sending funds from it. - isSourceAccountCreated, err := isTokenAccountOnBlockchain(ctx, h.data, fulfillmentRecord.Source) + isSourceAccountCreated, err := isAccountInitialized(ctx, h.data, fulfillmentRecord.Source) if err != nil { return false, err } else if !isSourceAccountCreated { @@ -406,7 +443,7 @@ func (h *NoPrivacyWithdrawFulfillmentHandler) CanSubmitToBlockchain(ctx context. // The destination user account might be a Code account or external wallet, so we // must validate it exists on the blockchain prior to send funds to it. - isDestinationAccountCreated, err := isTokenAccountOnBlockchain(ctx, h.data, *fulfillmentRecord.Destination) + isDestinationAccountCreated, err := isAccountInitialized(ctx, h.data, *fulfillmentRecord.Destination) if err != nil { return false, err } else if !isDestinationAccountCreated { @@ -700,7 +737,6 @@ func (h *CloseEmptyTimelockAccountFulfillmentHandler) OnFailure(ctx context.Cont // is dust in the account. // // todo: Implement auto-recovery when we know the account is empty - // todo: Do "something" to indicate the client needs to resign a new transaction return false, nil } @@ -712,7 +748,7 @@ func (h *CloseEmptyTimelockAccountFulfillmentHandler) IsRevoked(ctx context.Cont return false, false, nil } -func isTokenAccountOnBlockchain(ctx context.Context, data code_data.Provider, address string) (bool, error) { +func isAccountInitialized(ctx context.Context, data code_data.Provider, address string) (bool, error) { // Try our cache of Code timelock accounts timelockRecord, err := data.GetTimelockByVault(ctx, address) if err == timelock.ErrTimelockNotFound { diff --git a/pkg/code/async/sequencer/intent_handler.go b/pkg/code/async/sequencer/intent_handler.go index 99ae91a0..29addbf5 100644 --- a/pkg/code/async/sequencer/intent_handler.go +++ b/pkg/code/async/sequencer/intent_handler.go @@ -38,10 +38,6 @@ func (h *OpenAccountsIntentHandler) OnActionUpdated(ctx context.Context, intentI } for _, actionRecord := range actionRecords { - if actionRecord.ActionType != action.OpenAccount { - continue - } - // Intent is failed if at least one OpenAccount action fails if actionRecord.State == action.StateFailed { return markIntentFailed(ctx, h.data, intentId) @@ -73,11 +69,11 @@ func (h *SendPublicPaymentIntentHandler) OnActionUpdated(ctx context.Context, in } actionRecordsToCheck := actionRecords - if len(actionRecords) > 1 { + if len(actionRecords) > 2 { // Do not include the auto-return action, which is a different server-side // initiated intent using the final action here. // - // todo: Assumes > 1 case is just remote send + // todo: Assumes > 2 case is just remote send, but saves a DB call actionRecordsToCheck = actionRecordsToCheck[:len(actionRecordsToCheck)-1] } diff --git a/pkg/code/common/account.go b/pkg/code/common/account.go index 44d18093..735cc074 100644 --- a/pkg/code/common/account.go +++ b/pkg/code/common/account.go @@ -451,7 +451,7 @@ func ValidateExternalTokenAccount(ctx context.Context, data code_data.Provider, default: // Unfortunate if Solana is down, but this only impacts withdraw flows, // and we need to guarantee this isn't going to something that's not - // a Kin token acocunt. + // a core mint token acocunt. return false, "", err } } diff --git a/pkg/code/common/subsidizer.go b/pkg/code/common/subsidizer.go index 6756eb30..3d3770b5 100644 --- a/pkg/code/common/subsidizer.go +++ b/pkg/code/common/subsidizer.go @@ -30,18 +30,15 @@ const ( ) // todo: doesn't consider external deposits +// todo: need a better system given fees are dynamic, we'll consider the worst case for each fulfillment type to be safe var ( - // This doesn't account for recovery of rent, which implies some fulfillments - // actually have negative fees. We often need to think about "in flight" costs - // and SOL balances for our subsidizer, so we exclude rent recovery which - // ensures our estimates are always on the conservative side of things. lamportsByFulfillment = map[fulfillment.Type]uint64{ - fulfillment.InitializeLockedTimelockAccount: 5000, // 0.000005 SOL (5000 lamports per signature) - fulfillment.NoPrivacyTransferWithAuthority: 5000, // 0.000005 SOL (5000 lamports per signature) - fulfillment.NoPrivacyWithdraw: 5000, // 0.000005 SOL (5000 lamports per signature) - fulfillment.CloseEmptyTimelockAccount: 5000, // 0.000005 SOL (5000 lamports per signature) + fulfillment.InitializeLockedTimelockAccount: 5050, + fulfillment.NoPrivacyTransferWithAuthority: 203928 + 5125, + fulfillment.NoPrivacyWithdraw: 5100, + fulfillment.CloseEmptyTimelockAccount: 5100, } - lamportsPerCreateNonceAccount uint64 = 1450000 // 0.00145 SOL + lamportsPerCreateNonceAccount uint64 = 1450000 ) var ( diff --git a/pkg/code/data/action/action.go b/pkg/code/data/action/action.go index 1e355641..1ac47f77 100644 --- a/pkg/code/data/action/action.go +++ b/pkg/code/data/action/action.go @@ -4,6 +4,8 @@ import ( "errors" "time" + transactionpb "github.com/code-payments/code-protobuf-api/generated/go/transaction/v2" + "github.com/code-payments/code-server/pkg/code/data/intent" "github.com/code-payments/code-server/pkg/pointer" ) @@ -49,7 +51,7 @@ type Record struct { Source string // Source token account involved Destination *string // Destination token account involved, when it makes sense - // Kin quark amount involved, when it makes sense. This must be set for actions + // Core mint quark amount involved, when it makes sense. This must be set for actions // that make balance changes across Code accounts! For deferred actions that are // initially in the unknown state, the balance may be nil and updated at a later // time. Store implementations will enforce which actions will allow quantity updates. @@ -60,6 +62,8 @@ type Record struct { // use cases before forming a firm opinion. Quantity *uint64 + FeeType *transactionpb.FeePaymentAction_FeeType + State State CreatedAt time.Time @@ -92,6 +96,10 @@ func (r *Record) Validate() error { return errors.New("quantity is required when set") } + if r.FeeType != nil && *r.FeeType == transactionpb.FeePaymentAction_UNKNOWN { + return errors.New("fee type is required when set") + } + return nil } @@ -109,6 +117,8 @@ func (r *Record) Clone() Record { Destination: pointer.StringCopy(r.Destination), Quantity: pointer.Uint64Copy(r.Quantity), + FeeType: (*transactionpb.FeePaymentAction_FeeType)(pointer.Int32Copy((*int32)(r.FeeType))), + State: r.State, CreatedAt: r.CreatedAt, @@ -125,8 +135,10 @@ func (r *Record) CopyTo(dst *Record) { dst.ActionType = r.ActionType dst.Source = r.Source - dst.Destination = r.Destination - dst.Quantity = r.Quantity + dst.Destination = pointer.StringCopy(r.Destination) + dst.Quantity = pointer.Uint64Copy(r.Quantity) + + dst.FeeType = (*transactionpb.FeePaymentAction_FeeType)(pointer.Int32Copy((*int32)(r.FeeType))) dst.State = r.State diff --git a/pkg/code/data/action/memory/store.go b/pkg/code/data/action/memory/store.go index 29223cca..c4b3e3a9 100644 --- a/pkg/code/data/action/memory/store.go +++ b/pkg/code/data/action/memory/store.go @@ -6,6 +6,8 @@ import ( "sync" "time" + transactionpb "github.com/code-payments/code-protobuf-api/generated/go/transaction/v2" + "github.com/code-payments/code-server/pkg/code/data/action" "github.com/code-payments/code-server/pkg/code/data/intent" "github.com/code-payments/code-server/pkg/pointer" @@ -111,6 +113,22 @@ func (s *store) filterByActionType(items []*action.Record, want action.Type) []* return res } +func (s *store) filterByFeeType(items []*action.Record, want transactionpb.FeePaymentAction_FeeType) []*action.Record { + var res []*action.Record + + for _, item := range items { + if item.FeeType == nil { + continue + } + + if *item.FeeType == want { + res = append(res, item) + } + } + + return res +} + func (s *store) filterByState(items []*action.Record, include bool, states ...action.State) []*action.Record { var res []*action.Record @@ -295,6 +313,17 @@ func (s *store) GetGiftCardAutoReturnAction(ctx context.Context, giftCardVault s return &cloned, nil } +// CountFeeActions implements action.store.CountFeeActions +func (s *store) CountFeeActions(ctx context.Context, intent string, feeType transactionpb.FeePaymentAction_FeeType) (uint64, error) { + s.mu.Lock() + defer s.mu.Unlock() + + items := s.findByIntent(intent) + items = s.filterByFeeType(items, feeType) + items = s.filterByState(items, false, action.StateRevoked) + return uint64(len(items)), nil +} + func (s *store) getNetBalance(account string) int64 { var res int64 diff --git a/pkg/code/data/action/postgres/model.go b/pkg/code/data/action/postgres/model.go index fe1cc96a..b49a0930 100644 --- a/pkg/code/data/action/postgres/model.go +++ b/pkg/code/data/action/postgres/model.go @@ -9,9 +9,12 @@ import ( "github.com/jmoiron/sqlx" + transactionpb "github.com/code-payments/code-protobuf-api/generated/go/transaction/v2" + "github.com/code-payments/code-server/pkg/code/data/action" "github.com/code-payments/code-server/pkg/code/data/intent" pgutil "github.com/code-payments/code-server/pkg/database/postgres" + "github.com/code-payments/code-server/pkg/pointer" ) const ( @@ -27,6 +30,7 @@ type model struct { Source string `db:"source"` Destination sql.NullString `db:"destination"` Quantity sql.NullInt64 `db:"quantity"` + FeeType sql.NullInt32 `db:"fee_type"` State uint `db:"state"` CreatedAt time.Time `db:"created_at"` } @@ -48,6 +52,12 @@ func toModel(obj *action.Record) (*model, error) { quantity.Int64 = int64(*obj.Quantity) } + var feeType sql.NullInt32 + if obj.FeeType != nil { + feeType.Valid = true + feeType.Int32 = int32(*obj.FeeType) + } + return &model{ Intent: obj.Intent, IntentType: uint(obj.IntentType), @@ -56,23 +66,13 @@ func toModel(obj *action.Record) (*model, error) { Source: obj.Source, Destination: destination, Quantity: quantity, + FeeType: feeType, State: uint(obj.State), CreatedAt: obj.CreatedAt, }, nil } func fromModel(obj *model) *action.Record { - var destination *string - if obj.Destination.Valid { - destination = &obj.Destination.String - } - - var quantity *uint64 - if obj.Quantity.Valid { - value := uint64(obj.Quantity.Int64) - quantity = &value - } - return &action.Record{ Id: uint64(obj.Id.Int64), Intent: obj.Intent, @@ -80,8 +80,9 @@ func fromModel(obj *model) *action.Record { ActionId: uint32(obj.ActionId), ActionType: action.Type(obj.ActionType), Source: obj.Source, - Destination: destination, - Quantity: quantity, + Destination: pointer.StringIfValid(obj.Destination.Valid, obj.Destination.String), + Quantity: pointer.Uint64IfValid(obj.Quantity.Valid, uint64(obj.Quantity.Int64)), + FeeType: (*transactionpb.FeePaymentAction_FeeType)(pointer.Int32IfValid(obj.FeeType.Valid, obj.FeeType.Int32)), State: action.State(obj.State), CreatedAt: obj.CreatedAt, } @@ -104,7 +105,7 @@ func (m *model) dbUpdate(ctx context.Context, db *sqlx.DB) error { query := fmt.Sprintf(`UPDATE `+tableName+` SET state = $3%s WHERE intent = $1 AND action_id = $2 - RETURNING id, intent, intent_type, action_id, action_type, source, destination, quantity, state, created_at + RETURNING id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, created_at `, quantityUpdateStmt) err := tx.QueryRowxContext( @@ -123,7 +124,7 @@ func (m *model) dbUpdate(ctx context.Context, db *sqlx.DB) error { func dbPutAllInTx(ctx context.Context, tx *sqlx.Tx, models []*model) ([]*model, error) { var res []*model - query := `INSERT INTO ` + tableName + ` (intent, intent_type, action_id, action_type, source, destination, quantity, state, created_at) VALUES ` + query := `INSERT INTO ` + tableName + ` (intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, created_at) VALUES ` var parameters []interface{} for i, model := range models { @@ -133,8 +134,8 @@ func dbPutAllInTx(ctx context.Context, tx *sqlx.Tx, models []*model) ([]*model, baseIndex := len(parameters) query += fmt.Sprintf( - `($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d)`, - baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7, baseIndex+8, baseIndex+9, + `($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d)`, + baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7, baseIndex+8, baseIndex+9, baseIndex+10, ) if i != len(models)-1 { @@ -150,12 +151,13 @@ func dbPutAllInTx(ctx context.Context, tx *sqlx.Tx, models []*model) ([]*model, model.Source, model.Destination, model.Quantity, + model.FeeType, model.State, model.CreatedAt, ) } - query += ` RETURNING id, intent, intent_type, action_id, action_type, source, destination, quantity, state, created_at` + query += ` RETURNING id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, created_at` err := tx.SelectContext( ctx, @@ -173,7 +175,7 @@ func dbPutAllInTx(ctx context.Context, tx *sqlx.Tx, models []*model) ([]*model, func dbGetById(ctx context.Context, db *sqlx.DB, intent string, actionId uint32) (*model, error) { res := &model{} - query := `SELECT id, intent, intent_type, action_id, action_type, source, destination, quantity, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, created_at FROM ` + tableName + ` WHERE intent = $1 AND action_id = $2 LIMIT 1` @@ -188,7 +190,7 @@ func dbGetById(ctx context.Context, db *sqlx.DB, intent string, actionId uint32) func dbGetAllByIntent(ctx context.Context, db *sqlx.DB, intent string) ([]*model, error) { res := []*model{} - query := `SELECT id, intent, intent_type, action_id, action_type, source, destination, quantity, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, created_at FROM ` + tableName + ` WHERE intent = $1 ORDER BY action_id ASC` @@ -208,7 +210,7 @@ func dbGetAllByIntent(ctx context.Context, db *sqlx.DB, intent string) ([]*model func dbGetAllByAddress(ctx context.Context, db *sqlx.DB, address string) ([]*model, error) { res := []*model{} - query := `SELECT id, intent, intent_type, action_id, action_type, source, destination, quantity, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, created_at FROM ` + tableName + ` WHERE source = $1 OR destination = $1` @@ -249,6 +251,7 @@ func dbGetNetBalance(ctx context.Context, db *sqlx.DB, account string) (int64, e } return res.Int64, nil } + func dbGetNetBalanceBatch(ctx context.Context, db *sqlx.DB, accounts ...string) (map[string]int64, error) { if len(accounts) == 0 { return make(map[string]int64), nil @@ -297,7 +300,7 @@ func dbGetNetBalanceBatch(ctx context.Context, db *sqlx.DB, accounts ...string) func dbGetGiftCardClaimedAction(ctx context.Context, db *sqlx.DB, giftCardVault string) (*model, error) { res := []*model{} - query := `SELECT id, intent, intent_type, action_id, action_type, source, destination, quantity, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, created_at FROM ` + tableName + ` WHERE source = $1 AND action_type = $2 AND intent_type = $3 AND state != $4 LIMIT 2` @@ -327,7 +330,7 @@ func dbGetGiftCardClaimedAction(ctx context.Context, db *sqlx.DB, giftCardVault func dbGetGiftCardAutoReturnAction(ctx context.Context, db *sqlx.DB, giftCardVault string) (*model, error) { res := []*model{} - query := `SELECT id, intent, intent_type, action_id, action_type, source, destination, quantity, state, created_at + query := `SELECT id, intent, intent_type, action_id, action_type, source, destination, quantity, fee_type, state, created_at FROM ` + tableName + ` WHERE source = $1 AND action_type = $2 AND intent_type = $3 AND state != $4 LIMIT 2` @@ -353,3 +356,22 @@ func dbGetGiftCardAutoReturnAction(ctx context.Context, db *sqlx.DB, giftCardVau return res[0], nil } + +func dbCountFeeActions(ctx context.Context, db *sqlx.DB, intent string, feeType transactionpb.FeePaymentAction_FeeType) (uint64, error) { + var res uint64 + + query := `SELECT COUNT(*) FROM ` + tableName + ` WHERE intent = $1 AND fee_type = $2 AND state != $3` + + err := db.GetContext( + ctx, + &res, + query, + intent, + feeType, + action.StateRevoked, + ) + if err != nil { + return 0, err + } + return res, nil +} diff --git a/pkg/code/data/action/postgres/store.go b/pkg/code/data/action/postgres/store.go index c841300e..c786d493 100644 --- a/pkg/code/data/action/postgres/store.go +++ b/pkg/code/data/action/postgres/store.go @@ -8,6 +8,7 @@ import ( "github.com/jmoiron/sqlx" + transactionpb "github.com/code-payments/code-protobuf-api/generated/go/transaction/v2" "github.com/code-payments/code-server/pkg/code/data/action" pgutil "github.com/code-payments/code-server/pkg/database/postgres" ) @@ -156,3 +157,7 @@ func (s *store) GetGiftCardAutoReturnAction(ctx context.Context, giftCardVault s } return fromModel(model), nil } + +func (s *store) CountFeeActions(ctx context.Context, intent string, feeType transactionpb.FeePaymentAction_FeeType) (uint64, error) { + return dbCountFeeActions(ctx, s.db, intent, feeType) +} diff --git a/pkg/code/data/action/postgres/store_test.go b/pkg/code/data/action/postgres/store_test.go index 86758484..927ba4e1 100644 --- a/pkg/code/data/action/postgres/store_test.go +++ b/pkg/code/data/action/postgres/store_test.go @@ -32,6 +32,8 @@ const ( destination TEXT NULL, quantity INTEGER NULL, + fee_type INTEGER NULL, + state INTEGER NOT NULL, created_at timestamp with time zone NOT NULL, diff --git a/pkg/code/data/action/store.go b/pkg/code/data/action/store.go index 7775d78a..0d7d1bff 100644 --- a/pkg/code/data/action/store.go +++ b/pkg/code/data/action/store.go @@ -3,6 +3,8 @@ package action import ( "context" "errors" + + transactionpb "github.com/code-payments/code-protobuf-api/generated/go/transaction/v2" ) var ( @@ -29,7 +31,7 @@ type Store interface { // todo: Support paging for accounts that might have many actions when a use case emerges GetAllByAddress(ctx context.Context, address string) ([]*Record, error) - // GetNetBalance gets the net balance of Kin in quarks after appying actions + // GetNetBalance gets the net balance of core mint quarks after appying actions // that operate on balances. GetNetBalance(ctx context.Context, account string) (int64, error) @@ -47,4 +49,7 @@ type Store interface { // as a source. This DB cannot validate the account type, so that must be done // prior to making this call elsewhere. GetGiftCardAutoReturnAction(ctx context.Context, giftCardVault string) (*Record, error) + + // CountFeeActions counts the number of fee actions of the specified type for an intent + CountFeeActions(ctx context.Context, intent string, feeType transactionpb.FeePaymentAction_FeeType) (uint64, error) } diff --git a/pkg/code/data/action/tests/tests.go b/pkg/code/data/action/tests/tests.go index 73665c40..d1583bac 100644 --- a/pkg/code/data/action/tests/tests.go +++ b/pkg/code/data/action/tests/tests.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + transactionpb "github.com/code-payments/code-protobuf-api/generated/go/transaction/v2" + "github.com/code-payments/code-server/pkg/code/data/action" "github.com/code-payments/code-server/pkg/code/data/intent" "github.com/code-payments/code-server/pkg/pointer" @@ -24,6 +26,7 @@ func RunTests(t *testing.T, s action.Store, teardown func()) { testGetNetBalance, testGetGiftCardClaimedAction, testGetGiftCardAutoReturnAction, + testCountCountFeeActions, } { tf(t, s) teardown() @@ -47,6 +50,8 @@ func testRoundTrip(t *testing.T, s action.Store) { Destination: pointer.String("destination"), Quantity: nil, + FeeType: (*transactionpb.FeePaymentAction_FeeType)(pointer.Int32((int32)(transactionpb.FeePaymentAction_CREATE_ON_SEND_WITHDRAWAL))), + State: action.StateConfirmed, } @@ -94,12 +99,13 @@ func testBatchPut(t *testing.T, s action.Store) { for i := 0; i < 1000; i++ { actionRecord := &action.Record{ Intent: fmt.Sprintf("intent%d", i), - IntentType: intent.SendPrivatePayment, + IntentType: intent.SendPublicPayment, ActionId: uint32(i), - ActionType: action.PrivateTransfer, + ActionType: action.NoPrivacyTransfer, Source: fmt.Sprintf("source%d", i), Destination: pointer.String(fmt.Sprintf("destination%d", i)), Quantity: pointer.Uint64(uint64(i + 1)), + FeeType: (*transactionpb.FeePaymentAction_FeeType)(pointer.Int32((int32)(transactionpb.FeePaymentAction_CREATE_ON_SEND_WITHDRAWAL))), CreatedAt: time.Now().Add(time.Duration(i) * time.Second), } cloned := actionRecord.Clone() @@ -371,6 +377,44 @@ func testGetGiftCardAutoReturnAction(t *testing.T, s action.Store) { }) } +func testCountCountFeeActions(t *testing.T, s action.Store) { + t.Run("testCountCountFeeActions", func(t *testing.T) { + ctx := context.Background() + + feeType := transactionpb.FeePaymentAction_CREATE_ON_SEND_WITHDRAWAL + records := []*action.Record{ + {Intent: "i1", IntentType: intent.SendPublicPayment, ActionId: 0, ActionType: action.NoPrivacyTransfer, Source: "a1", Destination: pointer.String("destination"), FeeType: &feeType, State: action.StateUnknown}, + {Intent: "i1", IntentType: intent.SendPublicPayment, ActionId: 1, ActionType: action.NoPrivacyTransfer, Source: "a1", Destination: pointer.String("destination"), FeeType: &feeType, State: action.StatePending}, + {Intent: "i1", IntentType: intent.SendPublicPayment, ActionId: 2, ActionType: action.NoPrivacyTransfer, Source: "a1", Destination: pointer.String("destination"), FeeType: &feeType, State: action.StateFailed}, + {Intent: "i1", IntentType: intent.SendPublicPayment, ActionId: 3, ActionType: action.NoPrivacyTransfer, Source: "a1", Destination: pointer.String("destination"), FeeType: &feeType, State: action.StateConfirmed}, + {Intent: "i1", IntentType: intent.SendPublicPayment, ActionId: 4, ActionType: action.NoPrivacyTransfer, Source: "a1", Destination: pointer.String("destination"), State: action.StateConfirmed}, + {Intent: "i1", IntentType: intent.SendPublicPayment, ActionId: 5, ActionType: action.NoPrivacyTransfer, Source: "a1", Destination: pointer.String("destination"), State: action.StateConfirmed}, + + {Intent: "i2", IntentType: intent.SendPublicPayment, ActionId: 1, ActionType: action.NoPrivacyTransfer, Source: "a2", Destination: pointer.String("destination"), FeeType: &feeType, State: action.StateRevoked}, + {Intent: "i2", IntentType: intent.SendPublicPayment, ActionId: 0, ActionType: action.NoPrivacyTransfer, Source: "a2", Destination: pointer.String("destination"), FeeType: &feeType, State: action.StatePending}, + {Intent: "i2", IntentType: intent.SendPublicPayment, ActionId: 2, ActionType: action.NoPrivacyTransfer, Source: "a2", Destination: pointer.String("destination"), State: action.StatePending}, + } + + require.NoError(t, s.PutAll(ctx, records...)) + + count, err := s.CountFeeActions(ctx, "i1", transactionpb.FeePaymentAction_CREATE_ON_SEND_WITHDRAWAL) + require.NoError(t, err) + assert.EqualValues(t, 4, count) + + count, err = s.CountFeeActions(ctx, "i2", transactionpb.FeePaymentAction_CREATE_ON_SEND_WITHDRAWAL) + require.NoError(t, err) + assert.EqualValues(t, 1, count) + + count, err = s.CountFeeActions(ctx, "i3", transactionpb.FeePaymentAction_CREATE_ON_SEND_WITHDRAWAL) + require.NoError(t, err) + assert.EqualValues(t, 0, count) + + count, err = s.CountFeeActions(ctx, "i1", transactionpb.FeePaymentAction_UNKNOWN) + require.NoError(t, err) + assert.EqualValues(t, 0, count) + }) +} + func assertEquivalentRecords(t *testing.T, obj1, obj2 *action.Record) { assert.Equal(t, obj1.Intent, obj2.Intent) assert.Equal(t, obj1.IntentType, obj2.IntentType) @@ -382,5 +426,7 @@ func assertEquivalentRecords(t *testing.T, obj1, obj2 *action.Record) { assert.EqualValues(t, obj1.Destination, obj2.Destination) assert.EqualValues(t, obj1.Quantity, obj2.Quantity) + assert.EqualValues(t, obj1.FeeType, obj2.FeeType) + assert.Equal(t, obj1.State, obj2.State) } diff --git a/pkg/code/data/deposit/memory/store.go b/pkg/code/data/deposit/memory/store.go index 8799b08e..6c844845 100644 --- a/pkg/code/data/deposit/memory/store.go +++ b/pkg/code/data/deposit/memory/store.go @@ -70,7 +70,7 @@ func (s *store) GetQuarkAmount(_ context.Context, account string) (uint64, error s.mu.Lock() defer s.mu.Unlock() - return s.getKinAmount(account), nil + return s.getQuarkAmount(account), nil } // GetQuarkAmountBatch implements deposit.Store.GetQuarkAmountBatch @@ -80,7 +80,7 @@ func (s *store) GetQuarkAmountBatch(_ context.Context, accounts ...string) (map[ res := make(map[string]uint64) for _, account := range accounts { - res[account] = s.getKinAmount(account) + res[account] = s.getQuarkAmount(account) } return res, nil } @@ -93,7 +93,7 @@ func (s *store) GetUsdAmount(ctx context.Context, account string) (float64, erro return s.getUsdAmount(account), nil } -func (s *store) getKinAmount(account string) uint64 { +func (s *store) getQuarkAmount(account string) uint64 { items := s.findByDestination(account) items = s.filterFinalized(items) return s.sumAmounts(items) diff --git a/pkg/code/data/external_test.go b/pkg/code/data/external_test.go index 93c376ff..753ce9a1 100644 --- a/pkg/code/data/external_test.go +++ b/pkg/code/data/external_test.go @@ -30,7 +30,7 @@ func TestComputeAllExchangeRates_HappyPath(t *testing.T) { } func TestComputeAllExchangeRates_UsdRateMissing(t *testing.T) { - kinRates := map[string]float64{ + coreMintRates := map[string]float64{ "cad": 1.0, } @@ -41,6 +41,6 @@ func TestComputeAllExchangeRates_UsdRateMissing(t *testing.T) { "aud": 0.66, } - _, err := computeAllExchangeRates(kinRates, usdRates) + _, err := computeAllExchangeRates(coreMintRates, usdRates) assert.Error(t, err) } diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index 39eaaf2d..4189a463 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -18,6 +18,7 @@ import ( timelock_token "github.com/code-payments/code-server/pkg/solana/timelock/v1" commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" + transactionpb "github.com/code-payments/code-protobuf-api/generated/go/transaction/v2" "github.com/code-payments/code-server/pkg/code/data/account" "github.com/code-payments/code-server/pkg/code/data/action" @@ -102,6 +103,25 @@ type DatabaseData interface { GetAccountInfoCountRequiringDepositSync(ctx context.Context) (uint64, error) GetAccountInfoCountRequiringAutoReturnCheck(ctx context.Context) (uint64, error) + // Actions + // -------------------------------------------------------------------------------- + PutAllActions(ctx context.Context, records ...*action.Record) error + UpdateAction(ctx context.Context, record *action.Record) error + GetActionById(ctx context.Context, intent string, actionId uint32) (*action.Record, error) + GetAllActionsByIntent(ctx context.Context, intent string) ([]*action.Record, error) + GetAllActionsByAddress(ctx context.Context, address string) ([]*action.Record, error) + GetNetBalanceFromActions(ctx context.Context, address string) (int64, error) + GetNetBalanceFromActionsBatch(ctx context.Context, accounts ...string) (map[string]int64, error) + GetGiftCardClaimedAction(ctx context.Context, giftCardVault string) (*action.Record, error) + GetGiftCardAutoReturnAction(ctx context.Context, giftCardVault string) (*action.Record, error) + CountFeeActions(ctx context.Context, intent string, feeType transactionpb.FeePaymentAction_FeeType) (uint64, error) + HasFeeAction(ctx context.Context, intent string, feeType transactionpb.FeePaymentAction_FeeType) (bool, error) + + // Balance + // -------------------------------------------------------------------------------- + SaveBalanceCheckpoint(ctx context.Context, record *balance.Record) error + GetBalanceCheckpoint(ctx context.Context, account string) (*balance.Record, error) + // Currency // -------------------------------------------------------------------------------- GetExchangeRate(ctx context.Context, code currency_lib.Code, t time.Time) (*currency.ExchangeRateRecord, error) @@ -109,25 +129,28 @@ type DatabaseData interface { GetExchangeRateHistory(ctx context.Context, code currency_lib.Code, opts ...query.Option) ([]*currency.ExchangeRateRecord, error) ImportExchangeRates(ctx context.Context, data *currency.MultiRateRecord) error - // Vault + // CVM RAM // -------------------------------------------------------------------------------- - GetKey(ctx context.Context, public_key string) (*vault.Record, error) - GetKeyCount(ctx context.Context) (uint64, error) - GetKeyCountByState(ctx context.Context, state vault.State) (uint64, error) - GetAllKeysByState(ctx context.Context, state vault.State, opts ...query.Option) ([]*vault.Record, error) - SaveKey(ctx context.Context, record *vault.Record) error + InitializeVmMemory(ctx context.Context, record *cvm_ram.Record) error + FreeVmMemoryByIndex(ctx context.Context, memoryAccount string, index uint16) error + FreeVmMemoryByAddress(ctx context.Context, address string) error + ReserveVmMemory(ctx context.Context, vm string, accountType cvm.VirtualAccountType, address string) (string, uint16, error) - // Nonce + // CVM Storage // -------------------------------------------------------------------------------- - GetNonce(ctx context.Context, address string) (*nonce.Record, error) - GetNonceCount(ctx context.Context, env nonce.Environment, instance string) (uint64, error) - GetNonceCountByState(ctx context.Context, env nonce.Environment, instance string, state nonce.State) (uint64, error) - GetNonceCountByStateAndPurpose(ctx context.Context, env nonce.Environment, instance string, state nonce.State, purpose nonce.Purpose) (uint64, error) - GetAllNonceByState(ctx context.Context, env nonce.Environment, instance string, state nonce.State, opts ...query.Option) ([]*nonce.Record, error) - BatchClaimAvailableNoncesByPurpose(ctx context.Context, env nonce.Environment, instance string, purpose nonce.Purpose, limit int, nodeID string, minExpireAt, maxExpireAt time.Time) ([]*nonce.Record, error) - SaveNonce(ctx context.Context, record *nonce.Record) error + InitializeVmStorage(ctx context.Context, record *cvm_storage.Record) error + FindAnyVmStorageWithAvailableCapacity(ctx context.Context, vm string, purpose cvm_storage.Purpose, minCapacity uint64) (*cvm_storage.Record, error) + ReserveVmStorage(ctx context.Context, vm string, purpose cvm_storage.Purpose, address string) (string, error) + + // Deposits + // -------------------------------------------------------------------------------- + SaveExternalDeposit(ctx context.Context, record *deposit.Record) error + GetExternalDeposit(ctx context.Context, signature, destination string) (*deposit.Record, error) + GetTotalExternalDepositedAmountInQuarks(ctx context.Context, account string) (uint64, error) + GetTotalExternalDepositedAmountInQuarksBatch(ctx context.Context, accounts ...string) (map[string]uint64, error) + GetTotalExternalDepositedAmountInUsd(ctx context.Context, account string) (float64, error) - // Fulfillment + // Fulfillments // -------------------------------------------------------------------------------- GetFulfillmentById(ctx context.Context, id uint64) (*fulfillment.Record, error) GetFulfillmentBySignature(ctx context.Context, signature string) (*fulfillment.Record, error) @@ -155,7 +178,7 @@ type DatabaseData interface { UpdateFulfillment(ctx context.Context, record *fulfillment.Record) error MarkFulfillmentAsActivelyScheduled(ctx context.Context, id uint64) error - // Intent + // Intents // -------------------------------------------------------------------------------- SaveIntent(ctx context.Context, record *intent.Record) error GetIntent(ctx context.Context, intentID string) (*intent.Record, error) @@ -166,22 +189,10 @@ type DatabaseData interface { GetGiftCardClaimedIntent(ctx context.Context, giftCardVault string) (*intent.Record, error) GetTransactedAmountForAntiMoneyLaundering(ctx context.Context, owner string, since time.Time) (uint64, float64, error) - // Action + // Merkle Trees // -------------------------------------------------------------------------------- - PutAllActions(ctx context.Context, records ...*action.Record) error - UpdateAction(ctx context.Context, record *action.Record) error - GetActionById(ctx context.Context, intent string, actionId uint32) (*action.Record, error) - GetAllActionsByIntent(ctx context.Context, intent string) ([]*action.Record, error) - GetAllActionsByAddress(ctx context.Context, address string) ([]*action.Record, error) - GetNetBalanceFromActions(ctx context.Context, address string) (int64, error) - GetNetBalanceFromActionsBatch(ctx context.Context, accounts ...string) (map[string]int64, error) - GetGiftCardClaimedAction(ctx context.Context, giftCardVault string) (*action.Record, error) - GetGiftCardAutoReturnAction(ctx context.Context, giftCardVault string) (*action.Record, error) - - // Transaction - // -------------------------------------------------------------------------------- - GetTransaction(ctx context.Context, sig string) (*transaction.Record, error) - SaveTransaction(ctx context.Context, record *transaction.Record) error + InitializeNewMerkleTree(ctx context.Context, name string, levels uint8, seeds []merkletree.Seed, readOnly bool) (*merkletree.MerkleTree, error) + LoadExistingMerkleTree(ctx context.Context, name string, readOnly bool) (*merkletree.MerkleTree, error) // Messaging // -------------------------------------------------------------------------------- @@ -189,28 +200,15 @@ type DatabaseData interface { GetMessages(ctx context.Context, account string) ([]*messaging.Record, error) DeleteMessage(ctx context.Context, account string, messageID uuid.UUID) error - // Timelock - // -------------------------------------------------------------------------------- - SaveTimelock(ctx context.Context, record *timelock.Record) error - GetTimelockByAddress(ctx context.Context, address string) (*timelock.Record, error) - GetTimelockByVault(ctx context.Context, vault string) (*timelock.Record, error) - GetTimelockByDepositPda(ctx context.Context, depositPda string) (*timelock.Record, error) - GetTimelockByVaultBatch(ctx context.Context, vaults ...string) (map[string]*timelock.Record, error) - GetAllTimelocksByState(ctx context.Context, state timelock_token.TimelockState, opts ...query.Option) ([]*timelock.Record, error) - GetTimelockCountByState(ctx context.Context, state timelock_token.TimelockState) (uint64, error) - - // Merkle Tree - // -------------------------------------------------------------------------------- - InitializeNewMerkleTree(ctx context.Context, name string, levels uint8, seeds []merkletree.Seed, readOnly bool) (*merkletree.MerkleTree, error) - LoadExistingMerkleTree(ctx context.Context, name string, readOnly bool) (*merkletree.MerkleTree, error) - - // External Deposits + // Nonces // -------------------------------------------------------------------------------- - SaveExternalDeposit(ctx context.Context, record *deposit.Record) error - GetExternalDeposit(ctx context.Context, signature, destination string) (*deposit.Record, error) - GetTotalExternalDepositedAmountInQuarks(ctx context.Context, account string) (uint64, error) - GetTotalExternalDepositedAmountInQuarksBatch(ctx context.Context, accounts ...string) (map[string]uint64, error) - GetTotalExternalDepositedAmountInUsd(ctx context.Context, account string) (float64, error) + GetNonce(ctx context.Context, address string) (*nonce.Record, error) + GetNonceCount(ctx context.Context, env nonce.Environment, instance string) (uint64, error) + GetNonceCountByState(ctx context.Context, env nonce.Environment, instance string, state nonce.State) (uint64, error) + GetNonceCountByStateAndPurpose(ctx context.Context, env nonce.Environment, instance string, state nonce.State, purpose nonce.Purpose) (uint64, error) + GetAllNonceByState(ctx context.Context, env nonce.Environment, instance string, state nonce.State, opts ...query.Option) ([]*nonce.Record, error) + BatchClaimAvailableNoncesByPurpose(ctx context.Context, env nonce.Environment, instance string, purpose nonce.Purpose, limit int, nodeID string, minExpireAt, maxExpireAt time.Time) ([]*nonce.Record, error) + SaveNonce(ctx context.Context, record *nonce.Record) error // Rendezvous // -------------------------------------------------------------------------------- @@ -219,23 +217,28 @@ type DatabaseData interface { DeleteRendezvous(ctx context.Context, key, address string) error GetRendezvous(ctx context.Context, key string) (*rendezvous.Record, error) - // Balance + // Timelocks // -------------------------------------------------------------------------------- - SaveBalanceCheckpoint(ctx context.Context, record *balance.Record) error - GetBalanceCheckpoint(ctx context.Context, account string) (*balance.Record, error) + SaveTimelock(ctx context.Context, record *timelock.Record) error + GetTimelockByAddress(ctx context.Context, address string) (*timelock.Record, error) + GetTimelockByVault(ctx context.Context, vault string) (*timelock.Record, error) + GetTimelockByDepositPda(ctx context.Context, depositPda string) (*timelock.Record, error) + GetTimelockByVaultBatch(ctx context.Context, vaults ...string) (map[string]*timelock.Record, error) + GetAllTimelocksByState(ctx context.Context, state timelock_token.TimelockState, opts ...query.Option) ([]*timelock.Record, error) + GetTimelockCountByState(ctx context.Context, state timelock_token.TimelockState) (uint64, error) - // CVM RAM + // Transactions // -------------------------------------------------------------------------------- - InitializeVmMemory(ctx context.Context, record *cvm_ram.Record) error - FreeVmMemoryByIndex(ctx context.Context, memoryAccount string, index uint16) error - FreeVmMemoryByAddress(ctx context.Context, address string) error - ReserveVmMemory(ctx context.Context, vm string, accountType cvm.VirtualAccountType, address string) (string, uint16, error) + GetTransaction(ctx context.Context, sig string) (*transaction.Record, error) + SaveTransaction(ctx context.Context, record *transaction.Record) error - // CVM Storage + // Vault // -------------------------------------------------------------------------------- - InitializeVmStorage(ctx context.Context, record *cvm_storage.Record) error - FindAnyVmStorageWithAvailableCapacity(ctx context.Context, vm string, purpose cvm_storage.Purpose, minCapacity uint64) (*cvm_storage.Record, error) - ReserveVmStorage(ctx context.Context, vm string, purpose cvm_storage.Purpose, address string) (string, error) + GetKey(ctx context.Context, public_key string) (*vault.Record, error) + GetKeyCount(ctx context.Context) (uint64, error) + GetKeyCountByState(ctx context.Context, state vault.State) (uint64, error) + GetAllKeysByState(ctx context.Context, state vault.State, opts ...query.Option) ([]*vault.Record, error) + SaveKey(ctx context.Context, record *vault.Record) error // ExecuteInTx executes fn with a single DB transaction that is scoped to the call. // This enables more complex transactions that can span many calls across the provider. @@ -248,21 +251,21 @@ type DatabaseData interface { type DatabaseProvider struct { accounts account.Store + actions action.Store + balance balance.Store currencies currency.Store - vault vault.Store - nonces nonce.Store + cvmRam cvm_ram.Store + cvmStorage cvm_storage.Store + deposits deposit.Store fulfillments fulfillment.Store intents intent.Store - actions action.Store - transactions transaction.Store + merkleTrees merkletree.Store messages messaging.Store - timelock timelock.Store - merkleTree merkletree.Store - deposits deposit.Store + nonces nonce.Store rendezvous rendezvous.Store - balance balance.Store - cvmRam cvm_ram.Store - cvmStorage cvm_storage.Store + timelocks timelock.Store + transactions transaction.Store + vault vault.Store exchangeCache cache.Cache timelockCache cache.Cache @@ -293,21 +296,21 @@ func NewDatabaseProvider(dbConfig *pg.Config) (DatabaseData, error) { return &DatabaseProvider{ accounts: account_postgres_client.New(db), + actions: action_postgres_client.New(db), + balance: balance_postgres_client.New(db), currencies: currency_postgres_client.New(db), - nonces: nonce_postgres_client.New(db), + cvmRam: cvm_ram_postgres_client.New(db), + cvmStorage: cvm_storage_postgres_client.New(db), + deposits: deposit_postgres_client.New(db), fulfillments: fulfillment_postgres_client.New(db), intents: intent_postgres_client.New(db), - actions: action_postgres_client.New(db), - transactions: transaction_postgres_client.New(db), + merkleTrees: merkletree_postgres_client.New(db), messages: messaging_postgres_client.New(db), - timelock: timelock_postgres_client.New(db), - vault: vault_postgres_client.New(db), - merkleTree: merkletree_postgres_client.New(db), - deposits: deposit_postgres_client.New(db), + nonces: nonce_postgres_client.New(db), rendezvous: rendezvous_postgres_client.New(db), - balance: balance_postgres_client.New(db), - cvmRam: cvm_ram_postgres_client.New(db), - cvmStorage: cvm_storage_postgres_client.New(db), + timelocks: timelock_postgres_client.New(db), + transactions: transaction_postgres_client.New(db), + vault: vault_postgres_client.New(db), exchangeCache: cache.NewCache(maxExchangeRateCacheBudget), timelockCache: cache.NewCache(maxTimelockCacheBudget), @@ -319,21 +322,21 @@ func NewDatabaseProvider(dbConfig *pg.Config) (DatabaseData, error) { func NewTestDatabaseProvider() DatabaseData { return &DatabaseProvider{ accounts: account_memory_client.New(), + actions: action_memory_client.New(), + balance: balance_memory_client.New(), currencies: currency_memory_client.New(), - nonces: nonce_memory_client.New(), + cvmRam: cvm_ram_memory_client.New(), + cvmStorage: cvm_storage_memory_client.New(), + deposits: deposit_memory_client.New(), fulfillments: fulfillment_memory_client.New(), intents: intent_memory_client.New(), - actions: action_memory_client.New(), - transactions: transaction_memory_client.New(), - timelock: timelock_memory_client.New(), - vault: vault_memory_client.New(), - merkleTree: merkletree_memory_client.New(), + merkleTrees: merkletree_memory_client.New(), messages: messaging_memory_client.New(), - deposits: deposit_memory_client.New(), + nonces: nonce_memory_client.New(), rendezvous: rendezvous_memory_client.New(), - balance: balance_memory_client.New(), - cvmRam: cvm_ram_memory_client.New(), - cvmStorage: cvm_storage_memory_client.New(), + timelocks: timelock_memory_client.New(), + transactions: transaction_memory_client.New(), + vault: vault_memory_client.New(), exchangeCache: cache.NewCache(maxExchangeRateCacheBudget), timelockCache: nil, // Shouldn't be used for tests @@ -384,7 +387,56 @@ func (dp *DatabaseProvider) GetAccountInfoCountRequiringAutoReturnCheck(ctx cont return dp.accounts.CountRequiringAutoReturnCheck(ctx) } -// Currency +// Actions +// -------------------------------------------------------------------------------- +func (dp *DatabaseProvider) PutAllActions(ctx context.Context, records ...*action.Record) error { + return dp.actions.PutAll(ctx, records...) +} +func (dp *DatabaseProvider) UpdateAction(ctx context.Context, record *action.Record) error { + return dp.actions.Update(ctx, record) +} +func (dp *DatabaseProvider) GetActionById(ctx context.Context, intent string, actionId uint32) (*action.Record, error) { + return dp.actions.GetById(ctx, intent, actionId) +} +func (dp *DatabaseProvider) GetAllActionsByIntent(ctx context.Context, intent string) ([]*action.Record, error) { + return dp.actions.GetAllByIntent(ctx, intent) +} +func (dp *DatabaseProvider) GetAllActionsByAddress(ctx context.Context, address string) ([]*action.Record, error) { + return dp.actions.GetAllByAddress(ctx, address) +} +func (dp *DatabaseProvider) GetNetBalanceFromActions(ctx context.Context, address string) (int64, error) { + return dp.actions.GetNetBalance(ctx, address) +} +func (dp *DatabaseProvider) GetNetBalanceFromActionsBatch(ctx context.Context, accounts ...string) (map[string]int64, error) { + return dp.actions.GetNetBalanceBatch(ctx, accounts...) +} +func (dp *DatabaseProvider) GetGiftCardClaimedAction(ctx context.Context, giftCardVault string) (*action.Record, error) { + return dp.actions.GetGiftCardClaimedAction(ctx, giftCardVault) +} +func (dp *DatabaseProvider) GetGiftCardAutoReturnAction(ctx context.Context, giftCardVault string) (*action.Record, error) { + return dp.actions.GetGiftCardAutoReturnAction(ctx, giftCardVault) +} +func (dp *DatabaseProvider) CountFeeActions(ctx context.Context, intent string, feeType transactionpb.FeePaymentAction_FeeType) (uint64, error) { + return dp.actions.CountFeeActions(ctx, intent, feeType) +} +func (dp *DatabaseProvider) HasFeeAction(ctx context.Context, intent string, feeType transactionpb.FeePaymentAction_FeeType) (bool, error) { + count, err := dp.actions.CountFeeActions(ctx, intent, feeType) + if err != nil { + return false, err + } + return count > 0, nil +} + +// Balance +// -------------------------------------------------------------------------------- +func (dp *DatabaseProvider) SaveBalanceCheckpoint(ctx context.Context, record *balance.Record) error { + return dp.balance.SaveCheckpoint(ctx, record) +} +func (dp *DatabaseProvider) GetBalanceCheckpoint(ctx context.Context, account string) (*balance.Record, error) { + return dp.balance.GetCheckpoint(ctx, account) +} + +// Currencies // -------------------------------------------------------------------------------- func (dp *DatabaseProvider) GetExchangeRate(ctx context.Context, code currency_lib.Code, t time.Time) (*currency.ExchangeRateRecord, error) { key := fmt.Sprintf("%s:%s", code, t.Truncate(5*time.Minute).Format(time.RFC3339)) @@ -437,62 +489,52 @@ func (dp *DatabaseProvider) ImportExchangeRates(ctx context.Context, data *curre return dp.currencies.Put(ctx, data) } -// Vault +// CVM RAM // -------------------------------------------------------------------------------- -func (dp *DatabaseProvider) GetKey(ctx context.Context, public_key string) (*vault.Record, error) { - return dp.vault.Get(ctx, public_key) -} - -func (dp *DatabaseProvider) GetKeyCount(ctx context.Context) (uint64, error) { - return dp.vault.Count(ctx) +func (dp *DatabaseProvider) InitializeVmMemory(ctx context.Context, record *cvm_ram.Record) error { + return dp.cvmRam.InitializeMemory(ctx, record) } - -func (dp *DatabaseProvider) GetKeyCountByState(ctx context.Context, state vault.State) (uint64, error) { - return dp.vault.CountByState(ctx, state) +func (dp *DatabaseProvider) FreeVmMemoryByIndex(ctx context.Context, memoryAccount string, index uint16) error { + return dp.cvmRam.FreeMemoryByIndex(ctx, memoryAccount, index) } - -func (dp *DatabaseProvider) GetAllKeysByState(ctx context.Context, state vault.State, opts ...query.Option) ([]*vault.Record, error) { - req, err := query.DefaultPaginationHandlerWithLimit(25, opts...) - if err != nil { - return nil, err - } - - return dp.vault.GetAllByState(ctx, state, req.Cursor, req.Limit, req.SortBy) +func (dp *DatabaseProvider) FreeVmMemoryByAddress(ctx context.Context, address string) error { + return dp.cvmRam.FreeMemoryByAddress(ctx, address) } -func (dp *DatabaseProvider) SaveKey(ctx context.Context, record *vault.Record) error { - return dp.vault.Save(ctx, record) +func (dp *DatabaseProvider) ReserveVmMemory(ctx context.Context, vm string, accountType cvm.VirtualAccountType, address string) (string, uint16, error) { + return dp.cvmRam.ReserveMemory(ctx, vm, accountType, address) } -// Nonce +// CVM Storage // -------------------------------------------------------------------------------- -func (dp *DatabaseProvider) GetNonce(ctx context.Context, address string) (*nonce.Record, error) { - return dp.nonces.Get(ctx, address) -} -func (dp *DatabaseProvider) GetNonceCount(ctx context.Context, env nonce.Environment, instance string) (uint64, error) { - return dp.nonces.Count(ctx, env, instance) +func (dp *DatabaseProvider) InitializeVmStorage(ctx context.Context, record *cvm_storage.Record) error { + return dp.cvmStorage.InitializeStorage(ctx, record) } -func (dp *DatabaseProvider) GetNonceCountByState(ctx context.Context, env nonce.Environment, instance string, state nonce.State) (uint64, error) { - return dp.nonces.CountByState(ctx, env, instance, state) +func (dp *DatabaseProvider) FindAnyVmStorageWithAvailableCapacity(ctx context.Context, vm string, purpose cvm_storage.Purpose, minCapacity uint64) (*cvm_storage.Record, error) { + return dp.cvmStorage.FindAnyWithAvailableCapacity(ctx, vm, purpose, minCapacity) } -func (dp *DatabaseProvider) GetNonceCountByStateAndPurpose(ctx context.Context, env nonce.Environment, instance string, state nonce.State, purpose nonce.Purpose) (uint64, error) { - return dp.nonces.CountByStateAndPurpose(ctx, env, instance, state, purpose) +func (dp *DatabaseProvider) ReserveVmStorage(ctx context.Context, vm string, purpose cvm_storage.Purpose, address string) (string, error) { + return dp.cvmStorage.ReserveStorage(ctx, vm, purpose, address) } -func (dp *DatabaseProvider) GetAllNonceByState(ctx context.Context, env nonce.Environment, instance string, state nonce.State, opts ...query.Option) ([]*nonce.Record, error) { - req, err := query.DefaultPaginationHandler(opts...) - if err != nil { - return nil, err - } - return dp.nonces.GetAllByState(ctx, env, instance, state, req.Cursor, req.Limit, req.SortBy) +// Deposits +// -------------------------------------------------------------------------------- +func (dp *DatabaseProvider) SaveExternalDeposit(ctx context.Context, record *deposit.Record) error { + return dp.deposits.Save(ctx, record) } -func (dp *DatabaseProvider) BatchClaimAvailableNoncesByPurpose(ctx context.Context, env nonce.Environment, instance string, purpose nonce.Purpose, limit int, nodeID string, minExpireAt, maxExpireAt time.Time) ([]*nonce.Record, error) { - return dp.nonces.BatchClaimAvailableByPurpose(ctx, env, instance, purpose, limit, nodeID, minExpireAt, maxExpireAt) +func (dp *DatabaseProvider) GetExternalDeposit(ctx context.Context, signature, account string) (*deposit.Record, error) { + return dp.deposits.Get(ctx, signature, account) } -func (dp *DatabaseProvider) SaveNonce(ctx context.Context, record *nonce.Record) error { - return dp.nonces.Save(ctx, record) +func (dp *DatabaseProvider) GetTotalExternalDepositedAmountInQuarks(ctx context.Context, account string) (uint64, error) { + return dp.deposits.GetQuarkAmount(ctx, account) +} +func (dp *DatabaseProvider) GetTotalExternalDepositedAmountInQuarksBatch(ctx context.Context, accounts ...string) (map[string]uint64, error) { + return dp.deposits.GetQuarkAmountBatch(ctx, accounts...) +} +func (dp *DatabaseProvider) GetTotalExternalDepositedAmountInUsd(ctx context.Context, account string) (float64, error) { + return dp.deposits.GetUsdAmount(ctx, account) } -// Fulfillment +// Fulfillments // -------------------------------------------------------------------------------- func (dp *DatabaseProvider) GetFulfillmentById(ctx context.Context, id uint64) (*fulfillment.Record, error) { return dp.fulfillments.GetById(ctx, id) @@ -580,7 +622,7 @@ func (dp *DatabaseProvider) MarkFulfillmentAsActivelyScheduled(ctx context.Conte return dp.fulfillments.MarkAsActivelyScheduled(ctx, id) } -// Intent +// Intents // -------------------------------------------------------------------------------- func (dp *DatabaseProvider) GetIntent(ctx context.Context, intentID string) (*intent.Record, error) { return dp.intents.Get(ctx, intentID) @@ -619,70 +661,82 @@ func (dp *DatabaseProvider) GetTransactedAmountForAntiMoneyLaundering(ctx contex return dp.intents.GetTransactedAmountForAntiMoneyLaundering(ctx, owner, since) } -// Action +// Merkle Trees // -------------------------------------------------------------------------------- -func (dp *DatabaseProvider) PutAllActions(ctx context.Context, records ...*action.Record) error { - return dp.actions.PutAll(ctx, records...) +func (dp *DatabaseProvider) InitializeNewMerkleTree(ctx context.Context, name string, levels uint8, seeds []merkletree.Seed, readOnly bool) (*merkletree.MerkleTree, error) { + return merkletree.InitializeNew(ctx, dp.merkleTrees, name, levels, seeds, readOnly) } -func (dp *DatabaseProvider) UpdateAction(ctx context.Context, record *action.Record) error { - return dp.actions.Update(ctx, record) +func (dp *DatabaseProvider) LoadExistingMerkleTree(ctx context.Context, name string, readOnly bool) (*merkletree.MerkleTree, error) { + return merkletree.LoadExisting(ctx, dp.merkleTrees, name, readOnly) } -func (dp *DatabaseProvider) GetActionById(ctx context.Context, intent string, actionId uint32) (*action.Record, error) { - return dp.actions.GetById(ctx, intent, actionId) + +// Messaging +// -------------------------------------------------------------------------------- +func (dp *DatabaseProvider) CreateMessage(ctx context.Context, record *messaging.Record) error { + return dp.messages.Insert(ctx, record) } -func (dp *DatabaseProvider) GetAllActionsByIntent(ctx context.Context, intent string) ([]*action.Record, error) { - return dp.actions.GetAllByIntent(ctx, intent) +func (dp *DatabaseProvider) GetMessages(ctx context.Context, account string) ([]*messaging.Record, error) { + return dp.messages.Get(ctx, account) } -func (dp *DatabaseProvider) GetAllActionsByAddress(ctx context.Context, address string) ([]*action.Record, error) { - return dp.actions.GetAllByAddress(ctx, address) +func (dp *DatabaseProvider) DeleteMessage(ctx context.Context, account string, messageID uuid.UUID) error { + return dp.messages.Delete(ctx, account, messageID) } -func (dp *DatabaseProvider) GetNetBalanceFromActions(ctx context.Context, address string) (int64, error) { - return dp.actions.GetNetBalance(ctx, address) + +// Nonces +// -------------------------------------------------------------------------------- +func (dp *DatabaseProvider) GetNonce(ctx context.Context, address string) (*nonce.Record, error) { + return dp.nonces.Get(ctx, address) } -func (dp *DatabaseProvider) GetNetBalanceFromActionsBatch(ctx context.Context, accounts ...string) (map[string]int64, error) { - return dp.actions.GetNetBalanceBatch(ctx, accounts...) +func (dp *DatabaseProvider) GetNonceCount(ctx context.Context, env nonce.Environment, instance string) (uint64, error) { + return dp.nonces.Count(ctx, env, instance) } -func (dp *DatabaseProvider) GetGiftCardClaimedAction(ctx context.Context, giftCardVault string) (*action.Record, error) { - return dp.actions.GetGiftCardClaimedAction(ctx, giftCardVault) +func (dp *DatabaseProvider) GetNonceCountByState(ctx context.Context, env nonce.Environment, instance string, state nonce.State) (uint64, error) { + return dp.nonces.CountByState(ctx, env, instance, state) } -func (dp *DatabaseProvider) GetGiftCardAutoReturnAction(ctx context.Context, giftCardVault string) (*action.Record, error) { - return dp.actions.GetGiftCardAutoReturnAction(ctx, giftCardVault) +func (dp *DatabaseProvider) GetNonceCountByStateAndPurpose(ctx context.Context, env nonce.Environment, instance string, state nonce.State, purpose nonce.Purpose) (uint64, error) { + return dp.nonces.CountByStateAndPurpose(ctx, env, instance, state, purpose) } +func (dp *DatabaseProvider) GetAllNonceByState(ctx context.Context, env nonce.Environment, instance string, state nonce.State, opts ...query.Option) ([]*nonce.Record, error) { + req, err := query.DefaultPaginationHandler(opts...) + if err != nil { + return nil, err + } -// Transaction -// -------------------------------------------------------------------------------- -func (dp *DatabaseProvider) GetTransaction(ctx context.Context, sig string) (*transaction.Record, error) { - return dp.transactions.Get(ctx, sig) + return dp.nonces.GetAllByState(ctx, env, instance, state, req.Cursor, req.Limit, req.SortBy) } -func (dp *DatabaseProvider) SaveTransaction(ctx context.Context, record *transaction.Record) error { - return dp.transactions.Put(ctx, record) +func (dp *DatabaseProvider) BatchClaimAvailableNoncesByPurpose(ctx context.Context, env nonce.Environment, instance string, purpose nonce.Purpose, limit int, nodeID string, minExpireAt, maxExpireAt time.Time) ([]*nonce.Record, error) { + return dp.nonces.BatchClaimAvailableByPurpose(ctx, env, instance, purpose, limit, nodeID, minExpireAt, maxExpireAt) +} +func (dp *DatabaseProvider) SaveNonce(ctx context.Context, record *nonce.Record) error { + return dp.nonces.Save(ctx, record) } -// Messaging +// Rendezvous // -------------------------------------------------------------------------------- -func (dp *DatabaseProvider) CreateMessage(ctx context.Context, record *messaging.Record) error { - return dp.messages.Insert(ctx, record) +func (dp *DatabaseProvider) PutRendezvous(ctx context.Context, record *rendezvous.Record) error { + return dp.rendezvous.Put(ctx, record) } - -func (dp *DatabaseProvider) GetMessages(ctx context.Context, account string) ([]*messaging.Record, error) { - return dp.messages.Get(ctx, account) +func (dp *DatabaseProvider) ExtendRendezvousExpiry(ctx context.Context, key, address string, expiry time.Time) error { + return dp.rendezvous.ExtendExpiry(ctx, key, address, expiry) } - -func (dp *DatabaseProvider) DeleteMessage(ctx context.Context, account string, messageID uuid.UUID) error { - return dp.messages.Delete(ctx, account, messageID) +func (dp *DatabaseProvider) DeleteRendezvous(ctx context.Context, key, address string) error { + return dp.rendezvous.Delete(ctx, key, address) +} +func (dp *DatabaseProvider) GetRendezvous(ctx context.Context, key string) (*rendezvous.Record, error) { + return dp.rendezvous.Get(ctx, key) } -// Timelock +// Timelocks // -------------------------------------------------------------------------------- func (dp *DatabaseProvider) SaveTimelock(ctx context.Context, record *timelock.Record) error { - return dp.timelock.Save(ctx, record) + return dp.timelocks.Save(ctx, record) } func (dp *DatabaseProvider) GetTimelockByAddress(ctx context.Context, address string) (*timelock.Record, error) { // todo: add caching if this becomes a heavy hitter like GetByVault - return dp.timelock.GetByAddress(ctx, address) + return dp.timelocks.GetByAddress(ctx, address) } func (dp *DatabaseProvider) GetTimelockByVaultBatch(ctx context.Context, vaults ...string) (map[string]*timelock.Record, error) { - records, err := dp.timelock.GetByVaultBatch(ctx, vaults...) + records, err := dp.timelocks.GetByVaultBatch(ctx, vaults...) if err != nil { return nil, err } @@ -714,7 +768,7 @@ func (dp *DatabaseProvider) GetTimelockByVaultBatch(ctx context.Context, vaults func (dp *DatabaseProvider) GetTimelockByVault(ctx context.Context, vault string) (*timelock.Record, error) { // Don't use a cache if it hasn't been setup (eg. test implementation) if dp.timelockCache == nil { - return dp.timelock.GetByVault(ctx, vault) + return dp.timelocks.GetByVault(ctx, vault) } // todo: Use a cache implementation that has TTLs and refreshes lol @@ -741,7 +795,7 @@ func (dp *DatabaseProvider) GetTimelockByVault(ctx context.Context, vault string } // Cached value is still stale, so fetch from the DB - record, err := dp.timelock.GetByVault(ctx, vault) + record, err := dp.timelocks.GetByVault(ctx, vault) if err == nil { cacheEntry.record = record.Clone() cacheEntry.lastUpdatedAt = time.Now() @@ -750,7 +804,7 @@ func (dp *DatabaseProvider) GetTimelockByVault(ctx context.Context, vault string } // Record not cached, so fetch it and insert the initial cache entry - record, err := dp.timelock.GetByVault(ctx, vault) + record, err := dp.timelocks.GetByVault(ctx, vault) if err == nil { cacheEntry := &timelockCacheEntry{ record: record.Clone(), @@ -761,7 +815,7 @@ func (dp *DatabaseProvider) GetTimelockByVault(ctx context.Context, vault string return record, err } func (dp *DatabaseProvider) GetTimelockByDepositPda(ctx context.Context, depositPda string) (*timelock.Record, error) { - return dp.timelock.GetByDepositPda(ctx, depositPda) + return dp.timelocks.GetByDepositPda(ctx, depositPda) } func (dp *DatabaseProvider) GetAllTimelocksByState(ctx context.Context, state timelock_token.TimelockState, opts ...query.Option) ([]*timelock.Record, error) { req, err := query.DefaultPaginationHandler(opts...) @@ -769,85 +823,40 @@ func (dp *DatabaseProvider) GetAllTimelocksByState(ctx context.Context, state ti return nil, err } - return dp.timelock.GetAllByState(ctx, state, req.Cursor, req.Limit, req.SortBy) + return dp.timelocks.GetAllByState(ctx, state, req.Cursor, req.Limit, req.SortBy) } func (dp *DatabaseProvider) GetTimelockCountByState(ctx context.Context, state timelock_token.TimelockState) (uint64, error) { - return dp.timelock.GetCountByState(ctx, state) + return dp.timelocks.GetCountByState(ctx, state) } -// Merkle Tree -func (dp *DatabaseProvider) InitializeNewMerkleTree(ctx context.Context, name string, levels uint8, seeds []merkletree.Seed, readOnly bool) (*merkletree.MerkleTree, error) { - return merkletree.InitializeNew(ctx, dp.merkleTree, name, levels, seeds, readOnly) -} -func (dp *DatabaseProvider) LoadExistingMerkleTree(ctx context.Context, name string, readOnly bool) (*merkletree.MerkleTree, error) { - return merkletree.LoadExisting(ctx, dp.merkleTree, name, readOnly) -} - -// External Deposits -// -------------------------------------------------------------------------------- -func (dp *DatabaseProvider) SaveExternalDeposit(ctx context.Context, record *deposit.Record) error { - return dp.deposits.Save(ctx, record) -} -func (dp *DatabaseProvider) GetExternalDeposit(ctx context.Context, signature, account string) (*deposit.Record, error) { - return dp.deposits.Get(ctx, signature, account) -} -func (dp *DatabaseProvider) GetTotalExternalDepositedAmountInQuarks(ctx context.Context, account string) (uint64, error) { - return dp.deposits.GetQuarkAmount(ctx, account) -} -func (dp *DatabaseProvider) GetTotalExternalDepositedAmountInQuarksBatch(ctx context.Context, accounts ...string) (map[string]uint64, error) { - return dp.deposits.GetQuarkAmountBatch(ctx, accounts...) -} -func (dp *DatabaseProvider) GetTotalExternalDepositedAmountInUsd(ctx context.Context, account string) (float64, error) { - return dp.deposits.GetUsdAmount(ctx, account) -} - -// Rendezvous +// Transactions // -------------------------------------------------------------------------------- -func (dp *DatabaseProvider) PutRendezvous(ctx context.Context, record *rendezvous.Record) error { - return dp.rendezvous.Put(ctx, record) -} -func (dp *DatabaseProvider) ExtendRendezvousExpiry(ctx context.Context, key, address string, expiry time.Time) error { - return dp.rendezvous.ExtendExpiry(ctx, key, address, expiry) -} -func (dp *DatabaseProvider) DeleteRendezvous(ctx context.Context, key, address string) error { - return dp.rendezvous.Delete(ctx, key, address) -} -func (dp *DatabaseProvider) GetRendezvous(ctx context.Context, key string) (*rendezvous.Record, error) { - return dp.rendezvous.Get(ctx, key) -} - -// Balance -// -------------------------------------------------------------------------------- -func (dp *DatabaseProvider) SaveBalanceCheckpoint(ctx context.Context, record *balance.Record) error { - return dp.balance.SaveCheckpoint(ctx, record) +func (dp *DatabaseProvider) GetTransaction(ctx context.Context, sig string) (*transaction.Record, error) { + return dp.transactions.Get(ctx, sig) } -func (dp *DatabaseProvider) GetBalanceCheckpoint(ctx context.Context, account string) (*balance.Record, error) { - return dp.balance.GetCheckpoint(ctx, account) +func (dp *DatabaseProvider) SaveTransaction(ctx context.Context, record *transaction.Record) error { + return dp.transactions.Put(ctx, record) } -// VM RAM +// Vault // -------------------------------------------------------------------------------- -func (dp *DatabaseProvider) InitializeVmMemory(ctx context.Context, record *cvm_ram.Record) error { - return dp.cvmRam.InitializeMemory(ctx, record) -} -func (dp *DatabaseProvider) FreeVmMemoryByIndex(ctx context.Context, memoryAccount string, index uint16) error { - return dp.cvmRam.FreeMemoryByIndex(ctx, memoryAccount, index) +func (dp *DatabaseProvider) GetKey(ctx context.Context, public_key string) (*vault.Record, error) { + return dp.vault.Get(ctx, public_key) } -func (dp *DatabaseProvider) FreeVmMemoryByAddress(ctx context.Context, address string) error { - return dp.cvmRam.FreeMemoryByAddress(ctx, address) +func (dp *DatabaseProvider) GetKeyCount(ctx context.Context) (uint64, error) { + return dp.vault.Count(ctx) } -func (dp *DatabaseProvider) ReserveVmMemory(ctx context.Context, vm string, accountType cvm.VirtualAccountType, address string) (string, uint16, error) { - return dp.cvmRam.ReserveMemory(ctx, vm, accountType, address) +func (dp *DatabaseProvider) GetKeyCountByState(ctx context.Context, state vault.State) (uint64, error) { + return dp.vault.CountByState(ctx, state) } +func (dp *DatabaseProvider) GetAllKeysByState(ctx context.Context, state vault.State, opts ...query.Option) ([]*vault.Record, error) { + req, err := query.DefaultPaginationHandlerWithLimit(25, opts...) + if err != nil { + return nil, err + } -// VM Storage -// -------------------------------------------------------------------------------- -func (dp *DatabaseProvider) InitializeVmStorage(ctx context.Context, record *cvm_storage.Record) error { - return dp.cvmStorage.InitializeStorage(ctx, record) -} -func (dp *DatabaseProvider) FindAnyVmStorageWithAvailableCapacity(ctx context.Context, vm string, purpose cvm_storage.Purpose, minCapacity uint64) (*cvm_storage.Record, error) { - return dp.cvmStorage.FindAnyWithAvailableCapacity(ctx, vm, purpose, minCapacity) + return dp.vault.GetAllByState(ctx, state, req.Cursor, req.Limit, req.SortBy) } -func (dp *DatabaseProvider) ReserveVmStorage(ctx context.Context, vm string, purpose cvm_storage.Purpose, address string) (string, error) { - return dp.cvmStorage.ReserveStorage(ctx, vm, purpose, address) +func (dp *DatabaseProvider) SaveKey(ctx context.Context, record *vault.Record) error { + return dp.vault.Save(ctx, record) } diff --git a/pkg/code/server/transaction/action_handler.go b/pkg/code/server/transaction/action_handler.go index 4b029ca7..211865cd 100644 --- a/pkg/code/server/transaction/action_handler.go +++ b/pkg/code/server/transaction/action_handler.go @@ -183,11 +183,10 @@ func (h *OpenAccountActionHandler) OnSaveToDB(ctx context.Context) error { } type NoPrivacyTransferActionHandler struct { - source *common.TimelockAccounts - destination *common.Account - amount uint64 - isFeePayment bool // Internally, the mechanics of a fee payment are exactly the same - isCodeFeePayment bool + source *common.TimelockAccounts + destination *common.Account + amount uint64 + feeType transactionpb.FeePaymentAction_FeeType // Internally, the mechanics of a fee payment are exactly the same } func NewNoPrivacyTransferActionHandler(protoAction *transactionpb.NoPrivacyTransferAction) (CreateActionHandler, error) { @@ -207,10 +206,9 @@ func NewNoPrivacyTransferActionHandler(protoAction *transactionpb.NoPrivacyTrans } return &NoPrivacyTransferActionHandler{ - source: source, - destination: destination, - amount: protoAction.Amount, - isFeePayment: false, + source: source, + destination: destination, + amount: protoAction.Amount, }, nil } @@ -225,24 +223,11 @@ func NewFeePaymentActionHandler(protoAction *transactionpb.FeePaymentAction, fee return nil, err } - var destination *common.Account - var isCodeFeePayment bool - if protoAction.Type == transactionpb.FeePaymentAction_CODE { - destination = feeCollector - isCodeFeePayment = true - } else { - destination, err = common.NewAccountFromProto(protoAction.Destination) - if err != nil { - return nil, err - } - } - return &NoPrivacyTransferActionHandler{ - source: source, - destination: destination, - amount: protoAction.Amount, - isFeePayment: true, - isCodeFeePayment: isCodeFeePayment, + source: source, + destination: feeCollector, + amount: protoAction.Amount, + feeType: protoAction.Type, }, nil } @@ -257,21 +242,20 @@ func (h *NoPrivacyTransferActionHandler) PopulateMetadata(actionRecord *action.R actionRecord.Quantity = &h.amount + if h.isFeePayment() { + actionRecord.FeeType = &h.feeType + } + actionRecord.State = action.StatePending return nil } func (h *NoPrivacyTransferActionHandler) GetServerParameter() *transactionpb.ServerParameter { - if h.isFeePayment { - var codeDestination *commonpb.SolanaAccountId - if h.isCodeFeePayment { - codeDestination = h.destination.ToProto() - } - + if h.isFeePayment() { return &transactionpb.ServerParameter{ Type: &transactionpb.ServerParameter_FeePayment{ FeePayment: &transactionpb.FeePaymentServerParameter{ - CodeDestination: codeDestination, + Destination: h.destination.ToProto(), }, }, } @@ -312,7 +296,6 @@ func (h *NoPrivacyTransferActionHandler) GetFulfillmentMetadata( source: h.source.Vault, destination: h.destination, fulfillmentOrderingIndex: 0, - disableActiveScheduling: h.isFeePayment, }, nil default: return nil, errors.New("invalid transaction index") @@ -323,6 +306,10 @@ func (h *NoPrivacyTransferActionHandler) OnSaveToDB(ctx context.Context) error { return nil } +func (h *NoPrivacyTransferActionHandler) isFeePayment() bool { + return h.feeType != transactionpb.FeePaymentAction_UNKNOWN +} + type NoPrivacyWithdrawActionHandler struct { source *common.TimelockAccounts destination *common.Account diff --git a/pkg/code/server/transaction/config.go b/pkg/code/server/transaction/config.go index 5f8aae11..439e4594 100644 --- a/pkg/code/server/transaction/config.go +++ b/pkg/code/server/transaction/config.go @@ -15,9 +15,6 @@ const ( DisableSubmitIntentConfigEnvName = envConfigPrefix + "DISABLE_SUBMIT_INTENT" defaultDisableSubmitIntent = false - DisableBlockchainChecksConfigEnvName = envConfigPrefix + "DISABLE_BLOCKCHAIN_CHECKS" - defaultDisableBlockchainChecks = false - SubmitIntentTimeoutConfigEnvName = envConfigPrefix + "SUBMIT_INTENT_TIMEOUT" defaultSubmitIntentTimeout = 5 * time.Second @@ -27,6 +24,9 @@ const ( FeeCollectorTokenPublicKeyConfigEnvName = envConfigPrefix + "FEE_COLLECTOR_TOKEN_PUBLIC_KEY" defaultFeeCollectorPublicKey = "invalid" // Ensure something valid is set + CreateOnSendUsdWithdrawalFeeConfigEnvName = envConfigPrefix + "WITHDRAWAL_CREATE_ON_SEND_WITHDRAWAL_USD_FEE" + defaultCreateOnSendWithdrawalUsdFee = 0.50 + EnableAirdropsConfigEnvName = envConfigPrefix + "ENABLE_AIRDROPS" defaultEnableAirdrops = false @@ -38,17 +38,18 @@ const ( ) type conf struct { - disableSubmitIntent config.Bool - disableAntispamChecks config.Bool // To avoid limits during testing - disableAmlChecks config.Bool // To avoid limits during testing - disableBlockchainChecks config.Bool - submitIntentTimeout config.Duration - clientReceiveTimeout config.Duration - feeCollectorTokenPublicKey config.String - enableAirdrops config.Bool - airdropperOwnerPublicKey config.String - maxAirdropUsdValue config.Float64 - stripedLockParallelization config.Uint64 + disableSubmitIntent config.Bool + disableAntispamChecks config.Bool // To avoid limits during testing + disableAmlChecks config.Bool // To avoid limits during testing + disableBlockchainChecks config.Bool // To avoid blockchain checks during testing + submitIntentTimeout config.Duration + clientReceiveTimeout config.Duration + feeCollectorTokenPublicKey config.String + createOnSendWithdrawalUsdFee config.Float64 + enableAirdrops config.Bool + airdropperOwnerPublicKey config.String + maxAirdropUsdValue config.Float64 + stripedLockParallelization config.Uint64 } // ConfigProvider defines how config values are pulled @@ -58,17 +59,18 @@ type ConfigProvider func() *conf func WithEnvConfigs() ConfigProvider { return func() *conf { return &conf{ - disableSubmitIntent: env.NewBoolConfig(DisableSubmitIntentConfigEnvName, defaultDisableSubmitIntent), - disableAntispamChecks: wrapper.NewBoolConfig(memory.NewConfig(false), false), - disableAmlChecks: wrapper.NewBoolConfig(memory.NewConfig(false), false), - disableBlockchainChecks: env.NewBoolConfig(DisableBlockchainChecksConfigEnvName, defaultDisableBlockchainChecks), - submitIntentTimeout: env.NewDurationConfig(SubmitIntentTimeoutConfigEnvName, defaultSubmitIntentTimeout), - clientReceiveTimeout: env.NewDurationConfig(ClientReceiveTimeoutConfigEnvName, defaultClientReceiveTimeout), - feeCollectorTokenPublicKey: env.NewStringConfig(FeeCollectorTokenPublicKeyConfigEnvName, defaultFeeCollectorPublicKey), - enableAirdrops: env.NewBoolConfig(EnableAirdropsConfigEnvName, defaultEnableAirdrops), - airdropperOwnerPublicKey: env.NewStringConfig(AirdropperOwnerPublicKeyEnvName, defaultAirdropperOwnerPublicKey), - maxAirdropUsdValue: env.NewFloat64Config(MaxAirdropUsdValueEnvName, defaultMaxAirdropUsdValue), - stripedLockParallelization: wrapper.NewUint64Config(memory.NewConfig(8192), 8192), + disableSubmitIntent: env.NewBoolConfig(DisableSubmitIntentConfigEnvName, defaultDisableSubmitIntent), + disableAntispamChecks: wrapper.NewBoolConfig(memory.NewConfig(false), false), + disableAmlChecks: wrapper.NewBoolConfig(memory.NewConfig(false), false), + disableBlockchainChecks: wrapper.NewBoolConfig(memory.NewConfig(false), false), + submitIntentTimeout: env.NewDurationConfig(SubmitIntentTimeoutConfigEnvName, defaultSubmitIntentTimeout), + clientReceiveTimeout: env.NewDurationConfig(ClientReceiveTimeoutConfigEnvName, defaultClientReceiveTimeout), + feeCollectorTokenPublicKey: env.NewStringConfig(FeeCollectorTokenPublicKeyConfigEnvName, defaultFeeCollectorPublicKey), + createOnSendWithdrawalUsdFee: env.NewFloat64Config(CreateOnSendUsdWithdrawalFeeConfigEnvName, defaultCreateOnSendWithdrawalUsdFee), + enableAirdrops: env.NewBoolConfig(EnableAirdropsConfigEnvName, defaultEnableAirdrops), + airdropperOwnerPublicKey: env.NewStringConfig(AirdropperOwnerPublicKeyEnvName, defaultAirdropperOwnerPublicKey), + maxAirdropUsdValue: env.NewFloat64Config(MaxAirdropUsdValueEnvName, defaultMaxAirdropUsdValue), + stripedLockParallelization: wrapper.NewUint64Config(memory.NewConfig(8192), 8192), } } } @@ -85,17 +87,18 @@ type testOverrides struct { func withManualTestOverrides(overrides *testOverrides) ConfigProvider { return func() *conf { return &conf{ - disableSubmitIntent: wrapper.NewBoolConfig(memory.NewConfig(overrides.disableSubmitIntent), defaultDisableSubmitIntent), - disableAntispamChecks: wrapper.NewBoolConfig(memory.NewConfig(!overrides.enableAntispamChecks), false), - disableAmlChecks: wrapper.NewBoolConfig(memory.NewConfig(!overrides.enableAmlChecks), false), - disableBlockchainChecks: wrapper.NewBoolConfig(memory.NewConfig(true), true), - submitIntentTimeout: wrapper.NewDurationConfig(memory.NewConfig(defaultSubmitIntentTimeout), defaultSubmitIntentTimeout), - clientReceiveTimeout: wrapper.NewDurationConfig(memory.NewConfig(overrides.clientReceiveTimeout), defaultClientReceiveTimeout), - feeCollectorTokenPublicKey: wrapper.NewStringConfig(memory.NewConfig(overrides.feeCollectorTokenPublicKey), defaultFeeCollectorPublicKey), - enableAirdrops: wrapper.NewBoolConfig(memory.NewConfig(overrides.enableAirdrops), false), - airdropperOwnerPublicKey: wrapper.NewStringConfig(memory.NewConfig(defaultAirdropperOwnerPublicKey), defaultAirdropperOwnerPublicKey), - maxAirdropUsdValue: wrapper.NewFloat64Config(memory.NewConfig(defaultMaxAirdropUsdValue), defaultMaxAirdropUsdValue), - stripedLockParallelization: wrapper.NewUint64Config(memory.NewConfig(4), 4), + disableSubmitIntent: wrapper.NewBoolConfig(memory.NewConfig(overrides.disableSubmitIntent), defaultDisableSubmitIntent), + disableAntispamChecks: wrapper.NewBoolConfig(memory.NewConfig(!overrides.enableAntispamChecks), false), + disableAmlChecks: wrapper.NewBoolConfig(memory.NewConfig(!overrides.enableAmlChecks), false), + disableBlockchainChecks: wrapper.NewBoolConfig(memory.NewConfig(true), true), + submitIntentTimeout: wrapper.NewDurationConfig(memory.NewConfig(defaultSubmitIntentTimeout), defaultSubmitIntentTimeout), + clientReceiveTimeout: wrapper.NewDurationConfig(memory.NewConfig(overrides.clientReceiveTimeout), defaultClientReceiveTimeout), + feeCollectorTokenPublicKey: wrapper.NewStringConfig(memory.NewConfig(overrides.feeCollectorTokenPublicKey), defaultFeeCollectorPublicKey), + createOnSendWithdrawalUsdFee: wrapper.NewFloat64Config(memory.NewConfig(defaultCreateOnSendWithdrawalUsdFee), defaultCreateOnSendWithdrawalUsdFee), + enableAirdrops: wrapper.NewBoolConfig(memory.NewConfig(overrides.enableAirdrops), false), + airdropperOwnerPublicKey: wrapper.NewStringConfig(memory.NewConfig(defaultAirdropperOwnerPublicKey), defaultAirdropperOwnerPublicKey), + maxAirdropUsdValue: wrapper.NewFloat64Config(memory.NewConfig(defaultMaxAirdropUsdValue), defaultMaxAirdropUsdValue), + stripedLockParallelization: wrapper.NewUint64Config(memory.NewConfig(4), 4), } } } diff --git a/pkg/code/server/transaction/intent.go b/pkg/code/server/transaction/intent.go index d41e2163..0653109a 100644 --- a/pkg/code/server/transaction/intent.go +++ b/pkg/code/server/transaction/intent.go @@ -25,8 +25,8 @@ import ( "github.com/code-payments/code-server/pkg/code/data/action" "github.com/code-payments/code-server/pkg/code/data/fulfillment" "github.com/code-payments/code-server/pkg/code/data/intent" - "github.com/code-payments/code-server/pkg/code/data/timelock" "github.com/code-payments/code-server/pkg/code/transaction" + currency_lib "github.com/code-payments/code-server/pkg/currency" "github.com/code-payments/code-server/pkg/grpc/client" "github.com/code-payments/code-server/pkg/pointer" "github.com/code-payments/code-server/pkg/solana" @@ -210,7 +210,7 @@ func (s *transactionServer) SubmitIntent(streamer transactionpb.Transaction_Subm intentRecord := &intent.Record{ IntentId: intentId, InitiatorOwnerAccount: initiatorOwnerAccount.PublicKey().ToBase58(), - State: intent.StateUnknown, + State: intent.StatePending, CreatedAt: time.Now(), } @@ -554,6 +554,13 @@ func (s *transactionServer) SubmitIntent(streamer transactionpb.Transaction_Subm return err } + // Save additional state related to the intent + err = intentHandler.OnSaveToDB(ctx, intentRecord) + if err != nil { + log.WithError(err).Warn("failure executing intent db save callback") + return err + } + // Save all actions err = s.data.PutAllActions(ctx, actionRecords...) if err != nil { @@ -561,6 +568,15 @@ func (s *transactionServer) SubmitIntent(streamer transactionpb.Transaction_Subm return err } + // Save additional state related to each action + for _, actionHandler := range actionHandlers { + err = actionHandler.OnSaveToDB(ctx) + if err != nil { + log.WithError(err).Warn("failure executing action db save callback handler") + return err + } + } + // Save all fulfillment records fulfillmentRecordsToSave := make([]*fulfillment.Record, 0) for i, fulfillmentWithMetadata := range fulfillments { @@ -587,29 +603,6 @@ func (s *transactionServer) SubmitIntent(streamer transactionpb.Transaction_Subm return err } - // Save additional state related to each action - for _, actionHandler := range actionHandlers { - err = actionHandler.OnSaveToDB(ctx) - if err != nil { - log.WithError(err).Warn("failure executing action db save callback handler") - return err - } - } - - // Save additional state related to the intent - err = intentHandler.OnSaveToDB(ctx, intentRecord) - if err != nil { - log.WithError(err).Warn("failure executing intent db save callback") - return err - } - - // Mark the intent as pending once everything else has succeeded - err = s.markIntentAsPending(ctx, intentRecord) - if err != nil { - log.WithError(err).Warn("failure marking the intent as pending") - return err - } - return nil }) if err != nil { @@ -672,21 +665,6 @@ func (s *transactionServer) boundedSubmitIntentRecv(ctx context.Context, streame } } -func (s *transactionServer) markIntentAsPending(ctx context.Context, record *intent.Record) error { - if record.State != intent.StateUnknown { - return nil - } - - // After one minute, we mark the intent as revoked, so avoid the race with - // a time-based check until we have distributed locks - if time.Since(record.CreatedAt) > time.Minute { - return errors.New("took too long to mark intent as pending") - } - - record.State = intent.StatePending - return s.data.SaveIntent(ctx, record) -} - func (s *transactionServer) GetIntentMetadata(ctx context.Context, req *transactionpb.GetIntentMetadataRequest) (*transactionpb.GetIntentMetadataResponse, error) { intentId := base58.Encode(req.IntentId.Value) @@ -811,7 +789,6 @@ func (s *transactionServer) GetIntentMetadata(ctx context.Context, req *transact }, nil } -// todo: Test the blockchain checks when we have a mocked Solana client func (s *transactionServer) CanWithdrawToAccount(ctx context.Context, req *transactionpb.CanWithdrawToAccountRequest) (*transactionpb.CanWithdrawToAccountResponse, error) { log := s.log.WithField("method", "CanWithdrawToAccount") log = client.InjectLoggingMetadata(ctx, log) @@ -827,46 +804,33 @@ func (s *transactionServer) CanWithdrawToAccount(ctx context.Context, req *trans log = log.WithField("account", accountToCheck.PublicKey().ToBase58()) // - // Part 1: Is this a legacy Code timelock account? If so, deny it. + // Part 1: Is this a timelock vault? If so, only allow primary accounts. // - timelockRecord, err := s.data.GetTimelockByVault(ctx, accountToCheck.PublicKey().ToBase58()) + accountInfoRecord, err := s.data.GetAccountInfoByTokenAddress(ctx, accountToCheck.PublicKey().ToBase58()) switch err { case nil: - case timelock.ErrTimelockNotFound: + return &transactionpb.CanWithdrawToAccountResponse{ + IsValidPaymentDestination: accountInfoRecord.AccountType == commonpb.AccountType_PRIMARY, + AccountType: transactionpb.CanWithdrawToAccountResponse_TokenAccount, + }, nil + case account.ErrAccountInfoNotFound: // Nothing to do default: - log.WithError(err).Warn("failure checking timelock db") + log.WithError(err).Warn("failure checking account info db") return nil, status.Error(codes.Internal, "") } // - // Part 2: Is this a privacy-based timelock vault? If so, only allow primary accounts. - // - - if timelockRecord != nil { - accountInfoRecord, err := s.data.GetAccountInfoByTokenAddress(ctx, accountToCheck.PublicKey().ToBase58()) - if err == nil { - return &transactionpb.CanWithdrawToAccountResponse{ - AccountType: transactionpb.CanWithdrawToAccountResponse_TokenAccount, - IsValidPaymentDestination: accountInfoRecord.AccountType == commonpb.AccountType_PRIMARY, - }, nil - } else { - log.WithError(err).Warn("failure checking account info db") - return nil, status.Error(codes.Internal, "") - } - } - - // - // Part 3: Is this an opened Kin token account? If so, allow it. + // Part 2: Is this an opened core mint token account? If so, allow it. // _, err = s.data.GetBlockchainTokenAccountInfo(ctx, accountToCheck.PublicKey().ToBase58(), solana.CommitmentFinalized) switch err { case nil: return &transactionpb.CanWithdrawToAccountResponse{ - AccountType: transactionpb.CanWithdrawToAccountResponse_TokenAccount, IsValidPaymentDestination: true, + AccountType: transactionpb.CanWithdrawToAccountResponse_TokenAccount, }, nil case token.ErrAccountNotFound, solana.ErrNoAccountInfo, token.ErrInvalidTokenAccount: // Nothing to do @@ -876,7 +840,8 @@ func (s *transactionServer) CanWithdrawToAccount(ctx context.Context, req *trans } // - // Part 4: Is this an owner account with an opened Core Mint ATA? If so, allow it. + // Part 3: Is this an owner account with an opened Core Mint ATA? If so, allow it. + // If not, indicate to the client to pay a fee for a create-on-send withdrawal. // ata, err := accountToCheck.ToAssociatedTokenAccount(common.CoreMintAccount) @@ -885,7 +850,6 @@ func (s *transactionServer) CanWithdrawToAccount(ctx context.Context, req *trans return nil, status.Error(codes.Internal, "") } - var requiresInitialization bool _, err = s.data.GetBlockchainTokenAccountInfo(ctx, ata.PublicKey().ToBase58(), solana.CommitmentFinalized) switch err { case nil: @@ -895,20 +859,25 @@ func (s *transactionServer) CanWithdrawToAccount(ctx context.Context, req *trans }, nil case token.ErrAccountNotFound, solana.ErrNoAccountInfo: // ATA doesn't exist, and we won't be subsidizing it. Let the client know - // they should initialize it first. - requiresInitialization = true + // they require a fee. + return &transactionpb.CanWithdrawToAccountResponse{ + IsValidPaymentDestination: true, + AccountType: transactionpb.CanWithdrawToAccountResponse_OwnerAccount, + RequiresInitialization: true, + FeeAmount: &transactionpb.ExchangeDataWithoutRate{ + Currency: string(currency_lib.USD), + NativeAmount: s.conf.createOnSendWithdrawalUsdFee.Get(ctx), + }, + }, nil case token.ErrInvalidTokenAccount: - // Nothing to do + return &transactionpb.CanWithdrawToAccountResponse{ + IsValidPaymentDestination: false, + AccountType: transactionpb.CanWithdrawToAccountResponse_Unknown, + }, nil default: log.WithError(err).Warn("failure checking against blockchain as an owner account") return nil, status.Error(codes.Internal, "") } - - return &transactionpb.CanWithdrawToAccountResponse{ - AccountType: transactionpb.CanWithdrawToAccountResponse_Unknown, - IsValidPaymentDestination: false, - RequiresInitialization: requiresInitialization, - }, nil } func (s *transactionServer) VoidGiftCard(ctx context.Context, req *transactionpb.VoidGiftCardRequest) (*transactionpb.VoidGiftCardResponse, error) { diff --git a/pkg/code/server/transaction/intent_handler.go b/pkg/code/server/transaction/intent_handler.go index 555ff75d..da8bf4d3 100644 --- a/pkg/code/server/transaction/intent_handler.go +++ b/pkg/code/server/transaction/intent_handler.go @@ -176,7 +176,7 @@ func (h *OpenAccountsIntentHandler) AllowCreation(ctx context.Context, intentRec // Part 5: Validate fee payments // - return validateFeePayments(simResult) + return validateFeePayments(ctx, h.data, h.conf, intentRecord, simResult) } func (h *OpenAccountsIntentHandler) validateActions(ctx context.Context, initiatiorOwnerAccount *common.Account, actions []*transactionpb.Action) error { @@ -300,10 +300,12 @@ func (h *SendPublicPaymentIntentHandler) PopulateMetadata(ctx context.Context, i if destinationAccountInfo != nil { intentRecord.SendPublicPaymentMetadata.DestinationOwnerAccount = destinationAccountInfo.OwnerAccount - } - - if intentRecord.SendPublicPaymentMetadata.IsRemoteSend && intentRecord.SendPublicPaymentMetadata.IsWithdrawal { - return newIntentValidationError("remote send cannot be a withdraw") + } else if typedProtoMetadata.IsWithdrawal && typedProtoMetadata.DestinationOwner != nil { + destinationOwner, err := common.NewAccountFromProto(typedProtoMetadata.DestinationOwner) + if err != nil { + return err + } + intentRecord.SendPublicPaymentMetadata.DestinationOwnerAccount = destinationOwner.PublicKey().ToBase58() } return nil @@ -414,7 +416,7 @@ func (h *SendPublicPaymentIntentHandler) AllowCreation(ctx context.Context, inte // Part 6: Validate fee payments // - err = validateFeePayments(simResult) + err = validateFeePayments(ctx, h.data, h.conf, intentRecord, simResult) if err != nil { return err } @@ -434,7 +436,6 @@ func (h *SendPublicPaymentIntentHandler) AllowCreation(ctx context.Context, inte ) } -// todo: For remote send, we still need to fully validate the auto-return action func (h *SendPublicPaymentIntentHandler) validateActions( ctx context.Context, initiatorOwnerAccount *common.Account, @@ -444,13 +445,6 @@ func (h *SendPublicPaymentIntentHandler) validateActions( actions []*transactionpb.Action, simResult *LocalSimulationResult, ) error { - if !metadata.IsRemoteSend && len(actions) != 1 { - return newIntentValidationError("expected 1 action") - } - if metadata.IsRemoteSend && len(actions) != 3 { - return newIntentValidationError("expected 3 actions") - } - var source *common.Account var err error if metadata.Source != nil { @@ -472,7 +466,32 @@ func (h *SendPublicPaymentIntentHandler) validateActions( return err } - // Part 1: Check the source and destination accounts are valid + // + // Part 1: High-level action validation based on intent metadata + // + + if metadata.IsRemoteSend && metadata.IsWithdrawal { + return newIntentValidationError("remote send cannot be a withdraw") + } + + if !metadata.IsWithdrawal && !metadata.IsRemoteSend && len(actions) != 1 { + return newIntentValidationError("expected 1 action for payment") + } + if metadata.IsWithdrawal && len(actions) != 1 && len(actions) != 2 { + return newIntentValidationError("expected 1 or 2 actions for withdrawal") + } + if metadata.IsRemoteSend && len(actions) != 3 { + return newIntentValidationError("expected 3 actions for remote send") + } + + // + // Part 2: Check the source and destination accounts are valid + // + + sourceAccountRecords, ok := initiatorAccountsByVault[source.PublicKey().ToBase58()] + if !ok || sourceAccountRecords.General.AccountType != commonpb.AccountType_PRIMARY { + return newIntentValidationError("source account must be a deposit account") + } destinationAccountInfo, err := h.data.GetAccountInfoByTokenAddress(ctx, destination.PublicKey().ToBase58()) switch err { @@ -487,48 +506,115 @@ func (h *SendPublicPaymentIntentHandler) validateActions( return newIntentValidationError("destination account must be a deposit account") } + // Fee payments are not required for Code->Code public withdraws + if metadata.IsWithdrawal && simResult.HasAnyFeePayments() { + return newIntentValidationErrorf("%s fee payment not required for code destination", transactionpb.FeePaymentAction_CREATE_ON_SEND_WITHDRAWAL.String()) + } + // And the destination cannot be the source of funds, since that results in a no-op if source.PublicKey().ToBase58() == destinationAccountInfo.TokenAccount { return newIntentValidationError("payment is a no-op") } case account.ErrAccountInfoNotFound: - // Check whether the destination account is a core mint token account that's - // been created on the blockchain. Exception is made when we're doing a remote - // send, since we expect the gift card account to no yet exist. - if !metadata.IsRemoteSend && !h.conf.disableBlockchainChecks.Get(ctx) { + err = func() error { + // Destination is to a brand new gift card that will be created as part of this + // intent + if metadata.IsRemoteSend { + return nil + } + + // All payments to external destinations must be withdraws + if !metadata.IsWithdrawal { + return newIntentValidationError("payments to external destinations must be withdrawals") + } + + // Ensure the destination is the core mint ATA for the client-provided owner, + // if provided. We'll check later if this is absolutely required. + if metadata.DestinationOwner != nil { + destinationOwner, err := common.NewAccountFromProto(metadata.DestinationOwner) + if err != nil { + return err + } + + ata, err := destinationOwner.ToAssociatedTokenAccount(common.CoreMintAccount) + if err != nil { + return err + } + + if ata.PublicKey().ToBase58() != destination.PublicKey().ToBase58() { + return newIntentValidationErrorf("destination is not the ata for %s", destinationOwner.PublicKey().ToBase58()) + } + } + + // Technically we should always enforce a fee payment, but these checks are only + // disabled for tests + if h.conf.disableBlockchainChecks.Get(ctx) { + return nil + } + + // Check whether the destination account is a core mint token account that's + // been created on the blockchain. If not, a fee is required err = validateExternalTokenAccountWithinIntent(ctx, h.data, destination) - if err != nil { - return err + switch err { + case nil: + if simResult.HasAnyFeePayments() { + return newIntentValidationErrorf("%s fee payment not required when external destination exists", transactionpb.FeePaymentAction_CREATE_ON_SEND_WITHDRAWAL.String()) + } + default: + if !strings.Contains(strings.ToLower(err.Error()), "doesn't exist on the blockchain") { + return err + } + + if !simResult.HasAnyFeePayments() { + return newIntentValidationErrorf("%s fee payment is required", transactionpb.FeePaymentAction_CREATE_ON_SEND_WITHDRAWAL.String()) + } + + if metadata.DestinationOwner == nil { + return newIntentValidationError("destination owner account is required to derive ata") + } } + + return nil + }() + if err != nil { + return err } default: return err } - sourceAccountRecords, ok := initiatorAccountsByVault[source.PublicKey().ToBase58()] - if !ok || sourceAccountRecords.General.AccountType != commonpb.AccountType_PRIMARY { - return newIntentValidationError("source account must be a deposit account") - } - // - // Part 2: Validate actions match intent metadata + // Part 3 Validate actions match intent metadata // // - // Part 2.1: Check destination account is paid exact quark amount from the deposit account + // Part 3.1: Check destination account is paid exact quark amount from the deposit account + // minus any fees // + expectedDestinationPayment := int64(metadata.ExchangeData.Quarks) + + // Minimal validation required here since validateFeePayments generically handles + // most checks that isn't specific to an intent. + feePayments := simResult.GetFeePayments() + if len(feePayments) > 1 { + return newIntentValidationError("expected at most 1 fee payment") + } + for _, feePayment := range feePayments { + expectedDestinationPayment += feePayment.DeltaQuarks + } + destinationSimulation, ok := simResult.SimulationsByAccount[destination.PublicKey().ToBase58()] if !ok { return newIntentValidationErrorf("must send payment to destination account %s", destination.PublicKey().ToBase58()) } else if destinationSimulation.Transfers[0].IsPrivate || destinationSimulation.Transfers[0].IsWithdraw { return newActionValidationError(destinationSimulation.Transfers[0].Action, "payment sent to destination must be a public transfer") - } else if destinationSimulation.GetDeltaQuarks(false) != int64(metadata.ExchangeData.Quarks) { - return newActionValidationErrorf(destinationSimulation.Transfers[0].Action, "must send %d quarks to destination account", metadata.ExchangeData.Quarks) + } else if destinationSimulation.GetDeltaQuarks(false) != expectedDestinationPayment { + return newActionValidationErrorf(destinationSimulation.Transfers[0].Action, "must send %d quarks to destination account", expectedDestinationPayment) } // - // Part 2.2: Check that the user's deposit account was used as the source of funds + // Part 3.2: Check that the user's deposit account was used as the source of funds // as specified in the metadata // @@ -539,14 +625,14 @@ func (h *SendPublicPaymentIntentHandler) validateActions( return newActionValidationErrorf(sourceSimulation.Transfers[0].Action, "must send %d quarks from source account", metadata.ExchangeData.Quarks) } - // Part 3: Generic validation of actions that move money + // Part 4: Generic validation of actions that move money err = validateMoneyMovementActionUserAccounts(intent.SendPublicPayment, initiatorAccountsByVault, actions) if err != nil { return err } - // Part 4: Validate open and closed accounts + // Part 5: Validate open and closed accounts if metadata.IsRemoteSend { if len(simResult.GetOpenedAccounts()) != 1 { @@ -789,7 +875,7 @@ func (h *ReceivePaymentsPubliclyIntentHandler) AllowCreation(ctx context.Context // Part 6: Validate fee payments // - err = validateFeePayments(simResult) + err = validateFeePayments(ctx, h.data, h.conf, intentRecord, simResult) if err != nil { return err } @@ -1062,7 +1148,6 @@ func validateExternalTokenAccountWithinIntent(ctx context.Context, data code_dat } func validateExchangeDataWithinIntent(ctx context.Context, data code_data.Provider, proto *transactionpb.ExchangeData) error { - // Validate exchange data fully using the common method isValid, message, err := currency_util.ValidateClientExchangeData(ctx, data, proto) if err != nil { return err @@ -1075,10 +1160,82 @@ func validateExchangeDataWithinIntent(ctx context.Context, data code_data.Provid return nil } -func validateFeePayments(simResult *LocalSimulationResult) error { - if simResult.HasAnyFeePayments() { +func validateFeePayments( + ctx context.Context, + data code_data.Provider, + conf *conf, + intentRecord *intent.Record, + simResult *LocalSimulationResult, +) error { + var isFeeOptional bool + var expectedFeeType transactionpb.FeePaymentAction_FeeType + switch intentRecord.IntentType { + case intent.SendPublicPayment: + if intentRecord.SendPublicPaymentMetadata.IsWithdrawal { + isFeeOptional = true + expectedFeeType = transactionpb.FeePaymentAction_CREATE_ON_SEND_WITHDRAWAL + } + } + + if simResult.HasAnyFeePayments() && expectedFeeType == transactionpb.FeePaymentAction_UNKNOWN { return newIntentValidationError("intent doesn't require a fee payment") } + if expectedFeeType == transactionpb.FeePaymentAction_UNKNOWN { + return nil + } + + if !simResult.HasAnyFeePayments() && !isFeeOptional { + return newIntentValidationErrorf("expected a %s fee payment", expectedFeeType.String()) + } + if !simResult.HasAnyFeePayments() && isFeeOptional { + return nil + } + + feePayments := simResult.GetFeePayments() + if len(feePayments) > 1 { + return newIntentValidationError("expected at most 1 fee payment") + } else if len(feePayments) == 0 { + return nil + } + feePayment := feePayments[0] + + if feePayment.Action.GetFeePayment().Type != expectedFeeType { + return newActionValidationErrorf(feePayment.Action, "expected a %s fee payment", expectedFeeType.String()) + } + + var expectedUsdValue float64 + switch expectedFeeType { + case transactionpb.FeePaymentAction_CREATE_ON_SEND_WITHDRAWAL: + expectedUsdValue = conf.createOnSendWithdrawalUsdFee.Get(ctx) + default: + return errors.New("unhandled fee type") + } + + feeAmount := feePayment.DeltaQuarks + if feeAmount >= 0 { + return newActionValidationError(feePayment.Action, "fee payment amount is negative") + } + feeAmount = -feeAmount // Because it's coming out of a user account in this simulation + + var foundUsdExchangeRecord bool + usdExchangeRecords, err := currency_util.GetPotentialClientExchangeRates(ctx, data, currency_lib.USD) + if err != nil { + return err + } + for _, exchangeRecord := range usdExchangeRecords { + usdValue := exchangeRecord.Rate * float64(feeAmount) / float64(common.CoreMintQuarksPerUnit) + + // Allow for some small margin of error + if usdValue > expectedUsdValue-0.0001 && usdValue < expectedUsdValue+0.0001 { + foundUsdExchangeRecord = true + break + } + } + + if !foundUsdExchangeRecord { + return newActionValidationErrorf(feePayment.Action, "code fee payment amount must be $%.2f USD", expectedUsdValue) + } + return nil } diff --git a/pkg/code/server/transaction/local_simulation.go b/pkg/code/server/transaction/local_simulation.go index f1364779..89699eb2 100644 --- a/pkg/code/server/transaction/local_simulation.go +++ b/pkg/code/server/transaction/local_simulation.go @@ -29,7 +29,6 @@ type TokenAccountSimulation struct { IsAutoReturned bool } -// todo: Make it easier to extract accounts from a TransferSimulation (see some fee payment validation logic) type TransferSimulation struct { Action *transactionpb.Action IsPrivate bool @@ -153,9 +152,7 @@ func LocalSimulation(ctx context.Context, data code_data.Provider, actions []*tr }, }, }, - // todo: Doesn't specify destination, but that's not required yet, - // and makes other validation more complex since it's based - // on the simulation. + // todo: Doesn't specify destination, but that's not required yet ) case *transactionpb.Action_NoPrivacyWithdraw: source, err := common.NewAccountFromProto(typedAction.NoPrivacyWithdraw.Source) diff --git a/pkg/code/transaction/instruction.go b/pkg/code/transaction/instruction.go index 868b7d29..c4167535 100644 --- a/pkg/code/transaction/instruction.go +++ b/pkg/code/transaction/instruction.go @@ -6,8 +6,6 @@ import ( "github.com/code-payments/code-server/pkg/solana/system" ) -// todo: start moving instruction construction code here - func makeAdvanceNonceInstruction(nonce *common.Account) (solana.Instruction, error) { return system.AdvanceNonce( nonce.PublicKey().ToBytes(), diff --git a/pkg/code/transaction/transaction.go b/pkg/code/transaction/transaction.go index 3a183dd8..64429151 100644 --- a/pkg/code/transaction/transaction.go +++ b/pkg/code/transaction/transaction.go @@ -10,6 +10,7 @@ import ( "github.com/code-payments/code-server/pkg/solana" compute_budget "github.com/code-payments/code-server/pkg/solana/computebudget" "github.com/code-payments/code-server/pkg/solana/cvm" + "github.com/code-payments/code-server/pkg/solana/token" ) // todo: The argument sizes are blowing out of proportion, though there's likely @@ -210,7 +211,7 @@ func MakeInternalTransferWithAuthorityTransaction( destinationMemory *common.Account, destinationIndex uint16, - kinAmountInQuarks uint64, + coreMintQuarks uint64, ) (solana.Transaction, error) { mergedMemoryBanks, err := MergeMemoryBanks(nonceMemory, sourceMemory, destinationMemory) if err != nil { @@ -218,7 +219,7 @@ func MakeInternalTransferWithAuthorityTransaction( } vixn := cvm.NewTransferVirtualInstruction(&cvm.TransferVirtualInstructionArgs{ - Amount: kinAmountInQuarks, + Amount: coreMintQuarks, Signature: cvm.Signature(virtualSignature), }) @@ -260,8 +261,11 @@ func MakeExternalTransferWithAuthorityTransaction( sourceMemory *common.Account, sourceIndex uint16, + isCreateOnSend bool, + externalDestinationOwner *common.Account, externalDestination *common.Account, - kinAmountInQuarks uint64, + + coreMintQuarks uint64, ) (solana.Transaction, error) { mergedMemoryBanks, err := MergeMemoryBanks(nonceMemory, sourceMemory) if err != nil { @@ -273,7 +277,7 @@ func MakeExternalTransferWithAuthorityTransaction( vmOmnibusPublicKeyBytes := ed25519.PublicKey(vmOmnibus.PublicKey().ToBytes()) vixn := cvm.NewExternalTransferVirtualInstruction(&cvm.TransferVirtualInstructionArgs{ - Amount: kinAmountInQuarks, + Amount: coreMintQuarks, Signature: cvm.Signature(virtualSignature), }) @@ -294,11 +298,34 @@ func MakeExternalTransferWithAuthorityTransaction( }, ) + computeLimit := 100_000 + if isCreateOnSend { + computeLimit = 125_000 + } + instructions := []solana.Instruction{ compute_budget.SetComputeUnitPrice(1_000), - compute_budget.SetComputeUnitLimit(100_000), - execInstruction, + compute_budget.SetComputeUnitLimit(uint32(computeLimit)), + } + if isCreateOnSend { + if externalDestinationOwner == nil { + return solana.Transaction{}, errors.New("destination owner is required") + } + + createIdempotentInstruction, ata, err := token.CreateAssociatedTokenAccountIdempotent( + common.GetSubsidizer().PublicKey().ToBytes(), + externalDestinationOwner.PublicKey().ToBytes(), + common.CoreMintAccount.PublicKey().ToBytes(), + ) + if err != nil { + return solana.Transaction{}, err + } else if !bytes.Equal(externalDestination.PublicKey().ToBytes(), ata) { + return solana.Transaction{}, errors.New("invalid destination owner") + } + + instructions = append(instructions, createIdempotentInstruction) } + instructions = append(instructions, execInstruction) return MakeNoncedTransaction(nonce, bh, instructions...) } diff --git a/pkg/pointer/pointer.go b/pkg/pointer/pointer.go index bcb5b5cd..e3a51bec 100644 --- a/pkg/pointer/pointer.go +++ b/pkg/pointer/pointer.go @@ -32,6 +32,36 @@ func StringCopy(value *string) *string { return String(*value) } +// Int32 returns a pointer to the provided int32 value +func Int32(value int32) *int32 { + return &value +} + +// Int32OrDefault returns the pointer if not nil, otherwise the default value +func Int32OrDefault(value *int32, defaultValue int32) *int32 { + if value != nil { + return value + } + return &defaultValue +} + +// Int32IfValid returns a pointer to the value if it's valid, otherwise nil +func Int32IfValid(valid bool, value int32) *int32 { + if valid { + return &value + } + return nil +} + +// Int32Copy returns a pointer that's a copy of the provided value +func Int32Copy(value *int32) *int32 { + if value == nil { + return nil + } + + return Int32(*value) +} + // Uint32 returns a pointer to the provided uint32 value func Uint32(value uint32) *uint32 { return &value diff --git a/pkg/solana/instruction.go b/pkg/solana/instruction.go index 295f5c96..e2e2006a 100644 --- a/pkg/solana/instruction.go +++ b/pkg/solana/instruction.go @@ -85,7 +85,7 @@ type Instruction struct { } // NewInstruction creates a new instruction. -func NewInstruction(program, data ed25519.PublicKey, accounts ...AccountMeta) Instruction { +func NewInstruction(program ed25519.PublicKey, data []byte, accounts ...AccountMeta) Instruction { return Instruction{ Program: program, Data: data, diff --git a/pkg/solana/token/associated.go b/pkg/solana/token/associated.go index f11a8e9a..d193cb0b 100644 --- a/pkg/solana/token/associated.go +++ b/pkg/solana/token/associated.go @@ -10,6 +10,11 @@ import ( "github.com/code-payments/code-server/pkg/solana/system" ) +const ( + commandCreate uint8 = iota + commandCreateIdempotent +) + // AssociatedTokenAccountProgramKey is the address of the associated token account program that should be used. // // Current key: ATokenGPvbdGVxr1b2hvZbsiqW5xWH25efTNsLJA8knL @@ -27,7 +32,7 @@ func GetAssociatedAccount(wallet, mint ed25519.PublicKey) (ed25519.PublicKey, er ) } -// Reference: https://github.com/solana-labs/solana-program-library/blob/0639953c7dd0f5228c3ceda3ba68fece3b46ff1d/associated-token-account/program/src/lib.rs#L54 +// Reference: https://github.com/solana-program/associated-token-account/blob/0588a2c3558cc93c31d27bcc96f97cf559a767bc/program/src/instruction.rs#L9-L17 func CreateAssociatedTokenAccount(subsidizer, wallet, mint ed25519.PublicKey) (solana.Instruction, ed25519.PublicKey, error) { addr, err := GetAssociatedAccount(wallet, mint) if err != nil { @@ -36,7 +41,7 @@ func CreateAssociatedTokenAccount(subsidizer, wallet, mint ed25519.PublicKey) (s return solana.NewInstruction( AssociatedTokenAccountProgramKey, - []byte{}, + []byte{commandCreate}, solana.NewAccountMeta(subsidizer, true), solana.NewAccountMeta(addr, false), solana.NewReadonlyAccountMeta(wallet, false), @@ -63,9 +68,12 @@ func DecompileCreateAssociatedAccount(m solana.Message, index int) (*DecompiledC if !bytes.Equal(m.Accounts[i.ProgramIndex], AssociatedTokenAccountProgramKey) { return nil, solana.ErrIncorrectProgram } - if len(i.Data) != 0 { + if len(i.Data) != 1 { return nil, errors.Errorf("unexpected data") } + if i.Data[0] != commandCreate { + return nil, errors.Errorf("unexpected instruction data") + } if len(i.Accounts) != 7 { return nil, errors.Errorf("invalid number of accounts: %d (expected %d)", len(i.Accounts), 7) } @@ -87,3 +95,67 @@ func DecompileCreateAssociatedAccount(m solana.Message, index int) (*DecompiledC Mint: m.Accounts[i.Accounts[3]], }, nil } + +// Reference: https://github.com/solana-program/associated-token-account/blob/0588a2c3558cc93c31d27bcc96f97cf559a767bc/program/src/instruction.rs#L19-L28 +func CreateAssociatedTokenAccountIdempotent(subsidizer, wallet, mint ed25519.PublicKey) (solana.Instruction, ed25519.PublicKey, error) { + addr, err := GetAssociatedAccount(wallet, mint) + if err != nil { + return solana.Instruction{}, nil, err + } + + return solana.NewInstruction( + AssociatedTokenAccountProgramKey, + []byte{commandCreateIdempotent}, + solana.NewAccountMeta(subsidizer, true), + solana.NewAccountMeta(addr, false), + solana.NewReadonlyAccountMeta(wallet, false), + solana.NewReadonlyAccountMeta(mint, false), + solana.NewReadonlyAccountMeta(system.ProgramKey[:], false), + solana.NewReadonlyAccountMeta(ProgramKey, false), + solana.NewReadonlyAccountMeta(system.RentSysVar, false), + ), addr, nil +} + +type DecompiledCreateAssociatedAccountIdempotent struct { + Subsidizer ed25519.PublicKey + Address ed25519.PublicKey + Owner ed25519.PublicKey + Mint ed25519.PublicKey +} + +func DecompileCreateAssociatedAccountIdempotent(m solana.Message, index int) (*DecompiledCreateAssociatedAccountIdempotent, error) { + if index >= len(m.Instructions) { + return nil, errors.Errorf("instruction doesn't exist at %d", index) + } + + i := m.Instructions[index] + if !bytes.Equal(m.Accounts[i.ProgramIndex], AssociatedTokenAccountProgramKey) { + return nil, solana.ErrIncorrectProgram + } + if len(i.Data) != 1 { + return nil, errors.Errorf("unexpected data") + } + if i.Data[0] != commandCreateIdempotent { + return nil, errors.Errorf("unexpected instruction data") + } + if len(i.Accounts) != 7 { + return nil, errors.Errorf("invalid number of accounts: %d (expected %d)", len(i.Accounts), 7) + } + + if !bytes.Equal(m.Accounts[i.Accounts[4]], system.ProgramKey[:]) { + return nil, errors.Errorf("system program key mismatch") + } + if !bytes.Equal(m.Accounts[i.Accounts[5]], ProgramKey) { + return nil, errors.Errorf("token program key mismatch") + } + if !bytes.Equal(m.Accounts[i.Accounts[6]], system.RentSysVar) { + return nil, errors.Errorf("rent sysvar mismatch") + } + + return &DecompiledCreateAssociatedAccountIdempotent{ + Subsidizer: m.Accounts[i.Accounts[0]], + Address: m.Accounts[i.Accounts[1]], + Owner: m.Accounts[i.Accounts[2]], + Mint: m.Accounts[i.Accounts[3]], + }, nil +} diff --git a/pkg/solana/token/associated_test.go b/pkg/solana/token/associated_test.go index b604b6ef..23fbfb4e 100644 --- a/pkg/solana/token/associated_test.go +++ b/pkg/solana/token/associated_test.go @@ -35,7 +35,8 @@ func TestCreateAssociatedAccount(t *testing.T) { require.NoError(t, err) assert.Equal(t, expectedAddr, addr) - assert.Empty(t, instruction.Data) + assert.Len(t, instruction.Data, 1) + assert.Equal(t, commandCreate, instruction.Data[0]) assert.Equal(t, 7, len(instruction.Accounts)) assert.True(t, instruction.Accounts[0].IsSigner) assert.True(t, instruction.Accounts[0].IsWritable) @@ -56,3 +57,36 @@ func TestCreateAssociatedAccount(t *testing.T) { assert.Equal(t, keys[1], decompiled.Owner) assert.Equal(t, keys[2], decompiled.Mint) } + +func TestCreateAssociatedAccountIdempotent(t *testing.T) { + keys := generateKeys(t, 3) + + expectedAddr, err := GetAssociatedAccount(keys[1], keys[2]) + require.NoError(t, err) + + instruction, addr, err := CreateAssociatedTokenAccountIdempotent(keys[0], keys[1], keys[2]) + require.NoError(t, err) + assert.Equal(t, expectedAddr, addr) + + assert.Len(t, instruction.Data, 1) + assert.Equal(t, commandCreateIdempotent, instruction.Data[0]) + assert.Equal(t, 7, len(instruction.Accounts)) + assert.True(t, instruction.Accounts[0].IsSigner) + assert.True(t, instruction.Accounts[0].IsWritable) + assert.False(t, instruction.Accounts[1].IsSigner) + assert.True(t, instruction.Accounts[1].IsWritable) + for i := 2; i < len(instruction.Accounts); i++ { + assert.False(t, instruction.Accounts[i].IsSigner) + assert.False(t, instruction.Accounts[i].IsWritable) + } + + assert.EqualValues(t, system.ProgramKey[:], instruction.Accounts[4].PublicKey) + assert.EqualValues(t, ProgramKey, instruction.Accounts[5].PublicKey) + assert.EqualValues(t, system.RentSysVar, instruction.Accounts[6].PublicKey) + + decompiled, err := DecompileCreateAssociatedAccountIdempotent(solana.NewTransaction(keys[0], instruction).Message, 0) + assert.NoError(t, err) + assert.Equal(t, keys[0], decompiled.Subsidizer) + assert.Equal(t, keys[1], decompiled.Owner) + assert.Equal(t, keys[2], decompiled.Mint) +}