11package loopin
22
33import (
4+ "bytes"
45 "context"
56 "errors"
67 "fmt"
@@ -91,8 +92,15 @@ type StaticAddressLoopIn struct {
9192
9293 // The outpoints in the format txid:vout that are part of the loop-in
9394 // swap.
95+ // TODO(hieblmi): Replace this with a getter method that fetches the
96+ // outpoints from the deposits.
9497 DepositOutpoints []string
9598
99+ // SelectedAmount is the amount that the user selected for the swap. If
100+ // the user did not select an amount, the amount of all deposits is
101+ // used.
102+ SelectedAmount btcutil.Amount
103+
96104 // state is the current state of the swap.
97105 state fsm.StateType
98106
@@ -283,14 +291,25 @@ func (l *StaticAddressLoopIn) createHtlcTx(chainParams *chaincfg.Params,
283291 })
284292 }
285293
294+ // Determine the swap amount. If the user selected a specific amount, we
295+ // use that and use the difference to the total deposit amount as the
296+ // change.
297+ var (
298+ swapAmt = l .TotalDepositAmount ()
299+ changeAmount btcutil.Amount
300+ )
301+ if l .SelectedAmount > 0 {
302+ swapAmt = l .SelectedAmount
303+ changeAmount = l .TotalDepositAmount () - l .SelectedAmount
304+ }
305+
286306 // Calculate htlc tx fee for server provided fee rate.
287- weight := l .htlcWeight ()
307+ hasChange := changeAmount > 0
308+ weight := l .htlcWeight (hasChange )
288309 fee := feeRate .FeeForWeight (weight )
289310
290311 // Check if the server breaches our fee limits.
291- amt := float64 (l .TotalDepositAmount ())
292- feeLimit := btcutil .Amount (amt * maxFeePercentage )
293-
312+ feeLimit := btcutil .Amount (float64 (swapAmt ) * maxFeePercentage )
294313 if fee > feeLimit {
295314 return nil , fmt .Errorf ("htlc tx fee %v exceeds max fee %v" ,
296315 fee , feeLimit )
@@ -308,12 +327,20 @@ func (l *StaticAddressLoopIn) createHtlcTx(chainParams *chaincfg.Params,
308327
309328 // Create the sweep output
310329 sweepOutput := & wire.TxOut {
311- Value : int64 (l . TotalDepositAmount ()) - int64 ( fee ),
330+ Value : int64 (swapAmt - fee ),
312331 PkScript : pkscript ,
313332 }
314333
315334 msgTx .AddTxOut (sweepOutput )
316335
336+ // We expect change to be sent back to our static address output script.
337+ if changeAmount > 0 {
338+ msgTx .AddTxOut (& wire.TxOut {
339+ Value : int64 (changeAmount ),
340+ PkScript : l .AddressParams .PkScript ,
341+ })
342+ }
343+
317344 return msgTx , nil
318345}
319346
@@ -325,7 +352,7 @@ func (l *StaticAddressLoopIn) isHtlcTimedOut(height int32) bool {
325352}
326353
327354// htlcWeight returns the weight for the htlc transaction.
328- func (l * StaticAddressLoopIn ) htlcWeight () lntypes.WeightUnit {
355+ func (l * StaticAddressLoopIn ) htlcWeight (hasChange bool ) lntypes.WeightUnit {
329356 var weightEstimator input.TxWeightEstimator
330357 for i := 0 ; i < len (l .Deposits ); i ++ {
331358 weightEstimator .AddTaprootKeySpendInput (
@@ -335,6 +362,10 @@ func (l *StaticAddressLoopIn) htlcWeight() lntypes.WeightUnit {
335362
336363 weightEstimator .AddP2WSHOutput ()
337364
365+ if hasChange {
366+ weightEstimator .AddP2TROutput ()
367+ }
368+
338369 return weightEstimator .Weight ()
339370}
340371
@@ -373,11 +404,25 @@ func (l *StaticAddressLoopIn) createHtlcSweepTx(ctx context.Context,
373404 return nil , err
374405 }
375406
407+ // Check if the htlc tx has a change output. If so we need to select the
408+ // non-change output index to construct the sweep with.
409+ htlcInputIndex := uint32 (0 )
410+ if len (htlcTx .TxOut ) == 2 {
411+ // If the first htlc tx output matches our static address
412+ // script we need to select the second output to sweep from.
413+ if bytes .Equal (
414+ htlcTx .TxOut [0 ].PkScript , l .AddressParams .PkScript ,
415+ ) {
416+
417+ htlcInputIndex = 1
418+ }
419+ }
420+
376421 // Add the htlc input.
377422 sweepTx .AddTxIn (& wire.TxIn {
378423 PreviousOutPoint : wire.OutPoint {
379424 Hash : htlcTx .TxHash (),
380- Index : 0 ,
425+ Index : htlcInputIndex ,
381426 },
382427 SignatureScript : htlc .SigScript ,
383428 Sequence : htlc .SuccessSequence (),
0 commit comments