Skip to content

Commit 60b7398

Browse files
authored
xmss: touchups to make things clearer and more modern (#239)
* xmss: touchups to make things clearer and more modern * some touchups * touchups
1 parent 836b17f commit 60b7398

File tree

12 files changed

+182
-226
lines changed

12 files changed

+182
-226
lines changed

src/lean_spec/subspecs/xmss/containers.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,32 @@ def verify(
7272
public_key: PublicKey,
7373
epoch: "Uint64",
7474
message: bytes,
75-
scheme: GeneralizedXmssScheme,
75+
scheme: "GeneralizedXmssScheme",
7676
) -> bool:
77-
"""Verify the signature using XMSS verification algorithm."""
77+
"""
78+
Verify the signature using XMSS verification algorithm.
79+
80+
This is a convenience method that delegates to `scheme.verify()`.
81+
82+
Invalid or malformed signatures return `False`.
83+
84+
Expected exceptions:
85+
- `ValueError` for invalid epochs,
86+
- `IndexError` for malformed signatures
87+
are caught and converted to `False`.
88+
89+
Args:
90+
public_key: The public key to verify against.
91+
epoch: The epoch the signature corresponds to.
92+
message: The message that was supposedly signed.
93+
scheme: The XMSS scheme instance to use for verification.
94+
95+
Returns:
96+
`True` if the signature is valid, `False` otherwise.
97+
"""
7898
try:
7999
return scheme.verify(public_key, epoch, message, self)
80-
except Exception:
100+
except (ValueError, IndexError):
81101
return False
82102

83103

src/lean_spec/subspecs/xmss/hypercube.py

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,30 +33,29 @@
3333

3434
import bisect
3535
import math
36+
from dataclasses import dataclass
3637
from functools import lru_cache
3738
from itertools import accumulate
38-
from typing import List, Tuple
39-
40-
from lean_spec.types import StrictBaseModel
4139

4240
MAX_DIMENSION = 100
4341
"""The maximum dimension `v` for which layer sizes will be precomputed."""
4442

4543

46-
class LayerInfo(StrictBaseModel):
44+
@dataclass(frozen=True, slots=True)
45+
class LayerInfo:
4746
"""
48-
A data structure to store precomputed sizes and cumulative sums for the
49-
layers of a single hypercube configuration (fixed `w` and `v`).
47+
Precomputed sizes and cumulative sums for a hypercube configuration.
5048
51-
This object makes subsequent calculations, like finding the total size of a
52-
range of layers, highly efficient.
49+
This immutable data structure enables O(1) lookups for layer sizes and
50+
range sums, which is critical for efficient hypercube mapping.
5351
"""
5452

55-
sizes: List[int]
56-
"""A list where `sizes[d]` is the number of vertices in layer `d`."""
57-
prefix_sums: List[int]
53+
sizes: tuple[int, ...]
54+
"""Tuple where `sizes[d]` is the number of vertices in layer `d`."""
55+
56+
prefix_sums: tuple[int, ...]
5857
"""
59-
A list where `prefix_sums[d]` is the cumulative number of vertices from
58+
Tuple where `prefix_sums[d]` is the cumulative number of vertices from
6059
layer 0 up to and including layer `d`.
6160
6261
Mathematically: `prefix_sums[d] = sizes[0] + ... + sizes[d]`.
@@ -125,7 +124,7 @@ def _calculate_layer_size(w: int, v: int, d: int) -> int:
125124

126125

127126
@lru_cache(maxsize=None)
128-
def prepare_layer_info(w: int) -> List[LayerInfo]:
127+
def prepare_layer_info(w: int) -> tuple[LayerInfo, ...]:
129128
"""
130129
Precomputes and caches layer information using a direct combinatorial formula.
131130
@@ -138,24 +137,25 @@ def prepare_layer_info(w: int) -> List[LayerInfo]:
138137
w: The base of the hypercube.
139138
140139
Returns:
141-
A list where `list[v]` is a `LayerInfo` object for a `v`-dim hypercube.
140+
A tuple where `tuple[v]` is a `LayerInfo` object for a `v`-dim hypercube.
142141
"""
143-
all_info = [LayerInfo(sizes=[], prefix_sums=[])] * (MAX_DIMENSION + 1)
142+
# Initialize with empty placeholder for index 0
143+
all_info: list[LayerInfo] = [LayerInfo(sizes=(), prefix_sums=())] * (MAX_DIMENSION + 1)
144144

145145
for v in range(1, MAX_DIMENSION + 1):
146146
# The maximum possible distance `d` in a v-dimensional hypercube.
147147
max_d = (w - 1) * v
148148

149149
# Directly compute the size of each layer using the helper function.
150-
sizes = [_calculate_layer_size(w, v, d) for d in range(max_d + 1)]
150+
sizes = tuple(_calculate_layer_size(w, v, d) for d in range(max_d + 1))
151151

152-
# Compute the cumulative sums from the list of sizes.
153-
prefix_sums = list(accumulate(sizes))
152+
# Compute the cumulative sums from the tuple of sizes.
153+
prefix_sums = tuple(accumulate(sizes))
154154

155155
# Store the complete layer info for the current dimension `v`.
156156
all_info[v] = LayerInfo(sizes=sizes, prefix_sums=prefix_sums)
157157

158-
return all_info
158+
return tuple(all_info)
159159

160160

161161
def get_layer_size(w: int, v: int, d: int) -> int:
@@ -168,7 +168,7 @@ def hypercube_part_size(w: int, v: int, d: int) -> int:
168168
return prepare_layer_info(w)[v].prefix_sums[d]
169169

170170

171-
def hypercube_find_layer(w: int, v: int, x: int) -> Tuple[int, int]:
171+
def hypercube_find_layer(w: int, v: int, x: int) -> tuple[int, int]:
172172
"""
173173
Given a global index `x`, finds its layer `d` and local offset `remainder`.
174174
@@ -203,7 +203,7 @@ def hypercube_find_layer(w: int, v: int, x: int) -> Tuple[int, int]:
203203
return d, remainder
204204

205205

206-
def map_to_vertex(w: int, v: int, d: int, x: int) -> List[int]:
206+
def map_to_vertex(w: int, v: int, d: int, x: int) -> list[int]:
207207
"""
208208
Maps an integer index `x` to a unique vertex in a specific hypercube layer.
209209
@@ -228,7 +228,7 @@ def map_to_vertex(w: int, v: int, d: int, x: int) -> List[int]:
228228
if x >= layer_size:
229229
raise ValueError("Index x is out of bounds for the given layer.")
230230

231-
vertex: List[int] = []
231+
vertex: list[int] = []
232232
# Track remaining distance and index.
233233
d_curr, x_curr = d, x
234234

@@ -239,20 +239,17 @@ def map_to_vertex(w: int, v: int, d: int, x: int) -> List[int]:
239239

240240
# This loop finds which block of sub-hypercubes the index `x_curr` falls into.
241241
#
242-
# It skips over full blocks by subtracting their size
243-
# from `x_curr` until the correct one is found.
244-
ji = -1 # Sentinel value
242+
# It skips over full blocks by subtracting their size from `x_curr` until found.
243+
ji = None
245244
range_start = max(0, d_curr - (w - 1) * dim_remaining)
246245
for j in range(range_start, min(w, d_curr + 1)):
247246
count = prev_dim_layer_info.sizes[d_curr - j]
248-
if x_curr >= count:
249-
x_curr -= count
250-
else:
251-
# Found the correct block.
247+
if x_curr < count:
252248
ji = j
253249
break
250+
x_curr -= count
254251

255-
if ji == -1:
252+
if ji is None:
256253
raise RuntimeError("Internal logic error: failed to find coordinate")
257254

258255
# Convert the block's distance contribution `ji` to a coordinate `ai`.

0 commit comments

Comments
 (0)