1
1
mod mut_sparse_array ;
2
2
use dep::sort::sort_advanced ;
3
3
4
- unconstrained fn __sort_field_as_u32 (lhs : Field , rhs : Field ) -> bool {
4
+ unconstrained fn __sort (lhs : u32 , rhs : u32 ) -> bool {
5
5
// lhs.lt(rhs)
6
- lhs as u32 < rhs as u32
6
+ lhs < rhs
7
7
}
8
8
9
- fn assert_sorted (lhs : Field , rhs : Field ) {
10
- let result = (rhs - lhs - 1 );
11
- result .assert_max_bit_size ::<32 >();
9
+ fn assert_sorted (lhs : u32 , rhs : u32 ) {
10
+ assert (lhs < rhs );
12
11
}
13
12
14
13
/**
@@ -24,10 +23,10 @@ fn assert_sorted(lhs: Field, rhs: Field) {
24
23
**/
25
24
struct MutSparseArrayBase <let N : u32 , T , ComparisonFuncs > {
26
25
values : [T ; N + 3 ],
27
- keys : [Field ; N + 2 ],
28
- linked_keys : [Field ; N + 2 ],
29
- tail_ptr : Field ,
30
- maximum : Field ,
26
+ keys : [u32 ; N + 2 ],
27
+ linked_keys : [u32 ; N + 2 ],
28
+ tail_ptr : u32 ,
29
+ maximum : u32 ,
31
30
}
32
31
33
32
struct U32RangeTraits {}
@@ -47,9 +46,9 @@ pub struct MutSparseArray<let N: u32, T> {
47
46
* 2. values[0] is an empty object. when calling `get(idx)`, if `idx` is not in `keys` we will return `values[0]`
48
47
**/
49
48
pub struct SparseArray <let N : u32 , T > {
50
- keys : [Field ; N + 2 ],
49
+ keys : [u32 ; N + 2 ],
51
50
values : [T ; N + 3 ],
52
- maximum : Field , // can be up to 2^32
51
+ maximum : u32 , // can be up to 2^32 - 1
53
52
}
54
53
impl <let N : u32 , T > SparseArray <N , T >
55
54
where
@@ -59,15 +58,16 @@ where
59
58
/**
60
59
* @brief construct a SparseArray
61
60
**/
62
- pub (crate ) fn create (_keys : [Field ; N ], _values : [T ; N ], size : Field ) -> Self {
61
+ pub (crate ) fn create (_keys : [u32 ; N ], _values : [T ; N ], size : u32 ) -> Self {
62
+ assert (size >= 1 );
63
63
let _maximum = size - 1 ;
64
64
let mut r : Self =
65
65
SparseArray { keys : [0 ; N + 2 ], values : [T ::default (); N + 3 ], maximum : _maximum };
66
66
67
67
// for any valid index, we want to ensure the following is satified:
68
68
// self.keys[X] <= index <= self.keys[X+1]
69
69
// this requires us to sort hte keys, and insert a startpoint and endpoint
70
- let sorted_keys = sort_advanced (_keys , __sort_field_as_u32 , assert_sorted );
70
+ let sorted_keys = sort_advanced (_keys , __sort , assert_sorted );
71
71
72
72
// insert start and endpoints
73
73
r .keys [0 ] = 0 ;
@@ -103,45 +103,41 @@ where
103
103
// because `self.keys` is sorted, we can simply validate that
104
104
// sorted_keys.sorted[0] < 2^32
105
105
// sorted_keys.sorted[N-1] < maximum
106
- sorted_keys .sorted [0 ].assert_max_bit_size ::<32 >();
107
- _maximum .assert_max_bit_size ::<32 >();
108
- (_maximum - sorted_keys .sorted [N - 1 ]).assert_max_bit_size ::<32 >();
106
+ assert (_maximum >= sorted_keys .sorted [N - 1 ]);
109
107
r
110
108
}
111
109
112
110
/**
113
111
* @brief determine whether `target` is present in `self.keys`
114
112
* @details if `found == false`, `self.keys[found_index] < target < self.keys[found_index + 1]`
115
113
**/
116
- unconstrained fn search_for_key (self , target : Field ) -> (Field , Field ) {
114
+ unconstrained fn search_for_key (self , target : u32 ) -> (bool , u32 ) {
117
115
let mut found = false ;
118
- let mut found_index = 0 ;
116
+ let mut found_index : u32 = 0 ;
119
117
let mut previous_less_than_or_equal_to_target = false ;
120
118
for i in 0 ..N + 2 {
121
119
// if target = 0xffffffff we need to be able to add 1 here, so use u64
122
120
let current_less_than_or_equal_to_target = self .keys [i ] as u64 <= target as u64 ;
123
121
if (self .keys [i ] == target ) {
124
122
found = true ;
125
- found_index = i as Field ;
123
+ found_index = i ;
126
124
break ;
127
125
}
128
126
if (previous_less_than_or_equal_to_target & !current_less_than_or_equal_to_target ) {
129
- found_index = i as Field - 1 ;
127
+ found_index = i - 1 ;
130
128
break ;
131
129
}
132
130
previous_less_than_or_equal_to_target = current_less_than_or_equal_to_target ;
133
131
}
134
- (found as Field , found_index )
132
+ (found , found_index )
135
133
}
136
134
137
135
/**
138
136
* @brief return element `idx` from the sparse array
139
137
* @details cost is 14.5 gates per lookup
140
138
**/
141
- fn get (self , idx : Field ) -> T {
139
+ fn get (self , idx : u32 ) -> T {
142
140
let (found , found_index ) = unsafe { self .search_for_key (idx ) };
143
- // bool check. 0.25 gates cheaper than a raw `bool` type. need to fix at some point
144
- assert (found * found == found );
145
141
146
142
// OK! So we have the following cases to check
147
143
// 1. if `found` then `self.keys[found_index] == idx`
@@ -152,15 +148,13 @@ where
152
148
// combine the two into the following single statement:
153
149
// `self.keys[found_index] + 1 - found <= idx <= self.keys[found_index + 1 - found] - 1 + found
154
150
let lhs = self .keys [found_index ];
155
- let rhs = self .keys [found_index + 1 - found ];
156
- let lhs_condition = idx - lhs - 1 + found ;
157
- let rhs_condition = rhs - 1 + found - idx ;
158
- lhs_condition .assert_max_bit_size ::<32 >();
159
- rhs_condition .assert_max_bit_size ::<32 >();
151
+ let rhs = self .keys [found_index + 1 - found as u32 ];
152
+ assert (lhs + 1 - found as u32 <= idx );
153
+ assert (idx <= rhs + found as u32 - 1 );
160
154
161
155
// self.keys[i] maps to self.values[i+1]
162
156
// however...if we did not find a non-sparse entry, we want to return self.values[0] (the default value)
163
- let value_index = (found_index + 1 ) * found ;
157
+ let value_index = (found_index + 1 ) * found as u32 ;
164
158
self .values [value_index ]
165
159
}
166
160
}
@@ -179,7 +173,7 @@ mod test {
179
173
180
174
for i in 0 ..100 {
181
175
if ((i != 1 ) & (i != 5 ) & (i != 7 ) & (i != 99 )) {
182
- assert (example .get (i as Field ) == 0 );
176
+ assert (example .get (i ) == 0 );
183
177
}
184
178
}
185
179
}
@@ -188,34 +182,35 @@ mod test {
188
182
fn test_sparse_lookup_boundary_cases () {
189
183
// what about when keys[0] = 0 and keys[N-1] = 2^32 - 1?
190
184
let example = SparseArray ::create (
191
- [0 , 99999 , 7 , 0xffffffff ],
185
+ [0 , 99999 , 7 , 0xfffffffe ],
192
186
[123 , 101112 , 789 , 456 ],
193
- 0x100000000 ,
187
+ 0xffffffff ,
194
188
);
195
189
196
190
assert (example .get (0 ) == 123 );
197
191
assert (example .get (99999 ) == 101112 );
198
192
assert (example .get (7 ) == 789 );
199
- assert (example .get (0xffffffff ) == 456 );
200
- assert (example .get (0xfffffffe ) == 0 );
193
+ assert (example .get (0xfffffffe ) == 456 );
194
+ assert (example .get (0xfffffffd ) == 0 );
201
195
}
202
196
203
- #[test(should_fail_with = "call to assert_max_bit_size" )]
197
+ #[test(should_fail )]
204
198
fn test_sparse_lookup_overflow () {
205
199
let example = SparseArray ::create ([1 , 5 , 7 , 99999 ], [123 , 456 , 789 , 101112 ], 100000 );
206
200
207
201
assert (example .get (100000 ) == 0 );
208
202
}
209
203
204
+ /**
210
205
#[test(should_fail_with = "call to assert_max_bit_size")]
211
206
fn test_sparse_lookup_boundary_case_overflow() {
212
207
let example =
213
208
SparseArray::create([0, 5, 7, 0xffffffff], [123, 456, 789, 101112], 0x100000000);
214
209
215
210
assert(example.get(0x100000000) == 0);
216
211
}
217
-
218
- #[test(should_fail_with = "call to assert_max_bit_size" )]
212
+ **/
213
+ #[test(should_fail )]
219
214
fn test_sparse_lookup_key_exceeds_maximum () {
220
215
let example =
221
216
SparseArray ::create ([0 , 5 , 7 , 0xffffffff ], [123 , 456 , 789 , 101112 ], 0xffffffff );
@@ -236,7 +231,7 @@ mod test {
236
231
237
232
for i in 0 ..100 {
238
233
if ((i != 1 ) & (i != 5 ) & (i != 7 ) & (i != 99 )) {
239
- assert (example .get (i as Field ) == 0 );
234
+ assert (example .get (i ) == 0 );
240
235
}
241
236
}
242
237
}
@@ -272,7 +267,7 @@ mod test {
272
267
assert (example .get (99 ) == values [1 ]);
273
268
for i in 0 ..100 {
274
269
if ((i != 1 ) & (i != 5 ) & (i != 7 ) & (i != 99 )) {
275
- assert (example .get (i as Field ) == F ::default ());
270
+ assert (example .get (i ) == F ::default ());
276
271
}
277
272
}
278
273
}
0 commit comments