Skip to content

Commit fade08a

Browse files
committed
Provide generic way to use unary/binary/ternary/in-place operators without frame
1 parent 989adfd commit fade08a

File tree

7 files changed

+407
-151
lines changed

7 files changed

+407
-151
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/PythonLanguage.java

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
package com.oracle.graal.python;
2727

2828
import java.io.IOException;
29+
import java.lang.ref.WeakReference;
2930
import java.util.ArrayList;
3031
import java.util.Arrays;
3132
import java.util.concurrent.ConcurrentHashMap;
@@ -49,7 +50,11 @@
4950
import com.oracle.graal.python.nodes.NodeFactory;
5051
import com.oracle.graal.python.nodes.call.InvokeNode;
5152
import com.oracle.graal.python.nodes.control.TopLevelExceptionHandler;
53+
import com.oracle.graal.python.nodes.expression.BinaryArithmetic;
5254
import com.oracle.graal.python.nodes.expression.ExpressionNode;
55+
import com.oracle.graal.python.nodes.expression.InplaceArithmetic;
56+
import com.oracle.graal.python.nodes.expression.TernaryArithmetic;
57+
import com.oracle.graal.python.nodes.expression.UnaryArithmetic;
5358
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
5459
import com.oracle.graal.python.parser.PythonParserImpl;
5560
import com.oracle.graal.python.runtime.PythonContext;
@@ -141,6 +146,15 @@ public final class PythonLanguage extends TruffleLanguage<PythonContext> {
141146

142147
private final NodeFactory nodeFactory;
143148
private final ConcurrentHashMap<String, RootCallTarget> builtinCallTargetCache = new ConcurrentHashMap<>();
149+
/**
150+
* A thread-safe map that maps arithmetic operators (i.e.
151+
* {@link com.oracle.graal.python.nodes.expression.UnaryArithmetic},
152+
* {@link com.oracle.graal.python.nodes.expression.BinaryArithmetic},
153+
* {@link com.oracle.graal.python.nodes.expression.TernaryArithmetic}, and
154+
* {@link com.oracle.graal.python.nodes.expression.InplaceArithmetic}) to call targets. Use this
155+
* map to retrieve a singleton instance (per engine) such that proper AST sharing is possible.
156+
*/
157+
private ConcurrentHashMap<Object, WeakReference<RootCallTarget>> arithmeticOperatorCallTargetCache;
144158

145159
@CompilationFinal(dimensions = 1) private static final Object[] CONTEXT_INSENSITIVE_SINGLETONS = new Object[]{PNone.NONE, PNone.NO_VALUE, PEllipsis.INSTANCE, PNotImplemented.NOT_IMPLEMENTED};
146160

@@ -645,4 +659,72 @@ public RootCallTarget getOrComputeBuiltinCallTarget(Builtin builtin, Class<? ext
645659
String key = builtin.name() + nodeClass.getName();
646660
return builtinCallTargetCache.computeIfAbsent(key, (k) -> supplier.apply(builtin));
647661
}
662+
663+
/**
664+
* Retrieve a call target for the given {@link UnaryArithmetic} operator. If the no such call
665+
* target exists yet, it will be created lazily. This method is thread-safe and should be used
666+
* for all contexts in this engine to enable AST sharing.
667+
*/
668+
public RootCallTarget getOrCreateUnaryArithmeticCallTarget(UnaryArithmetic unaryOperator) {
669+
return getOrCreateArithmeticCallTarget(unaryOperator, unaryOperator::createCallTarget);
670+
}
671+
672+
/**
673+
* Retrieve a call target for the given {@link BinaryArithmetic} operator. If the no such call
674+
* target exists yet, it will be created lazily. This method is thread-safe and should be used
675+
* for all contexts in this engine to enable AST sharing.
676+
*/
677+
public RootCallTarget getOrCreateBinaryArithmeticCallTarget(BinaryArithmetic unaryOperator) {
678+
return getOrCreateArithmeticCallTarget(unaryOperator, unaryOperator::createCallTarget);
679+
}
680+
681+
/**
682+
* Retrieve a call target for the given {@link TernaryArithmetic} operator. If the no such call
683+
* target exists yet, it will be created lazily. This method is thread-safe and should be used
684+
* for all contexts in this engine to enable AST sharing.
685+
*/
686+
public RootCallTarget getOrCreateTernaryArithmeticCallTarget(TernaryArithmetic unaryOperator) {
687+
return getOrCreateArithmeticCallTarget(unaryOperator, unaryOperator::createCallTarget);
688+
}
689+
690+
/**
691+
* Retrieve a call target for the given {@link InplaceArithmetic} operator. If the no such call
692+
* target exists yet, it will be created lazily. This method is thread-safe and should be used
693+
* for all contexts in this engine to enable AST sharing.
694+
*/
695+
public RootCallTarget getOrCreateInplaceArithmeticCallTarget(InplaceArithmetic unaryOperator) {
696+
return getOrCreateArithmeticCallTarget(unaryOperator, unaryOperator::createCallTarget);
697+
}
698+
699+
private RootCallTarget getOrCreateArithmeticCallTarget(Object arithmeticOperator, Function<PythonLanguage, RootCallTarget> supplier) {
700+
if (arithmeticOperatorCallTargetCache == null) {
701+
synchronized (this) {
702+
// need to do check a second time (now synchronized)
703+
if (arithmeticOperatorCallTargetCache == null) {
704+
arithmeticOperatorCallTargetCache = new ConcurrentHashMap<>();
705+
}
706+
}
707+
}
708+
WeakReference<RootCallTarget> ctRef = arithmeticOperatorCallTargetCache.compute(arithmeticOperator, (k, v) -> {
709+
RootCallTarget cachedCallTarget = v != null ? v.get() : null;
710+
if (cachedCallTarget == null) {
711+
return new WeakReference<>(supplier.apply(this));
712+
}
713+
return v;
714+
});
715+
716+
RootCallTarget callTarget = ctRef.get();
717+
if (callTarget == null) {
718+
// Bad luck: we ensured that there is a mapping in the cache but the weak value got
719+
// collected before we could strongly reference it. Now, we need to be conservative and
720+
// create the call target eagerly, hold a strong reference to it until we've put it into
721+
// the map.
722+
final RootCallTarget callTargetToCache = supplier.apply(this);
723+
callTarget = callTargetToCache;
724+
arithmeticOperatorCallTargetCache.computeIfAbsent(arithmeticOperator, (k) -> new WeakReference<>(callTargetToCache));
725+
}
726+
assert callTarget != null;
727+
return callTarget;
728+
729+
}
648730
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/cext/hpy/GraalHPyContextFunctions.java

Lines changed: 11 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,7 @@
5858
import static com.oracle.graal.python.builtins.objects.cext.hpy.GraalHPyNativeSymbols.GRAAL_HPY_MODULE_GET_LEGACY_METHODS;
5959
import static com.oracle.graal.python.builtins.objects.cext.hpy.GraalHPyNativeSymbols.GRAAL_HPY_WRITE_PTR;
6060

61-
import java.lang.ref.WeakReference;
6261
import java.nio.charset.StandardCharsets;
63-
import java.util.WeakHashMap;
64-
65-
import org.graalvm.collections.Pair;
6662

6763
import com.oracle.graal.python.PythonLanguage;
6864
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
@@ -103,7 +99,6 @@
10399
import com.oracle.graal.python.builtins.objects.dict.PDict;
104100
import com.oracle.graal.python.builtins.objects.function.PArguments;
105101
import com.oracle.graal.python.builtins.objects.function.PBuiltinFunction;
106-
import com.oracle.graal.python.builtins.objects.function.Signature;
107102
import com.oracle.graal.python.builtins.objects.ints.PInt;
108103
import com.oracle.graal.python.builtins.objects.list.PList;
109104
import com.oracle.graal.python.builtins.objects.method.PBuiltinMethod;
@@ -113,7 +108,6 @@
113108
import com.oracle.graal.python.builtins.objects.type.TypeNodes.IsTypeNode;
114109
import com.oracle.graal.python.nodes.PGuards;
115110
import com.oracle.graal.python.nodes.PRaiseNode;
116-
import com.oracle.graal.python.nodes.PRootNode;
117111
import com.oracle.graal.python.nodes.SpecialAttributeNames;
118112
import com.oracle.graal.python.nodes.SpecialMethodNames;
119113
import com.oracle.graal.python.nodes.attributes.LookupInheritedAttributeNode;
@@ -123,13 +117,9 @@
123117
import com.oracle.graal.python.nodes.call.GenericInvokeNode;
124118
import com.oracle.graal.python.nodes.call.special.CallBinaryMethodNode;
125119
import com.oracle.graal.python.nodes.call.special.CallTernaryMethodNode;
126-
import com.oracle.graal.python.nodes.call.special.LookupAndCallBinaryNode;
127-
import com.oracle.graal.python.nodes.call.special.LookupAndCallTernaryNode;
128-
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
129120
import com.oracle.graal.python.nodes.classes.IsSubtypeNode;
130121
import com.oracle.graal.python.nodes.expression.BinaryArithmetic;
131122
import com.oracle.graal.python.nodes.expression.InplaceArithmetic;
132-
import com.oracle.graal.python.nodes.expression.LookupAndCallInplaceNode;
133123
import com.oracle.graal.python.nodes.expression.TernaryArithmetic;
134124
import com.oracle.graal.python.nodes.expression.UnaryArithmetic;
135125
import com.oracle.graal.python.nodes.object.IsBuiltinClassProfile;
@@ -138,7 +128,6 @@
138128
import com.oracle.graal.python.nodes.util.CastToJavaIntLossyNode;
139129
import com.oracle.graal.python.nodes.util.CastToJavaLongExactNode;
140130
import com.oracle.graal.python.nodes.util.CastToJavaStringNode;
141-
import com.oracle.graal.python.runtime.ExecutionContext.CalleeContext;
142131
import com.oracle.graal.python.runtime.exception.PException;
143132
import com.oracle.graal.python.runtime.object.PythonObjectFactory;
144133
import com.oracle.graal.python.runtime.sequence.PSequence;
@@ -147,13 +136,11 @@
147136
import com.oracle.truffle.api.CompilerAsserts;
148137
import com.oracle.truffle.api.CompilerDirectives;
149138
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
150-
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
151139
import com.oracle.truffle.api.RootCallTarget;
152140
import com.oracle.truffle.api.TruffleLogger;
153141
import com.oracle.truffle.api.dsl.Cached;
154142
import com.oracle.truffle.api.dsl.CachedLanguage;
155143
import com.oracle.truffle.api.dsl.Specialization;
156-
import com.oracle.truffle.api.frame.VirtualFrame;
157144
import com.oracle.truffle.api.interop.ArityException;
158145
import com.oracle.truffle.api.interop.InteropException;
159146
import com.oracle.truffle.api.interop.InteropLibrary;
@@ -386,110 +373,8 @@ Object execute(Object[] arguments,
386373
}
387374
}
388375

