Skip to content

Commit 9c6308f

Browse files
committed
instruction: fix poseidon instructions
1 parent 178b76e commit 9c6308f

File tree

3 files changed

+169
-89
lines changed

3 files changed

+169
-89
lines changed
Lines changed: 68 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
use p3_field::BasedVectorSpace;
12
use p3_symmetric::Permutation;
23

34
use crate::{
5+
bytecode::operand::{MemOrConstant, MemOrFp},
46
constant::{DIMENSION, F},
57
context::run_context::RunContext,
68
errors::vm::VirtualMachineError,
@@ -11,21 +13,21 @@ use crate::{
1113
/// Poseidon2 permutation over 16 field elements (2 inputs, 2 outputs).
1214
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
1315
pub struct Poseidon2_16Instruction {
14-
/// The starting offset `s` from `fp`. The instruction reads 4 pointers from `m[fp+s]` to `m[fp+s+3]`.
15-
pub shift: usize,
16+
/// A pointer to the first 8-element input vector.
17+
pub arg_a: MemOrConstant,
18+
/// A pointer to the second 8-element input vector.
19+
pub arg_b: MemOrConstant,
20+
/// A pointer to the location for the two 8-element output vectors.
21+
pub res: MemOrFp,
1622
}
1723

1824
impl Poseidon2_16Instruction {
1925
/// Executes the `Poseidon2_16` precompile instruction.
2026
///
21-
/// Reads four pointers from memory starting at `fp + shift`, representing:
22-
/// - two input vector addresses (`ptr_arg_0`, `ptr_arg_1`)
23-
/// - two output vector addresses (`ptr_res_0`, `ptr_res_1`)
24-
///
25-
/// Each input is an 8-element vector of `F`. The two inputs are concatenated,
26-
/// permuted using Poseidon2, and written back to the output locations.
27-
///
28-
/// The operation is: `Poseidon2(m_vec[ptr_0], m_vec[ptr_1]) -> (m_vec[ptr_2], m_vec[ptr_3])`
27+
/// This function resolves pointers from its operands to find the memory locations for
28+
/// two 8-element input vectors and two 8-element output vectors. It reads the inputs,
29+
/// concatenates them, applies the permutation, and writes the two resulting vectors
30+
/// back to their designated output locations.
2931
pub fn execute<Perm>(
3032
&self,
3133
run_context: &RunContext,
@@ -35,42 +37,50 @@ impl Poseidon2_16Instruction {
3537
where
3638
Perm: Permutation<[F; 2 * DIMENSION]>,
3739
{
38-
// Read Pointers from Memory
40+
// Pointer Resolution
3941
//
40-
// The instruction specifies 4 consecutive pointers starting at `fp + shift`.
41-
let base_ptr_addr = (run_context.fp + self.shift)?;
42-
let ptrs: [MemoryAddress; 4] = memory_manager.memory.get_array_as(base_ptr_addr)?;
43-
44-
// Convert the `MemoryValue` pointers to `MemoryAddress`.
45-
let [ptr_arg_0, ptr_arg_1, ptr_res_0, ptr_res_1] = ptrs;
42+
// Resolve the pointer to the first input vector from the `arg_a` operand.
43+
let ptr_arg_a: MemoryAddress = run_context
44+
.value_from_mem_or_constant(&self.arg_a, memory_manager)?
45+
.try_into()?;
46+
// Resolve the pointer to the second input vector from the `arg_b` operand.
47+
let ptr_arg_b: MemoryAddress = run_context
48+
.value_from_mem_or_constant(&self.arg_b, memory_manager)?
49+
.try_into()?;
50+
// Resolve the pointer for the start of the output block from the `res` operand.
51+
let ptr_res: MemoryAddress = run_context
52+
.value_from_mem_or_fp(&self.res, memory_manager)?
53+
.try_into()?;
54+
// The second output vector will be stored immediately after the first one.
55+
let ptr_res_b = (ptr_res + 1)?;
4656

47-
// Read Input Vectors
57+
// Data Reading
4858
//
49-
// Read the 8-element vectors from the locations pointed to by `ptr_arg_0` and `ptr_arg_1`.
59+
// Read the first 8-element input vector from memory.
5060
let arg0 = memory_manager
5161
.memory
52-
.get_array_as::<F, DIMENSION>(ptr_arg_0)?;
62+
.get_array_as::<F, DIMENSION>(ptr_arg_a)?;
63+
// Read the second 8-element input vector from memory.
5364
let arg1 = memory_manager
5465
.memory
55-
.get_array_as::<F, DIMENSION>(ptr_arg_1)?;
66+
.get_array_as::<F, DIMENSION>(ptr_arg_b)?;
5667

57-
// Perform Hashing
68+
// Hashing
5869
//
5970
// Concatenate the two input vectors into a single 16-element array for the permutation.
6071
let mut state = [arg0, arg1].concat().try_into().unwrap();
61-
6272
// Apply the Poseidon2 permutation to the state.
6373
perm.permute_mut(&mut state);
6474

65-
// Write Output Vectors
75+
// Data Writing
6676
//
6777
// Split the permuted state back into two 8-element output vectors.
6878
let res0: [F; DIMENSION] = state[..DIMENSION].try_into().unwrap();
6979
let res1: [F; DIMENSION] = state[DIMENSION..].try_into().unwrap();
70-
71-
// Write the output vectors to the memory locations pointed to by `ptr_res_0` and `ptr_res_1`.
72-
memory_manager.load_data(ptr_res_0, &res0)?;
73-
memory_manager.load_data(ptr_res_1, &res1)?;
80+
// Write the first output vector to its memory location.
81+
memory_manager.load_data(ptr_res, &res0)?;
82+
// Write the second output vector to its memory location.
83+
memory_manager.load_data(ptr_res_b, &res1)?;
7484

7585
Ok(())
7686
}
@@ -82,36 +92,51 @@ impl Poseidon2_16Instruction {
8292
run_context: &RunContext,
8393
memory_manager: &MemoryManager,
8494
) -> Result<WitnessPoseidon16, VirtualMachineError> {
85-
// Read the four pointers (input_a, input_b, output_a, output_b) from memory.
86-
let base_ptr_addr = (run_context.fp + self.shift)?;
87-
let [addr_input_a, addr_input_b, addr_output_a, addr_output_b]: [MemoryAddress; 4] =
88-
memory_manager.memory.get_array_as(base_ptr_addr)?;
95+
// Pointer Resolution
96+
//
97+
// Resolve the pointer to the first input vector from the `arg_a` operand.
98+
let addr_input_a: MemoryAddress = run_context
99+
.value_from_mem_or_constant(&self.arg_a, memory_manager)?
100+
.try_into()?;
101+
// Resolve the pointer to the second input vector from the `arg_b` operand.
102+
let addr_input_b: MemoryAddress = run_context
103+
.value_from_mem_or_constant(&self.arg_b, memory_manager)?
104+
.try_into()?;
105+
// Resolve the pointer for the start of the output block from the `res` operand.
106+
let addr_output: MemoryAddress = run_context
107+
.value_from_mem_or_fp(&self.res, memory_manager)?
108+
.try_into()?;
89109

90-
// Read the two 8-element input vectors from their respective addresses.
110+
// Data Reading
111+
//
112+
// Read the first 8-element input vector from its respective address.
91113
let value_a = memory_manager
92114
.memory
93115
.get_array_as::<F, DIMENSION>(addr_input_a)?;
116+
// Read the second 8-element input vector from its respective address.
94117
let value_b = memory_manager
95118
.memory
96119
.get_array_as::<F, DIMENSION>(addr_input_b)?;
97-
98-
// Read the two 8-element output vectors.
99-
let output_a = memory_manager
120+
// Read the full 16-element output from memory, starting at the output address.
121+
let output = memory_manager
100122
.memory
101-
.get_array_as::<F, DIMENSION>(addr_output_a)?;
102-
let output_b = memory_manager
103-
.memory
104-
.get_array_as::<F, DIMENSION>(addr_output_b)?;
123+
.get_vectorized_slice_extension(addr_output, 2)?;
124+
let output_coeffs: Vec<F> = output
125+
.iter()
126+
.flat_map(BasedVectorSpace::as_basis_coefficients_slice)
127+
.copied()
128+
.collect();
105129

130+
// Witness Construction
131+
//
106132
// Construct and return the witness struct.
107133
Ok(WitnessPoseidon16 {
108134
cycle: Some(cycle),
109135
addr_input_a,
110136
addr_input_b,
111-
// The output address is the start of the first output vector.
112-
addr_output: addr_output_a,
137+
addr_output,
113138
input: [value_a, value_b].concat().try_into().unwrap(),
114-
output: [output_a, output_b].concat().try_into().unwrap(),
139+
output: output_coeffs.try_into().unwrap(),
115140
})
116141
}
117142
}
Lines changed: 98 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,33 @@
1+
use p3_field::BasedVectorSpace;
12
use p3_symmetric::Permutation;
23

34
use crate::{
5+
bytecode::operand::{MemOrConstant, MemOrFp},
46
constant::{DIMENSION, F},
57
context::run_context::RunContext,
68
errors::vm::VirtualMachineError,
79
memory::{address::MemoryAddress, manager::MemoryManager},
10+
witness::poseidon::WitnessPoseidon24,
811
};
912

1013
/// Poseidon2 permutation over 24 field elements (3 inputs, 3 outputs).
1114
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
1215
pub struct Poseidon2_24Instruction {
13-
/// The starting offset from `fp`. The instruction reads 6 pointers from `m[fp+shift]` to `m[fp+shift+5]`.
14-
pub shift: usize,
16+
/// A pointer to the first two 8-element input vectors.
17+
pub arg_a: MemOrConstant,
18+
/// A pointer to the third 8-element input vector.
19+
pub arg_b: MemOrConstant,
20+
/// A pointer to the location of the third 8-element output vector.
21+
pub res: MemOrFp,
1522
}
1623

1724
impl Poseidon2_24Instruction {
1825
/// Executes the `Poseidon2_24` precompile instruction.
1926
///
20-
/// Reads six pointers from memory starting at `fp + shift`, representing:
21-
/// - three input vector addresses (`ptr_arg_0`, `ptr_arg_1`, `ptr_arg_2`)
22-
/// - three output vector addresses (`ptr_res_0`, `ptr_res_1`, `ptr_res_2`)
23-
///
24-
/// Each input is an 8-element vector of `F`. The three inputs are concatenated into
25-
/// a single 24-element state, permuted using Poseidon2, and written back to the three
26-
/// output locations as three 8-element vectors.
27-
///
28-
/// The operation is:
29-
/// `Poseidon2(m_vec[ptr_0], m_vec[ptr_1], m_vec[ptr_2]) -> (m_vec[ptr_3], m_vec[ptr_4], m_vec[ptr_5])`
27+
/// This function resolves pointers from its operands to find the memory locations for
28+
/// three 8-element input vectors and one 8-element output vector. It reads the inputs,
29+
/// concatenates them, applies the permutation, and writes the third resulting vector
30+
/// back to its designated output location.
3031
pub fn execute<Perm>(
3132
&self,
3233
run_context: &RunContext,
@@ -36,55 +37,109 @@ impl Poseidon2_24Instruction {
3637
where
3738
Perm: Permutation<[F; 3 * DIMENSION]>,
3839
{
39-
// Read Pointers from Memory
40+
// Pointer Resolution
4041
//
41-
// The instruction specifies 6 consecutive pointers starting at `fp + shift`.
42-
let ptr_addr = (run_context.fp + self.shift)?;
43-
let ptrs: [MemoryAddress; 6] = memory_manager.memory.get_array_as(ptr_addr)?;
42+
// Resolve the pointer to the first block of two input vectors from the `arg_a` operand.
43+
let ptr_arg_a: MemoryAddress = run_context
44+
.value_from_mem_or_constant(&self.arg_a, memory_manager)?
45+
.try_into()?;
46+
// The second input vector is located immediately after the first one.
47+
let ptr_arg_b = (ptr_arg_a + 1)?;
48+
// Resolve the pointer to the third input vector from the `arg_b` operand.
49+
let ptr_arg_c: MemoryAddress = run_context
50+
.value_from_mem_or_constant(&self.arg_b, memory_manager)?
51+
.try_into()?;
52+
// Resolve the pointer to the third output vector from the `res` operand.
53+
let ptr_res: MemoryAddress = run_context
54+
.value_from_mem_or_fp(&self.res, memory_manager)?
55+
.try_into()?;
4456

45-
// Convert the raw memory values into memory addresses.
46-
let [
47-
ptr_arg_0,
48-
ptr_arg_1,
49-
ptr_arg_2,
50-
ptr_res_0,
51-
ptr_res_1,
52-
ptr_res_2,
53-
] = ptrs;
54-
55-
// Read Input Vectors
57+
// Data Reading
5658
//
57-
// Each is an 8-element array of field elements.
59+
// Read the first 8-element input vector from memory.
5860
let arg0 = memory_manager
5961
.memory
60-
.get_array_as::<F, DIMENSION>(ptr_arg_0)?;
62+
.get_array_as::<F, DIMENSION>(ptr_arg_a)?;
63+
// Read the second 8-element input vector from memory.
6164
let arg1 = memory_manager
6265
.memory
63-
.get_array_as::<F, DIMENSION>(ptr_arg_1)?;
66+
.get_array_as::<F, DIMENSION>(ptr_arg_b)?;
67+
// Read the third 8-element input vector from memory.
6468
let arg2 = memory_manager
6569
.memory
66-
.get_array_as::<F, DIMENSION>(ptr_arg_2)?;
70+
.get_array_as::<F, DIMENSION>(ptr_arg_c)?;
6771

68-
// Perform Hashing
72+
// Hashing
6973
//
7074
// Concatenate the three input vectors into a single 24-element array for the permutation.
7175
let mut state = [arg0, arg1, arg2].concat().try_into().unwrap();
72-
7376
// Apply the Poseidon2 permutation to the state.
7477
perm.permute_mut(&mut state);
7578

76-
// Write Output Vectors
79+
// Data Writing
7780
//
78-
// Split the permuted state back into three 8-element output vectors.
79-
let res0: [F; DIMENSION] = state[..DIMENSION].try_into().unwrap();
80-
let res1: [F; DIMENSION] = state[DIMENSION..2 * DIMENSION].try_into().unwrap();
81-
let res2: [F; DIMENSION] = state[2 * DIMENSION..].try_into().unwrap();
82-
83-
// Write the output vectors to the memory locations pointed to by the result pointers.
84-
memory_manager.load_data(ptr_res_0, &res0)?;
85-
memory_manager.load_data(ptr_res_1, &res1)?;
86-
memory_manager.load_data(ptr_res_2, &res2)?;
81+
// Extract the last 8 elements of the permuted state, which is the result.
82+
let res: [F; DIMENSION] = state[2 * DIMENSION..].try_into().unwrap();
83+
// Write the result vector to the memory location pointed to by the result pointer.
84+
memory_manager.load_data(ptr_res, &res)?;
8785

8886
Ok(())
8987
}
88+
89+
/// Generates the witness for a `Poseidon2_24` instruction execution.
90+
pub fn generate_witness(
91+
&self,
92+
cycle: usize,
93+
run_context: &RunContext,
94+
memory_manager: &MemoryManager,
95+
) -> Result<WitnessPoseidon24, VirtualMachineError> {
96+
// Pointer Resolution
97+
//
98+
// Resolve the pointer to the first block of two input vectors from the `arg_a` operand.
99+
let addr_input_a: MemoryAddress = run_context
100+
.value_from_mem_or_constant(&self.arg_a, memory_manager)?
101+
.try_into()?;
102+
// Resolve the pointer to the third input vector from the `arg_b` operand.
103+
let addr_input_b: MemoryAddress = run_context
104+
.value_from_mem_or_constant(&self.arg_b, memory_manager)?
105+
.try_into()?;
106+
// Resolve the pointer to the third output vector from the `res` operand.
107+
let addr_output: MemoryAddress = run_context
108+
.value_from_mem_or_fp(&self.res, memory_manager)?
109+
.try_into()?;
110+
111+
// Data Reading
112+
//
113+
// Read the first two input vectors (a slice of 2 EF elements, total 16 F elements) from memory.
114+
let value_a = memory_manager
115+
.memory
116+
.get_vectorized_slice_extension(addr_input_a, 2)?;
117+
// Read the third input vector (a single EF element, total 8 F elements) from memory.
118+
let value_b = memory_manager.memory.get_extension(addr_input_b)?;
119+
// Read the third output vector (a single EF element, total 8 F elements) from memory.
120+
let output = memory_manager.memory.get_extension(addr_output)?;
121+
122+
// Witness Construction
123+
//
124+
// Get the F coefficients from the first two EF elements (value_a).
125+
let input_part1: Vec<F> = value_a
126+
.iter()
127+
.flat_map(BasedVectorSpace::as_basis_coefficients_slice)
128+
.copied()
129+
.collect();
130+
// Get the F coefficients from the third EF element (value_b).
131+
let input_part2: Vec<F> = value_b.as_basis_coefficients_slice().to_vec();
132+
// Concatenate the parts to form the full 24-element input for the witness.
133+
let full_input = [input_part1, input_part2].concat().try_into().unwrap();
134+
135+
// Construct and return the final witness struct.
136+
Ok(WitnessPoseidon24 {
137+
cycle: Some(cycle),
138+
addr_input_a,
139+
addr_input_b,
140+
addr_output,
141+
input: full_input,
142+
output: output.as_basis_coefficients_slice().try_into().unwrap(),
143+
})
144+
}
90145
}

crates/leanVm/src/witness/poseidon.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ pub struct WitnessPoseidon24 {
2323
/// The CPU cycle at which this operation is initiated, if applicable.
2424
pub cycle: Option<usize>,
2525
/// The memory address (vectorized pointer, of size 2) of the first two 8-element input vectors.
26-
pub addr_input_a: usize,
26+
pub addr_input_a: MemoryAddress,
2727
/// The memory address (vectorized pointer, of size 1) of the third 8-element input vector.
28-
pub addr_input_b: usize,
28+
pub addr_input_b: MemoryAddress,
2929
/// The memory address (vectorized pointer, of size 1) where the relevant 8-element output vector is stored.
30-
pub addr_output: usize,
30+
pub addr_output: MemoryAddress,
3131
/// The full 24-element input state for the permutation.
3232
pub input: [F; 24],
3333
/// The last 8 elements of the 24-element output state from the permutation.

0 commit comments

Comments
 (0)