Skip to content

Commit 3f6226c

Browse files
committed
subspec: implement Poseidon2 spec with tests
1 parent 5cf1afa commit 3f6226c

File tree

4 files changed

+475
-6
lines changed

4 files changed

+475
-6
lines changed
Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1-
"""
2-
Specifications for the Poseidon2 cryptographic permutation for
3-
zero-knowledge applications.
4-
"""
1+
"""Specification for the Poseidon2 permutation."""
2+
3+
from .permutation import (
4+
PARAMS_16,
5+
PARAMS_24,
6+
permute,
7+
)
8+
9+
__all__ = [
10+
"permute",
11+
"PARAMS_16",
12+
"PARAMS_24",
13+
]
Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
"""
2+
A minimal Python specification for the Poseidon2 permutation.
3+
4+
The design is based on the paper "Poseidon2: A Faster Version of the Poseidon
5+
Hash Function" (https://eprint.iacr.org/2023/323).
6+
"""
7+
8+
from itertools import chain
9+
from typing import List, NamedTuple
10+
11+
from ..koalabear.field import Fp
12+
13+
# =================================================================
14+
# Poseidon2 Parameter Definitions
15+
# =================================================================
16+
17+
S_BOX_DEGREE = 3
18+
"""
19+
The S-box exponent `d`.
20+
21+
For fields where `gcd(d, p-1) = 1`, `x -> x^d` is a permutation.
22+
23+
For KoalaBear, `d=3` is chosen for its low degree.
24+
"""
25+
26+
27+
class Poseidon2Params(NamedTuple):
28+
"""
29+
Encapsulates all necessary parameters for a specific Poseidon2 instance.
30+
31+
This structure holds the configuration for a given state width, including
32+
the number of rounds and the constants for the internal linear layer.
33+
34+
Attributes:
35+
WIDTH (int): The size of the state (t).
36+
37+
ROUNDS_F (int): The total number of "full" rounds, where the S-box is
38+
applied to the entire state.
39+
40+
ROUNDS_P (int): The number of "partial" rounds, where the S-box is
41+
applied to only the first element of the state.
42+
43+
INTERNAL_DIAG_VECTORS (List[Fp]): The diagonal vectors for the
44+
efficient internal linear layer matrix (M_I).
45+
"""
46+
47+
WIDTH: int
48+
ROUNDS_F: int
49+
ROUNDS_P: int
50+
INTERNAL_DIAG_VECTORS: List[Fp]
51+
52+
53+
def _generate_round_constants(params: Poseidon2Params) -> List[Fp]:
54+
"""
55+
Generates a deterministic list of round constants for the permutation.
56+
57+
Round constants are added in each round to break symmetries and prevent
58+
attacks like slide or interpolation attacks.
59+
60+
Args:
61+
params: The object defining the permutation's configuration.
62+
63+
Returns:
64+
A list of Fp elements to be used as round constants.
65+
"""
66+
# The total number of constants needed for the entire permutation.
67+
#
68+
# This is the sum of constants for all full rounds and all partial rounds.
69+
# - Full rounds require `WIDTH` constants each
70+
# (one for each state element).
71+
# - Partial rounds require 1 constant each
72+
# (for the first state element).
73+
total_constants = (params.ROUNDS_F * params.WIDTH) + params.ROUNDS_P
74+
75+
# For the specification, we generate the constants as a deterministic d
76+
# sequence of integers.
77+
#
78+
# This is sufficient to define the algorithm's mechanics.
79+
#
80+
# Real-world implementations would use constants generated from a secure,
81+
# pseudo-random source.
82+
return [Fp(value=i) for i in range(total_constants)]
83+
84+
85+
# Parameters for WIDTH = 16
86+
PARAMS_16 = Poseidon2Params(
87+
WIDTH=16,
88+
ROUNDS_F=8,
89+
ROUNDS_P=20,
90+
INTERNAL_DIAG_VECTORS=[
91+
Fp(value=-2),
92+
Fp(value=1),
93+
Fp(value=2),
94+
Fp(value=1) / Fp(value=2),
95+
Fp(value=3),
96+
Fp(value=4),
97+
Fp(value=-1) / Fp(value=2),
98+
Fp(value=-3),
99+
Fp(value=-4),
100+
Fp(value=1) / Fp(value=2**8),
101+
Fp(value=1) / Fp(value=8),
102+
Fp(value=1) / Fp(value=2**24),
103+
Fp(value=-1) / Fp(value=2**8),
104+
Fp(value=-1) / Fp(value=8),
105+
Fp(value=-1) / Fp(value=16),
106+
Fp(value=-1) / Fp(value=2**24),
107+
],
108+
)
109+
110+
# Parameters for WIDTH = 24
111+
PARAMS_24 = Poseidon2Params(
112+
WIDTH=24,
113+
ROUNDS_F=8,
114+
ROUNDS_P=23,
115+
INTERNAL_DIAG_VECTORS=[
116+
Fp(value=-2),
117+
Fp(value=1),
118+
Fp(value=2),
119+
Fp(value=1) / Fp(value=2),
120+
Fp(value=3),
121+
Fp(value=4),
122+
Fp(value=-1) / Fp(value=2),
123+
Fp(value=-3),
124+
Fp(value=-4),
125+
Fp(value=1) / Fp(value=2**8),
126+
Fp(value=1) / Fp(value=4),
127+
Fp(value=1) / Fp(value=8),
128+
Fp(value=1) / Fp(value=16),
129+
Fp(value=1) / Fp(value=32),
130+
Fp(value=1) / Fp(value=64),
131+
Fp(value=1) / Fp(value=2**24),
132+
Fp(value=-1) / Fp(value=2**8),
133+
Fp(value=-1) / Fp(value=8),
134+
Fp(value=-1) / Fp(value=16),
135+
Fp(value=-1) / Fp(value=32),
136+
Fp(value=-1) / Fp(value=64),
137+
Fp(value=-1) / Fp(value=2**7),
138+
Fp(value=-1) / Fp(value=2**9),
139+
Fp(value=-1) / Fp(value=2**24),
140+
],
141+
)
142+
143+
# Base 4x4 matrix, used in the external linear layer.
144+
M4_MATRIX = [
145+
[Fp(value=2), Fp(value=3), Fp(value=1), Fp(value=1)],
146+
[Fp(value=1), Fp(value=2), Fp(value=3), Fp(value=1)],
147+
[Fp(value=1), Fp(value=1), Fp(value=2), Fp(value=3)],
148+
[Fp(value=3), Fp(value=1), Fp(value=1), Fp(value=2)],
149+
]
150+
151+
# =================================================================
152+
# Linear Layers
153+
# =================================================================
154+
155+
156+
def _apply_m4(chunk: List[Fp]) -> List[Fp]:
157+
"""
158+
Applies the 4x4 M4 MDS matrix to a 4-element chunk of the state.
159+
This is a helper function for the external linear layer.
160+
161+
Args:
162+
chunk: A list of 4 Fp elements.
163+
164+
Returns:
165+
The transformed 4-element chunk.
166+
"""
167+
# Initialize the result vector with zeros.
168+
result = [Fp(value=0)] * 4
169+
# Perform standard matrix-vector multiplication.
170+
for i in range(4):
171+
for j in range(4):
172+
result[i] += M4_MATRIX[i][j] * chunk[j]
173+
return result
174+
175+
176+
def external_linear_layer(state: List[Fp], width: int) -> List[Fp]:
177+
"""
178+
Applies the external linear layer (M_E).
179+
180+
This layer provides strong diffusion across the entire state and is used
181+
in the full rounds. For a state of size t=4k, it's constructed from the
182+
base M4 matrix to form a larger circulant-like matrix, which is efficient
183+
while ensuring that a change in any single element affects all other
184+
elements after application.
185+
186+
The process follows Appendix B of the paper.
187+
188+
Args:
189+
state: The current state vector.
190+
width: The width `t` of the state.
191+
192+
Returns:
193+
The state vector after applying the external linear layer.
194+
"""
195+
# Apply the M4 matrix to each 4-element chunk of the state.
196+
#
197+
# This provides strong local diffusion within each block.
198+
state_after_m4 = list(
199+
chain.from_iterable(
200+
_apply_m4(state[i : i + 4]) for i in range(0, width, 4)
201+
)
202+
)
203+
204+
# Apply the outer circulant structure for global diffusion.
205+
#
206+
# We precompute the four sums of elements at the same offset in each chunk.
207+
# For each k in 0..4:
208+
# sums[k] = state[k] + state[4 + k] + state[8 + k] + ... up to width
209+
sums = [
210+
sum((state_after_m4[j + k] for j in range(0, width, 4)), Fp(value=0))
211+
for k in range(4)
212+
]
213+
214+
# Add the corresponding sum to each element of the state.
215+
state_after_circulant = [
216+
s + sums[i % 4] for i, s in enumerate(state_after_m4)
217+
]
218+
219+
return state_after_circulant
220+
221+
222+
def internal_linear_layer(
223+
state: List[Fp], params: Poseidon2Params
224+
) -> List[Fp]:
225+
"""
226+
Applies the internal linear layer (M_I).
227+
228+
This layer is used during partial rounds and is optimized for speed. Its
229+
matrix is constructed as M_I = J + D, where J is the all-ones matrix and D
230+
is a diagonal matrix. This structure allows the matrix-vector product to be
231+
computed in O(t) time instead of O(t^2), as M_I * s = J*s + D*s.
232+
The term J*s is a vector where each element is the sum of
233+
all elements in s.
234+
235+
Args:
236+
state: The current state vector.
237+
params: The Poseidon2Params object containing the diagonal vectors.
238+
239+
Returns:
240+
The state vector after applying the internal linear layer.
241+
"""
242+
# Calculate the sum of all elements in the state vector.
243+
s_sum = sum(state, Fp(value=0))
244+
# For each element s_i, compute s_i' = d_i * s_i + sum(s).
245+
# This is the efficient computation of (J + D)s.
246+
new_state = [
247+
s * d + s_sum
248+
for s, d in zip(state, params.INTERNAL_DIAG_VECTORS, strict=False)
249+
]
250+
return new_state
251+
252+
253+
# =================================================================
254+
# Core Permutation
255+
# =================================================================
256+
257+
258+
def permute(state: List[Fp], params: Poseidon2Params) -> List[Fp]:
259+
"""
260+
Performs the full Poseidon2 permutation on the given state.
261+
262+
The permutation follows the structure:
263+
Initial Layer -> Full Rounds -> Partial Rounds -> Full Rounds
264+
265+
Args:
266+
state: A list of Fp elements representing the current state.
267+
params: The object defining the permutation's configuration.
268+
269+
Returns:
270+
The new state after applying the permutation.
271+
"""
272+
# Ensure the input state has the correct dimensions.
273+
if len(state) != params.WIDTH:
274+
raise ValueError(f"Input state must have length {params.WIDTH}")
275+
276+
# Generate the deterministic round constants for this parameter set.
277+
round_constants = _generate_round_constants(params)
278+
# The number of full rounds is split between the beginning and end.
279+
half_rounds_f = params.ROUNDS_F // 2
280+
# Initialize index for accessing the flat list of round constants.
281+
const_idx = 0
282+
283+
# 1. Initial Linear Layer
284+
#
285+
# Another linear layer is applied at the start to prevent certain algebraic
286+
# attacks by ensuring the permutation begins with a diffusion layer.
287+
state = external_linear_layer(list(state), params.WIDTH)
288+
289+
# 2. First Half of Full Rounds (R_F / 2)
290+
for _r in range(half_rounds_f):
291+
# Add round constants to the entire state.
292+
state = [
293+
s + round_constants[const_idx + i] for i, s in enumerate(state)
294+
]
295+
const_idx += params.WIDTH
296+
# Apply the S-box (x -> x^d) to the full state.
297+
state = [s**S_BOX_DEGREE for s in state]
298+
# Apply the external linear layer for diffusion.
299+
state = external_linear_layer(state, params.WIDTH)
300+
301+
# 3. Partial Rounds (R_P)
302+
for _r in range(params.ROUNDS_P):
303+
# Add a single round constant to the first state element.
304+
state[0] += round_constants[const_idx]
305+
const_idx += 1
306+
# Apply the S-box to the first state element only.
307+
#
308+
# This is the main optimization of the Hades design.
309+
state[0] = state[0] ** S_BOX_DEGREE
310+
# Apply the internal linear layer.
311+
state = internal_linear_layer(state, params)
312+
313+
# 4. Second Half of Full Rounds (R_F / 2)
314+
for _r in range(half_rounds_f):
315+
# Add round constants to the entire state.
316+
state = [
317+
s + round_constants[const_idx + i] for i, s in enumerate(state)
318+
]
319+
const_idx += params.WIDTH
320+
# Apply the S-box to the full state.
321+
state = [s**S_BOX_DEGREE for s in state]
322+
# Apply the external linear layer for diffusion.
323+
state = external_linear_layer(state, params.WIDTH)
324+
325+
return state

0 commit comments

Comments
 (0)