98
98
import com .oracle .truffle .api .dsl .Cached ;
99
99
import com .oracle .truffle .api .dsl .Fallback ;
100
100
import com .oracle .truffle .api .dsl .GenerateNodeFactory ;
101
+ import com .oracle .truffle .api .dsl .ImportStatic ;
101
102
import com .oracle .truffle .api .dsl .NodeFactory ;
102
103
import com .oracle .truffle .api .dsl .Specialization ;
103
104
import com .oracle .truffle .api .frame .VirtualFrame ;
@@ -221,9 +222,10 @@ protected ArgumentClinicProvider getArgumentClinic() {
221
222
222
223
@ Builtin (name = __EQ__ , minNumOfPositionalArgs = 2 )
223
224
@ GenerateNodeFactory
225
+ @ ImportStatic (BufferFormat .class )
224
226
abstract static class EqNode extends PythonBinaryBuiltinNode {
225
227
226
- @ Specialization (guards = "shouldCompareBytes( left, right)" )
228
+ @ Specialization (guards = { " left.getFormat() == right.getFormat()" , "!isFloatingPoint(left.getFormat())" } )
227
229
static boolean eqBytes (PArray left , PArray right ) {
228
230
if (left .getLength () != right .getLength ()) {
229
231
return false ;
@@ -237,7 +239,7 @@ static boolean eqBytes(PArray left, PArray right) {
237
239
return true ;
238
240
}
239
241
240
- @ Specialization (guards = "!shouldCompareBytes( left, right)" )
242
+ @ Specialization (guards = "left.getFormat() != right.getFormat( )" )
241
243
static boolean eqItems (PArray left , PArray right ,
242
244
@ CachedLibrary (limit = "4" ) PythonObjectLibrary lib ,
243
245
@ Cached ArrayNodes .GetValueNode getLeft ,
@@ -253,21 +255,44 @@ static boolean eqItems(PArray left, PArray right,
253
255
return true ;
254
256
}
255
257
258
+ // Separate specialization for float/double is needed because of NaN comparisons
259
+ @ Specialization (guards = {"left.getFormat() == right.getFormat()" , "isFloatingPoint(left.getFormat())" })
260
+ static boolean eqDoubles (PArray left , PArray right ,
261
+ @ Cached ArrayNodes .GetValueNode getLeft ,
262
+ @ Cached ArrayNodes .GetValueNode getRight ) {
263
+ if (left .getLength () != right .getLength ()) {
264
+ return false ;
265
+ }
266
+ for (int i = 0 ; i < left .getLength (); i ++) {
267
+ double leftValue = (Double ) getLeft .execute (left , i );
268
+ double rightValue = (Double ) getRight .execute (right , i );
269
+ if (leftValue != rightValue ) {
270
+ return false ;
271
+ }
272
+ }
273
+ return true ;
274
+ }
275
+
256
276
@ Specialization (guards = "!isArray(right)" )
257
277
@ SuppressWarnings ("unused" )
258
278
static Object eq (PArray left , Object right ) {
259
279
return PNotImplemented .NOT_IMPLEMENTED ;
260
280
}
261
281
282
+ protected static boolean shouldCompareDouble (PArray left , PArray right ) {
283
+ return left .getFormat () == right .getFormat () && (left .getFormat () == BufferFormat .DOUBLE || left .getFormat () == BufferFormat .FLOAT );
284
+ }
285
+
262
286
protected static boolean shouldCompareBytes (PArray left , PArray right ) {
263
287
return left .getFormat () == right .getFormat () && left .getFormat () != BufferFormat .DOUBLE && left .getFormat () != BufferFormat .FLOAT ;
264
288
}
265
289
}
266
290
291
+ @ ImportStatic (BufferFormat .class )
267
292
abstract static class AbstractComparisonNode extends PythonBinaryBuiltinNode {
268
293
269
- @ Specialization
270
- boolean cmp (VirtualFrame frame , PArray left , PArray right ,
294
+ @ Specialization ( guards = "!isFloatingPoint(left.getFormat()) || (left.getFormat() != right.getFormat())" )
295
+ boolean cmpItems (VirtualFrame frame , PArray left , PArray right ,
271
296
@ CachedLibrary (limit = "4" ) PythonObjectLibrary lib ,
272
297
@ Cached ("createComparison()" ) BinaryComparisonNode compareNode ,
273
298
@ Cached ("createIfTrueNode()" ) CoerceToBooleanNode coerceToBooleanNode ,
@@ -284,6 +309,24 @@ boolean cmp(VirtualFrame frame, PArray left, PArray right,
284
309
return compareLengths (left .getLength (), right .getLength ());
285
310
}
286
311
312
+ // Separate specialization for float/double is needed because of NaN comparisons
313
+ @ Specialization (guards = {"isFloatingPoint(left.getFormat())" , "left.getFormat() == right.getFormat()" })
314
+ boolean cmpDoubles (VirtualFrame frame , PArray left , PArray right ,
315
+ @ Cached ("createComparison()" ) BinaryComparisonNode compareNode ,
316
+ @ Cached ("createIfTrueNode()" ) CoerceToBooleanNode coerceToBooleanNode ,
317
+ @ Cached ArrayNodes .GetValueNode getLeft ,
318
+ @ Cached ArrayNodes .GetValueNode getRight ) {
319
+ int commonLength = Math .min (left .getLength (), right .getLength ());
320
+ for (int i = 0 ; i < commonLength ; i ++) {
321
+ double leftValue = (Double ) getLeft .execute (left , i );
322
+ double rightValue = (Double ) getRight .execute (right , i );
323
+ if (leftValue != rightValue ) {
324
+ return coerceToBooleanNode .executeBoolean (frame , compareNode .executeWith (frame , leftValue , rightValue ));
325
+ }
326
+ }
327
+ return compareLengths (left .getLength (), right .getLength ());
328
+ }
329
+
287
330
@ Specialization (guards = "!isArray(right)" )
288
331
@ SuppressWarnings ("unused" )
289
332
static Object cmp (PArray left , Object right ) {
0 commit comments