Skip to content

Commit 70d83f7

Browse files
authored
Merge pull request #1060 from sumanjeet0012/fix/kademlia_key_decoding
Migrate DHT API to accept string keys.
2 parents e3a296d + cda110e commit 70d83f7

File tree

8 files changed

+111
-108
lines changed

8 files changed

+111
-108
lines changed

examples/kademlia/kademlia.py

Lines changed: 26 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import secrets
1414
import sys
1515

16-
import base58
1716
from multiaddr import (
1817
Multiaddr,
1918
)
@@ -32,9 +31,6 @@
3231
DHTMode,
3332
KadDHT,
3433
)
35-
from libp2p.kad_dht.utils import (
36-
create_key_from_binary,
37-
)
3834
from libp2p.tools.async_service import (
3935
background_trio_service,
4036
)
@@ -191,57 +187,46 @@ async def run_node(
191187
# Start the DHT service
192188
async with background_trio_service(dht):
193189
logger.info(f"DHT service started in {dht_mode.value} mode")
194-
val_key = create_key_from_binary(b"py-libp2p kademlia example value")
195-
content = b"Hello from python node "
196-
content_key = create_key_from_binary(content)
190+
191+
# Example 1: Simple Key-Value Storage
192+
# Just use a string key directly - DHT API accepts strings!
193+
key = "my-example-key"
194+
value = b"Hello from py-libp2p!"
195+
196+
# Example 2: Content Provider Advertisement
197+
content_id = "my-content-identifier"
197198

198199
if dht_mode == DHTMode.SERVER:
199-
# Store a value in the DHT
200-
msg = "Hello message from Sumanjeet"
201-
val_data = msg.encode()
202-
await dht.put_value(val_key, val_data)
203-
logger.info(
204-
f"Stored value '{val_data.decode()}'"
205-
f"with key: {base58.b58encode(val_key).decode()}"
206-
)
200+
# Store key-value pair in the DHT
201+
await dht.put_value(key, value)
202+
logger.info(f"Stored value: {value.decode()} with key: {key}")
207203

208-
# Advertise as content server
209-
success = await dht.provider_store.provide(content_key)
204+
# Advertise as a provider for content
205+
success = await dht.provide(content_id)
210206
if success:
211-
logger.info(
212-
"Successfully advertised as server"
213-
f"for content: {content_key.hex()}"
214-
)
207+
logger.info(f"Advertised as provider for content: {content_id}")
215208
else:
216-
logger.warning("Failed to advertise as content server")
209+
logger.warning("Failed to advertise as provider")
217210

218211
else:
219-
# retrieve the value
220-
logger.info(
221-
"Looking up key: %s", base58.b58encode(val_key).decode()
222-
)
223-
val_data = await dht.get_value(val_key)
224-
if val_data:
225-
try:
226-
logger.info(f"Retrieved value: {val_data.decode()}")
227-
except UnicodeDecodeError:
228-
logger.info(f"Retrieved value (bytes): {val_data!r}")
212+
# Retrieve value from DHT using the same key
213+
logger.info(f"Looking up key: {key}")
214+
retrieved_value = await dht.get_value(key)
215+
if retrieved_value:
216+
logger.info(f"Retrieved value: {retrieved_value.decode()}")
229217
else:
230218
logger.warning("Failed to retrieve value")
231219

232-
# Also check if we can find servers for our own content
233-
logger.info("Looking for servers of content: %s", content_key.hex())
234-
providers = await dht.provider_store.find_providers(content_key)
220+
# Find providers for content
221+
logger.info(f"Looking for providers of content: {content_id}")
222+
providers = await dht.find_providers(content_id)
235223
if providers:
236224
logger.info(
237-
"Found %d servers for content: %s",
238-
len(providers),
239-
[p.peer_id.pretty() for p in providers],
225+
f"Found {len(providers)} providers: "
226+
f"{[p.peer_id.pretty() for p in providers]}"
240227
)
241228
else:
242-
logger.warning(
243-
"No servers found for content %s", content_key.hex()
244-
)
229+
logger.warning("No providers found")
245230

246231
# Keep the node running
247232
while True:

libp2p/kad_dht/kad_dht.py

Lines changed: 51 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -763,46 +763,45 @@ async def find_peer(self, peer_id: ID) -> PeerInfo | None:
763763

764764
# Value storage and retrieval methods
765765

766-
async def put_value(self, key: bytes, value: bytes) -> None:
766+
async def put_value(self, key: str, value: bytes) -> None:
767767
"""
768768
Store a value in the DHT.
769+
770+
Args:
771+
key: String key (will be converted to bytes for storage)
772+
value: Binary value to store
773+
769774
"""
770-
logger.debug(f"Storing value for key {key.hex()}")
775+
logger.debug(f"Storing value for key {key}")
771776

772-
if key.decode("utf-8").startswith("/"):
773-
if self.validator is not None:
774-
# Dont allow local users to put bad values
775-
self.validator.validate(key.decode("utf-8"), value)
777+
# Validate if key starts with "/" (namespaced keys like /pk/...)
778+
if self.validator is not None and key.startswith("/"):
779+
self.validator.validate(key, value)
776780

777-
old_value_record = self.value_store.get(key)
778-
if old_value_record is not None and old_value_record.value != value:
779-
# Select which value is better
780-
try:
781-
index = self.validator.select(
782-
key.decode("utf-8"), [value, old_value_record.value]
783-
)
784-
if index != 0:
785-
raise ValueError(
786-
"Refusing to replace newer value with the older one"
787-
)
788-
except Exception as e:
789-
logger.debug(f"Validation error for key {key.hex()}: {e}")
790-
raise
781+
key_bytes = key.encode("utf-8")
782+
old_value_record = self.value_store.get(key_bytes)
783+
if old_value_record is not None and old_value_record.value != value:
784+
index = self.validator.select(key, [value, old_value_record.value])
785+
if index != 0:
786+
raise ValueError(
787+
"Refusing to replace newer value with the older one"
788+
)
789+
790+
# Convert string key to bytes for storage
791+
key_bytes = key.encode("utf-8")
791792

792793
# 1. Store locally first
793-
self.value_store.put(key, value)
794+
self.value_store.put(key_bytes, value)
794795
try:
795796
decoded_value = value.decode("utf-8")
796797
except UnicodeDecodeError:
797798
decoded_value = value.hex()
798-
logger.debug(
799-
f"Stored value locally for key {key.hex()} with value {decoded_value}"
800-
)
799+
logger.debug(f"Stored value locally for key {key} with value {decoded_value}")
801800

802801
# 2. Get closest peers, excluding self
803802
closest_peers = [
804803
peer
805-
for peer in self.routing_table.find_local_closest_peers(key)
804+
for peer in self.routing_table.find_local_closest_peers(key_bytes)
806805
if peer != self.local_peer_id
807806
]
808807
logger.debug(f"Found {len(closest_peers)} peers to store value at")
@@ -817,7 +816,7 @@ async def store_one(idx: int, peer: ID) -> None:
817816
try:
818817
with trio.move_on_after(QUERY_TIMEOUT):
819818
success = await self.value_store._store_at_peer(
820-
peer, key, value
819+
peer, key_bytes, value
821820
)
822821
batch_results[idx] = success
823822
if success:
@@ -835,19 +834,32 @@ async def store_one(idx: int, peer: ID) -> None:
835834

836835
logger.info(f"Successfully stored value at {stored_count} peers")
837836

838-
async def get_value(self, key: bytes) -> bytes | None:
839-
logger.debug(f"Getting value for key: {key.hex()}")
837+
async def get_value(self, key: str) -> bytes | None:
838+
"""
839+
Retrieve a value from the DHT.
840+
841+
Args:
842+
key: String key (will be converted to bytes for lookup)
843+
844+
Returns:
845+
The value if found, None otherwise
846+
847+
"""
848+
logger.debug(f"Getting value for key: {key}")
849+
850+
# Convert string key to bytes for lookup
851+
key_bytes = key.encode("utf-8")
840852

841853
# 1. Check local store first
842-
value_record = self.value_store.get(key)
854+
value_record = self.value_store.get(key_bytes)
843855
if value_record:
844856
logger.debug("Found value locally")
845857
return value_record.value
846858

847859
# 2. Get closest peers, excluding self
848860
closest_peers = [
849861
peer
850-
for peer in self.routing_table.find_local_closest_peers(key)
862+
for peer in self.routing_table.find_local_closest_peers(key_bytes)
851863
if peer != self.local_peer_id
852864
]
853865
logger.debug(f"Searching {len(closest_peers)} peers for value")
@@ -861,7 +873,7 @@ async def query_one(peer: ID) -> None:
861873
nonlocal found_value
862874
try:
863875
with trio.move_on_after(QUERY_TIMEOUT):
864-
value = await self.value_store._get_from_peer(peer, key)
876+
value = await self.value_store._get_from_peer(peer, key_bytes)
865877
if value is not None and found_value is None:
866878
found_value = value
867879
logger.debug(f"Found value at peer {peer}")
@@ -873,12 +885,12 @@ async def query_one(peer: ID) -> None:
873885
nursery.start_soon(query_one, peer)
874886

875887
if found_value is not None:
876-
self.value_store.put(key, found_value)
888+
self.value_store.put(key_bytes, found_value)
877889
logger.info("Successfully retrieved value from network")
878890
return found_value
879891

880892
# 4. Not found
881-
logger.warning(f"Value not found for key {key.hex()}")
893+
logger.warning(f"Value not found for key {key}")
882894
return None
883895

884896
# Add these methods in the Utility methods section
@@ -899,17 +911,19 @@ async def add_peer(self, peer_id: ID) -> bool:
899911
"""
900912
return await self.routing_table.add_peer(peer_id)
901913

902-
async def provide(self, key: bytes) -> bool:
914+
async def provide(self, key: str) -> bool:
903915
"""
904916
Reference to provider_store.provide for convenience.
905917
"""
906-
return await self.provider_store.provide(key)
918+
key_bytes = key.encode("utf-8")
919+
return await self.provider_store.provide(key_bytes)
907920

908-
async def find_providers(self, key: bytes, count: int = 20) -> list[PeerInfo]:
921+
async def find_providers(self, key: str, count: int = 20) -> list[PeerInfo]:
909922
"""
910923
Reference to provider_store.find_providers for convenience.
911924
"""
912-
return await self.provider_store.find_providers(key, count)
925+
key_bytes = key.encode("utf-8")
926+
return await self.provider_store.find_providers(key_bytes, count)
913927

914928
def get_routing_table_size(self) -> int:
915929
"""

libp2p/kad_dht/value_store.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def put(self, key: bytes, value: bytes, validity: float = 0.0) -> None:
7272
logger.debug(
7373
"Storing value for key %s... with validity %s", key.hex(), validity
7474
)
75-
record = make_put_record(key.decode("utf-8"), value)
75+
record = make_put_record(key, value)
7676
record.timeReceived = str(time.time)
7777

7878
self.store[key] = (record, validity)

libp2p/records/record.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
from libp2p.kad_dht.pb import kademlia_pb2 as record_pb2
22

33

4-
def make_put_record(key: str, value: bytes) -> record_pb2.Record:
4+
def make_put_record(key: bytes, value: bytes) -> record_pb2.Record:
55
"""
66
Create a new Record object with the specified key and value.
77
88
Args:
9-
key (str): The key for the record, which will be encoded as bytes.
9+
key (bytes): The key for the record.
1010
value (bytes): The value to associate with the key in the record.
1111
1212
Returns:
1313
record_pb2.Record: A Record object containing the provided key and value.
1414
1515
"""
1616
record = record_pb2.Record()
17-
record.key = key.encode()
17+
record.key = key
1818
record.value = value
1919
return record

newsfragments/1059.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Kademlia DHT API now accepts string keys instead of bytes (``put_value(key: str, ...)``). Fixes UnicodeDecodeError with binary multihash keys.

tests/core/kad_dht/test_kad_dht.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -255,14 +255,15 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]):
255255
dht_a, dht_b = dht_pair
256256
# dht_a.peer_routing.routing_table.add_peer(dht_b.pe)
257257
peer_b_info = PeerInfo(dht_b.host.get_id(), dht_b.host.get_addrs())
258-
# Generate a random key and value
259-
key = b"rendom_key"
258+
# Generate a random key and value (use string key for API)
259+
key = "random_key"
260+
key_bytes = key.encode("utf-8")
260261
value = b"test-value"
261262

