Skip to content

Commit 4326c20

Browse files
committed
Fix: LookupAndCallInplaceNode didn't correctly handle ternary in-place ops
1 parent cfd64d3 commit 4326c20

File tree

2 files changed

+93
-28
lines changed

2 files changed

+93
-28
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/expression/InplaceArithmetic.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
import com.oracle.graal.python.nodes.PRaiseNode;
4949
import com.oracle.graal.python.nodes.SpecialMethodNames;
50+
import com.oracle.truffle.api.nodes.Node.Child;
5051

5152
public enum InplaceArithmetic {
5253
IAdd(SpecialMethodNames.__IADD__, "+="),
@@ -65,13 +66,20 @@ public enum InplaceArithmetic {
6566

6667
private final String methodName;
6768
private final String operator;
69+
private final boolean isTernary;
6870
private final Supplier<LookupAndCallInplaceNode.NotImplementedHandler> notImplementedHandler;
6971

7072
InplaceArithmetic(String methodName, String operator) {
73+
this(methodName, operator, false);
74+
}
75+
76+
InplaceArithmetic(String methodName, String operator, boolean isTernary) {
7177
this.methodName = methodName;
7278
this.operator = operator;
79+
this.isTernary = isTernary;
7380
this.notImplementedHandler = () -> new LookupAndCallInplaceNode.NotImplementedHandler() {
74-
@Child private PRaiseNode raiseNode = PRaiseNode.create();
81+
@Child
82+
private PRaiseNode raiseNode = PRaiseNode.create();
7583

7684
@Override
7785
public Object execute(Object arg, Object arg2) {
@@ -93,6 +101,9 @@ public LookupAndCallInplaceNode create(ExpressionNode left, ExpressionNode right
93101
}
94102

95103
public LookupAndCallInplaceNode create() {
104+
if (isTernary) {
105+
return LookupAndCallInplaceNode.createWithTernary(methodName, null, null, null, notImplementedHandler);
106+
}
96107
return LookupAndCallInplaceNode.createWithBinary(methodName, null, null, notImplementedHandler);
97108
}
98109
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/expression/LookupAndCallInplaceNode.java

Lines changed: 81 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@
4545
import com.oracle.graal.python.nodes.PNodeWithContext;
4646
import com.oracle.graal.python.nodes.attributes.LookupInheritedAttributeNode;
4747
import com.oracle.graal.python.nodes.call.special.CallBinaryMethodNode;
48+
import com.oracle.graal.python.nodes.call.special.CallTernaryMethodNode;
4849
import com.oracle.graal.python.nodes.call.special.LookupAndCallBinaryNode;
50+
import com.oracle.graal.python.nodes.call.special.LookupAndCallTernaryNode;
51+
import com.oracle.graal.python.nodes.literal.ObjectLiteralNode;
4952
import com.oracle.graal.python.util.Supplier;
5053
import com.oracle.truffle.api.CompilerDirectives;
5154
import com.oracle.truffle.api.dsl.Cached;
@@ -55,6 +58,7 @@
5558

5659
@NodeChild("arg")
5760
@NodeChild("arg2")
61+
@NodeChild("arg3")
5862
public abstract class LookupAndCallInplaceNode extends ExpressionNode {
5963

6064
public abstract static class NotImplementedHandler extends PNodeWithContext {
@@ -66,7 +70,10 @@ public abstract static class NotImplementedHandler extends PNodeWithContext {
6670
protected final String reverseBinaryOpName;
6771
protected final Supplier<NotImplementedHandler> handlerFactory;
6872

69-
@Child private CallBinaryMethodNode dispatchNode;
73+
@Child private CallBinaryMethodNode callBinaryMethodNode;
74+
@Child private CallTernaryMethodNode callTernaryMethodNode;
75+
@Child private LookupAndCallBinaryNode lookupAndCallBinaryNode;
76+
@Child private LookupAndCallTernaryNode lookupAndCallTernaryNode;
7077
@Child private NotImplementedHandler handler;
7178

7279
LookupAndCallInplaceNode(String inplaceOpName, String binaryOpName, String reverseBinaryOpName, Supplier<NotImplementedHandler> handlerFactory) {
@@ -76,45 +83,82 @@ public abstract static class NotImplementedHandler extends PNodeWithContext {
7683
this.handlerFactory = handlerFactory;
7784
}
7885

79-
public static LookupAndCallInplaceNode create(String inplaceOpName) {
80-
return LookupAndCallInplaceNodeGen.create(inplaceOpName, null, null, null, null, null);
86+
public static LookupAndCallInplaceNode createWithBinary(String inplaceOpName, ExpressionNode left, ExpressionNode right, Supplier<LookupAndCallInplaceNode.NotImplementedHandler> handlerFactory) {
87+
String binaryOpName = inplaceOpName.replaceFirst("__i", "__");
88+
String reverseBinaryOpName = inplaceOpName.replaceFirst("__i", "__r");
89+
return LookupAndCallInplaceNodeGen.create(inplaceOpName, binaryOpName, reverseBinaryOpName, handlerFactory, left, right, new ObjectLiteralNode(PNone.NO_VALUE));
8190
}
8291

83-
public static LookupAndCallInplaceNode createWithBinary(String inplaceOpName, ExpressionNode left, ExpressionNode right, Supplier<LookupAndCallInplaceNode.NotImplementedHandler> handlerFactory) {
84-
return LookupAndCallInplaceNodeGen.create(inplaceOpName, inplaceOpName.replaceFirst("__i", "__"), inplaceOpName.replaceFirst("__i", "__r"), handlerFactory, left, right);
92+
public static LookupAndCallInplaceNode createWithTernary(String inplaceOpName, ExpressionNode x, ExpressionNode y, ExpressionNode z,
93+
Supplier<LookupAndCallInplaceNode.NotImplementedHandler> handlerFactory) {
94+
String binaryOpName = inplaceOpName.replaceFirst("__i", "__");
95+
String reverseBinaryOpName = inplaceOpName.replaceFirst("__i", "__r");
96+
return LookupAndCallInplaceNodeGen.create(inplaceOpName, binaryOpName, reverseBinaryOpName, handlerFactory, x, y, z);
8597
}
8698

8799
public static LookupAndCallInplaceNode create(String inplaceOpName, String binaryOpName) {
88-
return LookupAndCallInplaceNodeGen.create(inplaceOpName, binaryOpName, null, null, null, null);
100+
return LookupAndCallInplaceNodeGen.create(inplaceOpName, binaryOpName, null, null, null, null, null);
89101
}
90102

91103
public static LookupAndCallInplaceNode create(String inplaceOpName, String binaryOpName, String reverseBinaryOpName, Supplier<NotImplementedHandler> handlerFactory) {
92-
return LookupAndCallInplaceNodeGen.create(inplaceOpName, binaryOpName, reverseBinaryOpName, handlerFactory, null, null);
104+
return LookupAndCallInplaceNodeGen.create(inplaceOpName, binaryOpName, reverseBinaryOpName, handlerFactory, null, null, null);
105+
}
106+
107+
private CallBinaryMethodNode ensureBinaryCallNode() {
108+
if (callBinaryMethodNode == null) {
109+
CompilerDirectives.transferToInterpreterAndInvalidate();
110+
callBinaryMethodNode = insert(CallBinaryMethodNode.create());
111+
}
112+
return callBinaryMethodNode;
113+
}
114+
115+
private CallTernaryMethodNode ensureTernaryCallNode() {
116+
if (callTernaryMethodNode == null) {
117+
CompilerDirectives.transferToInterpreterAndInvalidate();
118+
callTernaryMethodNode = insert(CallTernaryMethodNode.create());
119+
}
120+
return callTernaryMethodNode;
121+
}
122+
123+
private LookupAndCallBinaryNode ensureLookupAndCallBinaryNode() {
124+
if (lookupAndCallBinaryNode == null) {
125+
CompilerDirectives.transferToInterpreterAndInvalidate();
126+
lookupAndCallBinaryNode = insert(LookupAndCallBinaryNode.create(binaryOpName, reverseBinaryOpName));
127+
}
128+
return lookupAndCallBinaryNode;
93129
}
94130

95-
private CallBinaryMethodNode ensureDispatch() {
96-
if (dispatchNode == null) {
131+
private LookupAndCallTernaryNode ensureLookupAndCallTernaryNode() {
132+
if (lookupAndCallTernaryNode == null) {
97133
CompilerDirectives.transferToInterpreterAndInvalidate();
98-
dispatchNode = insert(CallBinaryMethodNode.create());
134+
lookupAndCallTernaryNode = insert(LookupAndCallTernaryNode.createReversible(binaryOpName, null));
99135
}
100-
return dispatchNode;
136+
return lookupAndCallTernaryNode;
101137
}
102138

103-
protected boolean hasBinaryVersion() {
139+
protected boolean hasNonInplaceOperator() {
104140
return binaryOpName != null;
105141
}
106142

107-
public abstract Object execute(VirtualFrame frame, Object left, Object right);
143+
public final Object execute(VirtualFrame frame, Object left, Object right) {
144+
return executeTernary(frame, left, right, PNone.NO_VALUE);
145+
}
146+
147+
public abstract Object executeTernary(VirtualFrame frame, Object x, Object y, Object z);
108148

109-
@Specialization(guards = "!hasBinaryVersion()")
110-
Object callObject(VirtualFrame frame, Object left, Object right,
149+
@Specialization(guards = "!hasNonInplaceOperator()")
150+
Object doInplaceOnly(VirtualFrame frame, Object left, Object right, Object z,
111151
@Cached("create(inplaceOpName)") LookupInheritedAttributeNode getattr) {
112152
Object leftCallable = getattr.execute(left);
113153
Object result;
114154
if (leftCallable == PNone.NO_VALUE) {
115155
result = PNotImplemented.NOT_IMPLEMENTED;
116156
} else {
117-
result = ensureDispatch().executeObject(frame, leftCallable, left, right);
157+
if (z == PNone.NO_VALUE) {
158+
result = ensureBinaryCallNode().executeObject(frame, leftCallable, left, right);
159+
} else {
160+
result = ensureTernaryCallNode().execute(frame, leftCallable, left, right, z);
161+
}
118162
}
119163
if (handlerFactory != null && result == PNotImplemented.NOT_IMPLEMENTED) {
120164
if (handler == null) {
@@ -126,25 +170,35 @@ Object callObject(VirtualFrame frame, Object left, Object right,
126170
return result;
127171
}
128172

129-
@Specialization(guards = "hasBinaryVersion()")
130-
Object callObject(VirtualFrame frame, Object left, Object right,
131-
@Cached("create(inplaceOpName)") LookupInheritedAttributeNode getattrInplace,
132-
@Cached("create(binaryOpName, reverseBinaryOpName)") LookupAndCallBinaryNode binaryNode) {
173+
@Specialization(guards = "hasNonInplaceOperator()")
174+
Object doBinary(VirtualFrame frame, Object left, Object right, Object z,
175+
@Cached("create(inplaceOpName)") LookupInheritedAttributeNode getattrInplace) {
133176
Object result = PNotImplemented.NOT_IMPLEMENTED;
134177
Object inplaceCallable = getattrInplace.execute(left);
178+
boolean isBinary = z == PNone.NO_VALUE;
135179
if (inplaceCallable != PNone.NO_VALUE) {
136-
result = ensureDispatch().executeObject(frame, inplaceCallable, left, right);
137-
if (result != PNotImplemented.NOT_IMPLEMENTED) {
138-
return result;
180+
if (isBinary) {
181+
result = ensureBinaryCallNode().executeObject(frame, inplaceCallable, left, right);
182+
} else {
183+
result = ensureTernaryCallNode().execute(frame, inplaceCallable, left, right, z);
139184
}
140-
}
141-
if (binaryNode != null) {
142-
result = binaryNode.executeObject(frame, left, right);
143185
if (result != PNotImplemented.NOT_IMPLEMENTED) {
144186
return result;
145187
}
146188
}
147-
if (handlerFactory != null && result == PNotImplemented.NOT_IMPLEMENTED) {
189+
190+
// try non-inplace variant
191+
if (isBinary) {
192+
result = ensureLookupAndCallBinaryNode().executeObject(frame, left, right);
193+
} else {
194+
result = ensureLookupAndCallTernaryNode().execute(frame, left, right, z);
195+
}
196+
if (result != PNotImplemented.NOT_IMPLEMENTED) {
197+
return result;
198+
}
199+
200+
if (handlerFactory != null) {
201+
assert result == PNotImplemented.NOT_IMPLEMENTED;
148202
if (handler == null) {
149203
CompilerDirectives.transferToInterpreterAndInvalidate();
150204
handler = insert(handlerFactory.get());

0 commit comments

Comments
 (0)