1+ # subspecs/poseidon2/permutation.py
2+
13"""
24A minimal Python specification for the Poseidon2 permutation.
35
68"""
79
810from itertools import chain
9- from typing import List , NamedTuple
11+ from typing import List
12+
13+ from pydantic import BaseModel , ConfigDict , Field , model_validator
1014
1115from ..koalabear .field import Fp
1216
2428"""
2529
2630
27- class Poseidon2Params (NamedTuple ):
31+ class Poseidon2Params (BaseModel ):
2832 """
2933 Encapsulates all necessary parameters for a specific Poseidon2 instance.
3034
3135 This structure holds the configuration for a given state width, including
3236 the number of rounds and the constants for the internal linear layer.
37+ """
3338
34- Attributes:
35- WIDTH (int): The size of the state (t).
36-
37- ROUNDS_F (int): The total number of "full" rounds, where the S-box is
38- applied to the entire state.
39-
40- ROUNDS_P (int): The number of "partial" rounds, where the S-box is
41- applied to only the first element of the state.
39+ # Configuration to make the model immutable and prevent extra arguments.
40+ model_config = ConfigDict (
41+ frozen = True ,
42+ extra = "forbid" ,
43+ arbitrary_types_allowed = True ,
44+ )
4245
43- INTERNAL_DIAG_VECTORS (List[Fp]): The diagonal vectors for the
44- efficient internal linear layer matrix (M_I).
45- """
46+ width : int = Field (gt = 0 , description = "The size of the state (t)." )
47+ rounds_f : int = Field (gt = 0 , description = "Total number of 'full' rounds." )
48+ rounds_p : int = Field (
49+ ge = 0 , description = "Total number of 'partial' rounds."
50+ )
51+ internal_diag_vectors : List [Fp ] = Field (
52+ min_length = 1 ,
53+ description = (
54+ "Diagonal vectors for the efficient "
55+ "internal linear layer matrix (M_I)."
56+ ),
57+ )
4658
47- WIDTH : int
48- ROUNDS_F : int
49- ROUNDS_P : int
50- INTERNAL_DIAG_VECTORS : List [Fp ]
59+ @model_validator (mode = "after" )
60+ def check_vector_length (self ) -> "Poseidon2Params" :
61+ """Length of the diagonal vector should match the state width."""
62+ if len (self .internal_diag_vectors ) != self .width :
63+ raise ValueError (
64+ f"Length of internal diagonal vector "
65+ "({len(self.internal_diag_vectors)}) "
66+ f"must be equal to width ({ self .width } )."
67+ )
68+ return self
5169
5270
5371def _generate_round_constants (params : Poseidon2Params ) -> List [Fp ]:
@@ -66,11 +84,11 @@ def _generate_round_constants(params: Poseidon2Params) -> List[Fp]:
6684 # The total number of constants needed for the entire permutation.
6785 #
6886 # This is the sum of constants for all full rounds and all partial rounds.
69- # - Full rounds require `WIDTH ` constants each
87+ # - Full rounds require `width ` constants each
7088 # (one for each state element).
7189 # - Partial rounds require 1 constant each
7290 # (for the first state element).
73- total_constants = (params .ROUNDS_F * params .WIDTH ) + params .ROUNDS_P
91+ total_constants = (params .rounds_f * params .width ) + params .rounds_p
7492
7593 # For the specification, we generate the constants as a deterministic d
7694 # sequence of integers.
@@ -84,10 +102,10 @@ def _generate_round_constants(params: Poseidon2Params) -> List[Fp]:
84102
85103# Parameters for WIDTH = 16
86104PARAMS_16 = Poseidon2Params (
87- WIDTH = 16 ,
88- ROUNDS_F = 8 ,
89- ROUNDS_P = 20 ,
90- INTERNAL_DIAG_VECTORS = [
105+ width = 16 ,
106+ rounds_f = 8 ,
107+ rounds_p = 20 ,
108+ internal_diag_vectors = [
91109 Fp (value = - 2 ),
92110 Fp (value = 1 ),
93111 Fp (value = 2 ),
@@ -109,10 +127,10 @@ def _generate_round_constants(params: Poseidon2Params) -> List[Fp]:
109127
110128# Parameters for WIDTH = 24
111129PARAMS_24 = Poseidon2Params (
112- WIDTH = 24 ,
113- ROUNDS_F = 8 ,
114- ROUNDS_P = 23 ,
115- INTERNAL_DIAG_VECTORS = [
130+ width = 24 ,
131+ rounds_f = 8 ,
132+ rounds_p = 23 ,
133+ internal_diag_vectors = [
116134 Fp (value = - 2 ),
117135 Fp (value = 1 ),
118136 Fp (value = 2 ),
@@ -203,9 +221,12 @@ def external_linear_layer(state: List[Fp], width: int) -> List[Fp]:
203221
204222 # Apply the outer circulant structure for global diffusion.
205223 #
224+ # This is equivalent to multiplying by circ(2*I, I, ..., I)
225+ # after the M4 stage.
226+ #
206227 # We precompute the four sums of elements at the same offset in each chunk.
207228 # For each k in 0..4:
208- # sums[k] = state[k] + state[4 + k] + state[8 + k] + ... up to width
229+ # sums[k] = state[k] + state[4 + k] + state[8 + k] + ... up to width
209230 sums = [
210231 sum ((state_after_m4 [j + k ] for j in range (0 , width , 4 )), Fp (value = 0 ))
211232 for k in range (4 )
@@ -245,7 +266,7 @@ def internal_linear_layer(
245266 # This is the efficient computation of (J + D)s.
246267 new_state = [
247268 s * d + s_sum
248- for s , d in zip (state , params .INTERNAL_DIAG_VECTORS , strict = False )
269+ for s , d in zip (state , params .internal_diag_vectors , strict = False )
249270 ]
250271 return new_state
251272
@@ -270,36 +291,36 @@ def permute(state: List[Fp], params: Poseidon2Params) -> List[Fp]:
270291 The new state after applying the permutation.
271292 """
272293 # Ensure the input state has the correct dimensions.
273- if len (state ) != params .WIDTH :
274- raise ValueError (f"Input state must have length { params .WIDTH } " )
294+ if len (state ) != params .width :
295+ raise ValueError (f"Input state must have length { params .width } " )
275296
276297 # Generate the deterministic round constants for this parameter set.
277298 round_constants = _generate_round_constants (params )
278299 # The number of full rounds is split between the beginning and end.
279- half_rounds_f = params .ROUNDS_F // 2
300+ half_rounds_f = params .rounds_f // 2
280301 # Initialize index for accessing the flat list of round constants.
281302 const_idx = 0
282303
283304 # 1. Initial Linear Layer
284305 #
285306 # Another linear layer is applied at the start to prevent certain algebraic
286307 # attacks by ensuring the permutation begins with a diffusion layer.
287- state = external_linear_layer (list (state ), params .WIDTH )
308+ state = external_linear_layer (list (state ), params .width )
288309
289310 # 2. First Half of Full Rounds (R_F / 2)
290311 for _r in range (half_rounds_f ):
291312 # Add round constants to the entire state.
292313 state = [
293314 s + round_constants [const_idx + i ] for i , s in enumerate (state )
294315 ]
295- const_idx += params .WIDTH
316+ const_idx += params .width
296317 # Apply the S-box (x -> x^d) to the full state.
297318 state = [s ** S_BOX_DEGREE for s in state ]
298319 # Apply the external linear layer for diffusion.
299- state = external_linear_layer (state , params .WIDTH )
320+ state = external_linear_layer (state , params .width )
300321
301322 # 3. Partial Rounds (R_P)
302- for _r in range (params .ROUNDS_P ):
323+ for _r in range (params .rounds_p ):
303324 # Add a single round constant to the first state element.
304325 state [0 ] += round_constants [const_idx ]
305326 const_idx += 1
@@ -316,10 +337,10 @@ def permute(state: List[Fp], params: Poseidon2Params) -> List[Fp]:
316337 state = [
317338 s + round_constants [const_idx + i ] for i , s in enumerate (state )
318339 ]
319- const_idx += params .WIDTH
340+ const_idx += params .width
320341 # Apply the S-box to the full state.
321342 state = [s ** S_BOX_DEGREE for s in state ]
322343 # Apply the external linear layer for diffusion.
323- state = external_linear_layer (state , params .WIDTH )
344+ state = external_linear_layer (state , params .width )
324345
325346 return state
0 commit comments