Skip to content

Commit 784d236

Browse files
committed
channeldb: consolidate root bucket TLVs into new struct
In this commit, we consolidate the root bucket TLVs into a new struct. This makes it easier to see all the new TLV fields at a glance. We also convert TLV usage to use the new type param based APis.
1 parent 007f968 commit 784d236

File tree

2 files changed

+100
-83
lines changed

2 files changed

+100
-83
lines changed

channeldb/channel.go

Lines changed: 98 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -221,27 +221,60 @@ const (
221221
// A tlv type definition used to serialize an outpoint's indexStatus
222222
// for use in the outpoint index.
223223
indexStatusType tlv.Type = 0
224+
)
224225

225-
// A tlv type definition used to serialize and deserialize a KeyLocator
226-
// from the database.
227-
keyLocType tlv.Type = 1
226+
// chanAuxData houses the auxiliary data that is stored for each channel in a
227+
// TLV stream within the root bucket. This is stored as a TLV stream appended
228+
// to the existing hard-coded fields in the channel's root bucket.
229+
type chanAuxData struct {
230+
revokeKeyLoc tlv.RecordT[tlv.TlvType1, keyLocRecord]
228231

229-
// A tlv type used to serialize and deserialize the
230-
// `InitialLocalBalance` field.
231-
initialLocalBalanceType tlv.Type = 2
232+
initialLocalBalance tlv.RecordT[tlv.TlvType2, uint64]
232233

233-
// A tlv type used to serialize and deserialize the
234-
// `InitialRemoteBalance` field.
235-
initialRemoteBalanceType tlv.Type = 3
234+
initialRemoteBalance tlv.RecordT[tlv.TlvType3, uint64]
236235

237-
// A tlv type definition used to serialize and deserialize the
238-
// confirmed ShortChannelID for a zero-conf channel.
239-
realScidType tlv.Type = 4
236+
realScid tlv.RecordT[tlv.TlvType4, lnwire.ShortChannelID]
240237

241-
// A tlv type definition used to serialize and deserialize the
242-
// Memo for the channel channel.
243-
channelMemoType tlv.Type = 5
244-
)
238+
memo tlv.OptionalRecordT[tlv.TlvType5, []byte]
239+
}
240+
241+
// toOpeChan converts the chanAuxData to an OpenChannel by setting the relevant
242+
// fields in the OpenChannel struct.
243+
func (c *chanAuxData) toOpenChan(o *OpenChannel) {
244+
o.RevocationKeyLocator = c.revokeKeyLoc.Val.KeyLocator
245+
o.InitialLocalBalance = lnwire.MilliSatoshi(c.initialLocalBalance.Val)
246+
o.InitialRemoteBalance = lnwire.MilliSatoshi(c.initialRemoteBalance.Val)
247+
o.confirmedScid = c.realScid.Val
248+
c.memo.WhenSomeV(func(memo []byte) {
249+
o.Memo = memo
250+
})
251+
}
252+
253+
// newChanAuxDataFromChan creates a new chanAuxData from the given channel.
254+
func newChanAuxDataFromChan(openChan *OpenChannel) *chanAuxData {
255+
c := &chanAuxData{
256+
revokeKeyLoc: tlv.NewRecordT[tlv.TlvType1, keyLocRecord](
257+
keyLocRecord{openChan.RevocationKeyLocator},
258+
),
259+
initialLocalBalance: tlv.NewPrimitiveRecord[tlv.TlvType2, uint64](
260+
uint64(openChan.InitialLocalBalance),
261+
),
262+
initialRemoteBalance: tlv.NewPrimitiveRecord[tlv.TlvType3, uint64](
263+
uint64(openChan.InitialRemoteBalance),
264+
),
265+
realScid: tlv.NewRecordT[tlv.TlvType4, lnwire.ShortChannelID](
266+
openChan.confirmedScid,
267+
),
268+
}
269+
270+
if len(openChan.Memo) == 0 {
271+
c.memo = tlv.SomeRecordT(
272+
tlv.NewPrimitiveRecord[tlv.TlvType5](openChan.Memo),
273+
)
274+
}
275+
276+
return c
277+
}
245278

