Skip to content

Commit 5575fec

Browse files
switchrpc: update TrackOnion rpc for new error handling
1 parent f18da2a commit 5575fec

File tree

3 files changed

+272
-190
lines changed

3 files changed

+272
-190
lines changed

itest/lnd_sendonion_test.go

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func testSendOnion(ht *lntest.HarnessTest) {
9393
HopPubkeys: onionResp.HopPubkeys,
9494
}
9595
trackResp := alice.RPC.TrackOnion(trackReq)
96-
require.Equal(ht, invoices[0].RPreimage, trackResp.Preimage)
96+
require.Equal(ht, invoices[0].RPreimage, trackResp.GetPreimage())
9797

9898
// The invoice should show as settled for Dave.
9999
ht.AssertInvoiceSettled(dave, invoices[0].PaymentAddr)
@@ -200,7 +200,7 @@ func testSendOnionTwice(ht *lntest.HarnessTest) {
200200
HopPubkeys: onionResp.HopPubkeys,
201201
}
202202
trackResp := alice.RPC.TrackOnion(trackReq)
203-
require.Equal(ht, preimage[:], trackResp.Preimage)
203+
require.Equal(ht, preimage[:], trackResp.GetPreimage())
204204

205205
// Now that the original HTLC attempt has settled, we'll send the same
206206
// onion again with the same attempt ID.
@@ -275,9 +275,6 @@ func testTrackOnion(ht *lntest.HarnessTest) {
275275
require.True(ht, resp.Success, "expected successful onion send")
276276
require.Empty(ht, resp.ErrorMessage, "unexpected failure to send onion")
277277

278-
serverErrorStr := ""
279-
clientErrorStr := ""
280-
281278
// Track the payment providing all necessary information to delegate
282279
// error decryption to the server. We expect this to fail as Dave is not
283280
// expecting payment.
@@ -288,10 +285,11 @@ func testTrackOnion(ht *lntest.HarnessTest) {
288285
HopPubkeys: onionResp.HopPubkeys,
289286
}
290287
trackResp := alice.RPC.TrackOnion(trackReq)
291-
require.NotEmpty(ht, trackResp.ErrorMessage,
292-
"expected onion tracking error")
288+
serverFailure := trackResp.GetFailureDetails()
289+
require.NotNil(ht, serverFailure, "expected onion tracking error")
293290

294-
serverErrorStr = trackResp.ErrorMessage
291+
serverFwdFailure := serverFailure.GetForwardingFailure()
292+
require.NotNil(ht, serverFwdFailure, "expected forwarding failure")
295293

296294
// Now we'll track the same payment attempt, but we'll specify that
297295
// we want to handle the error decryption ourselves client side.
@@ -300,16 +298,18 @@ func testTrackOnion(ht *lntest.HarnessTest) {
300298
PaymentHash: paymentHash,
301299
}
302300
trackResp = alice.RPC.TrackOnion(trackReq)
303-
require.NotNil(ht, trackResp.EncryptedError, "expected encrypted error")
301+
clientFailure := trackResp.GetFailureDetails()
302+
require.NotNil(ht, clientFailure, "expected client tracking error")
303+
304+
encryptedErrorBytes := clientFailure.GetEncryptedErrorData()
305+
require.NotNil(ht, encryptedErrorBytes, "expected encrypted error")
304306

305307
// Decrypt and inspect the error from the TrackOnion RPC response.
306308
sessionKey, _ := btcec.PrivKeyFromBytes(onionResp.SessionKey)
307309
var pubKeys []*btcec.PublicKey
308310
for _, keyBytes := range onionResp.HopPubkeys {
309311
pubKey, err := btcec.ParsePubKey(keyBytes)
310-
if err != nil {
311-
ht.Fatalf("Failed to parse public key: %v", err)
312-
}
312+
require.NoError(ht, err, "Failed to parse public key")
313313
pubKeys = append(pubKeys, pubKey)
314314
}
315315

@@ -323,14 +323,19 @@ func testTrackOnion(ht *lntest.HarnessTest) {
323323
}
324324

325325
// Simulate an RPC client decrypting the onion error.
326-
encryptedError := lnwire.OpaqueReason(trackResp.EncryptedError)
327-
forwardingError, err := errorDecryptor.DecryptError(encryptedError)
328-
require.Nil(ht, err, "unable to decrypt error")
329-
330-
clientErrorStr = forwardingError.Error()
326+
encryptedError := lnwire.OpaqueReason(encryptedErrorBytes)
327+
clientFwdErr, err := errorDecryptor.DecryptError(encryptedError)
328+
require.NoError(ht, err, "unable to decrypt error")
329+
330+
// Finally, assert that the structured forwarding failure is the same
331+
// whether it was decrypted on the server or on the client.
332+
serverFwdErr, err := switchrpc.UnmarshallForwardingError(
333+
serverFwdFailure,
334+
)
335+
require.NoError(ht, err, "unable to decode server forwarding failure")
331336

332-
serverFwdErr, err := switchrpc.ParseForwardingError(serverErrorStr)
333-
require.Nil(ht, err, "expected to parse forwarding error from server")
334-
require.Equal(ht, serverFwdErr.Error(), clientErrorStr, "expect error "+
335-
"message to match whether handled by client or server")
337+
require.Equal(ht, serverFwdFailure.FailureSourceIndex,
338+
uint32(clientFwdErr.FailureSourceIdx), "source index mismatch")
339+
require.Equal(ht, serverFwdErr.WireMessage(),
340+
clientFwdErr.WireMessage(), "wire message mismatch")
336341
}

