|
1 |
| -use rand::prelude::*; |
| 1 | +use std::collections::HashMap; |
| 2 | + |
2 | 3 | use triton_vm::prelude::*;
|
3 |
| -use twenty_first::prelude::U32s; |
4 | 4 |
|
5 |
| -use crate::empty_stack; |
6 | 5 | use crate::prelude::*;
|
7 |
| -use crate::push_encodable; |
8 |
| -use crate::traits::deprecated_snippet::DeprecatedSnippet; |
9 |
| -use crate::InitVmState; |
| 6 | +use crate::traits::basic_snippet::Reviewer; |
| 7 | +use crate::traits::basic_snippet::SignOffFingerprint; |
10 | 8 |
|
11 | 9 | #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
|
12 | 10 | pub struct WrappingMul;
|
13 | 11 |
|
14 |
| -impl DeprecatedSnippet for WrappingMul { |
15 |
| - fn entrypoint_name(&self) -> String { |
16 |
| - "tasmlib_arithmetic_u64_wrapping_mul".to_string() |
17 |
| - } |
18 |
| - |
19 |
| - fn input_field_names(&self) -> Vec<String> { |
20 |
| - vec![ |
21 |
| - "lhs_hi".to_string(), |
22 |
| - "lhs_lo".to_string(), |
23 |
| - "rhs_hi".to_string(), |
24 |
| - "rhs_lo".to_string(), |
25 |
| - ] |
26 |
| - } |
27 |
| - |
28 |
| - fn input_types(&self) -> Vec<DataType> { |
29 |
| - vec![DataType::U64, DataType::U64] |
30 |
| - } |
31 |
| - |
32 |
| - fn output_field_names(&self) -> Vec<String> { |
33 |
| - vec!["prod_hi".to_string(), "prod_lo".to_string()] |
| 12 | +impl BasicSnippet for WrappingMul { |
| 13 | + fn inputs(&self) -> Vec<(DataType, String)> { |
| 14 | + ["rhs", "lhs"] |
| 15 | + .map(|side| (DataType::U64, side.to_string())) |
| 16 | + .to_vec() |
34 | 17 | }
|
35 | 18 |
|
36 |
| - fn output_types(&self) -> Vec<DataType> { |
37 |
| - vec![DataType::U64] |
| 19 | + fn outputs(&self) -> Vec<(DataType, String)> { |
| 20 | + vec![(DataType::U64, "product".to_string())] |
38 | 21 | }
|
39 | 22 |
|
40 |
| - fn stack_diff(&self) -> isize { |
41 |
| - -2 |
| 23 | + fn entrypoint(&self) -> String { |
| 24 | + "tasmlib_arithmetic_u64_wrapping_mul".to_string() |
42 | 25 | }
|
43 | 26 |
|
44 |
| - fn function_code(&self, _library: &mut Library) -> String { |
45 |
| - let entrypoint = self.entrypoint_name(); |
46 |
| - |
47 |
| - format!( |
48 |
| - " |
49 |
| - // BEFORE: _ rhs_hi rhs_lo lhs_hi lhs_lo |
50 |
| - // AFTER: _ prod_hi prod_lo |
51 |
| - {entrypoint}: |
52 |
| - // `lhs_lo * rhs_lo`: |
53 |
| - dup 0 dup 3 mul |
54 |
| - // _ rhs_hi rhs_lo lhs_hi lhs_lo (lhs_lo * rhs_lo) |
55 |
| -
|
56 |
| - // `rhs_hi * lhs_lo` (consume `rhs_hi` and `lhs_lo`): |
57 |
| - swap 4 |
58 |
| - mul |
59 |
| - // _ (lhs_lo * rhs_lo) rhs_lo lhs_hi (lhs_lo * rhs_hi) |
60 |
| -
|
61 |
| - // `rhs_lo * lhs_hi` (consume `rhs_lo` and `lhs_hi`): |
62 |
| - swap 2 |
63 |
| - mul |
64 |
| - // _ (lhs_lo * rhs_lo) (lhs_lo * rhs_hi) (lhs_hi * rhs_lo) |
65 |
| -
|
66 |
| - // rename to: a, b, c: |
67 |
| - // _ a b c |
68 |
| -
|
69 |
| - // Calculate `prod_hi = a_hi + b_lo + c_lo`: |
70 |
| - split |
71 |
| - swap 1 |
72 |
| - pop 1 |
73 |
| - // _ a b c_lo |
74 |
| -
|
75 |
| - swap 1 |
76 |
| - split |
77 |
| - swap 1 |
78 |
| - pop 1 |
79 |
| - // _ a c_lo b_lo |
80 |
| -
|
81 |
| - swap 2 |
82 |
| - split |
83 |
| - // _ b_lo c_lo a_hi a_lo |
84 |
| -
|
85 |
| - swap 3 |
86 |
| - // _ a_lo c_lo a_hi b_lo |
87 |
| -
|
88 |
| - add |
89 |
| - add |
90 |
| - // _ a_lo (c_lo + a_hi + b_lo) |
91 |
| -
|
92 |
| - split |
93 |
| - swap 1 |
94 |
| - pop 1 |
95 |
| - // _ a_lo (c_lo + a_hi + b_lo)_lo |
96 |
| -
|
97 |
| - swap 1 |
98 |
| - // _ (c_lo + a_hi + b_lo)_lo a_lo |
99 |
| -
|
100 |
| - // _ prod_hi prod_lo |
101 |
| -
|
102 |
| - return |
103 |
| - " |
| 27 | + fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> { |
| 28 | + triton_asm!( |
| 29 | + // BEFORE: _ right_hi right_lo left_hi left_lo |
| 30 | + // AFTER: _ prod_hi prod_lo |
| 31 | + {self.entrypoint()}: |
| 32 | + /* left_lo · right_lo */ |
| 33 | + dup 0 |
| 34 | + dup 3 |
| 35 | + mul |
| 36 | + // _ right_hi right_lo left_hi left_lo (left_lo · right_lo) |
| 37 | + |
| 38 | + /* left_lo · right_hi (consume both) */ |
| 39 | + swap 4 |
| 40 | + mul |
| 41 | + // _ (left_lo · right_lo) right_lo left_hi (left_lo · right_hi) |
| 42 | + |
| 43 | + /* left_hi · right_lo (consume both) */ |
| 44 | + swap 2 |
| 45 | + mul |
| 46 | + // _ (left_lo · right_lo) (left_lo · right_hi) (left_hi · right_lo) |
| 47 | + // _ lolo lohi hilo |
| 48 | + |
| 49 | + /* prod_hi = lolo_hi + lohi_lo + hilo_lo */ |
| 50 | + split |
| 51 | + pick 1 |
| 52 | + pop 1 |
| 53 | + // _ lolo lohi hilo_lo |
| 54 | + |
| 55 | + pick 1 |
| 56 | + split |
| 57 | + pick 1 |
| 58 | + pop 1 |
| 59 | + // _ lolo hilo_lo lohi_lo |
| 60 | + |
| 61 | + pick 2 |
| 62 | + split |
| 63 | + // _ hilo_lo lohi_lo lolo_hi lolo_lo |
| 64 | + // _ hilo_lo lohi_lo lolo_hi prod_lo |
| 65 | + |
| 66 | + place 3 |
| 67 | + add |
| 68 | + add |
| 69 | + // _ prod_lo (hilo_lo + lohi_lo + lolo_hi) |
| 70 | + |
| 71 | + split |
| 72 | + pick 1 |
| 73 | + pop 1 |
| 74 | + // _ prod_lo (hilo_lo + lohi_lo + lolo_hi)_lo |
| 75 | + // _ prod_lo prod_hi |
| 76 | + |
| 77 | + place 1 |
| 78 | + return |
104 | 79 | )
|
105 | 80 | }
|
106 | 81 |
|
107 |
| - fn crash_conditions(&self) -> Vec<String> { |
108 |
| - todo!() |
109 |
| - } |
110 |
| - |
111 |
| - fn gen_input_states(&self) -> Vec<InitVmState> { |
112 |
| - let mut rng = rand::thread_rng(); |
113 |
| - |
114 |
| - let mut ret = vec![]; |
115 |
| - for _ in 0..10 { |
116 |
| - ret.push(prepare_state(rng.next_u64(), rng.next_u64())); |
117 |
| - } |
118 |
| - |
119 |
| - ret |
120 |
| - } |
121 |
| - |
122 |
| - fn common_case_input_state(&self) -> InitVmState { |
123 |
| - prepare_state(1 << 60, (1 << 42) - 1) |
124 |
| - } |
125 |
| - |
126 |
| - fn worst_case_input_state(&self) -> InitVmState { |
127 |
| - prepare_state(1 << 60, (1 << 42) - 1) |
128 |
| - } |
129 |
| - |
130 |
| - fn rust_shadowing( |
131 |
| - &self, |
132 |
| - stack: &mut Vec<BFieldElement>, |
133 |
| - _std_in: Vec<BFieldElement>, |
134 |
| - _secret_in: Vec<BFieldElement>, |
135 |
| - _memory: &mut std::collections::HashMap<BFieldElement, BFieldElement>, |
136 |
| - ) { |
137 |
| - // top element on stack |
138 |
| - let a_lo: u32 = stack.pop().unwrap().try_into().unwrap(); |
139 |
| - let a_hi: u32 = stack.pop().unwrap().try_into().unwrap(); |
140 |
| - let a = ((a_hi as u64) << 32) + a_lo as u64; |
141 |
| - |
142 |
| - let b_lo: u32 = stack.pop().unwrap().try_into().unwrap(); |
143 |
| - let b_hi: u32 = stack.pop().unwrap().try_into().unwrap(); |
144 |
| - let b = ((b_hi as u64) << 32) + b_lo as u64; |
145 |
| - |
146 |
| - let prod = a.wrapping_mul(b); |
147 |
| - |
148 |
| - stack.push(BFieldElement::new(prod >> 32)); |
149 |
| - stack.push(BFieldElement::new(prod & u32::MAX as u64)); |
| 82 | + fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> { |
| 83 | + let mut sign_offs = HashMap::new(); |
| 84 | + sign_offs.insert(Reviewer("ferdinand"), 0x98526c7c401009ed.into()); |
| 85 | + sign_offs |
150 | 86 | }
|
151 | 87 | }
|
152 | 88 |
|
153 |
| -fn prepare_state(a: u64, b: u64) -> InitVmState { |
154 |
| - let a = U32s::<2>::try_from(a).unwrap(); |
155 |
| - let b = U32s::<2>::try_from(b).unwrap(); |
156 |
| - let mut init_stack = empty_stack(); |
157 |
| - push_encodable(&mut init_stack, &a); |
158 |
| - push_encodable(&mut init_stack, &b); |
159 |
| - InitVmState::with_stack(init_stack) |
160 |
| -} |
161 |
| - |
162 | 89 | #[cfg(test)]
|
163 | 90 | mod tests {
|
164 |
| - use std::collections::HashMap; |
| 91 | + use super::*; |
| 92 | + use crate::test_prelude::*; |
165 | 93 |
|
166 |
| - use num::Zero; |
| 94 | + impl Closure for WrappingMul { |
| 95 | + type Args = (u64, u64); |
167 | 96 |
|
168 |
| - use super::*; |
169 |
| - use crate::empty_stack; |
170 |
| - use crate::test_helpers::test_rust_equivalence_given_input_values_deprecated; |
171 |
| - use crate::test_helpers::test_rust_equivalence_multiple_deprecated; |
| 97 | + fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) { |
| 98 | + let (right, left) = pop_encodable::<Self::Args>(stack); |
| 99 | + push_encodable(stack, &left.wrapping_mul(right)); |
| 100 | + } |
172 | 101 |
|
173 |
| - #[test] |
174 |
| - fn wrapping_mul_u64_test() { |
175 |
| - test_rust_equivalence_multiple_deprecated(&WrappingMul, true); |
| 102 | + fn pseudorandom_args( |
| 103 | + &self, |
| 104 | + seed: [u8; 32], |
| 105 | + bench_case: Option<BenchmarkCase>, |
| 106 | + ) -> Self::Args { |
| 107 | + match bench_case { |
| 108 | + Some(BenchmarkCase::CommonCase) => (1 << 31, (1 << 25) - 1), |
| 109 | + Some(BenchmarkCase::WorstCase) => (1 << 53, (1 << 33) - 1), |
| 110 | + None => StdRng::from_seed(seed).gen(), |
| 111 | + } |
| 112 | + } |
176 | 113 | }
|
177 | 114 |
|
178 | 115 | #[test]
|
179 |
| - fn wrapping_mul_u64_simple() { |
180 |
| - let mut init_stack = empty_stack(); |
181 |
| - init_stack.push(BFieldElement::zero()); |
182 |
| - init_stack.push(BFieldElement::new(100)); |
183 |
| - init_stack.push(BFieldElement::zero()); |
184 |
| - init_stack.push(BFieldElement::new(200)); |
185 |
| - |
186 |
| - let mut expected = empty_stack(); |
187 |
| - expected.push(BFieldElement::zero()); |
188 |
| - expected.push(BFieldElement::new(20_000)); |
189 |
| - test_rust_equivalence_given_input_values_deprecated( |
190 |
| - &WrappingMul, |
191 |
| - &init_stack, |
192 |
| - &[], |
193 |
| - HashMap::default(), |
194 |
| - Some(&expected), |
195 |
| - ); |
| 116 | + fn rust_shadow() { |
| 117 | + ShadowedClosure::new(WrappingMul).test(); |
196 | 118 | }
|
197 | 119 | }
|
198 | 120 |
|
199 | 121 | #[cfg(test)]
|
200 | 122 | mod benches {
|
201 | 123 | use super::*;
|
202 |
| - use crate::snippet_bencher::bench_and_write; |
| 124 | + use crate::test_prelude::*; |
203 | 125 |
|
204 | 126 | #[test]
|
205 |
| - fn wrappingmul_u64_benchmark() { |
206 |
| - bench_and_write(WrappingMul); |
| 127 | + fn benchmark() { |
| 128 | + ShadowedClosure::new(WrappingMul).bench(); |
207 | 129 | }
|
208 | 130 | }
|
0 commit comments