Skip to content

Commit 5f60cc7

Browse files
committed
signoff: Inner product of three rows with weights
1 parent 9605fc4 commit 5f60cc7

File tree

1 file changed

+66
-56
lines changed

1 file changed

+66
-56
lines changed

tasm-lib/src/array/inner_product_of_three_rows_with_weights.rs

Lines changed: 66 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,33 @@
11
use arbitrary::Arbitrary;
22
use strum::Display;
3-
use triton_vm::prelude::triton_asm;
3+
use triton_vm::prelude::*;
44
use triton_vm::table::master_table::MasterAuxTable;
55
use triton_vm::table::master_table::MasterMainTable;
66
use triton_vm::table::master_table::MasterTable;
77

88
use crate::data_type::ArrayType;
99
use crate::prelude::*;
1010

11+
/// The type of field element used in
12+
/// [`InnerProductOfThreeRowsWithWeights`].
1113
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Display, Arbitrary)]
1214
pub enum MainElementType {
15+
/// Corresponds to [`BFieldElement`].
1316
Bfe,
17+
18+
/// Corresponds to [`XFieldElement`].
1419
Xfe,
1520
}
1621

22+
impl MainElementType {
23+
fn dot_step(&self) -> LabelledInstruction {
24+
match self {
25+
Self::Bfe => triton_instr!(xb_dot_step),
26+
Self::Xfe => triton_instr!(xx_dot_step),
27+
}
28+
}
29+
}
30+
1731
impl From<MainElementType> for DataType {
1832
fn from(value: MainElementType) -> Self {
1933
match value {
@@ -26,9 +40,25 @@ impl From<MainElementType> for DataType {
2640
/// Calculate inner products of Triton VM
2741
/// [execution trace](triton_vm::table::master_table) rows with weights.
2842
///
29-
/// Calculate inner product of both main columns and auxiliary columns with weights. Returns one
30-
/// scalar in the form of an auxiliary-field element. Main column can be either a base field
31-
/// element, or an auxiliary-field element.
43+
/// Calculate inner product of both main columns and auxiliary columns with
44+
/// weights. Returns one scalar in the form of an auxiliary-field element.
45+
/// The main column can be either a base field element, or an auxiliary-field
46+
/// element; see also [`MainElementType`].
47+
///
48+
/// ### Behavior
49+
///
50+
/// ```text
51+
/// BEFORE: _
52+
/// AFTER: _
53+
/// ```
54+
///
55+
/// ### Preconditions
56+
///
57+
/// None.
58+
///
59+
/// ### Postconditions
60+
///
61+
/// None.
3262
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
3363
pub struct InnerProductOfThreeRowsWithWeights {
3464
main_length: usize,
@@ -82,43 +112,32 @@ impl BasicSnippet for InnerProductOfThreeRowsWithWeights {
82112
format!("tasmlib_array_inner_product_of_three_rows_with_weights_{element_ty}_mainrowelem")
83113
}
84114

85-
fn code(
86-
&self,
87-
_library: &mut crate::library::Library,
88-
) -> Vec<triton_vm::prelude::LabelledInstruction> {
89-
let entrypoint = self.entrypoint();
90-
let acc_all_main_rows = match self.main_element_type {
91-
MainElementType::Bfe => triton_asm![xb_dot_step; self.main_length],
92-
MainElementType::Xfe => triton_asm![xx_dot_step; self.main_length],
93-
};
94-
let acc_all_aux_rows = triton_asm![xx_dot_step; self.aux_length];
95-
115+
fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
96116
triton_asm! {
97117
// BEFORE: _ *aux_row *main_row *weights
98118
// AFTER: _ [inner_product; 3]
99-
{entrypoint}:
119+
{self.entrypoint()}:
100120
push 0
101121
push 0
102122
push 0
103-
// _ *aux_row *main_row *weights 0 0 0
123+
// _ *aux_row *main_row *weights [0: XFE]
104124

105125
pick 3
106126
pick 4
107-
// _ *aux_row 0 0 0 *weights *main_row
127+
// _ *aux_row [0: XFE] *weights *main_row
108128

109-
{&acc_all_main_rows}
110-
// _ *aux_row acc2 acc1 acc0 *weights_next garbage
129+
{&vec![self.main_element_type.dot_step(); self.main_length]}
130+
// _ *aux_row [acc: XFE] *weights_next garbage
111131

112132
pop 1
113133
pick 4
114-
// _ acc2 acc1 acc0 *weights_next *aux_row
134+
// _ [acc: XFE] *weights_next *aux_row
115135

116-
{&acc_all_aux_rows}
117-
// _ acc2 acc1 acc0 garbage garbage
136+
{&triton_asm![xx_dot_step; self.aux_length]}
137+
// _ [acc: XFE] garbage garbage
118138

119139
pop 2
120-
// _ result2 result1 result0
121-
// _ [result; 3]
140+
// _ [result: XFE]
122141

123142
return
124143
}
@@ -138,25 +157,21 @@ mod tests {
138157
fn three_rows_tvm_parameters_xfe_main_test() {
139158
let snippet =
140159
InnerProductOfThreeRowsWithWeights::triton_vm_parameters(MainElementType::Xfe);
141-
ShadowedFunction::new(snippet).test()
160+
ShadowedAccessor::new(snippet).test();
142161
}
143162

144163
#[test]
145164
fn three_rows_tvm_parameters_bfe_main_test() {
146165
let snippet =
147166
InnerProductOfThreeRowsWithWeights::triton_vm_parameters(MainElementType::Bfe);
148-
ShadowedFunction::new(snippet).test()
167+
ShadowedAccessor::new(snippet).test();
149168
}
150169

151-
#[test]
152-
fn works_with_main_or_aux_column_count_of_zero() {
153-
for snippet in [
154-
InnerProductOfThreeRowsWithWeights::new(0, MainElementType::Bfe, 8),
155-
InnerProductOfThreeRowsWithWeights::new(0, MainElementType::Xfe, 14),
156-
InnerProductOfThreeRowsWithWeights::new(12, MainElementType::Bfe, 0),
157-
InnerProductOfThreeRowsWithWeights::new(16, MainElementType::Xfe, 0),
158-
] {
159-
ShadowedFunction::new(snippet).test()
170+
#[proptest(cases = 10)]
171+
fn main_or_aux_column_count_can_be_zero(#[strategy(0_usize..500)] len: usize) {
172+
for elt_ty in [MainElementType::Bfe, MainElementType::Xfe] {
173+
ShadowedAccessor::new(InnerProductOfThreeRowsWithWeights::new(0, elt_ty, len)).test();
174+
ShadowedAccessor::new(InnerProductOfThreeRowsWithWeights::new(len, elt_ty, 0)).test();
160175
}
161176
}
162177

@@ -168,26 +183,23 @@ mod tests {
168183
) {
169184
let snippet =
170185
InnerProductOfThreeRowsWithWeights::new(main_length, main_element_type, aux_length);
171-
ShadowedFunction::new(snippet).test()
186+
ShadowedAccessor::new(snippet).test();
172187
}
173188

174-
impl Function for InnerProductOfThreeRowsWithWeights {
189+
impl Accessor for InnerProductOfThreeRowsWithWeights {
175190
fn rust_shadow(
176191
&self,
177192
stack: &mut Vec<BFieldElement>,
178-
memory: &mut HashMap<BFieldElement, BFieldElement>,
193+
memory: &HashMap<BFieldElement, BFieldElement>,
179194
) {
180-
// read stack: _ *e *b *w
181195
let weights_address = stack.pop().unwrap();
182196
let main_row_address = stack.pop().unwrap();
183-
let auxiliary_row_address = stack.pop().unwrap();
197+
let aux_row_address = stack.pop().unwrap();
184198

185-
// read arrays
186199
let weights_len = self.main_length + self.aux_length;
187-
let weights: Vec<XFieldElement> =
188-
array_from_memory(weights_address, weights_len, memory);
189-
let aux_row: Vec<XFieldElement> =
190-
array_from_memory(auxiliary_row_address, self.aux_length, memory);
200+
let weights = array_from_memory::<XFieldElement>(weights_address, weights_len, memory);
201+
let aux_row =
202+
array_from_memory::<XFieldElement>(aux_row_address, self.aux_length, memory);
191203

192204
let main_row_as_xfes = match self.main_element_type {
193205
MainElementType::Bfe => {
@@ -201,23 +213,21 @@ mod tests {
201213
}
202214
};
203215

204-
// compute inner product
205216
let inner_product = main_row_as_xfes
206217
.into_iter()
207218
.chain(aux_row)
208219
.zip_eq(weights)
209220
.map(|(element, weight)| element * weight)
210221
.sum::<XFieldElement>();
211222

212-
// write inner product back to stack
213-
stack.extend(inner_product.coefficients.into_iter().rev());
223+
push_encodable(stack, &inner_product)
214224
}
215225

216226
fn pseudorandom_initial_state(
217227
&self,
218228
seed: [u8; 32],
219-
_bench_case: Option<BenchmarkCase>,
220-
) -> FunctionInitialState {
229+
_: Option<BenchmarkCase>,
230+
) -> AccessorInitialState {
221231
let mut rng = StdRng::from_seed(seed);
222232
let main_address = rng.gen();
223233
let aux_address = rng.gen();
@@ -241,7 +251,7 @@ mod tests {
241251
let mut stack = self.init_stack_for_isolated_run();
242252
stack.extend([aux_address, main_address, weights_address]);
243253

244-
FunctionInitialState { stack, memory }
254+
AccessorInitialState { stack, memory }
245255
}
246256
}
247257
}
@@ -254,17 +264,17 @@ mod benches {
254264
/// Benchmark the calculation of the (in-domain) current rows that happen in the
255265
/// main-loop, where all revealed FRI values are verified.
256266
#[test]
257-
fn inner_product_of_three_rows_bench_current_tvm_main_is_bfe() {
267+
fn bench_current_tvm_bfe() {
258268
let snippet =
259269
InnerProductOfThreeRowsWithWeights::triton_vm_parameters(MainElementType::Bfe);
260-
ShadowedFunction::new(snippet).bench();
270+
ShadowedAccessor::new(snippet).bench();
261271
}
262272

263273
/// Benchmark the calculation of the out-of-domain current and next row values.
264274
#[test]
265-
fn inner_product_of_three_rows_bench_current_tvm_main_is_xfe() {
275+
fn bench_current_tvm_xfe() {
266276
let snippet =
267277
InnerProductOfThreeRowsWithWeights::triton_vm_parameters(MainElementType::Xfe);
268-
ShadowedFunction::new(snippet).bench();
278+
ShadowedAccessor::new(snippet).bench();
269279
}
270280
}

0 commit comments

Comments
 (0)