lnrpc/switchrpc/switch_server.go

Lines changed: 137 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ import (
1212
"math/big"
1313
"os"
1414
"path/filepath"
15-
"strconv"
16-
"strings"
1715

1816
"github.com/btcsuite/btcd/btcec/v2"
1917
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
@@ -435,15 +433,19 @@ func (s *Server) TrackOnion(ctx context.Context,
435433
req.AttemptId, hash, errorDecryptor,
436434
)
437435
if err != nil {
438-
message, code := translateErrorForRPC(err)
439-
440436
log.Errorf("GetAttemptResult failed for attempt_id=%d of "+
441-
" payment=%x: %v", req.AttemptId, hash, message)
437+
" payment=%x: %v", req.AttemptId, hash, err)
442438

443-
return &TrackOnionResponse{
444-
ErrorCode: code,
445-
ErrorMessage: message,
446-
}, nil
439+
// If the payment ID is not found, we return a NotFound error.
440+
if errors.Is(err, htlcswitch.ErrPaymentIDNotFound) {
441+
return nil, status.Errorf(codes.NotFound,
442+
"payment with attempt ID %d not found",
443+
req.AttemptId)
444+
}
445+
446+
// For other errors, we return an internal error.
447+
return nil, status.Errorf(codes.Internal,
448+
"GetAttemptResult failed: %v", err)
447449
}
448450

449451
// The switch knows about this payment, we'll wait for a result to be
@@ -456,12 +458,10 @@ func (s *Server) TrackOnion(ctx context.Context,
456458
select {
457459
case result, ok = <-resultChan:
458460
if !ok {
459-
// This channel is closed when the Switch shuts down.
460-
return &TrackOnionResponse{
461-
ErrorCode: ErrorCode_SWITCH_EXITING,
462-
ErrorMessage: htlcswitch.ErrSwitchExiting.
463-
Error(),
464-
}, nil
461+
// This channel is closed when the Switch shuts down. We
462+
// return a gRPC error to the client.
463+
return nil, status.Error(codes.Unavailable,
464+
htlcswitch.ErrSwitchExiting.Error())
465465
}
466466

467467
case <-ctx.Done():
@@ -470,25 +470,30 @@ func (s *Server) TrackOnion(ctx context.Context,
470470
return nil, status.FromContextError(ctx.Err()).Err()
471471
}
472472

473-
// The attempt result arrived so the HTLC is no longer in-flight.
473+
// The attempt result arrived so the HTLC is no longer in-flight. If
474+
// the payment failed, we build a structured response for the client.
474475
if result.Error != nil {
475-
message, code := translateErrorForRPC(result.Error)
476-
477476
log.Errorf("Payment via onion failed for payment=%v: %v",
478-
hash, message)
477+
hash, result.Error)
479478

480-
return &TrackOnionResponse{
481-
ErrorCode: code,
482-
ErrorMessage: message,
483-
}, nil
479+
details := marshallFailureDetails(result.Error)
480+
481+
return newTrackOnionFailureResponse(details), nil
484482
}
485483

484+
// If the server was unable to decrypt the error, it will be returned
485+
// as an encrypted byte slice. We populate the response accordingly.
486486
if len(result.EncryptedError) > 0 {
487-
log.Errorf("Payment via onion failed for payment=%v", hash)
487+
log.Errorf("Payment via onion failed for payment=%v with "+
488+
"encrypted error", hash)
488489

489-
return &TrackOnionResponse{
490-
EncryptedError: result.EncryptedError,
491-
}, nil
490+
details := &FailureDetails{
491+
Failure: &FailureDetails_EncryptedErrorData{
492+
EncryptedErrorData: result.EncryptedError,
493+
},
494+
}
495+
496+
return newTrackOnionFailureResponse(details), nil
492497
}
493498

494499
// If we have reached this point, we expect a valid preimage for a
@@ -497,17 +502,21 @@ func (s *Server) TrackOnion(ctx context.Context,
497502
log.Errorf("Payment %v completed without a valid preimage or "+
498503
"error", hash)
499504

500-
return &TrackOnionResponse{
505+
details := &FailureDetails{
501506
ErrorCode: ErrorCode_INTERNAL,
502507
ErrorMessage: ErrAmbiguousPaymentState.Error(),
503-
}, nil
508+
}
509+
510+
return newTrackOnionFailureResponse(details), nil
504511
}
505512

506513
log.Debugf("Received preimage via onion attempt_id=%d for payment=%v",
507514
req.AttemptId, hash)
508515

509516
return &TrackOnionResponse{
510-
Preimage: result.Preimage[:],
517+
Result: &TrackOnionResponse_Preimage{
518+
Preimage: result.Preimage[:],
519+
},
511520
}, nil
512521
}
513522

