Skip to content

Commit 2f9d911

Browse files
authored
feat!: switch sort key type to u32, upgrade sort dependency (#20)
1 parent 5f3a49e commit 2f9d911

File tree

4 files changed

+84
-108
lines changed

4 files changed

+84
-108
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ on:
88

99
env:
1010
CARGO_TERM_COLOR: always
11-
MINIMUM_NOIR_VERSION: v0.36.0
11+
MINIMUM_NOIR_VERSION: v1.0.0-beta.4
1212

1313
jobs:
1414
noir-version-list:

Nargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ authors = [""]
55
compiler_version = ">=0.36.0"
66

77
[dependencies]
8-
sort = { tag = "v0.2.3", git = "https://github.com/noir-lang/noir_sort" }
8+
sort = { tag = "v0.3.0", git = "https://github.com/noir-lang/noir_sort" }

src/lib.nr

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
mod mut_sparse_array;
22
use dep::sort::sort_advanced;
33

4-
unconstrained fn __sort_field_as_u32(lhs: Field, rhs: Field) -> bool {
4+
unconstrained fn __sort(lhs: u32, rhs: u32) -> bool {
55
// lhs.lt(rhs)
6-
lhs as u32 < rhs as u32
6+
lhs < rhs
77
}
88

9-
fn assert_sorted(lhs: Field, rhs: Field) {
10-
let result = (rhs - lhs - 1);
11-
result.assert_max_bit_size::<32>();
9+
fn assert_sorted(lhs: u32, rhs: u32) {
10+
assert(lhs < rhs);
1211
}
1312

1413
/**
@@ -24,10 +23,10 @@ fn assert_sorted(lhs: Field, rhs: Field) {
2423
**/
2524
struct MutSparseArrayBase<let N: u32, T, ComparisonFuncs> {
2625
values: [T; N + 3],
27-
keys: [Field; N + 2],
28-
linked_keys: [Field; N + 2],
29-
tail_ptr: Field,
30-
maximum: Field,
26+
keys: [u32; N + 2],
27+
linked_keys: [u32; N + 2],
28+
tail_ptr: u32,
29+
maximum: u32,
3130
}
3231

3332
struct U32RangeTraits {}
@@ -47,9 +46,9 @@ pub struct MutSparseArray<let N: u32, T> {
4746
* 2. values[0] is an empty object. when calling `get(idx)`, if `idx` is not in `keys` we will return `values[0]`
4847
**/
4948
pub struct SparseArray<let N: u32, T> {
50-
keys: [Field; N + 2],
49+
keys: [u32; N + 2],
5150
values: [T; N + 3],
52-
maximum: Field, // can be up to 2^32
51+
maximum: u32, // can be up to 2^32 - 1
5352
}
5453
impl<let N: u32, T> SparseArray<N, T>
5554
where
@@ -59,15 +58,16 @@ where
5958
/**
6059
* @brief construct a SparseArray
6160
**/
62-
pub(crate) fn create(_keys: [Field; N], _values: [T; N], size: Field) -> Self {
61+
pub(crate) fn create(_keys: [u32; N], _values: [T; N], size: u32) -> Self {
62+
assert(size >= 1);
6363
let _maximum = size - 1;
6464
let mut r: Self =
6565
SparseArray { keys: [0; N + 2], values: [T::default(); N + 3], maximum: _maximum };
6666

6767
// for any valid index, we want to ensure the following is satified:
6868
// self.keys[X] <= index <= self.keys[X+1]
6969
// this requires us to sort hte keys, and insert a startpoint and endpoint
70-
let sorted_keys = sort_advanced(_keys, __sort_field_as_u32, assert_sorted);
70+
let sorted_keys = sort_advanced(_keys, __sort, assert_sorted);
7171

7272
// insert start and endpoints
7373
r.keys[0] = 0;
@@ -103,45 +103,41 @@ where
103103
// because `self.keys` is sorted, we can simply validate that
104104
// sorted_keys.sorted[0] < 2^32
105105
// sorted_keys.sorted[N-1] < maximum
106-
sorted_keys.sorted[0].assert_max_bit_size::<32>();
107-
_maximum.assert_max_bit_size::<32>();
108-
(_maximum - sorted_keys.sorted[N - 1]).assert_max_bit_size::<32>();
106+
assert(_maximum >= sorted_keys.sorted[N - 1]);
109107
r
110108
}
111109

112110
/**
113111
* @brief determine whether `target` is present in `self.keys`
114112
* @details if `found == false`, `self.keys[found_index] < target < self.keys[found_index + 1]`
115113
**/
116-
unconstrained fn search_for_key(self, target: Field) -> (Field, Field) {
114+
unconstrained fn search_for_key(self, target: u32) -> (bool, u32) {
117115
let mut found = false;
118-
let mut found_index = 0;
116+
let mut found_index: u32 = 0;
119117
let mut previous_less_than_or_equal_to_target = false;
120118
for i in 0..N + 2 {
121119
// if target = 0xffffffff we need to be able to add 1 here, so use u64
122120
let current_less_than_or_equal_to_target = self.keys[i] as u64 <= target as u64;
123121
if (self.keys[i] == target) {
124122
found = true;
125-
found_index = i as Field;
123+
found_index = i;
126124
break;
127125
}
128126
if (previous_less_than_or_equal_to_target & !current_less_than_or_equal_to_target) {
129-
found_index = i as Field - 1;
127+
found_index = i - 1;
130128
break;
131129
}
132130
previous_less_than_or_equal_to_target = current_less_than_or_equal_to_target;
133131
}
134-
(found as Field, found_index)
132+
(found, found_index)
135133
}
136134

137135
/**
138136
* @brief return element `idx` from the sparse array
139137
* @details cost is 14.5 gates per lookup
140138
**/
141-
fn get(self, idx: Field) -> T {
139+
fn get(self, idx: u32) -> T {
142140
let (found, found_index) = unsafe { self.search_for_key(idx) };
143-
// bool check. 0.25 gates cheaper than a raw `bool` type. need to fix at some point
144-
assert(found * found == found);
145141

146142
// OK! So we have the following cases to check
147143
// 1. if `found` then `self.keys[found_index] == idx`
@@ -152,15 +148,13 @@ where
152148
// combine the two into the following single statement:
153149
// `self.keys[found_index] + 1 - found <= idx <= self.keys[found_index + 1 - found] - 1 + found
154150
let lhs = self.keys[found_index];
155-
let rhs = self.keys[found_index + 1 - found];
156-
let lhs_condition = idx - lhs - 1 + found;
157-
let rhs_condition = rhs - 1 + found - idx;
158-
lhs_condition.assert_max_bit_size::<32>();
159-
rhs_condition.assert_max_bit_size::<32>();
151+
let rhs = self.keys[found_index + 1 - found as u32];
152+
assert(lhs + 1 - found as u32 <= idx);
153+
assert(idx <= rhs + found as u32 - 1);
160154

161155
// self.keys[i] maps to self.values[i+1]
162156
// however...if we did not find a non-sparse entry, we want to return self.values[0] (the default value)
163-
let value_index = (found_index + 1) * found;
157+
let value_index = (found_index + 1) * found as u32;
164158
self.values[value_index]
165159
}
166160
}
@@ -179,7 +173,7 @@ mod test {
179173

180174
for i in 0..100 {
181175
if ((i != 1) & (i != 5) & (i != 7) & (i != 99)) {
182-
assert(example.get(i as Field) == 0);
176+
assert(example.get(i) == 0);
183177
}
184178
}
185179
}
@@ -188,34 +182,35 @@ mod test {
188182
fn test_sparse_lookup_boundary_cases() {
189183
// what about when keys[0] = 0 and keys[N-1] = 2^32 - 1?
190184
let example = SparseArray::create(
191-
[0, 99999, 7, 0xffffffff],
185+
[0, 99999, 7, 0xfffffffe],
192186
[123, 101112, 789, 456],
193-
0x100000000,
187+
0xffffffff,
194188
);
195189

196190
assert(example.get(0) == 123);
197191
assert(example.get(99999) == 101112);
198192
assert(example.get(7) == 789);
199-
assert(example.get(0xffffffff) == 456);
200-
assert(example.get(0xfffffffe) == 0);
193+
assert(example.get(0xfffffffe) == 456);
194+
assert(example.get(0xfffffffd) == 0);
201195
}
202196

203-
#[test(should_fail_with = "call to assert_max_bit_size")]
197+
#[test(should_fail)]
204198
fn test_sparse_lookup_overflow() {
205199
let example = SparseArray::create([1, 5, 7, 99999], [123, 456, 789, 101112], 100000);
206200

207201
assert(example.get(100000) == 0);
208202
}
209203

204+
/**
210205
#[test(should_fail_with = "call to assert_max_bit_size")]
211206
fn test_sparse_lookup_boundary_case_overflow() {
212207
let example =
213208
SparseArray::create([0, 5, 7, 0xffffffff], [123, 456, 789, 101112], 0x100000000);
214209
215210
assert(example.get(0x100000000) == 0);
216211
}
217-
218-
#[test(should_fail_with = "call to assert_max_bit_size")]
212+
**/
213+
#[test(should_fail)]
219214
fn test_sparse_lookup_key_exceeds_maximum() {
220215
let example =
221216
SparseArray::create([0, 5, 7, 0xffffffff], [123, 456, 789, 101112], 0xffffffff);
@@ -236,7 +231,7 @@ mod test {
236231

237232
for i in 0..100 {
238233
if ((i != 1) & (i != 5) & (i != 7) & (i != 99)) {
239-
assert(example.get(i as Field) == 0);
234+
assert(example.get(i) == 0);
240235
}
241236
}
242237
}
@@ -272,7 +267,7 @@ mod test {
272267
assert(example.get(99) == values[1]);
273268
for i in 0..100 {
274269
if ((i != 1) & (i != 5) & (i != 7) & (i != 99)) {
275-
assert(example.get(i as Field) == F::default());
270+
assert(example.get(i) == F::default());
276271
}
277272
}
278273
}

0 commit comments

Comments
 (0)