diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 4c81964cc24..f6b29f96afd 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -2605,40 +2605,6 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, // forwarded. availableBandwidth := l.Bandwidth() - auxBandwidth, externalErr := fn.MapOptionZ( - l.cfg.AuxTrafficShaper, - func(ts AuxTrafficShaper) fn.Result[OptionalBandwidth] { - var htlcBlob fn.Option[tlv.Blob] - blob, err := customRecords.Serialize() - if err != nil { - return fn.Err[OptionalBandwidth]( - fmt.Errorf("unable to serialize "+ - "custom records: %w", err)) - } - - if len(blob) > 0 { - htlcBlob = fn.Some(blob) - } - - return l.AuxBandwidth(amt, originalScid, htlcBlob, ts) - }, - ).Unpack() - if externalErr != nil { - l.log.Errorf("Unable to determine aux bandwidth: %v", - externalErr) - - return NewLinkError(&lnwire.FailTemporaryNodeFailure{}) - } - - if auxBandwidth.IsHandled && auxBandwidth.Bandwidth.IsSome() { - auxBandwidth.Bandwidth.WhenSome( - func(bandwidth lnwire.MilliSatoshi) { - availableBandwidth = bandwidth - }, - ) - } - - // Check to see if there is enough balance in this channel. if amt > availableBandwidth { l.log.Warnf("insufficient bandwidth to route htlc: %v is "+ "larger than %v", amt, availableBandwidth) diff --git a/lnwallet/aux_test_utils.go b/lnwallet/aux_test_utils.go new file mode 100644 index 00000000000..eb140644ab8 --- /dev/null +++ b/lnwallet/aux_test_utils.go @@ -0,0 +1,39 @@ +package lnwallet + +import ( + "github.com/lightningnetwork/lnd/lnwire" +) + +// NewTestAuxHtlcDescriptor creates an AuxHtlcDescriptor for testing purposes. +// This function allows tests to create descriptors with specific commit heights +// and entry types, which are normally unexported fields. +func NewTestAuxHtlcDescriptor( + chanID lnwire.ChannelID, + rHash PaymentHash, + timeout uint32, + amount lnwire.MilliSatoshi, + htlcIndex uint64, + parentIndex uint64, + entryType uint8, + customRecords lnwire.CustomRecords, + addHeightLocal uint64, + addHeightRemote uint64, + removeHeightLocal uint64, + removeHeightRemote uint64, +) AuxHtlcDescriptor { + + return AuxHtlcDescriptor{ + ChanID: chanID, + RHash: rHash, + Timeout: timeout, + Amount: amount, + HtlcIndex: htlcIndex, + ParentIndex: parentIndex, + EntryType: updateType(entryType), + CustomRecords: customRecords, + addCommitHeightLocal: addHeightLocal, + addCommitHeightRemote: addHeightRemote, + removeCommitHeightLocal: removeHeightLocal, + removeCommitHeightRemote: removeHeightRemote, + } +} diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 484a019da5b..cc01a9c3a26 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -833,6 +833,14 @@ type LightningChannel struct { // is created. type ChannelOpt func(*channelOpts) +// AuxHtlcValidator is a function that validates whether an HTLC can be added +// to a custom channel. It is called during HTLC validation with the current +// channel state and HTLC details. This allows external components (like the +// traffic shaper) to perform final validation checks against the most +// up-to-date channel state before the HTLC is committed. +type AuxHtlcValidator func(amount, linkBandwidth lnwire.MilliSatoshi, + customRecords lnwire.CustomRecords, view AuxHtlcView) error + // channelOpts is the set of options used to create a new channel. type channelOpts struct { localNonce *musig2.Nonces @@ -842,6 +850,10 @@ type channelOpts struct { auxSigner fn.Option[AuxSigner] auxResolver fn.Option[AuxContractResolver] + // auxHtlcValidator is an optional validator that performs custom + // validation on HTLCs before they are added to the channel state. + auxHtlcValidator fn.Option[AuxHtlcValidator] + skipNonceInit bool } @@ -894,6 +906,15 @@ func WithAuxResolver(resolver AuxContractResolver) ChannelOpt { } } +// WithAuxHtlcValidator is used to specify a custom HTLC validator for the +// channel. This validator will be called during HTLC addition to perform +// final validation checks against the most up-to-date channel state. +func WithAuxHtlcValidator(validator AuxHtlcValidator) ChannelOpt { + return func(o *channelOpts) { + o.auxHtlcValidator = fn.Some(validator) + } +} + // defaultChannelOpts returns the set of default options for a new channel. func defaultChannelOpts() *channelOpts { return &channelOpts{} @@ -2738,9 +2759,15 @@ func (lc *LightningChannel) FetchLatestAuxHTLCView() AuxHtlcView { lc.RLock() defer lc.RUnlock() - return newAuxHtlcView(lc.fetchHTLCView( - lc.updateLogs.Remote.logIndex, lc.updateLogs.Local.logIndex, - )) + nextHeight := lc.commitChains.Local.tip().height + 1 + remoteACKedIndex := lc.commitChains.Local.tail().messageIndices.Remote + view := lc.fetchHTLCView( + remoteACKedIndex, lc.updateLogs.Local.logIndex, + ) + + view.NextHeight = nextHeight + + return newAuxHtlcView(view) } // fetchHTLCView returns all the candidate HTLC updates which should be @@ -6061,7 +6088,7 @@ func (lc *LightningChannel) addHTLC(htlc *lnwire.UpdateAddHTLC, defer lc.Unlock() pd := lc.htlcAddDescriptor(htlc, openKey) - if err := lc.validateAddHtlc(pd, buffer); err != nil { + if err := lc.validateAddHtlc(pd, buffer, true); err != nil { return 0, err } @@ -6179,7 +6206,7 @@ func (lc *LightningChannel) MayAddOutgoingHtlc(amt lnwire.MilliSatoshi) error { // Enforce the FeeBuffer because we are evaluating whether we can add // another htlc to the channel state. - if err := lc.validateAddHtlc(pd, FeeBuffer); err != nil { + if err := lc.validateAddHtlc(pd, FeeBuffer, false); err != nil { lc.log.Debugf("May add outgoing htlc rejected: %v", err) return err } @@ -6215,7 +6242,8 @@ func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC, // validateAddHtlc validates the addition of an outgoing htlc to our local and // remote commitments. func (lc *LightningChannel) validateAddHtlc(pd *paymentDescriptor, - buffer BufferType) error { + buffer BufferType, finalCheck bool) error { + // Make sure adding this HTLC won't violate any of the constraints we // must keep on the commitment transactions. remoteACKedIndex := lc.commitChains.Local.tail().messageIndices.Remote @@ -6243,6 +6271,59 @@ func (lc *LightningChannel) validateAddHtlc(pd *paymentDescriptor, return err } + // In order to avoid unnecessary validations of the aux bandwidth that + // may be costly to perform, let's skip unless this is the final check + // before adding the HTLC to the channel. + if !finalCheck { + return nil + } + + // If an auxiliary HTLC validator is configured, call it now to perform + // custom validation checks against the current channel state. This is + // the final validation point before the HTLC is added to the update + // log, ensuring that the validator sees the most up-to-date state + // including all previously validated HTLCs in this batch. + // + // NOTE: This is called after the standard commitment sanity checks to + // ensure we only perform (potentially) expensive custom validation on + // HTLCs that have already passed the basic Lightning protocol + // constraints. + err = fn.MapOptionZ( + lc.opts.auxHtlcValidator, + func(validator AuxHtlcValidator) error { + // Fetch the current HTLC view which includes all + // pending HTLCs that haven't been committed yet. This + // provides the validator with the most accurate state. + commitChain := lc.commitChains.Local + remoteIndex := commitChain.tail().messageIndices.Remote + view := lc.fetchHTLCView( + remoteIndex, + lc.updateLogs.Local.logIndex, + ) + + nextHeight := lc.commitChains.Local.tip().height + 1 + view.NextHeight = nextHeight + + lc.log.Infof("Setting view nextheight=%v", nextHeight) + + auxView := newAuxHtlcView(view) + + // Get the current available balance for the link + // bandwidth check. This is needed for the reserve + // validation in the traffic shaper. We use NoBuffer + // since this is the final check before adding the HTLC. + linkBandwidth, _ := lc.availableBalance(NoBuffer) + + return validator( + pd.Amount, linkBandwidth, pd.CustomRecords, + auxView, + ) + }, + ) + if err != nil { + return fmt.Errorf("aux HTLC validation failed: %w", err) + } + return nil } diff --git a/peer/brontide.go b/peer/brontide.go index 8d02ca6e539..79f153ed2e3 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -51,6 +51,7 @@ import ( "github.com/lightningnetwork/lnd/pool" "github.com/lightningnetwork/lnd/protofsm" "github.com/lightningnetwork/lnd/queue" + "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/subscribe" "github.com/lightningnetwork/lnd/ticker" "github.com/lightningnetwork/lnd/tlv" @@ -1139,6 +1140,16 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) ( }, ) + p.cfg.AuxTrafficShaper.WhenSome( + func(ts htlcswitch.AuxTrafficShaper) { + val := p.createHtlcValidator(dbChan, ts) + chanOpts = append( + chanOpts, + lnwallet.WithAuxHtlcValidator(val), + ) + }, + ) + lnChan, err := lnwallet.NewLightningChannel( p.cfg.Signer, dbChan, p.cfg.SigPool, chanOpts..., ) @@ -5228,6 +5239,15 @@ func (p *Brontide) addActiveChannel(c *lnpeer.NewChannel) error { chanOpts = append(chanOpts, lnwallet.WithAuxResolver(s)) }) + p.cfg.AuxTrafficShaper.WhenSome( + func(ts htlcswitch.AuxTrafficShaper) { + val := p.createHtlcValidator(c.OpenChannel, ts) + chanOpts = append( + chanOpts, lnwallet.WithAuxHtlcValidator(val), + ) + }, + ) + // If not already active, we'll add this channel to the set of active // channels, so we can look it up later easily according to its channel // ID. @@ -5434,6 +5454,72 @@ func (p *Brontide) scaleTimeout(timeout time.Duration) time.Duration { return timeout } +// createHtlcValidator creates an HTLC validator function that performs final +// aux balance validation before HTLCs are added to the channel state. This +// validator calls into the traffic shaper's PaymentBandwidth method to check +// external balance against the most up-to-date channel state, preventing race +// conditions where multiple HTLCs could be approved based on stale bandwidth. +func (p *Brontide) createHtlcValidator(dbChan *channeldb.OpenChannel, + ts htlcswitch.AuxTrafficShaper) lnwallet.AuxHtlcValidator { + + return func(amount, linkBandwidth lnwire.MilliSatoshi, + customRecords lnwire.CustomRecords, + view lnwallet.AuxHtlcView) error { + + // Get the short channel ID for logging. + scid := dbChan.ShortChannelID + + // Extract the HTLC custom records to pass to the traffic + // shaper. + var htlcBlob fn.Option[tlv.Blob] + if len(customRecords) > 0 { + blob, err := customRecords.Serialize() + if err != nil { + return fmt.Errorf("unable to serialize "+ + "custom records: %w", err) + } + htlcBlob = fn.Some(blob) + } + + // Get the funding and commitment blobs for this channel. + fundingBlob := dbChan.CustomBlob + commitmentBlob := dbChan.LocalCommitment.CustomBlob + + // Fetch the peer's public key. + peerBytes := p.IdentityKey().SerializeCompressed() + peer, err := route.NewVertexFromBytes(peerBytes) + if err != nil { + return fmt.Errorf("failed to create vertex from peer "+ + "pub key: %w", err) + } + + // Call the traffic shaper's PaymentBandwidth method with the + // current state. This performs the same bandwidth checks as + // during pathfinding/forwarding, but against the absolute + // latest channel state. + // + // The linkBandwidth is provided by the channel and represents + // the current available balance, which is used by the traffic + // shaper to ensure we don't dip below channel reserves. + bandwidth, err := ts.PaymentBandwidth( + fundingBlob, htlcBlob, commitmentBlob, + linkBandwidth, amount, view, peer, + ) + if err != nil { + return fmt.Errorf("traffic shaper bandwidth check "+ + "failed: %w", err) + } + + if amount > bandwidth { + return fmt.Errorf("insufficient aux bandwidth: "+ + "need %v, have %v (scid=%v)", amount, + bandwidth, scid) + } + + return nil + } +} + // CoopCloseUpdates is a struct used to communicate updates for an active close // to the caller. type CoopCloseUpdates struct {