Skip to content

Commit 18d43d5

Browse files
committed
fix using pydantic
1 parent 3f6226c commit 18d43d5

File tree

2 files changed

+61
-40
lines changed

2 files changed

+61
-40
lines changed

src/lean_spec/subspecs/poseidon2/permutation.py

Lines changed: 59 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# subspecs/poseidon2/permutation.py
2+
13
"""
24
A minimal Python specification for the Poseidon2 permutation.
35
@@ -6,7 +8,9 @@
68
"""
79

810
from itertools import chain
9-
from typing import List, NamedTuple
11+
from typing import List
12+
13+
from pydantic import BaseModel, ConfigDict, Field, model_validator
1014

1115
from ..koalabear.field import Fp
1216

@@ -24,30 +28,44 @@
2428
"""
2529

2630

27-
class Poseidon2Params(NamedTuple):
31+
class Poseidon2Params(BaseModel):
2832
"""
2933
Encapsulates all necessary parameters for a specific Poseidon2 instance.
3034
3135
This structure holds the configuration for a given state width, including
3236
the number of rounds and the constants for the internal linear layer.
37+
"""
3338

34-
Attributes:
35-
WIDTH (int): The size of the state (t).
36-
37-
ROUNDS_F (int): The total number of "full" rounds, where the S-box is
38-
applied to the entire state.
39-
40-
ROUNDS_P (int): The number of "partial" rounds, where the S-box is
41-
applied to only the first element of the state.
39+
# Configuration to make the model immutable and prevent extra arguments.
40+
model_config = ConfigDict(
41+
frozen=True,
42+
extra="forbid",
43+
arbitrary_types_allowed=True,
44+
)
4245

43-
INTERNAL_DIAG_VECTORS (List[Fp]): The diagonal vectors for the
44-
efficient internal linear layer matrix (M_I).
45-
"""
46+
width: int = Field(gt=0, description="The size of the state (t).")
47+
rounds_f: int = Field(gt=0, description="Total number of 'full' rounds.")
48+
rounds_p: int = Field(
49+
ge=0, description="Total number of 'partial' rounds."
50+
)
51+
internal_diag_vectors: List[Fp] = Field(
52+
min_length=1,
53+
description=(
54+
"Diagonal vectors for the efficient "
55+
"internal linear layer matrix (M_I)."
56+
),
57+
)
4658

47-
WIDTH: int
48-
ROUNDS_F: int
49-
ROUNDS_P: int
50-
INTERNAL_DIAG_VECTORS: List[Fp]
59+
@model_validator(mode="after")
60+
def check_vector_length(self) -> "Poseidon2Params":
61+
"""Length of the diagonal vector should match the state width."""
62+
if len(self.internal_diag_vectors) != self.width:
63+
raise ValueError(
64+
f"Length of internal diagonal vector "
65+
"({len(self.internal_diag_vectors)}) "
66+
f"must be equal to width ({self.width})."
67+
)
68+
return self
5169

5270

5371
def _generate_round_constants(params: Poseidon2Params) -> List[Fp]:
@@ -66,11 +84,11 @@ def _generate_round_constants(params: Poseidon2Params) -> List[Fp]:
6684
# The total number of constants needed for the entire permutation.
6785
#
6886
# This is the sum of constants for all full rounds and all partial rounds.
69-
# - Full rounds require `WIDTH` constants each
87+
# - Full rounds require `width` constants each
7088
# (one for each state element).
7189
# - Partial rounds require 1 constant each
7290
# (for the first state element).
73-
total_constants = (params.ROUNDS_F * params.WIDTH) + params.ROUNDS_P
91+
total_constants = (params.rounds_f * params.width) + params.rounds_p
7492

7593
# For the specification, we generate the constants as a deterministic d
7694
# sequence of integers.
@@ -84,10 +102,10 @@ def _generate_round_constants(params: Poseidon2Params) -> List[Fp]:
84102

