Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ocp/data/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ type DatabaseData interface {
FreeVmMemoryByIndex(ctx context.Context, memoryAccount string, index uint16) error
FreeVmMemoryByAddress(ctx context.Context, address string) error
ReserveVmMemory(ctx context.Context, vm string, accountType vm.VirtualAccountType, address string) (string, uint16, error)
GetVmMemoryLocationByAddress(ctx context.Context, address string) (string, uint16, error)

// VM Storage
// --------------------------------------------------------------------------------
Expand Down Expand Up @@ -876,6 +877,9 @@ func (dp *DatabaseProvider) FreeVmMemoryByAddress(ctx context.Context, address s
func (dp *DatabaseProvider) ReserveVmMemory(ctx context.Context, vm string, accountType vm.VirtualAccountType, address string) (string, uint16, error) {
return dp.vmRam.ReserveMemory(ctx, vm, accountType, address)
}
func (dp *DatabaseProvider) GetVmMemoryLocationByAddress(ctx context.Context, address string) (string, uint16, error) {
return dp.vmRam.GetMemoryLocationByAddress(ctx, address)
}

// VM Storage
// --------------------------------------------------------------------------------
Expand Down
29 changes: 29 additions & 0 deletions ocp/data/vm/ram/memory/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,20 @@ func (s *store) ReserveMemory(_ context.Context, vm string, accountType vm.Virtu
return "", 0, ram.ErrNoFreeMemory
}

// GetMemoryLocationByAddress implements vm.ram.Store.GetMemoryLocationByAddress
func (s *store) GetMemoryLocationByAddress(_ context.Context, address string) (string, uint16, error) {
s.mu.Lock()
defer s.mu.Unlock()

reservationKey, ok := s.storedVirtualAccounts[address]
if !ok {
return "", 0, ram.ErrNotReserved
}

memoryAccount, index := parseAccountIndexKey(reservationKey)
return memoryAccount, index, nil
}

func (s *store) find(data *ram.Record) *ram.Record {
for _, item := range s.records {
if item.Id == data.Id {
Expand Down Expand Up @@ -154,3 +168,18 @@ func (s *store) reset() {
func getAccountIndexKey(memoryAccount string, index uint16) string {
return fmt.Sprintf("%s:%d", memoryAccount, index)
}

func parseAccountIndexKey(key string) (string, uint16) {
var memoryAccount string
var index uint16
fmt.Sscanf(key, "%s", &memoryAccount)
// Find the last colon and parse
for i := len(key) - 1; i >= 0; i-- {
if key[i] == ':' {
memoryAccount = key[:i]
fmt.Sscanf(key[i+1:], "%d", &index)
break
}
}
return memoryAccount, index
}
15 changes: 15 additions & 0 deletions ocp/data/vm/ram/postgres/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,18 @@ func dbReserveMemory(ctx context.Context, db *sqlx.DB, vm string, accountType vm
})
return memoryAccount, index, err
}

func dbGetMemoryLocationByAddress(ctx context.Context, db *sqlx.DB, address string) (string, uint16, error) {
var model allocatedMemoryModel

query := `SELECT id, vm, memory_account, index, is_allocated, stored_account_type, address, last_updated_at
FROM ` + allocatedMemoryTableName + `
WHERE address = $1 AND is_allocated`

err := db.QueryRowxContext(ctx, query, address).StructScan(&model)
if err != nil {
return "", 0, pgutil.CheckNoRows(err, ram.ErrNotReserved)
}

return model.MemoryAccount, model.Index, nil
}
5 changes: 5 additions & 0 deletions ocp/data/vm/ram/postgres/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,8 @@ func (s *store) FreeMemoryByAddress(ctx context.Context, address string) error {
func (s *store) ReserveMemory(ctx context.Context, vm string, accountType vm.VirtualAccountType, address string) (string, uint16, error) {
return dbReserveMemory(ctx, s.db, vm, accountType, address)
}

// GetMemoryLocationByAddress implements vm.ram.Store.GetMemoryLocationByAddress
func (s *store) GetMemoryLocationByAddress(ctx context.Context, address string) (string, uint16, error) {
return dbGetMemoryLocationByAddress(ctx, s.db, address)
}
3 changes: 3 additions & 0 deletions ocp/data/vm/ram/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,7 @@ type Store interface {

// ReserveMemory reserves a piece of memory in a VM for the virtual account address
ReserveMemory(ctx context.Context, vm string, accountType vm.VirtualAccountType, address string) (string, uint16, error)

// GetMemoryLocationByAddress returns the memory account and index for a reserved virtual account address
GetMemoryLocationByAddress(ctx context.Context, address string) (string, uint16, error)
}
13 changes: 13 additions & 0 deletions ocp/data/vm/ram/tests/tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,5 +107,18 @@ func testHappyPath(t *testing.T, s ram.Store) {

_, _, err = s.ReserveMemory(ctx, "vm1", vm.VirtualAccountTypeTimelock, "newvirtualaccount3")
assert.Equal(t, ram.ErrNoFreeMemory, err)

memoryAccount, index, err := s.GetMemoryLocationByAddress(ctx, "newvirtualaccount1")
require.NoError(t, err)
assert.Equal(t, "memoryaccount1", memoryAccount)
assert.Equal(t, freedIndex1, index)

memoryAccount, index, err = s.GetMemoryLocationByAddress(ctx, "newvirtualaccount2")
require.NoError(t, err)
assert.Equal(t, "memoryaccount1", memoryAccount)
assert.Equal(t, freedIndex2, index)

_, _, err = s.GetMemoryLocationByAddress(ctx, "nonexistent")
assert.Equal(t, ram.ErrNotReserved, err)
})
}
8 changes: 2 additions & 6 deletions ocp/rpc/transaction/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"go.uber.org/zap"

indexerpb "github.com/code-payments/code-vm-indexer/generated/indexer/v1"
transactionpb "github.com/code-payments/ocp-protobuf-api/generated/go/transaction/v1"

"github.com/code-payments/ocp-server/ocp/aml"
Expand All @@ -24,8 +23,7 @@ type transactionServer struct {

log *zap.Logger

data ocp_data.Provider
vmIndexerClient indexerpb.IndexerClient
data ocp_data.Provider

auth *auth_util.RPCSignatureVerifier

Expand All @@ -51,7 +49,6 @@ type transactionServer struct {
func NewTransactionServer(
log *zap.Logger,
data ocp_data.Provider,
vmIndexerClient indexerpb.IndexerClient,
submitIntentIntegration SubmitIntentIntegration,
airdropIntegration AirdropIntegration,
antispamGuard *antispam.Guard,
Expand Down Expand Up @@ -92,8 +89,7 @@ func NewTransactionServer(

log: log,

data: data,
vmIndexerClient: vmIndexerClient,
data: data,

auth: auth_util.NewRPCSignatureVerifier(log, data),

Expand Down
7 changes: 2 additions & 5 deletions ocp/rpc/transaction/swap.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,9 @@ func (s *transactionServer) StatefulSwap(streamer transactionpb.Transaction_Stat
// Section: On-demand account creation
//

err = vm.EnsureVirtualTimelockAccountIsInitialized(ctx, s.data, s.vmIndexerClient, toMint, owner, true)
err = vm.EnsureVirtualTimelockAccountIsInitialized(ctx, s.data, ownerDestinationTimelockVault, false)
if err != nil {
log.With(zap.Error(err)).Warn("timed out waiting for destination timelock account initialization")
log.With(zap.Error(err)).Warn("error ensuring destination virtual timelock account is initialized")
return handleStatefulSwapError(streamer, err)
}

Expand Down Expand Up @@ -288,7 +288,6 @@ func (s *transactionServer) StatefulSwap(streamer transactionpb.Transaction_Stat
if common.IsCoreMint(fromMint) {
swapHandler = NewCurrencyCreatorBuySwapHandler(
s.data,
s.vmIndexerClient,
owner,
swapAuthority,
toMint,
Expand All @@ -298,7 +297,6 @@ func (s *transactionServer) StatefulSwap(streamer transactionpb.Transaction_Stat
} else if common.IsCoreMint(toMint) {
swapHandler = NewCurrencyCreatorSellSwapHandler(
s.data,
s.vmIndexerClient,
owner,
swapAuthority,
fromMint,
Expand All @@ -308,7 +306,6 @@ func (s *transactionServer) StatefulSwap(streamer transactionpb.Transaction_Stat
} else {
swapHandler = NewCurrencyCreatorBuySellSwapHandler(
s.data,
s.vmIndexerClient,
owner,
swapAuthority,
fromMint,
Expand Down
44 changes: 24 additions & 20 deletions ocp/rpc/transaction/swap_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package transaction
import (
"context"

indexerpb "github.com/code-payments/code-vm-indexer/generated/indexer/v1"

"github.com/code-payments/ocp-server/ocp/common"
ocp_data "github.com/code-payments/ocp-server/ocp/data"
vm_util "github.com/code-payments/ocp-server/ocp/vm"
Expand Down Expand Up @@ -36,8 +34,7 @@ type SwapHandler interface {
}

type CurrencyCreatorBuySwapHandler struct {
data ocp_data.Provider
vmIndexerClient indexerpb.IndexerClient
data ocp_data.Provider

buyer *common.Account
temporaryHolder *common.Account
Expand All @@ -54,16 +51,14 @@ type CurrencyCreatorBuySwapHandler struct {

func NewCurrencyCreatorBuySwapHandler(
data ocp_data.Provider,
vmIndexerClient indexerpb.IndexerClient,
buyer *common.Account,
temporaryHolder *common.Account,
mint *common.Account,
amount uint64,
nonce *common.Account,
) SwapHandler {
return &CurrencyCreatorBuySwapHandler{
data: data,
vmIndexerClient: vmIndexerClient,
data: data,

buyer: buyer,
temporaryHolder: temporaryHolder,
Expand Down Expand Up @@ -103,7 +98,12 @@ func (h *CurrencyCreatorBuySwapHandler) MakeInstructions(ctx context.Context) ([
return nil, err
}

h.memoryAccount, h.memoryIndex, err = vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, h.vmIndexerClient, destinationVmConfig.Vm, h.buyer)
destinationTimelockAccounts, err := h.buyer.GetTimelockAccounts(destinationVmConfig)
if err != nil {
return nil, err
}

h.memoryAccount, h.memoryIndex, err = vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, h.data, destinationTimelockAccounts.Vault)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -203,8 +203,7 @@ func (h *CurrencyCreatorBuySwapHandler) MakeInstructions(ctx context.Context) ([
}

type CurrencyCreatorSellSwapHandler struct {
data ocp_data.Provider
vmIndexerClient indexerpb.IndexerClient
data ocp_data.Provider

seller *common.Account
temporaryHolder *common.Account
Expand All @@ -221,16 +220,14 @@ type CurrencyCreatorSellSwapHandler struct {

func NewCurrencyCreatorSellSwapHandler(
data ocp_data.Provider,
vmIndexerClient indexerpb.IndexerClient,
seller *common.Account,
temporaryHolder *common.Account,
mint *common.Account,
amount uint64,
nonce *common.Account,
) SwapHandler {
return &CurrencyCreatorSellSwapHandler{
data: data,
vmIndexerClient: vmIndexerClient,
data: data,

seller: seller,
temporaryHolder: temporaryHolder,
Expand Down Expand Up @@ -280,7 +277,12 @@ func (h *CurrencyCreatorSellSwapHandler) MakeInstructions(ctx context.Context) (
return nil, err
}

h.memoryAccount, h.memoryIndex, err = vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, h.vmIndexerClient, destinationVmConfig.Vm, h.seller)
destinationTimelockAccounts, err := h.seller.GetTimelockAccounts(destinationVmConfig)
if err != nil {
return nil, err
}

h.memoryAccount, h.memoryIndex, err = vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, h.data, destinationTimelockAccounts.Vault)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -370,8 +372,7 @@ func (h *CurrencyCreatorSellSwapHandler) MakeInstructions(ctx context.Context) (
}

type CurrencyCreatorBuySellSwapHandler struct {
data ocp_data.Provider
vmIndexerClient indexerpb.IndexerClient
data ocp_data.Provider

swapper *common.Account
temporaryHolder *common.Account
Expand All @@ -389,7 +390,6 @@ type CurrencyCreatorBuySellSwapHandler struct {

func NewCurrencyCreatorBuySellSwapHandler(
data ocp_data.Provider,
vmIndexerClient indexerpb.IndexerClient,
swapper *common.Account,
temporaryHolder *common.Account,
fromMint *common.Account,
Expand All @@ -398,8 +398,7 @@ func NewCurrencyCreatorBuySellSwapHandler(
nonce *common.Account,
) SwapHandler {
return &CurrencyCreatorBuySellSwapHandler{
data: data,
vmIndexerClient: vmIndexerClient,
data: data,

swapper: swapper,
temporaryHolder: temporaryHolder,
Expand Down Expand Up @@ -450,7 +449,12 @@ func (h *CurrencyCreatorBuySellSwapHandler) MakeInstructions(ctx context.Context
return nil, err
}

h.memoryAccount, h.memoryIndex, err = vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, h.vmIndexerClient, destinationVmConfig.Vm, h.swapper)
destinationTimelockAccounts, err := h.swapper.GetTimelockAccounts(destinationVmConfig)
if err != nil {
return nil, err
}

h.memoryAccount, h.memoryIndex, err = vm_util.GetVirtualTimelockAccountLocationInMemory(ctx, h.data, destinationTimelockAccounts.Vault)
if err != nil {
return nil, err
}
Expand Down
36 changes: 19 additions & 17 deletions ocp/vm/virtual_account.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,8 @@ import (
vm_program "github.com/code-payments/ocp-server/solana/vm"
)

func EnsureVirtualTimelockAccountIsInitialized(ctx context.Context, data ocp_data.Provider, vmIndexerClient indexerpb.IndexerClient, mint, owner *common.Account, waitForInitialization bool) error {
vmConfig, err := common.GetVmConfigForMint(ctx, data, mint)
if err != nil {
return err
}

timelockAccounts, err := owner.GetTimelockAccounts(vmConfig)
if err != nil {
return err
}

timelockRecord, err := data.GetTimelockByVault(ctx, timelockAccounts.Vault.PublicKey().ToBase58())
func EnsureVirtualTimelockAccountIsInitialized(ctx context.Context, data ocp_data.Provider, vault *common.Account, waitForInitialization bool) error {
timelockRecord, err := data.GetTimelockByVault(ctx, vault.PublicKey().ToBase58())
if err != nil {
return err
}
Expand All @@ -44,15 +34,21 @@ func EnsureVirtualTimelockAccountIsInitialized(ctx context.Context, data ocp_dat
if err != nil {
return err
}
} else {
return nil
}

if !waitForInitialization {
return nil
}

for range 60 {
_, _, err := GetVirtualTimelockAccountLocationInMemory(ctx, vmIndexerClient, vmConfig.Vm, owner)
if err == nil {
timelockRecord, err := data.GetTimelockByVault(ctx, vault.PublicKey().ToBase58())
if err != nil {
continue
}

if timelockRecord.ExistsOnBlockchain() {
return nil
}

Expand Down Expand Up @@ -101,12 +97,18 @@ func GetVirtualTimelockAccountStateInMemory(ctx context.Context, vmIndexerClient
return &state, memory, uint16(protoMemory.Index), nil
}

func GetVirtualTimelockAccountLocationInMemory(ctx context.Context, vmIndexerClient indexerpb.IndexerClient, vm, owner *common.Account) (*common.Account, uint16, error) {
_, memory, memoryIndex, err := GetVirtualTimelockAccountStateInMemory(ctx, vmIndexerClient, vm, owner)
func GetVirtualTimelockAccountLocationInMemory(ctx context.Context, data ocp_data.Provider, vault *common.Account) (*common.Account, uint16, error) {
memoryAddress, index, err := data.GetVmMemoryLocationByAddress(ctx, vault.PublicKey().ToBase58())
if err != nil {
return nil, 0, err
}
return memory, memoryIndex, nil

memoryAccount, err := common.NewAccountFromPublicKeyString(memoryAddress)
if err != nil {
return nil, 0, err
}

return memoryAccount, index, err
}

func GetVirtualDurableNonceAccountStateInMemory(ctx context.Context, vmIndexerClient indexerpb.IndexerClient, vm, nonce *common.Account) (*vm_program.VirtualDurableNonce, *common.Account, uint16, error) {
Expand Down
2 changes: 1 addition & 1 deletion ocp/worker/geyser/backup.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func (p *runtime) backupExternalDepositWorker(runtimeCtx context.Context, interv
zap.String("mint", mintAccount.PublicKey().ToBase58()),
)

err = fixMissingExternalDeposits(tracedCtx, p.data, p.vmIndexerClient, p.integration, authorityAccount, mintAccount)
err = fixMissingExternalDeposits(tracedCtx, p.data, p.integration, authorityAccount, mintAccount)
if err != nil {
log.With(zap.Error(err)).Warn("failed to fix missing external deposits")
}
Expand Down
Loading
Loading