Skip to content

Commit 37523b6

Browse files
authored
Merge pull request #10100 from Abdulkbk/fix-chanid-flag
commands: fix how we parse chan ids args at CLI level
2 parents 9ffbb97 + 754c254 commit 37523b6

File tree

2 files changed

+101
-33
lines changed

2 files changed

+101
-33
lines changed

cmd/commands/cmd_payments.go

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -183,13 +183,13 @@ func PaymentFlags() []cli.Flag {
183183
cancelableFlag,
184184
cltvLimitFlag,
185185
lastHopFlag,
186-
cli.Int64SliceFlag{
186+
cli.StringSliceFlag{
187187
Name: "outgoing_chan_id",
188188
Usage: "short channel id of the outgoing channel to " +
189189
"use for the first hop of the payment; can " +
190190
"be specified multiple times in the same " +
191191
"command",
192-
Value: &cli.Int64Slice{},
192+
Value: &cli.StringSlice{},
193193
},
194194
cli.BoolFlag{
195195
Name: "force, f",
@@ -521,12 +521,11 @@ func SendPaymentRequest(ctx *cli.Context, req *routerrpc.SendPaymentRequest,
521521

522522
lnClient := lnrpc.NewLightningClient(lnConn)
523523

524-
outChan := ctx.Int64Slice("outgoing_chan_id")
525-
if len(outChan) != 0 {
526-
req.OutgoingChanIds = make([]uint64, len(outChan))
527-
for i, c := range outChan {
528-
req.OutgoingChanIds[i] = uint64(c)
529-
}
524+
var err error
525+
outChan := ctx.StringSlice("outgoing_chan_id")
526+
req.OutgoingChanIds, err = parseChanIDs(outChan)
527+
if err != nil {
528+
return fmt.Errorf("unable to decode outgoing_chan_ids: %w", err)
530529
}
531530

532531
if ctx.IsSet(lastHopFlag.Name) {
@@ -1282,17 +1281,9 @@ func queryRoutes(ctx *cli.Context) error {
12821281
}
12831282

12841283
outgoingChanIds := ctx.StringSlice("outgoing_chan_id")
1285-
if len(outgoingChanIds) != 0 {
1286-
req.OutgoingChanIds = make([]uint64, len(outgoingChanIds))
1287-
for i, chanID := range outgoingChanIds {
1288-
id, err := strconv.ParseUint(chanID, 10, 64)
1289-
if err != nil {
1290-
return fmt.Errorf("invalid outgoing_chan_id "+
1291-
"argument: %w", err)
1292-
}
1293-
1294-
req.OutgoingChanIds[i] = id
1295-
}
1284+
req.OutgoingChanIds, err = parseChanIDs(outgoingChanIds)
1285+
if err != nil {
1286+
return fmt.Errorf("unable to decode outgoing_chan_id: %w", err)
12961287
}
12971288

12981289
if ctx.IsSet("route_hints") {
@@ -1585,13 +1576,13 @@ var forwardingHistoryCommand = cli.Command{
15851576
Usage: "skip the peer alias lookup per forwarding " +
15861577
"event in order to improve performance",
15871578
},
1588-
cli.Int64SliceFlag{
1579+
cli.StringSliceFlag{
15891580
Name: "incoming_chan_ids",
15901581
Usage: "the short channel id of the incoming " +
15911582
"channel to filter events by; can be " +
15921583
"specified multiple times in the same command",
15931584
},
1594-
cli.Int64SliceFlag{
1585+
cli.StringSliceFlag{
15951586
Name: "outgoing_chan_ids",
15961587
Usage: "the short channel id of the outgoing " +
15971588
"channel to filter events by; can be " +
@@ -1677,21 +1668,19 @@ func forwardingHistory(ctx *cli.Context) error {
16771668
NumMaxEvents: maxEvents,
16781669
PeerAliasLookup: lookupPeerAlias,
16791670
}
1680-
outgoingChannelIDs := ctx.Int64Slice("outgoing_chan_ids")
1681-
if len(outgoingChannelIDs) != 0 {
1682-
req.OutgoingChanIds = make([]uint64, len(outgoingChannelIDs))
1683-
for i, c := range outgoingChannelIDs {
1684-
req.OutgoingChanIds[i] = uint64(c)
1685-
}
1671+
1672+
outgoingChannelIDs := ctx.StringSlice("outgoing_chan_ids")
1673+
req.OutgoingChanIds, err = parseChanIDs(outgoingChannelIDs)
1674+
if err != nil {
1675+
return fmt.Errorf("unable to decode outgoing_chan_ids: %w", err)
16861676
}
16871677

1688-
incomingChannelIDs := ctx.Int64Slice("incoming_chan_ids")
1689-
if len(incomingChannelIDs) != 0 {
1690-
req.IncomingChanIds = make([]uint64, len(incomingChannelIDs))
1691-
for i, c := range incomingChannelIDs {
1692-
req.IncomingChanIds[i] = uint64(c)
1693-
}
1678+
incomingChannelIDs := ctx.StringSlice("incoming_chan_ids")
1679+
req.IncomingChanIds, err = parseChanIDs(incomingChannelIDs)
1680+
if err != nil {
1681+
return fmt.Errorf("unable to decode incoming_chan_ids: %w", err)
16941682
}
1683+
16951684
resp, err := client.ForwardingHistory(ctxc, req)
16961685
if err != nil {
16971686
return err
@@ -2060,3 +2049,24 @@ func ordinalNumber(num uint32) string {
20602049
return fmt.Sprintf("%dth", num)
20612050
}
20622051
}
2052+
2053+
// parseChanIDs parses a slice of strings containing short channel IDs into a
2054+
// slice of uint64 values.
2055+
func parseChanIDs(idStrings []string) ([]uint64, error) {
2056+
// Return early if no chan IDs are passed.
2057+
if len(idStrings) == 0 {
2058+
return nil, nil
2059+
}
2060+
2061+
chanIDs := make([]uint64, len(idStrings))
2062+
for i, idStr := range idStrings {
2063+
scid, err := strconv.ParseUint(idStr, 10, 64)
2064+
if err != nil {
2065+
return nil, err
2066+
}
2067+
2068+
chanIDs[i] = scid
2069+
}
2070+
2071+
return chanIDs, nil
2072+
}

cmd/commands/commands_test.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,3 +434,61 @@ func TestParseBlockHeightInputs(t *testing.T) {
434434
})
435435
}
436436
}
437+
438+
// TestParseChanIDs tests the parseChanIDs function with various
439+
// valid and invalid input values and verifies the output.
440+
func TestParseChanIDs(t *testing.T) {
441+
t.Parallel()
442+
443+
testCases := []struct {
444+
name string
445+
chanIDs []string
446+
expected []uint64
447+
expectedErr bool
448+
}{
449+
{
450+
name: "valid chan ids",
451+
chanIDs: []string{
452+
"1499733860352000", "17592186044552773633",
453+
},
454+
expected: []uint64{
455+
1499733860352000, 17592186044552773633,
456+
},
457+
expectedErr: false,
458+
},
459+
{
460+
name: "invalid chan id",
461+
chanIDs: []string{
462+
"channel id",
463+
},
464+
expected: []uint64{},
465+
expectedErr: true,
466+
},
467+
{
468+
name: "negative chan id",
469+
chanIDs: []string{
470+
"-10000",
471+
},
472+
expected: []uint64{},
473+
expectedErr: true,
474+
},
475+
{
476+
name: "empty chan ids",
477+
chanIDs: []string{},
478+
expected: nil,
479+
expectedErr: false,
480+
},
481+
}
482+
483+
for _, tc := range testCases {
484+
t.Run(tc.name, func(t *testing.T) {
485+
chanIDs, err := parseChanIDs(tc.chanIDs)
486+
if tc.expectedErr {
487+
require.Error(t, err)
488+
return
489+
}
490+
require.NoError(t, err)
491+
require.Equal(t, tc.expected, chanIDs)
492+
})
493+
}
494+
}

0 commit comments

Comments
 (0)