Skip to content

Commit ef1494d

Browse files
committed
Fix comparisons of arrays containing NaN
1 parent f0c2497 commit ef1494d

File tree

3 files changed

+53
-4
lines changed

3 files changed

+53
-4
lines changed

graalpython/com.oracle.graal.python.test/src/tests/unittest_tags/test_array.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
*graalpython.lib-python.3.test.test_array.DoubleTest.test_iterator_pickle
9393
*graalpython.lib-python.3.test.test_array.DoubleTest.test_len
9494
*graalpython.lib-python.3.test.test_array.DoubleTest.test_mul
95+
*graalpython.lib-python.3.test.test_array.DoubleTest.test_nan
9596
*graalpython.lib-python.3.test.test_array.DoubleTest.test_obsolete_write_lock
9697
*graalpython.lib-python.3.test.test_array.DoubleTest.test_pickle
9798
*graalpython.lib-python.3.test.test_array.DoubleTest.test_pickle_for_empty_array
@@ -147,6 +148,7 @@
147148
*graalpython.lib-python.3.test.test_array.FloatTest.test_iterator_pickle
148149
*graalpython.lib-python.3.test.test_array.FloatTest.test_len
149150
*graalpython.lib-python.3.test.test_array.FloatTest.test_mul
151+
*graalpython.lib-python.3.test.test_array.FloatTest.test_nan
150152
*graalpython.lib-python.3.test.test_array.FloatTest.test_obsolete_write_lock
151153
*graalpython.lib-python.3.test.test_array.FloatTest.test_pickle
152154
*graalpython.lib-python.3.test.test_array.FloatTest.test_pickle_for_empty_array

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/array/ArrayBuiltins.java

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
import com.oracle.truffle.api.dsl.Cached;
9999
import com.oracle.truffle.api.dsl.Fallback;
100100
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
101+
import com.oracle.truffle.api.dsl.ImportStatic;
101102
import com.oracle.truffle.api.dsl.NodeFactory;
102103
import com.oracle.truffle.api.dsl.Specialization;
103104
import com.oracle.truffle.api.frame.VirtualFrame;
@@ -221,9 +222,10 @@ protected ArgumentClinicProvider getArgumentClinic() {
221222

222223
@Builtin(name = __EQ__, minNumOfPositionalArgs = 2)
223224
@GenerateNodeFactory
225+
@ImportStatic(BufferFormat.class)
224226
abstract static class EqNode extends PythonBinaryBuiltinNode {
225227

226-
@Specialization(guards = "shouldCompareBytes(left, right)")
228+
@Specialization(guards = {"left.getFormat() == right.getFormat()", "!isFloatingPoint(left.getFormat())"})
227229
static boolean eqBytes(PArray left, PArray right) {
228230
if (left.getLength() != right.getLength()) {
229231
return false;
@@ -237,7 +239,7 @@ static boolean eqBytes(PArray left, PArray right) {
237239
return true;
238240
}
239241

240-
@Specialization(guards = "!shouldCompareBytes(left, right)")
242+
@Specialization(guards = "left.getFormat() != right.getFormat()")
241243
static boolean eqItems(PArray left, PArray right,
242244
@CachedLibrary(limit = "4") PythonObjectLibrary lib,
243245
@Cached ArrayNodes.GetValueNode getLeft,
@@ -253,21 +255,44 @@ static boolean eqItems(PArray left, PArray right,
253255
return true;
254256
}
255257

258+
// Separate specialization for float/double is needed because of NaN comparisons
259+
@Specialization(guards = {"left.getFormat() == right.getFormat()", "isFloatingPoint(left.getFormat())"})
260+
static boolean eqDoubles(PArray left, PArray right,
261+
@Cached ArrayNodes.GetValueNode getLeft,
262+
@Cached ArrayNodes.GetValueNode getRight) {
263+
if (left.getLength() != right.getLength()) {
264+
return false;
265+
}
266+
for (int i = 0; i < left.getLength(); i++) {
267+
double leftValue = (Double) getLeft.execute(left, i);
268+
double rightValue = (Double) getRight.execute(right, i);
269+
if (leftValue != rightValue) {
270+
return false;
271+
}
272+
}
273+
return true;
274+
}
275+
256276
@Specialization(guards = "!isArray(right)")
257277
@SuppressWarnings("unused")
258278
static Object eq(PArray left, Object right) {
259279
return PNotImplemented.NOT_IMPLEMENTED;
260280
}
261281

282+
protected static boolean shouldCompareDouble(PArray left, PArray right) {
283+
return left.getFormat() == right.getFormat() && (left.getFormat() == BufferFormat.DOUBLE || left.getFormat() == BufferFormat.FLOAT);
284+
}
285+
262286
protected static boolean shouldCompareBytes(PArray left, PArray right) {
263287
return left.getFormat() == right.getFormat() && left.getFormat() != BufferFormat.DOUBLE && left.getFormat() != BufferFormat.FLOAT;
264288
}
265289
}
266290

291+
@ImportStatic(BufferFormat.class)
267292
abstract static class AbstractComparisonNode extends PythonBinaryBuiltinNode {
268293

269-
@Specialization
270-
boolean cmp(VirtualFrame frame, PArray left, PArray right,
294+
@Specialization(guards = "!isFloatingPoint(left.getFormat()) || (left.getFormat() != right.getFormat())")
295+
boolean cmpItems(VirtualFrame frame, PArray left, PArray right,
271296
@CachedLibrary(limit = "4") PythonObjectLibrary lib,
272297
@Cached("createComparison()") BinaryComparisonNode compareNode,
273298
@Cached("createIfTrueNode()") CoerceToBooleanNode coerceToBooleanNode,
@@ -284,6 +309,24 @@ boolean cmp(VirtualFrame frame, PArray left, PArray right,
284309
return compareLengths(left.getLength(), right.getLength());
285310
}
286311

312+
// Separate specialization for float/double is needed because of NaN comparisons
313+
@Specialization(guards = {"isFloatingPoint(left.getFormat())", "left.getFormat() == right.getFormat()"})
314+
boolean cmpDoubles(VirtualFrame frame, PArray left, PArray right,
315+
@Cached("createComparison()") BinaryComparisonNode compareNode,
316+
@Cached("createIfTrueNode()") CoerceToBooleanNode coerceToBooleanNode,
317+
@Cached ArrayNodes.GetValueNode getLeft,
318+
@Cached ArrayNodes.GetValueNode getRight) {
319+
int commonLength = Math.min(left.getLength(), right.getLength());
320+
for (int i = 0; i < commonLength; i++) {
321+
double leftValue = (Double) getLeft.execute(left, i);
322+
double rightValue = (Double) getRight.execute(right, i);
323+
if (leftValue != rightValue) {
324+
return coerceToBooleanNode.executeBoolean(frame, compareNode.executeWith(frame, leftValue, rightValue));
325+
}
326+
}
327+
return compareLengths(left.getLength(), right.getLength());
328+
}
329+
287330
@Specialization(guards = "!isArray(right)")
288331
@SuppressWarnings("unused")
289332
static Object cmp(PArray left, Object right) {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/util/BufferFormat.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,4 +135,8 @@ private static BufferFormat fromCharCommon(char fmtchar) {
135135
return null;
136136
}
137137

138+
public static boolean isFloatingPoint(BufferFormat format) {
139+
return format == FLOAT || format == DOUBLE;
140+
}
141+
138142
}

0 commit comments

Comments
 (0)