Skip to content
Closed
213 changes: 132 additions & 81 deletions channeldb/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,27 +221,70 @@ const (
// A tlv type definition used to serialize an outpoint's indexStatus
// for use in the outpoint index.
indexStatusType tlv.Type = 0
)

// A tlv type definition used to serialize and deserialize a KeyLocator
// from the database.
keyLocType tlv.Type = 1
// chanAuxData houses the auxiliary data that is stored for each channel in a
// TLV stream within the root bucket. This is stored as a TLV stream appended
// to the existing hard-coded fields in the channel's root bucket.
type chanAuxData struct {
revokeKeyLoc tlv.RecordT[tlv.TlvType1, keyLocRecord]

// A tlv type used to serialize and deserialize the
// `InitialLocalBalance` field.
initialLocalBalanceType tlv.Type = 2
initialLocalBalance tlv.RecordT[tlv.TlvType2, uint64]

// A tlv type used to serialize and deserialize the
// `InitialRemoteBalance` field.
initialRemoteBalanceType tlv.Type = 3
initialRemoteBalance tlv.RecordT[tlv.TlvType3, uint64]

// A tlv type definition used to serialize and deserialize the
// confirmed ShortChannelID for a zero-conf channel.
realScidType tlv.Type = 4
realScid tlv.RecordT[tlv.TlvType4, lnwire.ShortChannelID]

// A tlv type definition used to serialize and deserialize the
// Memo for the channel channel.
channelMemoType tlv.Type = 5
)
memo tlv.OptionalRecordT[tlv.TlvType5, []byte]

tapscriptRoot tlv.OptionalRecordT[tlv.TlvType6, [32]byte]
}

// toOpeChan converts the chanAuxData to an OpenChannel by setting the relevant
// fields in the OpenChannel struct.
func (c *chanAuxData) toOpenChan(o *OpenChannel) {
o.RevocationKeyLocator = c.revokeKeyLoc.Val.KeyLocator
o.InitialLocalBalance = lnwire.MilliSatoshi(c.initialLocalBalance.Val)
o.InitialRemoteBalance = lnwire.MilliSatoshi(c.initialRemoteBalance.Val)
o.confirmedScid = c.realScid.Val
c.memo.WhenSomeV(func(memo []byte) {
o.Memo = memo
})
c.tapscriptRoot.WhenSomeV(func(h [32]byte) {
o.TapscriptRoot = fn.Some(chainhash.Hash(h))
})
}

// newChanAuxDataFromChan creates a new chanAuxData from the given channel.
func newChanAuxDataFromChan(openChan *OpenChannel) *chanAuxData {
c := &chanAuxData{
revokeKeyLoc: tlv.NewRecordT[tlv.TlvType1, keyLocRecord](
keyLocRecord{openChan.RevocationKeyLocator},
),
initialLocalBalance: tlv.NewPrimitiveRecord[tlv.TlvType2, uint64](
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: line length here and below

uint64(openChan.InitialLocalBalance),
),
initialRemoteBalance: tlv.NewPrimitiveRecord[tlv.TlvType3, uint64](
uint64(openChan.InitialRemoteBalance),
),
realScid: tlv.NewRecordT[tlv.TlvType4, lnwire.ShortChannelID](
openChan.confirmedScid,
),
}

if len(openChan.Memo) != 0 {
c.memo = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType5](openChan.Memo),
)
}
openChan.TapscriptRoot.WhenSome(func(h chainhash.Hash) {
c.tapscriptRoot = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType6]([32]byte(h)),
)
})

return c
}

// indexStatus is an enum-like type that describes what state the
// outpoint is in. Currently only two possible values.
Expand Down Expand Up @@ -320,6 +363,11 @@ const (
// SimpleTaprootFeatureBit indicates that the simple-taproot-chans
// feature bit was negotiated during the lifetime of the channel.
SimpleTaprootFeatureBit ChannelType = 1 << 10

// TapscriptRootBit indicates that this is a musig2 channel with a top
// level tapscript commitment. This MUST be set along with the
// SimpleTaprootFeatureBit.
TapscriptRootBit ChannelType = 1 << 11
)

