Skip to content

Commit 099f556

Browse files
Roasbeefguggero
authored andcommitted
lnwire: add CustomRecords to shutdown message
1 parent 9a972e1 commit 099f556

File tree

3 files changed

+178
-22
lines changed

3 files changed

+178
-22
lines changed

lnwire/lnwire_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ func randCustomRecords(t *testing.T, r *rand.Rand) CustomRecords {
435435
key := MinCustomRecordsTlvType + keyOffset
436436

437437
// Values are byte slices of any length.
438-
value := make([]byte, r.Intn(100))
438+
value := make([]byte, r.Intn(10))
439439
_, err := r.Read(value)
440440
require.NoError(t, err)
441441

@@ -791,7 +791,6 @@ func TestLightningWireProtocol(t *testing.T) {
791791
req := Shutdown{
792792
ChannelID: ChannelID(c),
793793
Address: shutdownAddr,
794-
ExtraData: make([]byte, 0),
795794
}
796795

797796
if r.Int31()%2 == 0 {
@@ -953,12 +952,14 @@ func TestLightningWireProtocol(t *testing.T) {
953952
// Only create the slice if there will be any signatures
954953
// in it to prevent false positive test failures due to
955954
// an empty slice versus a nil slice.
956-
numSigs := uint16(r.Int31n(1019))
955+
numSigs := uint16(r.Int31n(500))
957956
if numSigs > 0 {
958957
req.HtlcSigs = make([]Sig, numSigs)
959958
}
960959
for i := 0; i < int(numSigs); i++ {
961-
req.HtlcSigs[i], err = NewSigFromSignature(testSig)
960+
req.HtlcSigs[i], err = NewSigFromSignature(
961+
testSig,
962+
)
962963
if err != nil {
963964
t.Fatalf("unable to parse sig: %v", err)
964965
return

lnwire/shutdown.go

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ type Shutdown struct {
3838
// co-op sign offer.
3939
ShutdownNonce ShutdownNonceTLV
4040

41+
// CustomRecords maps TLV types to byte slices, storing arbitrary data
42+
// intended for inclusion in the ExtraData field of the Shutdown
43+
// message.
44+
CustomRecords CustomRecords
45+
4146
// ExtraData is the set of data that was appended to this message to
4247
// fill out the full maximum transport message size. These fields can
4348
// be used to specify optional data such as custom TLV fields.
@@ -56,7 +61,7 @@ func NewShutdown(cid ChannelID, addr DeliveryAddress) *Shutdown {
5661
// interface.
5762
var _ Message = (*Shutdown)(nil)
5863

59-
// Decode deserializes a serialized Shutdown stored in the passed io.Reader
64+
// Decode deserializes a serialized Shutdown from the passed io.Reader,
6065
// observing the specified protocol version.
6166
//
6267
// This is part of the lnwire.Message interface.
@@ -71,20 +76,23 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error {
7176
return err
7277
}
7378

79+
// Extract TLV records from the extra data field.
7480
musigNonce := s.ShutdownNonce.Zero()
75-
typeMap, err := tlvRecords.ExtractRecords(&musigNonce)
81+
82+
customRecords, parsed, extraData, err := ParseAndExtractCustomRecords(
83+
tlvRecords, &musigNonce,
84+
)
7685
if err != nil {
7786
return err
7887
}
7988

80-
// Set the corresponding TLV types if they were included in the stream.
81-
if val, ok := typeMap[s.ShutdownNonce.TlvType()]; ok && val == nil {
89+
// Assign the parsed records back to the message.
90+
if _, ok := parsed[musigNonce.TlvType()]; ok {
8291
s.ShutdownNonce = tlv.SomeRecordT(musigNonce)
8392
}
8493

85-
if len(tlvRecords) != 0 {
86-
s.ExtraData = tlvRecords
87-
}
94+
s.CustomRecords = customRecords
95+
s.ExtraData = extraData
8896

8997
return nil
9098
}
@@ -94,26 +102,28 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error {
94102
//
95103
// This is part of the lnwire.Message interface.
96104
func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error {
97-
recordProducers := make([]tlv.RecordProducer, 0, 1)
98-
s.ShutdownNonce.WhenSome(
99-
func(nonce tlv.RecordT[ShutdownNonceType, Musig2Nonce]) {
100-
recordProducers = append(recordProducers, &nonce)
101-
},
102-
)
103-
err := EncodeMessageExtraData(&s.ExtraData, recordProducers...)
104-
if err != nil {
105+
if err := WriteChannelID(w, s.ChannelID); err != nil {
105106
return err
106107
}
107108

108-
if err := WriteChannelID(w, s.ChannelID); err != nil {
109+
if err := WriteDeliveryAddress(w, s.Address); err != nil {
109110
return err
110111
}
111112

112-
if err := WriteDeliveryAddress(w, s.Address); err != nil {
113+
// Only include nonce in extra data if present.
114+
var records []tlv.RecordProducer
115+
s.ShutdownNonce.WhenSome(
116+
func(nonce tlv.RecordT[ShutdownNonceType, Musig2Nonce]) {
117+
records = append(records, &nonce)
118+
},
119+
)
120+
121+
extraData, err := MergeAndEncode(records, s.ExtraData, s.CustomRecords)
122+
if err != nil {
113123
return err
114124
}
115125

116-
return WriteBytes(w, s.ExtraData)
126+
return WriteBytes(w, extraData)
117127
}
118128

119129
// MsgType returns the integer uniquely identifying this message type on the

lnwire/shutdown_test.go

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
package lnwire
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"testing"
7+
8+
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
9+
"github.com/lightningnetwork/lnd/tlv"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
// testCaseShutdown is a test case for the Shutdown message.
14+
type testCaseShutdown struct {
15+
// Msg is the message to be encoded and decoded.
16+
Msg Shutdown
17+
18+
// ExpectEncodeError is a flag that indicates whether we expect the
19+
// encoding of the message to fail.
20+
ExpectEncodeError bool
21+
}
22+
23+
// generateShutdownTestCases generates a set of Shutdown message test cases.
24+
func generateShutdownTestCases(t *testing.T) []testCaseShutdown {
25+
// Firstly, we'll set basic values for the message fields.
26+
//
27+
// Generate random channel ID.
28+
chanIDBytes, err := generateRandomBytes(32)
29+
require.NoError(t, err)
30+
31+
var chanID ChannelID
32+
copy(chanID[:], chanIDBytes)
33+
34+
// Generate random payment preimage.
35+
paymentPreimageBytes, err := generateRandomBytes(32)
36+
require.NoError(t, err)
37+
38+
var paymentPreimage [32]byte
39+
copy(paymentPreimage[:], paymentPreimageBytes)
40+
41+
deliveryAddr, err := generateRandomBytes(16)
42+
require.NoError(t, err)
43+
44+
// Define custom records.
45+
recordKey1 := uint64(MinCustomRecordsTlvType + 1)
46+
recordValue1, err := generateRandomBytes(10)
47+
require.NoError(t, err)
48+
49+
recordKey2 := uint64(MinCustomRecordsTlvType + 2)
50+
recordValue2, err := generateRandomBytes(10)
51+
require.NoError(t, err)
52+
53+
customRecords := CustomRecords{
54+
recordKey1: recordValue1,
55+
recordKey2: recordValue2,
56+
}
57+
58+
dummyPubKey, err := pubkeyFromHex(
59+
"0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d4" +
60+
"8236c39",
61+
)
62+
require.NoError(t, err)
63+
64+
muSig2Nonce, err := musig2.GenNonces(musig2.WithPublicKey(dummyPubKey))
65+
require.NoError(t, err)
66+
67+
// Construct an instance of extra data that contains records with TLV
68+
// types below the minimum custom records threshold and that lack
69+
// corresponding fields in the message struct. Content should persist in
70+
// the extra data field after encoding and decoding.
71+
var (
72+
recordBytes45 = []byte("recordBytes45")
73+
tlvRecord45 = tlv.NewPrimitiveRecord[tlv.TlvType45](
74+
recordBytes45,
75+
)
76+
77+
recordBytes55 = []byte("recordBytes55")
78+
tlvRecord55 = tlv.NewPrimitiveRecord[tlv.TlvType55](
79+
recordBytes55,
80+
)
81+
)
82+
83+
var extraData ExtraOpaqueData
84+
err = extraData.PackRecords(
85+
[]tlv.RecordProducer{&tlvRecord45, &tlvRecord55}...,
86+
)
87+
require.NoError(t, err)
88+
89+
return []testCaseShutdown{
90+
{
91+
Msg: Shutdown{
92+
ChannelID: chanID,
93+
CustomRecords: customRecords,
94+
ExtraData: extraData,
95+
Address: deliveryAddr,
96+
},
97+
},
98+
{
99+
Msg: Shutdown{
100+
ChannelID: chanID,
101+
CustomRecords: customRecords,
102+
ExtraData: extraData,
103+
Address: deliveryAddr,
104+
ShutdownNonce: SomeShutdownNonce(
105+
muSig2Nonce.PubNonce,
106+
),
107+
},
108+
},
109+
}
110+
}
111+
112+
// TestShutdownEncodeDecode tests Shutdown message encoding and decoding for all
113+
// supported field values.
114+
func TestShutdownEncodeDecode(t *testing.T) {
115+
t.Parallel()
116+
117+
// Generate test cases.
118+
testCases := generateShutdownTestCases(t)
119+
120+
// Execute test cases.
121+
for tcIdx, tc := range testCases {
122+
t.Run(fmt.Sprintf("testcase-%d", tcIdx), func(t *testing.T) {
123+
// Encode test case message.
124+
var buf bytes.Buffer
125+
err := tc.Msg.Encode(&buf, 0)
126+
127+
// Check if we expect an encoding error.
128+
if tc.ExpectEncodeError {
129+
require.Error(t, err)
130+
return
131+
}
132+
133+
require.NoError(t, err)
134+
135+
// Decode the encoded message bytes message.
136+
var actualMsg Shutdown
137+
decodeReader := bytes.NewReader(buf.Bytes())
138+
err = actualMsg.Decode(decodeReader, 0)
139+
require.NoError(t, err)
140+
141+
// Compare the two messages to ensure equality.
142+
require.Equal(t, tc.Msg, actualMsg)
143+
})
144+
}
145+
}

0 commit comments

Comments
 (0)