Skip to content

Commit 78f31da

Browse files
committed
lnwire: refactor custom vs official tlv parsing into new func
We'll use this to update CommitSig in the next commit.
1 parent e4c97d2 commit 78f31da

File tree

3 files changed

+118
-24
lines changed

3 files changed

+118
-24
lines changed

lnwire/extra_bytes.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
package lnwire
22

3+
// For some reason golangci-lint has a false positive on the sort order of the
4+
// imports for the new "maps" package... We need the nolint directive here to
5+
// ignore that.
6+
//
7+
//nolint:gci
38
import (
49
"bytes"
510
"fmt"
611
"io"
12+
"maps"
713

814
"github.com/lightningnetwork/lnd/tlv"
915
)
@@ -194,3 +200,57 @@ func EncodeMessageExtraData(extraData *ExtraOpaqueData,
194200
// are all properly sorted.
195201
return extraData.PackRecords(recordProducers...)
196202
}
203+
204+
// wireTlvMap is a struct that holds the official records and custom records in
205+
// a TLV type map. This is useful for ensuring that the set of custom TLV
206+
// records are handled properly and don't overlap with the official records.
207+
type wireTlvMap struct {
208+
// officialTypes is the set of official records that are defined in the
209+
// spec.
210+
officialTypes tlv.TypeMap
211+
212+
// customTypes is the set of custom records that are not defined in
213+
// spec, and are used by higher level applications.
214+
customTypes tlv.TypeMap
215+
}
216+
217+
// newWireTlvMap creates a new tlv.TypeMap from the given set of parsed TLV
218+
// records. A struct with two maps are returned:
219+
//
220+
// 1. officialTypes: the set of official records that are defined in the
221+
// spec.
222+
//
223+
// 2. customTypes: the set of custom records that are not defined in
224+
// the spec.
225+
func newWireTlvMap(typeMap tlv.TypeMap) wireTlvMap {
226+
officialRecords := maps.Clone(typeMap)
227+
228+
// Any records from the extra data TLV map which are in the custom
229+
// records TLV type range will be included in the custom records field
230+
// and removed from the extra data field.
231+
customRecordsTlvMap := make(tlv.TypeMap, len(typeMap))
232+
for k, v := range typeMap {
233+
// Skip records that are not in the custom records TLV type
234+
// range.
235+
if k < MinCustomRecordsTlvType {
236+
continue
237+
}
238+
239+
// Include the record in the custom records map.
240+
customRecordsTlvMap[k] = v
241+
242+
// Now that the record is included in the custom records map,
243+
// we can remove it from the extra data TLV map.
244+
delete(officialRecords, k)
245+
}
246+
247+
return wireTlvMap{
248+
officialTypes: officialRecords,
249+
customTypes: customRecordsTlvMap,
250+
}
251+
}
252+
253+
// Len returns the total number of records in the wireTlvMap.
254+
func (w *wireTlvMap) Len() int {
255+
return len(w.officialTypes) + len(w.customTypes)
256+
}

lnwire/extra_bytes_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@ import (
77
"testing"
88
"testing/quick"
99

10+
"github.com/lightningnetwork/lnd/fn"
1011
"github.com/lightningnetwork/lnd/tlv"
1112
"github.com/stretchr/testify/require"
13+
"golang.org/x/exp/maps"
14+
"pgregory.net/rapid"
1215
)
1316

