diff --git a/ocp/data/internal.go b/ocp/data/internal.go index 1f1f356..23f88e8 100644 --- a/ocp/data/internal.go +++ b/ocp/data/internal.go @@ -237,6 +237,7 @@ type DatabaseData interface { FreeVmMemoryByIndex(ctx context.Context, memoryAccount string, index uint16) error FreeVmMemoryByAddress(ctx context.Context, address string) error ReserveVmMemory(ctx context.Context, vm string, accountType vm.VirtualAccountType, address string) (string, uint16, error) + GetVmMemoryLocationByAddress(ctx context.Context, address string) (string, uint16, error) // VM Storage // -------------------------------------------------------------------------------- @@ -876,6 +877,9 @@ func (dp *DatabaseProvider) FreeVmMemoryByAddress(ctx context.Context, address s func (dp *DatabaseProvider) ReserveVmMemory(ctx context.Context, vm string, accountType vm.VirtualAccountType, address string) (string, uint16, error) { return dp.vmRam.ReserveMemory(ctx, vm, accountType, address) } +func (dp *DatabaseProvider) GetVmMemoryLocationByAddress(ctx context.Context, address string) (string, uint16, error) { + return dp.vmRam.GetMemoryLocationByAddress(ctx, address) +} // VM Storage // -------------------------------------------------------------------------------- diff --git a/ocp/data/vm/ram/memory/store.go b/ocp/data/vm/ram/memory/store.go index f41f523..ff485dd 100644 --- a/ocp/data/vm/ram/memory/store.go +++ b/ocp/data/vm/ram/memory/store.go @@ -112,6 +112,20 @@ func (s *store) ReserveMemory(_ context.Context, vm string, accountType vm.Virtu return "", 0, ram.ErrNoFreeMemory } +// GetMemoryLocationByAddress implements vm.ram.Store.GetMemoryLocationByAddress +func (s *store) GetMemoryLocationByAddress(_ context.Context, address string) (string, uint16, error) { + s.mu.Lock() + defer s.mu.Unlock() + + reservationKey, ok := s.storedVirtualAccounts[address] + if !ok { + return "", 0, ram.ErrNotReserved + } + + memoryAccount, index := parseAccountIndexKey(reservationKey) + return memoryAccount, index, nil +} + func (s *store) find(data *ram.Record) *ram.Record { for _, item := range s.records { if item.Id == data.Id { @@ -154,3 +168,18 @@ func (s *store) reset() { func getAccountIndexKey(memoryAccount string, index uint16) string { return fmt.Sprintf("%s:%d", memoryAccount, index) } + +func parseAccountIndexKey(key string) (string, uint16) { + var memoryAccount string + var index uint16 + fmt.Sscanf(key, "%s", &memoryAccount) + // Find the last colon and parse + for i := len(key) - 1; i >= 0; i-- { + if key[i] == ':' { + memoryAccount = key[:i] + fmt.Sscanf(key[i+1:], "%d", &index) + break + } + } + return memoryAccount, index +} diff --git a/ocp/data/vm/ram/postgres/model.go b/ocp/data/vm/ram/postgres/model.go index e1f56d1..5e80dd0 100644 --- a/ocp/data/vm/ram/postgres/model.go +++ b/ocp/data/vm/ram/postgres/model.go @@ -211,3 +211,18 @@ func dbReserveMemory(ctx context.Context, db *sqlx.DB, vm string, accountType vm }) return memoryAccount, index, err } + +func dbGetMemoryLocationByAddress(ctx context.Context, db *sqlx.DB, address string) (string, uint16, error) { + var model allocatedMemoryModel + + query := `SELECT id, vm, memory_account, index, is_allocated, stored_account_type, address, last_updated_at + FROM ` + allocatedMemoryTableName + ` + WHERE address = $1 AND is_allocated` + + err := db.QueryRowxContext(ctx, query, address).StructScan(&model) + if err != nil { + return "", 0, pgutil.CheckNoRows(err, ram.ErrNotReserved) + } + + return model.MemoryAccount, model.Index, nil +} diff --git a/ocp/data/vm/ram/postgres/store.go b/ocp/data/vm/ram/postgres/store.go index c6d3c3d..bac5e5c 100644 --- a/ocp/data/vm/ram/postgres/store.go +++ b/ocp/data/vm/ram/postgres/store.go @@ -53,3 +53,8 @@ func (s *store) FreeMemoryByAddress(ctx context.Context, address string) error { func (s *store) ReserveMemory(ctx context.Context, vm string, accountType vm.VirtualAccountType, address string) (string, uint16, error) { return dbReserveMemory(ctx, s.db, vm, accountType, address) } + +// GetMemoryLocationByAddress implements vm.ram.Store.GetMemoryLocationByAddress +func (s *store) GetMemoryLocationByAddress(ctx context.Context, address string) (string, uint16, error) { + return dbGetMemoryLocationByAddress(ctx, s.db, address) +} diff --git a/ocp/data/vm/ram/store.go b/ocp/data/vm/ram/store.go index 93daf9a..1eaebe0 100644 --- a/ocp/data/vm/ram/store.go +++ b/ocp/data/vm/ram/store.go @@ -32,4 +32,7 @@ type Store interface { // ReserveMemory reserves a piece of memory in a VM for the virtual account address ReserveMemory(ctx context.Context, vm string, accountType vm.VirtualAccountType, address string) (string, uint16, error) + + // GetMemoryLocationByAddress returns the memory account and index for a reserved virtual account address + GetMemoryLocationByAddress(ctx context.Context, address string) (string, uint16, error) } diff --git a/ocp/data/vm/ram/tests/tests.go b/ocp/data/vm/ram/tests/tests.go index 5611096..5612a0d 100644 --- a/ocp/data/vm/ram/tests/tests.go +++ b/ocp/data/vm/ram/tests/tests.go @@ -107,5 +107,18 @@ func testHappyPath(t *testing.T, s ram.Store) { _, _, err = s.ReserveMemory(ctx, "vm1", vm.VirtualAccountTypeTimelock, "newvirtualaccount3") assert.Equal(t, ram.ErrNoFreeMemory, err) + + memoryAccount, index, err := s.GetMemoryLocationByAddress(ctx, "newvirtualaccount1") + require.NoError(t, err) + assert.Equal(t, "memoryaccount1", memoryAccount) + assert.Equal(t, freedIndex1, index) + + memoryAccount, index, err = s.GetMemoryLocationByAddress(ctx, "newvirtualaccount2") + require.NoError(t, err) + assert.Equal(t, "memoryaccount1", memoryAccount) + assert.Equal(t, freedIndex2, index) + + _, _, err = s.GetMemoryLocationByAddress(ctx, "nonexistent") + assert.Equal(t, ram.ErrNotReserved, err) }) } diff --git a/ocp/rpc/transaction/server.go b/ocp/rpc/transaction/server.go index 8a6516f..d9ae657 100644 --- a/ocp/rpc/transaction/server.go +++ b/ocp/rpc/transaction/server.go @@ -7,7 +7,6 @@ import ( "go.uber.org/zap" - indexerpb "github.com/code-payments/code-vm-indexer/generated/indexer/v1" transactionpb "github.com/code-payments/ocp-protobuf-api/generated/go/transaction/v1" "github.com/code-payments/ocp-server/ocp/aml" @@ -24,8 +23,7 @@ type transactionServer struct { log *zap.Logger - data ocp_data.Provider - vmIndexerClient indexerpb.IndexerClient + data ocp_data.Provider auth *auth_util.RPCSignatureVerifier @@ -51,7 +49,6 @@ type transactionServer struct { func NewTransactionServer( log *zap.Logger, data ocp_data.Provider, - vmIndexerClient indexerpb.IndexerClient, submitIntentIntegration SubmitIntentIntegration, airdropIntegration AirdropIntegration, antispamGuard *antispam.Guard, @@ -92,8 +89,7 @@ func NewTransactionServer( log: log, - data: data, - vmIndexerClient: vmIndexerClient, + data: data, auth: auth_util.NewRPCSignatureVerifier(log, data), diff --git a/ocp/rpc/transaction/swap.go b/ocp/rpc/transaction/swap.go index 4b6a2f7..49c68de 100644 --- a/ocp/rpc/transaction/swap.go +++ b/ocp/rpc/transaction/swap.go @@ -255,9 +255,9 @@ func (s *transactionServer) StatefulSwap(streamer transactionpb.Transaction_Stat // Section: On-demand account creation // - err = vm.EnsureVirtualTimelockAccountIsInitialized(ctx, s.data, s.vmIndexerClient, toMint, owner, true) + err = vm.EnsureVirtualTimelockAccountIsInitialized(ctx, s.data, ownerDestinationTimelockVault, false) if err != nil { - log.With(zap.Error(err)).Warn("timed out waiting for destination timelock account initialization") + log.With(zap.Error(err)).Warn("error ensuring destination virtual timelock account is initialized") return handleStatefulSwapError(streamer, err) } @@ -288,7 +288,6 @@ func (s *transactionServer) StatefulSwap(streamer transactionpb.Transaction_Stat if common.IsCoreMint(fromMint) { swapHandler = NewCurrencyCreatorBuySwapHandler( s.data, - s.vmIndexerClient, owner, swapAuthority, toMint, @@ -298,7 +297,6 @@ func (s *transactionServer) StatefulSwap(streamer transactionpb.Transaction_Stat } else if common.IsCoreMint(toMint) { swapHandler = NewCurrencyCreatorSellSwapHandler( s.data, - s.vmIndexerClient, owner, swapAuthority, fromMint, @@ -308,7 +306,6 @@ func (s *transactionServer) StatefulSwap(streamer transactionpb.Transaction_Stat } else { swapHandler = NewCurrencyCreatorBuySellSwapHandler( s.data, - s.vmIndexerClient, owner, swapAuthority, fromMint, diff --git a/ocp/rpc/transaction/swap_handler.go b/ocp/rpc/transaction/swap_handler.go index 25a9fc3..abf9dcf 100644 --- a/ocp/rpc/transaction/swap_handler.go +++ b/ocp/rpc/transaction/swap_handler.go @@ -3,8 +3,6 @@ package transaction import ( "context" - indexerpb "github.com/code-payments/code-vm-indexer/generated/indexer/v1" - "github.com/code-payments/ocp-server/ocp/common" ocp_data "github.com/code-payments/ocp-server/ocp/data" vm_util "github.com/code-payments/ocp-server/ocp/vm" @@ -36,8 +34,7 @@ type SwapHandler interface { } type CurrencyCreatorBuySwapHandler struct { - data ocp_data.Provider - vmIndexerClient indexerpb.IndexerClient + data ocp_data.Provider buyer *common.Account temporaryHolder *common.Account @@ -54,7 +51,6 @@ type CurrencyCreatorBuySwapHandler struct { func NewCurrencyCreatorBuySwapHandler( data ocp_data.Provider, - vmIndexerClient indexerpb.IndexerClient, buyer *common.Account, temporaryHolder *common.Account, mint *common.Account, @@ -62,8 +58,7 @@ func NewCurrencyCreatorBuySwapHandler( nonce *common.Account, ) SwapHandler { return &CurrencyCreatorBuySwapHandler{ - data: data, - vmIndexerClient: vmIndexerClient, + data: data, buyer: buyer, temporaryHolder: temporaryHolder, @@ -71,7 +66,7 @@ func NewCurrencyCreatorBuySwapHandler( amount: amount, nonce: nonce, - computeUnitLimit: 300_000, + computeUnitLimit: 150_000, computeUnitPrice: 1_000, memoValue: "buy_v0", } @@ -103,7 +98,12 @@ func (h *CurrencyCreatorBuySwapHandler) MakeInstructions(ctx context.Context) ([ return nil, err } - h.memoryAccount, h.memoryIndex, err = vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, h.vmIndexerClient, destinationVmConfig.Vm, h.buyer) + destinationTimelockAccounts, err := h.buyer.GetTimelockAccounts(destinationVmConfig) + if err != nil { + return nil, err + } + + h.memoryAccount, h.memoryIndex, err = vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, h.data, destinationTimelockAccounts.Vault) if err != nil { return nil, err } @@ -203,8 +203,7 @@ func (h *CurrencyCreatorBuySwapHandler) MakeInstructions(ctx context.Context) ([ } type CurrencyCreatorSellSwapHandler struct { - data ocp_data.Provider - vmIndexerClient indexerpb.IndexerClient + data ocp_data.Provider seller *common.Account temporaryHolder *common.Account @@ -221,7 +220,6 @@ type CurrencyCreatorSellSwapHandler struct { func NewCurrencyCreatorSellSwapHandler( data ocp_data.Provider, - vmIndexerClient indexerpb.IndexerClient, seller *common.Account, temporaryHolder *common.Account, mint *common.Account, @@ -229,8 +227,7 @@ func NewCurrencyCreatorSellSwapHandler( nonce *common.Account, ) SwapHandler { return &CurrencyCreatorSellSwapHandler{ - data: data, - vmIndexerClient: vmIndexerClient, + data: data, seller: seller, temporaryHolder: temporaryHolder, @@ -238,7 +235,7 @@ func NewCurrencyCreatorSellSwapHandler( amount: amount, nonce: nonce, - computeUnitLimit: 300_000, + computeUnitLimit: 175_000, computeUnitPrice: 1_000, memoValue: "sell_v0", } @@ -280,7 +277,12 @@ func (h *CurrencyCreatorSellSwapHandler) MakeInstructions(ctx context.Context) ( return nil, err } - h.memoryAccount, h.memoryIndex, err = vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, h.vmIndexerClient, destinationVmConfig.Vm, h.seller) + destinationTimelockAccounts, err := h.seller.GetTimelockAccounts(destinationVmConfig) + if err != nil { + return nil, err + } + + h.memoryAccount, h.memoryIndex, err = vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, h.data, destinationTimelockAccounts.Vault) if err != nil { return nil, err } @@ -370,8 +372,7 @@ func (h *CurrencyCreatorSellSwapHandler) MakeInstructions(ctx context.Context) ( } type CurrencyCreatorBuySellSwapHandler struct { - data ocp_data.Provider - vmIndexerClient indexerpb.IndexerClient + data ocp_data.Provider swapper *common.Account temporaryHolder *common.Account @@ -389,7 +390,6 @@ type CurrencyCreatorBuySellSwapHandler struct { func NewCurrencyCreatorBuySellSwapHandler( data ocp_data.Provider, - vmIndexerClient indexerpb.IndexerClient, swapper *common.Account, temporaryHolder *common.Account, fromMint *common.Account, @@ -398,8 +398,7 @@ func NewCurrencyCreatorBuySellSwapHandler( nonce *common.Account, ) SwapHandler { return &CurrencyCreatorBuySellSwapHandler{ - data: data, - vmIndexerClient: vmIndexerClient, + data: data, swapper: swapper, temporaryHolder: temporaryHolder, @@ -450,7 +449,12 @@ func (h *CurrencyCreatorBuySellSwapHandler) MakeInstructions(ctx context.Context return nil, err } - h.memoryAccount, h.memoryIndex, err = vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, h.vmIndexerClient, destinationVmConfig.Vm, h.swapper) + destinationTimelockAccounts, err := h.swapper.GetTimelockAccounts(destinationVmConfig) + if err != nil { + return nil, err + } + + h.memoryAccount, h.memoryIndex, err = vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, h.data, destinationTimelockAccounts.Vault) if err != nil { return nil, err } diff --git a/ocp/vm/virtual_account.go b/ocp/vm/virtual_account.go index 38ed8f1..9b58c25 100644 --- a/ocp/vm/virtual_account.go +++ b/ocp/vm/virtual_account.go @@ -14,18 +14,8 @@ import ( vm_program "github.com/code-payments/ocp-server/solana/vm" ) -func EnsureVirtualTimelockAccountIsInitialized(ctx context.Context, data ocp_data.Provider, vmIndexerClient indexerpb.IndexerClient, mint, owner *common.Account, waitForInitialization bool) error { - vmConfig, err := common.GetVmConfigForMint(ctx, data, mint) - if err != nil { - return err - } - - timelockAccounts, err := owner.GetTimelockAccounts(vmConfig) - if err != nil { - return err - } - - timelockRecord, err := data.GetTimelockByVault(ctx, timelockAccounts.Vault.PublicKey().ToBase58()) +func EnsureVirtualTimelockAccountIsInitialized(ctx context.Context, data ocp_data.Provider, vault *common.Account, waitForInitialization bool) error { + timelockRecord, err := data.GetTimelockByVault(ctx, vault.PublicKey().ToBase58()) if err != nil { return err } @@ -44,6 +34,8 @@ func EnsureVirtualTimelockAccountIsInitialized(ctx context.Context, data ocp_dat if err != nil { return err } + } else { + return nil } if !waitForInitialization { @@ -51,8 +43,12 @@ func EnsureVirtualTimelockAccountIsInitialized(ctx context.Context, data ocp_dat } for range 60 { - _, _, err := GetVirtualTimelockAccountLocationInMemory(ctx, vmIndexerClient, vmConfig.Vm, owner) - if err == nil { + timelockRecord, err := data.GetTimelockByVault(ctx, vault.PublicKey().ToBase58()) + if err != nil { + continue + } + + if timelockRecord.ExistsOnBlockchain() { return nil } @@ -101,12 +97,18 @@ func GetVirtualTimelockAccountStateInMemory(ctx context.Context, vmIndexerClient return &state, memory, uint16(protoMemory.Index), nil } -func GetVirtualTimelockAccountLocationInMemory(ctx context.Context, vmIndexerClient indexerpb.IndexerClient, vm, owner *common.Account) (*common.Account, uint16, error) { - _, memory, memoryIndex, err := GetVirtualTimelockAccountStateInMemory(ctx, vmIndexerClient, vm, owner) +func GetVirtualTimelockAccountLocationInMemory(ctx context.Context, data ocp_data.Provider, vault *common.Account) (*common.Account, uint16, error) { + memoryAddress, index, err := data.GetVmMemoryLocationByAddress(ctx, vault.PublicKey().ToBase58()) if err != nil { return nil, 0, err } - return memory, memoryIndex, nil + + memoryAccount, err := common.NewAccountFromPublicKeyString(memoryAddress) + if err != nil { + return nil, 0, err + } + + return memoryAccount, index, err } func GetVirtualDurableNonceAccountStateInMemory(ctx context.Context, vmIndexerClient indexerpb.IndexerClient, vm, nonce *common.Account) (*vm_program.VirtualDurableNonce, *common.Account, uint16, error) { diff --git a/ocp/worker/geyser/backup.go b/ocp/worker/geyser/backup.go index 5b842db..5310aed 100644 --- a/ocp/worker/geyser/backup.go +++ b/ocp/worker/geyser/backup.go @@ -157,7 +157,7 @@ func (p *runtime) backupExternalDepositWorker(runtimeCtx context.Context, interv zap.String("mint", mintAccount.PublicKey().ToBase58()), ) - err = fixMissingExternalDeposits(tracedCtx, p.data, p.vmIndexerClient, p.integration, authorityAccount, mintAccount) + err = fixMissingExternalDeposits(tracedCtx, p.data, p.integration, authorityAccount, mintAccount) if err != nil { log.With(zap.Error(err)).Warn("failed to fix missing external deposits") } diff --git a/ocp/worker/geyser/external_deposit.go b/ocp/worker/geyser/external_deposit.go index 65ada04..683ed17 100644 --- a/ocp/worker/geyser/external_deposit.go +++ b/ocp/worker/geyser/external_deposit.go @@ -11,7 +11,6 @@ import ( "github.com/mr-tron/base58" "github.com/pkg/errors" - indexerpb "github.com/code-payments/code-vm-indexer/generated/indexer/v1" commonpb "github.com/code-payments/ocp-protobuf-api/generated/go/common/v1" "github.com/code-payments/ocp-server/cache" @@ -40,8 +39,8 @@ var ( syncedDepositCache = cache.NewCache(1_000_000) ) -func fixMissingExternalDeposits(ctx context.Context, data ocp_data.Provider, vmIndexerClient indexerpb.IndexerClient, integration Integration, userAuthority, mint *common.Account) error { - err := maybeInitiateExternalDepositIntoVm(ctx, data, vmIndexerClient, userAuthority, mint) +func fixMissingExternalDeposits(ctx context.Context, data ocp_data.Provider, integration Integration, userAuthority, mint *common.Account) error { + err := maybeInitiateExternalDepositIntoVm(ctx, data, userAuthority, mint) if err != nil { return errors.Wrap(err, "error depositing into the vm") } @@ -65,7 +64,7 @@ func fixMissingExternalDeposits(ctx context.Context, data ocp_data.Provider, vmI return markDepositsAsSynced(ctx, data, userAuthority, mint) } -func maybeInitiateExternalDepositIntoVm(ctx context.Context, data ocp_data.Provider, vmIndexerClient indexerpb.IndexerClient, userAuthority, mint *common.Account) error { +func maybeInitiateExternalDepositIntoVm(ctx context.Context, data ocp_data.Provider, userAuthority, mint *common.Account) error { vmConfig, err := common.GetVmConfigForMint(ctx, data, mint) if err != nil { return err @@ -86,10 +85,10 @@ func maybeInitiateExternalDepositIntoVm(ctx context.Context, data ocp_data.Provi if balance == 0 { return nil } - return initiateExternalDepositIntoVm(ctx, data, vmIndexerClient, userAuthority, mint, balance) + return initiateExternalDepositIntoVm(ctx, data, userAuthority, mint, balance) } -func initiateExternalDepositIntoVm(ctx context.Context, data ocp_data.Provider, vmIndexerClient indexerpb.IndexerClient, userAuthority, mint *common.Account, balance uint64) error { +func initiateExternalDepositIntoVm(ctx context.Context, data ocp_data.Provider, userAuthority, mint *common.Account, balance uint64) error { vmConfig, err := common.GetVmConfigForMint(ctx, data, mint) if err != nil { return errors.Wrap(err, "error getting vm config") @@ -100,12 +99,12 @@ func initiateExternalDepositIntoVm(ctx context.Context, data ocp_data.Provider, return errors.Wrap(err, "error getting timelock accounts") } - err = vm_util.EnsureVirtualTimelockAccountIsInitialized(ctx, data, vmIndexerClient, mint, userAuthority, true) + err = vm_util.EnsureVirtualTimelockAccountIsInitialized(ctx, data, timelockAccounts.Vault, true) if err != nil { return errors.Wrap(err, "error ensuring vta is initialized") } - memoryAccount, memoryIndex, err := vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, vmIndexerClient, vmConfig.Vm, userAuthority) + memoryAccount, memoryIndex, err := vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, data, timelockAccounts.Vault) if err != nil { return errors.Wrap(err, "error getting vta location in memory") } diff --git a/ocp/worker/geyser/handler.go b/ocp/worker/geyser/handler.go index 7f29bde..4bb43c0 100644 --- a/ocp/worker/geyser/handler.go +++ b/ocp/worker/geyser/handler.go @@ -7,7 +7,6 @@ import ( "github.com/mr-tron/base58" "github.com/pkg/errors" - indexerpb "github.com/code-payments/code-vm-indexer/generated/indexer/v1" "github.com/code-payments/ocp-server/ocp/config" geyserpb "github.com/code-payments/ocp-server/ocp/worker/geyser/api/gen" @@ -29,18 +28,16 @@ type ProgramAccountUpdateHandler interface { } type TokenProgramAccountHandler struct { - conf *conf - data ocp_data.Provider - vmIndexerClient indexerpb.IndexerClient - integration Integration + conf *conf + data ocp_data.Provider + integration Integration } -func NewTokenProgramAccountHandler(conf *conf, data ocp_data.Provider, vmIndexerClient indexerpb.IndexerClient, integration Integration) ProgramAccountUpdateHandler { +func NewTokenProgramAccountHandler(conf *conf, data ocp_data.Provider, integration Integration) ProgramAccountUpdateHandler { return &TokenProgramAccountHandler{ - conf: conf, - data: data, - vmIndexerClient: vmIndexerClient, - integration: integration, + conf: conf, + data: data, + integration: integration, } } @@ -106,7 +103,7 @@ func (h *TokenProgramAccountHandler) Handle(ctx context.Context, update *geyserp } if unmarshalled.Amount > 0 { - err = initiateExternalDepositIntoVm(ctx, h.data, h.vmIndexerClient, userAuthorityAccount, mintAccount, unmarshalled.Amount) + err = initiateExternalDepositIntoVm(ctx, h.data, userAuthorityAccount, mintAccount, unmarshalled.Amount) if err != nil { return errors.Wrap(err, "error depositing into the vm") } @@ -118,8 +115,8 @@ func (h *TokenProgramAccountHandler) Handle(ctx context.Context, update *geyserp } } -func initializeProgramAccountUpdateHandlers(conf *conf, data ocp_data.Provider, vmIndexerClient indexerpb.IndexerClient, integration Integration) map[string]ProgramAccountUpdateHandler { +func initializeProgramAccountUpdateHandlers(conf *conf, data ocp_data.Provider, integration Integration) map[string]ProgramAccountUpdateHandler { return map[string]ProgramAccountUpdateHandler{ - base58.Encode(token.ProgramKey): NewTokenProgramAccountHandler(conf, data, vmIndexerClient, integration), + base58.Encode(token.ProgramKey): NewTokenProgramAccountHandler(conf, data, integration), } } diff --git a/ocp/worker/geyser/runtime.go b/ocp/worker/geyser/runtime.go index ec7cff4..3eb21c0 100644 --- a/ocp/worker/geyser/runtime.go +++ b/ocp/worker/geyser/runtime.go @@ -7,7 +7,6 @@ import ( "go.uber.org/zap" - indexerpb "github.com/code-payments/code-vm-indexer/generated/indexer/v1" "github.com/code-payments/ocp-server/ocp/worker" geyserpb "github.com/code-payments/ocp-server/ocp/worker/geyser/api/gen" timelock_token "github.com/code-payments/ocp-server/solana/timelock/v1" @@ -22,10 +21,9 @@ type eventWorkerMetrics struct { // todo: we can consolidate the various subscription streams into one type runtime struct { - log *zap.Logger - data ocp_data.Provider - vmIndexerClient indexerpb.IndexerClient - conf *conf + log *zap.Logger + data ocp_data.Provider + conf *conf integration Integration @@ -46,16 +44,15 @@ type runtime struct { backupExternalDepositWorkerStatus bool } -func New(log *zap.Logger, data ocp_data.Provider, vmIndexerClient indexerpb.IndexerClient, integration Integration, configProvider ConfigProvider) worker.Runtime { +func New(log *zap.Logger, data ocp_data.Provider, integration Integration, configProvider ConfigProvider) worker.Runtime { conf := configProvider() return &runtime{ log: log, data: data, - vmIndexerClient: vmIndexerClient, conf: configProvider(), integration: integration, programUpdatesChan: make(chan *geyserpb.SubscribeUpdateAccount, conf.programUpdateQueueSize.Get(context.Background())), - programUpdateHandlers: initializeProgramAccountUpdateHandlers(conf, data, vmIndexerClient, integration), + programUpdateHandlers: initializeProgramAccountUpdateHandlers(conf, data, integration), programUpdateWorkerMetrics: make(map[int]*eventWorkerMetrics), } } diff --git a/ocp/worker/sequencer/fulfillment_handler.go b/ocp/worker/sequencer/fulfillment_handler.go index f92f47b..473b573 100644 --- a/ocp/worker/sequencer/fulfillment_handler.go +++ b/ocp/worker/sequencer/fulfillment_handler.go @@ -345,11 +345,6 @@ func (h *NoPrivacyTransferWithAuthorityFulfillmentHandler) MakeOnDemandTransacti return nil, nil, err } - sourceAuthority, err := common.NewAccountFromPublicKeyString(sourceAccountInfoRecord.AuthorityAccount) - if err != nil { - return nil, nil, err - } - mint, err := common.NewAccountFromPublicKeyString(sourceAccountInfoRecord.MintAccount) if err != nil { return nil, nil, err @@ -370,7 +365,7 @@ func (h *NoPrivacyTransferWithAuthorityFulfillmentHandler) MakeOnDemandTransacti return nil, nil, err } - sourceMemory, sourceIndex, err := vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, h.vmIndexerClient, vmConfig.Vm, sourceAuthority) + sourceMemory, sourceIndex, err := vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, h.data, sourceVault) if err != nil { return nil, nil, err } @@ -383,17 +378,7 @@ func (h *NoPrivacyTransferWithAuthorityFulfillmentHandler) MakeOnDemandTransacti var txn solana.Transaction var makeTxnErr error if isInternal { - destinationAccountInfoRecord, err := h.data.GetAccountInfoByTokenAddress(ctx, destinationToken.PublicKey().ToBase58()) - if err != nil { - return nil, nil, err - } - - destinationAuthority, err := common.NewAccountFromPublicKeyString(destinationAccountInfoRecord.AuthorityAccount) - if err != nil { - return nil, nil, err - } - - destinationMemory, destinationIndex, err := vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, h.vmIndexerClient, vmConfig.Vm, destinationAuthority) + destinationMemory, destinationIndex, err := vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, h.data, destinationToken) if err != nil { return nil, nil, err } @@ -564,11 +549,6 @@ func (h *NoPrivacyWithdrawFulfillmentHandler) MakeOnDemandTransaction(ctx contex return nil, nil, err } - sourceAuthority, err := common.NewAccountFromPublicKeyString(sourceAccountInfoRecord.AuthorityAccount) - if err != nil { - return nil, nil, err - } - mint, err := common.NewAccountFromPublicKeyString(sourceAccountInfoRecord.MintAccount) if err != nil { return nil, nil, err @@ -589,7 +569,7 @@ func (h *NoPrivacyWithdrawFulfillmentHandler) MakeOnDemandTransaction(ctx contex return nil, nil, err } - sourceMemory, sourceIndex, err := vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, h.vmIndexerClient, vmConfig.Vm, sourceAuthority) + sourceMemory, sourceIndex, err := vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, h.data, sourceVault) if err != nil { return nil, nil, err } @@ -602,17 +582,7 @@ func (h *NoPrivacyWithdrawFulfillmentHandler) MakeOnDemandTransaction(ctx contex var txn solana.Transaction var makeTxnErr error if isInternal { - destinationAccountInfoRecord, err := h.data.GetAccountInfoByTokenAddress(ctx, destinationToken.PublicKey().ToBase58()) - if err != nil { - return nil, nil, err - } - - destinationAuthority, err := common.NewAccountFromPublicKeyString(destinationAccountInfoRecord.AuthorityAccount) - if err != nil { - return nil, nil, err - } - - destinationMemory, destinationIndex, err := vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, h.vmIndexerClient, vmConfig.Vm, destinationAuthority) + destinationMemory, destinationIndex, err := vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, h.data, destinationToken) if err != nil { return nil, nil, err } diff --git a/ocp/worker/swap/util.go b/ocp/worker/swap/util.go index 387aa9e..157ef6a 100644 --- a/ocp/worker/swap/util.go +++ b/ocp/worker/swap/util.go @@ -19,6 +19,7 @@ import ( "github.com/code-payments/ocp-server/ocp/data/swap" "github.com/code-payments/ocp-server/ocp/data/transaction" transaction_util "github.com/code-payments/ocp-server/ocp/transaction" + vm_util "github.com/code-payments/ocp-server/ocp/vm" "github.com/code-payments/ocp-server/solana" ) @@ -391,6 +392,30 @@ func (p *runtime) validateExternalWalletFunding(ctx context.Context, record *swa return true, nil } +func (p *runtime) ensureSwapDestinationIsInitialized(ctx context.Context, record *swap.Record) error { + toMint, err := common.NewAccountFromPublicKeyString(record.ToMint) + if err != nil { + return err + } + + owner, err := common.NewAccountFromPublicKeyString(record.Owner) + if err != nil { + return err + } + + destinationVmConfig, err := common.GetVmConfigForMint(ctx, p.data, toMint) + if err != nil { + return err + } + + destinationTimelockVault, err := owner.ToTimelockVault(destinationVmConfig) + if err != nil { + return err + } + + return vm_util.EnsureVirtualTimelockAccountIsInitialized(ctx, p.data, destinationTimelockVault, true) +} + func getSwapDepositIntentID(signature string, destination *common.Account) string { combined := fmt.Sprintf("%s-%s", signature, destination.PublicKey().ToBase58()) hashed := sha256.Sum256([]byte(combined)) diff --git a/ocp/worker/swap/worker.go b/ocp/worker/swap/worker.go index 297933c..73d54c5 100644 --- a/ocp/worker/swap/worker.go +++ b/ocp/worker/swap/worker.go @@ -184,6 +184,12 @@ func (p *runtime) handleStateFunded(ctx context.Context, record *swap.Record) er // todo: Return funds if the amount was wrong return p.markSwapCancelled(ctx, record) } + + err = p.ensureSwapDestinationIsInitialized(ctx, record) + if err != nil { + return errors.Wrap(err, "error ensuring swap destination is initialized") + } + return p.markSwapSubmitting(ctx, record) }