Skip to content

Commit 2c48076

Browse files
committed
fixup! lnwire: add OnionMessagePayload
1 parent 64d1879 commit 2c48076

File tree

1 file changed

+79
-96
lines changed

1 file changed

+79
-96
lines changed

lnwire/onion_msg_payload.go

Lines changed: 79 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import (
77
"io"
88
"sort"
99

10-
"github.com/btcsuite/btcd/btcec/v2"
10+
sphinx "github.com/lightningnetwork/lightning-onion"
1111
"github.com/lightningnetwork/lnd/tlv"
1212
)
1313

@@ -51,26 +51,28 @@ var (
5151
type OnionMessagePayload struct {
5252
// ReplyPath contains a blinded path that can be used to respond to an
5353
// onion message.
54-
ReplyPath *ReplyPath
54+
ReplyPath *sphinx.BlindedPath
5555

5656
// EncryptedData contains encrypted data for the recipient.
5757
EncryptedData []byte
5858

59-
// FinalHopPayloads contains any tlvs with type > 64 that
60-
FinalHopPayloads []*FinalHopPayload
59+
// FinalHopTLVs contains any tlvs with type >= 64 that
60+
FinalHopTLVs []*FinalHopTLV
6161
}
6262

6363
// NewOnionMessagePayload creates a new OnionMessagePayload.
6464
func NewOnionMessagePayload() *OnionMessagePayload {
6565
return &OnionMessagePayload{}
6666
}
6767

68-
// Encode encodes an onion message's final payload.
68+
// Encode encodes an onion message's payload.
69+
//
70+
// This is part of the lnwire.Message interface.
6971
func (o *OnionMessagePayload) Encode() ([]byte, error) {
7072
var records []tlv.Record
7173

7274
if o.ReplyPath != nil {
73-
records = append(records, o.ReplyPath.record())
75+
records = append(records, replyPathRecord(o.ReplyPath))
7476
}
7577

7678
if len(o.EncryptedData) != 0 {
@@ -80,16 +82,16 @@ func (o *OnionMessagePayload) Encode() ([]byte, error) {
8082
records = append(records, record)
8183
}
8284

83-
for _, finalHopPayload := range o.FinalHopPayloads {
84-
if err := finalHopPayload.Validate(); err != nil {
85+
for _, finalHopTLV := range o.FinalHopTLVs {
86+
if err := finalHopTLV.Validate(); err != nil {
8587
return nil, err
8688
}
8789

8890
// Create a primitive record that just writes the final hop
89-
// payload's bytes directly. The creating function should have
91+
// tlv's bytes as-is. The creating function should have
9092
// encoded the value correctly.
9193
record := tlv.MakePrimitiveRecord(
92-
finalHopPayload.TLVType, &finalHopPayload.Value,
94+
finalHopTLV.TLVType, &finalHopTLV.Value,
9395
)
9496
records = append(records, record)
9597
}
@@ -112,27 +114,27 @@ func (o *OnionMessagePayload) Encode() ([]byte, error) {
112114
}
113115

114116
// Decode decodes an onion message's payload.
115-
func (o *OnionMessagePayload) Decode(r io.Reader) (*OnionMessagePayload,
116-
map[tlv.Type][]byte, error) {
117-
117+
//
118+
// This is part of the lnwire.Message interface.
119+
func (o *OnionMessagePayload) Decode(r io.Reader) (map[tlv.Type][]byte, error) {
118120
var (
119-
invoicePayload = &FinalHopPayload{
121+
invoicePayload = &FinalHopTLV{
120122
TLVType: InvoiceNamespaceType,
121123
}
122124

123-
invoiceErrorPayload = &FinalHopPayload{
125+
invoiceErrorPayload = &FinalHopTLV{
124126
TLVType: InvoiceErrorNamespaceType,
125127
}
126128

127-
invoiceRequestPayload = &FinalHopPayload{
129+
invoiceRequestPayload = &FinalHopTLV{
128130
TLVType: InvoiceRequestNamespaceType,
129131
}
130132
)
131133
// Create a non-nil entry so that we can directly decode into it.
132-
o.ReplyPath = &ReplyPath{}
134+
o.ReplyPath = &sphinx.BlindedPath{}
133135

134136
records := []tlv.Record{
135-
o.ReplyPath.record(),
137+
replyPathRecord(o.ReplyPath),
136138
tlv.MakePrimitiveRecord(
137139
encryptedDataTLVType, &o.EncryptedData,
138140
),
@@ -160,12 +162,12 @@ func (o *OnionMessagePayload) Decode(r io.Reader) (*OnionMessagePayload,
160162

161163
stream, err := tlv.NewStream(records...)
162164
if err != nil {
163-
return nil, nil, fmt.Errorf("new stream: %w", err)
165+
return nil, fmt.Errorf("new stream: %w", err)
164166
}
165167

166168
tlvMap, err := stream.DecodeWithParsedTypesP2P(r)
167169
if err != nil {
168-
return nil, tlvMap, fmt.Errorf("decode stream: %w", err)
170+
return tlvMap, fmt.Errorf("decode stream: %w", err)
169171
}
170172

171173
// If our reply path wasn't populated, replace it with a nil entry.
@@ -190,13 +192,13 @@ func (o *OnionMessagePayload) Decode(r io.Reader) (*OnionMessagePayload,
190192
}
191193

192194
// Add the payload to our message's final hop payloads.
193-
payload := &FinalHopPayload{
195+
payload := &FinalHopTLV{
194196
TLVType: tlvType,
195197
Value: tlvBytes,
196198
}
197199

198-
o.FinalHopPayloads = append(
199-
o.FinalHopPayloads, payload,
200+
o.FinalHopTLVs = append(
201+
o.FinalHopTLVs, payload,
200202
)
201203
}
202204

@@ -205,37 +207,37 @@ func (o *OnionMessagePayload) Decode(r io.Reader) (*OnionMessagePayload,
205207
// have been added in the loop above, because we recognized the TLV so
206208
// len(tlvMap[invoiceType].tlvBytes) will be zero (thus, skipped above).
207209
if _, ok := tlvMap[InvoiceNamespaceType]; ok {
208-
o.FinalHopPayloads = append(
209-
o.FinalHopPayloads, invoicePayload,
210+
o.FinalHopTLVs = append(
211+
o.FinalHopTLVs, invoicePayload,
210212
)
211213
}
212214

213215
if _, ok := tlvMap[InvoiceErrorNamespaceType]; ok {
214-
o.FinalHopPayloads = append(
215-
o.FinalHopPayloads, invoiceErrorPayload,
216+
o.FinalHopTLVs = append(
217+
o.FinalHopTLVs, invoiceErrorPayload,
216218
)
217219
}
218220

219221
if _, ok := tlvMap[InvoiceRequestNamespaceType]; ok {
220-
o.FinalHopPayloads = append(
221-
o.FinalHopPayloads, invoiceRequestPayload,
222+
o.FinalHopTLVs = append(
223+
o.FinalHopTLVs, invoiceRequestPayload,
222224
)
223225
}
224226

225227
// Iteration through maps occurs in random order - sort final hop
226-
// payloads in ascending order to make this decoding function
228+
// TLVs in ascending order to make this decoding function
227229
// deterministic.
228-
sort.SliceStable(o.FinalHopPayloads, func(i, j int) bool {
229-
return o.FinalHopPayloads[i].TLVType <
230-
o.FinalHopPayloads[j].TLVType
230+
sort.SliceStable(o.FinalHopTLVs, func(i, j int) bool {
231+
return o.FinalHopTLVs[i].TLVType <
232+
o.FinalHopTLVs[j].TLVType
231233
})
232234

233-
return o, tlvMap, nil
235+
return tlvMap, nil
234236
}
235237

236-
// FinalHopPayload contains values reserved for the final hop, which are just
238+
// FinalHopTLV contains values reserved for the final hop, which are just
237239
// directly read from the tlv stream.
238-
type FinalHopPayload struct {
240+
type FinalHopTLV struct {
239241
// TLVType is the type for the payload.
240242
TLVType tlv.Type
241243

@@ -248,60 +250,51 @@ type FinalHopPayload struct {
248250
// Validate performs validation of items added to the final hop's payload in an
249251
// onion. This function returns an error if a tlv is not within the range
250252
// reserved for final payload.
251-
func (f *FinalHopPayload) Validate() error {
253+
func (f *FinalHopTLV) Validate() error {
252254
if f.TLVType < finalHopPayloadStart {
253255
return fmt.Errorf("%w: %v", ErrNotFinalPayload, f.TLVType)
254256
}
255257

256258
return nil
257259
}
258260

259-
// ReplyPath is a blinded path used to respond to onion messages.
260-
type ReplyPath struct {
261-
// FirstNodeID is the pubkey of the first node in the reply path.
262-
FirstNodeID *btcec.PublicKey
263-
264-
// BlindingPoint is the ephemeral pubkey used in route blinding.
265-
BlindingPoint *btcec.PublicKey
266-
267-
// Hops is a set of blinded hops in the route, starting with the blinded
268-
// introduction node (first node id).
269-
Hops []*BlindedHop
270-
}
271-
272-
// record produces a tlv record for a reply path.
273-
func (r *ReplyPath) record() tlv.Record {
261+
// replyPathRecord produces a tlv record for a reply path.
262+
func replyPathRecord(r *sphinx.BlindedPath) tlv.Record {
274263
return tlv.MakeDynamicRecord(
275-
replyPathType, r, r.size, encodeReplyPath, decodeReplyPath,
264+
replyPathType, r, replyPathSize(r), encodeReplyPath,
265+
decodeReplyPath,
276266
)
277267
}
278268

279-
// size returns the encoded size of our reply path.
280-
func (r *ReplyPath) size() uint64 {
281-
// First node pubkey 33 + blinding point pubkey 33 + 1 byte for uint8
282-
// for our hop count.
283-
size := uint64(33 + 33 + 1)
269+
// replyPathSize returns the encoded size of a reply path.
270+
func replyPathSize(r *sphinx.BlindedPath) func() uint64 {
271+
return func() uint64 {
272+
// First node pubkey 33 + blinding point pubkey 33 + 1 byte for
273+
// uint8 for our hop count.
274+
size := uint64(33 + 33 + 1)
284275

285-
// Add each hop's size to our total.
286-
for _, hop := range r.Hops {
287-
size += hop.size()
288-
}
276+
// Add each hop's size to our total.
277+
for _, hop := range r.BlindedHops {
278+
size += blindedHopSize(hop)
279+
}
289280

290-
return size
281+
return size
282+
}
291283
}
292284

293285
// encodeReplyPath encodes a reply path tlv.
294286
func encodeReplyPath(w io.Writer, val interface{}, buf *[8]byte) error {
295-
if p, ok := val.(*ReplyPath); ok {
296-
if err := tlv.EPubKey(w, &p.FirstNodeID, buf); err != nil {
287+
if p, ok := val.(*sphinx.BlindedPath); ok {
288+
err := tlv.EPubKey(w, &p.IntroductionPoint, buf)
289+
if err != nil {
297290
return fmt.Errorf("encode first node id: %w", err)
298291
}
299292

300293
if err := tlv.EPubKey(w, &p.BlindingPoint, buf); err != nil {
301-
return fmt.Errorf("encode blinded path: %w", err)
294+
return fmt.Errorf("encode blinding point: %w", err)
302295
}
303296

304-
hopCount := uint8(len(p.Hops))
297+
hopCount := uint8(len(p.BlindedHops))
305298
if hopCount == 0 {
306299
return ErrNoHops
307300
}
@@ -310,7 +303,7 @@ func encodeReplyPath(w io.Writer, val interface{}, buf *[8]byte) error {
310303
return fmt.Errorf("encode hop count: %w", err)
311304
}
312305

313-
for i, hop := range p.Hops {
306+
for i, hop := range p.BlindedHops {
314307
if err := encodeBlindedHop(w, hop, buf); err != nil {
315308
return fmt.Errorf("hop %v: %w", i, err)
316309
}
@@ -319,7 +312,7 @@ func encodeReplyPath(w io.Writer, val interface{}, buf *[8]byte) error {
319312
return nil
320313
}
321314

322-
return tlv.NewTypeForEncodingErr(val, "*ReplyPath")
315+
return tlv.NewTypeForEncodingErr(val, "*sphinx.BlindedPath")
323316
}
324317

325318
// decodeReplyPath decodes a reply path tlv.
@@ -329,8 +322,8 @@ func decodeReplyPath(r io.Reader, val interface{}, buf *[8]byte,
329322
// If we have the correct type, and the length is sufficient (first node
330323
// pubkey (33) + blinding point (33) + hop count (1) = 67 bytes), decode
331324
// the reply path.
332-
if p, ok := val.(*ReplyPath); ok && l > 67 {
333-
err := tlv.DPubKey(r, &p.FirstNodeID, buf, 33)
325+
if p, ok := val.(*sphinx.BlindedPath); ok && l > 67 {
326+
err := tlv.DPubKey(r, &p.IntroductionPoint, buf, 33)
334327
if err != nil {
335328
return fmt.Errorf("decode first id: %w", err)
336329
}
@@ -350,62 +343,52 @@ func decodeReplyPath(r io.Reader, val interface{}, buf *[8]byte,
350343
}
351344

352345
for i := 0; i < int(hopCount); i++ {
353-
hop := &BlindedHop{}
346+
hop := &sphinx.BlindedHopInfo{}
354347
if err := decodeBlindedHop(r, hop, buf); err != nil {
355348
return fmt.Errorf("decode hop: %w", err)
356349
}
357350

358-
p.Hops = append(p.Hops, hop)
351+
p.BlindedHops = append(p.BlindedHops, hop)
359352
}
360353

361354
return nil
362355
}
363356

364-
return tlv.NewTypeForDecodingErr(val, "*ReplyPath", l, l)
365-
}
366-
367-
// BlindedHop contains a blinded node ID and encrypted data used to send onion
368-
// messages over blinded routes.
369-
type BlindedHop struct {
370-
// BlindedNodeID is the blinded node id of a node in the path.
371-
BlindedNodeID *btcec.PublicKey
372-
373-
// EncryptedData is the encrypted data to be included for the node.
374-
EncryptedData []byte
357+
return tlv.NewTypeForDecodingErr(val, "*sphinx.BlindedPath", l, l)
375358
}
376359

377-
// size returns the encoded size of a blinded hop.
378-
func (b *BlindedHop) size() uint64 {
360+
// blindedHopSize returns the encoded size of a blinded hop.
361+
func blindedHopSize(b *sphinx.BlindedHopInfo) uint64 {
379362
// 33 byte pubkey + 2 bytes uint16 length + var bytes.
380-
return uint64(33 + 2 + len(b.EncryptedData))
363+
return uint64(33 + 2 + len(b.CipherText))
381364
}
382365

383366
// encodeBlindedHop encodes a blinded hop tlv.
384367
func encodeBlindedHop(w io.Writer, val interface{}, buf *[8]byte) error {
385-
if b, ok := val.(*BlindedHop); ok {
386-
if err := tlv.EPubKey(w, &b.BlindedNodeID, buf); err != nil {
368+
if b, ok := val.(*sphinx.BlindedHopInfo); ok {
369+
if err := tlv.EPubKey(w, &b.BlindedNodePub, buf); err != nil {
387370
return fmt.Errorf("encode blinded id: %w", err)
388371
}
389372

390-
dataLen := uint16(len(b.EncryptedData))
373+
dataLen := uint16(len(b.CipherText))
391374
if err := tlv.EUint16(w, &dataLen, buf); err != nil {
392375
return fmt.Errorf("data len: %w", err)
393376
}
394377

395-
if err := tlv.EVarBytes(w, &b.EncryptedData, buf); err != nil {
378+
if err := tlv.EVarBytes(w, &b.CipherText, buf); err != nil {
396379
return fmt.Errorf("encode encrypted data: %w", err)
397380
}
398381

399382
return nil
400383
}
401384

402-
return tlv.NewTypeForEncodingErr(val, "*BlindedHop")
385+
return tlv.NewTypeForEncodingErr(val, "*sphinx.BlindedHopInfo")
403386
}
404387

405388
// decodeBlindedHop decodes a blinded hop tlv.
406389
func decodeBlindedHop(r io.Reader, val interface{}, buf *[8]byte) error {
407-
if b, ok := val.(*BlindedHop); ok {
408-
err := tlv.DPubKey(r, &b.BlindedNodeID, buf, 33)
390+
if b, ok := val.(*sphinx.BlindedHopInfo); ok {
391+
err := tlv.DPubKey(r, &b.BlindedNodePub, buf, 33)
409392
if err != nil {
410393
return fmt.Errorf("decode blinded id: %w", err)
411394
}
@@ -416,13 +399,13 @@ func decodeBlindedHop(r io.Reader, val interface{}, buf *[8]byte) error {
416399
return fmt.Errorf("decode data len: %w", err)
417400
}
418401

419-
err = tlv.DVarBytes(r, &b.EncryptedData, buf, uint64(dataLen))
402+
err = tlv.DVarBytes(r, &b.CipherText, buf, uint64(dataLen))
420403
if err != nil {
421404
return fmt.Errorf("decode data: %w", err)
422405
}
423406

424407
return nil
425408
}
426409

427-
return tlv.NewTypeForDecodingErr(val, "*BlindedHop", 0, 0)
410+
return tlv.NewTypeForDecodingErr(val, "*sphinx.BlindedHopInfo", 0, 0)
428411
}

0 commit comments

Comments
 (0)