Skip to content

Commit f57fca8

Browse files
authored
types: small refactor (leanEthereum#65)
1 parent e571d0d commit f57fca8

File tree

5 files changed

+40
-85
lines changed

5 files changed

+40
-85
lines changed

src/lean_spec/types/bitfields.py

Lines changed: 16 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,12 @@ def _validate_vector_data(cls, v: Any) -> Tuple[Boolean, ...]:
5454
if not isinstance(v, (list, tuple)):
5555
v = tuple(v)
5656

57-
# Convert each bit to Boolean
58-
typed_values = tuple(Boolean(item) if not isinstance(item, Boolean) else item for item in v)
59-
60-
if len(typed_values) != cls.LENGTH:
57+
if len(v) != cls.LENGTH:
6158
raise ValueError(
62-
f"{cls.__name__} requires exactly {cls.LENGTH} bits, "
63-
f"but {len(typed_values)} were provided."
59+
f"{cls.__name__} requires exactly {cls.LENGTH} bits, but {len(v)} were provided."
6460
)
6561

66-
return typed_values
62+
return tuple(Boolean(item) for item in v)
6763

6864
@classmethod
6965
def is_fixed_size(cls) -> bool:
@@ -104,9 +100,7 @@ def encode_bytes(self) -> bytes:
104100
byte_array = bytearray(byte_len)
105101
for i, bit in enumerate(self.data):
106102
if bit:
107-
byte_index = i // 8
108-
bit_index_in_byte = i % 8
109-
byte_array[byte_index] |= 1 << bit_index_in_byte
103+
byte_array[i // 8] |= 1 << (i % 8)
110104
return bytes(byte_array)
111105

112106
@classmethod
@@ -120,8 +114,8 @@ def decode_bytes(cls, data: bytes) -> Self:
120114
if len(data) != expected_len:
121115
raise ValueError(f"{cls.__name__} expected {expected_len} bytes, got {len(data)}")
122116

123-
bits_generator = (Boolean((data[i // 8] >> (i % 8)) & 1) for i in range(cls.LENGTH))
124-
return cls(data=tuple(bits_generator))
117+
bits = tuple(Boolean((data[i // 8] >> (i % 8)) & 1) for i in range(cls.LENGTH))
118+
return cls(data=bits)
125119

126120

127121
class BaseBitlist(SSZModel):
@@ -158,18 +152,10 @@ def _validate_list_data(cls, v: Any) -> Tuple[Boolean, ...]:
158152
f"{cls.__name__} cannot contain more than {cls.LIMIT} bits, got {len(elements)}"
159153
)
160154

161-
# Convert and validate each bit
162-
typed_values = []
163-
for i, element in enumerate(elements):
164-
if isinstance(element, Boolean):
165-
typed_values.append(element)
166-
else:
167-
try:
168-
typed_values.append(Boolean(element))
169-
except Exception as e:
170-
raise ValueError(f"Bit {i} cannot be converted to Boolean: {e}") from e
171-
172-
return tuple(typed_values)
155+
try:
156+
return tuple(Boolean(element) for element in elements)
157+
except Exception as e:
158+
raise ValueError(f"Cannot convert elements to Boolean: {e}") from e
173159

174160
def __getitem__(self, key: int | slice) -> Boolean | tuple[Boolean, ...]:
175161
"""Get a bit by index or slice."""
@@ -219,19 +205,15 @@ def encode_bytes(self) -> bytes:
219205
# Pack data bits.
220206
for i, bit in enumerate(self.data):
221207
if bit:
222-
byte_index = i // 8
223-
bit_index_in_byte = i % 8
224-
byte_array[byte_index] |= 1 << bit_index_in_byte
208+
byte_array[i // 8] |= 1 << (i % 8)
225209

226210
# Place delimiter bit (1) immediately after the last data bit.
227211
if num_bits % 8 == 0:
228212
# Delimiter starts a new byte.
229213
return bytes(byte_array) + b"\x01"
230214
else:
231215
# Delimiter lives in the last byte at position num_bits % 8.
232-
delimiter_byte_index = num_bits // 8
233-
delimiter_bit_index = num_bits % 8
234-
byte_array[delimiter_byte_index] |= 1 << delimiter_bit_index
216+
byte_array[num_bits // 8] |= 1 << (num_bits % 8)
235217
return bytes(byte_array)
236218

237219
@classmethod
@@ -250,11 +232,9 @@ def decode_bytes(cls, data: bytes) -> Self:
250232
for byte_idx in range(len(data) - 1, -1, -1):
251233
byte_val = data[byte_idx]
252234
if byte_val != 0:
253-
# Find the rightmost 1 bit in this byte.
254-
for bit_idx in range(7, -1, -1):
255-
if (byte_val >> bit_idx) & 1:
256-
delimiter_pos = byte_idx * 8 + bit_idx
257-
break
235+
# Find the highest set bit in this byte using bit_length
236+
bit_idx = byte_val.bit_length() - 1
237+
delimiter_pos = byte_idx * 8 + bit_idx
258238
break
259239

260240
if delimiter_pos is None:
@@ -267,14 +247,5 @@ def decode_bytes(cls, data: bytes) -> Self:
267247
f"{cls.__name__} decoded length {num_data_bits} exceeds limit {cls.LIMIT}"
268248
)
269249

270-
bits = []
271-
for i in range(num_data_bits):
272-
byte_index = i // 8
273-
bit_index_in_byte = i % 8
274-
if byte_index < len(data):
275-
bit_value = bool((data[byte_index] >> bit_index_in_byte) & 1)
276-
else:
277-
bit_value = False
278-
bits.append(bit_value)
279-
250+
bits = [bool((data[i // 8] >> (i % 8)) & 1) for i in range(num_data_bits)]
280251
return cls(data=bits)

src/lean_spec/types/boolean.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,10 @@ def __new__(cls, value: bool | int) -> Self:
3737
if not isinstance(value, int):
3838
raise TypeError(f"Expected bool or int, got {type(value).__name__}")
3939

40-
int_value = int(value)
41-
if int_value not in (0, 1):
42-
raise ValueError(f"Boolean value must be 0 or 1, not {int_value}")
40+
if value not in (0, 1):
41+
raise ValueError(f"Boolean value must be 0 or 1, not {value}")
4342

44-
return super().__new__(cls, int_value)
43+
return super().__new__(cls, value)
4544

4645
@classmethod
4746
def __get_pydantic_core_schema__(
@@ -142,7 +141,7 @@ def __and__(self, other: Any) -> Self:
142141
"""Handle the bitwise AND operator (`&`) strictly."""
143142
if not isinstance(other, type(self)):
144143
self._raise_type_error(other, "&")
145-
return type(self)(super().__and__(other))
144+
return type(self)(int(self) & int(other))
146145

147146
def __rand__(self, other: Any) -> Self:
148147
"""Handle the reverse bitwise AND operator (`&`) strictly."""
@@ -152,7 +151,7 @@ def __or__(self, other: Any) -> Self:
152151
"""Handle the bitwise OR operator (`|`) strictly."""
153152
if not isinstance(other, type(self)):
154153
self._raise_type_error(other, "|")
155-
return type(self)(super().__or__(other))
154+
return type(self)(int(self) | int(other))
156155

157156
def __ror__(self, other: Any) -> Self:
158157
"""Handle the reverse bitwise OR operator (`|`) strictly."""
@@ -162,7 +161,7 @@ def __xor__(self, other: Any) -> Self:
162161
"""Handle the bitwise XOR operator (`^`) strictly."""
163162
if not isinstance(other, type(self)):
164163
self._raise_type_error(other, "^")
165-
return type(self)(super().__xor__(other))
164+
return type(self)(int(self) ^ int(other))
166165

167166
def __rxor__(self, other: Any) -> Self:
168167
"""Handle the reverse bitwise XOR operator (`^`) strictly."""
@@ -176,9 +175,7 @@ def __eq__(self, other: object) -> bool:
176175
177176
It returns `False` for all other types.
178177
"""
179-
if isinstance(other, int):
180-
return int(self) == int(other)
181-
return False
178+
return isinstance(other, int) and int(self) == int(other)
182179

183180
def __ne__(self, other: object) -> bool:
184181
"""

src/lean_spec/types/byte_arrays.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@ def _coerce_to_bytes(value: Any) -> bytes:
3434
if isinstance(value, (bytes, bytearray)):
3535
return bytes(value)
3636
if isinstance(value, str):
37-
s = value[2:] if value.startswith("0x") else value
3837
# bytes.fromhex handles empty string and validates hex characters
39-
return bytes.fromhex(s)
38+
return bytes.fromhex(value.removeprefix("0x"))
4039
if isinstance(value, Iterable):
4140
# bytes(bytearray(iterable)) enforces each element is an int in 0..255
4241
return bytes(bytearray(value))
@@ -185,9 +184,7 @@ def __hash__(self) -> int:
185184

186185
def hex(self, sep: str | bytes | None = None, bytes_per_sep: SupportsIndex = 1) -> str:
187186
"""Return the hexadecimal string representation of the underlying bytes."""
188-
if sep is None:
189-
return bytes(self).hex()
190-
return bytes(self).hex(sep, bytes_per_sep)
187+
return bytes(self).hex() if sep is None else bytes(self).hex(sep, bytes_per_sep)
191188

192189

193190
class Bytes1(BaseBytes):
@@ -288,11 +285,11 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self:
288285
"""
289286
if scope < 0:
290287
raise ValueError("Invalid scope for ByteList: negative")
288+
if scope > cls.LIMIT:
289+
raise ValueError(f"ByteList[{cls.LIMIT}] scope {scope} exceeds limit")
291290
data = stream.read(scope)
292291
if len(data) != scope:
293292
raise IOError("Stream ended prematurely while decoding ByteList")
294-
if len(data) > cls.LIMIT:
295-
raise ValueError(f"ByteList[{cls.LIMIT}] decoded length {len(data)} exceeds limit")
296293
return cls(data=data)
297294

298295
def encode_bytes(self) -> bytes:
@@ -316,8 +313,6 @@ def __bytes__(self) -> bytes:
316313

317314
def __add__(self, other: Any) -> bytes:
318315
"""Return the concatenation of the byte list and the argument."""
319-
if isinstance(other, (bytes, bytearray)):
320-
return self.data + bytes(other)
321316
return self.data + bytes(other)
322317

323318
def __radd__(self, other: Any) -> bytes:

src/lean_spec/types/collections.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,7 @@ def serialize(self, stream: IO[bytes]) -> int:
8484
"""Serialize the vector to a binary stream."""
8585
# If elements are fixed-size, serialize them back-to-back.
8686
if self.is_fixed_size():
87-
total_bytes_written = 0
88-
for element in self.data:
89-
total_bytes_written += element.serialize(stream)
90-
return total_bytes_written
87+
return sum(element.serialize(stream) for element in self.data)
9188
# If elements are variable-size, serialize their offsets, then their data.
9289
else:
9390
# Use a temporary in-memory stream to hold the serialized variable data.
@@ -230,10 +227,7 @@ def serialize(self, stream: IO[bytes]) -> int:
230227
# Lists are always variable-size, so we serialize offsets + data
231228
if self.ELEMENT_TYPE.is_fixed_size():
232229
# Fixed-size elements: serialize them back-to-back
233-
total_bytes_written = 0
234-
for element in self.data:
235-
total_bytes_written += element.serialize(stream)
236-
return total_bytes_written
230+
return sum(element.serialize(stream) for element in self.data)
237231
else:
238232
# Variable-size elements: serialize offsets, then data
239233
variable_data_stream = io.BytesIO()
@@ -268,10 +262,10 @@ def deserialize(cls, stream: IO[bytes], scope: int) -> Self:
268262
return cls(data=elements)
269263
else:
270264
# Variable-size elements: read offsets first, then data
271-
if scope < OFFSET_BYTE_LENGTH:
265+
if scope == 0:
272266
# Empty list case
273-
if scope == 0:
274-
return cls(data=[])
267+
return cls(data=[])
268+
if scope < OFFSET_BYTE_LENGTH:
275269
raise ValueError(f"Invalid scope for variable-size list: {scope}")
276270

277271
# Read the first offset to determine the number of elements.

src/lean_spec/types/container.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,12 @@ def _get_ssz_field_type(annotation: Any) -> Type[SSZType]:
3434
TypeError: If the annotation is not a valid SSZType class.
3535
"""
3636
# Check if it's a class and is a subclass of SSZType
37-
if inspect.isclass(annotation) and issubclass(annotation, SSZType):
38-
return annotation
39-
40-
# If we get here, the annotation is not a valid SSZType
41-
raise TypeError(
42-
f"Field annotation {annotation} is not a valid SSZType class. "
43-
f"Container fields must be concrete SSZType subclasses."
44-
)
37+
if not (inspect.isclass(annotation) and issubclass(annotation, SSZType)):
38+
raise TypeError(
39+
f"Field annotation {annotation} is not a valid SSZType class. "
40+
f"Container fields must be concrete SSZType subclasses."
41+
)
42+
return annotation
4543

4644

4745
class Container(SSZModel):
@@ -134,7 +132,7 @@ def serialize(self, stream: IO[bytes]) -> int:
134132
variable_data = []
135133

136134
# Process each field in definition order
137-
for field_name, _field_info in type(self).model_fields.items():
135+
for field_name in type(self).model_fields:
138136
# Get the field value and its type
139137
value = getattr(self, field_name)
140138
# Use the actual runtime type of the value, which should be an SSZType
@@ -160,7 +158,7 @@ def serialize(self, stream: IO[bytes]) -> int:
160158
if part: # Fixed-size field data
161159
stream.write(part)
162160
else: # Variable-size field offset
163-
Uint32(offset).serialize(stream)
161+
stream.write(Uint32(offset).encode_bytes())
164162
offset += len(variable_data[var_index])
165163
var_index += 1
166164

0 commit comments

Comments
 (0)