262263
# First add the value directly to node A's store to verify storage works
263-
dht_a.value_store.put(key, value)
264+
dht_a.value_store.put(key_bytes, value)
264265
logger.debug("Local value store: %s", dht_a.value_store.store)
265-
local_value_record = dht_a.value_store.get(key)
266+
local_value_record = dht_a.value_store.get(key_bytes)
266267
assert local_value_record is not None
267268
assert local_value_record.value == value, "Local value storage failed"
268269
print("number of nodes in peer store", dht_a.host.get_peerstore().peer_ids())
@@ -305,7 +306,7 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]):
305306
assert record_b.seq == record_b_put_value.seq
306307

307308
# # Log debugging information
308-
logger.debug("Put value with key %s...", key.hex()[:10])
309+
logger.debug("Put value with key %s...", key[:10])
309310
logger.debug("Node A value store: %s", dht_a.value_store.store)
310311

311312
# # Allow more time for the value to propagate
@@ -353,15 +354,15 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]):
353354
value = keypair.public_key.serialize()
354355

355356
with trio.fail_after(TEST_TIMEOUT):
356-
await dht_a.put_value(key.encode(), value)
357+
await dht_a.put_value(key, value) # Now accepts string directly
357358

358359
# INVALID KEY PAIR
359360
key = "/pk/abcdef1234567890" # Not a valid multihash
360361
value = b"not-a-real-key"
361362

