1
1
use arbitrary:: Arbitrary ;
2
2
use strum:: Display ;
3
- use triton_vm:: prelude:: triton_asm ;
3
+ use triton_vm:: prelude:: * ;
4
4
use triton_vm:: table:: master_table:: MasterAuxTable ;
5
5
use triton_vm:: table:: master_table:: MasterMainTable ;
6
6
use triton_vm:: table:: master_table:: MasterTable ;
7
7
8
8
use crate :: data_type:: ArrayType ;
9
9
use crate :: prelude:: * ;
10
10
11
+ /// The type of field element used in
12
+ /// [`InnerProductOfThreeRowsWithWeights`].
11
13
#[ derive( Debug , Copy , Clone , Eq , PartialEq , Hash , Display , Arbitrary ) ]
12
14
pub enum MainElementType {
15
+ /// Corresponds to [`BFieldElement`].
13
16
Bfe ,
17
+
18
+ /// Corresponds to [`XFieldElement`].
14
19
Xfe ,
15
20
}
16
21
22
+ impl MainElementType {
23
+ fn dot_step ( & self ) -> LabelledInstruction {
24
+ match self {
25
+ Self :: Bfe => triton_instr ! ( xb_dot_step) ,
26
+ Self :: Xfe => triton_instr ! ( xx_dot_step) ,
27
+ }
28
+ }
29
+ }
30
+
17
31
impl From < MainElementType > for DataType {
18
32
fn from ( value : MainElementType ) -> Self {
19
33
match value {
@@ -26,9 +40,25 @@ impl From<MainElementType> for DataType {
26
40
/// Calculate inner products of Triton VM
27
41
/// [execution trace](triton_vm::table::master_table) rows with weights.
28
42
///
29
- /// Calculate inner product of both main columns and auxiliary columns with weights. Returns one
30
- /// scalar in the form of an auxiliary-field element. Main column can be either a base field
31
- /// element, or an auxiliary-field element.
43
+ /// Calculate inner product of both main columns and auxiliary columns with
44
+ /// weights. Returns one scalar in the form of an auxiliary-field element.
45
+ /// The main column can be either a base field element, or an auxiliary-field
46
+ /// element; see also [`MainElementType`].
47
+ ///
48
+ /// ### Behavior
49
+ ///
50
+ /// ```text
51
+ /// BEFORE: _
52
+ /// AFTER: _
53
+ /// ```
54
+ ///
55
+ /// ### Preconditions
56
+ ///
57
+ /// None.
58
+ ///
59
+ /// ### Postconditions
60
+ ///
61
+ /// None.
32
62
#[ derive( Debug , Copy , Clone , Eq , PartialEq , Hash ) ]
33
63
pub struct InnerProductOfThreeRowsWithWeights {
34
64
main_length : usize ,
@@ -82,43 +112,32 @@ impl BasicSnippet for InnerProductOfThreeRowsWithWeights {
82
112
format ! ( "tasmlib_array_inner_product_of_three_rows_with_weights_{element_ty}_mainrowelem" )
83
113
}
84
114
85
- fn code (
86
- & self ,
87
- _library : & mut crate :: library:: Library ,
88
- ) -> Vec < triton_vm:: prelude:: LabelledInstruction > {
89
- let entrypoint = self . entrypoint ( ) ;
90
- let acc_all_main_rows = match self . main_element_type {
91
- MainElementType :: Bfe => triton_asm ! [ xb_dot_step; self . main_length] ,
92
- MainElementType :: Xfe => triton_asm ! [ xx_dot_step; self . main_length] ,
93
- } ;
94
- let acc_all_aux_rows = triton_asm ! [ xx_dot_step; self . aux_length] ;
95
-
115
+ fn code ( & self , _: & mut Library ) -> Vec < LabelledInstruction > {
96
116
triton_asm ! {
97
117
// BEFORE: _ *aux_row *main_row *weights
98
118
// AFTER: _ [inner_product; 3]
99
- { entrypoint} :
119
+ { self . entrypoint( ) } :
100
120
push 0
101
121
push 0
102
122
push 0
103
- // _ *aux_row *main_row *weights 0 0 0
123
+ // _ *aux_row *main_row *weights [0: XFE]
104
124
105
125
pick 3
106
126
pick 4
107
- // _ *aux_row 0 0 0 *weights *main_row
127
+ // _ *aux_row [0: XFE] *weights *main_row
108
128
109
- { & acc_all_main_rows }
110
- // _ *aux_row acc2 acc1 acc0 *weights_next garbage
129
+ { & vec! [ self . main_element_type . dot_step ( ) ; self . main_length ] }
130
+ // _ *aux_row [acc: XFE] *weights_next garbage
111
131
112
132
pop 1
113
133
pick 4
114
- // _ acc2 acc1 acc0 *weights_next *aux_row
134
+ // _ [acc: XFE] *weights_next *aux_row
115
135
116
- { & acc_all_aux_rows }
117
- // _ acc2 acc1 acc0 garbage garbage
136
+ { & triton_asm! [ xx_dot_step ; self . aux_length ] }
137
+ // _ [acc: XFE] garbage garbage
118
138
119
139
pop 2
120
- // _ result2 result1 result0
121
- // _ [result; 3]
140
+ // _ [result: XFE]
122
141
123
142
return
124
143
}
@@ -138,25 +157,21 @@ mod tests {
138
157
fn three_rows_tvm_parameters_xfe_main_test ( ) {
139
158
let snippet =
140
159
InnerProductOfThreeRowsWithWeights :: triton_vm_parameters ( MainElementType :: Xfe ) ;
141
- ShadowedFunction :: new ( snippet) . test ( )
160
+ ShadowedAccessor :: new ( snippet) . test ( ) ;
142
161
}
143
162
144
163
#[ test]
145
164
fn three_rows_tvm_parameters_bfe_main_test ( ) {
146
165
let snippet =
147
166
InnerProductOfThreeRowsWithWeights :: triton_vm_parameters ( MainElementType :: Bfe ) ;
148
- ShadowedFunction :: new ( snippet) . test ( )
167
+ ShadowedAccessor :: new ( snippet) . test ( ) ;
149
168
}
150
169
151
- #[ test]
152
- fn works_with_main_or_aux_column_count_of_zero ( ) {
153
- for snippet in [
154
- InnerProductOfThreeRowsWithWeights :: new ( 0 , MainElementType :: Bfe , 8 ) ,
155
- InnerProductOfThreeRowsWithWeights :: new ( 0 , MainElementType :: Xfe , 14 ) ,
156
- InnerProductOfThreeRowsWithWeights :: new ( 12 , MainElementType :: Bfe , 0 ) ,
157
- InnerProductOfThreeRowsWithWeights :: new ( 16 , MainElementType :: Xfe , 0 ) ,
158
- ] {
159
- ShadowedFunction :: new ( snippet) . test ( )
170
+ #[ proptest( cases = 10 ) ]
171
+ fn main_or_aux_column_count_can_be_zero ( #[ strategy( 0_usize ..500 ) ] len : usize ) {
172
+ for elt_ty in [ MainElementType :: Bfe , MainElementType :: Xfe ] {
173
+ ShadowedAccessor :: new ( InnerProductOfThreeRowsWithWeights :: new ( 0 , elt_ty, len) ) . test ( ) ;
174
+ ShadowedAccessor :: new ( InnerProductOfThreeRowsWithWeights :: new ( len, elt_ty, 0 ) ) . test ( ) ;
160
175
}
161
176
}
162
177
@@ -168,26 +183,23 @@ mod tests {
168
183
) {
169
184
let snippet =
170
185
InnerProductOfThreeRowsWithWeights :: new ( main_length, main_element_type, aux_length) ;
171
- ShadowedFunction :: new ( snippet) . test ( )
186
+ ShadowedAccessor :: new ( snippet) . test ( ) ;
172
187
}
173
188
174
- impl Function for InnerProductOfThreeRowsWithWeights {
189
+ impl Accessor for InnerProductOfThreeRowsWithWeights {
175
190
fn rust_shadow (
176
191
& self ,
177
192
stack : & mut Vec < BFieldElement > ,
178
- memory : & mut HashMap < BFieldElement , BFieldElement > ,
193
+ memory : & HashMap < BFieldElement , BFieldElement > ,
179
194
) {
180
- // read stack: _ *e *b *w
181
195
let weights_address = stack. pop ( ) . unwrap ( ) ;
182
196
let main_row_address = stack. pop ( ) . unwrap ( ) ;
183
- let auxiliary_row_address = stack. pop ( ) . unwrap ( ) ;
197
+ let aux_row_address = stack. pop ( ) . unwrap ( ) ;
184
198
185
- // read arrays
186
199
let weights_len = self . main_length + self . aux_length ;
187
- let weights: Vec < XFieldElement > =
188
- array_from_memory ( weights_address, weights_len, memory) ;
189
- let aux_row: Vec < XFieldElement > =
190
- array_from_memory ( auxiliary_row_address, self . aux_length , memory) ;
200
+ let weights = array_from_memory :: < XFieldElement > ( weights_address, weights_len, memory) ;
201
+ let aux_row =
202
+ array_from_memory :: < XFieldElement > ( aux_row_address, self . aux_length , memory) ;
191
203
192
204
let main_row_as_xfes = match self . main_element_type {
193
205
MainElementType :: Bfe => {
@@ -201,23 +213,21 @@ mod tests {
201
213
}
202
214
} ;
203
215
204
- // compute inner product
205
216
let inner_product = main_row_as_xfes
206
217
. into_iter ( )
207
218
. chain ( aux_row)
208
219
. zip_eq ( weights)
209
220
. map ( |( element, weight) | element * weight)
210
221
. sum :: < XFieldElement > ( ) ;
211
222
212
- // write inner product back to stack
213
- stack. extend ( inner_product. coefficients . into_iter ( ) . rev ( ) ) ;
223
+ push_encodable ( stack, & inner_product)
214
224
}
215
225
216
226
fn pseudorandom_initial_state (
217
227
& self ,
218
228
seed : [ u8 ; 32 ] ,
219
- _bench_case : Option < BenchmarkCase > ,
220
- ) -> FunctionInitialState {
229
+ _ : Option < BenchmarkCase > ,
230
+ ) -> AccessorInitialState {
221
231
let mut rng = StdRng :: from_seed ( seed) ;
222
232
let main_address = rng. gen ( ) ;
223
233
let aux_address = rng. gen ( ) ;
@@ -241,7 +251,7 @@ mod tests {
241
251
let mut stack = self . init_stack_for_isolated_run ( ) ;
242
252
stack. extend ( [ aux_address, main_address, weights_address] ) ;
243
253
244
- FunctionInitialState { stack, memory }
254
+ AccessorInitialState { stack, memory }
245
255
}
246
256
}
247
257
}
@@ -254,17 +264,17 @@ mod benches {
254
264
/// Benchmark the calculation of the (in-domain) current rows that happen in the
255
265
/// main-loop, where all revealed FRI values are verified.
256
266
#[ test]
257
- fn inner_product_of_three_rows_bench_current_tvm_main_is_bfe ( ) {
267
+ fn bench_current_tvm_bfe ( ) {
258
268
let snippet =
259
269
InnerProductOfThreeRowsWithWeights :: triton_vm_parameters ( MainElementType :: Bfe ) ;
260
- ShadowedFunction :: new ( snippet) . bench ( ) ;
270
+ ShadowedAccessor :: new ( snippet) . bench ( ) ;
261
271
}
262
272
263
273
/// Benchmark the calculation of the out-of-domain current and next row values.
264
274
#[ test]
265
- fn inner_product_of_three_rows_bench_current_tvm_main_is_xfe ( ) {
275
+ fn bench_current_tvm_xfe ( ) {
266
276
let snippet =
267
277
InnerProductOfThreeRowsWithWeights :: triton_vm_parameters ( MainElementType :: Xfe ) ;
268
- ShadowedFunction :: new ( snippet) . bench ( ) ;
278
+ ShadowedAccessor :: new ( snippet) . bench ( ) ;
269
279
}
270
280
}
0 commit comments