Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 28 additions & 30 deletions src/lean_spec/subspecs/xmss/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
from .containers import PublicKey, SecretKey, Signature
from .prf import PROD_PRF, TEST_PRF, Prf
from .rand import PROD_RAND, TEST_RAND, Rand
from .subtree import HashSubTree
from .tweak_hash import (
PROD_TWEAK_HASHER,
TEST_TWEAK_HASHER,
TweakHasher,
)
from .types import HashDigestVector
from .utils import bottom_tree_from_prf_key, expand_activation_time
from .utils import expand_activation_time


class GeneralizedXmssScheme(StrictBaseModel):
Expand Down Expand Up @@ -162,23 +163,23 @@ def key_gen(
actual_num_active_epochs = num_bottom_trees * leafs_per_bottom_tree

# Step 2: Generate the first two bottom trees (kept in memory).
left_bottom_tree = bottom_tree_from_prf_key(
self.prf,
self.hasher,
self.rand,
config,
prf_key,
Uint64(start_bottom_tree_index),
parameter,
left_bottom_tree = HashSubTree.from_prf_key(
prf=self.prf,
hasher=self.hasher,
rand=self.rand,
config=config,
prf_key=prf_key,
bottom_tree_index=Uint64(start_bottom_tree_index),
parameter=parameter,
)
right_bottom_tree = bottom_tree_from_prf_key(
self.prf,
self.hasher,
self.rand,
config,
prf_key,
Uint64(start_bottom_tree_index + 1),
parameter,
right_bottom_tree = HashSubTree.from_prf_key(
prf=self.prf,
hasher=self.hasher,
rand=self.rand,
config=config,
prf_key=prf_key,
bottom_tree_index=Uint64(start_bottom_tree_index + 1),
parameter=parameter,
)

# Collect roots for building the top tree.
Expand All @@ -189,21 +190,18 @@ def key_gen(

# Step 3: Generate remaining bottom trees (only their roots).
for i in range(start_bottom_tree_index + 2, end_bottom_tree_index):
tree = bottom_tree_from_prf_key(
self.prf,
self.hasher,
self.rand,
config,
prf_key,
Uint64(i),
parameter,
tree = HashSubTree.from_prf_key(
prf=self.prf,
hasher=self.hasher,
rand=self.rand,
config=config,
prf_key=prf_key,
bottom_tree_index=Uint64(i),
parameter=parameter,
)
root = tree.root()
bottom_tree_roots.append(root)
bottom_tree_roots.append(tree.root())

# Step 4: Build the top tree from bottom tree roots.
from .subtree import HashSubTree

top_tree = HashSubTree.new_top_tree(
hasher=self.hasher,
rand=self.rand,
Expand Down Expand Up @@ -554,7 +552,7 @@ def advance_preparation(self, sk: SecretKey) -> SecretKey:

# Compute the next bottom tree (the one after the current right tree)
new_right_tree_index = sk.left_bottom_tree_index + Uint64(2)
new_right_bottom_tree = bottom_tree_from_prf_key(
new_right_bottom_tree = HashSubTree.from_prf_key(
prf=self.prf,
hasher=self.hasher,
rand=self.rand,
Expand Down
90 changes: 90 additions & 0 deletions src/lean_spec/subspecs/xmss/subtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
HashTreeLayers,
HashTreeOpening,
Parameter,
PRFKey,
)
from .utils import get_padded_layer

if TYPE_CHECKING:
from .constants import XmssConfig
from .prf import Prf
from .rand import Rand
from .tweak_hash import TweakHasher

Expand Down Expand Up @@ -318,6 +321,93 @@ def new_bottom_tree(
layers=HashTreeLayers(data=truncated + [root_layer]),
)

@classmethod
def from_prf_key(
cls,
prf: "Prf",
hasher: "TweakHasher",
rand: "Rand",
config: "XmssConfig",
prf_key: PRFKey,
bottom_tree_index: Uint64,
parameter: Parameter,
) -> "HashSubTree":
"""
Generates a single bottom tree on-demand from the PRF key.

This is a key component of the top-bottom tree approach: instead of storing all
one-time secret keys, we regenerate them on-demand using the PRF. This enables
O(sqrt(LIFETIME)) memory usage.

### Algorithm

1. **Determine epoch range**: Bottom tree `i` covers epochs
`[i * sqrt(LIFETIME), (i+1) * sqrt(LIFETIME))`

2. **Generate leaves**: For each epoch in parallel:
- For each chain (0 to DIMENSION-1):
- Derive secret start: `PRF(prf_key, epoch, chain_index)`
- Compute public end: hash chain for `BASE - 1` steps
- Hash all chain ends to get the leaf

3. **Build bottom tree**: Construct the bottom tree from the leaves

Args:
prf: The PRF instance for key derivation.
hasher: The tweakable hash instance.
rand: Random generator for padding values.
config: The XMSS configuration.
prf_key: The master PRF secret key.
bottom_tree_index: The index of the bottom tree to generate (0, 1, 2, ...).
parameter: The public parameter `P` for the hash function.

Returns:
A `HashSubTree` representing the requested bottom tree.
"""
# Calculate the number of leaves per bottom tree: sqrt(LIFETIME).
leafs_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2)

# Determine the epoch range for this bottom tree.
start_epoch = bottom_tree_index * Uint64(leafs_per_bottom_tree)
end_epoch = start_epoch + Uint64(leafs_per_bottom_tree)

# Generate leaf hashes for all epochs in this bottom tree.
leaf_hashes: List[HashDigestVector] = []

for epoch in range(int(start_epoch), int(end_epoch)):
# For each epoch, compute the one-time public key (chain endpoints).
chain_ends: List[HashDigestVector] = []

for chain_index in range(config.DIMENSION):
# Derive the secret start of the chain from the PRF key.
start_digest = prf.apply(prf_key, Uint64(epoch), Uint64(chain_index))

# Compute the public end by hashing BASE - 1 times.
end_digest = hasher.hash_chain(
parameter=parameter,
epoch=Uint64(epoch),
chain_index=chain_index,
start_step=0,
num_steps=config.BASE - 1,
start_digest=start_digest,
)
chain_ends.append(end_digest)

# Hash the chain ends to get the leaf for this epoch.
leaf_tweak = TreeTweak(level=0, index=epoch)
leaf_hash = hasher.apply(parameter, leaf_tweak, chain_ends)
leaf_hashes.append(leaf_hash)

# Build the bottom tree from the leaf hashes.
return cls.new_bottom_tree(
hasher=hasher,
rand=rand,
depth=config.LOG_LIFETIME,
bottom_tree_index=bottom_tree_index,
parameter=parameter,
leaves=leaf_hashes,
)

def root(self) -> HashDigestVector:
"""
Extracts the root digest from this subtree.
Expand Down
100 changes: 2 additions & 98 deletions src/lean_spec/subspecs/xmss/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
"""Utility functions for the XMSS signature scheme."""

from typing import TYPE_CHECKING, List
from typing import List

from ...types.uint import Uint64
from ..koalabear import Fp, P
from .constants import XmssConfig
from .rand import Rand
from .types import HashDigestList, HashDigestVector, HashTreeLayer, Parameter, PRFKey

if TYPE_CHECKING:
from .prf import Prf
from .subtree import HashSubTree
from .tweak_hash import TweakHasher
from .types import HashDigestList, HashDigestVector, HashTreeLayer


def get_padded_layer(
Expand Down Expand Up @@ -159,93 +153,3 @@ def expand_activation_time(
end_bottom_tree_index = end // c

return (start_bottom_tree_index, end_bottom_tree_index)


def bottom_tree_from_prf_key(
prf: "Prf",
hasher: "TweakHasher",
rand: Rand,
config: XmssConfig,
prf_key: PRFKey,
bottom_tree_index: Uint64,
parameter: Parameter,
) -> "HashSubTree":
"""
Generates a single bottom tree on-demand from the PRF key.

This is a key component of the top-bottom tree approach: instead of storing all
one-time secret keys, we regenerate them on-demand using the PRF. This enables
O(sqrt(LIFETIME)) memory usage.

### Algorithm

1. **Determine epoch range**: Bottom tree `i` covers epochs
`[i * sqrt(LIFETIME), (i+1) * sqrt(LIFETIME))`

2. **Generate leaves**: For each epoch in parallel:
- For each chain (0 to DIMENSION-1):
- Derive secret start: `PRF(prf_key, epoch, chain_index)`
- Compute public end: hash chain for `BASE - 1` steps
- Hash all chain ends to get the leaf

3. **Build bottom tree**: Construct the bottom tree from the leaves

Args:
prf: The PRF instance for key derivation.
hasher: The tweakable hash instance.
rand: Random generator for padding values.
config: The XMSS configuration.
prf_key: The master PRF secret key.
bottom_tree_index: The index of the bottom tree to generate (0, 1, 2, ...).
parameter: The public parameter `P` for the hash function.

Returns:
A `HashSubTree` representing the requested bottom tree.
"""
from .tweak_hash import TreeTweak

# Calculate the number of leaves per bottom tree: sqrt(LIFETIME).
leafs_per_bottom_tree = 1 << (config.LOG_LIFETIME // 2)

# Determine the epoch range for this bottom tree.
start_epoch = bottom_tree_index * Uint64(leafs_per_bottom_tree)
end_epoch = start_epoch + Uint64(leafs_per_bottom_tree)

# Generate leaf hashes for all epochs in this bottom tree.
leaf_hashes: List[HashDigestVector] = []

for epoch in range(int(start_epoch), int(end_epoch)):
# For each epoch, compute the one-time public key (chain endpoints).
chain_ends: List[HashDigestVector] = []

for chain_index in range(config.DIMENSION):
# Derive the secret start of the chain from the PRF key.
start_digest = prf.apply(prf_key, Uint64(epoch), Uint64(chain_index))

# Compute the public end by hashing BASE - 1 times.
end_digest = hasher.hash_chain(
parameter=parameter,
epoch=Uint64(epoch),
chain_index=chain_index,
start_step=0,
num_steps=config.BASE - 1,
start_digest=start_digest,
)
chain_ends.append(end_digest)

# Hash the chain ends to get the leaf for this epoch.
leaf_tweak = TreeTweak(level=0, index=epoch)
leaf_hash = hasher.apply(parameter, leaf_tweak, chain_ends)
leaf_hashes.append(leaf_hash)

# Build the bottom tree from the leaf hashes.
from .subtree import HashSubTree

return HashSubTree.new_bottom_tree(
hasher=hasher,
rand=rand,
depth=config.LOG_LIFETIME,
bottom_tree_index=bottom_tree_index,
parameter=parameter,
leaves=leaf_hashes,
)
Loading
Loading