246279
// indexStatus is an enum-like type that describes what state the
247280
// outpoint is in. Currently only two possible values.
@@ -852,6 +885,10 @@ type OpenChannel struct {
852885
// channel that will be useful to our future selves.
853886
Memo []byte
854887

888+
// TapscriptRoot is an optional tapscript root used to derive the
889+
// musig2 funding output.
890+
TapscriptRoot fn.Option[chainhash.Hash]
891+
855892
// TODO(roasbeef): eww
856893
Db *ChannelStateDB
857894

@@ -3932,26 +3969,20 @@ func putChanInfo(chanBucket kvdb.RwBucket, channel *OpenChannel) error {
39323969
return err
39333970
}
39343971

3935-
// Convert balance fields into uint64.
3936-
localBalance := uint64(channel.InitialLocalBalance)
3937-
remoteBalance := uint64(channel.InitialRemoteBalance)
3972+
auxData := newChanAuxDataFromChan(channel)
3973+
3974+
tlvRecords := []tlv.Record{
3975+
auxData.revokeKeyLoc.Record(),
3976+
auxData.initialLocalBalance.Record(),
3977+
auxData.initialRemoteBalance.Record(),
3978+
auxData.realScid.Record(),
3979+
}
3980+
auxData.memo.WhenSome(func(memo tlv.RecordT[tlv.TlvType5, []byte]) {
3981+
tlvRecords = append(tlvRecords, memo.Record())
3982+
})
39383983

39393984
// Create the tlv stream.
3940-
tlvStream, err := tlv.NewStream(
3941-
// Write the RevocationKeyLocator as the first entry in a tlv
3942-
// stream.
3943-
MakeKeyLocRecord(
3944-
keyLocType, &channel.RevocationKeyLocator,
3945-
),
3946-
tlv.MakePrimitiveRecord(
3947-
initialLocalBalanceType, &localBalance,
3948-
),
3949-
tlv.MakePrimitiveRecord(
3950-
initialRemoteBalanceType, &remoteBalance,
3951-
),
3952-
MakeScidRecord(realScidType, &channel.confirmedScid),
3953-
tlv.MakePrimitiveRecord(channelMemoType, &channel.Memo),
3954-
)
3985+
tlvStream, err := tlv.NewStream(tlvRecords...)
39553986
if err != nil {
39563987
return err
39573988
}
@@ -4146,28 +4177,16 @@ func fetchChanInfo(chanBucket kvdb.RBucket, channel *OpenChannel) error {
41464177
}
41474178
}
41484179

4149-
// Create balance fields in uint64, and Memo field as byte slice.
4150-
var (
4151-
localBalance uint64
4152-
remoteBalance uint64
4153-
memo []byte
4154-
)
4180+
var auxData chanAuxData
4181+
zeroMemo := auxData.memo.Zero()
41554182