@@ -662,7 +671,6 @@ func (s *Server) BuildOnion(_ context.Context,
662671
func translateErrorForRPC(err error) (string, ErrorCode) {
663672
var (
664673
clearTextErr htlcswitch.ClearTextError
665-
fwdErr *htlcswitch.ForwardingError
666674
)
667675

668676
switch {
@@ -680,20 +688,6 @@ func translateErrorForRPC(err error) (string, ErrorCode) {
680688
return err.Error(), ErrorCode_SWITCH_EXITING
681689

682690
case errors.As(err, &clearTextErr):
683-
// If this is a forwarding error, we'll handle it specially.
684-
if errors.As(err, &fwdErr) {
685-
encodedError, encodeErr := encodeForwardingError(fwdErr)
686-
if encodeErr != nil {
687-
return fmt.Sprintf("failed to encode wire "+
688-
"message: %v", encodeErr),
689-
ErrorCode_INTERNAL
690-
}
691-
692-
return encodedError,
693-
ErrorCode_FORWARDING_ERROR
694-
}
695-
696-
// Otherwise, we'll just encode the clear text error.
697691
var buf bytes.Buffer
698692
encodeErr := lnwire.EncodeFailure(
699693
&buf, clearTextErr.WireMessage(), 0,
@@ -712,48 +706,116 @@ func translateErrorForRPC(err error) (string, ErrorCode) {
712706
}
713707
}
714708

715-
// encodeForwardingError converts a forwarding error from the switch to the
716-
// format we can package for delivery to SendOnion rpc clients. We preserve the
717-
// failure message from the wire as well as the index along the route where the
718-
// failure occurred.
719-
func encodeForwardingError(e *htlcswitch.ForwardingError) (string, error) {
720-
var buf bytes.Buffer
721-
err := lnwire.EncodeFailure(&buf, e.WireMessage(), 0)
722-
if err != nil {
723-
return "", fmt.Errorf("failed to encode wire message: %w", err)
709+
// newTrackOnionFailureResponse is a helper function that wraps a
710+
// PaymentFailureDetails message in a TrackOnionResponse.
711+
func newTrackOnionFailureResponse(
712+
details *FailureDetails) *TrackOnionResponse {
713+
714+
return &TrackOnionResponse{
715+
Result: &TrackOnionResponse_FailureDetails{
716+
FailureDetails: details,
717+
},
718+
}
719+
}
720+
721+
// marshallFailureDetails creates the FailureDetails message for the
722+
// TrackOnion response body.
723+
func marshallFailureDetails(err error) *FailureDetails {
724+
var (
725+
clearTextErr htlcswitch.ClearTextError
726+
fwdErr *htlcswitch.ForwardingError
727+
)
728+
729+
details := &FailureDetails{
730+
ErrorMessage: err.Error(),
724731
}
725732

726-
return fmt.Sprintf("%d@%s", e.FailureSourceIdx,
727-
hex.EncodeToString(buf.Bytes())), nil
733+
switch {
734+
case errors.Is(err, htlcswitch.ErrPaymentIDNotFound):
735+
details.ErrorCode = ErrorCode_PAYMENT_ID_NOT_FOUND
736+
737+
case errors.Is(err, htlcswitch.ErrUnreadableFailureMessage):
738+
details.ErrorCode = ErrorCode_UNREADABLE_FAILURE_MESSAGE
739+
740+
case errors.Is(err, htlcswitch.ErrSwitchExiting):
741+
details.ErrorCode = ErrorCode_SWITCH_EXITING
742+
743+
case errors.As(err, &clearTextErr):
744+
var buf bytes.Buffer
745+
746+
encodeErr := lnwire.EncodeFailure(
747+
&buf, clearTextErr.WireMessage(), 0,
748+
)
749+
if encodeErr != nil {
750+
log.Errorf("failed to encode wire message: %v",
751+
encodeErr)
752+
details.ErrorCode = ErrorCode_INTERNAL
753+
754+
return details
755+
}
756+
757+
if errors.As(err, &fwdErr) {
758+
details.Failure = &FailureDetails_ForwardingFailure{
759+
ForwardingFailure: &ForwardingFailure{
760+
FailureSourceIndex: uint32(
761+
fwdErr.FailureSourceIdx,
762+
),
763+
WireMessage: buf.Bytes(),
764+
},
765+
}
766+
} else {
767+
details.Failure = &FailureDetails_ClearTextFailure{
768+
ClearTextFailure: &ClearTextFailure{
769+
WireMessage: buf.Bytes(),
770+
},
771+
}
772+
}
773+
774+
default:
775+
details.ErrorCode = ErrorCode_INTERNAL
776+
}
777+
778+
return details
728779
}
729780

730-
// ParseForwardingError converts an error from the format in SendOnion rpc
731-
// protos to a forwarding error type.
732-
func ParseForwardingError(errStr string) (*htlcswitch.ForwardingError, error) {
733-
parts := strings.SplitN(errStr, "@", 2)
734-
if len(parts) != 2 {
735-
return nil, fmt.Errorf("invalid forwarding error format: %s",
736-
errStr)
781+
// UnmarshallForwardingError converts a protobuf ForwardingFailure message into
782+
// an htlcswitch.ForwardingError.
783+
func UnmarshallForwardingError(f *ForwardingFailure) (
784+
*htlcswitch.ForwardingError, error) {
785+
786+
if f == nil {
787+
return nil, fmt.Errorf("cannot parse nil ForwardingFailure")
737788
}
738789

739-
idx, err := strconv.Atoi(parts[0])
790+
wireMsg, err := UnmarshallFailureMessage(f.WireMessage)
740791
if err != nil {
741-
return nil, fmt.Errorf("invalid forwarding error index: %s",
742-
errStr)
792+
return nil, fmt.Errorf("failed to decode wire message: %w", err)
743793
}
744794

745-
wireMsgBytes, err := hex.DecodeString(parts[1])
746-
if err != nil {
747-
return nil, fmt.Errorf("invalid forwarding error wire "+
748-
"message: %s", errStr)
795+
return htlcswitch.NewForwardingError(
796+
wireMsg, int(f.FailureSourceIndex),
797+
), nil
798+
}
799+
800+
// UnmarshallLinkError converts a protobuf ClearTextFailure message into the an
801+
// htlcswitch.LinkError.
802+
func UnmarshallLinkError(f *ClearTextFailure) (*htlcswitch.LinkError, error) {
803+
if f == nil {
804+
return nil, fmt.Errorf("cannot parse nil ClearTextFailure")
749805
}
750806

751-
r := bytes.NewReader(wireMsgBytes)
752-
wireMsg, err := lnwire.DecodeFailure(r, 0)
807+
wireMsg, err := UnmarshallFailureMessage(f.WireMessage)
753808
if err != nil {
754-
return nil, fmt.Errorf("failed to decode wire message: %w",
755-
err)
809+
return nil, fmt.Errorf("failed to decode wire message: %w", err)
756810
}
757811

758-
return htlcswitch.NewForwardingError(wireMsg, idx), nil
812+
return htlcswitch.NewLinkError(wireMsg), nil
813+
}
814+
815+
// UnmarshallFailureMessage decodes a raw wire message byte slice into a rich
816+
// lnwire.FailureMessage object.
817+
func UnmarshallFailureMessage(wireMsg []byte) (lnwire.FailureMessage, error) {
818+
r := bytes.NewReader(wireMsg)
819+
820+
return lnwire.DecodeFailure(r, 0)
759821
}

0 commit comments

Comments
 (0)