Skip to content

Commit 4f5c806

Browse files
committed
loopdb: add helper methods to update swap costs
This commit adds the necessary sqlc code and SwapStore function to update swap costs for all swaps in one transaction.
1 parent 08aa4db commit 4f5c806

File tree

8 files changed

+245
-0
lines changed

8 files changed

+245
-0
lines changed

loopdb/interface.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ type SwapStore interface {
6565
// it's decoding using the proto package's `Unmarshal` method.
6666
FetchLiquidityParams(ctx context.Context) ([]byte, error)
6767

68+
// BatchUpdateLoopOutSwapCosts updates the swap costs for a batch of
69+
// loop out swaps.
70+
BatchUpdateLoopOutSwapCosts(ctx context.Context,
71+
swaps map[lntypes.Hash]SwapCost) error
72+
6873
// Close closes the underlying database.
6974
Close() error
7075
}

loopdb/sql_store.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,38 @@ func (s *BaseDB) BatchInsertUpdate(ctx context.Context,
407407
})
408408
}
409409

410+
// BatchUpdateLoopOutSwapCosts updates the swap costs for a batch of loop out
411+
// swaps.
412+
func (b *BaseDB) BatchUpdateLoopOutSwapCosts(ctx context.Context,
413+
costs map[lntypes.Hash]SwapCost) error {
414+
415+
writeOpts := &SqliteTxOptions{}
416+
return b.ExecTx(ctx, writeOpts, func(tx *sqlc.Queries) error {
417+
for swapHash, cost := range costs {
418+
lastUpdateID, err := tx.GetLastUpdateID(
419+
ctx, swapHash[:],
420+
)
421+
if err != nil {
422+
return err
423+
}
424+
425+
err = tx.OverrideSwapCosts(
426+
ctx, sqlc.OverrideSwapCostsParams{
427+
ID: lastUpdateID,
428+
ServerCost: int64(cost.Server),
429+
OnchainCost: int64(cost.Onchain),
430+
OffchainCost: int64(cost.Offchain),
431+
},
432+
)
433+
if err != nil {
434+
return err
435+
}
436+
}
437+
438+
return nil
439+
})
440+
}
441+
410442
// loopToInsertArgs converts a SwapContract struct to the arguments needed to
411443
// insert it into the database.
412444
func loopToInsertArgs(hash lntypes.Hash,

loopdb/sql_test.go

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/lightninglabs/loop/loopdb/sqlc"
1414
"github.com/lightninglabs/loop/test"
1515
"github.com/lightningnetwork/lnd/keychain"
16+
"github.com/lightningnetwork/lnd/lntypes"
1617
"github.com/lightningnetwork/lnd/routing/route"
1718
"github.com/stretchr/testify/require"
1819
)
@@ -396,6 +397,124 @@ func TestIssue615(t *testing.T) {
396397
require.NoError(t, err)
397398
}
398399

