Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion quaddtype/numpy_quaddtype/src/scalar_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,23 @@ QuadPrecision_int(QuadPrecisionObject *self)
}
}

template <binary_op_quad_def sleef_op, binary_op_longdouble_def longdouble_op>
static PyObject *
quad_ternary_power_func(PyObject *op1, PyObject *op2, PyObject *mod)
{
if (mod != Py_None) {
PyErr_SetString(PyExc_TypeError,
"pow() 3rd argument not allowed unless all arguments are integers");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way the if statement above is written, the bit starting with "unless" is incorrect unless I'm missing something. Not sure if you meant to implement what's in the error or if the error is wrong.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cpython's pow function take 3 args but the last arg is only valid if the inputs are integers, for float values it passes the py_none there.
quaddtype is floating-point so we anyways don't need that argument and if in case somebody explicitly called pow with a 3rd argument (mod) then it'll be a mistake and error out

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the standard behaviour so we are in this PR implementing the same

In [5]: pow(2.0, 1,2)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[5], line 1
----> 1 pow(2.0, 1,2)

TypeError: pow() 3rd argument not allowed unless all arguments are integers

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In quaddtype it should be as

In [2]: a = QuadPrecision("1")

In [3]: pow(a, 2)
Out[3]: QuadPrecision('1.0e+000', backend='sleef')

In [4]: pow(a, 2, 1)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[4], line 1
----> 1 pow(a, 2, 1)

TypeError: pow() 3rd argument not allowed unless all arguments are integers

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining.

return NULL;
}
return quad_binary_func<sleef_op, longdouble_op>(op1, op2);
}

PyNumberMethods quad_as_scalar = {
.nb_add = (binaryfunc)quad_binary_func<quad_add, ld_add>,
.nb_subtract = (binaryfunc)quad_binary_func<quad_sub, ld_sub>,
.nb_multiply = (binaryfunc)quad_binary_func<quad_mul, ld_mul>,
.nb_power = (ternaryfunc)quad_binary_func<quad_pow, ld_pow>,
.nb_power = (ternaryfunc)quad_ternary_power_func<quad_pow, ld_pow>,
.nb_negative = (unaryfunc)quad_unary_func<quad_negative, ld_negative>,
.nb_positive = (unaryfunc)quad_unary_func<quad_positive, ld_positive>,
.nb_absolute = (unaryfunc)quad_unary_func<quad_absolute, ld_absolute>,
Expand Down
11 changes: 10 additions & 1 deletion quaddtype/tests/test_quaddtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -5411,4 +5411,13 @@ def test_float_to_quad_sign_preserve(dtype, val):
if np.isnan(val):
assert np.isnan(result), f"NaN failed for {dtype}"
else:
assert result == val, f"{val} failed for {dtype}"
assert result == val, f"{val} failed for {dtype}"

@pytest.mark.parametrize("val, pow", [(2, 112), (2, -112), (10, 34), (10, -34)])
def test_quadprecision_large_exponents(val, pow):
mp.prec = 113
mp_value = mp.mpf(val) ** pow
value = QuadPrecision(val) ** pow
value_str = mp.nstr(mp.mpf(str(value)), 33)
expected_str = mp.nstr(mp_value, 33)
assert value_str == expected_str, f"QuadPrecision({val}) ** {pow} = {value_str}, expected {expected_str}"