Skip to content

Commit 2975d7b

Browse files
committed
signoff: Inner product of three rows with weights
1 parent cdfbb7e commit 2975d7b

File tree

1 file changed

+83
-56
lines changed

1 file changed

+83
-56
lines changed

tasm-lib/src/array/inner_product_of_three_rows_with_weights.rs

Lines changed: 83 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,37 @@
1+
use std::collections::HashMap;
2+
13
use arbitrary::Arbitrary;
24
use strum::Display;
3-
use triton_vm::prelude::triton_asm;
5+
use triton_vm::prelude::*;
46
use triton_vm::table::master_table::MasterAuxTable;
57
use triton_vm::table::master_table::MasterMainTable;
68
use triton_vm::table::master_table::MasterTable;
79

810
use crate::data_type::ArrayType;
911
use crate::prelude::*;
12+
use crate::traits::basic_snippet::Reviewer;
13+
use crate::traits::basic_snippet::SignOffFingerprint;
1014

15+
/// The type of field element used in
16+
/// [`InnerProductOfThreeRowsWithWeights`].
1117
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Display, Arbitrary)]
1218
pub enum MainElementType {
19+
/// Corresponds to [`BFieldElement`].
1320
Bfe,
21+
22+
/// Corresponds to [`XFieldElement`].
1423
Xfe,
1524
}
1625

