Skip to content

Commit 391ef57

Browse files
committed
loopout: enable p2tr without keyspend
1 parent 901a935 commit 391ef57

File tree

7 files changed

+171
-37
lines changed

7 files changed

+171
-37
lines changed

client.go

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -192,24 +192,39 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) {
192192
swaps := make([]*SwapInfo, 0, len(loopInSwaps)+len(loopOutSwaps))
193193

194194
for _, swp := range loopOutSwaps {
195+
swapInfo := &SwapInfo{
196+
SwapType: swap.TypeOut,
197+
SwapContract: swp.Contract.SwapContract,
198+
SwapStateData: swp.State(),
199+
SwapHash: swp.Hash,
200+
LastUpdate: swp.LastUpdateTime(),
201+
}
202+
scriptVersion := GetHtlcScriptVersion(
203+
swp.Contract.ProtocolVersion,
204+
)
205+
206+
outputType := swap.HtlcP2WSH
207+
if scriptVersion == swap.HtlcV3 {
208+
outputType = swap.HtlcP2TR
209+
}
210+
195211
htlc, err := swap.NewHtlc(
196-
GetHtlcScriptVersion(swp.Contract.ProtocolVersion),
212+
scriptVersion,
197213
swp.Contract.CltvExpiry, swp.Contract.SenderKey,
198-
swp.Contract.ReceiverKey, swp.Hash, swap.HtlcP2WSH,
199-
s.lndServices.ChainParams,
214+
swp.Contract.ReceiverKey, swp.Hash,
215+
outputType, s.lndServices.ChainParams,
200216
)
201217
if err != nil {
202218
return nil, err
203219
}
204220

205-
swaps = append(swaps, &SwapInfo{
206-
SwapType: swap.TypeOut,
207-
SwapContract: swp.Contract.SwapContract,
208-
SwapStateData: swp.State(),
209-
SwapHash: swp.Hash,
210-
LastUpdate: swp.LastUpdateTime(),
211-
HtlcAddressP2WSH: htlc.Address,
212-
})
221+
if outputType == swap.HtlcP2TR {
222+
swapInfo.HtlcAddressP2TR = htlc.Address
223+
} else {
224+
swapInfo.HtlcAddressP2WSH = htlc.Address
225+
}
226+
227+
swaps = append(swaps, swapInfo)
213228
}
214229

215230
for _, swp := range loopInSwaps {
@@ -426,9 +441,9 @@ func (s *Client) LoopOut(globalCtx context.Context,
426441
// Return hash so that the caller can identify this swap in the updates
427442
// stream.
428443
return &LoopOutSwapInfo{
429-
SwapHash: swap.hash,
430-
HtlcAddressP2WSH: swap.htlc.Address,
431-
ServerMessage: initResult.serverMessage,
444+
SwapHash: swap.hash,
445+
HtlcAddress: swap.htlc.Address,
446+
ServerMessage: initResult.serverMessage,
432447
}, nil
433448
}
434449

client_test.go

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ func TestLoopOutResume(t *testing.T) {
159159
storedVersion := []loopdb.ProtocolVersion{
160160
loopdb.ProtocolVersionUnrecorded,
161161
loopdb.ProtocolVersionHtlcV2,
162+
loopdb.ProtocolVersionHtlcV3,
162163
}
163164

164165
for _, version := range storedVersion {
@@ -283,9 +284,15 @@ func testLoopOutResume(t *testing.T, confs uint32, expired, preimageRevealed,
283284

284285
// Assert that the loopout htlc equals to the expected one.
285286
scriptVersion := GetHtlcScriptVersion(protocolVersion)
287+
288+
outputType := swap.HtlcP2TR
289+
if scriptVersion != swap.HtlcV3 {
290+
outputType = swap.HtlcP2WSH
291+
}
292+
286293
htlc, err := swap.NewHtlc(
287294
scriptVersion, pendingSwap.Contract.CltvExpiry, senderKey,
288-
receiverKey, hash, swap.HtlcP2WSH, &chaincfg.TestNet3Params,
295+
receiverKey, hash, outputType, &chaincfg.TestNet3Params,
289296
)
290297
require.NoError(t, err)
291298
require.Equal(t, htlc.PkScript, confIntent.PkScript)
@@ -345,8 +352,15 @@ func testLoopOutSuccess(ctx *testContext, amt btcutil.Amount, hash lntypes.Hash,
345352
ctx.T.Fatalf("client not sweeping from htlc tx")
346353
}
347354

348-
preImageIndex := 1
349-
if scriptVersion == swap.HtlcV2 {
355+
var preImageIndex int
356+
switch scriptVersion {
357+
case swap.HtlcV1:
358+
preImageIndex = 1
359+
360+
case swap.HtlcV2:
361+
preImageIndex = 0
362+
363+
case swap.HtlcV3:
350364
preImageIndex = 0
351365
}
352366

interface.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,9 @@ type LoopOutSwapInfo struct { // nolint:revive
312312
// SwapHash contains the sha256 hash of the swap preimage.
313313
SwapHash lntypes.Hash
314314

315-
// HtlcAddressP2WSH contains the native segwit swap htlc address that
316-
// the server will publish to.
317-
HtlcAddressP2WSH btcutil.Address
315+
// HtlcAddress contains the swap htlc address that the server will
316+
// publish to.
317+
HtlcAddress btcutil.Address
318318

319319
// ServerMessages is the human-readable message received from the loop
320320
// server.

liquidity/liquidity.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,7 @@ func (m *Manager) autoloop(ctx context.Context) error {
386386
}
387387

388388
log.Infof("loop out automatically dispatched: hash: %v, "+
389-
"address: %v", loopOut.SwapHash,
390-
loopOut.HtlcAddressP2WSH)
389+
"address: %v", loopOut.SwapHash, loopOut.HtlcAddress)
391390
}
392391

393392
for _, in := range suggestion.InSwaps {

loopd/swapclient_server.go

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,21 @@ func (s *swapClientServer) LoopOut(ctx context.Context,
150150
return nil, err
151151
}
152152

153-
return &clientrpc.SwapResponse{
154-
Id: info.SwapHash.String(),
155-
IdBytes: info.SwapHash[:],
156-
HtlcAddress: info.HtlcAddressP2WSH.String(),
157-
HtlcAddressP2Wsh: info.HtlcAddressP2WSH.String(),
158-
ServerMessage: info.ServerMessage,
159-
}, nil
153+
htlcAddress := info.HtlcAddress.String()
154+
resp := &clientrpc.SwapResponse{
155+
Id: info.SwapHash.String(),
156+
IdBytes: info.SwapHash[:],
157+
HtlcAddress: htlcAddress,
158+
ServerMessage: info.ServerMessage,
159+
}
160+
161+
if loopdb.CurrentProtocolVersion() < loopdb.ProtocolVersionHtlcV3 {
162+
resp.HtlcAddressP2Wsh = htlcAddress
163+
} else {
164+
resp.HtlcAddressP2Tr = htlcAddress
165+
}
166+
167+
return resp, nil
160168
}
161169

162170
func (s *swapClientServer) marshallSwap(loopSwap *loop.SwapInfo) (
@@ -252,8 +260,13 @@ func (s *swapClientServer) marshallSwap(loopSwap *loop.SwapInfo) (
252260

253261
case swap.TypeOut:
254262
swapType = clientrpc.SwapType_LOOP_OUT
255-
htlcAddressP2WSH = loopSwap.HtlcAddressP2WSH.EncodeAddress()
256-
htlcAddress = htlcAddressP2WSH
263+
if loopSwap.HtlcAddressP2WSH != nil {
264+
htlcAddressP2WSH = loopSwap.HtlcAddressP2WSH.EncodeAddress()
265+
htlcAddress = htlcAddressP2WSH
266+
} else {
267+
htlcAddressP2TR = loopSwap.HtlcAddressP2TR.EncodeAddress()
268+
htlcAddress = htlcAddressP2TR
269+
}
257270

258271
outGoingChanSet = loopSwap.OutgoingChanSet
259272

loopout.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,14 +183,20 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig,
183183
}
184184

185185
swapKit := newSwapKit(
186-
swapHash, swap.TypeOut,
187-
cfg, &contract.SwapContract,
186+
swapHash, swap.TypeOut, cfg, &contract.SwapContract,
188187
)
189188

190189
swapKit.lastUpdateTime = initiationTime
191190

191+
scriptVersion := GetHtlcScriptVersion(loopdb.CurrentProtocolVersion())
192+
outputType := swap.HtlcP2TR
193+
if scriptVersion != swap.HtlcV3 {
194+
// Default to using P2WSH for legacy htlcs.
195+
outputType = swap.HtlcP2WSH
196+
}
197+
192198
// Create the htlc.
193-
htlc, err := swapKit.getHtlc(swap.HtlcP2WSH)
199+
htlc, err := swapKit.getHtlc(outputType)
194200
if err != nil {
195201
return nil, err
196202
}
@@ -239,12 +245,18 @@ func resumeLoopOutSwap(reqContext context.Context, cfg *swapConfig,
239245
log.Infof("Resuming loop out swap %v", hash)
240246

241247
swapKit := newSwapKit(
242-
hash, swap.TypeOut, cfg,
243-
&pend.Contract.SwapContract,
248+
hash, swap.TypeOut, cfg, &pend.Contract.SwapContract,
244249
)
245250

251+
scriptVersion := GetHtlcScriptVersion(pend.Contract.ProtocolVersion)
252+
outputType := swap.HtlcP2TR
253+
if scriptVersion != swap.HtlcV3 {
254+
// Default to using P2WSH for legacy htlcs.
255+
outputType = swap.HtlcP2WSH
256+
}
257+
246258
// Create the htlc.
247-
htlc, err := swapKit.getHtlc(swap.HtlcP2WSH)
259+
htlc, err := swapKit.getHtlc(outputType)
248260
if err != nil {
249261
return nil, err
250262
}

loopout_test.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,22 @@ import (
2424
// TestLoopOutPaymentParameters tests the first part of the loop out process up
2525
// to the point where the off-chain payments are made.
2626
func TestLoopOutPaymentParameters(t *testing.T) {
27+
t.Run("stable protocol", func(t *testing.T) {
28+
testLoopOutPaymentParameters(t)
29+
})
30+
31+
t.Run("experimental protocol", func(t *testing.T) {
32+
loopdb.EnableExperimentalProtocol()
33+
defer loopdb.ResetCurrentProtocolVersion()
34+
35+
testLoopOutPaymentParameters(t)
36+
})
37+
}
38+
39+
// TestLoopOutPaymentParameters tests the first part of the loop out process up
40+
// to the point where the off-chain payments are made.
41+
func testLoopOutPaymentParameters(t *testing.T) {
42+
2743
defer test.Guard(t)()
2844

2945
// Set up test context objects.
@@ -144,6 +160,19 @@ func TestLoopOutPaymentParameters(t *testing.T) {
144160
// TestLateHtlcPublish tests that the client is not revealing the preimage if
145161
// there are not enough blocks left.
146162
func TestLateHtlcPublish(t *testing.T) {
163+
t.Run("stable protocol", func(t *testing.T) {
164+
testLateHtlcPublish(t)
165+
})
166+
167+
t.Run("experimental protocol", func(t *testing.T) {
168+
loopdb.EnableExperimentalProtocol()
169+
defer loopdb.ResetCurrentProtocolVersion()
170+
171+
testLateHtlcPublish(t)
172+
})
173+
}
174+
175+
func testLateHtlcPublish(t *testing.T) {
147176
defer test.Guard(t)()
148177

149178
lnd := test.NewMockLnd()
@@ -232,6 +261,19 @@ func TestLateHtlcPublish(t *testing.T) {
232261
// TestCustomSweepConfTarget ensures we are able to sweep a Loop Out HTLC with a
233262
// custom confirmation target.
234263
func TestCustomSweepConfTarget(t *testing.T) {
264+
t.Run("stable protocol", func(t *testing.T) {
265+
testCustomSweepConfTarget(t)
266+
})
267+
268+
t.Run("experimental protocol", func(t *testing.T) {
269+
loopdb.EnableExperimentalProtocol()
270+
defer loopdb.ResetCurrentProtocolVersion()
271+
272+
testCustomSweepConfTarget(t)
273+
})
274+
}
275+
276+
func testCustomSweepConfTarget(t *testing.T) {
235277
defer test.Guard(t)()
236278

237279
lnd := test.NewMockLnd()
@@ -433,6 +475,19 @@ func TestCustomSweepConfTarget(t *testing.T) {
433475
// to start with a fee rate that will be too high, then progress to an
434476
// acceptable one.
435477
func TestPreimagePush(t *testing.T) {
478+
t.Run("stable protocol", func(t *testing.T) {
479+
testPreimagePush(t)
480+
})
481+
482+
t.Run("experimental protocol", func(t *testing.T) {
483+
loopdb.EnableExperimentalProtocol()
484+
defer loopdb.ResetCurrentProtocolVersion()
485+
486+
testPreimagePush(t)
487+
})
488+
}
489+
490+
func testPreimagePush(t *testing.T) {
436491
defer test.Guard(t)()
437492

438493
lnd := test.NewMockLnd()
@@ -604,6 +659,19 @@ func TestPreimagePush(t *testing.T) {
604659
// we have revealed our preimage, demonstrating that we do not reveal our
605660
// preimage once we've reached our expiry height.
606661
func TestExpiryBeforeReveal(t *testing.T) {
662+
t.Run("stable protocol", func(t *testing.T) {
663+
testExpiryBeforeReveal(t)
664+
})
665+
666+
t.Run("experimental protocol", func(t *testing.T) {
667+
loopdb.EnableExperimentalProtocol()
668+
defer loopdb.ResetCurrentProtocolVersion()
669+
670+
testExpiryBeforeReveal(t)
671+
})
672+
}
673+
674+
func testExpiryBeforeReveal(t *testing.T) {
607675
defer test.Guard(t)()
608676

609677
lnd := test.NewMockLnd()
@@ -719,6 +787,19 @@ func TestExpiryBeforeReveal(t *testing.T) {
719787
// TestFailedOffChainCancelation tests sending of a cancelation message to
720788
// the server when a swap fails due to off-chain routing.
721789
func TestFailedOffChainCancelation(t *testing.T) {
790+
t.Run("stable protocol", func(t *testing.T) {
791+
testFailedOffChainCancelation(t)
792+
})
793+
794+
t.Run("experimental protocol", func(t *testing.T) {
795+
loopdb.EnableExperimentalProtocol()
796+
defer loopdb.ResetCurrentProtocolVersion()
797+
798+
testFailedOffChainCancelation(t)
799+
})
800+
}
801+
802+
func testFailedOffChainCancelation(t *testing.T) {
722803
defer test.Guard(t)()
723804

724805
lnd := test.NewMockLnd()

0 commit comments

Comments
 (0)