362363
with trio.fail_after(TEST_TIMEOUT):
363364
with pytest.raises(InvalidRecordType, match="valid multihash"):
364-
await dht_a.put_value(key.encode(), value)
365+
await dht_a.put_value(key, value) # Now accepts string directly
365366

366367

367368
@pytest.mark.trio
@@ -372,10 +373,11 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
372373

373374
# Generate a random content ID
374375
content = f"test-content-{uuid.uuid4()}".encode()
375-
content_id = b"randome_content"
376+
content_id = "randome_content" # String for API
377+
content_id_bytes = content_id.encode("utf-8") # Bytes for internal storage
376378

377379
# Store content on the first node
378-
dht_a.value_store.put(content_id, content)
380+
dht_a.value_store.put(content_id_bytes, content)
379381

380382
# An extra FIND_NODE req is sent between the 2 nodes while dht creation,
381383
# so both the nodes will have records of each other before PUT_VALUE req is sent
@@ -562,12 +564,13 @@ async def test_dht_req_fail_with_invalid_record_transfer(
562564
peer_b_info = PeerInfo(dht_b.host.get_id(), dht_b.host.get_addrs())
563565

564566
# Generate a random key and value
565-
key = b"rendom_key"
567+
key = "random_key" # String for API
568+
key_bytes = key.encode("utf-8") # Bytes for internal storage
566569
value = b"test-value"
567570

568571
# First add the value directly to node A's store to verify storage works
569-
dht_a.value_store.put(key, value)
570-
local_value = dht_a.value_store.get(key)
572+
dht_a.value_store.put(key_bytes, value)
573+
local_value = dht_a.value_store.get(key_bytes)
571574
assert local_value is not None
572575
assert local_value.value == value, "Local value storage failed"
573576
await dht_a.routing_table.add_peer(peer_b_info)
@@ -583,7 +586,7 @@ async def test_dht_req_fail_with_invalid_record_transfer(
583586
dht_a.host.get_peerstore().set_local_record(envelope)
584587

585588
await dht_a.put_value(key, value)
586-
retrieved_value_record = dht_b.value_store.get(key)
589+
retrieved_value_record = dht_b.value_store.get(key_bytes)
587590

588591
# This proves that DHT_B rejected DHT_A PUT_RECORD req upon receiving
589592
# the corrupted invalid record
@@ -596,7 +599,7 @@ async def test_dht_req_fail_with_invalid_record_transfer(
596599
dht_a.host.get_peerstore().set_local_record(false_envelope)
597600

598601
await dht_a.put_value(key, value)
599-
retrieved_value_record = dht_b.value_store.get(key)
602+
retrieved_value_record = dht_b.value_store.get(key_bytes)
600603

601604
# This proves that DHT_B rejected DHT_A PUT_RECORD req upon receving
602605
# the record with a different peer_id regardless of a valid signature

0 commit comments

Comments
 (0)