// IsSingleFunder returns true if the channel type if one of the known single
Expand Down Expand Up @@ -390,6 +438,12 @@ func (c ChannelType) IsTaproot() bool {
return c&SimpleTaprootFeatureBit == SimpleTaprootFeatureBit
}

// HasTapscriptRoot returns true if the channel is using a top level tapscript
// root commmitment.
func (c ChannelType) HasTapscriptRoot() bool {
return c&TapscriptRootBit == TapscriptRootBit
}

// ChannelConstraints represents a set of constraints meant to allow a node to
// limit their exposure, enact flow control and ensure that all HTLCs are
// economically relevant. This struct will be mirrored for both sides of the
Expand Down Expand Up @@ -852,6 +906,10 @@ type OpenChannel struct {
// channel that will be useful to our future selves.
Memo []byte

// TapscriptRoot is an optional tapscript root used to derive the
// musig2 funding output.
TapscriptRoot fn.Option[chainhash.Hash]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move this to next commit since it's not yet used here?


// TODO(roasbeef): eww
Db *ChannelStateDB

Expand Down Expand Up @@ -3932,26 +3990,23 @@ func putChanInfo(chanBucket kvdb.RwBucket, channel *OpenChannel) error {
return err
}

// Convert balance fields into uint64.
localBalance := uint64(channel.InitialLocalBalance)
remoteBalance := uint64(channel.InitialRemoteBalance)
auxData := newChanAuxDataFromChan(channel)

tlvRecords := []tlv.Record{
auxData.revokeKeyLoc.Record(),
auxData.initialLocalBalance.Record(),
auxData.initialRemoteBalance.Record(),
auxData.realScid.Record(),
}
auxData.memo.WhenSome(func(memo tlv.RecordT[tlv.TlvType5, []byte]) {
tlvRecords = append(tlvRecords, memo.Record())
})
auxData.tapscriptRoot.WhenSome(func(root tlv.RecordT[tlv.TlvType6, [32]byte]) { //nolint:lll
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: format as:

	auxData.tapscriptRoot.WhenSome(
		func(root tlv.RecordT[tlv.TlvType6, [32]byte]) {
			tlvRecords = append(tlvRecords, root.Record())
		}
	)

instead? Or place the //nolint:lll above the line instead of at the end to make it even longer?

tlvRecords = append(tlvRecords, root.Record())
})

// Create the tlv stream.
tlvStream, err := tlv.NewStream(
// Write the RevocationKeyLocator as the first entry in a tlv
// stream.
MakeKeyLocRecord(
keyLocType, &channel.RevocationKeyLocator,
),
tlv.MakePrimitiveRecord(
initialLocalBalanceType, &localBalance,
),
tlv.MakePrimitiveRecord(
initialRemoteBalanceType, &remoteBalance,
),
MakeScidRecord(realScidType, &channel.confirmedScid),
tlv.MakePrimitiveRecord(channelMemoType, &channel.Memo),
)
tlvStream, err := tlv.NewStream(tlvRecords...)
if err != nil {
return err
}
Expand Down Expand Up @@ -4146,46 +4201,39 @@ func fetchChanInfo(chanBucket kvdb.RBucket, channel *OpenChannel) error {
}
}

// Create balance fields in uint64, and Memo field as byte slice.
var (
localBalance uint64
remoteBalance uint64
memo []byte
)
var auxData chanAuxData
memo := auxData.memo.Zero()
tapscriptRoot := auxData.tapscriptRoot.Zero()

