Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 123 additions & 2 deletions src/lib.nr
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub fn sort<T, let N: u32>(input: [T; N]) -> [T; N]
where
T: std::cmp::Ord + std::cmp::Eq,
{
/// Safety: validated using check_shuffle. See its documentation for more info.
let sorted = unsafe { quicksort(input) };

for i in 0..N - 1 {
Expand All @@ -36,8 +37,11 @@ pub fn sort_via<T, let N: u32>(input: [T; N], sortfn: fn(T, T) -> bool) -> [T; N
where
T: std::cmp::Eq,
{
/// Safety: validated using check_shuffle. See its documentation for more info.
let sorted = unsafe { quicksort_explicit(input, sortfn) };

// TODO
println(sorted);
for i in 0..N - 1 {
assert(sortfn(sorted[i], sorted[i + 1]));
}
Expand Down Expand Up @@ -67,6 +71,7 @@ pub fn sort_extended<T, let N: u32>(
where
T: std::cmp::Eq,
{
/// Safety: validated using check_shuffle. See its documentation for more info.
let sorted = unsafe { quicksort_explicit(input, sortfn) };

for i in 0..N - 1 {
Expand All @@ -80,6 +85,7 @@ pub struct SortResult<T, let N: u32> {
pub sorted: [T; N],
pub sort_indices: [Field; N],
}

pub fn sort_advanced<T, let N: u32>(
input: [T; N],
sortfn: unconstrained fn(T, T) -> bool,
Expand All @@ -88,6 +94,7 @@ pub fn sort_advanced<T, let N: u32>(
where
T: std::cmp::Eq,
{
/// Safety: validated using check_shuffle. See its documentation for more info.
let sorted = unsafe { quicksort_explicit(input, sortfn) };

let sort_indices = get_shuffle_indices(input, sorted);
Expand All @@ -103,6 +110,86 @@ mod test {
use crate::sort_extended;
use crate::sort_via;

use crate::quicksort::quicksort_recursive::quicksort as quicksort_recursive;
use crate::quicksort::quicksort_recursive_explicit::quicksort as quicksort_recursive_explicit;
use dep::check_shuffle::check_shuffle;

/**
* Given an input array of type T, return a sorted array, using unconstrained recursion.
* Type `T` must satisfy the Ord and Eq trait
* The Eq function is used within an unconstrained function so its constraint-efficiently is not relevant
* Note: sort_extended is likely more efficient as constraining `x < y` can typically be described
* more efficiently than returning a boolean that describes whether `x < y`
**/
fn sort_recursive<T, let N: u32>(input: [T; N]) -> [T; N]
where
T: std::cmp::Ord + std::cmp::Eq,
{
/// Safety: validated using check_shuffle. See its documentation for more info.
let sorted = unsafe { quicksort_recursive(input) };

for i in 0..N - 1 {
assert(sorted[i] <= sorted[i + 1]);
}
check_shuffle(input, sorted);
sorted
}

/**
* Given an input array of type T, return a sorted array, using unconstrained recursion.
* Type `T` must satisfy the Eq trait
* The Eq function is used within an unconstrained function so its constraint-efficiently is not relevant
*
* The `sortfn` parameter is a function that descibes whether, given two elements `a, b` of type T, `a <= b`
* Note: sort_extended is likely more efficient as constraining `x < y` can typically be described
* more efficiently than returning a boolean that describes whether `x < y`
**/
fn sort_via_recursive<T, let N: u32>(input: [T; N], sortfn: fn(T, T) -> bool) -> [T; N]
where
T: std::cmp::Eq,
{
/// Safety: validated using check_shuffle. See its documentation for more info.
let sorted = unsafe { quicksort_recursive_explicit(input, sortfn) };

for i in 0..N - 1 {
assert(sortfn(sorted[i], sorted[i + 1]));
}
check_shuffle(input, sorted);
sorted
}

/**
* Given an input array of type T, return a sorted array, using unconstrained recursion.
* Type `T` must satisfy the Eq trait
* The Eq function is used within an unconstrained function so its constraint-efficiently is not relevant
*
* The `sortfn` parameter is a function that descibes whether, given two elements `a, b` of type T, `a <= b`
* The `sortfn_assert` parameter is a function that *asserts* that `a <= b`
*
* `sortfn` is used in unconstrained functions only
* `sortfn_assert` is used in constrained functions

* Note: This is likely the most efficient sort function as constraining `x < y` can typically be described
* more efficiently than returning a boolean that describes whether `x < y`
**/
pub fn sort_extended_recursive<T, let N: u32>(
input: [T; N],
sortfn: unconstrained fn(T, T) -> bool,
sortfn_assert: fn(T, T) -> (),
) -> [T; N]
where
T: std::cmp::Eq,
{
/// Safety: validated using check_shuffle. See its documentation for more info.
let sorted = unsafe { quicksort_recursive_explicit(input, sortfn) };

for i in 0..N - 1 {
sortfn_assert(sorted[i], sorted[i + 1]);
}
check_shuffle(input, sorted);
sorted
}

fn sort_u32(a: u32, b: u32) -> bool {
a <= b
}
Expand All @@ -118,6 +205,9 @@ mod test {
let b = _b as Field;

let diff = b - a;
// TODO cleanup
println(diff);
// assert(0.lt(diff));
diff.assert_max_bit_size::<32>();
}

Expand All @@ -131,6 +221,16 @@ mod test {
assert(sorted == expected);
}

#[test]
fn test_sort_recursive() {
let mut arr: [u32; 7] = [3, 6, 8, 10, 1, 2, 1];

let sorted = sort_recursive(arr);

let expected: [u32; 7] = [1, 1, 2, 3, 6, 8, 10];
assert(sorted == expected);
}

#[test]
fn test_sort_via() {
let mut arr: [u32; 7] = [3, 6, 8, 10, 1, 2, 1];
Expand All @@ -142,10 +242,31 @@ mod test {
}

#[test]
fn test_sort_extended() {
fn test_sort_via_recursive() {
let mut arr: [u32; 7] = [3, 6, 8, 10, 1, 2, 1];

let sorted = sort_via_recursive(arr, sort_u32);

let expected: [u32; 7] = [1, 1, 2, 3, 6, 8, 10];
assert(sorted == expected);
}

// TODO: re-enable after test_sort_via
// #[test]
// fn test_sort_extended() {
// let mut arr: [u32; 7] = [3, 6, 8, 10, 1, 2, 1];
//
// let sorted = sort_extended(arr, __sort_u32, unconditional_lt);
//
// let expected: [u32; 7] = [1, 1, 2, 3, 6, 8, 10];
// assert(sorted == expected);
// }

#[test]
fn test_sort_extended_recursive() {
let mut arr: [u32; 7] = [3, 6, 8, 10, 1, 2, 1];

let sorted = sort_extended(arr, __sort_u32, unconditional_lt);
let sorted = sort_extended_recursive(arr, __sort_u32, unconditional_lt);

let expected: [u32; 7] = [1, 1, 2, 3, 6, 8, 10];
assert(sorted == expected);
Expand Down
2 changes: 2 additions & 0 deletions src/quicksort.nr
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
pub mod quicksort;
pub mod quicksort_explicit;
pub mod quicksort_recursive;
pub mod quicksort_recursive_explicit;
15 changes: 10 additions & 5 deletions src/quicksort/quicksort.nr
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,21 @@ where
i
}

unconstrained fn quicksort_recursive<T, let N: u32>(arr: &mut [T; N], low: u32, high: u32)
unconstrained fn quicksort_loop<T, let N: u32>(arr: &mut [T; N], low: u32, high: u32)
where
T: std::cmp::Ord,
{
if low < high {
let mut low = low;
let mut high = high;
loop {
if low >= high {
break;
}
let pivot_index = partition(arr, low, high);
if pivot_index > 0 {
quicksort_recursive(arr, low, pivot_index - 1);
high = pivot_index - 1;
}
quicksort_recursive(arr, pivot_index + 1, high);
low = pivot_index + 1;
}
}

Expand All @@ -45,7 +50,7 @@ where
{
let mut arr: [T; N] = _arr;
if arr.len() <= 1 {} else {
quicksort_recursive(&mut arr, 0, arr.len() - 1);
quicksort_loop(&mut arr, 0, arr.len() - 1);
}
arr
}
Expand Down
20 changes: 12 additions & 8 deletions src/quicksort/quicksort_explicit.nr
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,33 @@ unconstrained fn partition<T, let N: u32>(
i
}

unconstrained fn quicksort_recursive<T, let N: u32>(
unconstrained fn quicksort_loop<T, let N: u32>(
arr: &mut [T; N],
low: u32,
high: u32,
sortfn: unconstrained fn(T, T) -> bool,
) {
if low < high {
let mut low = low;
let mut high = high;
loop {
if low >= high {
break;
}
let pivot_index = partition(arr, low, high, sortfn);
if pivot_index > 0 {
quicksort_recursive(arr, low, pivot_index - 1, sortfn);
high = pivot_index - 1;
}
quicksort_recursive(arr, pivot_index + 1, high, sortfn);
low = pivot_index + 1;
}
}

pub unconstrained fn quicksort<T, let N: u32>(
_arr: [T; N],
arr: [T; N],
sortfn: unconstrained fn(T, T) -> bool,
) -> [T; N] {
let mut arr: [T; N] = _arr;
let mut arr: [T; N] = arr;
if arr.len() <= 1 {} else {
quicksort_recursive(&mut arr, 0, arr.len() - 1, sortfn);
quicksort_loop(&mut arr, 0, arr.len() - 1, sortfn);
}
arr
}

51 changes: 51 additions & 0 deletions src/quicksort/quicksort_recursive.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
pub trait Swap {
fn swap(&mut self, i: u32, j: u32);
}

impl<T, let N: u32> Swap for [T; N] {
fn swap(&mut self, i: u32, j: u32) {
let temp = self[i];
self[i] = self[j];
self[j] = temp;
}
}

unconstrained fn partition<T, let N: u32>(arr: &mut [T; N], low: u32, high: u32) -> u32
where
T: std::cmp::Ord,
{
let pivot = high;
let mut i = low;
for j in low..high {
if (arr[j] < arr[pivot]) {
arr.swap(i, j);
i += 1;
}
}
arr.swap(i, pivot);
i
}

unconstrained fn quicksort_recursive<T, let N: u32>(arr: &mut [T; N], low: u32, high: u32)
where
T: std::cmp::Ord,
{
if low < high {
let pivot_index = partition(arr, low, high);
if pivot_index > 0 {
quicksort_recursive(arr, low, pivot_index - 1);
}
quicksort_recursive(arr, pivot_index + 1, high);
}
}

pub unconstrained fn quicksort<T, let N: u32>(_arr: [T; N]) -> [T; N]
where
T: std::cmp::Ord,
{
let mut arr: [T; N] = _arr;
if arr.len() <= 1 {} else {
quicksort_recursive(&mut arr, 0, arr.len() - 1);
}
arr
}
56 changes: 56 additions & 0 deletions src/quicksort/quicksort_recursive_explicit.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
pub trait Swap {
fn swap(&mut self, i: u32, j: u32);
}

impl<T, let N: u32> Swap for [T; N] {
fn swap(&mut self, i: u32, j: u32) {
let temp = self[i];
self[i] = self[j];
self[j] = temp;
}
}

unconstrained fn partition<T, let N: u32>(
arr: &mut [T; N],
low: u32,
high: u32,
sortfn: unconstrained fn(T, T) -> bool,
) -> u32 {
let pivot = high;
let mut i = low;
for j in low..high {
if (sortfn(arr[j], arr[pivot])) {
arr.swap(i, j);
i += 1;
}
}
arr.swap(i, pivot);
i
}

unconstrained fn quicksort_recursive<T, let N: u32>(
arr: &mut [T; N],
low: u32,
high: u32,
sortfn: unconstrained fn(T, T) -> bool,
) {
if low < high {
let pivot_index = partition(arr, low, high, sortfn);
if pivot_index > 0 {
quicksort_recursive(arr, low, pivot_index - 1, sortfn);
}
quicksort_recursive(arr, pivot_index + 1, high, sortfn);
}
}

pub unconstrained fn quicksort<T, let N: u32>(
_arr: [T; N],
sortfn: unconstrained fn(T, T) -> bool,
) -> [T; N] {
let mut arr: [T; N] = _arr;
if arr.len() <= 1 {} else {
quicksort_recursive(&mut arr, 0, arr.len() - 1, sortfn);
}
arr
}

Loading