26+
impl MainElementType {
27+
fn dot_step(&self) -> LabelledInstruction {
28+
match self {
29+
Self::Bfe => triton_instr!(xb_dot_step),
30+
Self::Xfe => triton_instr!(xx_dot_step),
31+
}
32+
}
33+
}
34+
1735
impl From<MainElementType> for DataType {
1836
fn from(value: MainElementType) -> Self {
1937
match value {
@@ -26,9 +44,25 @@ impl From<MainElementType> for DataType {
2644
/// Calculate inner products of Triton VM
2745
/// [execution trace](triton_vm::table::master_table) rows with weights.
2846
///
29-
/// Calculate inner product of both main columns and auxiliary columns with weights. Returns one
30-
/// scalar in the form of an auxiliary-field element. Main column can be either a base field
31-
/// element, or an auxiliary-field element.
47+
/// Calculate inner product of both main columns and auxiliary columns with
48+
/// weights. Returns one scalar in the form of an auxiliary-field element.
49+
/// The main column can be either a base field element, or an auxiliary-field
50+
/// element; see also [`MainElementType`].
51+
///
52+
/// ### Behavior
53+
///
54+
/// ```text
55+
/// BEFORE: _
56+
/// AFTER: _
57+
/// ```
58+
///
59+
/// ### Preconditions
60+
///
61+
/// None.
62+
///
63+
/// ### Postconditions
64+
///
65+
/// None.
3266
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
3367
pub struct InnerProductOfThreeRowsWithWeights {
3468
main_length: usize,
@@ -82,47 +116,49 @@ impl BasicSnippet for InnerProductOfThreeRowsWithWeights {
82116
format!("tasmlib_array_inner_product_of_three_rows_with_weights_{element_ty}_mainrowelem")
83117
}
84118

85-
fn code(
86-
&self,
87-
_library: &mut crate::library::Library,
88-
) -> Vec<triton_vm::prelude::LabelledInstruction> {
89-
let entrypoint = self.entrypoint();
90-
let acc_all_main_rows = match self.main_element_type {
91-
MainElementType::Bfe => triton_asm![xb_dot_step; self.main_length],
92-
MainElementType::Xfe => triton_asm![xx_dot_step; self.main_length],
93-
};
94-
let acc_all_aux_rows = triton_asm![xx_dot_step; self.aux_length];
95-
119+
fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
96120
triton_asm! {
97121
// BEFORE: _ *aux_row *main_row *weights
98122
// AFTER: _ [inner_product; 3]
99-
{entrypoint}:
123+
{self.entrypoint()}:
100124
push 0
101125
push 0
102126
push 0
103-
// _ *aux_row *main_row *weights 0 0 0
127+
// _ *aux_row *main_row *weights [0: XFE]
104128

105129
pick 3
106130
pick 4
107-
// _ *aux_row 0 0 0 *weights *main_row
131+
// _ *aux_row [0: XFE] *weights *main_row
108132

109-
{&acc_all_main_rows}
110-
// _ *aux_row acc2 acc1 acc0 *weights_next garbage
133+
{&vec![self.main_element_type.dot_step(); self.main_length]}
134+
// _ *aux_row [acc: XFE] *weights_next garbage
111135

112136
pop 1
113137
pick 4
114-
// _ acc2 acc1 acc0 *weights_next *aux_row
138+
// _ [acc: XFE] *weights_next *aux_row
115139

116-
{&acc_all_aux_rows}
117-
// _ acc2 acc1 acc0 garbage garbage
140+
{&triton_asm![xx_dot_step; self.aux_length]}
141+
// _ [acc: XFE] garbage garbage
118142

119143
pop 2
120-
// _ result2 result1 result0
121-
// _ [result; 3]
144+
// _ [result: XFE]
122145

123146
return
124147
}
125148
}
149+
150+
fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
151+
let mut sign_offs = HashMap::new();
152+
153+
if self == &Self::new(379, MainElementType::Bfe, 88) {
154+
sign_offs.insert(Reviewer("ferdinand"), 0x7e46570df803d4b.into());
155+
}
156+
if self == &Self::new(379, MainElementType::Xfe, 88) {
157+
sign_offs.insert(Reviewer("ferdinand"), 0x1c549ac8d61e6a70.into());
158+
}
159+
160+
sign_offs
161+
}
126162
}
127163

128164
#[cfg(test)]
@@ -138,25 +174,21 @@ mod tests {
138174
fn three_rows_tvm_parameters_xfe_main_test() {
139175
let snippet =
140176
InnerProductOfThreeRowsWithWeights::triton_vm_parameters(MainElementType::Xfe);
141-
ShadowedFunction::new(snippet).test()
177+
ShadowedAccessor::new(snippet).test();
142178
}
143179

144180
#[test]
145181
fn three_rows_tvm_parameters_bfe_main_test() {
146182
let snippet =
147183
InnerProductOfThreeRowsWithWeights::triton_vm_parameters(MainElementType::Bfe);
148-
ShadowedFunction::new(snippet).test()
184+
ShadowedAccessor::new(snippet).test();
149185
}
150186

151-
#[test]
152-
fn works_with_main_or_aux_column_count_of_zero() {
153-
for snippet in [
154-
InnerProductOfThreeRowsWithWeights::new(0, MainElementType::Bfe, 8),
155-
InnerProductOfThreeRowsWithWeights::new(0, MainElementType::Xfe, 14),
156-
InnerProductOfThreeRowsWithWeights::new(12, MainElementType::Bfe, 0),
157-
InnerProductOfThreeRowsWithWeights::new(16, MainElementType::Xfe, 0),
158-
] {
159-
ShadowedFunction::new(snippet).test()
187+
#[proptest(cases = 10)]
188+
fn main_or_aux_column_count_can_be_zero(#[strategy(0_usize..500)] len: usize) {
189+
for elt_ty in [MainElementType::Bfe, MainElementType::Xfe] {
190+
ShadowedAccessor::new(InnerProductOfThreeRowsWithWeights::new(0, elt_ty, len)).test();
191+
ShadowedAccessor::new(InnerProductOfThreeRowsWithWeights::new(len, elt_ty, 0)).test();
160192
}
161193
}
162194

@@ -168,26 +200,23 @@ mod tests {
168200
) {
169201
let snippet =
170202
InnerProductOfThreeRowsWithWeights::new(main_length, main_element_type, aux_length);
171-
ShadowedFunction::new(snippet).test()
203+
ShadowedAccessor::new(snippet).test();
172204
}
173205

174-
impl Function for InnerProductOfThreeRowsWithWeights {
206+
impl Accessor for InnerProductOfThreeRowsWithWeights {
175207
fn rust_shadow(
176208
&self,
177209
stack: &mut Vec<BFieldElement>,
178-
memory: &mut HashMap<BFieldElement, BFieldElement>,
210+
memory: &HashMap<BFieldElement, BFieldElement>,
179211
) {
180-
// read stack: _ *e *b *w
181212
let weights_address = stack.pop().unwrap();
182213
let main_row_address = stack.pop().unwrap();
183-
let auxiliary_row_address = stack.pop().unwrap();
214+
let aux_row_address = stack.pop().unwrap();
184215

185-
// read arrays
186216
let weights_len = self.main_length + self.aux_length;
187-
let weights: Vec<XFieldElement> =
188-
array_from_memory(weights_address, weights_len, memory);
189-
let aux_row: Vec<XFieldElement> =
190-
array_from_memory(auxiliary_row_address, self.aux_length, memory);
217+
let weights = array_from_memory::<XFieldElement>(weights_address, weights_len, memory);
218+
let aux_row =
219+
array_from_memory::<XFieldElement>(aux_row_address, self.aux_length, memory);
191220

192221
let main_row_as_xfes = match self.main_element_type {
193222
MainElementType::Bfe => {
@@ -201,23 +230,21 @@ mod tests {
201230
}
202231
};
203232

204-
// compute inner product
205233
let inner_product = main_row_as_xfes
206234
.into_iter()
207235
.chain(aux_row)
208236
.zip_eq(weights)
209237
.map(|(element, weight)| element * weight)
210238
.sum::<XFieldElement>();
211239

212-
// write inner product back to stack
213-
stack.extend(inner_product.coefficients.into_iter().rev());
240+
push_encodable(stack, &inner_product)
214241
}
215242

216243
fn pseudorandom_initial_state(
217244
&self,
218245
seed: [u8; 32],
219-
_bench_case: Option<BenchmarkCase>,
220-
) -> FunctionInitialState {
246+
_: Option<BenchmarkCase>,
247+
) -> AccessorInitialState {
221248
let mut rng = StdRng::from_seed(seed);
222249
let main_address = rng.gen();
223250
let aux_address = rng.gen();
@@ -241,7 +268,7 @@ mod tests {
241268
let mut stack = self.init_stack_for_isolated_run();
242269
stack.extend([aux_address, main_address, weights_address]);
243270

244-
FunctionInitialState { stack, memory }
271+
AccessorInitialState { stack, memory }
245272
}
246273
}
247274
}
@@ -254,17 +281,17 @@ mod benches {
254281
/// Benchmark the calculation of the (in-domain) current rows that happen in the
255282
/// main-loop, where all revealed FRI values are verified.
256283
#[test]
257-
fn inner_product_of_three_rows_bench_current_tvm_main_is_bfe() {
284+
fn bench_current_tvm_bfe() {
258285
let snippet =
259286
InnerProductOfThreeRowsWithWeights::triton_vm_parameters(MainElementType::Bfe);
260-
ShadowedFunction::new(snippet).bench();
287+
ShadowedAccessor::new(snippet).bench();
261288
}
262289

263290
/// Benchmark the calculation of the out-of-domain current and next row values.
264291
#[test]
265-
fn inner_product_of_three_rows_bench_current_tvm_main_is_xfe() {
292+
fn bench_current_tvm_xfe() {
266293
let snippet =
267294
InnerProductOfThreeRowsWithWeights::triton_vm_parameters(MainElementType::Xfe);
268-
ShadowedFunction::new(snippet).bench();
295+
ShadowedAccessor::new(snippet).bench();
269296
}
270297
}

0 commit comments

Comments
 (0)