Skip to content

Commit a2c22e0

Browse files
committed
implement reversible ternary operations (just pow atm)
1 parent 78e115f commit a2c22e0

File tree

9 files changed

+247
-50
lines changed

9 files changed

+247
-50
lines changed

graalpython/com.oracle.graal.python.test/src/tests/cpyext/test_object.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ def test_pow(self):
7171
nb_power="test_pow"
7272
)
7373
tester = TestPow()
74-
assert tester ** 12 == (tester, 12, None)
75-
assert 12 ** tester == (12, tester, None)
76-
assert pow(tester, 48, 2) == (tester, 48, 2)
77-
assert pow(48, tester, 2) == (48, tester, 2)
74+
assert tester ** 12 == (tester, 12, None), tester ** 12
75+
assert 12 ** tester == (12, tester, None), 12 ** tester
76+
assert pow(tester, 48, 2) == (tester, 48, 2), pow(tester, 48, 2)
77+
assert pow(48, tester, 2) == (48, tester, 2), pow(48, tester, 2)
7878

7979
def test_int(self):
8080
TestInt = CPyExtType("TestInt",

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinFunctions.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import static com.oracle.graal.python.nodes.BuiltinNames.MIN;
4646
import static com.oracle.graal.python.nodes.BuiltinNames.NEXT;
4747
import static com.oracle.graal.python.nodes.BuiltinNames.ORD;
48+
import static com.oracle.graal.python.nodes.BuiltinNames.POW;
4849
import static com.oracle.graal.python.nodes.BuiltinNames.PRINT;
4950
import static com.oracle.graal.python.nodes.BuiltinNames.REPR;
5051
import static com.oracle.graal.python.nodes.BuiltinNames.ROUND;
@@ -101,6 +102,7 @@
101102
import com.oracle.graal.python.nodes.attributes.WriteAttributeToObjectNode;
102103
import com.oracle.graal.python.nodes.call.CallNode;
103104
import com.oracle.graal.python.nodes.call.special.LookupAndCallBinaryNode;
105+
import com.oracle.graal.python.nodes.call.special.LookupAndCallTernaryNode;
104106
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
105107
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode.NoAttributeHandler;
106108
import com.oracle.graal.python.nodes.classes.IsSubtypeNode;
@@ -109,6 +111,7 @@
109111
import com.oracle.graal.python.nodes.expression.BinaryArithmetic;
110112
import com.oracle.graal.python.nodes.expression.BinaryComparisonNode;
111113
import com.oracle.graal.python.nodes.expression.CastToBooleanNode;
114+
import com.oracle.graal.python.nodes.expression.TernaryArithmetic;
112115
import com.oracle.graal.python.nodes.frame.ReadCallerFrameNode;
113116
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
114117
import com.oracle.graal.python.nodes.function.builtins.PythonTernaryBuiltinNode;
@@ -1056,6 +1059,17 @@ public Object doIt(Object[] args) {
10561059
}
10571060
}
10581061

1062+
@Builtin(name = POW, minNumOfArguments = 2, keywordArguments = {"z"})
1063+
@GenerateNodeFactory
1064+
public abstract static class PowNode extends PythonBuiltinNode {
1065+
@Child LookupAndCallTernaryNode powNode = TernaryArithmetic.Pow.create();
1066+
1067+
@Specialization
1068+
Object doIt(Object x, Object y, Object z) {
1069+
return powNode.execute(x, y, z);
1070+
}
1071+
}
1072+
10591073
// sum(iterable[, start])
10601074
@Builtin(name = SUM, fixedNumOfArguments = 1, keywordArguments = {"start"})
10611075
@GenerateNodeFactory

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/ints/IntBuiltins.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,12 @@ PInt doPInt(PInt left, PInt right, @SuppressWarnings("unused") PNone none) {
893893
return factory().createInt((long) value);
894894
}
895895