400+
// TestBatchUpdateCost tests that we can batch update the cost of multiple swaps
401+
// at once.
402+
func TestBatchUpdateCost(t *testing.T) {
403+
// Create a new sqlite store for testing.
404+
store := NewTestDB(t)
405+
406+
destAddr := test.GetDestAddr(t, 0)
407+
initiationTime := time.Date(2018, 11, 1, 0, 0, 0, 0, time.UTC)
408+
409+
testContract := LoopOutContract{
410+
SwapContract: SwapContract{
411+
AmountRequested: 100,
412+
CltvExpiry: 144,
413+
HtlcKeys: HtlcKeys{
414+
SenderScriptKey: senderKey,
415+
ReceiverScriptKey: receiverKey,
416+
SenderInternalPubKey: senderInternalKey,
417+
ReceiverInternalPubKey: receiverInternalKey,
418+
ClientScriptKeyLocator: keychain.KeyLocator{
419+
Family: 1,
420+
Index: 2,
421+
},
422+
},
423+
MaxMinerFee: 10,
424+
MaxSwapFee: 20,
425+
426+
InitiationHeight: 99,
427+
428+
InitiationTime: initiationTime,
429+
ProtocolVersion: ProtocolVersionMuSig2,
430+
},
431+
MaxPrepayRoutingFee: 40,
432+
PrepayInvoice: "prepayinvoice",
433+
DestAddr: destAddr,
434+
SwapInvoice: "swapinvoice",
435+
MaxSwapRoutingFee: 30,
436+
SweepConfTarget: 2,
437+
HtlcConfirmations: 2,
438+
SwapPublicationDeadline: initiationTime,
439+
PaymentTimeout: time.Second * 11,
440+
}
441+
442+
makeSwap := func(preimage lntypes.Preimage) *LoopOutContract {
443+
contract := testContract
444+
contract.Preimage = preimage
445+
446+
return &contract
447+
}
448+
449+
// Next, we'll add two swaps to the database.
450+
preimage1 := testPreimage
451+
preimage2 := lntypes.Preimage{4, 4, 4}
452+
453+
ctxb := context.Background()
454+
swap1 := makeSwap(preimage1)
455+
swap2 := makeSwap(preimage2)
456+
457+
hash1 := swap1.Preimage.Hash()
458+
err := store.CreateLoopOut(ctxb, hash1, swap1)
459+
require.NoError(t, err)
460+
461+
hash2 := swap2.Preimage.Hash()
462+
err = store.CreateLoopOut(ctxb, hash2, swap2)
463+
require.NoError(t, err)
464+
465+
// Add an update to both swaps containing the cost.
466+
err = store.UpdateLoopOut(
467+
ctxb, hash1, testTime,
468+
SwapStateData{
469+
State: StateSuccess,
470+
Cost: SwapCost{
471+
Server: 1,
472+
Onchain: 2,
473+
Offchain: 3,
474+
},
475+
},
476+
)
477+
require.NoError(t, err)
478+
479+
err = store.UpdateLoopOut(
480+
ctxb, hash2, testTime,
481+
SwapStateData{
482+
State: StateSuccess,
483+
Cost: SwapCost{
484+
Server: 4,
485+
Onchain: 5,
486+
Offchain: 6,
487+
},
488+
},
489+
)
490+
require.NoError(t, err)
491+
492+
updateMap := map[lntypes.Hash]SwapCost{
493+
hash1: {
494+
Server: 2,
495+
Onchain: 3,
496+
Offchain: 4,
497+
},
498+
hash2: {
499+
Server: 6,
500+
Onchain: 7,
501+
Offchain: 8,
502+
},
503+
}
504+
require.NoError(t, store.BatchUpdateLoopOutSwapCosts(ctxb, updateMap))
505+
506+
swaps, err := store.FetchLoopOutSwaps(ctxb)
507+
require.NoError(t, err)
508+
require.Len(t, swaps, 2)
509+
510+
swapsMap := make(map[lntypes.Hash]*LoopOut)
511+
swapsMap[swaps[0].Hash] = swaps[0]
512+
swapsMap[swaps[1].Hash] = swaps[1]
513+
514+
require.Equal(t, updateMap[hash1], swapsMap[hash1].State().Cost)
515+
require.Equal(t, updateMap[hash2], swapsMap[hash2].State().Cost)
516+
}
517+
399518
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
400519

401520
func randomString(length int) string {

loopdb/sqlc/querier.go

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

loopdb/sqlc/queries/swaps.sql

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,19 @@ INSERT INTO htlc_keys(
133133
) VALUES (
134134
$1, $2, $3, $4, $5, $6, $7
135135
);
136+
137+
-- name: GetLastUpdateID :one
138+
SELECT id
139+
FROM swap_updates
140+
WHERE swap_hash = $1
141+
ORDER BY update_timestamp DESC
142+
LIMIT 1;
143+
144+
-- name: OverrideSwapCosts :exec
145+
UPDATE swap_updates
146+
SET
147+
server_cost = $2,
148+
onchain_cost = $3,
149+
offchain_cost = $4
150+
WHERE id = $1;
151+

loopdb/sqlc/swaps.sql.go

Lines changed: 41 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

loopdb/store.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,3 +1009,11 @@ func (b *boltSwapStore) BatchInsertUpdate(ctx context.Context,
10091009

10101010
return errUnimplemented
10111011
}
1012+
1013+
// BatchUpdateLoopOutSwapCosts updates the swap costs for a batch of loop out
1014+
// swaps.
1015+
func (b *boltSwapStore) BatchUpdateLoopOutSwapCosts(ctx context.Context,
1016+
costs map[lntypes.Hash]SwapCost) error {
1017+
1018+
return errUnimplemented
1019+
}

loopdb/store_mock.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package loopdb
33
import (
44
"context"
55
"errors"
6+
"fmt"
67
"testing"
78
"time"
89

@@ -337,3 +338,24 @@ func (b *StoreMock) BatchInsertUpdate(ctx context.Context,
337338

338339
return errors.New("not implemented")
339340
}
341+
342+
// BatchUpdateLoopOutSwapCosts updates the swap costs for a batch of loop out
343+
// swaps.
344+
func (s *StoreMock) BatchUpdateLoopOutSwapCosts(ctx context.Context,
345+
costs map[lntypes.Hash]SwapCost) error {
346+
347+
for hash, cost := range costs {
348+
if _, ok := s.LoopOutUpdates[hash]; !ok {
349+
return fmt.Errorf("swap has no updates: %v", hash)
350+
}
351+
352+
updates, ok := s.LoopOutUpdates[hash]
353+
if !ok {
354+
return fmt.Errorf("swap has no updates: %v", hash)
355+
}
356+
357+
updates[len(updates)-1].Cost = cost
358+
}
359+
360+
return nil
361+
}

0 commit comments

Comments
 (0)