Skip to content

Commit b162611

Browse files
authored
xmss: mv KeyPair as a container (#242)
1 parent 5eb4708 commit b162611

File tree

3 files changed

+36
-44
lines changed

3 files changed

+36
-44
lines changed

packages/testing/src/consensus_testing/keys.py

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@
3333
import tempfile
3434
import urllib.request
3535
from concurrent.futures import ProcessPoolExecutor
36-
from dataclasses import dataclass
3736
from functools import cache, partial
3837
from pathlib import Path
39-
from typing import TYPE_CHECKING, Iterator, Self
38+
from typing import TYPE_CHECKING, Iterator
4039

4140
from lean_spec.config import LEAN_ENV
4241
from lean_spec.subspecs.containers import AttestationData
@@ -46,7 +45,7 @@
4645
AttestationSignatures,
4746
)
4847
from lean_spec.subspecs.containers.slot import Slot
49-
from lean_spec.subspecs.xmss.containers import PublicKey, SecretKey, Signature
48+
from lean_spec.subspecs.xmss.containers import KeyPair, PublicKey, Signature
5049
from lean_spec.subspecs.xmss.interface import (
5150
PROD_SIGNATURE_SCHEME,
5251
TEST_SIGNATURE_SCHEME,
@@ -120,39 +119,6 @@ def get_shared_key_manager(max_slot: Slot = _DEFAULT_MAX_SLOT) -> XmssKeyManager
120119
"""Key lifetime in epochs (derived from DEFAULT_MAX_SLOT)."""
121120

122121

123-
@dataclass(frozen=True, slots=True)
124-
class KeyPair:
125-
"""
126-
Immutable XMSS key pair for a validator.
127-
128-
Attributes:
129-
public: Public key for signature verification.
130-
secret: Secret key containing Merkle tree structures.
131-
"""
132-
133-
public: PublicKey
134-
secret: SecretKey
135-
136-
@classmethod
137-
def from_dict(cls, data: Mapping[str, str]) -> Self:
138-
"""Deserialize from JSON-compatible dict with hex-encoded SSZ."""
139-
return cls(
140-
public=PublicKey.decode_bytes(bytes.fromhex(data["public"])),
141-
secret=SecretKey.decode_bytes(bytes.fromhex(data["secret"])),
142-
)
143-
144-
def to_dict(self) -> dict[str, str]:
145-
"""Serialize to JSON-compatible dict with hex-encoded SSZ."""
146-
return {
147-
"public": self.public.encode_bytes().hex(),
148-
"secret": self.secret.encode_bytes().hex(),
149-
}
150-
151-
def with_secret(self, secret: SecretKey) -> KeyPair:
152-
"""Return a new KeyPair with updated secret key (for state advancement)."""
153-
return KeyPair(public=self.public, secret=secret)
154-
155-
156122
def _get_keys_dir(scheme_name: str) -> Path:
157123
"""Get the keys directory path for the given scheme."""
158124
return Path(__file__).parent / "test_keys" / f"{scheme_name}_scheme"
@@ -298,7 +264,7 @@ def sign_attestation_data(
298264
prepared = self.scheme.get_prepared_interval(sk)
299265

300266
# Cache advanced state
301-
self._state[validator_id] = kp.with_secret(sk)
267+
self._state[validator_id] = kp._replace(secret=sk)
302268

303269
# Sign hash tree root of the attestation data
304270
message = attestation_data.data_root_bytes()

src/lean_spec/subspecs/xmss/containers.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from __future__ import annotations
99

10-
from typing import TYPE_CHECKING
10+
from typing import TYPE_CHECKING, Mapping, NamedTuple
1111

1212
from ...types import Uint64
1313
from ...types.container import Container
@@ -181,3 +181,31 @@ class SecretKey(Container):
181181
Together with `left_bottom_tree`, this provides a prepared interval of
182182
exactly `2 * sqrt(LIFETIME)` consecutive epochs.
183183
"""
184+
185+
186+
class KeyPair(NamedTuple):
187+
"""
188+
Immutable XMSS key pair for a validator.
189+
190+
Attributes:
191+
public: Public key for signature verification.
192+
secret: Secret key containing Merkle tree structures.
193+
"""
194+
195+
public: PublicKey
196+
secret: SecretKey
197+
198+
@classmethod
199+
def from_dict(cls, data: Mapping[str, str]) -> "KeyPair":
200+
"""Deserialize from JSON-compatible dict with hex-encoded SSZ."""
201+
return cls(
202+
public=PublicKey.decode_bytes(bytes.fromhex(data["public"])),
203+
secret=SecretKey.decode_bytes(bytes.fromhex(data["secret"])),
204+
)
205+
206+
def to_dict(self) -> dict[str, str]:
207+
"""Serialize to JSON-compatible dict with hex-encoded SSZ."""
208+
return {
209+
"public": self.public.encode_bytes().hex(),
210+
"secret": self.secret.encode_bytes().hex(),
211+
}

src/lean_spec/subspecs/xmss/interface.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
TEST_CONFIG,
2525
XmssConfig,
2626
)
27-
from .containers import PublicKey, SecretKey, Signature
27+
from .containers import KeyPair, PublicKey, SecretKey, Signature
2828
from .prf import PROD_PRF, TEST_PRF, Prf
2929
from .rand import PROD_RAND, TEST_RAND, Rand
3030
from .subtree import HashSubTree, combined_path, verify_path
@@ -73,9 +73,7 @@ def _validate_strict_types(self) -> "GeneralizedXmssScheme":
7373
)
7474
return self
7575

76-
def key_gen(
77-
self, activation_epoch: Uint64, num_active_epochs: Uint64
78-
) -> tuple[PublicKey, SecretKey]:
76+
def key_gen(self, activation_epoch: Uint64, num_active_epochs: Uint64) -> KeyPair:
7977
"""
8078
Generates a new cryptographic key pair for a specified range of epochs.
8179
@@ -120,7 +118,7 @@ def key_gen(
120118
- Will be rounded up to at least `2 * sqrt(LIFETIME)`.
121119
122120
Returns:
123-
A tuple containing the `PublicKey` and `SecretKey`.
121+
A `KeyPair` containing the public and secret keys.
124122
125123
Note:
126124
The actual activation epoch and num_active_epochs in the returned SecretKey
@@ -220,7 +218,7 @@ def key_gen(
220218
left_bottom_tree=left_bottom_tree,
221219
right_bottom_tree=right_bottom_tree,
222220
)
223-
return pk, sk
221+
return KeyPair(public=pk, secret=sk)
224222

225223
def sign(self, sk: SecretKey, epoch: Uint64, message: bytes) -> Signature:
226224
"""

0 commit comments

Comments
 (0)