41564183
// Create the tlv stream.
41574184
tlvStream, err := tlv.NewStream(
4158-
// Write the RevocationKeyLocator as the first entry in a tlv
4159-
// stream.
4160-
MakeKeyLocRecord(
4161-
keyLocType, &channel.RevocationKeyLocator,
4162-
),
4163-
tlv.MakePrimitiveRecord(
4164-
initialLocalBalanceType, &localBalance,
4165-
),
4166-
tlv.MakePrimitiveRecord(
4167-
initialRemoteBalanceType, &remoteBalance,
4168-
),
4169-
MakeScidRecord(realScidType, &channel.confirmedScid),
4170-
tlv.MakePrimitiveRecord(channelMemoType, &memo),
4185+
auxData.revokeKeyLoc.Record(),
4186+
auxData.initialLocalBalance.Record(),
4187+
auxData.initialRemoteBalance.Record(),
4188+
auxData.realScid.Record(),
4189+
zeroMemo.Record(),
41714190
)
41724191
if err != nil {
41734192
return err
@@ -4177,14 +4196,9 @@ func fetchChanInfo(chanBucket kvdb.RBucket, channel *OpenChannel) error {
41774196
return err
41784197
}
41794198

4180-
// Attach the balance fields.
4181-
channel.InitialLocalBalance = lnwire.MilliSatoshi(localBalance)
4182-
channel.InitialRemoteBalance = lnwire.MilliSatoshi(remoteBalance)
4183-
4184-
// Attach the memo field if non-empty.
4185-
if len(memo) > 0 {
4186-
channel.Memo = memo
4187-
}
4199+
// Assign all the relevant fields from the aux data into the actual
4200+
// open channel.
4201+
auxData.toOpenChan(channel)
41884202

41894203
channel.Packager = NewChannelPackager(channel.ShortChannelID)
41904204

@@ -4342,8 +4356,27 @@ func deleteThawHeight(chanBucket kvdb.RwBucket) error {
43424356
return chanBucket.Delete(frozenChanKey)
43434357
}
43444358

4345-
// EKeyLocator is an encoder for keychain.KeyLocator.
4346-
func EKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error {
4359+
// keyLocRecord is a wrapper struct around keychain.KeyLocator to implement the
4360+
// tlv.RecordProducer interface.
4361+
type keyLocRecord struct {
4362+
keychain.KeyLocator
4363+
}
4364+
4365+
// Record creates a Record out of a KeyLocator using the passed Type and the
4366+
// EKeyLocator and DKeyLocator functions. The size will always be 8 as
4367+
// KeyFamily is uint32 and the Index is uint32.
4368+
//
4369+
// NOTE: This is part of the tlv.RecordProducer interface.
4370+
func (k *keyLocRecord) Record() tlv.Record {
4371+
// Note that we set the type here as zero, as when used with a
4372+
// tlv.RecordT, the type param will be used as the type.
4373+
return tlv.MakeStaticRecord(
4374+
0, &k.KeyLocator, 8, eKeyLocator, dKeyLocator,
4375+
)
4376+
}
4377+
4378+
// eKeyLocator is an encoder for keychain.KeyLocator.
4379+
func eKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error {
43474380
if v, ok := val.(*keychain.KeyLocator); ok {
43484381
err := tlv.EUint32T(w, uint32(v.Family), buf)
43494382
if err != nil {
@@ -4355,8 +4388,8 @@ func EKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error {
43554388
return tlv.NewTypeForEncodingErr(val, "keychain.KeyLocator")
43564389
}
43574390

4358-
// DKeyLocator is a decoder for keychain.KeyLocator.
4359-
func DKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
4391+
// dKeyLocator is a decoder for keychain.KeyLocator.
4392+
func dKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
43604393
if v, ok := val.(*keychain.KeyLocator); ok {
43614394
var family uint32
43624395
err := tlv.DUint32(r, &family, buf, 4)
@@ -4370,22 +4403,6 @@ func DKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
43704403
return tlv.NewTypeForDecodingErr(val, "keychain.KeyLocator", l, 8)
43714404
}
43724405

4373-
// MakeKeyLocRecord creates a Record out of a KeyLocator using the passed
4374-
// Type and the EKeyLocator and DKeyLocator functions. The size will always be
4375-
// 8 as KeyFamily is uint32 and the Index is uint32.
4376-
func MakeKeyLocRecord(typ tlv.Type, keyLoc *keychain.KeyLocator) tlv.Record {
4377-
return tlv.MakeStaticRecord(typ, keyLoc, 8, EKeyLocator, DKeyLocator)
4378-
}
4379-
4380-
// MakeScidRecord creates a Record out of a ShortChannelID using the passed
4381-
// Type and the EShortChannelID and DShortChannelID functions. The size will
4382-
// always be 8 for the ShortChannelID.
4383-
func MakeScidRecord(typ tlv.Type, scid *lnwire.ShortChannelID) tlv.Record {
4384-
return tlv.MakeStaticRecord(
4385-
typ, scid, 8, lnwire.EShortChannelID, lnwire.DShortChannelID,
4386-
)
4387-
}
4388-
43894406
// ShutdownInfo contains various info about the shutdown initiation of a
43904407
// channel.
43914408
type ShutdownInfo struct {

channeldb/channel_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1523,14 +1523,14 @@ func TestKeyLocatorEncoding(t *testing.T) {
15231523
buf [8]byte
15241524
)
15251525

1526-
err := EKeyLocator(&b, &keyLoc, &buf)
1526+
err := eKeyLocator(&b, &keyLoc, &buf)
15271527
require.NoError(t, err, "unable to encode key locator")
15281528

15291529
// Next, we'll attempt to decode the bytes into a new KeyLocator.
15301530
r := bytes.NewReader(b.Bytes())
15311531
var decodedKeyLoc keychain.KeyLocator
15321532

1533-
err = DKeyLocator(r, &decodedKeyLoc, &buf, 8)
1533+
err = dKeyLocator(r, &decodedKeyLoc, &buf, 8)
15341534
require.NoError(t, err, "unable to decode key locator")
15351535

15361536
// Finally, we'll compare that the original KeyLocator and the decoded

0 commit comments

Comments
 (0)