Skip to content

Commit 57a4da9

Browse files
author
Adam Hrbac
committed
Add firstiter async hook
This can be used by various async frameworks to ensure asyncgens shut down - e.g. asyncio.shutdown_asyncgens(). This is even more important in GraalPy, since we do not run the finalizers.
1 parent 7250e4c commit 57a4da9

File tree

4 files changed

+58
-15
lines changed

4 files changed

+58
-15
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/SysModuleBuiltins.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1136,6 +1136,12 @@ Object getProfile() {
11361136
abstract static class SetAsyncgenHooks extends PythonBuiltinNode {
11371137
@Specialization
11381138
Object setAsyncgenHooks(Object firstIter, Object finalizer) {
1139+
if (firstIter != PNone.NO_VALUE && firstIter != PNone.NONE) {
1140+
getContext().getThreadState(getLanguage()).setAsyncgenFirstIter(firstIter);
1141+
} else if (firstIter == PNone.NONE) {
1142+
getContext().getThreadState(getLanguage()).setAsyncgenFirstIter(null);
1143+
}
1144+
// Ignore finalizer, since we don't have a useful place to call it
11391145
return PNone.NONE;
11401146
}
11411147
}
@@ -1146,7 +1152,9 @@ abstract static class GetAsyncgenHooks extends PythonBuiltinNode {
11461152
@Specialization
11471153
Object setAsyncgenHooks() {
11481154
// TODO: use asyncgen_hooks object
1149-
return factory().createTuple(new Object[]{PNone.NONE, PNone.NONE});
1155+
PythonContext.PythonThreadState threadState = getContext().getThreadState(getLanguage());
1156+
Object firstiter = threadState.getAsyncgenFirstIter();
1157+
return factory().createTuple(new Object[]{firstiter == null ? PNone.NONE : firstiter, PNone.NONE});
11501158
}
11511159
}
11521160

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/asyncio/AsyncGeneratorBuiltins.java

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,24 @@
4040
*/
4141
package com.oracle.graal.python.builtins.objects.asyncio;
4242

43+
import static com.oracle.graal.python.nodes.SpecialMethodNames.J___AITER__;
44+
import static com.oracle.graal.python.nodes.SpecialMethodNames.J___ANEXT__;
45+
46+
import java.util.List;
47+
4348
import com.oracle.graal.python.builtins.Builtin;
4449
import com.oracle.graal.python.builtins.CoreFunctions;
4550
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
4651
import com.oracle.graal.python.builtins.PythonBuiltins;
4752
import com.oracle.graal.python.builtins.objects.PNone;
4853
import com.oracle.graal.python.builtins.objects.generator.GeneratorBuiltins;
54+
import com.oracle.graal.python.nodes.call.special.CallUnaryMethodNode;
4955
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
5056
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
5157
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
5258
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
5359
import com.oracle.graal.python.runtime.PAsyncGen;
60+
import com.oracle.graal.python.runtime.PythonContext;
5461
import com.oracle.truffle.api.dsl.Bind;
5562
import com.oracle.truffle.api.dsl.Cached;
5663
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
@@ -60,13 +67,20 @@
6067
import com.oracle.truffle.api.nodes.Node;
6168
import com.oracle.truffle.api.profiles.InlinedConditionProfile;
6269

63-
import java.util.List;
64-
65-
import static com.oracle.graal.python.nodes.SpecialMethodNames.J___AITER__;
66-
import static com.oracle.graal.python.nodes.SpecialMethodNames.J___ANEXT__;
67-
6870
@CoreFunctions(extendClasses = PythonBuiltinClassType.PAsyncGenerator)
6971
public class AsyncGeneratorBuiltins extends PythonBuiltins {
72+
private static void callHooks(VirtualFrame frame, PAsyncGen self, PythonContext.PythonThreadState state, CallUnaryMethodNode invokeFirstIter) {
73+
Object firstIter = state.getAsyncgenFirstIter();
74+
if (firstIter == null) {
75+
return;
76+
}
77+
if (self.isHookCalled()) {
78+
return;
79+
}
80+
self.setHookCalled(true);
81+
invokeFirstIter.executeObject(frame, firstIter, self);
82+
}
83+
7084
@Override
7185
protected List<? extends NodeFactory<? extends PythonBuiltinBaseNode>> getNodeFactories() {
7286
return AsyncGeneratorBuiltinsFactory.getFactories();
@@ -116,7 +130,9 @@ public boolean isRunning(PAsyncGen self) {
116130
@GenerateNodeFactory
117131
public abstract static class ASend extends PythonBinaryBuiltinNode {
118132
@Specialization
119-
public Object aSend(PAsyncGen self, Object sent) {
133+
public Object aSend(VirtualFrame frame, PAsyncGen self, Object sent,
134+
@Cached CallUnaryMethodNode callFirstIter) {
135+
callHooks(frame, self, getContext().getThreadState(getLanguage()), callFirstIter);
120136
return factory().createAsyncGeneratorASend(self, sent);
121137
}
122138
}
@@ -127,7 +143,9 @@ public abstract static class AThrow extends PythonBuiltinNode {
127143
public abstract Object execute(VirtualFrame frame, PAsyncGen self, Object arg1, Object arg2, Object arg3);
128144

129145
@Specialization
130-
public Object athrow(PAsyncGen self, Object arg1, Object arg2, Object arg3) {
146+
public Object athrow(VirtualFrame frame, PAsyncGen self, Object arg1, Object arg2, Object arg3,
147+
@Cached CallUnaryMethodNode callFirstIter) {
148+
callHooks(frame, self, getContext().getThreadState(getLanguage()), callFirstIter);
131149
return factory().createAsyncGeneratorAThrow(self, arg1, arg2, arg3);
132150
}
133151
}
@@ -145,7 +163,9 @@ public Object aIter(PAsyncGen self) {
145163
@GenerateNodeFactory
146164
public abstract static class ANext extends PythonUnaryBuiltinNode {
147165
@Specialization
148-
public Object aNext(PAsyncGen self) {
166+
public Object aNext(VirtualFrame frame, PAsyncGen self,
167+
@Cached CallUnaryMethodNode callFirstIter) {
168+
callHooks(frame, self, getContext().getThreadState(getLanguage()), callFirstIter);
149169
return factory().createAsyncGeneratorASend(self, PNone.NONE);
150170
}
151171
}
@@ -154,7 +174,9 @@ public Object aNext(PAsyncGen self) {
154174
@GenerateNodeFactory
155175
public abstract static class AClose extends PythonUnaryBuiltinNode {
156176
@Specialization
157-
public Object aClose(PAsyncGen self) {
177+
public Object aClose(VirtualFrame frame, PAsyncGen self,
178+
@Cached CallUnaryMethodNode callFirstIter) {
179+
callHooks(frame, self, getContext().getThreadState(getLanguage()), callFirstIter);
158180
return factory().createAsyncGeneratorAThrow(self, null, PNone.NO_VALUE, PNone.NO_VALUE);
159181
}
160182
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/runtime/PAsyncGen.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949

5050
public final class PAsyncGen extends PGenerator {
5151
private boolean closed = false;
52-
private boolean hooksCalled = false;
52+
private boolean hookCalled = false;
5353
private boolean runningAsync = false;
5454

5555
public static PAsyncGen create(PythonLanguage lang, TruffleString name, TruffleString qualname, PBytecodeRootNode rootNode, RootCallTarget[] callTargets, Object[] arguments) {
@@ -69,12 +69,12 @@ public void markClosed() {
6969
this.closed = true;
7070
}
7171

72-
public boolean isHooksCalled() {
73-
return hooksCalled;
72+
public boolean isHookCalled() {
73+
return hookCalled;
7474
}
7575

76-
public void setHooksCalled(boolean hooksCalled) {
77-
this.hooksCalled = hooksCalled;
76+
public void setHookCalled(boolean hookCalled) {
77+
this.hookCalled = hookCalled;
7878
}
7979

8080
public boolean isRunningAsync() {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/runtime/PythonContext.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,11 @@ public static final class PythonThreadState {
300300
*/
301301
Object runningEventLoop;
302302

303+
/*
304+
* A callable that should be called for the first iteration of an async generators.
305+
*/
306+
Object asyncgenFirstIter;
307+
303308
/*
304309
* Counter for C-level recursion depth used for Py_(Enter/Leave)RecursiveCall.
305310
*/
@@ -477,6 +482,14 @@ public void profilingStart() {
477482
public void profilingStop() {
478483
this.profiling = false;
479484
}
485+
486+
public Object getAsyncgenFirstIter() {
487+
return asyncgenFirstIter;
488+
}
489+
490+
public void setAsyncgenFirstIter(Object asyncgenFirstIter) {
491+
this.asyncgenFirstIter = asyncgenFirstIter;
492+
}
480493
}
481494

482495
private static final class AtExitHook {

0 commit comments

Comments
 (0)