Skip to content

Commit e167f5a

Browse files
authored
xmss: add simpler validation (#240)
* xmss: add simpler validation * cleanup * small fix
1 parent 60b7398 commit e167f5a

File tree

8 files changed

+69
-48
lines changed

8 files changed

+69
-48
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""Internal validation utilities for the XMSS scheme."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any
6+
7+
8+
def enforce_strict_types(instance: Any, **field_types: type) -> None:
9+
"""
10+
Validate that specified fields are exact types, not subclasses.
11+
12+
This is a helper function to be called from Pydantic model validators.
13+
14+
It enforces that field values are exactly the declared type, preventing
15+
type confusion attacks where a malicious subclass could override behavior.
16+
17+
Args:
18+
instance: The model instance being validated.
19+
**field_types: Mapping of field names to their exact expected types.
20+
21+
Raises:
22+
TypeError: If any field is a subclass rather than the exact type.
23+
"""
24+
for field_name, expected_type in field_types.items():
25+
value = getattr(instance, field_name)
26+
if type(value) is not expected_type:
27+
raise TypeError(
28+
f"{field_name} must be exactly {expected_type.__name__}, not a subclass"
29+
)

src/lean_spec/subspecs/xmss/interface.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from lean_spec.types import StrictBaseModel, Uint64
2020

21+
from ._validation import enforce_strict_types
2122
from .constants import (
2223
PROD_CONFIG,
2324
TEST_CONFIG,
@@ -60,18 +61,16 @@ class GeneralizedXmssScheme(StrictBaseModel):
6061
"""Random data generator for key generation."""
6162

6263
@model_validator(mode="after")
63-
def enforce_strict_types(self) -> "GeneralizedXmssScheme":
64+
def _validate_strict_types(self) -> "GeneralizedXmssScheme":
6465
"""Reject subclasses to prevent type confusion attacks."""
65-
if type(self.config) is not XmssConfig:
66-
raise TypeError("config must be exactly XmssConfig, not a subclass")
67-
if type(self.prf) is not Prf:
68-
raise TypeError("prf must be exactly Prf, not a subclass")
69-
if type(self.hasher) is not TweakHasher:
70-
raise TypeError("hasher must be exactly TweakHasher, not a subclass")
71-
if type(self.encoder) is not TargetSumEncoder:
72-
raise TypeError("encoder must be exactly TargetSumEncoder, not a subclass")
73-
if type(self.rand) is not Rand:
74-
raise TypeError("rand must be exactly Rand, not a subclass")
66+
enforce_strict_types(
67+
self,
68+
config=XmssConfig,
69+
prf=Prf,
70+
hasher=TweakHasher,
71+
encoder=TargetSumEncoder,
72+
rand=Rand,
73+
)
7574
return self
7675

7776
def key_gen(

src/lean_spec/subspecs/xmss/message_hash.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from lean_spec.types import StrictBaseModel, Uint64
4040

4141
from ..koalabear import Fp, P
42+
from ._validation import enforce_strict_types
4243
from .constants import (
4344
PROD_CONFIG,
4445
TEST_CONFIG,
@@ -64,12 +65,9 @@ class MessageHasher(StrictBaseModel):
6465
"""Poseidon hash engine."""
6566

6667
@model_validator(mode="after")
67-
def enforce_strict_types(self) -> "MessageHasher":
68+
def _validate_strict_types(self) -> "MessageHasher":
6869
"""Reject subclasses to prevent type confusion attacks."""
69-
if type(self.config) is not XmssConfig:
70-
raise TypeError("config must be exactly XmssConfig, not a subclass")
71-
if type(self.poseidon) is not PoseidonXmss:
72-
raise TypeError("poseidon must be exactly PoseidonXmss, not a subclass")
70+
enforce_strict_types(self, config=XmssConfig, poseidon=PoseidonXmss)
7371
return self
7472

7573
def encode_message(self, message: bytes) -> list[Fp]:

src/lean_spec/subspecs/xmss/poseidon.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
Poseidon2Params,
3434
permute,
3535
)
36+
from ._validation import enforce_strict_types
3637
from .utils import int_to_base_p
3738

3839

@@ -46,12 +47,9 @@ class PoseidonXmss(StrictBaseModel):
4647
"""Poseidon2 parameters for 24-width permutation."""
4748

4849
@model_validator(mode="after")
49-
def enforce_strict_types(self) -> "PoseidonXmss":
50+
def _validate_strict_types(self) -> "PoseidonXmss":
5051
"""Reject subclasses to prevent type confusion attacks."""
51-
if type(self.params16) is not Poseidon2Params:
52-
raise TypeError("params16 must be exactly Poseidon2Params, not a subclass")
53-
if type(self.params24) is not Poseidon2Params:
54-
raise TypeError("params24 must be exactly Poseidon2Params, not a subclass")
52+
enforce_strict_types(self, params16=Poseidon2Params, params24=Poseidon2Params)
5553
return self
5654

5755
def compress(self, input_vec: list[Fp], width: int, output_len: int) -> list[Fp]:

src/lean_spec/subspecs/xmss/prf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from lean_spec.subspecs.koalabear import Fp
1818
from lean_spec.types import StrictBaseModel, Uint64
1919

20+
from ._validation import enforce_strict_types
2021
from .constants import (
2122
PRF_KEY_LENGTH,
2223
PROD_CONFIG,
@@ -109,10 +110,9 @@ class Prf(StrictBaseModel):
109110
"""Configuration parameters for the PRF."""
110111

111112
@model_validator(mode="after")
112-
def enforce_strict_types(self) -> "Prf":
113+
def _validate_strict_types(self) -> "Prf":
113114
"""Reject subclasses to prevent type confusion attacks."""
114-
if type(self.config) is not XmssConfig:
115-
raise TypeError("config must be exactly XmssConfig, not a subclass")
115+
enforce_strict_types(self, config=XmssConfig)
116116
return self
117117

118118
def key_gen(self) -> PRFKey:

src/lean_spec/subspecs/xmss/rand.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from lean_spec.types import StrictBaseModel
88

99
from ..koalabear import Fp, P
10+
from ._validation import enforce_strict_types
1011
from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig
1112
from .types import HashDigestVector, Parameter, Randomness
1213

@@ -18,10 +19,9 @@ class Rand(StrictBaseModel):
1819
"""Configuration parameters for the random generator."""
1920

2021
@model_validator(mode="after")
21-
def enforce_strict_types(self) -> "Rand":
22+
def _validate_strict_types(self) -> "Rand":
2223
"""Reject subclasses to prevent type confusion attacks."""
23-
if type(self.config) is not XmssConfig:
24-
raise TypeError("config must be exactly XmssConfig, not a subclass")
24+
enforce_strict_types(self, config=XmssConfig)
2525
return self
2626

2727
def field_elements(self, length: int) -> list[Fp]:

src/lean_spec/subspecs/xmss/target_sum.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from lean_spec.types import StrictBaseModel, Uint64
1212

13+
from ._validation import enforce_strict_types
1314
from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig
1415
from .message_hash import (
1516
PROD_MESSAGE_HASHER,
@@ -34,12 +35,9 @@ class TargetSumEncoder(StrictBaseModel):
3435
"""Message hasher for encoding."""
3536

3637
@model_validator(mode="after")
37-
def enforce_strict_types(self) -> "TargetSumEncoder":
38+
def _validate_strict_types(self) -> "TargetSumEncoder":
3839
"""Reject subclasses to prevent type confusion attacks."""
39-
if type(self.config) is not XmssConfig:
40-
raise TypeError("config must be exactly XmssConfig, not a subclass")
41-
if type(self.message_hasher) is not MessageHasher:
42-
raise TypeError("message_hasher must be exactly MessageHasher, not a subclass")
40+
enforce_strict_types(self, config=XmssConfig, message_hasher=MessageHasher)
4341
return self
4442

4543
def encode(

src/lean_spec/subspecs/xmss/tweak_hash.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from lean_spec.types import StrictBaseModel, Uint64
3232

3333
from ..koalabear import Fp
34+
from ._validation import enforce_strict_types
3435
from .constants import (
3536
PROD_CONFIG,
3637
TEST_CONFIG,
@@ -86,12 +87,9 @@ class TweakHasher(StrictBaseModel):
8687
"""Poseidon permutation instance for hashing."""
8788

8889
@model_validator(mode="after")
89-
def enforce_strict_types(self) -> "TweakHasher":
90+
def _validate_strict_types(self) -> "TweakHasher":
9091
"""Reject subclasses to prevent type confusion attacks."""
91-
if type(self.config) is not XmssConfig:
92-
raise TypeError("config must be exactly XmssConfig, not a subclass")
93-
if type(self.poseidon) is not PoseidonXmss:
94-
raise TypeError("poseidon must be exactly PoseidonXmss, not a subclass")
92+
enforce_strict_types(self, config=XmssConfig, poseidon=PoseidonXmss)
9593
return self
9694

9795
def _encode_tweak(self, tweak: TreeTweak | ChainTweak, length: int) -> list[Fp]:
@@ -123,17 +121,18 @@ def _encode_tweak(self, tweak: TreeTweak | ChainTweak, length: int) -> list[Fp]:
123121
# Pack the tweak's integer fields into a single large integer.
124122
#
125123
# A hardcoded prefix is included for domain separation between tweak types.
126-
if isinstance(tweak, TreeTweak):
127-
# Packing scheme: (level << 40) | (index << 8) | PREFIX
128-
acc = (tweak.level << 40) | (int(tweak.index) << 8) | TWEAK_PREFIX_TREE.value
129-
else:
130-
# Packing scheme: (epoch << 24) | (chain_index << 16) | (step << 8) | PREFIX
131-
acc = (
132-
(int(tweak.epoch) << 24)
133-
| (tweak.chain_index << 16)
134-
| (tweak.step << 8)
135-
| TWEAK_PREFIX_CHAIN.value
136-
)
124+
match tweak:
125+
case TreeTweak(level=level, index=index):
126+
# Packing scheme: (level << 40) | (index << 8) | PREFIX
127+
acc = (level << 40) | (int(index) << 8) | TWEAK_PREFIX_TREE.value
128+
case ChainTweak(epoch=epoch, chain_index=chain_index, step=step):
129+
# Packing scheme: (epoch << 24) | (chain_index << 16) | (step << 8) | PREFIX
130+
acc = (
131+
(int(epoch) << 24)
132+
| (chain_index << 16)
133+
| (step << 8)
134+
| TWEAK_PREFIX_CHAIN.value
135+
)
137136

138137
# Decompose the packed integer `acc` into a list of base-P field elements.
139138
return int_to_base_p(acc, length)

0 commit comments

Comments
 (0)