Skip to content

Commit b93d288

Browse files
committed
lnwire: add encoding/decoding unit tests for UpdateAddHtlc message
1 parent 2b3618c commit b93d288

File tree

2 files changed

+245
-0
lines changed

2 files changed

+245
-0
lines changed

lnwire/test_utils.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package lnwire
2+
3+
import (
4+
"crypto/rand"
5+
"encoding/hex"
6+
7+
"github.com/btcsuite/btcd/btcec/v2"
8+
)
9+
10+
// pubkeyFromHex parses a Bitcoin public key from a hex encoded string.
11+
func pubkeyFromHex(keyHex string) (*btcec.PublicKey, error) {
12+
bytes, err := hex.DecodeString(keyHex)
13+
if err != nil {
14+
return nil, err
15+
}
16+
17+
return btcec.ParsePubKey(bytes)
18+
}
19+
20+
// generateRandomBytes returns a slice of n random bytes.
21+
func generateRandomBytes(n int) ([]byte, error) {
22+
b := make([]byte, n)
23+
_, err := rand.Read(b)
24+
if err != nil {
25+
return nil, err
26+
}
27+
28+
return b, nil
29+
}

lnwire/update_add_htlc_test.go

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
package lnwire
2+
3+
import (
4+
"bytes"
5+
"testing"
6+
7+
"github.com/lightningnetwork/lnd/tlv"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
// testCase is a test case for the UpdateAddHTLC message.
12+
type testCase struct {
13+
// Msg is the message to be encoded and decoded.
14+
Msg UpdateAddHTLC
15+
16+
// ExpectEncodeError is a flag that indicates whether we expect the
17+
// encoding of the message to fail.
18+
ExpectEncodeError bool
19+
}
20+
21+
// generateTestCases generates a set of UpdateAddHTLC message test cases.
22+
func generateTestCases(t *testing.T) []testCase {
23+
// Firstly, we'll set basic values for the message fields.
24+
//
25+
// Generate random channel ID.
26+
chanIDBytes, err := generateRandomBytes(32)
27+
require.NoError(t, err)
28+
29+
var chanID ChannelID
30+
copy(chanID[:], chanIDBytes)
31+
32+
// Generate random payment hash.
33+
paymentHashBytes, err := generateRandomBytes(32)
34+
require.NoError(t, err)
35+
36+
var paymentHash [32]byte
37+
copy(paymentHash[:], paymentHashBytes)
38+
39+
// Generate random onion blob.
40+
onionBlobBytes, err := generateRandomBytes(OnionPacketSize)
41+
require.NoError(t, err)
42+
43+
var onionBlob [OnionPacketSize]byte
44+
copy(onionBlob[:], onionBlobBytes)
45+
46+
// Define the blinding point.
47+
blinding, err := pubkeyFromHex(
48+
"0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d4" +
49+
"8236c39",
50+
)
51+
require.NoError(t, err)
52+
53+
blindingPoint := tlv.SomeRecordT(
54+
tlv.NewPrimitiveRecord[BlindingPointTlvType](blinding),
55+
)
56+
57+
// Define custom records.
58+
recordKey1 := uint64(MinCustomRecordsTlvType + 1)
59+
recordValue1, err := generateRandomBytes(10)
60+
require.NoError(t, err)
61+
62+
recordKey2 := uint64(MinCustomRecordsTlvType + 2)
63+
recordValue2, err := generateRandomBytes(10)
64+
require.NoError(t, err)
65+
66+
customRecords := CustomRecords{
67+
recordKey1: recordValue1,
68+
recordKey2: recordValue2,
69+
}
70+
71+
// Construct an instance of extra data that contains records with TLV
72+
// types below the minimum custom records threshold and that lack
73+
// corresponding fields in the message struct. Content should persist in
74+
// the extra data field after encoding and decoding.
75+
var (
76+
recordBytes45 = []byte("recordBytes45")
77+
tlvRecord45 = tlv.NewPrimitiveRecord[tlv.TlvType45](
78+
recordBytes45,
79+
)
80+
81+
recordBytes55 = []byte("recordBytes55")
82+
tlvRecord55 = tlv.NewPrimitiveRecord[tlv.TlvType55](
83+
recordBytes55,
84+
)
85+
)
86+
87+
var extraData ExtraOpaqueData
88+
err = extraData.PackRecords(
89+
[]tlv.RecordProducer{&tlvRecord45, &tlvRecord55}...,
90+
)
91+
require.NoError(t, err)
92+
93+
// Define test cases.
94+
testCases := make([]testCase, 0)
95+
96+
testCases = append(testCases, testCase{
97+
Msg: UpdateAddHTLC{
98+
ChanID: chanID,
99+
ID: 42,
100+
Amount: MilliSatoshi(1000),
101+
PaymentHash: paymentHash,
102+
Expiry: 43,
103+
OnionBlob: onionBlob,
104+
BlindingPoint: blindingPoint,
105+
CustomRecords: customRecords,
106+
ExtraData: extraData,
107+
},
108+
})
109+
110+
// Add a test case where the blinding point field is not populated.
111+
testCases = append(testCases, testCase{
112+
Msg: UpdateAddHTLC{
113+
ChanID: chanID,
114+
ID: 42,
115+
Amount: MilliSatoshi(1000),
116+
PaymentHash: paymentHash,
117+
Expiry: 43,
118+
OnionBlob: onionBlob,
119+
CustomRecords: customRecords,
120+
},
121+
})
122+
123+
// Add a test case where the custom records field is not populated.
124+
testCases = append(testCases, testCase{
125+
Msg: UpdateAddHTLC{
126+
ChanID: chanID,
127+
ID: 42,
128+
Amount: MilliSatoshi(1000),
129+
PaymentHash: paymentHash,
130+
Expiry: 43,
131+
OnionBlob: onionBlob,
132+
BlindingPoint: blindingPoint,
133+
},
134+
})
135+
136+
// Add a case where the custom records are invlaid.
137+
invalidCustomRecords := CustomRecords{
138+
MinCustomRecordsTlvType - 1: recordValue1,
139+
}
140+
141+
testCases = append(testCases, testCase{
142+
Msg: UpdateAddHTLC{
143+
ChanID: chanID,
144+
ID: 42,
145+
Amount: MilliSatoshi(1000),
146+
PaymentHash: paymentHash,
147+
Expiry: 43,
148+
OnionBlob: onionBlob,
149+
BlindingPoint: blindingPoint,
150+
CustomRecords: invalidCustomRecords,
151+
},
152+
ExpectEncodeError: true,
153+
})
154+
155+
return testCases
156+
}
157+
158+
// TestUpdateAddHtlcEncodeDecode tests UpdateAddHTLC message encoding and
159+
// decoding for all supported field values.
160+
func TestUpdateAddHtlcEncodeDecode(t *testing.T) {
161+
t.Parallel()
162+
163+
// Generate test cases.
164+
testCases := generateTestCases(t)
165+
166+
// Execute test cases.
167+
for tcIdx, tc := range testCases {
168+
t.Log("Running test case", tcIdx)
169+
170+
// Encode test case message.
171+
var buf bytes.Buffer
172+
err := tc.Msg.Encode(&buf, 0)
173+
174+
// Check if we expect an encoding error.
175+
if tc.ExpectEncodeError {
176+
require.Error(t, err)
177+
continue
178+
}
179+
require.NoError(t, err)
180+
181+
// Decode the encoded message bytes message.
182+
var actualMsg UpdateAddHTLC
183+
decodeReader := bytes.NewReader(buf.Bytes())
184+
err = actualMsg.Decode(decodeReader, 0)
185+
require.NoError(t, err)
186+
187+
// Compare the two messages to ensure equality one field at a
188+
// time.
189+
require.Equal(t, tc.Msg.ChanID, actualMsg.ChanID)
190+
require.Equal(t, tc.Msg.ID, actualMsg.ID)
191+
require.Equal(t, tc.Msg.Amount, actualMsg.Amount)
192+
require.Equal(t, tc.Msg.PaymentHash, actualMsg.PaymentHash)
193+
require.Equal(t, tc.Msg.OnionBlob, actualMsg.OnionBlob)
194+
require.Equal(t, tc.Msg.BlindingPoint, actualMsg.BlindingPoint)
195+
196+
// Check that the custom records field is as expected.
197+
if len(tc.Msg.CustomRecords) == 0 {
198+
require.Len(t, actualMsg.CustomRecords, 0)
199+
} else {
200+
require.Equal(
201+
t, tc.Msg.CustomRecords,
202+
actualMsg.CustomRecords,
203+
)
204+
}
205+
206+
// Check that the extra data field is as expected.
207+
if len(tc.Msg.ExtraData) == 0 {
208+
require.Len(t, actualMsg.ExtraData, 0)
209+
} else {
210+
require.Equal(
211+
t, tc.Msg.ExtraData,
212+
actualMsg.ExtraData,
213+
)
214+
}
215+
}
216+
}

0 commit comments

Comments
 (0)