Skip to content

Commit ff1d418

Browse files
committed
Track swap PDAs in Timelock table
1 parent b83d00a commit ff1d418

File tree

9 files changed

+118
-11
lines changed

9 files changed

+118
-11
lines changed

pkg/code/common/account.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,9 @@ func (a *TimelockAccounts) ToDBRecord() *timelock.Record {
471471
DepositPdaAddress: a.VmDepositAccounts.Pda.PublicKey().ToBase58(),
472472
DepositPdaBump: a.VmDepositAccounts.PdaBump,
473473

474+
SwapPdaAddress: a.VmSwapAccounts.Pda.PublicKey().ToBase58(),
475+
SwapPdaBump: a.VmSwapAccounts.PdaBump,
476+
474477
UnlockAt: nil,
475478

476479
Block: 0,

pkg/code/data/internal.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ type DatabaseData interface {
229229
GetTimelockByAddress(ctx context.Context, address string) (*timelock.Record, error)
230230
GetTimelockByVault(ctx context.Context, vault string) (*timelock.Record, error)
231231
GetTimelockByDepositPda(ctx context.Context, depositPda string) (*timelock.Record, error)
232+
GetTimelockBySwapPda(ctx context.Context, swapPda string) (*timelock.Record, error)
232233
GetTimelockByVaultBatch(ctx context.Context, vaults ...string) (map[string]*timelock.Record, error)
233234
GetAllTimelocksByState(ctx context.Context, state timelock_token.TimelockState, opts ...query.Option) ([]*timelock.Record, error)
234235
GetTimelockCountByState(ctx context.Context, state timelock_token.TimelockState) (uint64, error)
@@ -841,6 +842,9 @@ func (dp *DatabaseProvider) GetTimelockByVault(ctx context.Context, vault string
841842
func (dp *DatabaseProvider) GetTimelockByDepositPda(ctx context.Context, depositPda string) (*timelock.Record, error) {
842843
return dp.timelocks.GetByDepositPda(ctx, depositPda)
843844
}
845+
func (dp *DatabaseProvider) GetTimelockBySwapPda(ctx context.Context, swapPda string) (*timelock.Record, error) {
846+
return dp.timelocks.GetBySwapPda(ctx, swapPda)
847+
}
844848
func (dp *DatabaseProvider) GetAllTimelocksByState(ctx context.Context, state timelock_token.TimelockState, opts ...query.Option) ([]*timelock.Record, error) {
845849
req, err := query.DefaultPaginationHandler(opts...)
846850
if err != nil {

pkg/code/data/timelock/memory/store.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,16 @@ func (s *store) GetByDepositPda(ctx context.Context, depositPda string) (*timelo
118118
return nil, timelock.ErrTimelockNotFound
119119
}
120120

121+
func (s *store) GetBySwapPda(ctx context.Context, swapPda string) (*timelock.Record, error) {
122+
s.mu.Lock()
123+
defer s.mu.Unlock()
124+
125+
if item := s.findBySwapPda(swapPda); item != nil {
126+
return item.Clone(), nil
127+
}
128+
return nil, timelock.ErrTimelockNotFound
129+
}
130+
121131
// GetAllByState implements timelock.Store.GetAllByState
122132
func (s *store) GetAllByState(ctx context.Context, state timelock_token.TimelockState, cursor query.Cursor, limit uint64, direction query.Ordering) ([]*timelock.Record, error) {
123133
s.mu.Lock()
@@ -184,6 +194,15 @@ func (s *store) findByDepositPda(depositPda string) *timelock.Record {
184194
return nil
185195
}
186196

197+
func (s *store) findBySwapPda(swapPda string) *timelock.Record {
198+
for _, item := range s.records {
199+
if swapPda == item.SwapPdaAddress {
200+
return item
201+
}
202+
}
203+
return nil
204+
}
205+
187206
func (s *store) findByState(state timelock_token.TimelockState) []*timelock.Record {
188207
res := make([]*timelock.Record, 0)
189208
for _, item := range s.records {

pkg/code/data/timelock/postgres/model.go

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ type model struct {
3333
DepositPdaAddress string `db:"deposit_pda_address"`
3434
DepositPdaBump uint `db:"deposit_pda_bump"`
3535

36+
SwapPdaAddress string `db:"swap_pda_address"`
37+
SwapPdaBump uint `db:"swap_pda_bump"`
38+
3639
UnlockAt sql.NullInt64 `db:"unlock_at"`
3740

3841
Block uint64 `db:"block"`
@@ -63,6 +66,9 @@ func toModel(obj *timelock.Record) (*model, error) {
6366
DepositPdaAddress: obj.DepositPdaAddress,
6467
DepositPdaBump: uint(obj.DepositPdaBump),
6568

69+
SwapPdaAddress: obj.SwapPdaAddress,
70+
SwapPdaBump: uint(obj.SwapPdaBump),
71+
6672
UnlockAt: unlockAt,
6773

6874
Block: obj.Block,
@@ -92,6 +98,9 @@ func fromModel(obj *model) *timelock.Record {
9298
DepositPdaAddress: obj.DepositPdaAddress,
9399
DepositPdaBump: uint8(obj.DepositPdaBump),
94100

101+
SwapPdaAddress: obj.SwapPdaAddress,
102+
SwapPdaBump: uint8(obj.SwapPdaBump),
103+
95104
UnlockAt: unlockAt,
96105

97106
Block: obj.Block,
@@ -103,16 +112,16 @@ func fromModel(obj *model) *timelock.Record {
103112
func (m *model) dbSave(ctx context.Context, db *sqlx.DB) error {
104113
return pgutil.ExecuteInTx(ctx, db, sql.LevelDefault, func(tx *sqlx.Tx) error {
105114
query := `INSERT INTO ` + tableName + `
106-
(address, bump, vault_address, vault_bump, vault_owner, vault_state, deposit_pda_address, deposit_pda_bump, unlock_at, block, last_updated_at)
107-
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
115+
(address, bump, vault_address, vault_bump, vault_owner, vault_state, deposit_pda_address, deposit_pda_bump, swap_pda_address, swap_pda_bump, unlock_at, block, last_updated_at)
116+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
108117
109118
ON CONFLICT (address)
110119
DO UPDATE
111-
SET vault_state = $6, unlock_at = $9, block = $10, last_updated_at = $11
112-
WHERE ` + tableName + `.address = $1 AND ` + tableName + `.vault_address = $3 AND ` + tableName + `.block < $10
120+
SET vault_state = $6, unlock_at = $11, block = $12, last_updated_at = $13
121+
WHERE ` + tableName + `.address = $1 AND ` + tableName + `.vault_address = $3 AND ` + tableName + `.block < $12
113122
114123
RETURNING
115-
id, address, bump, vault_address, vault_bump, vault_owner, vault_state, deposit_pda_address, deposit_pda_bump, unlock_at, block, last_updated_at`
124+
id, address, bump, vault_address, vault_bump, vault_owner, vault_state, deposit_pda_address, deposit_pda_bump, swap_pda_address, swap_pda_bump, unlock_at, block, last_updated_at`
116125

117126
m.LastUpdatedAt = time.Now()
118127

@@ -131,6 +140,9 @@ func (m *model) dbSave(ctx context.Context, db *sqlx.DB) error {
131140
m.DepositPdaAddress,
132141
m.DepositPdaBump,
133142

143+
m.SwapPdaAddress,
144+
m.SwapPdaBump,
145+
134146
m.UnlockAt,
135147

136148
m.Block,
@@ -146,7 +158,7 @@ func dbGetByAddress(ctx context.Context, db *sqlx.DB, address string) (*model, e
146158
res := &model{}
147159

148160
query := `SELECT
149-
id, address, bump, vault_address, vault_bump, vault_owner, vault_state, deposit_pda_address, deposit_pda_bump, unlock_at, block, last_updated_at
161+
id, address, bump, vault_address, vault_bump, vault_owner, vault_state, deposit_pda_address, deposit_pda_bump, swap_pda_address, swap_pda_bump, unlock_at, block, last_updated_at
150162
FROM ` + tableName + `
151163
WHERE address = $1
152164
LIMIT 1`
@@ -162,7 +174,7 @@ func dbGetByVault(ctx context.Context, db *sqlx.DB, vault string) (*model, error
162174
res := &model{}
163175

164176
query := `SELECT
165-
id, address, bump, vault_address, vault_bump, vault_owner, vault_state, deposit_pda_address, deposit_pda_bump, unlock_at, block, last_updated_at
177+
id, address, bump, vault_address, vault_bump, vault_owner, vault_state, deposit_pda_address, deposit_pda_bump, swap_pda_address, swap_pda_bump, unlock_at, block, last_updated_at
166178
FROM ` + tableName + `
167179
WHERE vault_address = $1
168180
LIMIT 1`
@@ -183,7 +195,7 @@ func dbGetByVaultBatch(ctx context.Context, db *sqlx.DB, vaults ...string) ([]*m
183195
}
184196

185197
query := fmt.Sprintf(
186-
`SELECT id, address, bump, vault_address, vault_bump, vault_owner, vault_state, deposit_pda_address, deposit_pda_bump, unlock_at, block, last_updated_at
198+
`SELECT id, address, bump, vault_address, vault_bump, vault_owner, vault_state, deposit_pda_address, deposit_pda_bump, swap_pda_address, swap_pda_bump, unlock_at, block, last_updated_at
187199
FROM `+tableName+`
188200
WHERE vault_address IN (%s)`,
189201
strings.Join(individualFilters, ", "),
@@ -203,7 +215,7 @@ func dbGetByDepositPda(ctx context.Context, db *sqlx.DB, depositPda string) (*mo
203215
res := &model{}
204216

205217
query := `SELECT
206-
id, address, bump, vault_address, vault_bump, vault_owner, vault_state, deposit_pda_address, deposit_pda_bump, unlock_at, block, last_updated_at
218+
id, address, bump, vault_address, vault_bump, vault_owner, vault_state, deposit_pda_address, deposit_pda_bump, swap_pda_address, swap_pda_bump, unlock_at, block, last_updated_at
207219
FROM ` + tableName + `
208220
WHERE deposit_pda_address = $1
209221
LIMIT 1`
@@ -215,11 +227,27 @@ func dbGetByDepositPda(ctx context.Context, db *sqlx.DB, depositPda string) (*mo
215227
return res, nil
216228
}
217229

230+
func dbGetBySwapPda(ctx context.Context, db *sqlx.DB, swapPda string) (*model, error) {
231+
res := &model{}
232+
233+
query := `SELECT
234+
id, address, bump, vault_address, vault_bump, vault_owner, vault_state, deposit_pda_address, deposit_pda_bump, swap_pda_address, swap_pda_bump, unlock_at, block, last_updated_at
235+
FROM ` + tableName + `
236+
WHERE swap_pda_address = $1
237+
LIMIT 1`
238+
239+
err := db.GetContext(ctx, res, query, swapPda)
240+
if err != nil {
241+
return nil, pgutil.CheckNoRows(err, timelock.ErrTimelockNotFound)
242+
}
243+
return res, nil
244+
}
245+
218246
func dbGetAllByState(ctx context.Context, db *sqlx.DB, state timelock_token.TimelockState, cursor q.Cursor, limit uint64, direction q.Ordering) ([]*model, error) {
219247
res := []*model{}
220248

221249
query := `SELECT
222-
id, address, bump, vault_address, vault_bump, vault_owner, vault_state, deposit_pda_address, deposit_pda_bump, unlock_at, block, last_updated_at
250+
id, address, bump, vault_address, vault_bump, vault_owner, vault_state, deposit_pda_address, deposit_pda_bump, swap_pda_address, swap_pda_bump, unlock_at, block, last_updated_at
223251
FROM ` + tableName + `
224252
WHERE (vault_state = $1)
225253
`

pkg/code/data/timelock/postgres/store.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,16 @@ func (s *store) GetByDepositPda(ctx context.Context, depositPda string) (*timelo
8383
return fromModel(model), nil
8484
}
8585

86+
// GetBySwapPda implements timelock.Store.GetBySwapPda
87+
func (s *store) GetBySwapPda(ctx context.Context, swapPda string) (*timelock.Record, error) {
88+
model, err := dbGetBySwapPda(ctx, s.db, swapPda)
89+
if err != nil {
90+
return nil, err
91+
}
92+
93+
return fromModel(model), nil
94+
}
95+
8696
// GetOldestByState implements timelock.Store.GetAllByState
8797
func (s *store) GetAllByState(ctx context.Context, state timelock_token.TimelockState, cursor query.Cursor, limit uint64, direction query.Ordering) ([]*timelock.Record, error) {
8898
res, err := dbGetAllByState(ctx, s.db, state, cursor, limit, direction)

pkg/code/data/timelock/postgres/store_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ const (
3333
deposit_pda_address TEXT NOT NULL,
3434
deposit_pda_bump INTEGER NOT NULL,
3535
36+
swap_pda_address TEXT NOT NULL,
37+
swap_pda_bump INTEGER NOT NULL,
38+
3639
unlock_at INTEGER,
3740
3841
block INTEGER NOT NULL,
@@ -44,7 +47,8 @@ const (
4447
CONSTRAINT codewallet__core_timelock__uniq__deposit_pda_address UNIQUE (deposit_pda_address),
4548
CONSTRAINT codewallet__core_timelock__uniq__address__and__vault_owner UNIQUE (address, vault_owner),
4649
CONSTRAINT codewallet__core_timelock__uniq__address__and__vault_address UNIQUE (address, vault_address),
47-
CONSTRAINT codewallet__core_timelock__uniq__address__and__deposit_pda_address UNIQUE (address, deposit_pda_address)
50+
CONSTRAINT codewallet__core_timelock__uniq__address__and__deposit_pda_address UNIQUE (address, deposit_pda_address),
51+
CONSTRAINT codewallet__core_timelock__uniq__address__and__deposit_swap_address UNIQUE (address, swap_pda_address)
4852
);
4953
`
5054

pkg/code/data/timelock/store.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ type Store interface {
2424
// GetByDepositPda gets a timelock account's state by the deposit PDA address
2525
GetByDepositPda(ctx context.Context, depositPda string) (*Record, error)
2626

27+
// GetBySwapPda gets a timelock account's state by the swap PDA address
28+
GetBySwapPda(ctx context.Context, depositPda string) (*Record, error)
29+
2730
// GetAllByState gets all timelock accounts in the provided state
2831
GetAllByState(ctx context.Context, state timelock_token.TimelockState, cursor query.Cursor, limit uint64, direction query.Ordering) ([]*Record, error)
2932

pkg/code/data/timelock/tests/tests.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ func testHappyPath(t *testing.T, s timelock.Store) {
4444
DepositPdaAddress: "deposit",
4545
DepositPdaBump: 253,
4646

47+
SwapPdaAddress: "swap",
48+
SwapPdaBump: 252,
49+
4750
Block: 123456,
4851
}
4952
cloned := expected.Clone()
@@ -59,6 +62,9 @@ func testHappyPath(t *testing.T, s timelock.Store) {
5962
_, err = s.GetByDepositPda(ctx, expected.DepositPdaAddress)
6063
assert.Equal(t, timelock.ErrTimelockNotFound, err)
6164

65+
_, err = s.GetBySwapPda(ctx, expected.SwapPdaAddress)
66+
assert.Equal(t, timelock.ErrTimelockNotFound, err)
67+
6268
// Save the record
6369

6470
require.NoError(t, s.Save(ctx, expected))
@@ -79,6 +85,10 @@ func testHappyPath(t *testing.T, s timelock.Store) {
7985
require.NoError(t, err)
8086
assertEquivalentRecords(t, cloned, actual)
8187

88+
actual, err = s.GetBySwapPda(ctx, expected.SwapPdaAddress)
89+
require.NoError(t, err)
90+
assertEquivalentRecords(t, cloned, actual)
91+
8292
initialBlock := expected.Block
8393

8494
// Update the record's state
@@ -119,6 +129,10 @@ func testHappyPath(t *testing.T, s timelock.Store) {
119129
actual, err = s.GetByDepositPda(ctx, expected.DepositPdaAddress)
120130
require.NoError(t, err)
121131
assertEquivalentRecords(t, cloned, actual)
132+
133+
actual, err = s.GetBySwapPda(ctx, expected.SwapPdaAddress)
134+
require.NoError(t, err)
135+
assertEquivalentRecords(t, cloned, actual)
122136
})
123137
}
124138

@@ -140,6 +154,9 @@ func testBatchedMethods(t *testing.T, s timelock.Store) {
140154
DepositPdaAddress: fmt.Sprintf("deposit%d", i),
141155
DepositPdaBump: 253,
142156

157+
SwapPdaAddress: fmt.Sprintf("swap%d", i),
158+
SwapPdaBump: 252,
159+
143160
Block: uint64(i),
144161
}
145162

@@ -189,6 +206,9 @@ func testGetAllByState(t *testing.T, s timelock.Store) {
189206
DepositPdaAddress: fmt.Sprintf("deposit%d", i),
190207
DepositPdaBump: 253,
191208

209+
SwapPdaAddress: fmt.Sprintf("swap%d", i),
210+
SwapPdaBump: 252,
211+
192212
Block: uint64(i),
193213
}
194214

@@ -260,6 +280,9 @@ func testGetCountByState(t *testing.T, s timelock.Store) {
260280

261281
DepositPdaAddress: fmt.Sprintf("deposit-%s-%d", state, i),
262282
DepositPdaBump: 253,
283+
284+
SwapPdaAddress: fmt.Sprintf("swap-%s-%d", state, i),
285+
SwapPdaBump: 252,
263286
}
264287

265288
require.NoError(t, s.Save(ctx, record))

pkg/code/data/timelock/timelock.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ type Record struct {
3131
DepositPdaAddress string
3232
DepositPdaBump uint8
3333

34+
SwapPdaAddress string
35+
SwapPdaBump uint8
36+
3437
UnlockAt *uint64
3538

3639
Block uint64
@@ -78,6 +81,9 @@ func (r *Record) Clone() *Record {
7881
DepositPdaAddress: r.DepositPdaAddress,
7982
DepositPdaBump: r.DepositPdaBump,
8083

84+
SwapPdaAddress: r.SwapPdaAddress,
85+
SwapPdaBump: r.SwapPdaBump,
86+
8187
UnlockAt: unlockAt,
8288

8389
Block: r.Block,
@@ -106,6 +112,9 @@ func (r *Record) CopyTo(dst *Record) {
106112
dst.DepositPdaAddress = r.DepositPdaAddress
107113
dst.DepositPdaBump = r.DepositPdaBump
108114

115+
dst.SwapPdaAddress = r.SwapPdaAddress
116+
dst.SwapPdaBump = r.SwapPdaBump
117+
109118
dst.UnlockAt = unlockAt
110119

111120
dst.Block = r.Block
@@ -134,5 +143,9 @@ func (r *Record) Validate() error {
134143
return errors.New("deposit pda address is required")
135144
}
136145

146+
if len(r.SwapPdaAddress) == 0 {
147+
return errors.New("swap pda address is required")
148+
}
149+
137150
return nil
138151
}

0 commit comments

Comments
 (0)