Skip to content

Commit 410dbb3

Browse files
committed
Enforce zero-sensitive min/max
1 parent d306fb2 commit 410dbb3

File tree

2 files changed

+47
-13
lines changed

2 files changed

+47
-13
lines changed

quaddtype/numpy_quaddtype/src/ops.hpp

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -514,33 +514,53 @@ quad_mod(const Sleef_quad *a, const Sleef_quad *b)
514514
static inline Sleef_quad
515515
quad_minimum(const Sleef_quad *in1, const Sleef_quad *in2)
516516
{
517-
return Sleef_iunordq1(*in1, *in2) ? (Sleef_iunordq1(*in1, *in1) ? *in1 : *in2)
518-
: Sleef_icmpltq1(*in1, *in2) ? *in1
519-
: *in2;
517+
if (Sleef_iunordq1(*in1, *in2)) {
518+
return Sleef_iunordq1(*in1, *in1) ? *in1 : *in2;
519+
}
520+
// minimum(-0.0, +0.0) = -0.0
521+
if (Sleef_icmpeqq1(*in1, QUAD_ZERO) && Sleef_icmpeqq1(*in2, QUAD_ZERO)) {
522+
return Sleef_icmpleq1(Sleef_copysignq1(QUAD_ONE, *in1), Sleef_copysignq1(QUAD_ONE, *in2)) ? *in1 : *in2;
523+
}
524+
return Sleef_fminq1(*in1, *in2);
520525
}
521526

522527
static inline Sleef_quad
523528
quad_maximum(const Sleef_quad *in1, const Sleef_quad *in2)
524529
{
525-
return Sleef_iunordq1(*in1, *in2) ? (Sleef_iunordq1(*in1, *in1) ? *in1 : *in2)
526-
: Sleef_icmpgtq1(*in1, *in2) ? *in1
527-
: *in2;
530+
if (Sleef_iunordq1(*in1, *in2)) {
531+
return Sleef_iunordq1(*in1, *in1) ? *in1 : *in2;
532+
}
533+
// maximum(-0.0, +0.0) = +0.0
534+
if (Sleef_icmpeqq1(*in1, QUAD_ZERO) && Sleef_icmpeqq1(*in2, QUAD_ZERO)) {
535+
return Sleef_icmpgeq1(Sleef_copysignq1(QUAD_ONE, *in1), Sleef_copysignq1(QUAD_ONE, *in2)) ? *in1 : *in2;
536+
}
537+
return Sleef_fmaxq1(*in1, *in2);
528538
}
529539

530540
static inline Sleef_quad
531541
quad_fmin(const Sleef_quad *in1, const Sleef_quad *in2)
532542
{
533-
return Sleef_iunordq1(*in1, *in2) ? (Sleef_iunordq1(*in2, *in2) ? *in1 : *in2)
534-
: Sleef_icmpleq1(*in1, *in2) ? *in1
535-
: *in2;
543+
if (Sleef_iunordq1(*in1, *in2)) {
544+
return Sleef_iunordq1(*in2, *in2) ? *in1 : *in2;
545+
}
546+
// fmin(-0.0, +0.0) = -0.0
547+
if (Sleef_icmpeqq1(*in1, QUAD_ZERO) && Sleef_icmpeqq1(*in2, QUAD_ZERO)) {
548+
return Sleef_icmpleq1(Sleef_copysignq1(QUAD_ONE, *in1), Sleef_copysignq1(QUAD_ONE, *in2)) ? *in1 : *in2;
549+
}
550+
return Sleef_fminq1(*in1, *in2);
536551
}
537552

538553
static inline Sleef_quad
539554
quad_fmax(const Sleef_quad *in1, const Sleef_quad *in2)
540555
{
541-
return Sleef_iunordq1(*in1, *in2) ? (Sleef_iunordq1(*in2, *in2) ? *in1 : *in2)
542-
: Sleef_icmpgeq1(*in1, *in2) ? *in1
543-
: *in2;
556+
if (Sleef_iunordq1(*in1, *in2)) {
557+
return Sleef_iunordq1(*in2, *in2) ? *in1 : *in2;
558+
}
559+
// maximum(-0.0, +0.0) = +0.0
560+
if (Sleef_icmpeqq1(*in1, QUAD_ZERO) && Sleef_icmpeqq1(*in2, QUAD_ZERO)) {
561+
return Sleef_icmpgeq1(Sleef_copysignq1(QUAD_ONE, *in1), Sleef_copysignq1(QUAD_ONE, *in2)) ? *in1 : *in2;
562+
}
563+
return Sleef_fmaxq1(*in1, *in2);
544564
}
545565

546566
static inline Sleef_quad

quaddtype/tests/test_quaddtype.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,13 @@ def test_array_minmax(op, a, b):
9797
quad_res = op_func(quad_a, quad_b)
9898
float_res = op_func(float_a, float_b)
9999

100+
# native implementation may not be sensitive to zero signs
101+
# but we want to enforce it for the quad dtype
102+
# e.g. min(+0.0, -0.0) = -0.0
103+
if float_a == 0.0 and float_b == 0.0:
104+
assert float_res == 0.0
105+
float_res = np.copysign(0.0, op_fun(np.copysign(1.0, float_a), np.copysign(1.0, float_b)))
106+
100107
np.testing.assert_array_equal(quad_res.astype(float), float_res)
101108

102109
# Check sign for zero results
@@ -116,6 +123,13 @@ def test_array_aminmax(op, a, b):
116123
quad_res = op_func(quad_ab)
117124
float_res = op_func(float_ab)
118125

126+
# native implementation may not be sensitive to zero signs
127+
# but we want to enforce it for the quad dtype
128+
# e.g. min(+0.0, -0.0) = -0.0
129+
if float(a) == 0.0 and float(b) == 0.0:
130+
assert float_res == 0.0
131+
float_res = np.copysign(0.0, op_fun(np.array([np.copysign(1.0, float(a)), np.copysign(1.0, float(b))])))
132+
119133
np.testing.assert_array_equal(np.array(quad_res).astype(float), float_res)
120134

121135
# Check sign for zero results
@@ -490,7 +504,7 @@ def test_mod(a, b, backend, op):
490504
"0.9", "-0.9", "0.9999", "-0.9999",
491505
"1.1", "-1.1", "1.0001", "-1.0001",
492506
# Medium values
493-
"10.0", "-10.0", "20.0", "-20.0",
507+
"10.0", "-10.0", "20.0x", "-20.0",
494508
# Large values
495509
"100.0", "200.0", "700.0", "1000.0",
496510
"-100.0", "-200.0", "-700.0", "-1000.0",

0 commit comments

Comments
 (0)