Skip to content

Commit ec3f169

Browse files
committed
[GR-20983] Fix crash when binary comparison methods don't return boolean.
PullRequest: graalpython/808
2 parents e7cef7b + 639e34c commit ec3f169

File tree

10 files changed

+185
-172
lines changed

10 files changed

+185
-172
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinFunctions.java

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,8 @@ public int len(VirtualFrame frame, Object obj,
12271227

12281228
public abstract static class MinMaxNode extends PythonBuiltinNode {
12291229

1230+
@CompilationFinal private boolean seenNonBoolean = false;
1231+
12301232
protected final BinaryComparisonNode createComparison() {
12311233
if (this instanceof MaxNode) {
12321234
return BinaryComparisonNode.create(SpecialMethodNames.__GT__, SpecialMethodNames.__LT__, ">");
@@ -1240,16 +1242,18 @@ Object maxSequence(VirtualFrame frame, PythonObject arg1, Object[] args, @Suppre
12401242
@Cached("create()") GetIteratorNode getIterator,
12411243
@Cached("create()") GetNextNode next,
12421244
@Cached("createComparison()") BinaryComparisonNode compare,
1245+
@Shared("castToBooleanNode") @Cached("createIfTrueNode()") CoerceToBooleanNode castToBooleanNode,
12431246
@Cached("create()") IsBuiltinClassProfile errorProfile1,
12441247
@Cached("create()") IsBuiltinClassProfile errorProfile2) {
1245-
return minmaxSequenceWithKey(frame, arg1, args, null, getIterator, next, compare, null, errorProfile1, errorProfile2);
1248+
return minmaxSequenceWithKey(frame, arg1, args, null, getIterator, next, compare, castToBooleanNode, null, errorProfile1, errorProfile2);
12461249
}
12471250

12481251
@Specialization(guards = "args.length == 0")
12491252
Object minmaxSequenceWithKey(VirtualFrame frame, PythonObject arg1, @SuppressWarnings("unused") Object[] args, PythonObject keywordArg,
12501253
@Cached("create()") GetIteratorNode getIterator,
12511254
@Cached("create()") GetNextNode next,
12521255
@Cached("createComparison()") BinaryComparisonNode compare,
1256+
@Shared("castToBooleanNode") @Cached("createIfTrueNode()") CoerceToBooleanNode castToBooleanNode,
12531257
@Cached("create()") CallNode keyCall,
12541258
@Cached("create()") IsBuiltinClassProfile errorProfile1,
12551259
@Cached("create()") IsBuiltinClassProfile errorProfile2) {
@@ -1271,7 +1275,19 @@ Object minmaxSequenceWithKey(VirtualFrame frame, PythonObject arg1, @SuppressWar
12711275
break;
12721276
}
12731277
Object nextKey = applyKeyFunction(frame, keywordArg, keyCall, nextValue);
1274-
if (compare.executeBool(frame, nextKey, currentKey)) {
1278+
boolean isTrue;
1279+
if (!seenNonBoolean) {
1280+
try {
1281+
isTrue = compare.executeBool(frame, nextKey, currentKey);
1282+
} catch (UnexpectedResultException e) {
1283+
CompilerDirectives.transferToInterpreterAndInvalidate();
1284+
seenNonBoolean = true;
1285+
isTrue = castToBooleanNode.executeBoolean(frame, e.getResult());
1286+
}
1287+
} else {
1288+
isTrue = castToBooleanNode.executeBoolean(frame, compare.executeWith(frame, nextKey, currentKey));
1289+
}
1290+
if (isTrue) {
12751291
currentKey = nextKey;
12761292
currentValue = nextValue;
12771293
}
@@ -1297,15 +1313,34 @@ Object minmaxBinaryWithKey(VirtualFrame frame, Object arg1, Object[] args, Pytho
12971313
Object currentKey = applyKeyFunction(frame, keywordArg, keyCall, currentValue);
12981314
Object nextValue = args[0];
12991315
Object nextKey = applyKeyFunction(frame, keywordArg, keyCall, nextValue);
1300-
if (castToBooleanNode.executeBoolean(frame, compare.executeWith(frame, nextKey, currentKey))) {
1316+
boolean isTrue;
1317+
try {
1318+
isTrue = compare.executeBool(frame, nextKey, currentKey);
1319+
} catch (UnexpectedResultException e) {
1320+
CompilerDirectives.transferToInterpreterAndInvalidate();
1321+
seenNonBoolean = true;
1322+
isTrue = castToBooleanNode.executeBoolean(frame, e.getResult());
1323+
}
1324+
if (isTrue) {
13011325
currentKey = nextKey;
13021326
currentValue = nextValue;
13031327
}
13041328
if (moreThanTwo.profile(args.length > 1)) {
13051329
for (int i = 0; i < args.length; i++) {
13061330
nextValue = args[i];
13071331
nextKey = applyKeyFunction(frame, keywordArg, keyCall, nextValue);
1308-
if (compare.executeBool(frame, nextKey, currentKey)) {
1332+
if (!seenNonBoolean) {
1333+
try {
1334+
isTrue = compare.executeBool(frame, nextKey, currentKey);
1335+
} catch (UnexpectedResultException e) {
1336+
CompilerDirectives.transferToInterpreterAndInvalidate();
1337+
seenNonBoolean = true;
1338+
isTrue = castToBooleanNode.executeBoolean(frame, e.getResult());
1339+
}
1340+
} else {
1341+
isTrue = castToBooleanNode.executeBoolean(frame, compare.executeWith(frame, nextKey, currentKey));
1342+
}
1343+
if (isTrue) {
13091344
currentKey = nextKey;
13101345
currentValue = nextValue;
13111346
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/OperatorModuleBuiltins.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,15 +175,15 @@ public boolean doString(String value1, String value2) {
175175
private @Child BinaryComparisonNode equalsNode;
176176

177177
@Fallback
178-
public boolean doObject(VirtualFrame frame, Object value1, Object value2) {
178+
public Object doObject(VirtualFrame frame, Object value1, Object value2) {
179179
if (value1 == value2) {
180180
return true;
181181
}
182182
if (equalsNode == null) {
183183
CompilerDirectives.transferToInterpreterAndInvalidate();
184184
equalsNode = insert((BinaryComparisonNode.create(SpecialMethodNames.__EQ__, SpecialMethodNames.__EQ__, "==")));
185185
}
186-
return equalsNode.executeBool(frame, value1, value2);
186+
return equalsNode.executeWith(frame, value1, value2);
187187
}
188188
}
189189

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/bytes/BytesBuiltins.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@
8080
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
8181
import com.oracle.graal.python.nodes.control.GetIteratorExpressionNode.GetIteratorNode;
8282
import com.oracle.graal.python.nodes.control.GetNextNode;
83-
import com.oracle.graal.python.nodes.expression.BinaryComparisonNode;
8483
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
8584
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
8685
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
@@ -671,16 +670,17 @@ private int getLength(SequenceStorage s) {
671670
@ImportStatic(SpecialMethodNames.class)
672671
public abstract static class ByteArrayCountNode extends PythonBinaryBuiltinNode {
673672

674-
@Specialization
673+
@Specialization(limit = "5")
675674
int count(VirtualFrame frame, PIBytesLike byteArray, Object arg,
675+
@CachedLibrary("arg") PythonObjectLibrary argLib,
676+
@CachedLibrary(limit = "1") PythonObjectLibrary otherLib,
676677
@Cached("createClassProfile()") ValueProfile storeProfile,
677-
@Cached("createNotNormalized()") SequenceStorageNodes.GetItemNode getItemNode,
678-
@Cached("create(__EQ__, __EQ__, __EQ__)") BinaryComparisonNode eqNode) {
678+
@Cached("createNotNormalized()") SequenceStorageNodes.GetItemNode getItemNode) {
679679

680680
SequenceStorage profiled = storeProfile.profile(byteArray.getSequenceStorage());
681681
int cnt = 0;
682682
for (int i = 0; i < profiled.length(); i++) {
683-
if (eqNode.executeBool(frame, arg, getItemNode.execute(frame, profiled, i))) {
683+
if (argLib.equals(arg, getItemNode.execute(frame, profiled, i), otherLib)) {
684684
cnt++;
685685
}
686686
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/foreign/ForeignObjectBuiltins.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ protected ForeignBinaryComparisonNode(BinaryComparisonNode genericOp) {
418418
Object doComparisonBool(VirtualFrame frame, Object left, Object right,
419419
@CachedLibrary(limit = "3") InteropLibrary lib) {
420420
try {
421-
return comparisonNode.executeBool(frame, lib.asBoolean(left), right);
421+
return comparisonNode.executeWith(frame, lib.asBoolean(left), right);
422422
} catch (UnsupportedMessageException e) {
423423
throw new IllegalStateException("object does not unpack to boolean for comparison as it claims to");
424424
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/list/ListBuiltins.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@
8181
import com.oracle.graal.python.nodes.builtins.ListNodes.IndexNode;
8282
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
8383
import com.oracle.graal.python.nodes.control.GetIteratorExpressionNode.GetIteratorNode;
84-
import com.oracle.graal.python.nodes.expression.BinaryComparisonNode;
8584
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
8685
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
8786
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
@@ -706,16 +705,17 @@ private int getLength(SequenceStorage s) {
706705
@GenerateNodeFactory
707706
public abstract static class ListCountNode extends PythonBuiltinNode {
708707

709-
@Specialization
708+
@Specialization(limit = "5")
710709
long count(VirtualFrame frame, PList self, Object value,
711710
@Cached("createNotNormalized()") SequenceStorageNodes.GetItemNode getItemNode,
712711
@Cached("create()") SequenceStorageNodes.LenNode lenNode,
713-
@Cached("create(__EQ__, __EQ__, __EQ__)") BinaryComparisonNode eqNode) {
712+
@CachedLibrary("value") PythonObjectLibrary valueLib,
713+
@CachedLibrary(limit = "16") PythonObjectLibrary otherLib) {
714714
long count = 0;
715715
SequenceStorage s = self.getSequenceStorage();
716716
for (int i = 0; i < lenNode.execute(s); i++) {
717717
Object object = getItemNode.execute(frame, s, i);
718-
if (eqNode.executeBool(frame, object, value)) {
718+
if (valueLib.equals(value, object, otherLib)) {
719719
count++;
720720
}
721721
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/object/ObjectBuiltins.java

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
import com.oracle.graal.python.nodes.object.GetLazyClassNode;
9393
import com.oracle.graal.python.nodes.object.IsBuiltinClassProfile;
9494
import com.oracle.truffle.api.CompilerDirectives;
95+
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
9596
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
9697
import com.oracle.truffle.api.dsl.Cached;
9798
import com.oracle.truffle.api.dsl.Fallback;
@@ -102,6 +103,7 @@
102103
import com.oracle.truffle.api.frame.VirtualFrame;
103104
import com.oracle.truffle.api.interop.UnsupportedMessageException;
104105
import com.oracle.truffle.api.library.CachedLibrary;
106+
import com.oracle.truffle.api.nodes.UnexpectedResultException;
105107
import com.oracle.truffle.api.profiles.BranchProfile;
106108
import com.oracle.truffle.api.profiles.ConditionProfile;
107109

@@ -118,7 +120,6 @@ protected List<? extends NodeFactory<? extends PythonBuiltinBaseNode>> getNodeFa
118120
abstract static class ClassNode extends PythonBinaryBuiltinNode {
119121
@Child private LookupAttributeInMRONode lookupSlotsInSelf;
120122
@Child private LookupAttributeInMRONode lookupSlotsInOther;
121-
@Child private BinaryComparisonNode slotsAreEqual;
122123
@Child private TypeNodes.GetNameNode getTypeNameNode;
123124

124125
private static final String ERROR_MESSAGE = "__class__ assignment only supported for heap types or ModuleType subclasses";
@@ -142,6 +143,7 @@ LazyPythonClass setClass(@SuppressWarnings("unused") Object self, @SuppressWarni
142143
@Specialization
143144
PNone setClass(VirtualFrame frame, PythonObject self, PythonAbstractClass value,
144145
@CachedLibrary(limit = "2") PythonObjectLibrary lib,
146+
@CachedLibrary(limit = "2") PythonObjectLibrary slotsLib,
145147
@Cached("create()") BranchProfile errorValueBranch,
146148
@Cached("create()") BranchProfile errorSelfBranch,
147149
@Cached("create()") BranchProfile errorSlotsBranch,
@@ -158,7 +160,7 @@ PNone setClass(VirtualFrame frame, PythonObject self, PythonAbstractClass value,
158160
Object selfSlots = getLookupSlotsInSelf().execute(lazyClass);
159161
if (selfSlots != PNone.NO_VALUE) {
160162
Object otherSlots = getLookupSlotsInOther().execute(value);
161-
if (otherSlots == PNone.NO_VALUE || !getSlotsAreEqual().executeBool(frame, selfSlots, otherSlots)) {
163+
if (otherSlots == PNone.NO_VALUE || !slotsLib.equals(selfSlots, otherSlots, slotsLib)) {
162164
errorSlotsBranch.enter();
163165
throw raise(TypeError, "__class__ assignment: '%s' object layout differs from '%s'", getTypeName(value), getTypeName(lazyClass));
164166
}
@@ -167,14 +169,6 @@ PNone setClass(VirtualFrame frame, PythonObject self, PythonAbstractClass value,
167169
return PNone.NONE;
168170
}
169171

170-
private BinaryComparisonNode getSlotsAreEqual() {
171-
if (slotsAreEqual == null) {
172-
CompilerDirectives.transferToInterpreterAndInvalidate();
173-
slotsAreEqual = insert(BinaryComparisonNode.create(__EQ__, null, "=="));
174-
}
175-
return slotsAreEqual;
176-
}
177-
178172
private LookupAttributeInMRONode getLookupSlotsInSelf() {
179173
if (lookupSlotsInSelf == null) {
180174
CompilerDirectives.transferToInterpreterAndInvalidate();
@@ -591,6 +585,7 @@ Object formatFail(@SuppressWarnings("unused") Object self, @SuppressWarnings("un
591585
@GenerateNodeFactory
592586
abstract static class RichCompareNode extends PythonTernaryBuiltinNode {
593587
protected static final int NO_SLOW_PATH = Integer.MAX_VALUE;
588+
@CompilationFinal private boolean seenNonBoolean = false;
594589

595590
protected BinaryComparisonNode createOp(String op) {
596591
return (BinaryComparisonNode) PythonLanguage.getCurrent().getNodeFactory().createComparisonOperation(op, null, null);
@@ -599,8 +594,19 @@ protected BinaryComparisonNode createOp(String op) {
599594
@Specialization(guards = "op.equals(cachedOp)", limit = "NO_SLOW_PATH")
600595
boolean richcmp(VirtualFrame frame, Object left, Object right, @SuppressWarnings("unused") String op,
601596
@SuppressWarnings("unused") @Cached("op") String cachedOp,
602-
@Cached("createOp(op)") BinaryComparisonNode node) {
603-
return node.executeBool(frame, left, right);
597+
@Cached("createOp(op)") BinaryComparisonNode node,
598+
@Cached("createIfTrueNode()") CoerceToBooleanNode castToBooleanNode) {
599+
if (!seenNonBoolean) {
600+
try {
601+
return node.executeBool(frame, left, right);
602+
} catch (UnexpectedResultException e) {
603+
CompilerDirectives.transferToInterpreterAndInvalidate();
604+
seenNonBoolean = true;
605+
return castToBooleanNode.executeBoolean(frame, e.getResult());
606+
}
607+
} else {
608+
return castToBooleanNode.executeBoolean(frame, node.executeWith(frame, left, right));
609+
}
604610
}
605611
}
606612

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/referencetype/ReferenceTypeBuiltins.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,9 @@ public String repr(PReferenceType self,
167167
@GenerateNodeFactory
168168
public abstract static class RefTypeEqNode extends PythonBuiltinNode {
169169
@Specialization(guards = {"self.getObject() != null", "other.getObject() != null"})
170-
boolean eq(VirtualFrame frame, PReferenceType self, PReferenceType other,
170+
Object eq(VirtualFrame frame, PReferenceType self, PReferenceType other,
171171
@Cached("create(__EQ__, __EQ__, __EQ__)") BinaryComparisonNode eqNode) {
172-
return eqNode.executeBool(frame, self.getObject(), other.getObject());
172+
return eqNode.executeWith(frame, self.getObject(), other.getObject());
173173
}
174174

175175
@Specialization(guards = "self.getObject() == null || other.getObject() == null")
@@ -183,9 +183,9 @@ boolean eq(PReferenceType self, PReferenceType other) {
183183
@GenerateNodeFactory
184184
public abstract static class RefTypeNeNode extends PythonBuiltinNode {
185185
@Specialization(guards = {"self.getObject() != null", "other.getObject() != null"})
186-
boolean ne(VirtualFrame frame, PReferenceType self, PReferenceType other,
186+
Object ne(VirtualFrame frame, PReferenceType self, PReferenceType other,
187187
@Cached("create(__NE__, __NE__, __NE__)") BinaryComparisonNode neNode) {
188-
return neNode.executeBool(frame, self.getObject(), other.getObject());
188+
return neNode.executeWith(frame, self.getObject(), other.getObject());
189189
}
190190

191191
@Specialization(guards = "self.getObject() == null || other.getObject() == null")

0 commit comments

Comments
 (0)