|
1 |
| -use rand::prelude::*; |
| 1 | +use std::collections::HashMap; |
| 2 | + |
2 | 3 | use triton_vm::prelude::*;
|
3 | 4 |
|
4 |
| -use crate::empty_stack; |
5 | 5 | use crate::prelude::*;
|
6 |
| -use crate::traits::deprecated_snippet::DeprecatedSnippet; |
7 |
| -use crate::InitVmState; |
8 |
| - |
9 |
| -#[derive(Clone, Debug)] |
| 6 | +use crate::traits::basic_snippet::Reviewer; |
| 7 | +use crate::traits::basic_snippet::SignOffFingerprint; |
| 8 | + |
| 9 | +/// Add two `u32`s and crash on overflow. |
| 10 | +/// |
| 11 | +/// ### Behavior |
| 12 | +/// |
| 13 | +/// ```text |
| 14 | +/// BEFORE: _ [right: 32] [left: u32] |
| 15 | +/// AFTER: _ [left + right: u32] |
| 16 | +/// ``` |
| 17 | +/// |
| 18 | +/// ### Preconditions |
| 19 | +/// |
| 20 | +/// - all input arguments are properly [`BFieldCodec`] encoded |
| 21 | +/// - the sum of `left` and `right` is less than or equal to [`u32::MAX`] |
| 22 | +/// |
| 23 | +/// ### Postconditions |
| 24 | +/// |
| 25 | +/// - the output is the sum of the input |
| 26 | +/// - the output is properly [`BFieldCodec`] encoded |
| 27 | +#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)] |
10 | 28 | pub struct SafeAdd;
|
11 | 29 |
|
12 |
| -impl DeprecatedSnippet for SafeAdd { |
13 |
| - fn entrypoint_name(&self) -> String { |
14 |
| - "tasmlib_arithmetic_u32_safe_add".to_string() |
15 |
| - } |
16 |
| - |
17 |
| - fn input_field_names(&self) -> Vec<String> { |
18 |
| - vec!["rhs".to_string(), "lhs".to_string()] |
19 |
| - } |
20 |
| - |
21 |
| - fn input_types(&self) -> Vec<DataType> { |
22 |
| - vec![DataType::U32, DataType::U32] |
23 |
| - } |
24 |
| - |
25 |
| - fn output_field_names(&self) -> Vec<String> { |
26 |
| - vec!["lhs + rhs".to_string()] |
27 |
| - } |
28 |
| - |
29 |
| - fn output_types(&self) -> Vec<DataType> { |
30 |
| - vec![DataType::U32] |
31 |
| - } |
32 |
| - |
33 |
| - fn stack_diff(&self) -> isize { |
34 |
| - -1 |
35 |
| - } |
36 |
| - |
37 |
| - fn function_code(&self, _library: &mut crate::library::Library) -> String { |
38 |
| - let entrypoint = self.entrypoint_name(); |
39 |
| - format!( |
40 |
| - " |
41 |
| - {entrypoint}: |
42 |
| - add // _ lhs + rhs |
43 |
| - dup 0 // _ (lhs + rhs) (lhs + rhs) |
44 |
| - split // _ (lhs + rhs) hi lo |
45 |
| - pop 1 // _ (lhs + rhs) hi |
46 |
| - push 0 // _ (lhs + rhs) hi 0 |
47 |
| - eq // _ (lhs + rhs) (hi == 0) |
48 |
| - assert // _ (lhs + rhs) |
49 |
| - return |
50 |
| - " |
51 |
| - ) |
52 |
| - } |
| 30 | +impl SafeAdd { |
| 31 | + pub const OVERFLOW_ERROR_ID: i128 = 450; |
| 32 | +} |
53 | 33 |
|
54 |
| - fn crash_conditions(&self) -> Vec<String> { |
55 |
| - vec!["u32 overflow".to_string()] |
| 34 | +impl BasicSnippet for SafeAdd { |
| 35 | + fn inputs(&self) -> Vec<(DataType, String)> { |
| 36 | + ["right", "left"] |
| 37 | + .map(|s| (DataType::U32, s.to_string())) |
| 38 | + .to_vec() |
56 | 39 | }
|
57 | 40 |
|
58 |
| - fn gen_input_states(&self) -> Vec<InitVmState> { |
59 |
| - let mut ret: Vec<InitVmState> = vec![]; |
60 |
| - for _ in 0..10 { |
61 |
| - let mut stack = empty_stack(); |
62 |
| - let lhs = thread_rng().gen_range(0..u32::MAX / 2); |
63 |
| - let rhs = thread_rng().gen_range(0..u32::MAX / 2); |
64 |
| - let lhs = BFieldElement::new(lhs as u64); |
65 |
| - let rhs = BFieldElement::new(rhs as u64); |
66 |
| - stack.push(rhs); |
67 |
| - stack.push(lhs); |
68 |
| - ret.push(InitVmState::with_stack(stack)); |
69 |
| - } |
70 |
| - |
71 |
| - ret |
| 41 | + fn outputs(&self) -> Vec<(DataType, String)> { |
| 42 | + vec![(DataType::U32, "left + right".to_string())] |
72 | 43 | }
|
73 | 44 |
|
74 |
| - fn common_case_input_state(&self) -> InitVmState { |
75 |
| - InitVmState::with_stack( |
76 |
| - [ |
77 |
| - empty_stack(), |
78 |
| - vec![BFieldElement::new(1 << 16), BFieldElement::new(1 << 15)], |
79 |
| - ] |
80 |
| - .concat(), |
81 |
| - ) |
| 45 | + fn entrypoint(&self) -> String { |
| 46 | + "tasmlib_arithmetic_u32_safe_add".to_string() |
82 | 47 | }
|
83 | 48 |
|
84 |
| - fn worst_case_input_state(&self) -> InitVmState { |
85 |
| - InitVmState::with_stack( |
86 |
| - [ |
87 |
| - empty_stack(), |
88 |
| - vec![ |
89 |
| - BFieldElement::new((1 << 30) - 1), |
90 |
| - BFieldElement::new((1 << 31) - 1), |
91 |
| - ], |
92 |
| - ] |
93 |
| - .concat(), |
| 49 | + fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> { |
| 50 | + triton_asm!( |
| 51 | + {self.entrypoint()}: |
| 52 | + add // _ sum |
| 53 | + dup 0 // _ sum sum |
| 54 | + split // _ sum hi lo |
| 55 | + pop 1 // _ sum hi |
| 56 | + push 0 // _ sum hi 0 |
| 57 | + eq // _ sum (hi == 0) |
| 58 | + assert error_id {Self::OVERFLOW_ERROR_ID} |
| 59 | + return |
94 | 60 | )
|
95 | 61 | }
|
96 | 62 |
|
97 |
| - fn rust_shadowing( |
98 |
| - &self, |
99 |
| - stack: &mut Vec<BFieldElement>, |
100 |
| - _std_in: Vec<BFieldElement>, |
101 |
| - _secret_in: Vec<BFieldElement>, |
102 |
| - _memory: &mut std::collections::HashMap<BFieldElement, BFieldElement>, |
103 |
| - ) { |
104 |
| - let lhs: u32 = stack.pop().unwrap().try_into().unwrap(); |
105 |
| - let rhs: u32 = stack.pop().unwrap().try_into().unwrap(); |
106 |
| - |
107 |
| - let sum = lhs + rhs; |
108 |
| - stack.push(BFieldElement::new(sum as u64)); |
| 63 | + fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> { |
| 64 | + let mut sign_offs = HashMap::new(); |
| 65 | + sign_offs.insert(Reviewer("ferdinand"), 0x92d8f083ca55e1d.into()); |
| 66 | + sign_offs |
109 | 67 | }
|
110 | 68 | }
|
111 | 69 |
|
112 | 70 | #[cfg(test)]
|
113 | 71 | mod tests {
|
114 |
| - use std::collections::HashMap; |
| 72 | + use super::*; |
| 73 | + use crate::test_prelude::*; |
115 | 74 |
|
116 |
| - use num::Zero; |
| 75 | + impl Closure for SafeAdd { |
| 76 | + type Args = (u32, u32); |
117 | 77 |
|
118 |
| - use super::*; |
119 |
| - use crate::test_helpers::test_rust_equivalence_given_input_values_deprecated; |
120 |
| - use crate::test_helpers::test_rust_equivalence_multiple_deprecated; |
| 78 | + fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) { |
| 79 | + let (right, left) = pop_encodable::<Self::Args>(stack); |
| 80 | + let sum = left.checked_add(right).unwrap(); |
| 81 | + push_encodable(stack, &sum); |
| 82 | + } |
121 | 83 |
|
122 |
| - #[test] |
123 |
| - fn snippet_test() { |
124 |
| - test_rust_equivalence_multiple_deprecated(&SafeAdd, true); |
125 |
| - } |
| 84 | + fn pseudorandom_args( |
| 85 | + &self, |
| 86 | + seed: [u8; 32], |
| 87 | + bench_case: Option<BenchmarkCase>, |
| 88 | + ) -> Self::Args { |
| 89 | + let Some(bench_case) = bench_case else { |
| 90 | + let mut rng = StdRng::from_seed(seed); |
| 91 | + let left = rng.gen(); |
| 92 | + let right = rng.gen_range(0..=u32::MAX - left); |
| 93 | + |
| 94 | + return (right, left); |
| 95 | + }; |
| 96 | + |
| 97 | + match bench_case { |
| 98 | + BenchmarkCase::CommonCase => (1 << 16, 1 << 15), |
| 99 | + BenchmarkCase::WorstCase => (u32::MAX >> 1, u32::MAX >> 2), |
| 100 | + } |
| 101 | + } |
126 | 102 |
|
127 |
| - #[test] |
128 |
| - fn safe_add_simple_test() { |
129 |
| - prop_safe_add(1000, 1, Some(1001)); |
130 |
| - prop_safe_add(10_000, 900, Some(10_900)); |
| 103 | + fn corner_case_args(&self) -> Vec<Self::Args> { |
| 104 | + vec![(0, u32::MAX)] |
| 105 | + } |
131 | 106 | }
|
132 | 107 |
|
133 |
| - #[should_panic] |
134 | 108 | #[test]
|
135 |
| - fn overflow_test() { |
136 |
| - prop_safe_add((1 << 31) + 1000, 1 << 31, None); |
| 109 | + fn rust_shadow() { |
| 110 | + ShadowedClosure::new(SafeAdd).test(); |
137 | 111 | }
|
138 | 112 |
|
139 |
| - fn prop_safe_add(lhs: u32, rhs: u32, _expected: Option<u32>) { |
140 |
| - let mut init_stack = empty_stack(); |
141 |
| - init_stack.push(BFieldElement::new(rhs as u64)); |
142 |
| - init_stack.push(BFieldElement::new(lhs as u64)); |
143 |
| - |
144 |
| - let expected = lhs.checked_add(rhs); |
145 |
| - let expected = [ |
146 |
| - empty_stack(), |
147 |
| - vec![expected |
148 |
| - .map(|x| BFieldElement::new(x as u64)) |
149 |
| - .unwrap_or_else(BFieldElement::zero)], |
150 |
| - ] |
151 |
| - .concat(); |
152 |
| - |
153 |
| - test_rust_equivalence_given_input_values_deprecated::<SafeAdd>( |
154 |
| - &SafeAdd, |
155 |
| - &init_stack, |
156 |
| - &[], |
157 |
| - HashMap::default(), |
158 |
| - Some(&expected), |
159 |
| - ); |
| 113 | + #[proptest] |
| 114 | + fn overflow_crashes_vm( |
| 115 | + #[filter(#left != 0)] left: u32, |
| 116 | + #[strategy(u32::MAX - #left + 1..)] right: u32, |
| 117 | + ) { |
| 118 | + debug_assert!(left.checked_add(right).is_none()); |
| 119 | + test_assertion_failure( |
| 120 | + &ShadowedClosure::new(SafeAdd), |
| 121 | + InitVmState::with_stack(SafeAdd.set_up_test_stack((left, right))), |
| 122 | + &[SafeAdd::OVERFLOW_ERROR_ID], |
| 123 | + ) |
160 | 124 | }
|
161 | 125 | }
|
162 | 126 |
|
163 | 127 | #[cfg(test)]
|
164 | 128 | mod benches {
|
165 | 129 | use super::*;
|
166 |
| - use crate::snippet_bencher::bench_and_write; |
| 130 | + use crate::test_prelude::*; |
167 | 131 |
|
168 | 132 | #[test]
|
169 | 133 | fn safe_add_benchmark() {
|
170 |
| - bench_and_write(SafeAdd); |
| 134 | + ShadowedClosure::new(SafeAdd).bench(); |
171 | 135 | }
|
172 | 136 | }
|
0 commit comments