Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions .github/workflows/build_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ on:
push:
branches:
- main
pull_request:
paths:
- 'docs/**'
- 'src/**'
- '.github/workflows/build_docs.yml'
workflow_dispatch: # Make sure this job can be triggered manually
# pull_request:
# paths:
# - 'docs/**'
# - 'src/**'
# - '.github/workflows/build_docs.yml'
workflow_dispatch: # Make sure this job can be triggered manually

jobs:
build:
Expand All @@ -23,7 +23,7 @@ jobs:
- uses: actions/configure-pages@v5
- uses: actions/setup-python@v6
with:
python-version: '3.13'
python-version: "3.13"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -39,7 +39,7 @@ jobs:
- name: Upload artifact
uses: actions/upload-pages-artifact@v4
with:
path: './docs/_build/html'
path: "./docs/_build/html"
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
67 changes: 67 additions & 0 deletions src/csrc/dtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <sleef.h>
#include <sleefquad.h>
#include <ctype.h>
#include <math.h>

#define PY_ARRAY_UNIQUE_SYMBOL QuadPrecType_ARRAY_API
#define PY_UFUNC_UNIQUE_SYMBOL QuadPrecType_UFUNC_API
Expand Down Expand Up @@ -389,6 +390,71 @@ quadprec_fromstr(char *s, void *dptr, char **endptr, PyArray_Descr *descr_generi
return 0;
}

/*
* Compare function for sorting operations (argsort, sort, etc.)
* Implements PyArray_CompareFunc.
* Returns: negative if a < b, positive if a > b, 0 if equal
*/
static int
quadprec_compare(void *a, void *b, void *arr)
{
PyArrayObject *array = (PyArrayObject *)arr;
QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)PyArray_DESCR(array);

if (descr->backend == BACKEND_SLEEF) {
Sleef_quad val_a = *(Sleef_quad *)a;
Sleef_quad val_b = *(Sleef_quad *)b;

// NaN is considered greater than all other values for sorting
int a_is_nan = Sleef_iunordq1(val_a, val_a);
int b_is_nan = Sleef_iunordq1(val_b, val_b);

if (a_is_nan && b_is_nan) {
return 0;
}
if (a_is_nan) {
return 1; /* NaN goes to the end */
}
if (b_is_nan) {
return -1;
}

if (Sleef_icmpltq1(val_a, val_b)) {
return -1;
}
if (Sleef_icmpgtq1(val_a, val_b)) {
return 1;
}
return 0;
}
else {
long double val_a = *(long double *)a;
long double val_b = *(long double *)b;

// NaN is considered greater than all other values for sorting
int a_is_nan = isnan(val_a);
int b_is_nan = isnan(val_b);

if (a_is_nan && b_is_nan) {
return 0;
}
if (a_is_nan) {
return 1; /* NaN goes to the end */
}
if (b_is_nan) {
return -1;
}

if (val_a < val_b) {
return -1;
}
if (val_a > val_b) {
return 1;
}
return 0;
}
}

static PyType_Slot QuadPrecDType_Slots[] = {
{NPY_DT_ensure_canonical, &ensure_canonical},
{NPY_DT_common_instance, &common_instance},
Expand All @@ -398,6 +464,7 @@ static PyType_Slot QuadPrecDType_Slots[] = {
{NPY_DT_getitem, &quadprec_getitem},
{NPY_DT_default_descr, &quadprec_default_descr},
{NPY_DT_get_constant, &quadprec_get_constant},
{NPY_DT_PyArray_ArrFuncs_compare, &quadprec_compare},

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, as in the main comment this will be set for deprecation soon, technically is already an older version of the API.

{NPY_DT_PyArray_ArrFuncs_fill, &quadprec_fill},
{NPY_DT_PyArray_ArrFuncs_scanfunc, &quadprec_scanfunc},
{NPY_DT_PyArray_ArrFuncs_fromstr, &quadprec_fromstr},
Expand Down
57 changes: 57 additions & 0 deletions tests/test_quaddtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -5688,3 +5688,60 @@ def test_quadprecision_large_exponents(val, pow):
value_str = mp.nstr(mp.mpf(str(value)), 33)
expected_str = mp.nstr(mp_value, 33)
assert value_str == expected_str, f"QuadPrecision({val}) ** {pow} = {value_str}, expected {expected_str}"

class TestSortingOperations:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a newline above

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""Test suite for sorting operations using the compare slot."""

@pytest.mark.parametrize("backend", ["sleef", "longdouble"])
@pytest.mark.parametrize("input_arr,expected_sorted,expected_indices", [
# Basic integers: [3, 1, 4, 1, 5, 9, 2, 6] -> sorted: [1, 1, 2, 3, 4, 5, 6, 9]
([3, 1, 4, 1, 5, 9, 2, 6], [1, 1, 2, 3, 4, 5, 6, 9], [1, 3, 6, 0, 2, 4, 7, 5]),
# Negative numbers
([3, -1, 4, -5, 2, -6], [-6, -5, -1, 2, 3, 4], [5, 3, 1, 4, 0, 2]),
# Floating-point values
([3.14, 1.41, 2.72, 0.57], [0.57, 1.41, 2.72, 3.14], [3, 1, 2, 0]),
# Single element
([42], [42], [0]),
# Empty array
([], [], []),
])
def test_sort_and_argsort(self, backend, input_arr, expected_sorted, expected_indices):
"""Test sort and argsort with various inputs."""
x = np.array(input_arr, dtype=QuadPrecDType(backend=backend))

# Test sort
sorted_x = np.sort(x)
expected = np.array(expected_sorted, dtype=QuadPrecDType(backend=backend))
np.testing.assert_array_equal(sorted_x, expected)

# Test argsort
indices = np.argsort(x)
np.testing.assert_array_equal(indices, np.array(expected_indices))
# Verify argsort returns integer type
assert np.issubdtype(indices.dtype, np.integer)

@pytest.mark.parametrize("backend", ["sleef", "longdouble"])
@pytest.mark.parametrize("input_arr,check_fn", [
# NaN should be sorted to the end
([3, float('nan'), 1, 2], lambda s: s[0] == 1 and s[1] == 2 and s[2] == 3 and np.isnan(s[3])),
# Inf handling
([3, float('inf'), 1, float('-inf'), 2],
lambda s: s[0] == float('-inf') and s[1] == 1 and s[2] == 2 and s[3] == 3 and s[4] == float('inf')),
# Multiple NaNs
([float('nan'), 1, float('nan'), 2], lambda s: s[0] == 1 and s[1] == 2 and np.isnan(s[2]) and np.isnan(s[3])),
])
def test_sort_special_values(self, backend, input_arr, check_fn):
"""Test sorting with NaN and Inf values."""
x = np.array(input_arr, dtype=QuadPrecDType(backend=backend))
sorted_x = np.sort(x)
assert check_fn(sorted_x)

@pytest.mark.parametrize("kind", ["quicksort", "mergesort", "heapsort", "stable"])
@pytest.mark.parametrize("backend", ["sleef", "longdouble"])
def test_sort_algorithms(self, backend, kind):
"""Test that different sorting algorithms work with the compare function."""
x = np.array([5, 2, 8, 1, 9, 3], dtype=QuadPrecDType(backend=backend))
sorted_x = np.sort(x, kind=kind)
expected = np.array([1, 2, 3, 5, 8, 9], dtype=QuadPrecDType(backend=backend))
np.testing.assert_array_equal(sorted_x, expected)