Skip to content

Commit 4d9d398

Browse files
authored
Merge pull request #349 from ellemouton/validate-dest-addr-network
loopd: verify that dest addr is for correct chain
2 parents 0fa3625 + 5399e60 commit 4d9d398

File tree

2 files changed

+145
-11
lines changed

2 files changed

+145
-11
lines changed

loopd/swapclient_server.go

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"sync"
99
"time"
1010

11+
"github.com/btcsuite/btcd/chaincfg"
1112
"github.com/btcsuite/btcutil"
1213
"github.com/lightninglabs/lndclient"
1314
"github.com/lightninglabs/loop"
@@ -34,6 +35,17 @@ const (
3435
minConfTarget = 2
3536
)
3637

38+
var (
39+
// errIncorrectChain is returned when the format of the
40+
// destination address provided does not match the active chain.
41+
errIncorrectChain = errors.New("invalid address format for the " +
42+
"active chain")
43+
44+
// errConfTargetTooLow is returned when the chosen confirmation target
45+
// is below the allowed minimum.
46+
errConfTargetTooLow = errors.New("confirmation target too low")
47+
)
48+
3749
// swapClientServer implements the grpc service exposed by loopd.
3850
type swapClientServer struct {
3951
network lndclient.Network
@@ -58,13 +70,6 @@ func (s *swapClientServer) LoopOut(ctx context.Context,
5870

5971
log.Infof("Loop out request received")
6072

61-
sweepConfTarget, err := validateConfTarget(
62-
in.SweepConfTarget, loop.DefaultSweepConfTarget,
63-
)
64-
if err != nil {
65-
return nil, err
66-
}
67-
6873
var sweepAddr btcutil.Address
6974
if in.Dest == "" {
7075
// Generate sweep address if none specified.
@@ -83,8 +88,10 @@ func (s *swapClientServer) LoopOut(ctx context.Context,
8388
}
8489
}
8590

86-
// Check that the label is valid.
87-
if err := labels.Validate(in.Label); err != nil {
91+
sweepConfTarget, err := validateLoopOutRequest(
92+
s.lnd.ChainParams, in.SweepConfTarget, sweepAddr, in.Label,
93+
)
94+
if err != nil {
8895
return nil, err
8996
}
9097

@@ -943,8 +950,9 @@ func validateConfTarget(target, defaultTarget int32) (int32, error) {
943950

944951
// Ensure the target respects our minimum threshold.
945952
case target < minConfTarget:
946-
return 0, fmt.Errorf("a confirmation target of at least %v "+
947-
"must be provided", minConfTarget)
953+
return 0, fmt.Errorf("%w: A confirmation target of at "+
954+
"least %v must be provided", errConfTargetTooLow,
955+
minConfTarget)
948956

949957
default:
950958
return target, nil
@@ -969,3 +977,22 @@ func validateLoopInRequest(htlcConfTarget int32, external bool) (int32, error) {
969977

970978
return validateConfTarget(htlcConfTarget, loop.DefaultHtlcConfTarget)
971979
}
980+
981+
// validateLoopOutRequest validates the confirmation target, destination
982+
// address and label of the loop out request.
983+
func validateLoopOutRequest(chainParams *chaincfg.Params, confTarget int32,
984+
sweepAddr btcutil.Address, label string) (int32, error) {
985+
// Check that the provided destination address has the correct format
986+
// for the active network.
987+
if !sweepAddr.IsForNet(chainParams) {
988+
return 0, fmt.Errorf("%w: Current active network is %s",
989+
errIncorrectChain, chainParams.Name)
990+
}
991+
992+
// Check that the label is valid.
993+
if err := labels.Validate(label); err != nil {
994+
return 0, err
995+
}
996+
997+
return validateConfTarget(confTarget, loop.DefaultSweepConfTarget)
998+
}

loopd/swapclient_server_test.go

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,24 @@
11
package loopd
22

33
import (
4+
"errors"
45
"testing"
56

7+
"github.com/btcsuite/btcd/chaincfg"
8+
"github.com/btcsuite/btcutil"
69
"github.com/lightninglabs/loop"
10+
"github.com/lightninglabs/loop/labels"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
var (
15+
testnetAddr, _ = btcutil.NewAddressScriptHash(
16+
[]byte{123}, &chaincfg.TestNet3Params,
17+
)
18+
19+
mainnetAddr, _ = btcutil.NewAddressScriptHash(
20+
[]byte{123}, &chaincfg.MainNetParams,
21+
)
722
)
823

924
// TestValidateConfTarget tests all failure and success cases for our conf
@@ -143,3 +158,95 @@ func TestValidateLoopInRequest(t *testing.T) {
143158
})
144159
}
145160
}
161+
162+
// TestValidateLoopOutRequest tests validation of loop out requests.
163+
func TestValidateLoopOutRequest(t *testing.T) {
164+
tests := []struct {
165+
name string
166+
chain chaincfg.Params
167+
confTarget int32
168+
destAddr btcutil.Address
169+
label string
170+
err error
171+
expectedTarget int32
172+
}{
173+
{
174+
name: "mainnet address with mainnet backend",
175+
chain: chaincfg.MainNetParams,
176+
destAddr: mainnetAddr,
177+
label: "label ok",
178+
confTarget: 2,
179+
err: nil,
180+
expectedTarget: 2,
181+
},
182+
{
183+
name: "mainnet address with testnet backend",
184+
chain: chaincfg.TestNet3Params,
185+
destAddr: mainnetAddr,
186+
label: "label ok",
187+
confTarget: 2,
188+
err: errIncorrectChain,
189+
expectedTarget: 0,
190+
},
191+
{
192+
name: "testnet address with testnet backend",
193+
chain: chaincfg.TestNet3Params,
194+
destAddr: testnetAddr,
195+
label: "label ok",
196+
confTarget: 2,
197+
err: nil,
198+
expectedTarget: 2,
199+
},
200+
{
201+
name: "testnet address with mainnet backend",
202+
chain: chaincfg.MainNetParams,
203+
destAddr: testnetAddr,
204+
label: "label ok",
205+
confTarget: 2,
206+
err: errIncorrectChain,
207+
expectedTarget: 0,
208+
},
209+
{
210+
name: "invalid label",
211+
chain: chaincfg.MainNetParams,
212+
destAddr: mainnetAddr,
213+
label: labels.Reserved,
214+
confTarget: 2,
215+
err: labels.ErrReservedPrefix,
216+
expectedTarget: 0,
217+
},
218+
{
219+
name: "invalid conf target",
220+
chain: chaincfg.MainNetParams,
221+
destAddr: mainnetAddr,
222+
label: "label ok",
223+
confTarget: 1,
224+
err: errConfTargetTooLow,
225+
expectedTarget: 0,
226+
},
227+
{
228+
name: "default conf target",
229+
chain: chaincfg.MainNetParams,
230+
destAddr: mainnetAddr,
231+
label: "label ok",
232+
confTarget: 0,
233+
err: nil,
234+
expectedTarget: 9,
235+
},
236+
}
237+
238+
for _, test := range tests {
239+
test := test
240+
241+
t.Run(test.name, func(t *testing.T) {
242+
t.Parallel()
243+
244+
conf, err := validateLoopOutRequest(
245+
&test.chain, test.confTarget, test.destAddr,
246+
test.label,
247+
)
248+
require.True(t, errors.Is(err, test.err))
249+
require.Equal(t, test.expectedTarget, conf)
250+
})
251+
}
252+
}

0 commit comments

Comments
 (0)