Skip to content

Commit 786e2ec

Browse files
committed
air: add dot product air
1 parent 2d9ec8b commit 786e2ec

File tree

6 files changed

+300
-0
lines changed

6 files changed

+300
-0
lines changed
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
use std::borrow::Borrow;
2+
3+
use p3_air::{Air, AirBuilder, BaseAir};
4+
use p3_field::PrimeCharacteristicRing;
5+
use p3_matrix::Matrix;
6+
7+
use crate::{
8+
constant::{DOT_PRODUCT_AIR_COLUMNS, EF},
9+
witness::dot_product::WitnessDotProduct,
10+
};
11+
12+
/// Dot Product AIR
13+
///
14+
/// ## Trace Layout
15+
///
16+
/// Each dot product is computed recursively, row by row, from the last element to the first.
17+
/// The `start_flag` column marks the beginning of a new dot product computation (the first row
18+
/// in its sequence). The `len` column acts as a counter, decreasing with each step.
19+
///
20+
/// An example trace for two dot products might look like this:
21+
///
22+
/// | start_flag | len | addr_a | addr_b | addr_res | val_a | val_b | res | computation |
23+
/// |:----------:|:---:|:------:|:------:|:--------:|:------:|:------:|:------:|:-------------:|
24+
/// | 1 | 4 | 90 | 211 | 74 | m[90] | m[211] | m[74] | v_a*v_b + C_next |
25+
/// | 0 | 3 | 91 | 212 | 74 | m[91] | m[212] | m[74] | v_a*v_b + C_next |
26+
/// | 0 | 2 | 92 | 213 | 74 | m[92] | m[213] | m[74] | v_a*v_b + C_next |
27+
/// | 0 | 1 | 93 | 214 | 74 | m[93] | m[214] | m[74] | v_a*v_b |
28+
/// | 1 | 10 | 1008 | 854 | 325 | m[1008]| m[854] | m[325] | v_a*v_b + C_next |
29+
/// | ... | ... | ... | ... | ... | ... | ... | ... | ... |
30+
31+
#[derive(Debug, Default)]
32+
pub struct DotProductAir;
33+
34+
impl<F> BaseAir<F> for DotProductAir {
35+
fn width(&self) -> usize {
36+
DOT_PRODUCT_AIR_COLUMNS
37+
}
38+
}
39+
40+
impl<AB: AirBuilder> Air<AB> for DotProductAir {
41+
#[inline]
42+
fn eval(&self, builder: &mut AB) {
43+
// Get a view of the main execution trace.
44+
let main = builder.main();
45+
46+
// Get the current row (`local`) and the next row (`next`) from the trace.
47+
let local = main.row_slice(0).unwrap();
48+
let local = local.borrow();
49+
50+
let next = main.row_slice(1).unwrap();
51+
let next = next.borrow();
52+
53+
// Destructure the local row into named variables for clarity.
54+
let [
55+
start_flag_local,
56+
len_local,
57+
addr_a_local,
58+
addr_b_local,
59+
_addr_res_local,
60+
val_a_local,
61+
val_b_local,
62+
res_local,
63+
computation_local,
64+
]: [AB::Expr; DOT_PRODUCT_AIR_COLUMNS] = local
65+
.iter()
66+
.map(|v| v.clone().into())
67+
.collect::<Vec<_>>()
68+
.try_into()
69+
.unwrap();
70+
71+
// Destructure the next row into named variables.
72+
let [
73+
start_flag_next,
74+
len_next,
75+
addr_a_next,
76+
addr_b_next,
77+
_addr_res_next,
78+
_val_a_next,
79+
_val_b_next,
80+
_res_next,
81+
computation_next,
82+
]: [AB::Expr; DOT_PRODUCT_AIR_COLUMNS] = next
83+
.iter()
84+
.map(|v| v.clone().into())
85+
.collect::<Vec<_>>()
86+
.try_into()
87+
.unwrap();
88+
89+
// TRANSITION CONSTRAINTS
90+
91+
// This constraint ensures that the `start_flag` is always boolean.
92+
//
93+
// It's checked on the `next` row, as the last row of the trace will have a dummy next row.
94+
builder.assert_bool(start_flag_next.clone());
95+
96+
// This is the core recursive constraint for the dot product.
97+
//
98+
// `computation_local` = `val_a * val_b` + `computation_next` (if continuing a product)
99+
//
100+
// If the next row starts a new dot product (`start_flag_next`=1), `computation_next` is ignored.
101+
let product_local = val_a_local * val_b_local;
102+
let not_start_flag_next = AB::Expr::ONE - start_flag_next.clone();
103+
builder.assert_eq(
104+
computation_local.clone(),
105+
start_flag_next.clone() * product_local.clone()
106+
+ not_start_flag_next.clone() * (product_local + computation_next),
107+
);
108+
109+
// When not starting a new product, the length must decrement by 1.
110+
// `(1 - start_flag_next) * (len_local - (len_next + 1)) = 0`
111+
builder.assert_zero(
112+
not_start_flag_next.clone() * (len_local.clone() - (len_next + AB::Expr::ONE)),
113+
);
114+
115+
// If the remaining length is 1, the next row must start a new product (`start_flag_next` = 1).
116+
//
117+
// This is enforced by `(len_local - 1) * (1 - start_flag_next) = 0`.
118+
builder.assert_zero((len_local - AB::Expr::ONE) * (AB::Expr::ONE - start_flag_next));
119+
120+
// When not starting a new product, address `a` must increment by 1.
121+
// `(1 - start_flag_next) * (addr_a_next - (addr_a_local + 1)) = 0`
122+
builder.assert_zero(
123+
not_start_flag_next.clone() * (addr_a_next - (addr_a_local + AB::Expr::ONE)),
124+
);
125+
126+
// When not starting a new product, address `b` must increment by 1.
127+
// `(1 - start_flag_next) * (addr_b_next - (addr_b_local + 1)) = 0`
128+
builder.assert_zero(not_start_flag_next * (addr_b_next - (addr_b_local + AB::Expr::ONE)));
129+
130+
// If this is the first row of a dot product (`start_flag_local` = 1), the accumulated
131+
// `computation_local` must equal the final result `res_local`.
132+
builder.assert_zero(start_flag_local * (computation_local - res_local));
133+
}
134+
}
135+
136+
/// ## Build Dot Product Columns
137+
///
138+
/// This function constructs the execution trace (witness) for the Dot Product AIR.
139+
/// It takes a high-level description of dot product operations and expands it into the
140+
/// row-by-row format required by the AIR constraints.
141+
///
142+
/// ### Arguments
143+
/// * `witness`: A slice of `WitnessDotProduct` structs, each describing one dot product.
144+
///
145+
/// ### Returns
146+
/// A tuple containing:
147+
/// * A vector of columns representing the complete, padded execution trace.
148+
/// * The number of padding rows that were added.
149+
pub fn build_dot_product_columns(witness: &[WitnessDotProduct]) -> (Vec<Vec<EF>>, usize) {
150+
// Initialize vectors for each column of the trace.
151+
//
152+
// These will be populated and returned.
153+
let (
154+
mut flag,
155+
mut len,
156+
mut index_a,
157+
mut index_b,
158+
mut index_res,
159+
mut value_a,
160+
mut value_b,
161+
mut res,
162+
mut computation,
163+
) = (
164+
Vec::new(),
165+
Vec::new(),
166+
Vec::new(),
167+
Vec::new(),
168+
Vec::new(),
169+
Vec::new(),
170+
Vec::new(),
171+
Vec::new(),
172+
Vec::new(),
173+
);
174+
175+
// Process each high-level dot product operation from the witness.
176+
for dot_product in witness {
177+
// A dot product must have at least one term.
178+
assert!(dot_product.len > 0, "Dot product length must be positive.");
179+
180+
// Build the `computation` column
181+
//
182+
// This is the most complex column, representing the recursive accumulation.
183+
// We build it backwards, from the last term to the first.
184+
let mut current_computation = vec![EF::ZERO; dot_product.len];
185+
let last_idx = dot_product.len - 1;
186+
187+
// Base case: The computation for the last term is just the product of the last elements.
188+
current_computation[last_idx] =
189+
dot_product.slice_0[last_idx] * dot_product.slice_1[last_idx];
190+
191+
// Recursive step: Iterate backwards from the second-to-last term.
192+
for i in (0..last_idx).rev() {
193+
// The computation at step `i` is the product of elements at `i` plus the computation from step `i+1`.
194+
current_computation[i] =
195+
current_computation[i + 1] + dot_product.slice_0[i] * dot_product.slice_1[i];
196+
}
197+
// Add the fully computed trace for this dot product to the main computation column.
198+
computation.extend(current_computation);
199+
200+
// Build the other columns for the current dot product
201+
202+
// The `flag` column is:
203+
// - 1 for the first row and
204+
// - 0 for all subsequent rows of this operation.
205+
flag.push(EF::ONE);
206+
flag.extend(vec![EF::ZERO; dot_product.len - 1]);
207+
208+
// The `len` column is a countdown from the total length to 1.
209+
len.extend((1..=dot_product.len).rev().map(EF::from_usize));
210+
211+
// The `index_a` and `index_b` columns are the memory addresses, incrementing from the start.
212+
index_a.extend(
213+
(dot_product.addr_0..(dot_product.addr_0 + dot_product.len)).map(EF::from_usize),
214+
);
215+
index_b.extend(
216+
(dot_product.addr_1..(dot_product.addr_1 + dot_product.len)).map(EF::from_usize),
217+
);
218+
219+
// The `index_res` column holds the constant result address, repeated for every row.
220+
index_res.extend(vec![EF::from_usize(dot_product.addr_res); dot_product.len]);
221+
222+
// The `value_a` and `value_b` columns are direct copies of the input slices.
223+
value_a.extend_from_slice(&dot_product.slice_0);
224+
value_b.extend_from_slice(&dot_product.slice_1);
225+
226+
// The `res` column holds the final dot product result, repeated for every row.
227+
res.extend(vec![dot_product.res; dot_product.len]);
228+
}
229+
230+
// Pad the trace to a power-of-two length
231+
//
232+
// This is required for efficient polynomial commitment schemes (e.g., using FFTs).
233+
let padding_len = flag.len().next_power_of_two() - flag.len();
234+
235+
// If there is padding, add it to the trace
236+
if padding_len > 0 {
237+
// Use `start_flag=1` and `len=1` for padding rows. This is a simple state that
238+
// trivially satisfies the transition constraints when the other values are zero.
239+
flag.extend(vec![EF::ONE; padding_len]);
240+
len.extend(vec![EF::ONE; padding_len]);
241+
// The rest of the padding values can be zero.
242+
index_a.extend(vec![EF::ZERO; padding_len]);
243+
index_b.extend(vec![EF::ZERO; padding_len]);
244+
index_res.extend(vec![EF::ZERO; padding_len]);
245+
value_a.extend(vec![EF::ZERO; padding_len]);
246+
value_b.extend(vec![EF::ZERO; padding_len]);
247+
res.extend(vec![EF::ZERO; padding_len]);
248+
computation.extend(vec![EF::ZERO; padding_len]);
249+
}
250+
251+
// Return the completed columns and the amount of padding added.
252+
(
253+
vec![
254+
flag,
255+
len,
256+
index_a,
257+
index_b,
258+
index_res,
259+
value_a,
260+
value_b,
261+
res,
262+
computation,
263+
],
264+
padding_len,
265+
)
266+
}

