Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion src/lib.nr
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub fn sort<T, let N: u32>(input: [T; N]) -> [T; N]
where
T: std::cmp::Ord + std::cmp::Eq,
{
// Safety: validated using check_shuffle. See its documentation for more info.
let sorted = unsafe { quicksort(input) };

for i in 0..N - 1 {
Expand All @@ -36,6 +37,7 @@ pub fn sort_via<T, let N: u32>(input: [T; N], sortfn: fn(T, T) -> bool) -> [T; N
where
T: std::cmp::Eq,
{
// Safety: validated using check_shuffle. See its documentation for more info.
let sorted = unsafe { quicksort_explicit(input, sortfn) };

for i in 0..N - 1 {
Expand Down Expand Up @@ -67,6 +69,7 @@ pub fn sort_extended<T, let N: u32>(
where
T: std::cmp::Eq,
{
// Safety: validated using check_shuffle. See its documentation for more info.
let sorted = unsafe { quicksort_explicit(input, sortfn) };

for i in 0..N - 1 {
Expand All @@ -80,6 +83,7 @@ pub struct SortResult<T, let N: u32> {
pub sorted: [T; N],
pub sort_indices: [Field; N],
}

pub fn sort_advanced<T, let N: u32>(
input: [T; N],
sortfn: unconstrained fn(T, T) -> bool,
Expand All @@ -88,6 +92,7 @@ pub fn sort_advanced<T, let N: u32>(
where
T: std::cmp::Eq,
{
// Safety: validated using check_shuffle. See its documentation for more info.
let sorted = unsafe { quicksort_explicit(input, sortfn) };

let sort_indices = get_shuffle_indices(input, sorted);
Expand All @@ -98,11 +103,29 @@ where
SortResult { sorted, sort_indices }
}

global arr_with_100_values: [u32; 100] = [
42, 123, 87, 93, 48, 80, 50, 5, 104, 84, 70, 47, 119, 66, 71, 121, 3, 29, 42, 118, 2, 54, 89,
44, 81, 0, 26, 106, 68, 96, 84, 48, 95, 54, 45, 32, 89, 100, 109, 19, 37, 41, 19, 98, 53, 114,
107, 66, 6, 74, 13, 19, 105, 64, 123, 28, 44, 50, 89, 58, 123, 126, 21, 43, 86, 35, 21, 62, 82,
0, 108, 120, 72, 72, 62, 80, 12, 71, 70, 86, 116, 73, 38, 15, 127, 81, 30, 8, 125, 28, 26, 69,
114, 63, 27, 28, 61, 42, 13, 32,
];
global expected_with_100_values: [u32; 100] = [
0, 0, 2, 3, 5, 6, 8, 12, 13, 13, 15, 19, 19, 19, 21, 21, 26, 26, 27, 28, 28, 28, 29, 30, 32, 32,
35, 37, 38, 41, 42, 42, 42, 43, 44, 44, 45, 47, 48, 48, 50, 50, 53, 54, 54, 58, 61, 62, 62, 63,
64, 66, 66, 68, 69, 70, 70, 71, 71, 72, 72, 73, 74, 80, 80, 81, 81, 82, 84, 84, 86, 86, 87, 89,
89, 89, 93, 95, 96, 98, 100, 104, 105, 106, 107, 108, 109, 114, 114, 116, 118, 119, 120, 121,
123, 123, 123, 125, 126, 127,
];

mod test {
use crate::sort;
use crate::sort_extended;
use crate::sort_via;

use crate::arr_with_100_values;
use crate::expected_with_100_values;

fn sort_u32(a: u32, b: u32) -> bool {
a <= b
}
Expand Down Expand Up @@ -131,6 +154,34 @@ mod test {
assert(sorted == expected);
}

#[test]
fn test_sort_100_values() {
let mut arr: [u32; 100] = [
42, 123, 87, 93, 48, 80, 50, 5, 104, 84, 70, 47, 119, 66, 71, 121, 3, 29, 42, 118, 2,
54, 89, 44, 81, 0, 26, 106, 68, 96, 84, 48, 95, 54, 45, 32, 89, 100, 109, 19, 37, 41,
19, 98, 53, 114, 107, 66, 6, 74, 13, 19, 105, 64, 123, 28, 44, 50, 89, 58, 123, 126, 21,
43, 86, 35, 21, 62, 82, 0, 108, 120, 72, 72, 62, 80, 12, 71, 70, 86, 116, 73, 38, 15,
127, 81, 30, 8, 125, 28, 26, 69, 114, 63, 27, 28, 61, 42, 13, 32,
];

let sorted = sort(arr);

let expected: [u32; 100] = [
0, 0, 2, 3, 5, 6, 8, 12, 13, 13, 15, 19, 19, 19, 21, 21, 26, 26, 27, 28, 28, 28, 29, 30,
32, 32, 35, 37, 38, 41, 42, 42, 42, 43, 44, 44, 45, 47, 48, 48, 50, 50, 53, 54, 54, 58,
61, 62, 62, 63, 64, 66, 66, 68, 69, 70, 70, 71, 71, 72, 72, 73, 74, 80, 80, 81, 81, 82,
84, 84, 86, 86, 87, 89, 89, 89, 93, 95, 96, 98, 100, 104, 105, 106, 107, 108, 109, 114,
114, 116, 118, 119, 120, 121, 123, 123, 123, 125, 126, 127,
];
assert(sorted == expected);
}

#[test]
fn test_sort_100_values_comptime() {
let sorted = sort(arr_with_100_values);
assert(sorted == expected_with_100_values);
}

#[test]
fn test_sort_via() {
let mut arr: [u32; 7] = [3, 6, 8, 10, 1, 2, 1];
Expand All @@ -141,6 +192,28 @@ mod test {
assert(sorted == expected);
}

#[test]
fn test_sort_via_100_values() {
let mut arr: [u32; 100] = [
42, 123, 87, 93, 48, 80, 50, 5, 104, 84, 70, 47, 119, 66, 71, 121, 3, 29, 42, 118, 2,
54, 89, 44, 81, 0, 26, 106, 68, 96, 84, 48, 95, 54, 45, 32, 89, 100, 109, 19, 37, 41,
19, 98, 53, 114, 107, 66, 6, 74, 13, 19, 105, 64, 123, 28, 44, 50, 89, 58, 123, 126, 21,
43, 86, 35, 21, 62, 82, 0, 108, 120, 72, 72, 62, 80, 12, 71, 70, 86, 116, 73, 38, 15,
127, 81, 30, 8, 125, 28, 26, 69, 114, 63, 27, 28, 61, 42, 13, 32,
];

let sorted = sort_via(arr, sort_u32);

let expected: [u32; 100] = [
0, 0, 2, 3, 5, 6, 8, 12, 13, 13, 15, 19, 19, 19, 21, 21, 26, 26, 27, 28, 28, 28, 29, 30,
32, 32, 35, 37, 38, 41, 42, 42, 42, 43, 44, 44, 45, 47, 48, 48, 50, 50, 53, 54, 54, 58,
61, 62, 62, 63, 64, 66, 66, 68, 69, 70, 70, 71, 71, 72, 72, 73, 74, 80, 80, 81, 81, 82,
84, 84, 86, 86, 87, 89, 89, 89, 93, 95, 96, 98, 100, 104, 105, 106, 107, 108, 109, 114,
114, 116, 118, 119, 120, 121, 123, 123, 123, 125, 126, 127,
];
assert(sorted == expected);
}

#[test]
fn test_sort_extended() {
let mut arr: [u32; 7] = [3, 6, 8, 10, 1, 2, 1];
Expand All @@ -151,4 +224,3 @@ mod test {
assert(sorted == expected);
}
}

56 changes: 35 additions & 21 deletions src/quicksort/quicksort.nr
Original file line number Diff line number Diff line change
@@ -1,15 +1,3 @@
pub trait Swap {
fn swap(&mut self, i: u32, j: u32);
}

impl<T, let N: u32> Swap for [T; N] {
fn swap(&mut self, i: u32, j: u32) {
let temp = self[i];
self[i] = self[j];
self[j] = temp;
}
}

unconstrained fn partition<T, let N: u32>(arr: &mut [T; N], low: u32, high: u32) -> u32
where
T: std::cmp::Ord,
Expand All @@ -18,24 +6,50 @@ where
let mut i = low;
for j in low..high {
if (arr[j] < arr[pivot]) {
arr.swap(i, j);
// arr.swap(i, j);
{
let temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
}

i += 1;
}
}
arr.swap(i, pivot);

// arr.swap(i, pivot);
{
let temp = arr[i];
arr[i] = arr[pivot];
arr[pivot] = temp;
}

i
}

unconstrained fn quicksort_recursive<T, let N: u32>(arr: &mut [T; N], low: u32, high: u32)
unconstrained fn quicksort_loop<T, let N: u32>(arr: &mut [T; N], low: u32, high: u32)
where
T: std::cmp::Ord,
{
if low < high {
let pivot_index = partition(arr, low, high);
if pivot_index > 0 {
quicksort_recursive(arr, low, pivot_index - 1);
let mut stack: [(u32, u32)] = &[(low, high)];
// TODO(https://github.com/noir-lang/noir_sort/issues/22): use 'loop' once it's stabilized
for _ in 0..2 * N {
if stack.len() == 0 {
break;
}

let (new_stack, (new_low, new_high)) = stack.pop_back();
stack = new_stack;

if new_high < new_low + 1 {
continue;
}

let pivot_index = partition(arr, new_low, new_high);
stack = stack.push_back((pivot_index + 1, new_high));
if 0 < pivot_index {
stack = stack.push_back((new_low, pivot_index - 1));
}
quicksort_recursive(arr, pivot_index + 1, high);
}
}

Expand All @@ -45,7 +59,7 @@ where
{
let mut arr: [T; N] = _arr;
if arr.len() <= 1 {} else {
quicksort_recursive(&mut arr, 0, arr.len() - 1);
quicksort_loop(&mut arr, 0, arr.len() - 1);
}
arr
}
Expand Down
61 changes: 37 additions & 24 deletions src/quicksort/quicksort_explicit.nr
Original file line number Diff line number Diff line change
@@ -1,15 +1,3 @@
pub trait Swap {
fn swap(&mut self, i: u32, j: u32);
}

impl<T, let N: u32> Swap for [T; N] {
fn swap(&mut self, i: u32, j: u32) {
let temp = self[i];
self[i] = self[j];
self[j] = temp;
}
}

unconstrained fn partition<T, let N: u32>(
arr: &mut [T; N],
low: u32,
Expand All @@ -20,37 +8,62 @@ unconstrained fn partition<T, let N: u32>(
let mut i = low;
for j in low..high {
if (sortfn(arr[j], arr[pivot])) {
arr.swap(i, j);
// arr.swap(i, j);
{
let temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
}

i += 1;
}
}
arr.swap(i, pivot);

// arr.swap(i, pivot);
{
let temp = arr[i];
arr[i] = arr[pivot];
arr[pivot] = temp;
}

i
}

unconstrained fn quicksort_recursive<T, let N: u32>(
unconstrained fn quicksort_loop<T, let N: u32>(
arr: &mut [T; N],
low: u32,
high: u32,
sortfn: unconstrained fn(T, T) -> bool,
) {
if low < high {
let pivot_index = partition(arr, low, high, sortfn);
if pivot_index > 0 {
quicksort_recursive(arr, low, pivot_index - 1, sortfn);
let mut stack: [(u32, u32)] = &[(low, high)];
// TODO(https://github.com/noir-lang/noir_sort/issues/22): use 'loop' once it's stabilized
for _ in 0..2 * N {
if stack.len() == 0 {
break;
}

let (new_stack, (new_low, new_high)) = stack.pop_back();
stack = new_stack;

if new_high < new_low + 1 {
continue;
}

let pivot_index = partition(arr, new_low, new_high, sortfn);
stack = stack.push_back((pivot_index + 1, new_high));
if 0 < pivot_index {
stack = stack.push_back((new_low, pivot_index - 1));
}
quicksort_recursive(arr, pivot_index + 1, high, sortfn);
}
}

pub unconstrained fn quicksort<T, let N: u32>(
_arr: [T; N],
arr: [T; N],
sortfn: unconstrained fn(T, T) -> bool,
) -> [T; N] {
let mut arr: [T; N] = _arr;
let mut arr: [T; N] = arr;
if arr.len() <= 1 {} else {
quicksort_recursive(&mut arr, 0, arr.len() - 1, sortfn);
quicksort_loop(&mut arr, 0, arr.len() - 1, sortfn);
}
arr
}