77 "io"
88 "sort"
99
10- "github.com/btcsuite/btcd/btcec/v2 "
10+ sphinx "github.com/lightningnetwork/lightning-onion "
1111 "github.com/lightningnetwork/lnd/tlv"
1212)
1313
@@ -51,26 +51,28 @@ var (
5151type OnionMessagePayload struct {
5252 // ReplyPath contains a blinded path that can be used to respond to an
5353 // onion message.
54- ReplyPath * ReplyPath
54+ ReplyPath * sphinx. BlindedPath
5555
5656 // EncryptedData contains encrypted data for the recipient.
5757 EncryptedData []byte
5858
59- // FinalHopPayloads contains any tlvs with type > 64 that
60- FinalHopPayloads []* FinalHopPayload
59+ // FinalHopTLVs contains any tlvs with type >= 64 that
60+ FinalHopTLVs []* FinalHopTLV
6161}
6262
6363// NewOnionMessagePayload creates a new OnionMessagePayload.
6464func NewOnionMessagePayload () * OnionMessagePayload {
6565 return & OnionMessagePayload {}
6666}
6767
68- // Encode encodes an onion message's final payload.
68+ // Encode encodes an onion message's payload.
69+ //
70+ // This is part of the lnwire.Message interface.
6971func (o * OnionMessagePayload ) Encode () ([]byte , error ) {
7072 var records []tlv.Record
7173
7274 if o .ReplyPath != nil {
73- records = append (records , o .ReplyPath . record ( ))
75+ records = append (records , replyPathRecord ( o .ReplyPath ))
7476 }
7577
7678 if len (o .EncryptedData ) != 0 {
@@ -80,16 +82,16 @@ func (o *OnionMessagePayload) Encode() ([]byte, error) {
8082 records = append (records , record )
8183 }
8284
83- for _ , finalHopPayload := range o .FinalHopPayloads {
84- if err := finalHopPayload .Validate (); err != nil {
85+ for _ , finalHopTLV := range o .FinalHopTLVs {
86+ if err := finalHopTLV .Validate (); err != nil {
8587 return nil , err
8688 }
8789
8890 // Create a primitive record that just writes the final hop
89- // payload 's bytes directly . The creating function should have
91+ // tlv 's bytes as-is . The creating function should have
9092 // encoded the value correctly.
9193 record := tlv .MakePrimitiveRecord (
92- finalHopPayload .TLVType , & finalHopPayload .Value ,
94+ finalHopTLV .TLVType , & finalHopTLV .Value ,
9395 )
9496 records = append (records , record )
9597 }
@@ -112,27 +114,27 @@ func (o *OnionMessagePayload) Encode() ([]byte, error) {
112114}
113115
114116// Decode decodes an onion message's payload.
115- func ( o * OnionMessagePayload ) Decode ( r io. Reader ) ( * OnionMessagePayload ,
116- map [tlv. Type ][] byte , error ) {
117-
117+ //
118+ // This is part of the lnwire.Message interface.
119+ func ( o * OnionMessagePayload ) Decode ( r io. Reader ) ( map [tlv. Type ][] byte , error ) {
118120 var (
119- invoicePayload = & FinalHopPayload {
121+ invoicePayload = & FinalHopTLV {
120122 TLVType : InvoiceNamespaceType ,
121123 }
122124
123- invoiceErrorPayload = & FinalHopPayload {
125+ invoiceErrorPayload = & FinalHopTLV {
124126 TLVType : InvoiceErrorNamespaceType ,
125127 }
126128
127- invoiceRequestPayload = & FinalHopPayload {
129+ invoiceRequestPayload = & FinalHopTLV {
128130 TLVType : InvoiceRequestNamespaceType ,
129131 }
130132 )
131133 // Create a non-nil entry so that we can directly decode into it.
132- o .ReplyPath = & ReplyPath {}
134+ o .ReplyPath = & sphinx. BlindedPath {}
133135
134136 records := []tlv.Record {
135- o .ReplyPath . record ( ),
137+ replyPathRecord ( o .ReplyPath ),
136138 tlv .MakePrimitiveRecord (
137139 encryptedDataTLVType , & o .EncryptedData ,
138140 ),
@@ -160,12 +162,12 @@ func (o *OnionMessagePayload) Decode(r io.Reader) (*OnionMessagePayload,
160162
161163 stream , err := tlv .NewStream (records ... )
162164 if err != nil {
163- return nil , nil , fmt .Errorf ("new stream: %w" , err )
165+ return nil , fmt .Errorf ("new stream: %w" , err )
164166 }
165167
166168 tlvMap , err := stream .DecodeWithParsedTypesP2P (r )
167169 if err != nil {
168- return nil , tlvMap , fmt .Errorf ("decode stream: %w" , err )
170+ return tlvMap , fmt .Errorf ("decode stream: %w" , err )
169171 }
170172
171173 // If our reply path wasn't populated, replace it with a nil entry.
@@ -190,13 +192,13 @@ func (o *OnionMessagePayload) Decode(r io.Reader) (*OnionMessagePayload,
190192 }
191193
192194 // Add the payload to our message's final hop payloads.
193- payload := & FinalHopPayload {
195+ payload := & FinalHopTLV {
194196 TLVType : tlvType ,
195197 Value : tlvBytes ,
196198 }
197199
198- o .FinalHopPayloads = append (
199- o .FinalHopPayloads , payload ,
200+ o .FinalHopTLVs = append (
201+ o .FinalHopTLVs , payload ,
200202 )
201203 }
202204
@@ -205,37 +207,37 @@ func (o *OnionMessagePayload) Decode(r io.Reader) (*OnionMessagePayload,
205207 // have been added in the loop above, because we recognized the TLV so
206208 // len(tlvMap[invoiceType].tlvBytes) will be zero (thus, skipped above).
207209 if _ , ok := tlvMap [InvoiceNamespaceType ]; ok {
208- o .FinalHopPayloads = append (
209- o .FinalHopPayloads , invoicePayload ,
210+ o .FinalHopTLVs = append (
211+ o .FinalHopTLVs , invoicePayload ,
210212 )
211213 }
212214
213215 if _ , ok := tlvMap [InvoiceErrorNamespaceType ]; ok {
214- o .FinalHopPayloads = append (
215- o .FinalHopPayloads , invoiceErrorPayload ,
216+ o .FinalHopTLVs = append (
217+ o .FinalHopTLVs , invoiceErrorPayload ,
216218 )
217219 }
218220
219221 if _ , ok := tlvMap [InvoiceRequestNamespaceType ]; ok {
220- o .FinalHopPayloads = append (
221- o .FinalHopPayloads , invoiceRequestPayload ,
222+ o .FinalHopTLVs = append (
223+ o .FinalHopTLVs , invoiceRequestPayload ,
222224 )
223225 }
224226
225227 // Iteration through maps occurs in random order - sort final hop
226- // payloads in ascending order to make this decoding function
228+ // TLVs in ascending order to make this decoding function
227229 // deterministic.
228- sort .SliceStable (o .FinalHopPayloads , func (i , j int ) bool {
229- return o .FinalHopPayloads [i ].TLVType <
230- o .FinalHopPayloads [j ].TLVType
230+ sort .SliceStable (o .FinalHopTLVs , func (i , j int ) bool {
231+ return o .FinalHopTLVs [i ].TLVType <
232+ o .FinalHopTLVs [j ].TLVType
231233 })
232234
233- return o , tlvMap , nil
235+ return tlvMap , nil
234236}
235237
236- // FinalHopPayload contains values reserved for the final hop, which are just
238+ // FinalHopTLV contains values reserved for the final hop, which are just
237239// directly read from the tlv stream.
238- type FinalHopPayload struct {
240+ type FinalHopTLV struct {
239241 // TLVType is the type for the payload.
240242 TLVType tlv.Type
241243
@@ -248,60 +250,51 @@ type FinalHopPayload struct {
248250// Validate performs validation of items added to the final hop's payload in an
249251// onion. This function returns an error if a tlv is not within the range
250252// reserved for final payload.
251- func (f * FinalHopPayload ) Validate () error {
253+ func (f * FinalHopTLV ) Validate () error {
252254 if f .TLVType < finalHopPayloadStart {
253255 return fmt .Errorf ("%w: %v" , ErrNotFinalPayload , f .TLVType )
254256 }
255257
256258 return nil
257259}
258260
259- // ReplyPath is a blinded path used to respond to onion messages.
260- type ReplyPath struct {
261- // FirstNodeID is the pubkey of the first node in the reply path.
262- FirstNodeID * btcec.PublicKey
263-
264- // BlindingPoint is the ephemeral pubkey used in route blinding.
265- BlindingPoint * btcec.PublicKey
266-
267- // Hops is a set of blinded hops in the route, starting with the blinded
268- // introduction node (first node id).
269- Hops []* BlindedHop
270- }
271-
272- // record produces a tlv record for a reply path.
273- func (r * ReplyPath ) record () tlv.Record {
261+ // replyPathRecord produces a tlv record for a reply path.
262+ func replyPathRecord (r * sphinx.BlindedPath ) tlv.Record {
274263 return tlv .MakeDynamicRecord (
275- replyPathType , r , r .size , encodeReplyPath , decodeReplyPath ,
264+ replyPathType , r , replyPathSize (r ), encodeReplyPath ,
265+ decodeReplyPath ,
276266 )
277267}
278268
279- // size returns the encoded size of our reply path.
280- func (r * ReplyPath ) size () uint64 {
281- // First node pubkey 33 + blinding point pubkey 33 + 1 byte for uint8
282- // for our hop count.
283- size := uint64 (33 + 33 + 1 )
269+ // replyPathSize returns the encoded size of a reply path.
270+ func replyPathSize (r * sphinx.BlindedPath ) func () uint64 {
271+ return func () uint64 {
272+ // First node pubkey 33 + blinding point pubkey 33 + 1 byte for
273+ // uint8 for our hop count.
274+ size := uint64 (33 + 33 + 1 )
284275
285- // Add each hop's size to our total.
286- for _ , hop := range r .Hops {
287- size += hop . size ( )
288- }
276+ // Add each hop's size to our total.
277+ for _ , hop := range r .BlindedHops {
278+ size += blindedHopSize ( hop )
279+ }
289280
290- return size
281+ return size
282+ }
291283}
292284
293285// encodeReplyPath encodes a reply path tlv.
294286func encodeReplyPath (w io.Writer , val interface {}, buf * [8 ]byte ) error {
295- if p , ok := val .(* ReplyPath ); ok {
296- if err := tlv .EPubKey (w , & p .FirstNodeID , buf ); err != nil {
287+ if p , ok := val .(* sphinx.BlindedPath ); ok {
288+ err := tlv .EPubKey (w , & p .IntroductionPoint , buf )
289+ if err != nil {
297290 return fmt .Errorf ("encode first node id: %w" , err )
298291 }
299292
300293 if err := tlv .EPubKey (w , & p .BlindingPoint , buf ); err != nil {
301- return fmt .Errorf ("encode blinded path : %w" , err )
294+ return fmt .Errorf ("encode blinding point : %w" , err )
302295 }
303296
304- hopCount := uint8 (len (p .Hops ))
297+ hopCount := uint8 (len (p .BlindedHops ))
305298 if hopCount == 0 {
306299 return ErrNoHops
307300 }
@@ -310,7 +303,7 @@ func encodeReplyPath(w io.Writer, val interface{}, buf *[8]byte) error {
310303 return fmt .Errorf ("encode hop count: %w" , err )
311304 }
312305
313- for i , hop := range p .Hops {
306+ for i , hop := range p .BlindedHops {
314307 if err := encodeBlindedHop (w , hop , buf ); err != nil {
315308 return fmt .Errorf ("hop %v: %w" , i , err )
316309 }
@@ -319,7 +312,7 @@ func encodeReplyPath(w io.Writer, val interface{}, buf *[8]byte) error {
319312 return nil
320313 }
321314
322- return tlv .NewTypeForEncodingErr (val , "*ReplyPath " )
315+ return tlv .NewTypeForEncodingErr (val , "*sphinx.BlindedPath " )
323316}
324317
325318// decodeReplyPath decodes a reply path tlv.
@@ -329,8 +322,8 @@ func decodeReplyPath(r io.Reader, val interface{}, buf *[8]byte,
329322 // If we have the correct type, and the length is sufficient (first node
330323 // pubkey (33) + blinding point (33) + hop count (1) = 67 bytes), decode
331324 // the reply path.
332- if p , ok := val .(* ReplyPath ); ok && l > 67 {
333- err := tlv .DPubKey (r , & p .FirstNodeID , buf , 33 )
325+ if p , ok := val .(* sphinx. BlindedPath ); ok && l > 67 {
326+ err := tlv .DPubKey (r , & p .IntroductionPoint , buf , 33 )
334327 if err != nil {
335328 return fmt .Errorf ("decode first id: %w" , err )
336329 }
@@ -350,62 +343,52 @@ func decodeReplyPath(r io.Reader, val interface{}, buf *[8]byte,
350343 }
351344
352345 for i := 0 ; i < int (hopCount ); i ++ {
353- hop := & BlindedHop {}
346+ hop := & sphinx. BlindedHopInfo {}
354347 if err := decodeBlindedHop (r , hop , buf ); err != nil {
355348 return fmt .Errorf ("decode hop: %w" , err )
356349 }
357350
358- p .Hops = append (p .Hops , hop )
351+ p .BlindedHops = append (p .BlindedHops , hop )
359352 }
360353
361354 return nil
362355 }
363356
364- return tlv .NewTypeForDecodingErr (val , "*ReplyPath" , l , l )
365- }
366-
367- // BlindedHop contains a blinded node ID and encrypted data used to send onion
368- // messages over blinded routes.
369- type BlindedHop struct {
370- // BlindedNodeID is the blinded node id of a node in the path.
371- BlindedNodeID * btcec.PublicKey
372-
373- // EncryptedData is the encrypted data to be included for the node.
374- EncryptedData []byte
357+ return tlv .NewTypeForDecodingErr (val , "*sphinx.BlindedPath" , l , l )
375358}
376359
377- // size returns the encoded size of a blinded hop.
378- func (b * BlindedHop ) size ( ) uint64 {
360+ // blindedHopSize returns the encoded size of a blinded hop.
361+ func blindedHopSize (b * sphinx. BlindedHopInfo ) uint64 {
379362 // 33 byte pubkey + 2 bytes uint16 length + var bytes.
380- return uint64 (33 + 2 + len (b .EncryptedData ))
363+ return uint64 (33 + 2 + len (b .CipherText ))
381364}
382365
383366// encodeBlindedHop encodes a blinded hop tlv.
384367func encodeBlindedHop (w io.Writer , val interface {}, buf * [8 ]byte ) error {
385- if b , ok := val .(* BlindedHop ); ok {
386- if err := tlv .EPubKey (w , & b .BlindedNodeID , buf ); err != nil {
368+ if b , ok := val .(* sphinx. BlindedHopInfo ); ok {
369+ if err := tlv .EPubKey (w , & b .BlindedNodePub , buf ); err != nil {
387370 return fmt .Errorf ("encode blinded id: %w" , err )
388371 }
389372
390- dataLen := uint16 (len (b .EncryptedData ))
373+ dataLen := uint16 (len (b .CipherText ))
391374 if err := tlv .EUint16 (w , & dataLen , buf ); err != nil {
392375 return fmt .Errorf ("data len: %w" , err )
393376 }
394377
395- if err := tlv .EVarBytes (w , & b .EncryptedData , buf ); err != nil {
378+ if err := tlv .EVarBytes (w , & b .CipherText , buf ); err != nil {
396379 return fmt .Errorf ("encode encrypted data: %w" , err )
397380 }
398381
399382 return nil
400383 }
401384
402- return tlv .NewTypeForEncodingErr (val , "*BlindedHop " )
385+ return tlv .NewTypeForEncodingErr (val , "*sphinx.BlindedHopInfo " )
403386}
404387
405388// decodeBlindedHop decodes a blinded hop tlv.
406389func decodeBlindedHop (r io.Reader , val interface {}, buf * [8 ]byte ) error {
407- if b , ok := val .(* BlindedHop ); ok {
408- err := tlv .DPubKey (r , & b .BlindedNodeID , buf , 33 )
390+ if b , ok := val .(* sphinx. BlindedHopInfo ); ok {
391+ err := tlv .DPubKey (r , & b .BlindedNodePub , buf , 33 )
409392 if err != nil {
410393 return fmt .Errorf ("decode blinded id: %w" , err )
411394 }
@@ -416,13 +399,13 @@ func decodeBlindedHop(r io.Reader, val interface{}, buf *[8]byte) error {
416399 return fmt .Errorf ("decode data len: %w" , err )
417400 }
418401
419- err = tlv .DVarBytes (r , & b .EncryptedData , buf , uint64 (dataLen ))
402+ err = tlv .DVarBytes (r , & b .CipherText , buf , uint64 (dataLen ))
420403 if err != nil {
421404 return fmt .Errorf ("decode data: %w" , err )
422405 }
423406
424407 return nil
425408 }
426409
427- return tlv .NewTypeForDecodingErr (val , "*BlindedHop " , 0 , 0 )
410+ return tlv .NewTypeForDecodingErr (val , "*sphinx.BlindedHopInfo " , 0 , 0 )
428411}
0 commit comments