85103
# Parameters for WIDTH = 16
86104
PARAMS_16 = Poseidon2Params(
87-
WIDTH=16,
88-
ROUNDS_F=8,
89-
ROUNDS_P=20,
90-
INTERNAL_DIAG_VECTORS=[
105+
width=16,
106+
rounds_f=8,
107+
rounds_p=20,
108+
internal_diag_vectors=[
91109
Fp(value=-2),
92110
Fp(value=1),
93111
Fp(value=2),
@@ -109,10 +127,10 @@ def _generate_round_constants(params: Poseidon2Params) -> List[Fp]:
109127

110128
# Parameters for WIDTH = 24
111129
PARAMS_24 = Poseidon2Params(
112-
WIDTH=24,
113-
ROUNDS_F=8,
114-
ROUNDS_P=23,
115-
INTERNAL_DIAG_VECTORS=[
130+
width=24,
131+
rounds_f=8,
132+
rounds_p=23,
133+
internal_diag_vectors=[
116134
Fp(value=-2),
117135
Fp(value=1),
118136
Fp(value=2),
@@ -203,9 +221,12 @@ def external_linear_layer(state: List[Fp], width: int) -> List[Fp]:
203221

204222
# Apply the outer circulant structure for global diffusion.
205223
#
224+
# This is equivalent to multiplying by circ(2*I, I, ..., I)
225+
# after the M4 stage.
226+
#
206227
# We precompute the four sums of elements at the same offset in each chunk.
207228
# For each k in 0..4:
208-
# sums[k] = state[k] + state[4 + k] + state[8 + k] + ... up to width
229+
# sums[k] = state[k] + state[4 + k] + state[8 + k] + ... up to width
209230
sums = [
210231
sum((state_after_m4[j + k] for j in range(0, width, 4)), Fp(value=0))
211232
for k in range(4)
@@ -245,7 +266,7 @@ def internal_linear_layer(
245266
# This is the efficient computation of (J + D)s.
246267
new_state = [
247268
s * d + s_sum
248-
for s, d in zip(state, params.INTERNAL_DIAG_VECTORS, strict=False)
269+
for s, d in zip(state, params.internal_diag_vectors, strict=False)
249270
]
250271
return new_state
251272

@@ -270,36 +291,36 @@ def permute(state: List[Fp], params: Poseidon2Params) -> List[Fp]:
270291
The new state after applying the permutation.
271292
"""
272293
# Ensure the input state has the correct dimensions.
273-
if len(state) != params.WIDTH:
274-
raise ValueError(f"Input state must have length {params.WIDTH}")
294+
if len(state) != params.width:
295+
raise ValueError(f"Input state must have length {params.width}")
275296

276297
# Generate the deterministic round constants for this parameter set.
277298
round_constants = _generate_round_constants(params)
278299
# The number of full rounds is split between the beginning and end.
279-
half_rounds_f = params.ROUNDS_F // 2
300+
half_rounds_f = params.rounds_f // 2
280301
# Initialize index for accessing the flat list of round constants.
281302
const_idx = 0
282303

283304
# 1. Initial Linear Layer
284305
#
285306
# Another linear layer is applied at the start to prevent certain algebraic
286307
# attacks by ensuring the permutation begins with a diffusion layer.
287-
state = external_linear_layer(list(state), params.WIDTH)
308+
state = external_linear_layer(list(state), params.width)
288309

289310
# 2. First Half of Full Rounds (R_F / 2)
290311
for _r in range(half_rounds_f):
291312
# Add round constants to the entire state.
292313
state = [
293314
s + round_constants[const_idx + i] for i, s in enumerate(state)
294315
]
295-
const_idx += params.WIDTH
316+
const_idx += params.width
296317
# Apply the S-box (x -> x^d) to the full state.
297318
state = [s**S_BOX_DEGREE for s in state]
298319
# Apply the external linear layer for diffusion.
299-
state = external_linear_layer(state, params.WIDTH)
320+
state = external_linear_layer(state, params.width)
300321

301322
# 3. Partial Rounds (R_P)
302-
for _r in range(params.ROUNDS_P):
323+
for _r in range(params.rounds_p):
303324
# Add a single round constant to the first state element.
304325
state[0] += round_constants[const_idx]
305326
const_idx += 1
@@ -316,10 +337,10 @@ def permute(state: List[Fp], params: Poseidon2Params) -> List[Fp]:
316337
state = [
317338
s + round_constants[const_idx + i] for i, s in enumerate(state)
318339
]
319-
const_idx += params.WIDTH
340+
const_idx += params.width
320341
# Apply the S-box to the full state.
321342
state = [s**S_BOX_DEGREE for s in state]
322343
# Apply the external linear layer for diffusion.
323-
state = external_linear_layer(state, params.WIDTH)
344+
state = external_linear_layer(state, params.width)
324345

325346
return state

tests/lean_spec/subspecs/poseidon2/test_permutation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def test_permutation_vector(
131131
output_state = permute(input_state, params)
132132

133133
# Verify the output
134-
assert len(output_state) == params.WIDTH
134+
assert len(output_state) == params.width
135135
assert output_state == expected_output, (
136-
f"Permutation output for width {params.WIDTH} did not match."
136+
f"Permutation output for width {params.width} did not match."
137137
)

0 commit comments

Comments
 (0)