Skip to content

Commit c45167c

Browse files
committed
fix: handle empty keys
Handle empty keys, both when sent in RPC requests and in the local API.
1 parent 067f8ab commit c45167c

File tree

6 files changed

+100
-18
lines changed

6 files changed

+100
-18
lines changed

dht_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,3 +1910,51 @@ func TestProtocolUpgrade(t *testing.T) {
19101910
t.Fatalf("Expected 'buzz' got '%s'", string(value))
19111911
}
19121912
}
1913+
1914+
func TestInvalidKeys(t *testing.T) {
1915+
ctx, cancel := context.WithCancel(context.Background())
1916+
defer cancel()
1917+
1918+
nDHTs := 2
1919+
dhts := setupDHTS(t, ctx, nDHTs)
1920+
defer func() {
1921+
for i := 0; i < nDHTs; i++ {
1922+
dhts[i].Close()
1923+
defer dhts[i].host.Close()
1924+
}
1925+
}()
1926+
1927+
t.Logf("connecting %d dhts in a ring", nDHTs)
1928+
for i := 0; i < nDHTs; i++ {
1929+
connect(t, ctx, dhts[i], dhts[(i+1)%len(dhts)])
1930+
}
1931+
1932+
querier := dhts[0]
1933+
_, err := querier.GetClosestPeers(ctx, "")
1934+
if err == nil {
1935+
t.Fatal("get closest peers should have failed")
1936+
}
1937+
1938+
_, err = querier.FindProviders(ctx, cid.Cid{})
1939+
switch err {
1940+
case routing.ErrNotFound, routing.ErrNotSupported, kb.ErrLookupFailure:
1941+
t.Fatal("failed with the wrong error: ", err)
1942+
case nil:
1943+
t.Fatal("find providers should have failed")
1944+
}
1945+
1946+
_, err = querier.FindPeer(ctx, peer.ID(""))
1947+
if err != peer.ErrEmptyPeerID {
1948+
t.Fatal("expected to fail due to the empty peer ID")
1949+
}
1950+
1951+
_, err = querier.GetValue(ctx, "")
1952+
if err == nil {
1953+
t.Fatal("expected to have failed")
1954+
}
1955+
1956+
err = querier.PutValue(ctx, "", []byte("foobar"))
1957+
if err == nil {
1958+
t.Fatal("expected to have failed")
1959+
}
1960+
}

handlers.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,15 @@ func (dht *IpfsDHT) handlerForMsgType(t pb.Message_MessageType) dhtHandler {
5252
}
5353

