|
31 | 31 | from lean_spec.types import StrictBaseModel, Uint64 |
32 | 32 |
|
33 | 33 | from ..koalabear import Fp |
| 34 | +from ._validation import enforce_strict_types |
34 | 35 | from .constants import ( |
35 | 36 | PROD_CONFIG, |
36 | 37 | TEST_CONFIG, |
@@ -86,12 +87,9 @@ class TweakHasher(StrictBaseModel): |
86 | 87 | """Poseidon permutation instance for hashing.""" |
87 | 88 |
|
88 | 89 | @model_validator(mode="after") |
89 | | - def enforce_strict_types(self) -> "TweakHasher": |
| 90 | + def _validate_strict_types(self) -> "TweakHasher": |
90 | 91 | """Reject subclasses to prevent type confusion attacks.""" |
91 | | - if type(self.config) is not XmssConfig: |
92 | | - raise TypeError("config must be exactly XmssConfig, not a subclass") |
93 | | - if type(self.poseidon) is not PoseidonXmss: |
94 | | - raise TypeError("poseidon must be exactly PoseidonXmss, not a subclass") |
| 92 | + enforce_strict_types(self, config=XmssConfig, poseidon=PoseidonXmss) |
95 | 93 | return self |
96 | 94 |
|
97 | 95 | def _encode_tweak(self, tweak: TreeTweak | ChainTweak, length: int) -> list[Fp]: |
@@ -123,17 +121,18 @@ def _encode_tweak(self, tweak: TreeTweak | ChainTweak, length: int) -> list[Fp]: |
123 | 121 | # Pack the tweak's integer fields into a single large integer. |
124 | 122 | # |
125 | 123 | # A hardcoded prefix is included for domain separation between tweak types. |
126 | | - if isinstance(tweak, TreeTweak): |
127 | | - # Packing scheme: (level << 40) | (index << 8) | PREFIX |
128 | | - acc = (tweak.level << 40) | (int(tweak.index) << 8) | TWEAK_PREFIX_TREE.value |
129 | | - else: |
130 | | - # Packing scheme: (epoch << 24) | (chain_index << 16) | (step << 8) | PREFIX |
131 | | - acc = ( |
132 | | - (int(tweak.epoch) << 24) |
133 | | - | (tweak.chain_index << 16) |
134 | | - | (tweak.step << 8) |
135 | | - | TWEAK_PREFIX_CHAIN.value |
136 | | - ) |
| 124 | + match tweak: |
| 125 | + case TreeTweak(level=level, index=index): |
| 126 | + # Packing scheme: (level << 40) | (index << 8) | PREFIX |
| 127 | + acc = (level << 40) | (int(index) << 8) | TWEAK_PREFIX_TREE.value |
| 128 | + case ChainTweak(epoch=epoch, chain_index=chain_index, step=step): |
| 129 | + # Packing scheme: (epoch << 24) | (chain_index << 16) | (step << 8) | PREFIX |
| 130 | + acc = ( |
| 131 | + (int(epoch) << 24) |
| 132 | + | (chain_index << 16) |
| 133 | + | (step << 8) |
| 134 | + | TWEAK_PREFIX_CHAIN.value |
| 135 | + ) |
137 | 136 |
|
138 | 137 | # Decompose the packed integer `acc` into a list of base-P field elements. |
139 | 138 | return int_to_base_p(acc, length) |
|
0 commit comments