Skip to content

Commit 9fdf740

Browse files
committed
feat: support iscoroutinefunction introspection
1 parent a8ec893 commit 9fdf740

File tree

5 files changed

+48
-1
lines changed

5 files changed

+48
-1
lines changed

mypyc/codegen/emitmodule.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,13 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module
954954
emitter.emit_lines(f"if (unlikely(!{type_struct}))", " goto fail;")
955955

956956
emitter.emit_lines("if (CPyGlobalsInit() < 0)", " goto fail;")
957+
# Patch async native functions so they're recognized as coroutines
958+
for fn in module.functions:
959+
if fn.decl.is_async and fn.class_name is None:
960+
emitter.emit_line(f'PyObject *func = PyObject_GetAttrString({module_static}, "{fn.decl.name}");')
961+
emitter.emit_line("if (!func) goto fail;")
962+
emitter.emit_line("if (!CPyPatchAsyncCode(func)) { Py_DECREF(func); goto fail; }")
963+
emitter.emit_line("Py_DECREF(func);")
957964

958965
self.generate_top_level_call(module, emitter)
959966

mypyc/ir/func_ir.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def __init__(
140140
is_prop_setter: bool = False,
141141
is_prop_getter: bool = False,
142142
implicit: bool = False,
143+
is_async: bool = False,
143144
) -> None:
144145
self.name = name
145146
self.class_name = class_name
@@ -159,6 +160,7 @@ def __init__(
159160
# If True, not present in the mypy AST and must be synthesized during irbuild
160161
# Currently only supported for property getters/setters
161162
self.implicit = implicit
163+
self.is_async = is_async
162164

163165
# This is optional because this will be set to the line number when the corresponding
164166
# FuncIR is created
@@ -204,6 +206,7 @@ def serialize(self) -> JsonDict:
204206
"is_prop_setter": self.is_prop_setter,
205207
"is_prop_getter": self.is_prop_getter,
206208
"implicit": self.implicit,
209+
"is_async": self.is_async,
207210
}
208211

209212
# TODO: move this to FuncIR?
@@ -226,6 +229,7 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> FuncDecl:
226229
data["is_prop_setter"],
227230
data["is_prop_getter"],
228231
data["implicit"],
232+
data["is_async"],
229233
)
230234

231235

mypyc/irbuild/prepare.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,14 @@ def prepare_func_def(
185185
else (FUNC_CLASSMETHOD if fdef.is_class else FUNC_NORMAL)
186186
)
187187
sig = mapper.fdef_to_sig(fdef, options.strict_dunders_typing)
188-
decl = FuncDecl(fdef.name, class_name, module_name, sig, kind)
188+
decl = FuncDecl(
189+
fdef.name,
190+
class_name,
191+
module_name,
192+
sig,
193+
kind=kind,
194+
is_async=fdef.is_coroutine,
195+
)
189196
mapper.func_to_decl[fdef] = decl
190197
return decl
191198

mypyc/lib-rt/pythonsupport.c

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,20 @@ CPyLong_AsSsize_tAndOverflow_(PyObject *vv, int *overflow)
103103
}
104104

105105

106+
// CPy support for async functions: patch code object to include CO_COROUTINE
107+
PyObject* CPyPatchAsyncCode(PyObject* func) {
108+
PyObject* code = PyObject_GetAttrString(func, "__code__");
109+
if (!code) {
110+
return NULL;
111+
}
112+
PyCodeObject* codeobj = (PyCodeObject*)code;
113+
codeobj->co_flags |= CO_COROUTINE;
114+
int res = PyObject_SetAttrString(func, "__code__", code);
115+
Py_DECREF(code);
116+
if (res < 0) {
117+
return NULL;
118+
}
119+
Py_INCREF(func);
120+
return func;
121+
}
106122
#endif

mypyc/test-data/run-async.test

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,3 +561,16 @@ def test_bool() -> None:
561561

562562
[file asyncio/__init__.pyi]
563563
def run(x: object) -> object: ...
564+
565+
[case testIsCoroutineFunction]
566+
import asyncio
567+
import inspect
568+
569+
async def foo():
570+
return 1
571+
572+
def test_asyncio_iscoroutinefunction():
573+
assert asyncio.iscoroutinefunction(foo) is True, "foo should be recognized as coroutine function"
574+
575+
def test_inspect_iscoroutinefunction():
576+
assert inspect.iscoroutinefunction(foo) is True, "foo should be recognized as coroutine function"

0 commit comments

Comments
 (0)