Skip to content

Commit 82ae022

Browse files
committed
lnwire21: add custom records parsing
We add the new custom records encoding/decoding logic to the "frozen" lnwire21 package. We can do this because nothing uses this logic yet. If the custom records logic changes, the changes should _not_ be added to the lnwire21 version.
1 parent 33ab4b9 commit 82ae022

File tree

1 file changed

+263
-0
lines changed

1 file changed

+263
-0
lines changed
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
package lnwire
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"io"
7+
"sort"
8+
9+
"github.com/lightningnetwork/lnd/fn"
10+
"github.com/lightningnetwork/lnd/tlv"
11+
)
12+
13+
const (
14+
// MinCustomRecordsTlvType is the minimum custom records TLV type as
15+
// defined in BOLT 01.
16+
MinCustomRecordsTlvType = 65536
17+
)
18+
19+
// CustomRecords stores a set of custom key/value pairs. Map keys are TLV types
20+
// which must be greater than or equal to MinCustomRecordsTlvType.
21+
type CustomRecords map[uint64][]byte
22+
23+
// NewCustomRecords creates a new CustomRecords instance from a
24+
// tlv.TypeMap.
25+
func NewCustomRecords(tlvMap tlv.TypeMap) (CustomRecords, error) {
26+
// Make comparisons in unit tests easy by returning nil if the map is
27+
// empty.
28+
if len(tlvMap) == 0 {
29+
return nil, nil
30+
}
31+
32+
customRecords := make(CustomRecords, len(tlvMap))
33+
for k, v := range tlvMap {
34+
customRecords[uint64(k)] = v
35+
}
36+
37+
// Validate the custom records.
38+
err := customRecords.Validate()
39+
if err != nil {
40+
return nil, fmt.Errorf("custom records from tlv map "+
41+
"validation error: %w", err)
42+
}
43+
44+
return customRecords, nil
45+
}
46+
47+
// ParseCustomRecords creates a new CustomRecords instance from a tlv.Blob.
48+
func ParseCustomRecords(b tlv.Blob) (CustomRecords, error) {
49+
return ParseCustomRecordsFrom(bytes.NewReader(b))
50+
}
51+
52+
// ParseCustomRecordsFrom creates a new CustomRecords instance from a reader.
53+
func ParseCustomRecordsFrom(r io.Reader) (CustomRecords, error) {
54+
typeMap, err := DecodeRecords(r)
55+
if err != nil {
56+
return nil, fmt.Errorf("error decoding HTLC record: %w", err)
57+
}
58+
59+
return NewCustomRecords(typeMap)
60+
}
61+
62+
// Validate checks that all custom records are in the custom type range.
63+
func (c CustomRecords) Validate() error {
64+
if c == nil {
65+
return nil
66+
}
67+
68+
for key := range c {
69+
if key < MinCustomRecordsTlvType {
70+
return fmt.Errorf("custom records entry with TLV "+
71+
"type below min: %d", MinCustomRecordsTlvType)
72+
}
73+
}
74+
75+
return nil
76+
}
77+
78+
// Copy returns a copy of the custom records.
79+
func (c CustomRecords) Copy() CustomRecords {
80+
if c == nil {
81+
return nil
82+
}
83+
84+
customRecords := make(CustomRecords, len(c))
85+
for k, v := range c {
86+
customRecords[k] = v
87+
}
88+
89+
return customRecords
90+
}
91+
92+
// ExtendRecordProducers extends the given records slice with the custom
93+
// records. The resultant records slice will be sorted if the given records
94+
// slice contains TLV types greater than or equal to MinCustomRecordsTlvType.
95+
func (c CustomRecords) ExtendRecordProducers(
96+
producers []tlv.RecordProducer) ([]tlv.RecordProducer, error) {
97+
98+
// If the custom records are nil or empty, there is nothing to do.
99+
if len(c) == 0 {
100+
return producers, nil
101+
}
102+
103+
// Validate the custom records.
104+
err := c.Validate()
105+
if err != nil {
106+
return nil, err
107+
}
108+
109+
// Ensure that the existing records slice TLV types are not also present
110+
// in the custom records. If they are, the resultant extended records
111+
// slice would erroneously contain duplicate TLV types.
112+
for _, rp := range producers {
113+
record := rp.Record()
114+
recordTlvType := uint64(record.Type())
115+
116+
_, foundDuplicateTlvType := c[recordTlvType]
117+
if foundDuplicateTlvType {
118+
return nil, fmt.Errorf("custom records contains a TLV "+
119+
"type that is already present in the "+
120+
"existing records: %d", recordTlvType)
121+
}
122+
}
123+
124+
// Convert the custom records map to a TLV record producer slice and
125+
// append them to the exiting records slice.
126+
customRecordProducers := RecordsAsProducers(tlv.MapToRecords(c))
127+
producers = append(producers, customRecordProducers...)
128+
129+
// If the records slice which was given as an argument included TLV
130+
// values greater than or equal to the minimum custom records TLV type
131+
// we will sort the extended records slice to ensure that it is ordered
132+
// correctly.
133+
SortProducers(producers)
134+
135+
return producers, nil
136+
}
137+
138+
// RecordProducers returns a slice of record producers for the custom records.
139+
func (c CustomRecords) RecordProducers() []tlv.RecordProducer {
140+
// If the custom records are nil or empty, return an empty slice.
141+
if len(c) == 0 {
142+
return nil
143+
}
144+
145+
// Convert the custom records map to a TLV record producer slice.
146+
records := tlv.MapToRecords(c)
147+
148+
return RecordsAsProducers(records)
149+
}
150+
151+
// Serialize serializes the custom records into a byte slice.
152+
func (c CustomRecords) Serialize() ([]byte, error) {
153+
records := tlv.MapToRecords(c)
154+
return EncodeRecords(records)
155+
}
156+
157+
// SerializeTo serializes the custom records into the given writer.
158+
func (c CustomRecords) SerializeTo(w io.Writer) error {
159+
records := tlv.MapToRecords(c)
160+
return EncodeRecordsTo(w, records)
161+
}
162+
163+
// ProduceRecordsSorted converts a slice of record producers into a slice of
164+
// records and then sorts it by type.
165+
func ProduceRecordsSorted(recordProducers ...tlv.RecordProducer) []tlv.Record {
166+
records := fn.Map(func(producer tlv.RecordProducer) tlv.Record {
167+
return producer.Record()
168+
}, recordProducers)
169+
170+
// Ensure that the set of records are sorted before we attempt to
171+
// decode from the stream, to ensure they're canonical.
172+
tlv.SortRecords(records)
173+
174+
return records
175+
}
176+
177+
// SortProducers sorts the given record producers by their type.
178+
func SortProducers(producers []tlv.RecordProducer) {
179+
sort.Slice(producers, func(i, j int) bool {
180+
recordI := producers[i].Record()
181+
recordJ := producers[j].Record()
182+
return recordI.Type() < recordJ.Type()
183+
})
184+
}
185+
186+
// TlvMapToRecords converts a TLV map into a slice of records.
187+
func TlvMapToRecords(tlvMap tlv.TypeMap) []tlv.Record {
188+
tlvMapGeneric := make(map[uint64][]byte)
189+
for k, v := range tlvMap {
190+
tlvMapGeneric[uint64(k)] = v
191+
}
192+
193+
return tlv.MapToRecords(tlvMapGeneric)
194+
}
195+
196+
// RecordsAsProducers converts a slice of records into a slice of record
197+
// producers.
198+
func RecordsAsProducers(records []tlv.Record) []tlv.RecordProducer {
199+
return fn.Map(func(record tlv.Record) tlv.RecordProducer {
200+
return &record
201+
}, records)
202+
}
203+
204+
// EncodeRecords encodes the given records into a byte slice.
205+
func EncodeRecords(records []tlv.Record) ([]byte, error) {
206+
var buf bytes.Buffer
207+
if err := EncodeRecordsTo(&buf, records); err != nil {
208+
return nil, err
209+
}
210+
211+
return buf.Bytes(), nil
212+
}
213+
214+
// EncodeRecordsTo encodes the given records into the given writer.
215+
func EncodeRecordsTo(w io.Writer, records []tlv.Record) error {
216+
tlvStream, err := tlv.NewStream(records...)
217+
if err != nil {
218+
return err
219+
}
220+
221+
return tlvStream.Encode(w)
222+
}
223+
224+
// DecodeRecords decodes the given byte slice into the given records and returns
225+
// the rest as a TLV type map.
226+
func DecodeRecords(r io.Reader,
227+
records ...tlv.Record) (tlv.TypeMap, error) {
228+
229+
tlvStream, err := tlv.NewStream(records...)
230+
if err != nil {
231+
return nil, err
232+
}
233+
234+
return tlvStream.DecodeWithParsedTypes(r)
235+
}
236+
237+
// DecodeRecordsP2P decodes the given byte slice into the given records and
238+
// returns the rest as a TLV type map. This function is identical to
239+
// DecodeRecords except that the record size is capped at 65535.
240+
func DecodeRecordsP2P(r *bytes.Reader,
241+
records ...tlv.Record) (tlv.TypeMap, error) {
242+
243+
tlvStream, err := tlv.NewStream(records...)
244+
if err != nil {
245+
return nil, err
246+
}
247+
248+
return tlvStream.DecodeWithParsedTypesP2P(r)
249+
}
250+
251+
// AssertUniqueTypes asserts that the given records have unique types.
252+
func AssertUniqueTypes(r []tlv.Record) error {
253+
seen := make(fn.Set[tlv.Type], len(r))
254+
for _, record := range r {
255+
t := record.Type()
256+
if seen.Contains(t) {
257+
return fmt.Errorf("duplicate record type: %d", t)
258+
}
259+
seen.Add(t)
260+
}
261+
262+
return nil
263+
}

0 commit comments

Comments
 (0)