|
1 |
| -use rand::prelude::*; |
| 1 | +use std::collections::HashMap; |
| 2 | + |
2 | 3 | use triton_vm::prelude::*;
|
3 | 4 |
|
4 |
| -use crate::empty_stack; |
5 | 5 | use crate::prelude::*;
|
6 |
| -use crate::traits::deprecated_snippet::DeprecatedSnippet; |
7 |
| -use crate::InitVmState; |
8 |
| - |
9 |
| -/// If the inputs, are valid u32s, then the output is guaranteed to be to. |
10 |
| -/// Crashes on overflow. |
11 |
| -#[derive(Clone, Debug)] |
| 6 | +use crate::traits::basic_snippet::Reviewer; |
| 7 | +use crate::traits::basic_snippet::SignOffFingerprint; |
| 8 | + |
| 9 | +/// Multiply two `u32`s and crash on overflow. |
| 10 | +/// |
| 11 | +/// ### Behavior |
| 12 | +/// |
| 13 | +/// ```text |
| 14 | +/// BEFORE: _ [right: 32] [left: u32] |
| 15 | +/// AFTER: _ [left · right: u32] |
| 16 | +/// ``` |
| 17 | +/// |
| 18 | +/// ### Preconditions |
| 19 | +/// |
| 20 | +/// - all input arguments are properly [`BFieldCodec`] encoded |
| 21 | +/// - the product of `left` and `right` is less than or equal to [`u32::MAX`] |
| 22 | +/// |
| 23 | +/// ### Postconditions |
| 24 | +/// |
| 25 | +/// - the output is the product of the input |
| 26 | +/// - the output is properly [`BFieldCodec`] encoded |
| 27 | +#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)] |
12 | 28 | pub struct SafeMul;
|
13 | 29 |
|
14 |
| -impl DeprecatedSnippet for SafeMul { |
15 |
| - fn entrypoint_name(&self) -> String { |
16 |
| - "tasmlib_arithmetic_u32_safe_mul".to_string() |
17 |
| - } |
18 |
| - |
19 |
| - fn input_field_names(&self) -> Vec<String> { |
20 |
| - vec!["lhs".to_string(), "rhs".to_string()] |
21 |
| - } |
22 |
| - |
23 |
| - fn input_types(&self) -> Vec<DataType> { |
24 |
| - vec![DataType::U32, DataType::U32] |
25 |
| - } |
26 |
| - |
27 |
| - fn output_field_names(&self) -> Vec<String> { |
28 |
| - vec!["lhs * rhs".to_string()] |
29 |
| - } |
30 |
| - |
31 |
| - fn output_types(&self) -> Vec<DataType> { |
32 |
| - vec![DataType::U32] |
33 |
| - } |
34 |
| - |
35 |
| - fn stack_diff(&self) -> isize { |
36 |
| - -1 |
37 |
| - } |
38 |
| - |
39 |
| - fn function_code(&self, _library: &mut crate::library::Library) -> String { |
40 |
| - let entrypoint = self.entrypoint_name(); |
41 |
| - format!( |
42 |
| - " |
43 |
| - // BEFORE: _ rhs lhs |
44 |
| - // AFTER: _ (lhs * rhs) |
45 |
| - {entrypoint}: |
46 |
| - mul |
47 |
| - dup 0 // _ (lhs * rhs) (lhs * rhs) |
48 |
| - split // _ (lhs * rhs) hi lo |
49 |
| - pop 1 // _ (lhs * rhs) hi |
50 |
| - push 0 // _ (lhs * rhs) hi 0 |
51 |
| - eq // _ (lhs * rhs) (hi == 0) |
52 |
| - assert // _ (lhs * rhs) |
53 |
| - return |
54 |
| - " |
55 |
| - ) |
56 |
| - } |
| 30 | +impl SafeMul { |
| 31 | + pub const OVERFLOW_ERROR_ID: i128 = 460; |
| 32 | +} |
57 | 33 |
|
58 |
| - fn crash_conditions(&self) -> Vec<String> { |
59 |
| - vec!["result overflows u32".to_string()] |
| 34 | +impl BasicSnippet for SafeMul { |
| 35 | + fn inputs(&self) -> Vec<(DataType, String)> { |
| 36 | + ["right", "left"] |
| 37 | + .map(|s| (DataType::U32, s.to_string())) |
| 38 | + .to_vec() |
60 | 39 | }
|
61 | 40 |
|
62 |
| - fn gen_input_states(&self) -> Vec<InitVmState> { |
63 |
| - let mut ret: Vec<InitVmState> = vec![]; |
64 |
| - for _ in 0..10 { |
65 |
| - let mut stack = empty_stack(); |
66 |
| - let lhs = thread_rng().gen_range(0..(1 << 16)); |
67 |
| - let rhs = thread_rng().gen_range(0..(1 << 16)); |
68 |
| - let lhs = BFieldElement::new(lhs as u64); |
69 |
| - let rhs = BFieldElement::new(rhs as u64); |
70 |
| - stack.push(lhs); |
71 |
| - stack.push(rhs); |
72 |
| - ret.push(InitVmState::with_stack(stack)); |
73 |
| - } |
74 |
| - |
75 |
| - ret |
| 41 | + fn outputs(&self) -> Vec<(DataType, String)> { |
| 42 | + vec![(DataType::U32, "left · right".to_string())] |
76 | 43 | }
|
77 | 44 |
|
78 |
| - fn common_case_input_state(&self) -> InitVmState { |
79 |
| - InitVmState::with_stack( |
80 |
| - [ |
81 |
| - empty_stack(), |
82 |
| - vec![BFieldElement::new(1 << 8), BFieldElement::new(1 << 9)], |
83 |
| - ] |
84 |
| - .concat(), |
85 |
| - ) |
| 45 | + fn entrypoint(&self) -> String { |
| 46 | + "tasmlib_arithmetic_u32_safe_mul".to_string() |
86 | 47 | }
|
87 | 48 |
|
88 |
| - fn worst_case_input_state(&self) -> InitVmState { |
89 |
| - InitVmState::with_stack( |
90 |
| - [ |
91 |
| - empty_stack(), |
92 |
| - vec![ |
93 |
| - BFieldElement::new((1 << 15) - 1), |
94 |
| - BFieldElement::new((1 << 16) - 1), |
95 |
| - ], |
96 |
| - ] |
97 |
| - .concat(), |
| 49 | + fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> { |
| 50 | + triton_asm!( |
| 51 | + // BEFORE: _ right left |
| 52 | + // AFTER: _ product |
| 53 | + {self.entrypoint()}: |
| 54 | + mul |
| 55 | + dup 0 // _ product product |
| 56 | + split // _ product hi lo |
| 57 | + pop 1 // _ product hi |
| 58 | + push 0 // _ product hi 0 |
| 59 | + eq // _ product (hi == 0) |
| 60 | + assert error_id {Self::OVERFLOW_ERROR_ID} |
| 61 | + return |
98 | 62 | )
|
99 | 63 | }
|
100 | 64 |
|
101 |
| - fn rust_shadowing( |
102 |
| - &self, |
103 |
| - stack: &mut Vec<BFieldElement>, |
104 |
| - _std_in: Vec<BFieldElement>, |
105 |
| - _secret_in: Vec<BFieldElement>, |
106 |
| - _memory: &mut std::collections::HashMap<BFieldElement, BFieldElement>, |
107 |
| - ) { |
108 |
| - let lhs: u32 = stack.pop().unwrap().try_into().unwrap(); |
109 |
| - let rhs: u32 = stack.pop().unwrap().try_into().unwrap(); |
110 |
| - |
111 |
| - let prod = lhs * rhs; |
112 |
| - stack.push(BFieldElement::new(prod as u64)); |
| 65 | + fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> { |
| 66 | + let mut sign_offs = HashMap::new(); |
| 67 | + sign_offs.insert(Reviewer("ferdinand"), 0x3836d772ff7b6165.into()); |
| 68 | + sign_offs |
113 | 69 | }
|
114 | 70 | }
|
115 | 71 |
|
116 | 72 | #[cfg(test)]
|
117 | 73 | mod tests {
|
118 |
| - use std::collections::HashMap; |
119 |
| - |
120 |
| - use num::Zero; |
121 |
| - |
122 | 74 | use super::*;
|
123 |
| - use crate::test_helpers::test_rust_equivalence_given_input_values_deprecated; |
124 |
| - use crate::test_helpers::test_rust_equivalence_multiple_deprecated; |
| 75 | + use crate::test_prelude::*; |
125 | 76 |
|
126 |
| - #[test] |
127 |
| - fn snippet_test() { |
128 |
| - test_rust_equivalence_multiple_deprecated(&SafeMul, true); |
129 |
| - } |
| 77 | + impl Closure for SafeMul { |
| 78 | + type Args = (u32, u32); |
130 | 79 |
|
131 |
| - #[test] |
132 |
| - fn safe_sub_simple_test() { |
133 |
| - prop_safe_mul(1000, 1, Some(1000)); |
134 |
| - prop_safe_mul(10_000, 900, Some(9_000_000)); |
135 |
| - prop_safe_mul(1, 1, Some(1)); |
136 |
| - prop_safe_mul(10_000, 10_000, Some(100_000_000)); |
137 |
| - prop_safe_mul(u32::MAX, 1, Some(u32::MAX)); |
138 |
| - prop_safe_mul(1, u32::MAX, Some(u32::MAX)); |
139 |
| - } |
| 80 | + fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) { |
| 81 | + let (right, left) = pop_encodable::<Self::Args>(stack); |
| 82 | + let product = left.checked_mul(right).unwrap(); |
| 83 | + push_encodable(stack, &product); |
| 84 | + } |
140 | 85 |
|
141 |
| - #[should_panic] |
142 |
| - #[test] |
143 |
| - fn overflow_test() { |
144 |
| - prop_safe_mul(1 << 16, 1 << 16, None); |
145 |
| - } |
| 86 | + fn pseudorandom_args( |
| 87 | + &self, |
| 88 | + seed: [u8; 32], |
| 89 | + bench_case: Option<BenchmarkCase>, |
| 90 | + ) -> Self::Args { |
| 91 | + let Some(bench_case) = bench_case else { |
| 92 | + let mut rng = StdRng::from_seed(seed); |
| 93 | + let left = rng.gen_range(1..=u32::MAX); |
| 94 | + let right = rng.gen_range(0..=u32::MAX / left); |
| 95 | + |
| 96 | + return (right, left); |
| 97 | + }; |
| 98 | + |
| 99 | + match bench_case { |
| 100 | + BenchmarkCase::CommonCase => (1 << 8, 1 << 9), |
| 101 | + BenchmarkCase::WorstCase => (1 << 15, 1 << 16), |
| 102 | + } |
| 103 | + } |
146 | 104 |
|
147 |
| - #[should_panic] |
148 |
| - #[test] |
149 |
| - fn overflow_test_2() { |
150 |
| - prop_safe_mul(1 << 31, 2, None); |
| 105 | + fn corner_case_args(&self) -> Vec<Self::Args> { |
| 106 | + [0, 1] |
| 107 | + .into_iter() |
| 108 | + .cartesian_product([0, 1, u32::MAX]) |
| 109 | + .collect() |
| 110 | + } |
151 | 111 | }
|
152 | 112 |
|
153 |
| - #[should_panic] |
154 | 113 | #[test]
|
155 |
| - fn overflow_test_3() { |
156 |
| - prop_safe_mul(2, 1 << 31, None); |
| 114 | + fn rust_shadow() { |
| 115 | + ShadowedClosure::new(SafeMul).test(); |
157 | 116 | }
|
158 | 117 |
|
159 |
| - fn prop_safe_mul(lhs: u32, rhs: u32, _expected: Option<u32>) { |
160 |
| - let mut init_stack = empty_stack(); |
161 |
| - init_stack.push(BFieldElement::new(rhs as u64)); |
162 |
| - init_stack.push(BFieldElement::new(lhs as u64)); |
163 |
| - |
164 |
| - let expected = lhs.checked_mul(rhs); |
165 |
| - let expected = [ |
166 |
| - empty_stack(), |
167 |
| - vec![expected |
168 |
| - .map(|x| BFieldElement::new(x as u64)) |
169 |
| - .unwrap_or_else(BFieldElement::zero)], |
170 |
| - ] |
171 |
| - .concat(); |
172 |
| - |
173 |
| - test_rust_equivalence_given_input_values_deprecated( |
174 |
| - &SafeMul, |
175 |
| - &init_stack, |
176 |
| - &[], |
177 |
| - HashMap::default(), |
178 |
| - Some(&expected), |
179 |
| - ); |
| 118 | + #[proptest] |
| 119 | + fn overflow_crashes_vm( |
| 120 | + #[strategy(1_u32..)] left: u32, |
| 121 | + #[strategy(u32::MAX / #left..)] |
| 122 | + #[filter(#left.checked_mul(#right).is_none())] |
| 123 | + right: u32, |
| 124 | + ) { |
| 125 | + test_assertion_failure( |
| 126 | + &ShadowedClosure::new(SafeMul), |
| 127 | + InitVmState::with_stack(SafeMul.set_up_test_stack((left, right))), |
| 128 | + &[SafeMul::OVERFLOW_ERROR_ID], |
| 129 | + ) |
180 | 130 | }
|
181 | 131 | }
|
182 | 132 |
|
183 | 133 | #[cfg(test)]
|
184 | 134 | mod benches {
|
185 | 135 | use super::*;
|
186 |
| - use crate::snippet_bencher::bench_and_write; |
| 136 | + use crate::test_prelude::*; |
187 | 137 |
|
188 | 138 | #[test]
|
189 | 139 | fn safe_mul_benchmark() {
|
190 |
| - bench_and_write(SafeMul); |
| 140 | + ShadowedClosure::new(SafeMul).bench(); |
191 | 141 | }
|
192 | 142 | }
|
0 commit comments