1417
// TestExtraOpaqueDataEncodeDecode tests that we're able to encode/decode
@@ -206,3 +209,47 @@ func TestPackRecords(t *testing.T) {
206209
require.Equal(t, recordBytes2, extractedRecords[tlvType2.TypeVal()])
207210
require.Equal(t, recordBytes3, extractedRecords[tlvType3.TypeVal()])
208211
}
212+
213+
// TestNewWireTlvMap tests the newWireTlvMap function using property-based
214+
// testing.
215+
func TestNewWireTlvMap(t *testing.T) {
216+
rapid.Check(t, func(t *rapid.T) {
217+
// Make a random type map, using the generic Make which'll
218+
// figure out what type to generate.
219+
tlvTypeMap := rapid.Make[tlv.TypeMap]().Draw(t, "typeMap")
220+
221+
// Create a wireTlvMap from the generated type map, this'll
222+
// operate on our random input.
223+
result := newWireTlvMap(tlvTypeMap)
224+
225+
// Property 1: The sum of lengths of officialTypes and
226+
// customTypes should equal the length of the input typeMap.
227+
require.Equal(t, len(tlvTypeMap), result.Len())
228+
229+
// Property 2: All types in customTypes should be >=
230+
// MinCustomRecordsTlvType.
231+
require.True(t, fn.All(func(k tlv.Type) bool {
232+
return uint64(k) >= uint64(MinCustomRecordsTlvType)
233+
}, maps.Keys(result.customTypes)))
234+
235+
// Property 3: All types in officialTypes should be <
236+
// MinCustomRecordsTlvType.
237+
require.True(t, fn.All(func(k tlv.Type) bool {
238+
return uint64(k) < uint64(MinCustomRecordsTlvType)
239+
}, maps.Keys(result.officialTypes)))
240+
241+
// Property 4: The union of officialTypes and customTypes
242+
// should equal the input typeMap.
243+
unionMap := make(tlv.TypeMap)
244+
maps.Copy(unionMap, result.officialTypes)
245+
maps.Copy(unionMap, result.customTypes)
246+
require.Equal(t, tlvTypeMap, unionMap)
247+
248+
// Property 5: No type should appear in both officialTypes and
249+
// customTypes.
250+
require.True(t, fn.All(func(k tlv.Type) bool {
251+
_, exists := result.officialTypes[k]
252+
return !exists
253+
}, maps.Keys(result.customTypes)))
254+
})
255+
}

lnwire/update_add_htlc.go

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -131,29 +131,14 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error {
131131
delete(extraDataTlvMap, c.BlindingPoint.TlvType())
132132
}
133133

134-
// Any records from the extra data TLV map which are in the custom
135-
// records TLV type range will be included in the custom records field
136-
// and removed from the extra data field.
137-
customRecordsTlvMap := make(tlv.TypeMap, len(extraDataTlvMap))
138-
for k, v := range extraDataTlvMap {
139-
// Skip records that are not in the custom records TLV type
140-
// range.
141-
if k < MinCustomRecordsTlvType {
142-
continue
143-
}
144-
145-
// Include the record in the custom records map.
146-
customRecordsTlvMap[k] = v
147-
148-
// Now that the record is included in the custom records map,
149-
// we can remove it from the extra data TLV map.
150-
delete(extraDataTlvMap, k)
151-
}
134+
// Parse through the remaining extra data map to separate the custom
135+
// records, from the set of official records.
136+
tlvTypes := newWireTlvMap(extraDataTlvMap)
152137

153138
// Set the custom records field to the custom records specific TLV
154139
// record map.
155140
customRecords, err := NewCustomRecordsFromTlvTypeMap(
156-
customRecordsTlvMap,
141+
tlvTypes.customTypes,
157142
)
158143
if err != nil {
159144
return err
@@ -162,21 +147,23 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error {
162147

163148
// Set custom records to nil if we didn't parse anything out of it so
164149
// that we can use assert.Equal in tests.
165-
if len(customRecordsTlvMap) == 0 {
150+
if len(customRecords) == 0 {
166151
c.CustomRecords = nil
167152
}
168153

169154
// Set extra data to nil if we didn't parse anything out of it so that
170155
// we can use assert.Equal in tests.
171-
if len(extraDataTlvMap) == 0 {
156+
if len(tlvTypes.officialTypes) == 0 {
172157
c.ExtraData = nil
173158
return nil
174159
}
175160

176161
// Encode the remaining records back into the extra data field. These
177-
// records are not in the custom records TLV type range and do not
178-
// have associated fields in the UpdateAddHTLC struct.
179-
c.ExtraData, err = NewExtraOpaqueDataFromTlvTypeMap(extraDataTlvMap)
162+
// records are not in the custom records TLV type range and do not have
163+
// associated fields in the UpdateAddHTLC struct.
164+
c.ExtraData, err = NewExtraOpaqueDataFromTlvTypeMap(
165+
tlvTypes.officialTypes,
166+
)
180167
if err != nil {
181168
return err
182169
}

0 commit comments

Comments
 (0)