389-
static final class CallArithmeticRootNode extends PRootNode {
390-
private static final Signature SIGNATURE_UNARY = new Signature(1, false, -1, false, new String[]{"$self"}, null);
391-
private static final Signature SIGNATURE_BINARY = new Signature(2, false, -1, false, new String[]{"$self", "other"}, null);
392-
private static final Signature SIGNATURE_TERNARY = new Signature(3, false, -1, false, new String[]{"x", "y", "z"}, null);
393-
394-
@Child private LookupAndCallUnaryNode callUnaryNode;
395-
@Child private LookupAndCallBinaryNode callBinaryNode;
396-
@Child private LookupAndCallInplaceNode callInplaceNode;
397-
@Child private LookupAndCallTernaryNode callTernaryNode;
398-
@Child private CalleeContext calleeContext;
399-
400-
private final UnaryArithmetic unaryOperator;
401-
private final BinaryArithmetic binaryOperator;
402-
private final InplaceArithmetic inplaceOperator;
403-
private final TernaryArithmetic ternaryOperator;
404-
405-
@CompilationFinal private ConditionProfile customLocalsProfile;
406-
407-
CallArithmeticRootNode(PythonLanguage language, UnaryArithmetic unaryOperator, BinaryArithmetic binaryOperator, InplaceArithmetic inplaceOperator,
408-
TernaryArithmetic ternaryOperator) {
409-
super(language);
410-
this.unaryOperator = unaryOperator;
411-
this.binaryOperator = binaryOperator;
412-
this.inplaceOperator = inplaceOperator;
413-
this.ternaryOperator = ternaryOperator;
414-
}
415-
416-
@Override
417-
public Signature getSignature() {
418-
if (unaryOperator != null) {
419-
return SIGNATURE_UNARY;
420-
} else if (binaryOperator != null) {
421-
return SIGNATURE_BINARY;
422-
} else if (inplaceOperator != null || ternaryOperator != null) {
423-
return SIGNATURE_TERNARY;
424-
} else {
425-
throw CompilerDirectives.shouldNotReachHere();
426-
}
427-
}
428-
429-
@Override
430-
public boolean isPythonInternal() {
431-
return true;
432-
}
433-
434-
@Override
435-
public Object execute(VirtualFrame frame) {
436-
ensureCallNode();
437-
if (calleeContext == null) {
438-
CompilerDirectives.transferToInterpreterAndInvalidate();
439-
calleeContext = insert(CalleeContext.create());
440-
}
441-
if (customLocalsProfile == null) {
442-
CompilerDirectives.transferToInterpreterAndInvalidate();
443-
customLocalsProfile = ConditionProfile.create();
444-
}
445-
446-
CalleeContext.enter(frame, customLocalsProfile);
447-
try {
448-
if (unaryOperator != null) {
449-
return callUnaryNode.executeObject(frame, PArguments.getArgument(frame, 0));
450-
} else if (binaryOperator != null) {
451-
return callBinaryNode.executeObject(frame, PArguments.getArgument(frame, 0), PArguments.getArgument(frame, 1));
452-
} else if (inplaceOperator != null) {
453-
// most of the in-place operators are binary but there can also be ternary
454-
if (PArguments.getUserArgumentLength(frame) == 2) {
455-
return callInplaceNode.execute(frame, PArguments.getArgument(frame, 0), PArguments.getArgument(frame, 1));
456-
} else if (PArguments.getUserArgumentLength(frame) == 3) {
457-
return callInplaceNode.executeTernary(frame, PArguments.getArgument(frame, 0), PArguments.getArgument(frame, 1), PArguments.getArgument(frame, 2));
458-
}
459-
throw CompilerDirectives.shouldNotReachHere();
460-
} else if (ternaryOperator != null) {
461-
return callTernaryNode.execute(frame, PArguments.getArgument(frame, 0), PArguments.getArgument(frame, 1), PArguments.getArgument(frame, 2));
462-
} else {
463-
throw CompilerDirectives.shouldNotReachHere();
464-
}
465-
} finally {
466-
calleeContext.exit(frame, this);
467-
}
468-
}
469-
470-
private void ensureCallNode() {
471-
if (callUnaryNode == null && callBinaryNode == null && callInplaceNode == null && callTernaryNode == null) {
472-
CompilerDirectives.transferToInterpreterAndInvalidate();
473-
if (unaryOperator != null) {
474-
callUnaryNode = insert(unaryOperator.create());
475-
} else if (binaryOperator != null) {
476-
callBinaryNode = insert(binaryOperator.create());
477-
} else if (inplaceOperator != null) {
478-
callInplaceNode = insert(inplaceOperator.create());
479-
} else if (ternaryOperator != null) {
480-
callTernaryNode = insert(ternaryOperator.create());
481-
} else {
482-
throw CompilerDirectives.shouldNotReachHere();
483-
}
484-
}
485-
}
486-
}
487-
488376
@ExportLibrary(InteropLibrary.class)
489377
public static final class GraalHPyArithmetic extends GraalHPyContextFunction {
490-
// TODO(fa): move to PythonLanguage ?
491-
private static final WeakHashMap<Pair<PythonLanguage, Object>, WeakReference<RootCallTarget>> weakCallTargetMap = new WeakHashMap<>();
492-
493378
private final UnaryArithmetic unaryOperator;
494379
private final BinaryArithmetic binaryOperator;
495380
private final InplaceArithmetic inplaceOperator;
@@ -535,7 +420,7 @@ Object execute(Object[] arguments,
535420
@Cached TransformExceptionToNativeNode transformExceptionToNativeNode) throws ArityException {
536421

537422
// We need to do argument checking at this position because our helper root node that
538-
// just dispatches to the appropriate 'LookupAndCallXXXNode' won't do any arguemnt
423+
// just dispatches to the appropriate 'LookupAndCallXXXNode' won't do any argument
539424
// checking. So it would just crash if there are too few arguments or just ignore if
540425
// there are too many.
541426
checkArguments(arguments);
@@ -582,38 +467,18 @@ private void checkArguments(Object[] arguments) throws ArityException {
582467
private RootCallTarget getCallTarget(PythonLanguage language) {
583468
if (callTarget == null) {
584469
CompilerDirectives.transferToInterpreterAndInvalidate();
585-
callTarget = createCallTarget(language);
586-
}
587-
return callTarget;
588-
}
589-
590-
@TruffleBoundary
591-
private RootCallTarget createCallTarget(PythonLanguage language) {
592-
Pair<PythonLanguage, Object> key = Pair.create(language, getOp());
593-
RootCallTarget cachedCallTarget;
594-
synchronized (weakCallTargetMap) {
595-
WeakReference<RootCallTarget> ctRef = weakCallTargetMap.get(key);
596-
cachedCallTarget = ctRef != null ? ctRef.get() : null;
597-
if (cachedCallTarget == null) {
598-
cachedCallTarget = PythonUtils.getOrCreateCallTarget(new CallArithmeticRootNode(language, unaryOperator, binaryOperator, inplaceOperator, ternaryOperator));
599-
weakCallTargetMap.put(key, new WeakReference<>(cachedCallTarget));
470+
if (unaryOperator != null) {
471+
callTarget = language.getOrCreateUnaryArithmeticCallTarget(unaryOperator);
472+
} else if (binaryOperator != null) {
473+
callTarget = language.getOrCreateBinaryArithmeticCallTarget(binaryOperator);
474+
} else if (inplaceOperator != null) {
475+
callTarget = language.getOrCreateInplaceArithmeticCallTarget(inplaceOperator);
476+
} else if (ternaryOperator != null) {
477+
callTarget = language.getOrCreateTernaryArithmeticCallTarget(ternaryOperator);
600478
}
479+
throw CompilerDirectives.shouldNotReachHere();
601480
}
602-
assert cachedCallTarget != null;
603-
return cachedCallTarget;
604-
}
605-
606-
private Object getOp() {
607-
if (unaryOperator != null) {
608-
return unaryOperator;
609-
} else if (binaryOperator != null) {
610-
return binaryOperator;
611-
} else if (inplaceOperator != null) {
612-
return inplaceOperator;
613-
} else if (ternaryOperator != null) {
614-
return ternaryOperator;
615-
}
616-
throw CompilerDirectives.shouldNotReachHere();
481+
return callTarget;
617482
}
618483
}
619484

0 commit comments

Comments
 (0)