Skip to content

Commit 9893c9c

Browse files
author
Matthias Zimmermann
committed
simplify EntityKey type
1 parent 898f54e commit 9893c9c

File tree

9 files changed

+155
-144
lines changed

9 files changed

+155
-144
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ SDK versions are tracked in the following files:
119119
Pytest is used for unit and integration testing.
120120
```bash
121121
uv run pytest # Run all tests
122-
uv run pytest -k test_create_entity_simple --log-cli-level=INFO # Specific tests via keyword, print at info log level
122+
uv run pytest -k test_create_entity_simple --log-cli-level=info # Specific tests via keyword, print at info log level
123123
```
124124

125125
If an `.env` file is present the unit tests are run against the specifice RPC coordinates and test accounts.

src/arkiv/contract.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from web3.types import RPCEndpoint
1010

1111
# Import EntityKey for munger type checking
12-
from .types import EntityKey
1312

1413
STORAGE_ADDRESS: Final[ChecksumAddress] = Web3.to_checksum_address(
1514
"0x0000000000000000000000000000000060138453"
@@ -60,48 +59,33 @@
6059
]
6160

6261

63-
def custom_munger(module: Any, *args: Any, **kwargs: Any) -> Any:
64-
"""Custom munger for RPC methods that automatically converts EntityKey objects."""
65-
processed_args = tuple(
66-
arg.value if isinstance(arg, EntityKey) else arg for arg in args
67-
)
68-
69-
processed_kwargs = {
70-
key: value.value if isinstance(value, EntityKey) else value
71-
for key, value in kwargs.items()
72-
}
73-
74-
# Apply default munger to processed arguments
75-
return default_root_munger(module, *processed_args, **processed_kwargs)
76-
77-
7862
FUNCTIONS_ABI: dict[str, Method[Any]] = {
7963
"get_storage_value": Method(
8064
json_rpc_method=RPCEndpoint("golembase_getStorageValue"),
81-
mungers=[custom_munger],
65+
mungers=[default_root_munger],
8266
),
8367
"get_entity_metadata": Method(
8468
json_rpc_method=RPCEndpoint("golembase_getEntityMetaData"),
85-
mungers=[custom_munger],
69+
mungers=[default_root_munger],
8670
),
8771
"get_entities_to_expire_at_block": Method(
8872
json_rpc_method=RPCEndpoint("golembase_getEntitiesToExpireAtBlock"),
89-
mungers=[custom_munger],
73+
mungers=[default_root_munger],
9074
),
9175
"get_entity_count": Method(
9276
json_rpc_method=RPCEndpoint("golembase_getEntityCount"),
93-
mungers=[custom_munger],
77+
mungers=[default_root_munger],
9478
),
9579
"get_all_entity_keys": Method(
9680
json_rpc_method=RPCEndpoint("golembase_getAllEntityKeys"),
97-
mungers=[custom_munger],
81+
mungers=[default_root_munger],
9882
),
9983
"get_entities_of_owner": Method(
10084
json_rpc_method=RPCEndpoint("golembase_getEntitiesOfOwner"),
101-
mungers=[custom_munger],
85+
mungers=[default_root_munger],
10286
),
10387
"query_entities": Method(
10488
json_rpc_method=RPCEndpoint("golembase_queryEntities"),
105-
mungers=[custom_munger],
89+
mungers=[default_root_munger],
10690
),
10791
}