crates/leanVm/src/air/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
pub mod constant;
2+
pub mod dot_product;
23
pub mod vm;

crates/leanVm/src/constant.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::ops::Range;
2+
13
use p3_field::extension::BinomialExtensionField;
24
use p3_koala_bear::KoalaBear;
35

@@ -37,3 +39,10 @@ pub const PUBLIC_INPUT_START: MemoryAddress = MemoryAddress::new(PUBLIC_DATA_SEG
3739

3840
/// The maximum size of the memory.
3941
pub const MAX_MEMORY_SIZE: usize = 1 << 23;
42+
43+
// Dot product constants
44+
45+
/// The total number of columns in the Dot Product AIR.
46+
pub(crate) const DOT_PRODUCT_AIR_COLUMNS: usize = 9;
47+
/// Defines column groups for processing.
48+
pub(crate) const DOT_PRODUCT_AIR_COLUMN_GROUPS: [Range<usize>; 5] = [0..1, 1..2, 2..5, 5..8, 8..9];

crates/leanVm/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ pub mod context;
55
pub mod core;
66
pub mod errors;
77
pub mod memory;
8+
pub mod witness;
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
use crate::constant::EF;
2+
3+
/// Holds the high-level witness data for a single dot product precompile execution.
4+
#[derive(Debug)]
5+
pub struct WitnessDotProduct {
6+
/// The CPU cycle at which the dot product operation is initiated.
7+
pub cycle: usize,
8+
/// The starting memory address (vectorized pointer) of the first input slice.
9+
pub addr_0: usize,
10+
/// The starting memory address (vectorized pointer) of the second input slice.
11+
pub addr_1: usize,
12+
/// The memory address (vectorized pointer) where the final result is stored.
13+
pub addr_res: usize,
14+
/// The number of elements in each input slice.
15+
pub len: usize,
16+
/// The actual data values of the first input slice.
17+
pub slice_0: Vec<EF>,
18+
/// The actual data values of the second input slice.
19+
pub slice_1: Vec<EF>,
20+
/// The final computed result of the dot product.
21+
pub res: EF,
22+
}

crates/leanVm/src/witness/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub mod dot_product;

0 commit comments

Comments
 (0)