Skip to content

Commit 07bd7e8

Browse files
committed
add fast paths for common unary operations
1 parent 39e2664 commit 07bd7e8

File tree

3 files changed

+135
-54
lines changed

3 files changed

+135
-54
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/PythonCextBuiltins.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@
221221
import com.oracle.graal.python.nodes.expression.InplaceArithmetic;
222222
import com.oracle.graal.python.nodes.expression.LookupAndCallInplaceNode;
223223
import com.oracle.graal.python.nodes.expression.UnaryArithmetic;
224+
import com.oracle.graal.python.nodes.expression.UnaryOpNode;
224225
import com.oracle.graal.python.nodes.frame.GetCurrentFrameRef;
225226
import com.oracle.graal.python.nodes.function.FunctionRootNode;
226227
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
@@ -3614,12 +3615,12 @@ abstract static class PyNumberUnaryOp extends PythonBinaryBuiltinNode {
36143615
@Specialization(guards = {"cachedOp == op", "left.isIntLike()"}, limit = "MAX_CACHE_SIZE")
36153616
static Object doIntLikePrimitiveWrapper(VirtualFrame frame, PrimitiveNativeWrapper left, @SuppressWarnings("unused") int op,
36163617
@Cached("op") @SuppressWarnings("unused") int cachedOp,
3617-
@Cached("createCallNode(op)") LookupAndCallUnaryNode callNode,
3618+
@Cached("createCallNode(op)") UnaryOpNode callNode,
36183619
@Cached ToNewRefNode toSulongNode,
36193620
@Cached TransformExceptionToNativeNode transformExceptionToNativeNode,
36203621
@Cached GetNativeNullNode getNativeNullNode) {
36213622
try {
3622-
return toSulongNode.execute(callNode.executeObject(frame, left.getLong()));
3623+
return toSulongNode.execute(callNode.execute(frame, left.getLong()));
36233624
} catch (PException e) {
36243625
transformExceptionToNativeNode.execute(e);
36253626
return toSulongNode.execute(getNativeNullNode.execute());
@@ -3630,7 +3631,7 @@ static Object doIntLikePrimitiveWrapper(VirtualFrame frame, PrimitiveNativeWrapp
36303631
static Object doObject(VirtualFrame frame, Object left, @SuppressWarnings("unused") int op,
36313632
@Cached AsPythonObjectNode leftToJava,
36323633
@Cached("op") @SuppressWarnings("unused") int cachedOp,
3633-
@Cached("createCallNode(op)") LookupAndCallUnaryNode callNode,
3634+
@Cached("createCallNode(op)") UnaryOpNode callNode,
36343635
@Cached ToNewRefNode toSulongNode,
36353636
@Cached TransformExceptionToNativeNode transformExceptionToNativeNode,
36363637
@Cached GetNativeNullNode getNativeNullNode) {
@@ -3643,7 +3644,7 @@ static Object doObject(VirtualFrame frame, Object left, @SuppressWarnings("unuse
36433644
} else {
36443645
leftValue = leftToJava.execute(left);
36453646
}
3646-
result = callNode.executeObject(frame, leftValue);
3647+
result = callNode.execute(frame, leftValue);
36473648
} catch (PException e) {
36483649
transformExceptionToNativeNode.execute(e);
36493650
result = getNativeNullNode.execute();
@@ -3654,7 +3655,7 @@ static Object doObject(VirtualFrame frame, Object left, @SuppressWarnings("unuse
36543655
/**
36553656
* This needs to stay in sync with {@code abstract.c: enum e_unaryop}.
36563657
*/
3657-
static LookupAndCallUnaryNode createCallNode(int op) {
3658+
static UnaryOpNode createCallNode(int op) {
36583659
UnaryArithmetic unaryArithmetic;
36593660
switch (op) {
36603661
case 0:

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/expression/UnaryArithmetic.java

Lines changed: 126 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -52,58 +52,25 @@
5252
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode.NoAttributeHandler;
5353
import com.oracle.graal.python.util.Supplier;
5454
import com.oracle.truffle.api.CompilerDirectives;
55+
import com.oracle.truffle.api.dsl.Cached;
56+
import com.oracle.truffle.api.dsl.ImportStatic;
57+
import com.oracle.truffle.api.dsl.Specialization;
5558
import com.oracle.truffle.api.frame.VirtualFrame;
56-
import com.oracle.truffle.api.nodes.NodeCost;
5759
import com.oracle.truffle.api.nodes.RootNode;
5860

5961
public enum UnaryArithmetic {
60-
Pos(SpecialMethodNames.__POS__, "+"),
61-
Neg(SpecialMethodNames.__NEG__, "-"),
62-
Invert(SpecialMethodNames.__INVERT__, "~");
63-
64-
private final String methodName;
65-
private final String operator;
66-
private final Supplier<NoAttributeHandler> noAttributeHandler;
67-
68-
UnaryArithmetic(String methodName, String operator) {
69-
this.methodName = methodName;
70-
this.operator = operator;
71-
this.noAttributeHandler = () -> new NoAttributeHandler() {
72-
@Child private PRaiseNode raiseNode = PRaiseNode.create();
73-
74-
@Override
75-
public Object execute(Object receiver) {
76-
throw raiseNode.raise(TypeError, ErrorMessages.BAD_OPERAND_FOR, "unary ", operator, receiver);
77-
}
78-
};
79-
}
80-
81-
public String getMethodName() {
82-
return methodName;
83-
}
62+
Pos(UnaryArithmeticFactory.PosNodeGen::create),
63+
Neg(UnaryArithmeticFactory.NegNodeGen::create),
64+
Invert(UnaryArithmeticFactory.InvertNodeGen::create);
8465

85-
public String getOperator() {
86-
return operator;
66+
interface CreateUnaryOp {
67+
UnaryOpNode create(ExpressionNode left);
8768
}
8869

89-
public static final class UnaryArithmeticExpression extends ExpressionNode {
90-
@Child private LookupAndCallUnaryNode callNode;
91-
@Child private ExpressionNode operand;
92-
93-
private UnaryArithmeticExpression(LookupAndCallUnaryNode callNode, ExpressionNode operand) {
94-
this.callNode = callNode;
95-
this.operand = operand;
96-
}
97-
98-
@Override
99-
public Object execute(VirtualFrame frame) {
100-
return callNode.executeObject(frame, operand.execute(frame));
101-
}
70+
private final CreateUnaryOp create;
10271

103-
@Override
104-
public NodeCost getCost() {
105-
return NodeCost.NONE;
106-
}
72+
UnaryArithmetic(CreateUnaryOp create) {
73+
this.create = create;
10774
}
10875

10976
/**
@@ -115,7 +82,7 @@ public NodeCost getCost() {
11582
static final class CallUnaryArithmeticRootNode extends CallArithmeticRootNode {
11683
private static final Signature SIGNATURE_UNARY = new Signature(1, false, -1, false, new String[]{"$self"}, null);
11784

118-
@Child private LookupAndCallUnaryNode callUnaryNode;
85+
@Child private UnaryOpNode callUnaryNode;
11986

12087
private final UnaryArithmetic unaryOperator;
12188

@@ -135,16 +102,16 @@ protected Object doCall(VirtualFrame frame) {
135102
CompilerDirectives.transferToInterpreterAndInvalidate();
136103
callUnaryNode = insert(unaryOperator.create());
137104
}
138-
return callUnaryNode.executeObject(frame, PArguments.getArgument(frame, 0));
105+
return callUnaryNode.execute(frame, PArguments.getArgument(frame, 0));
139106
}
140107
}
141108

142109
public ExpressionNode create(ExpressionNode receiver) {
143-
return new UnaryArithmeticExpression(LookupAndCallUnaryNode.create(methodName, noAttributeHandler), receiver);
110+
return create.create(receiver);
144111
}
145112

146-
public LookupAndCallUnaryNode create() {
147-
return LookupAndCallUnaryNode.create(methodName, noAttributeHandler);
113+
public UnaryOpNode create() {
114+
return create.create(null);
148115
}
149116

150117
/**
@@ -154,4 +121,114 @@ public LookupAndCallUnaryNode create() {
154121
public RootNode createRootNode(PythonLanguage language) {
155122
return new CallUnaryArithmeticRootNode(language, this);
156123
}
124+
125+
@ImportStatic(SpecialMethodNames.class)
126+
public abstract static class UnaryArithmeticNode extends UnaryOpNode {
127+
128+
static Supplier<NoAttributeHandler> createHandler(String operator) {
129+
130+
return () -> new NoAttributeHandler() {
131+
@Child private PRaiseNode raiseNode = PRaiseNode.create();
132+
133+
@Override
134+
public Object execute(Object receiver) {
135+
throw raiseNode.raise(TypeError, ErrorMessages.BAD_OPERAND_FOR, "unary ", operator, receiver);
136+
}
137+
};
138+
}
139+
140+
static LookupAndCallUnaryNode createCallNode(String name, Supplier<NoAttributeHandler> handler) {
141+
return LookupAndCallUnaryNode.create(name, handler);
142+
}
143+
}
144+
145+
/*
146+
*
147+
* All the following fast paths need to be kept in sync with the corresponding builtin functions
148+
* in IntBuiltins and FloatBuiltins.
149+
*
150+
*/
151+
152+
public abstract static class PosNode extends UnaryArithmeticNode {
153+
154+
static final Supplier<NoAttributeHandler> NOT_IMPLEMENTED = createHandler("+");
155+
156+
@Specialization
157+
static int pos(int arg) {
158+
return arg;
159+
}
160+
161+
@Specialization
162+
static long pos(long arg) {
163+
return arg;
164+
}
165+
166+
@Specialization
167+
static double pos(double arg) {
168+
return arg;
169+
}
170+
171+
@Specialization
172+
static Object doGeneric(VirtualFrame frame, Object arg,
173+
@Cached("createCallNode(__POS__, NOT_IMPLEMENTED)") LookupAndCallUnaryNode callNode) {
174+
return callNode.executeObject(frame, arg);
175+
}
176+
}
177+
178+
public abstract static class NegNode extends UnaryArithmeticNode {
179+
180+
static final Supplier<NoAttributeHandler> NOT_IMPLEMENTED = createHandler("-");
181+
182+
@Specialization(rewriteOn = ArithmeticException.class)
183+
static int neg(int arg) {
184+
return Math.negateExact(arg);
185+
}
186+
187+
@Specialization
188+
static long negOvf(int arg) {
189+
return -((long) arg);
190+
}
191+
192+
@Specialization(rewriteOn = ArithmeticException.class)
193+
static long neg(long arg) {
194+
return Math.negateExact(arg);
195+
}
196+
197+
@Specialization
198+
static double neg(double arg) {
199+
return -arg;
200+
}
201+
202+
@Specialization
203+
static Object doGeneric(VirtualFrame frame, Object arg,
204+
@Cached("createCallNode(__NEG__, NOT_IMPLEMENTED)") LookupAndCallUnaryNode callNode) {
205+
return callNode.executeObject(frame, arg);
206+
}
207+
}
208+
209+
public abstract static class InvertNode extends UnaryArithmeticNode {
210+
211+
static final Supplier<NoAttributeHandler> NOT_IMPLEMENTED = createHandler("~");
212+
213+
@Specialization
214+
static int invert(boolean arg) {
215+
return ~(arg ? 1 : 0);
216+
}
217+
218+
@Specialization
219+
static int invert(int arg) {
220+
return ~arg;
221+
}
222+
223+
@Specialization
224+
static long invert(long arg) {
225+
return ~arg;
226+
}
227+
228+
@Specialization
229+
static Object doGeneric(VirtualFrame frame, Object arg,
230+
@Cached("createCallNode(__INVERT__, NOT_IMPLEMENTED)") LookupAndCallUnaryNode callNode) {
231+
return callNode.executeObject(frame, arg);
232+
}
233+
}
157234
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/expression/UnaryOpNode.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@
2626
package com.oracle.graal.python.nodes.expression;
2727

2828
import com.oracle.truffle.api.dsl.NodeChild;
29+
import com.oracle.truffle.api.frame.VirtualFrame;
2930

3031
@NodeChild(value = "operand", type = ExpressionNode.class)
3132
public abstract class UnaryOpNode extends ExpressionNode {
3233
public abstract ExpressionNode getOperand();
34+
35+
public abstract Object execute(VirtualFrame frame, Object value);
3336
}

0 commit comments

Comments
 (0)