11package loopin
22
33import (
4+ "bytes"
45 "context"
56 "errors"
67 "fmt"
@@ -93,6 +94,11 @@ type StaticAddressLoopIn struct {
9394 // swap.
9495 DepositOutpoints []string
9596
97+ // SelectedAmount is the amount that the user selected for the swap. If
98+ // the user did not select an amount, the amount of all deposits is
99+ // used.
100+ SelectedAmount btcutil.Amount
101+
96102 // state is the current state of the swap.
97103 state fsm.StateType
98104
@@ -283,14 +289,25 @@ func (l *StaticAddressLoopIn) createHtlcTx(chainParams *chaincfg.Params,
283289 })
284290 }
285291
292+ // Determine the swap amount. If the user selected a specific amount, we
293+ // use that and use the difference to the total deposit amount as the
294+ // change.
295+ var (
296+ swapAmt = l .TotalDepositAmount ()
297+ changeAmount btcutil.Amount
298+ )
299+ if l .SelectedAmount > 0 {
300+ swapAmt = l .SelectedAmount
301+ changeAmount = l .TotalDepositAmount () - l .SelectedAmount
302+ }
303+
286304 // Calculate htlc tx fee for server provided fee rate.
287- weight := l .htlcWeight ()
305+ hasChange := changeAmount > 0
306+ weight := l .htlcWeight (hasChange )
288307 fee := feeRate .FeeForWeight (weight )
289308
290309 // Check if the server breaches our fee limits.
291- amt := float64 (l .TotalDepositAmount ())
292- feeLimit := btcutil .Amount (amt * maxFeePercentage )
293-
310+ feeLimit := btcutil .Amount (float64 (swapAmt ) * maxFeePercentage )
294311 if fee > feeLimit {
295312 return nil , fmt .Errorf ("htlc tx fee %v exceeds max fee %v" ,
296313 fee , feeLimit )
@@ -308,12 +325,20 @@ func (l *StaticAddressLoopIn) createHtlcTx(chainParams *chaincfg.Params,
308325
309326 // Create the sweep output
310327 sweepOutput := & wire.TxOut {
311- Value : int64 (l . TotalDepositAmount ()) - int64 ( fee ),
328+ Value : int64 (swapAmt - fee ),
312329 PkScript : pkscript ,
313330 }
314331
315332 msgTx .AddTxOut (sweepOutput )
316333
334+ // We expect change to be sent back to our static address output script.
335+ if changeAmount > 0 {
336+ msgTx .AddTxOut (& wire.TxOut {
337+ Value : int64 (changeAmount ),
338+ PkScript : l .AddressParams .PkScript ,
339+ })
340+ }
341+
317342 return msgTx , nil
318343}
319344
@@ -325,7 +350,7 @@ func (l *StaticAddressLoopIn) isHtlcTimedOut(height int32) bool {
325350}
326351
327352// htlcWeight returns the weight for the htlc transaction.
328- func (l * StaticAddressLoopIn ) htlcWeight () lntypes.WeightUnit {
353+ func (l * StaticAddressLoopIn ) htlcWeight (hasChange bool ) lntypes.WeightUnit {
329354 var weightEstimator input.TxWeightEstimator
330355 for i := 0 ; i < len (l .Deposits ); i ++ {
331356 weightEstimator .AddTaprootKeySpendInput (
@@ -335,6 +360,10 @@ func (l *StaticAddressLoopIn) htlcWeight() lntypes.WeightUnit {
335360
336361 weightEstimator .AddP2WSHOutput ()
337362
363+ if hasChange {
364+ weightEstimator .AddP2TROutput ()
365+ }
366+
338367 return weightEstimator .Weight ()
339368}
340369
@@ -373,11 +402,25 @@ func (l *StaticAddressLoopIn) createHtlcSweepTx(ctx context.Context,
373402 return nil , err
374403 }
375404
405+ // Check if the htlc tx has a change output. If so we need to select the
406+ // non-change output index to construct the sweep with.
407+ htlcInputIndex := uint32 (0 )
408+ if len (htlcTx .TxOut ) == 2 {
409+ // If the first htlc tx output matches our static address
410+ // script we need to select the second output to sweep from.
411+ if bytes .Equal (
412+ htlcTx .TxOut [0 ].PkScript , l .AddressParams .PkScript ,
413+ ) {
414+
415+ htlcInputIndex = 1
416+ }
417+ }
418+
376419 // Add the htlc input.
377420 sweepTx .AddTxIn (& wire.TxIn {
378421 PreviousOutPoint : wire.OutPoint {
379422 Hash : htlcTx .TxHash (),
380- Index : 0 ,
423+ Index : htlcInputIndex ,
381424 },
382425 SignatureScript : htlc .SigScript ,
383426 Sequence : htlc .SuccessSequence (),
0 commit comments