5454
func (dht *IpfsDHT) handleGetValue(ctx context.Context, p peer.ID, pmes *pb.Message) (_ *pb.Message, err error) {
55-
// setup response
56-
resp := pb.NewMessage(pmes.GetType(), pmes.GetKey(), pmes.GetClusterLevel())
57-
5855
// first, is there even a key?
5956
k := pmes.GetKey()
6057
if len(k) == 0 {
6158
return nil, errors.New("handleGetValue but no key was provided")
62-
// TODO: send back an error response? could be bad, but the other node's hanging.
6359
}
6460

61+
// setup response
62+
resp := pb.NewMessage(pmes.GetType(), pmes.GetKey(), pmes.GetClusterLevel())
63+
6564
rec, err := dht.checkLocalDatastore(k)
6665
if err != nil {
6766
return nil, err
@@ -150,6 +149,10 @@ func cleanRecord(rec *recpb.Record) {
150149

151150
// Store a value in this peer local storage
152151
func (dht *IpfsDHT) handlePutValue(ctx context.Context, p peer.ID, pmes *pb.Message) (_ *pb.Message, err error) {
152+
if len(pmes.GetKey()) == 0 {
153+
return nil, errors.New("handleGetValue but no key was provided")
154+
}
155+
153156
rec := pmes.GetRecord()
154157
if rec == nil {
155158
logger.Debugw("got nil record from", "from", p)
@@ -253,6 +256,10 @@ func (dht *IpfsDHT) handleFindPeer(ctx context.Context, from peer.ID, pmes *pb.M
253256
resp := pb.NewMessage(pmes.GetType(), nil, pmes.GetClusterLevel())
254257
var closest []peer.ID
255258

259+
if len(pmes.GetKey()) == 0 {
260+
return nil, fmt.Errorf("handleFindPeer with empty key")
261+
}
262+
256263
// if looking for self... special case where we send it on CloserPeers.
257264
targetPid := peer.ID(pmes.GetKey())
258265
if targetPid == dht.self {
@@ -300,12 +307,15 @@ func (dht *IpfsDHT) handleFindPeer(ctx context.Context, from peer.ID, pmes *pb.M
300307
}
301308

302309
func (dht *IpfsDHT) handleGetProviders(ctx context.Context, p peer.ID, pmes *pb.Message) (_ *pb.Message, _err error) {
303-
resp := pb.NewMessage(pmes.GetType(), pmes.GetKey(), pmes.GetClusterLevel())
304310
key := pmes.GetKey()
305311
if len(key) > 80 {
306312
return nil, fmt.Errorf("handleGetProviders key size too large")
313+
} else if len(key) == 0 {
314+
return nil, fmt.Errorf("handleGetProviders key is empty")
307315
}
308316

317+
resp := pb.NewMessage(pmes.GetType(), pmes.GetKey(), pmes.GetClusterLevel())
318+
309319
// check if we have this value, to add ourselves as provider.
310320
has, err := dht.datastore.Has(convertToDsKey(key))
311321
if err != nil && err != ds.ErrNotFound {
@@ -341,7 +351,9 @@ func (dht *IpfsDHT) handleGetProviders(ctx context.Context, p peer.ID, pmes *pb.
341351
func (dht *IpfsDHT) handleAddProvider(ctx context.Context, p peer.ID, pmes *pb.Message) (_ *pb.Message, _err error) {
342352
key := pmes.GetKey()
343353
if len(key) > 80 {
344-
return nil, fmt.Errorf("handleAddProviders key size too large")
354+
return nil, fmt.Errorf("handleAddProvider key size too large")
355+
} else if len(key) == 0 {
356+
return nil, fmt.Errorf("handleAddProvider key is empty")
345357
}
346358

347359
logger.Debugf("adding provider", "from", p, "key", key)

handlers_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,27 @@ func TestCleanRecord(t *testing.T) {
6767
}
6868
}
6969

70+
func TestBadMessage(t *testing.T) {
71+
ctx, cancel := context.WithCancel(context.Background())
72+
defer cancel()
73+
74+
dht := setupDHT(ctx, t, false)
75+
76+
for _, typ := range []pb.Message_MessageType{
77+
pb.Message_PUT_VALUE, pb.Message_GET_VALUE, pb.Message_ADD_PROVIDER,
78+
pb.Message_GET_PROVIDERS, pb.Message_FIND_NODE,
79+
} {
80+
msg := &pb.Message{
81+
Type: typ,
82+
// explicitly avoid the key.
83+
}
84+
_, err := dht.handlerForMsgType(typ)(ctx, dht.Host().ID(), msg)
85+
if err == nil {
86+
t.Fatalf("expected processing message to fail for type %s", pb.Message_FIND_NODE)
87+
}
88+
}
89+
}
90+
7091
func BenchmarkHandleFindPeer(b *testing.B) {
7192
ctx, cancel := context.WithCancel(context.Background())
7293
defer cancel()

lookup.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ func (lk loggableKeyBytes) String() string {
7272
// If the context is canceled, this function will return the context error along
7373
// with the closest K peers it has found so far.
7474
func (dht *IpfsDHT) GetClosestPeers(ctx context.Context, key string) (<-chan peer.ID, error) {
75+
if key == "" {
76+
return nil, fmt.Errorf("can't lookup empty key")
77+
}
7578
//TODO: I can break the interface! return []peer.ID
7679
lookupRes, err := dht.runLookupWithFollowup(ctx, key,
7780
func(ctx context.Context, p peer.ID) ([]*peer.AddrInfo, error) {

pb/message.go

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"github.com/libp2p/go-libp2p-core/peer"
66

77
logging "github.com/ipfs/go-log"
8-
b58 "github.com/mr-tron/base58/base58"
98
ma "github.com/multiformats/go-multiaddr"
109
)
1110

@@ -138,16 +137,6 @@ func (m *Message) SetClusterLevel(level int) {
138137
m.ClusterLevelRaw = lvl + 1
139138
}
140139

141-
// Loggable turns a Message into machine-readable log output
142-
func (m *Message) Loggable() map[string]interface{} {
143-
return map[string]interface{}{
144-
"message": map[string]string{
145-
"type": m.Type.String(),
146-
"key": b58.Encode([]byte(m.GetKey())),
147-
},
148-
}
149-
}
150-
151140
// ConnectionType returns a Message_ConnectionType associated with the
152141
// network.Connectedness.
153142
func ConnectionType(c network.Connectedness) Message_ConnectionType {

routing.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,8 @@ func (dht *IpfsDHT) refreshRTIfNoShortcut(key kb.ID, lookupRes *lookupWithFollow
394394
func (dht *IpfsDHT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err error) {
395395
if !dht.enableProviders {
396396
return routing.ErrNotSupported
397+
} else if !key.Defined() {
398+
return fmt.Errorf("invalid cid: undefined")
397399
}
398400
logger.Debugw("finding provider", "cid", key)
399401

@@ -486,7 +488,10 @@ func (dht *IpfsDHT) makeProvRecord(key []byte) (*pb.Message, error) {
486488
func (dht *IpfsDHT) FindProviders(ctx context.Context, c cid.Cid) ([]peer.AddrInfo, error) {
487489
if !dht.enableProviders {
488490
return nil, routing.ErrNotSupported
491+
} else if !c.Defined() {
492+
return nil, fmt.Errorf("invalid cid: undefined")
489493
}
494+
490495
var providers []peer.AddrInfo
491496
for p := range dht.FindProvidersAsync(ctx, c, dht.bucketSize) {
492497
providers = append(providers, p)
@@ -500,7 +505,7 @@ func (dht *IpfsDHT) FindProviders(ctx context.Context, c cid.Cid) ([]peer.AddrIn
500505
// completes. Note: not reading from the returned channel may block the query
501506
// from progressing.
502507
func (dht *IpfsDHT) FindProvidersAsync(ctx context.Context, key cid.Cid, count int) <-chan peer.AddrInfo {
503-
if !dht.enableProviders {
508+
if !dht.enableProviders || !key.Defined() {
504509
peerOut := make(chan peer.AddrInfo)
505510
close(peerOut)
506511
return peerOut
@@ -613,6 +618,10 @@ func (dht *IpfsDHT) findProvidersAsyncRoutine(ctx context.Context, key multihash
613618

614619
// FindPeer searches for a peer with given ID.
615620
func (dht *IpfsDHT) FindPeer(ctx context.Context, id peer.ID) (_ peer.AddrInfo, err error) {
621+
if err := id.Validate(); err != nil {
622+
return peer.AddrInfo{}, err
623+
}
624+
616625
logger.Debugw("finding peer", "peer", id)
617626

618627
// Check if were already connected to them

0 commit comments

Comments
 (0)