diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index d13a058d..244d6bf7 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -27,13 +27,14 @@ from .containers import PublicKey, SecretKey, Signature from .prf import PROD_PRF, TEST_PRF, Prf from .rand import PROD_RAND, TEST_RAND, Rand +from .subtree import HashSubTree from .tweak_hash import ( PROD_TWEAK_HASHER, TEST_TWEAK_HASHER, TweakHasher, ) from .types import HashDigestVector -from .utils import bottom_tree_from_prf_key, expand_activation_time +from .utils import expand_activation_time class GeneralizedXmssScheme(StrictBaseModel): @@ -162,23 +163,23 @@ def key_gen( actual_num_active_epochs = num_bottom_trees * leafs_per_bottom_tree # Step 2: Generate the first two bottom trees (kept in memory). - left_bottom_tree = bottom_tree_from_prf_key( - self.prf, - self.hasher, - self.rand, - config, - prf_key, - Uint64(start_bottom_tree_index), - parameter, + left_bottom_tree = HashSubTree.from_prf_key( + prf=self.prf, + hasher=self.hasher, + rand=self.rand, + config=config, + prf_key=prf_key, + bottom_tree_index=Uint64(start_bottom_tree_index), + parameter=parameter, ) - right_bottom_tree = bottom_tree_from_prf_key( - self.prf, - self.hasher, - self.rand, - config, - prf_key, - Uint64(start_bottom_tree_index + 1), - parameter, + right_bottom_tree = HashSubTree.from_prf_key( + prf=self.prf, + hasher=self.hasher, + rand=self.rand, + config=config, + prf_key=prf_key, + bottom_tree_index=Uint64(start_bottom_tree_index + 1), + parameter=parameter, ) # Collect roots for building the top tree. @@ -189,21 +190,18 @@ def key_gen( # Step 3: Generate remaining bottom trees (only their roots). for i in range(start_bottom_tree_index + 2, end_bottom_tree_index): - tree = bottom_tree_from_prf_key( - self.prf, - self.hasher, - self.rand, - config, - prf_key, - Uint64(i), - parameter, + tree = HashSubTree.from_prf_key( + prf=self.prf, + hasher=self.hasher, + rand=self.rand, + config=config, + prf_key=prf_key, + bottom_tree_index=Uint64(i), + parameter=parameter, ) - root = tree.root() - bottom_tree_roots.append(root) + bottom_tree_roots.append(tree.root()) # Step 4: Build the top tree from bottom tree roots. - from .subtree import HashSubTree - top_tree = HashSubTree.new_top_tree( hasher=self.hasher, rand=self.rand, @@ -554,7 +552,7 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey: # Compute the next bottom tree (the one after the current right tree) new_right_tree_index = sk.left_bottom_tree_index + Uint64(2) - new_right_bottom_tree = bottom_tree_from_prf_key( + new_right_bottom_tree = HashSubTree.from_prf_key( prf=self.prf, hasher=self.hasher, rand=self.rand, diff --git a/src/lean_spec/subspecs/xmss/subtree.py b/src/lean_spec/subspecs/xmss/subtree.py index 451c5cc0..b0de2712 100644 --- a/src/lean_spec/subspecs/xmss/subtree.py +++ b/src/lean_spec/subspecs/xmss/subtree.py @@ -20,10 +20,13 @@ HashTreeLayers, HashTreeOpening, Parameter, + PRFKey, ) from .utils import get_padded_layer if TYPE_CHECKING: + from .constants import XmssConfig + from .prf import Prf from .rand import Rand from .tweak_hash import TweakHasher @@ -318,6 +321,93 @@ def new_bottom_tree( layers=HashTreeLayers(data=truncated + [root_layer]), ) + @classmethod + def from_prf_key( + cls, + prf: "Prf", + hasher: "TweakHasher", + rand: "Rand", + config: "XmssConfig", + prf_key: PRFKey, + bottom_tree_index: Uint64, + parameter: Parameter, + ) -> "HashSubTree": + """ + Generates a single bottom tree on-demand from the PRF key. + + This is a key component of the top-bottom tree approach: instead of storing all + one-time secret keys, we regenerate them on-demand using the PRF. This enables + O(sqrt(LIFETIME)) memory usage. + + ### Algorithm + + 1. **Determine epoch range**: Bottom tree `i` covers epochs + `[i * sqrt(LIFETIME), (i+1) * sqrt(LIFETIME))` + + 2. **Generate leaves**: For each epoch in parallel: + - For each chain (0 to DIMENSION-1): + - Derive secret start: `PRF(prf_key, epoch, chain_index)` + - Compute public end: hash chain for `BASE - 1` steps + - Hash all chain ends to get the leaf + + 3. **Build bottom tree**: Construct the bottom tree from the leaves + + Args: + prf: The PRF instance for key derivation. + hasher: The tweakable hash instance. + rand: Random generator for padding values. + config: The XMSS configuration. + prf_key: The master PRF secret key. + bottom_tree_index: The index of the bottom tree to generate (0, 1, 2, ...). + parameter: The public parameter `P` for the hash function. + + Returns: + A `HashSubTree` representing the requested bottom tree. + """ + # Calculate the number of leaves per bottom tree: sqrt(LIFETIME). + leafs_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2) + + # Determine the epoch range for this bottom tree. + start_epoch = bottom_tree_index * Uint64(leafs_per_bottom_tree) + end_epoch = start_epoch + Uint64(leafs_per_bottom_tree) + + # Generate leaf hashes for all epochs in this bottom tree. + leaf_hashes: List[HashDigestVector] = [] + + for epoch in range(int(start_epoch), int(end_epoch)): + # For each epoch, compute the one-time public key (chain endpoints). + chain_ends: List[HashDigestVector] = [] + + for chain_index in range(config.DIMENSION): + # Derive the secret start of the chain from the PRF key. + start_digest = prf.apply(prf_key, Uint64(epoch), Uint64(chain_index)) + + # Compute the public end by hashing BASE - 1 times. + end_digest = hasher.hash_chain( + parameter=parameter, + epoch=Uint64(epoch), + chain_index=chain_index, + start_step=0, + num_steps=config.BASE - 1, + start_digest=start_digest, + ) + chain_ends.append(end_digest) + + # Hash the chain ends to get the leaf for this epoch. + leaf_tweak = TreeTweak(level=0, index=epoch) + leaf_hash = hasher.apply(parameter, leaf_tweak, chain_ends) + leaf_hashes.append(leaf_hash) + + # Build the bottom tree from the leaf hashes. + return cls.new_bottom_tree( + hasher=hasher, + rand=rand, + depth=config.LOG_LIFETIME, + bottom_tree_index=bottom_tree_index, + parameter=parameter, + leaves=leaf_hashes, + ) + def root(self) -> HashDigestVector: """ Extracts the root digest from this subtree. diff --git a/src/lean_spec/subspecs/xmss/utils.py b/src/lean_spec/subspecs/xmss/utils.py index b0f2dfde..0a1d205a 100644 --- a/src/lean_spec/subspecs/xmss/utils.py +++ b/src/lean_spec/subspecs/xmss/utils.py @@ -1,17 +1,11 @@ """Utility functions for the XMSS signature scheme.""" -from typing import TYPE_CHECKING, List +from typing import List from ...types.uint import Uint64 from ..koalabear import Fp, P -from .constants import XmssConfig from .rand import Rand -from .types import HashDigestList, HashDigestVector, HashTreeLayer, Parameter, PRFKey - -if TYPE_CHECKING: - from .prf import Prf - from .subtree import HashSubTree - from .tweak_hash import TweakHasher +from .types import HashDigestList, HashDigestVector, HashTreeLayer def get_padded_layer( @@ -159,93 +153,3 @@ def expand_activation_time( end_bottom_tree_index = end // c return (start_bottom_tree_index, end_bottom_tree_index) - - -def bottom_tree_from_prf_key( - prf: "Prf", - hasher: "TweakHasher", - rand: Rand, - config: XmssConfig, - prf_key: PRFKey, - bottom_tree_index: Uint64, - parameter: Parameter, -) -> "HashSubTree": - """ - Generates a single bottom tree on-demand from the PRF key. - - This is a key component of the top-bottom tree approach: instead of storing all - one-time secret keys, we regenerate them on-demand using the PRF. This enables - O(sqrt(LIFETIME)) memory usage. - - ### Algorithm - - 1. **Determine epoch range**: Bottom tree `i` covers epochs - `[i * sqrt(LIFETIME), (i+1) * sqrt(LIFETIME))` - - 2. **Generate leaves**: For each epoch in parallel: - - For each chain (0 to DIMENSION-1): - - Derive secret start: `PRF(prf_key, epoch, chain_index)` - - Compute public end: hash chain for `BASE - 1` steps - - Hash all chain ends to get the leaf - - 3. **Build bottom tree**: Construct the bottom tree from the leaves - - Args: - prf: The PRF instance for key derivation. - hasher: The tweakable hash instance. - rand: Random generator for padding values. - config: The XMSS configuration. - prf_key: The master PRF secret key. - bottom_tree_index: The index of the bottom tree to generate (0, 1, 2, ...). - parameter: The public parameter `P` for the hash function. - - Returns: - A `HashSubTree` representing the requested bottom tree. - """ - from .tweak_hash import TreeTweak - - # Calculate the number of leaves per bottom tree: sqrt(LIFETIME). - leafs_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2) - - # Determine the epoch range for this bottom tree. - start_epoch = bottom_tree_index * Uint64(leafs_per_bottom_tree) - end_epoch = start_epoch + Uint64(leafs_per_bottom_tree) - - # Generate leaf hashes for all epochs in this bottom tree. - leaf_hashes: List[HashDigestVector] = [] - - for epoch in range(int(start_epoch), int(end_epoch)): - # For each epoch, compute the one-time public key (chain endpoints). - chain_ends: List[HashDigestVector] = [] - - for chain_index in range(config.DIMENSION): - # Derive the secret start of the chain from the PRF key. - start_digest = prf.apply(prf_key, Uint64(epoch), Uint64(chain_index)) - - # Compute the public end by hashing BASE - 1 times. - end_digest = hasher.hash_chain( - parameter=parameter, - epoch=Uint64(epoch), - chain_index=chain_index, - start_step=0, - num_steps=config.BASE - 1, - start_digest=start_digest, - ) - chain_ends.append(end_digest) - - # Hash the chain ends to get the leaf for this epoch. - leaf_tweak = TreeTweak(level=0, index=epoch) - leaf_hash = hasher.apply(parameter, leaf_tweak, chain_ends) - leaf_hashes.append(leaf_hash) - - # Build the bottom tree from the leaf hashes. - from .subtree import HashSubTree - - return HashSubTree.new_bottom_tree( - hasher=hasher, - rand=rand, - depth=config.LOG_LIFETIME, - bottom_tree_index=bottom_tree_index, - parameter=parameter, - leaves=leaf_hashes, - ) diff --git a/tests/lean_spec/subspecs/xmss/test_utils.py b/tests/lean_spec/subspecs/xmss/test_utils.py index 42cbc879..d09b41e7 100644 --- a/tests/lean_spec/subspecs/xmss/test_utils.py +++ b/tests/lean_spec/subspecs/xmss/test_utils.py @@ -9,10 +9,10 @@ from lean_spec.subspecs.xmss.constants import TEST_CONFIG from lean_spec.subspecs.xmss.prf import TEST_PRF from lean_spec.subspecs.xmss.rand import TEST_RAND +from lean_spec.subspecs.xmss.subtree import HashSubTree from lean_spec.subspecs.xmss.tweak_hash import TEST_TWEAK_HASHER from lean_spec.subspecs.xmss.types import HashTreeLayer, Parameter from lean_spec.subspecs.xmss.utils import ( - bottom_tree_from_prf_key, expand_activation_time, int_to_base_p, ) @@ -109,8 +109,8 @@ def test_expand_activation_time( assert actual_end_epoch <= lifetime -def test_bottom_tree_from_prf_key() -> None: - """Tests that bottom_tree_from_prf_key generates a valid bottom tree.""" +def test_hash_subtree_from_prf_key() -> None: + """Tests that HashSubTree.from_prf_key generates a valid bottom tree.""" config = TEST_CONFIG # Generate a PRF key @@ -122,7 +122,7 @@ def test_bottom_tree_from_prf_key() -> None: ) # Generate bottom tree 0 - bottom_tree = bottom_tree_from_prf_key( + bottom_tree = HashSubTree.from_prf_key( prf=TEST_PRF, hasher=TEST_TWEAK_HASHER, rand=TEST_RAND, @@ -147,8 +147,8 @@ def test_bottom_tree_from_prf_key() -> None: assert len(leaf_layer.nodes) == leafs_per_bottom_tree -def test_bottom_tree_from_prf_key_deterministic() -> None: - """Tests that bottom_tree_from_prf_key is deterministic.""" +def test_hash_subtree_from_prf_key_deterministic() -> None: + """Tests that HashSubTree.from_prf_key is deterministic.""" config = TEST_CONFIG prf_key = TEST_PRF.key_gen() parameter = Parameter( @@ -156,7 +156,7 @@ def test_bottom_tree_from_prf_key_deterministic() -> None: ) # Generate the same bottom tree twice - tree1 = bottom_tree_from_prf_key( + tree1 = HashSubTree.from_prf_key( prf=TEST_PRF, hasher=TEST_TWEAK_HASHER, rand=TEST_RAND, @@ -166,7 +166,7 @@ def test_bottom_tree_from_prf_key_deterministic() -> None: parameter=parameter, ) - tree2 = bottom_tree_from_prf_key( + tree2 = HashSubTree.from_prf_key( prf=TEST_PRF, hasher=TEST_TWEAK_HASHER, rand=TEST_RAND, @@ -183,7 +183,7 @@ def test_bottom_tree_from_prf_key_deterministic() -> None: ) -def test_bottom_tree_from_prf_key_different_indices() -> None: +def test_hash_subtree_from_prf_key_different_indices() -> None: """Tests that different bottom tree indices produce different trees.""" config = TEST_CONFIG prf_key = TEST_PRF.key_gen() @@ -192,7 +192,7 @@ def test_bottom_tree_from_prf_key_different_indices() -> None: ) # Generate two different bottom trees - tree0 = bottom_tree_from_prf_key( + tree0 = HashSubTree.from_prf_key( prf=TEST_PRF, hasher=TEST_TWEAK_HASHER, rand=TEST_RAND, @@ -202,7 +202,7 @@ def test_bottom_tree_from_prf_key_different_indices() -> None: parameter=parameter, ) - tree1 = bottom_tree_from_prf_key( + tree1 = HashSubTree.from_prf_key( prf=TEST_PRF, hasher=TEST_TWEAK_HASHER, rand=TEST_RAND,