896+
@Fallback
897+
@SuppressWarnings("unused")
898+
Object doFallback(Object x, Object y, Object z) {
899+
return PNotImplemented.NOT_IMPLEMENTED;
900+
}
901+
896902
@TruffleBoundary
897903
private BigInteger op(BigInteger a, long b) {
898904
try {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/NodeFactory.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
import com.oracle.graal.python.nodes.expression.InplaceArithmetic;
6565
import com.oracle.graal.python.nodes.expression.IsNode;
6666
import com.oracle.graal.python.nodes.expression.OrNode;
67+
import com.oracle.graal.python.nodes.expression.TernaryArithmetic;
6768
import com.oracle.graal.python.nodes.expression.UnaryArithmetic;
6869
import com.oracle.graal.python.nodes.frame.DeleteGlobalNode;
6970
import com.oracle.graal.python.nodes.frame.DestructuringAssignmentNode;
@@ -359,7 +360,7 @@ public PNode createBinaryOperation(String string, PNode left, PNode right) {
359360
case "%":
360361
return BinaryArithmetic.Mod.create(left, right);
361362
case "**":
362-
return BinaryArithmetic.Pow.create(left, right);
363+
return TernaryArithmetic.Pow.create(left, right);
363364
case "<<":
364365
return BinaryArithmetic.LShift.create(left, right);
365366
case ">>":

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/call/special/LookupAndCallSpecialNode.java

Lines changed: 0 additions & 38 deletions
This file was deleted.

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/call/special/LookupAndCallTernaryNode.java

Lines changed: 131 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,38 +38,164 @@
3838
*/
3939
package com.oracle.graal.python.nodes.call.special;
4040

41+
import java.util.function.Supplier;
42+
43+
import com.oracle.graal.python.builtins.objects.PNone;
44+
import com.oracle.graal.python.builtins.objects.PNotImplemented;
45+
import com.oracle.graal.python.builtins.objects.type.PythonClass;
46+
import com.oracle.graal.python.nodes.EmptyNode;
47+
import com.oracle.graal.python.nodes.PBaseNode;
48+
import com.oracle.graal.python.nodes.PNode;
4149
import com.oracle.graal.python.nodes.attributes.GetAttributeNode;
50+
import com.oracle.graal.python.nodes.attributes.LookupInheritedAttributeNode;
51+
import com.oracle.graal.python.nodes.classes.IsSubtypeNode;
4252
import com.oracle.graal.python.nodes.object.GetClassNode;
53+
import com.oracle.truffle.api.CompilerDirectives;
4354
import com.oracle.truffle.api.dsl.Cached;
55+
import com.oracle.truffle.api.dsl.NodeChild;
56+
import com.oracle.truffle.api.dsl.NodeChildren;
4457
import com.oracle.truffle.api.dsl.Specialization;
58+
import com.oracle.truffle.api.profiles.BranchProfile;
59+
60+
@NodeChildren({@NodeChild("arg"), @NodeChild("arg2"), @NodeChild("arg3")})
61+
public abstract class LookupAndCallTernaryNode extends PNode {
62+
63+
public abstract static class NotImplementedHandler extends PBaseNode {
64+
public abstract Object execute(Object arg, Object arg2, Object arg3);
65+
}
4566

46-
public abstract class LookupAndCallTernaryNode extends LookupAndCallSpecialNode {
4767
private final String name;
68+
private final boolean isReversible;
4869
@Child private CallTernaryMethodNode dispatchNode = CallTernaryMethodNode.create();
70+
@Child private CallTernaryMethodNode reverseDispatchNode;
71+
@Child private CallTernaryMethodNode thirdDispatchNode;
72+
@Child private LookupInheritedAttributeNode getThirdAttrNode;
73+
@Child private NotImplementedHandler handler;
74+
protected final Supplier<NotImplementedHandler> handlerFactory;
4975

5076
public abstract Object execute(Object arg1, Object arg2, Object arg3);
5177

5278
public abstract Object execute(Object arg1, int arg2, Object arg3);
5379

5480
public static LookupAndCallTernaryNode create(String name) {
55-
return LookupAndCallTernaryNodeGen.create(name);
81+
return LookupAndCallTernaryNodeGen.create(name, false, null, null, null, null);
82+
}
83+
84+
public static LookupAndCallTernaryNode createReversible(String name, Supplier<NotImplementedHandler> handlerFactory) {
85+
return LookupAndCallTernaryNodeGen.create(name, true, handlerFactory, null, null, null);
5686
}
5787

58-
LookupAndCallTernaryNode(String name) {
88+
public static LookupAndCallTernaryNode createReversible(String name, Supplier<NotImplementedHandler> handlerFactory, PNode x, PNode y) {
89+
return LookupAndCallTernaryNodeGen.create(name, true, handlerFactory, x, y, EmptyNode.create());
90+
}
91+
92+
LookupAndCallTernaryNode(String name, boolean isReversible, Supplier<NotImplementedHandler> handlerFactory) {
5993
this.name = name;
94+
this.isReversible = isReversible;
95+
this.handlerFactory = handlerFactory;
6096
}
6197

62-
@Specialization
98+
protected boolean isReversible() {
99+
return isReversible;
100+
}
101+
102+
@Specialization(guards = "!isReversible()")
63103
Object callObject(Object arg1, int arg2, Object arg3,
64104
@Cached("create()") GetClassNode getclass,
65105
@Cached("create()") GetAttributeNode getattr) {
66106
return dispatchNode.execute(getattr.execute(getclass.execute(arg1), name), arg1, arg2, arg3);
67107
}
68108

69-
@Specialization
109+
@Specialization(guards = "!isReversible()")
70110
Object callObject(Object arg1, Object arg2, Object arg3,
71111
@Cached("create()") GetClassNode getclass,
72112
@Cached("create()") GetAttributeNode getattr) {
73113
return dispatchNode.execute(getattr.execute(getclass.execute(arg1), name), arg1, arg2, arg3);
74114
}
115+
116+
private CallTernaryMethodNode ensureReverseDispatch() {
117+
// this also serves as a branch profile
118+
if (reverseDispatchNode == null) {
119+
CompilerDirectives.transferToInterpreterAndInvalidate();
120+
reverseDispatchNode = insert(CallTernaryMethodNode.create());
121+
}
122+
return reverseDispatchNode;
123+
}
124+
125+
private LookupInheritedAttributeNode ensureGetAttrZ() {
126+
// this also serves as a branch profile
127+
if (getThirdAttrNode == null) {
128+
CompilerDirectives.transferToInterpreterAndInvalidate();
129+
getThirdAttrNode = insert(LookupInheritedAttributeNode.create());
130+
}
131+
return getThirdAttrNode;
132+
}
133+
134+
private CallTernaryMethodNode ensureThirdDispatch() {
135+
// this also serves as a branch profile
136+
if (thirdDispatchNode == null) {
137+
CompilerDirectives.transferToInterpreterAndInvalidate();
138+
thirdDispatchNode = insert(CallTernaryMethodNode.create());
139+
}
140+
return thirdDispatchNode;
141+
}
142+
143+
@Specialization(guards = "isReversible()")
144+
Object callObject(Object v, Object w, Object z,
145+
@Cached("create()") LookupInheritedAttributeNode getattr,
146+
@Cached("create()") LookupInheritedAttributeNode getattrR,
147+
@Cached("create()") GetClassNode getClass,
148+
@Cached("create()") GetClassNode getClassR,
149+
@Cached("create()") IsSubtypeNode isSubtype,
150+
@Cached("create()") BranchProfile notImplementedBranch) {
151+
Object result = PNotImplemented.NOT_IMPLEMENTED;
152+
Object leftCallable = getattr.execute(v, name);
153+
Object rightCallable = PNone.NO_VALUE;
154+
155+
PythonClass leftClass = getClass.execute(v);
156+
PythonClass rightClass = getClassR.execute(w);
157+
if (leftClass != rightClass) {
158+
rightCallable = getattrR.execute(w, name);
159+
if (rightCallable == leftCallable) {
160+
rightCallable = PNone.NO_VALUE;
161+
}
162+
}
163+
if (leftCallable != PNone.NO_VALUE) {
164+
if (rightCallable != PNone.NO_VALUE && isSubtype.execute(rightClass, leftClass)) {
165+
result = ensureReverseDispatch().execute(rightCallable, v, w, z);
166+
if (result != PNotImplemented.NOT_IMPLEMENTED) {
167+
return result;
168+
}
169+
rightCallable = PNone.NO_VALUE;
170+
}
171+
result = dispatchNode.execute(leftCallable, v, w, z);
172+
if (result != PNotImplemented.NOT_IMPLEMENTED) {
173+
return result;
174+
}
175+
}
176+
if (rightCallable != PNone.NO_VALUE) {
177+
result = ensureReverseDispatch().execute(rightCallable, v, w, z);
178+
if (result != PNotImplemented.NOT_IMPLEMENTED) {
179+
return result;
180+
}
181+
}
182+
183+
Object zCallable = ensureGetAttrZ().execute(z, name);
184+
if (zCallable != PNone.NO_VALUE && zCallable != leftCallable && zCallable != rightCallable) {
185+
ensureThirdDispatch().execute(zCallable, v, w, z);
186+
if (result != PNotImplemented.NOT_IMPLEMENTED) {
187+
return result;
188+
}
189+
}
190+
191+
notImplementedBranch.enter();
192+
if (handlerFactory != null) {
193+
if (handler == null) {
194+
CompilerDirectives.transferToInterpreterAndInvalidate();
195+
handler = insert(handlerFactory.get());
196+
}
197+
return handler.execute(v, w, z);
198+
}
199+
return result;
200+
}
75201
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/call/special/LookupAndCallVarargsNode.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@
3939
package com.oracle.graal.python.nodes.call.special;
4040

4141
import com.oracle.graal.python.builtins.objects.function.PKeyword;
42+
import com.oracle.graal.python.nodes.PBaseNode;
4243
import com.oracle.graal.python.nodes.attributes.LookupInheritedAttributeNode;
4344
import com.oracle.truffle.api.dsl.Cached;
4445
import com.oracle.truffle.api.dsl.Specialization;
4546

46-
public abstract class LookupAndCallVarargsNode extends LookupAndCallSpecialNode {
47+
public abstract class LookupAndCallVarargsNode extends PBaseNode {
4748
private final String name;
4849
@Child private CallVarargsMethodNode dispatchNode = CallVarargsMethodNode.create();
4950

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/expression/BinaryArithmetic.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ public enum BinaryArithmetic {
5454
TrueDiv(SpecialMethodNames.__TRUEDIV__, "/"),
5555
FloorDiv(SpecialMethodNames.__FLOORDIV__, "//"),
5656
Mod(SpecialMethodNames.__MOD__, "%"),
57-
Pow(SpecialMethodNames.__POW__, "**"),
5857
LShift(SpecialMethodNames.__LSHIFT__, "<<"),
5958
RShift(SpecialMethodNames.__RSHIFT__, ">>"),
6059
And(SpecialMethodNames.__AND__, "&"),

0 commit comments

Comments
 (0)