// Create the tlv stream.
tlvStream, err := tlv.NewStream(
// Write the RevocationKeyLocator as the first entry in a tlv
// stream.
MakeKeyLocRecord(
keyLocType, &channel.RevocationKeyLocator,
),
tlv.MakePrimitiveRecord(
initialLocalBalanceType, &localBalance,
),
tlv.MakePrimitiveRecord(
initialRemoteBalanceType, &remoteBalance,
),
MakeScidRecord(realScidType, &channel.confirmedScid),
tlv.MakePrimitiveRecord(channelMemoType, &memo),
auxData.revokeKeyLoc.Record(),
auxData.initialLocalBalance.Record(),
auxData.initialRemoteBalance.Record(),
auxData.realScid.Record(),
memo.Record(),
tapscriptRoot.Record(),
)
if err != nil {
return err
}

if err := tlvStream.Decode(r); err != nil {
tlvs, err := tlvStream.DecodeWithParsedTypes(r)
if err != nil {
return err
}

// Attach the balance fields.
channel.InitialLocalBalance = lnwire.MilliSatoshi(localBalance)
channel.InitialRemoteBalance = lnwire.MilliSatoshi(remoteBalance)

// Attach the memo field if non-empty.
if len(memo) > 0 {
channel.Memo = memo
if _, ok := tlvs[memo.TlvType()]; ok {
auxData.memo = tlv.SomeRecordT(memo)
}
if _, ok := tlvs[tapscriptRoot.TlvType()]; ok {
auxData.tapscriptRoot = tlv.SomeRecordT(tapscriptRoot)
}

// Assign all the relevant fields from the aux data into the actual
// open channel.
auxData.toOpenChan(channel)

channel.Packager = NewChannelPackager(channel.ShortChannelID)

// Finally, read the optional shutdown scripts.
Expand Down Expand Up @@ -4342,8 +4390,27 @@ func deleteThawHeight(chanBucket kvdb.RwBucket) error {
return chanBucket.Delete(frozenChanKey)
}

// EKeyLocator is an encoder for keychain.KeyLocator.
func EKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error {
// keyLocRecord is a wrapper struct around keychain.KeyLocator to implement the
// tlv.RecordProducer interface.
type keyLocRecord struct {
keychain.KeyLocator
}

// Record creates a Record out of a KeyLocator using the passed Type and the
// EKeyLocator and DKeyLocator functions. The size will always be 8 as
// KeyFamily is uint32 and the Index is uint32.
//
// NOTE: This is part of the tlv.RecordProducer interface.
func (k *keyLocRecord) Record() tlv.Record {
// Note that we set the type here as zero, as when used with a
// tlv.RecordT, the type param will be used as the type.
return tlv.MakeStaticRecord(
0, &k.KeyLocator, 8, eKeyLocator, dKeyLocator,
)
}

// eKeyLocator is an encoder for keychain.KeyLocator.
func eKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*keychain.KeyLocator); ok {
err := tlv.EUint32T(w, uint32(v.Family), buf)
if err != nil {
Expand All @@ -4355,8 +4422,8 @@ func EKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error {
return tlv.NewTypeForEncodingErr(val, "keychain.KeyLocator")
}

// DKeyLocator is a decoder for keychain.KeyLocator.
func DKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
// dKeyLocator is a decoder for keychain.KeyLocator.
func dKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't un-export these two functions, since we apparently use them in Loop.

if v, ok := val.(*keychain.KeyLocator); ok {
var family uint32
err := tlv.DUint32(r, &family, buf, 4)
Expand All @@ -4370,22 +4437,6 @@ func DKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
return tlv.NewTypeForDecodingErr(val, "keychain.KeyLocator", l, 8)
}

// MakeKeyLocRecord creates a Record out of a KeyLocator using the passed
// Type and the EKeyLocator and DKeyLocator functions. The size will always be
// 8 as KeyFamily is uint32 and the Index is uint32.
func MakeKeyLocRecord(typ tlv.Type, keyLoc *keychain.KeyLocator) tlv.Record {
return tlv.MakeStaticRecord(typ, keyLoc, 8, EKeyLocator, DKeyLocator)
}

// MakeScidRecord creates a Record out of a ShortChannelID using the passed
// Type and the EShortChannelID and DShortChannelID functions. The size will
// always be 8 for the ShortChannelID.
func MakeScidRecord(typ tlv.Type, scid *lnwire.ShortChannelID) tlv.Record {
return tlv.MakeStaticRecord(
typ, scid, 8, lnwire.EShortChannelID, lnwire.DShortChannelID,
)
}

// ShutdownInfo contains various info about the shutdown initiation of a
// channel.
type ShutdownInfo struct {
Expand Down
11 changes: 8 additions & 3 deletions channeldb/channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnmock"
Expand Down Expand Up @@ -171,7 +172,7 @@ func fundingPointOption(chanPoint wire.OutPoint) testChannelOption {
}

// channelIDOption is an option which sets the short channel ID of the channel.
var channelIDOption = func(chanID lnwire.ShortChannelID) testChannelOption {
func channelIDOption(chanID lnwire.ShortChannelID) testChannelOption {
return func(params *testChannelParams) {
params.channel.ShortChannelID = chanID
}
Expand Down Expand Up @@ -311,6 +312,9 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel {
uniqueOutputIndex.Add(1)
op := wire.OutPoint{Hash: key, Index: uniqueOutputIndex.Load()}

var tapscriptRoot chainhash.Hash
copy(tapscriptRoot[:], bytes.Repeat([]byte{1}, 32))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we're at it, should we also throw in a value for the memo since that doesn't seem to be covered yet?


return &OpenChannel{
ChanType: SingleFunderBit | FrozenBit,
ChainHash: key,
Expand Down Expand Up @@ -353,6 +357,7 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel {
ThawHeight: uint32(defaultPendingHeight),
InitialLocalBalance: lnwire.MilliSatoshi(9000),
InitialRemoteBalance: lnwire.MilliSatoshi(3000),
TapscriptRoot: fn.Some(tapscriptRoot),
}
}

Expand Down Expand Up @@ -1523,14 +1528,14 @@ func TestKeyLocatorEncoding(t *testing.T) {
buf [8]byte
)

err := EKeyLocator(&b, &keyLoc, &buf)
err := eKeyLocator(&b, &keyLoc, &buf)
require.NoError(t, err, "unable to encode key locator")

// Next, we'll attempt to decode the bytes into a new KeyLocator.
r := bytes.NewReader(b.Bytes())
var decodedKeyLoc keychain.KeyLocator

err = DKeyLocator(r, &decodedKeyLoc, &buf, 8)
err = dKeyLocator(r, &decodedKeyLoc, &buf, 8)
require.NoError(t, err, "unable to decode key locator")

// Finally, we'll compare that the original KeyLocator and the decoded
Expand Down
6 changes: 5 additions & 1 deletion contractcourt/chain_watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet"
)
Expand Down Expand Up @@ -301,8 +302,11 @@ func (c *chainWatcher) Start() error {
err error
)
if chanState.ChanType.IsTaproot() {
fundingOpts := fn.MapOptionZ(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice trick 😎

chanState.TapscriptRoot, lnwallet.TapscriptRootToOpt,
)
c.fundingPkScript, _, err = input.GenTaprootFundingScript(
localKey, remoteKey, 0,
localKey, remoteKey, 0, fundingOpts...,
)
if err != nil {
return err
Expand Down
5 changes: 5 additions & 0 deletions funding/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/discovery"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/labels"
Expand Down Expand Up @@ -2853,8 +2854,12 @@ func makeFundingScript(channel *channeldb.OpenChannel) ([]byte, error) {
remoteKey := channel.RemoteChanCfg.MultiSigKey.PubKey

if channel.ChanType.IsTaproot() {
fundingOpts := fn.MapOptionZ(
channel.TapscriptRoot, lnwallet.TapscriptRootToOpt,
)
pkScript, _, err := input.GenTaprootFundingScript(
localKey, remoteKey, int64(channel.Capacity),
fundingOpts...,
)
if err != nil {
return nil, err
Expand Down
Loading