@@ -3,6 +3,22 @@ use triton_vm::prelude::*;
3
3
use crate :: data_type:: ArrayType ;
4
4
use crate :: prelude:: * ;
5
5
6
+ /// Compute the inner product of two lists of [`XFieldElement`]s.
7
+ ///
8
+ /// ### Behavior
9
+ ///
10
+ /// ```text
11
+ /// BEFORE: _ *a *b
12
+ /// AFTER: _ [inner_product: XFieldElement]
13
+ /// ```
14
+ ///
15
+ /// ### Preconditions
16
+ ///
17
+ /// None.
18
+ ///
19
+ /// ### Postconditions
20
+ ///
21
+ /// None.
6
22
#[ derive( Debug , Copy , Clone , Eq , PartialEq , Hash ) ]
7
23
pub struct InnerProductOfXfes {
8
24
pub length : usize ,
@@ -12,20 +28,18 @@ impl InnerProductOfXfes {
12
28
pub fn new ( length : usize ) -> Self {
13
29
Self { length }
14
30
}
15
-
16
- fn argument_type ( & self ) -> DataType {
17
- DataType :: Array ( Box :: new ( ArrayType {
18
- element_type : DataType :: Xfe ,
19
- length : self . length ,
20
- } ) )
21
- }
22
31
}
23
32
24
33
impl BasicSnippet for InnerProductOfXfes {
25
34
fn inputs ( & self ) -> Vec < ( DataType , String ) > {
35
+ let argument_type = DataType :: Array ( Box :: new ( ArrayType {
36
+ element_type : DataType :: Xfe ,
37
+ length : self . length ,
38
+ } ) ) ;
39
+
26
40
vec ! [
27
- ( self . argument_type( ) , "*a" . to_owned( ) ) ,
28
- ( self . argument_type( ) , "*b" . to_owned( ) ) ,
41
+ ( argument_type. clone ( ) , "*a" . to_owned( ) ) ,
42
+ ( argument_type, "*b" . to_owned( ) ) ,
29
43
]
30
44
}
31
45
@@ -37,29 +51,26 @@ impl BasicSnippet for InnerProductOfXfes {
37
51
format ! ( "tasmlib_array_inner_product_of_{}_xfes" , self . length)
38
52
}
39
53
40
- fn code ( & self , _library : & mut Library ) -> Vec < LabelledInstruction > {
41
- let entrypoint = self . entrypoint ( ) ;
42
-
43
- let accumulate_all_indices = triton_asm ! [ xx_dot_step; self . length] ;
44
-
54
+ fn code ( & self , _: & mut Library ) -> Vec < LabelledInstruction > {
45
55
triton_asm ! (
46
- { entrypoint} :
47
- // _ *a *b
56
+ // BEFORE: _ *a *b
57
+ // AFTER: _ [inner_product: XFieldElement]
58
+ { self . entrypoint( ) } :
48
59
49
60
push 0
50
61
push 0
51
62
push 0
52
- // _ *a *b 0 0 0
63
+ // _ *a *b [0: XFE]
53
64
54
65
pick 4
55
66
pick 4
56
- // _ 0 0 0 *a *b
67
+ // _ [0: XFE] *a *b
57
68
58
- { & accumulate_all_indices }
59
- // _ acc2 acc1 acc0 *garbage0 *garbage1
69
+ { & triton_asm! [ xx_dot_step ; self . length ] }
70
+ // _ [acc: XFE] *garbage0 *garbage1
60
71
61
72
pop 2
62
- // _ acc2 acc1 acc0
73
+ // _ [acc: XFE]
63
74
64
75
return
65
76
)
@@ -68,179 +79,93 @@ impl BasicSnippet for InnerProductOfXfes {
68
79
69
80
#[ cfg( test) ]
70
81
mod tests {
71
- use num:: Zero ;
72
- use num_traits:: ConstZero ;
73
- use twenty_first:: math:: x_field_element:: EXTENSION_DEGREE ;
74
-
75
82
use super :: * ;
83
+ use crate :: rust_shadowing_helper_functions:: array:: array_from_memory;
76
84
use crate :: rust_shadowing_helper_functions:: array:: insert_as_array;
77
85
use crate :: rust_shadowing_helper_functions:: array:: insert_random_array;
78
86
use crate :: test_prelude:: * ;
79
87
80
- impl Function for InnerProductOfXfes {
88
+ impl Accessor for InnerProductOfXfes {
81
89
fn rust_shadow (
82
90
& self ,
83
91
stack : & mut Vec < BFieldElement > ,
84
- memory : & mut HashMap < BFieldElement , BFieldElement > ,
92
+ memory : & HashMap < BFieldElement , BFieldElement > ,
85
93
) {
86
- fn read_xfe (
87
- memory : & HashMap < BFieldElement , BFieldElement > ,
88
- address : BFieldElement ,
89
- ) -> XFieldElement {
90
- let coefficients = [
91
- memory. get ( & address) ,
92
- memory. get ( & ( address + bfe ! ( 1 ) ) ) ,
93
- memory. get ( & ( address + bfe ! ( 2 ) ) ) ,
94
- ]
95
- . map ( |b| b. copied ( ) . unwrap_or ( BFieldElement :: ZERO ) ) ;
96
- xfe ! ( coefficients)
97
- }
98
-
99
- let mut array_pointer_b = stack. pop ( ) . unwrap ( ) ;
100
- let mut array_pointer_a = stack. pop ( ) . unwrap ( ) ;
101
- let mut acc = XFieldElement :: zero ( ) ;
102
- for _ in 0 ..self . length {
103
- let element_a = read_xfe ( memory, array_pointer_a) ;
104
- let element_b = read_xfe ( memory, array_pointer_b) ;
105
- acc += element_a * element_b;
106
- array_pointer_b += bfe ! ( 3 ) ;
107
- array_pointer_a += bfe ! ( 3 ) ;
108
- }
109
-
110
- for word in acc. coefficients . into_iter ( ) . rev ( ) {
111
- stack. push ( word) ;
112
- }
94
+ let b = array_from_memory :: < XFieldElement > ( stack. pop ( ) . unwrap ( ) , self . length , memory) ;
95
+ let a = array_from_memory :: < XFieldElement > ( stack. pop ( ) . unwrap ( ) , self . length , memory) ;
96
+ let inner_product: XFieldElement = a. into_iter ( ) . zip ( b) . map ( |( a, b) | a * b) . sum ( ) ;
97
+
98
+ push_encodable ( stack, & inner_product) ;
113
99
}
114
100
115
101
fn pseudorandom_initial_state (
116
102
& self ,
117
103
seed : [ u8 ; 32 ] ,
118
- _bench_case : Option < BenchmarkCase > ,
119
- ) -> FunctionInitialState {
104
+ _ : Option < BenchmarkCase > ,
105
+ ) -> AccessorInitialState {
120
106
let mut rng = StdRng :: from_seed ( seed) ;
121
- let array_pointer_a = BFieldElement :: new ( rng. gen ( ) ) ;
122
- let mut array_pointer_b = BFieldElement :: new ( rng. gen ( ) ) ;
123
- while array_pointer_a. value ( ) . abs_diff ( array_pointer_b. value ( ) )
124
- < EXTENSION_DEGREE as u64 * self . length as u64
125
- {
126
- array_pointer_b = BFieldElement :: new ( rng. gen ( ) ) ;
127
- }
128
-
129
- self . prepare_state ( array_pointer_a, array_pointer_b)
107
+ let pointer_a = rng. gen ( ) ;
108
+ let pointer_b_offset = rng. gen_range ( self . length ..usize:: MAX - self . length ) ;
109
+ let pointer_b = pointer_a + bfe ! ( pointer_b_offset) ;
110
+
111
+ let mut memory = HashMap :: default ( ) ;
112
+ insert_random_array ( & DataType :: Xfe , pointer_a, self . length , & mut memory) ;
113
+ insert_random_array ( & DataType :: Xfe , pointer_b, self . length , & mut memory) ;
114
+
115
+ let mut stack = self . init_stack_for_isolated_run ( ) ;
116
+ stack. push ( pointer_a) ;
117
+ stack. push ( pointer_b) ;
118
+
119
+ AccessorInitialState { stack, memory }
130
120
}
131
121
132
- fn corner_case_initial_states ( & self ) -> Vec < FunctionInitialState > {
133
- let all_zeros = {
134
- let init_stack = [
135
- self . init_stack_for_isolated_run ( ) ,
136
- vec ! [ BFieldElement :: new( 0 ) , BFieldElement :: new( 1 << 40 ) ] ,
137
- ]
138
- . concat ( ) ;
139
- FunctionInitialState {
140
- stack : init_stack,
141
- memory : HashMap :: default ( ) ,
142
- }
122
+ fn corner_case_initial_states ( & self ) -> Vec < AccessorInitialState > {
123
+ let all_zeros = AccessorInitialState {
124
+ stack : [ self . init_stack_for_isolated_run ( ) , bfe_vec ! [ 0 , 1_u64 << 40 ] ] . concat ( ) ,
125
+ memory : HashMap :: default ( ) ,
143
126
} ;
144
127
145
128
vec ! [ all_zeros]
146
129
}
147
130
}
148
131
149
- impl InnerProductOfXfes {
150
- fn prepare_state (
151
- & self ,
152
- array_pointer_a : BFieldElement ,
153
- array_pointer_b : BFieldElement ,
154
- ) -> FunctionInitialState {
155
- let mut memory = HashMap :: default ( ) ;
156
- insert_random_array ( & DataType :: Xfe , array_pointer_a, self . length , & mut memory) ;
157
- insert_random_array ( & DataType :: Xfe , array_pointer_b, self . length , & mut memory) ;
158
-
159
- let mut init_stack = self . init_stack_for_isolated_run ( ) ;
160
- init_stack. push ( array_pointer_a) ;
161
- init_stack. push ( array_pointer_b) ;
162
- FunctionInitialState {
163
- stack : init_stack,
164
- memory,
165
- }
166
- }
167
- }
168
-
169
132
#[ test]
170
133
fn inner_product_of_xfes_pbt ( ) {
171
- let snippets = ( 0 ..20 )
172
- . chain ( 100 ..110 )
173
- . map ( |x| InnerProductOfXfes { length : x } ) ;
174
- for test_case in snippets {
175
- ShadowedFunction :: new ( test_case) . test ( )
134
+ for test_case in ( 0 ..20 ) . chain ( 100 ..110 ) . map ( InnerProductOfXfes :: new) {
135
+ ShadowedAccessor :: new ( test_case) . test ( )
176
136
}
177
137
}
178
138
179
139
#[ test]
180
140
fn inner_product_unit_test ( ) {
181
- let a = vec ! [
182
- XFieldElement :: new( [
183
- BFieldElement :: new( 3 ) ,
184
- BFieldElement :: zero( ) ,
185
- BFieldElement :: zero( ) ,
186
- ] ) ,
187
- XFieldElement :: new( [
188
- BFieldElement :: new( 5 ) ,
189
- BFieldElement :: zero( ) ,
190
- BFieldElement :: zero( ) ,
191
- ] ) ,
192
- ] ;
193
- let b = vec ! [
194
- XFieldElement :: new( [
195
- BFieldElement :: new( 501 ) ,
196
- BFieldElement :: zero( ) ,
197
- BFieldElement :: zero( ) ,
198
- ] ) ,
199
- XFieldElement :: new( [
200
- BFieldElement :: new( 1003 ) ,
201
- BFieldElement :: zero( ) ,
202
- BFieldElement :: zero( ) ,
203
- ] ) ,
204
- ] ;
205
-
206
- let expected_inner_product = XFieldElement :: new ( [
207
- BFieldElement :: new ( 3 * 501 + 5 * 1003 ) ,
208
- BFieldElement :: zero ( ) ,
209
- BFieldElement :: zero ( ) ,
210
- ] ) ;
211
- assert_eq ! (
212
- expected_inner_product,
213
- a. iter( )
214
- . cloned( )
215
- . zip( b. iter( ) . cloned( ) )
216
- . map( |( a, b) | a * b)
217
- . sum:: <XFieldElement >( )
218
- ) ;
141
+ let a = xfe_vec ! [ [ 3 , 0 , 0 ] , [ 5 , 0 , 0 ] ] ;
142
+ let b = xfe_vec ! [ [ 501 , 0 , 0 ] , [ 1003 , 0 , 0 ] ] ;
143
+ let inner_product = xfe ! ( [ 3 * 501 + 5 * 1003 , 0 , 0 ] ) ;
144
+
145
+ let rust_inner_product = a
146
+ . iter ( )
147
+ . zip ( & b)
148
+ . map ( |( & a, & b) | a * b)
149
+ . sum :: < XFieldElement > ( ) ;
150
+ debug_assert_eq ! ( inner_product, rust_inner_product) ;
219
151
220
152
let mut memory = HashMap :: default ( ) ;
221
- let array_pointer_a = BFieldElement :: new ( 1u64 << 44 ) ;
222
- insert_as_array ( array_pointer_a, & mut memory, a) ;
223
- let array_pointer_b = BFieldElement :: new ( 1u64 << 45 ) ;
224
- insert_as_array ( array_pointer_b, & mut memory, b) ;
225
-
226
- let snippet = InnerProductOfXfes { length : 2 } ;
227
- let expected_final_stack = [
228
- snippet. init_stack_for_isolated_run ( ) ,
229
- expected_inner_product
230
- . coefficients
231
- . into_iter ( )
232
- . rev ( )
233
- . collect_vec ( ) ,
234
- ]
235
- . concat ( ) ;
236
- let init_stack = [
237
- snippet. init_stack_for_isolated_run ( ) ,
238
- vec ! [ array_pointer_a, array_pointer_b] ,
239
- ]
240
- . concat ( ) ;
153
+ let pointer_a = bfe ! ( 1_u64 << 44 ) ;
154
+ let pointer_b = bfe ! ( 1_u64 << 45 ) ;
155
+ insert_as_array ( pointer_a, & mut memory, a) ;
156
+ insert_as_array ( pointer_b, & mut memory, b) ;
157
+
158
+ let snippet = InnerProductOfXfes :: new ( 2 ) ;
159
+ let mut initial_stack = snippet. init_stack_for_isolated_run ( ) ;
160
+ initial_stack. push ( pointer_a) ;
161
+ initial_stack. push ( pointer_b) ;
162
+
163
+ let mut expected_final_stack = snippet. init_stack_for_isolated_run ( ) ;
164
+ push_encodable ( & mut expected_final_stack, & inner_product) ;
165
+
241
166
test_rust_equivalence_given_complete_state (
242
- & ShadowedFunction :: new ( snippet) ,
243
- & init_stack ,
167
+ & ShadowedAccessor :: new ( snippet) ,
168
+ & initial_stack ,
244
169
& [ ] ,
245
170
& NonDeterminism :: default ( ) . with_ram ( memory) ,
246
171
& None ,
@@ -259,28 +184,12 @@ mod benches {
259
184
use crate :: test_prelude:: * ;
260
185
261
186
#[ test]
262
- fn inner_product_xfes_bench_100 ( ) {
263
- ShadowedFunction :: new ( InnerProductOfXfes { length : 100 } ) . bench ( ) ;
264
- }
265
-
266
- #[ test]
267
- fn inner_product_xfes_bench_200 ( ) {
268
- ShadowedFunction :: new ( InnerProductOfXfes { length : 200 } ) . bench ( ) ;
269
- }
187
+ fn benchmark ( ) {
188
+ ShadowedAccessor :: new ( InnerProductOfXfes :: new ( 100 ) ) . bench ( ) ;
189
+ ShadowedAccessor :: new ( InnerProductOfXfes :: new ( 200 ) ) . bench ( ) ;
270
190
271
- #[ test]
272
- fn inner_product_xfes_bench_num_columns_current_tvm ( ) {
273
- ShadowedFunction :: new ( InnerProductOfXfes {
274
- length : MasterMainTable :: NUM_COLUMNS + MasterAuxTable :: NUM_COLUMNS ,
275
- } )
276
- . bench ( ) ;
277
- }
278
-
279
- #[ test]
280
- fn inner_product_xfes_bench_num_constraints_current_tvm ( ) {
281
- ShadowedFunction :: new ( InnerProductOfXfes {
282
- length : MasterAuxTable :: NUM_CONSTRAINTS ,
283
- } )
284
- . bench ( ) ;
191
+ let num_columns = MasterMainTable :: NUM_COLUMNS + MasterAuxTable :: NUM_COLUMNS ;
192
+ ShadowedAccessor :: new ( InnerProductOfXfes :: new ( num_columns) ) . bench ( ) ;
193
+ ShadowedAccessor :: new ( InnerProductOfXfes :: new ( MasterAuxTable :: NUM_CONSTRAINTS ) ) . bench ( ) ;
285
194
}
286
195
}
0 commit comments