src/arkiv/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""Arkiv client exceptions."""
22

33

4+
class EntityKeyException(Exception):
5+
pass
6+
7+
48
class AccountNameException(Exception):
59
pass
610

src/arkiv/types.py

Lines changed: 2 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2,89 +2,13 @@
22

33
from collections.abc import Sequence
44
from dataclasses import dataclass
5+
from typing import NewType
56

67
from eth_typing import ChecksumAddress, HexStr
78
from hexbytes import HexBytes
89

9-
1010
# Unique key for all entities
11-
@dataclass(frozen=True)
12-
class EntityKey:
13-
"""EntityKey dataclass that wraps HexStr for entity identification."""
14-
15-
value: HexStr
16-
17-
def __init__(self, value: str | int | HexBytes | HexStr) -> None:
18-
"""Create an EntityKey from various input types."""
19-
if isinstance(value, str) and value.startswith("0x"):
20-
# Already a hex string - validate length
21-
if len(value) != 66: # 0x + 64 hex chars
22-
raise ValueError(
23-
f"EntityKey hex string must be 66 characters (0x + 64 hex), got {len(value)}"
24-
)
25-
object.__setattr__(self, "value", HexStr(value.lower()))
26-
elif isinstance(value, int):
27-
# Convert integer to hex string with 0x prefix
28-
if value < 0:
29-
raise ValueError("EntityKey cannot be negative")
30-
hex_str = f"0x{value:064x}" # 64 chars = 32 bytes = 256 bits
31-
object.__setattr__(self, "value", HexStr(hex_str))
32-
elif isinstance(value, (HexBytes, bytes)):
33-
# Convert bytes to hex string
34-
if len(value) != 32: # 32 bytes = 256 bits
35-
raise ValueError(
36-
f"EntityKey bytes must be exactly 32 bytes, got {len(value)}"
37-
)
38-
object.__setattr__(self, "value", HexStr(f"0x{value.hex()}"))
39-
elif isinstance(value, str):
40-
# Plain string, assume it needs 0x prefix
41-
if len(value) % 2 != 0:
42-
value = "0" + value # Pad with leading zero if odd length
43-
if len(value) != 64: # Should be 64 hex characters
44-
raise ValueError(
45-
f"EntityKey hex string (without 0x) must be 64 characters, got {len(value)}"
46-
)
47-
object.__setattr__(self, "value", HexStr(f"0x{value.lower()}"))
48-
else:
49-
# Try to convert via HexBytes first, then to hex string
50-
try:
51-
hex_bytes = HexBytes(value)
52-
if len(hex_bytes) != 32:
53-
raise ValueError(
54-
f"EntityKey must represent exactly 32 bytes, got {len(hex_bytes)}"
55-
)
56-
object.__setattr__(self, "value", HexStr(f"0x{hex_bytes.hex()}"))
57-
except Exception as e:
58-
raise ValueError(
59-
f"Cannot convert {type(value)} to EntityKey: {e}"
60-
) from e
61-
62-
def __str__(self) -> str:
63-
"""String representation."""
64-
return self.value
65-
66-
def __repr__(self) -> str:
67-
"""Repr representation."""
68-
return f"EntityKey('{self.value}')"
69-
70-
def __eq__(self, other: object) -> bool:
71-
"""Equality comparison."""
72-
if isinstance(other, EntityKey):
73-
return self.value == other.value
74-
return False
75-
76-
def __hash__(self) -> int:
77-
"""Hash for use in sets/dicts."""
78-
return hash(self.value)
79-
80-
@property
81-
def hex(self) -> str:
82-
"""Get hex string without 0x prefix."""
83-
return self.value[2:]
84-
85-
def to_bytes(self) -> bytes:
86-
"""Convert to bytes."""
87-
return bytes.fromhex(self.hex)
11+
EntityKey = NewType("EntityKey", HexStr)
8812

8913

9014
type AnnotationValue = str | int # Only str or non-negative int allowed

src/arkiv/utils.py

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any
55

66
import rlp # type: ignore[import-untyped]
7+
from eth_typing import HexStr
78
from hexbytes import HexBytes
89
from web3 import Web3
910
from web3.contract import Contract
@@ -12,6 +13,7 @@
1213

1314
from . import contract
1415
from .contract import STORAGE_ADDRESS
16+
from .exceptions import EntityKeyException
1517
from .types import (
1618
Annotation,
1719
AnnotationValue,
@@ -28,6 +30,52 @@
2830
logger = logging.getLogger(__name__)
2931

3032

33+
def to_entity_key(entity_key_int: int) -> EntityKey:
34+
hex_value = Web3.to_hex(entity_key_int)
35+
# ensure lenth is 66 (0x + 64 hex)
36+
if len(hex_value) < 66:
37+
hex_value = HexStr("0x" + hex_value[2:].zfill(64))
38+
return EntityKey(hex_value)
39+
40+
41+
def entity_key_to_bytes(entity_key: EntityKey) -> bytes:
42+
return bytes.fromhex(entity_key[2:]) # Strip '0x' prefix and convert to bytes
43+
44+
45+
def check_entity_key(entity_key: Any | None, label: str | None = None) -> None:
46+
"""Validates entity key."""
47+
prefix = ""
48+
if label:
49+
prefix = f"{label}: "
50+
51+
logger.info(f"{prefix}Checking entity key {entity_key}")
52+
53+
if entity_key is None:
54+
raise EntityKeyException("Entity key should not be None")
55+
if not isinstance(entity_key, str):
56+
raise EntityKeyException(
57+
f"Entity key type should be str but is: {type(entity_key)}"
58+
)
59+
if len(entity_key) != 66:
60+
raise EntityKeyException(
61+
f"Entity key should be 66 characters long (0x + 64 hex) but is: {len(entity_key)}"
62+
)
63+
if not is_hex_str(entity_key):
64+
raise EntityKeyException("Entity key should be a valid hex string")
65+
66+
67+
def is_hex_str(value: str) -> bool:
68+
if not isinstance(value, str):
69+
return False
70+
if value.startswith("0x"):
71+
value = value[2:]
72+
try:
73+
int(value, 16)
74+
return True
75+
except ValueError:
76+
return False
77+
78+
3179
def to_create_operation(
3280
payload: bytes | None = None,
3381
annotations: dict[str, AnnotationValue] | None = None,
@@ -116,33 +164,36 @@ def to_receipt(
116164
event_args: dict[str, Any] = event_data["args"]
117165
event_name = event_data["event"]
118166

167+
entity_key: EntityKey = to_entity_key(event_args["entityKey"])
168+
expiration_block: int = event_args["expirationBlock"]
169+
119170
match event_name:
120171
case contract.CREATED_EVENT:
121172
creates.append(
122173
CreateReceipt(
123-
entity_key=EntityKey(event_args["entityKey"]),
124-
expiration_block=int(event_args["expirationBlock"]),
174+
entity_key=entity_key,
175+
expiration_block=expiration_block,
125176
)
126177
)
127178
case contract.UPDATED_EVENT:
128179
updates.append(
129180
UpdateReceipt(
130-
entity_key=EntityKey(event_args["entityKey"]),
131-
expiration_block=int(event_args["expirationBlock"]),
181+
entity_key=entity_key,
182+
expiration_block=expiration_block,
132183
)
133184
)
134185
case contract.DELETED_EVENT:
135186
deletes.append(
136187
DeleteReceipt(
137-
entity_key=EntityKey(event_args["entityKey"]),
188+
entity_key=entity_key,
138189
)
139190
)
140191
case contract.EXTENDED_EVENT:
141192
extensions.append(
142193
ExtendReceipt(
143-
entity_key=EntityKey(event_args["entityKey"]),
144-
old_expiration_block=int(event_args["oldExpirationBlock"]),
145-
new_expiration_block=int(event_args["newExpirationBlock"]),
194+
entity_key=entity_key,
195+
old_expiration_block=event_args["oldExpirationBlock"],
196+
new_expiration_block=event_args["newExpirationBlock"],
146197
)
147198
)
148199
case _:
@@ -195,7 +246,7 @@ def format_annotation(annotation: Annotation) -> tuple[str, AnnotationValue]:
195246
# Update
196247
[
197248
[
198-
element.entity_key.to_bytes(),
249+
entity_key_to_bytes(element.entity_key),
199250
element.btl,
200251
element.data,
201252
list(map(format_annotation, element.string_annotations)),
@@ -206,14 +257,14 @@ def format_annotation(annotation: Annotation) -> tuple[str, AnnotationValue]:
206257
# Delete
207258
[
208259
[
209-
element.entity_key.to_bytes(),
260+
entity_key_to_bytes(element.entity_key),
210261
]
211262
for element in tx.deletes
212263
],
213264
# Extend
214265
[
215266
[
216-
element.entity_key.to_bytes(),
267+
entity_key_to_bytes(element.entity_key),
217268
element.number_of_blocks,
218269
]
219270
for element in tx.extensions

tests/test_entity_create.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,14 @@
77

88
from arkiv.client import Arkiv
99
from arkiv.contract import STORAGE_ADDRESS
10-
from arkiv.types import EntityKey, Operations
11-
from arkiv.utils import to_create_operation, to_receipt, to_tx_params
10+
from arkiv.types import Operations
11+
from arkiv.utils import check_entity_key, to_create_operation, to_receipt, to_tx_params
1212

1313
logger = logging.getLogger(__name__)
1414

1515
TX_SUCCESS = 1
1616

1717

18-
def check_entity_key(label: str, entity_key: EntityKey) -> None:
19-
"""Check entity key validity."""
20-
logger.info(f"{label}: Checking entity key {entity_key}")
21-
assert entity_key is not None, f"{label}: Entity key should not be None"
22-
assert isinstance(entity_key, EntityKey), f"{label}: Entity key should be EntityKey"
23-
assert len(entity_key.to_bytes()) == 32, (
24-
f"{label}: Entity key should be 32 bytes long"
25-
)
26-
27-
2818
def check_tx_hash(label: str, tx_hash: HexBytes) -> None:
2919
"""Check transaction hash validity."""
3020
logger.info(f"{label}: Checking transaction hash {tx_hash.to_0x_hex()}")
@@ -161,7 +151,7 @@ def test_create_entity_simple(self, arkiv_client_http: Arkiv) -> None:
161151
)
162152

163153
label = "create_entity (a)"
164-
check_entity_key(label, entity_key)
154+
check_entity_key(entity_key, label)
165155
check_tx_hash(label, tx_hash)
166156

167157
entity = arkiv_client_http.arkiv.get_entity(entity_key)

tests/test_entity_create_parallel.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,8 @@ def client_task(client: Arkiv, client_idx: int, num_entities: int) -> list[Entit
115115
@pytest.mark.parametrize(
116116
"num_clients,num_entities",
117117
[
118-
(2, 10),
119-
# (2, 3),
120-
# (4, 2),
118+
# (20, 50),
119+
(2, 3),
121120
],
122121
)
123122
def test_parallel_entity_creation(
@@ -130,6 +129,7 @@ def test_parallel_entity_creation(
130129
if not rpc_url:
131130
pytest.skip("No Arkiv node available for testing")
132131

132+
# Create Arkiv clients
133133
logger.info(f"Starting {num_clients} Arkiv clients...")
134134
clients = []
135135
for i in range(num_clients):
@@ -138,6 +138,8 @@ def test_parallel_entity_creation(
138138
client = create_client(container, rpc_url, client_idx)
139139
account: ChecksumAddress = cast(ChecksumAddress, client.eth.default_account)
140140
balance = client.eth.get_balance(account)
141+
142+
# Only use clients with non-zero balance
141143
if balance > 0:
142144
logger.info(f"Arkiv client[{client_idx}] started.")
143145
clients.append(client)
@@ -147,6 +149,7 @@ def test_parallel_entity_creation(
147149
# Remember start time
148150
start_time = time.time()
149151

152+
# Start all clients in separate threads (pseudo-parallelism)
150153
threads = []
151154
for client_idx in range(len(clients)):
152155
client = clients[client_idx]
@@ -164,6 +167,7 @@ def test_parallel_entity_creation(
164167
end_time = time.time()
165168
elapsed_time = end_time - start_time
166169

170+
logger.info(f"Total active clients: {len(clients)}")
167171
logger.info(f"Total successful entity creation TX: {tx_counter}")
168172
logger.info(f"TX creation start time: {start_time:.2f}, end time: {end_time:.2f}")
169173
logger.info(f"All Arkiv clients have completed in {elapsed_time:.2f} seconds.")

0 commit comments

Comments
 (0)