Skip to content

Commit 36a137e

Browse files
authored
Merge pull request #3 from tcoratger/poseidon-spec
subspec: implement Poseidon2 spec with tests
2 parents 5cf1afa + acc4699 commit 36a137e

File tree

4 files changed

+488
-6
lines changed

4 files changed

+488
-6
lines changed
Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1-
"""
2-
Specifications for the Poseidon2 cryptographic permutation for
3-
zero-knowledge applications.
4-
"""
1+
"""Specification for the Poseidon2 permutation."""
2+
3+
from .permutation import (
4+
PARAMS_16,
5+
PARAMS_24,
6+
permute,
7+
)
8+
9+
__all__ = [
10+
"permute",
11+
"PARAMS_16",
12+
"PARAMS_24",
13+
]
Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,338 @@
1+
# subspecs/poseidon2/permutation.py
2+
3+
"""
4+
A minimal Python specification for the Poseidon2 permutation.
5+
6+
The design is based on the paper "Poseidon2: A Faster Version of the Poseidon
7+
Hash Function" (https://eprint.iacr.org/2023/323).
8+
"""
9+
10+
from itertools import chain
11+
from typing import List
12+
13+
from pydantic import BaseModel, ConfigDict, Field, model_validator
14+
15+
from ..koalabear.field import Fp
16+
17+
# =================================================================
18+
# Poseidon2 Parameter Definitions
19+
# =================================================================
20+
21+
S_BOX_DEGREE = 3
22+
"""
23+
The S-box exponent `d`.
24+
25+
For fields where `gcd(d, p-1) = 1`, `x -> x^d` is a permutation.
26+
27+
For KoalaBear, `d=3` is chosen for its low degree.
28+
"""
29+
30+
31+
class Poseidon2Params(BaseModel):
32+
"""
33+
Encapsulates all necessary parameters for a specific Poseidon2 instance.
34+
35+
This structure holds the configuration for a given state width, including
36+
the number of rounds and the constants for the internal linear layer.
37+
"""
38+
39+
# Configuration to make the model immutable and prevent extra arguments.
40+
model_config = ConfigDict(
41+
frozen=True,
42+
extra="forbid",
43+
arbitrary_types_allowed=True,
44+
)
45+
46+
width: int = Field(gt=0, description="The size of the state (t).")
47+
rounds_f: int = Field(gt=0, description="Total number of 'full' rounds.")
48+
rounds_p: int = Field(
49+
ge=0, description="Total number of 'partial' rounds."
50+
)
51+
internal_diag_vectors: List[Fp] = Field(
52+
min_length=1,
53+
description=(
54+
"Diagonal vectors for the efficient "
55+
"internal linear layer matrix (M_I)."
56+
),
57+
)
58+
59+
@model_validator(mode="after")
60+
def check_vector_length(self) -> "Poseidon2Params":
61+
"""Length of the diagonal vector should match the state width."""
62+
if len(self.internal_diag_vectors) != self.width:
63+
raise ValueError(
64+
f"Length of internal diagonal vector "
65+
"({len(self.internal_diag_vectors)}) "
66+
f"must be equal to width ({self.width})."
67+
)
68+
return self
69+
70+
71+
def _generate_spec_test_round_constants(params: Poseidon2Params) -> List[Fp]:
72+
"""
73+
Generates a deterministic list of round constants for testing the spec.
74+
75+
!!! WARNING !!!
76+
This function produces a simple, predictable sequence of integers for the
77+
sole purpose of testing the permutation's algebraic structure. Production
78+
implementations MUST use constants generated from a secure,
79+
unpredictable source.
80+
81+
Args:
82+
params: The object defining the permutation's configuration.
83+
84+
Returns:
85+
A list of Fp elements to be used as round constants for tests.
86+
"""
87+
# The total number of constants needed for the entire permutation.
88+
total_constants = (params.rounds_f * params.width) + params.rounds_p
89+
90+
# For the specification, we generate the constants as a deterministic
91+
# sequence of integers. This is sufficient to define the mechanics.
92+
return [Fp(value=i) for i in range(total_constants)]
93+
94+
95+
# Parameters for WIDTH = 16
96+
PARAMS_16 = Poseidon2Params(
97+
width=16,
98+
rounds_f=8,
99+
rounds_p=20,
100+
internal_diag_vectors=[
101+
Fp(value=-2),
102+
Fp(value=1),
103+
Fp(value=2),
104+
Fp(value=1) / Fp(value=2),
105+
Fp(value=3),
106+
Fp(value=4),
107+
Fp(value=-1) / Fp(value=2),
108+
Fp(value=-3),
109+
Fp(value=-4),
110+
Fp(value=1) / Fp(value=2**8),
111+
Fp(value=1) / Fp(value=8),
112+
Fp(value=1) / Fp(value=2**24),
113+
Fp(value=-1) / Fp(value=2**8),
114+
Fp(value=-1) / Fp(value=8),
115+
Fp(value=-1) / Fp(value=16),
116+
Fp(value=-1) / Fp(value=2**24),
117+
],
118+
)
119+
120+
# Parameters for WIDTH = 24
121+
PARAMS_24 = Poseidon2Params(
122+
width=24,
123+
rounds_f=8,
124+
rounds_p=23,
125+
internal_diag_vectors=[
126+
Fp(value=-2),
127+
Fp(value=1),
128+
Fp(value=2),
129+
Fp(value=1) / Fp(value=2),
130+
Fp(value=3),
131+
Fp(value=4),
132+
Fp(value=-1) / Fp(value=2),
133+
Fp(value=-3),
134+
Fp(value=-4),
135+
Fp(value=1) / Fp(value=2**8),
136+
Fp(value=1) / Fp(value=4),
137+
Fp(value=1) / Fp(value=8),
138+
Fp(value=1) / Fp(value=16),
139+
Fp(value=1) / Fp(value=32),
140+
Fp(value=1) / Fp(value=64),
141+
Fp(value=1) / Fp(value=2**24),
142+
Fp(value=-1) / Fp(value=2**8),
143+
Fp(value=-1) / Fp(value=8),
144+
Fp(value=-1) / Fp(value=16),
145+
Fp(value=-1) / Fp(value=32),
146+
Fp(value=-1) / Fp(value=64),
147+
Fp(value=-1) / Fp(value=2**7),
148+
Fp(value=-1) / Fp(value=2**9),
149+
Fp(value=-1) / Fp(value=2**24),
150+
],
151+
)
152+
153+
# Base 4x4 matrix, used in the external linear layer.
154+
M4_MATRIX = [
155+
[Fp(value=2), Fp(value=3), Fp(value=1), Fp(value=1)],
156+
[Fp(value=1), Fp(value=2), Fp(value=3), Fp(value=1)],
157+
[Fp(value=1), Fp(value=1), Fp(value=2), Fp(value=3)],
158+
[Fp(value=3), Fp(value=1), Fp(value=1), Fp(value=2)],
159+
]
160+
161+
# =================================================================
162+
# Linear Layers
163+
# =================================================================
164+
165+
166+
def _apply_m4(chunk: List[Fp]) -> List[Fp]:
167+
"""
168+
Applies the 4x4 M4 MDS matrix to a 4-element chunk of the state.
169+
This is a helper function for the external linear layer.
170+
171+
Args:
172+
chunk: A list of 4 Fp elements.
173+
174+
Returns:
175+
The transformed 4-element chunk.
176+
"""
177+
# Initialize the result vector with zeros.
178+
result = [Fp(value=0)] * 4
179+
# Perform standard matrix-vector multiplication.
180+
for i in range(4):
181+
for j in range(4):
182+
result[i] += M4_MATRIX[i][j] * chunk[j]
183+
return result
184+
185+
186+
def external_linear_layer(state: List[Fp], width: int) -> List[Fp]:
187+
"""
188+
Applies the external linear layer (M_E).
189+
190+
This layer provides strong diffusion across the entire state and is used
191+
in the full rounds. For a state of size t=4k, it's constructed from the
192+
base M4 matrix to form a larger circulant-like matrix, which is efficient
193+
while ensuring that a change in any single element affects all other
194+
elements after application.
195+
196+
The process follows Appendix B of the paper.
197+
198+
Args:
199+
state: The current state vector.
200+
width: The width `t` of the state.
201+
202+
Returns:
203+
The state vector after applying the external linear layer.
204+
"""
205+
# Apply the M4 matrix to each 4-element chunk of the state.
206+
#
207+
# This provides strong local diffusion within each block.
208+
state_after_m4 = list(
209+
chain.from_iterable(
210+
_apply_m4(state[i : i + 4]) for i in range(0, width, 4)
211+
)
212+
)
213+
214+
# Apply the outer circulant structure for global diffusion.
215+
#
216+
# This is equivalent to multiplying by circ(2*I, I, ..., I)
217+
# after the M4 stage.
218+
#
219+
# We precompute the four sums of elements at the same offset in each chunk.
220+
# For each k in 0..4:
221+
# sums[k] = state[k] + state[4 + k] + state[8 + k] + ... up to width
222+
sums = [
223+
sum((state_after_m4[j + k] for j in range(0, width, 4)), Fp(value=0))
224+
for k in range(4)
225+
]
226+
227+
# Add the corresponding sum to each element of the state.
228+
state_after_circulant = [
229+
s + sums[i % 4] for i, s in enumerate(state_after_m4)
230+
]
231+
232+
return state_after_circulant
233+
234+
235+
def internal_linear_layer(
236+
state: List[Fp], params: Poseidon2Params
237+
) -> List[Fp]:
238+
"""
239+
Applies the internal linear layer (M_I).
240+
241+
This layer is used during partial rounds and is optimized for speed. Its
242+
matrix is constructed as M_I = J + D, where J is the all-ones matrix and D
243+
is a diagonal matrix. This structure allows the matrix-vector product to be
244+
computed in O(t) time instead of O(t^2), as M_I * s = J*s + D*s.
245+
The term J*s is a vector where each element is the sum of
246+
all elements in s.
247+
248+
Args:
249+
state: The current state vector.
250+
params: The Poseidon2Params object containing the diagonal vectors.
251+
252+
Returns:
253+
The state vector after applying the internal linear layer.
254+
"""
255+
# Calculate the sum of all elements in the state vector.
256+
s_sum = sum(state, Fp(value=0))
257+
# For each element s_i, compute s_i' = d_i * s_i + sum(s).
258+
# This is the efficient computation of (J + D)s.
259+
new_state = [
260+
s * d + s_sum
261+
for s, d in zip(state, params.internal_diag_vectors, strict=False)
262+
]
263+
return new_state
264+
265+
266+
# =================================================================
267+
# Core Permutation
268+
# =================================================================
269+
270+
271+
def permute(state: List[Fp], params: Poseidon2Params) -> List[Fp]:
272+
"""
273+
Performs the full Poseidon2 permutation on the given state.
274+
275+
The permutation follows the structure:
276+
Initial Layer -> Full Rounds -> Partial Rounds -> Full Rounds
277+
278+
Args:
279+
state: A list of Fp elements representing the current state.
280+
params: The object defining the permutation's configuration.
281+
282+
Returns:
283+
The new state after applying the permutation.
284+
"""
285+
# Ensure the input state has the correct dimensions.
286+
if len(state) != params.width:
287+
raise ValueError(f"Input state must have length {params.width}")
288+
289+
# Generate the deterministic round constants for this parameter set.
290+
round_constants = _generate_spec_test_round_constants(params)
291+
# The number of full rounds is split between the beginning and end.
292+
half_rounds_f = params.rounds_f // 2
293+
# Initialize index for accessing the flat list of round constants.
294+
const_idx = 0
295+
296+
# 1. Initial Linear Layer
297+
#
298+
# Another linear layer is applied at the start to prevent certain algebraic
299+
# attacks by ensuring the permutation begins with a diffusion layer.
300+
state = external_linear_layer(list(state), params.width)
301+
302+
# 2. First Half of Full Rounds (R_F / 2)
303+
for _r in range(half_rounds_f):
304+
# Add round constants to the entire state.
305+
state = [
306+
s + round_constants[const_idx + i] for i, s in enumerate(state)
307+
]
308+
const_idx += params.width
309+
# Apply the S-box (x -> x^d) to the full state.
310+
state = [s**S_BOX_DEGREE for s in state]
311+
# Apply the external linear layer for diffusion.
312+
state = external_linear_layer(state, params.width)
313+
314+
# 3. Partial Rounds (R_P)
315+
for _r in range(params.rounds_p):
316+
# Add a single round constant to the first state element.
317+
state[0] += round_constants[const_idx]
318+
const_idx += 1
319+
# Apply the S-box to the first state element only.
320+
#
321+
# This is the main optimization of the Hades design.
322+
state[0] = state[0] ** S_BOX_DEGREE
323+
# Apply the internal linear layer.
324+
state = internal_linear_layer(state, params)
325+
326+
# 4. Second Half of Full Rounds (R_F / 2)
327+
for _r in range(half_rounds_f):
328+
# Add round constants to the entire state.
329+
state = [
330+
s + round_constants[const_idx + i] for i, s in enumerate(state)
331+
]
332+
const_idx += params.width
333+
# Apply the S-box to the full state.
334+
state = [s**S_BOX_DEGREE for s in state]
335+
# Apply the external linear layer for diffusion.
336+
state = external_linear_layer(state, params.width)
337+
338+
return state

0 commit comments

Comments
 (0)