Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 66 additions & 33 deletions mypyc/codegen/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,8 +867,16 @@ def generate_globals_init(self, emitter: Emitter) -> None:

def generate_module_def(self, emitter: Emitter, module_name: str, module: ModuleIR) -> None:
"""Emit the PyModuleDef struct for a module and the module init function."""
# Emit module methods
module_prefix = emitter.names.private_name(module_name)
self.emit_module_exec_func(emitter, module_name, module_prefix, module)
self.emit_module_methods(emitter, module_name, module_prefix, module)
self.emit_module_def_struct(emitter, module_name, module_prefix)
self.emit_module_init_func(emitter, module_name, module_prefix)

def emit_module_methods(
self, emitter: Emitter, module_name: str, module_prefix: str, module: ModuleIR
) -> None:
"""Emit module methods (the static PyMethodDef table)."""
emitter.emit_line(f"static PyMethodDef {module_prefix}module_methods[] = {{")
for fn in module.functions:
if fn.class_name is not None or fn.name == TOP_LEVEL_NAME:
Expand All @@ -888,7 +896,10 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module
emitter.emit_line("};")
emitter.emit_line()

# Emit module definition struct
def emit_module_def_struct(
self, emitter: Emitter, module_name: str, module_prefix: str
) -> None:
"""Emit the static module definition struct (PyModuleDef)."""
emitter.emit_lines(
f"static struct PyModuleDef {module_prefix}module = {{",
"PyModuleDef_HEAD_INIT,",
Expand All @@ -900,36 +911,22 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module
"};",
)
emitter.emit_line()
# Emit module init function. If we are compiling just one module, this
# will be the C API init function. If we are compiling 2+ modules, we
# generate a shared library for the modules and shims that call into
# the shared library, and in this case we use an internal module
# initialized function that will be called by the shim.
if not self.use_shared_lib:
declaration = f"PyMODINIT_FUNC PyInit_{module_name}(void)"
else:
declaration = f"PyObject *CPyInit_{exported_name(module_name)}(void)"
emitter.emit_lines(declaration, "{")
emitter.emit_line("PyObject* modname = NULL;")
# Store the module reference in a static and return it when necessary.
# This is separate from the *global* reference to the module that will
# be populated when it is imported by a compiled module. We want that
# reference to only be populated when the module has been successfully
# imported, whereas this we want to have to stop a circular import.
module_static = self.module_internal_static_name(module_name, emitter)

emitter.emit_lines(
f"if ({module_static}) {{",
f"Py_INCREF({module_static});",
f"return {module_static};",
"}",
)
def emit_module_exec_func(
self, emitter: Emitter, module_name: str, module_prefix: str, module: ModuleIR
) -> None:
"""Emit the module init function.

emitter.emit_lines(
f"{module_static} = PyModule_Create(&{module_prefix}module);",
f"if (unlikely({module_static} == NULL))",
" goto fail;",
)
If we are compiling just one module, this will be the C API init
function. If we are compiling 2+ modules, we generate a shared
library for the modules and shims that call into the shared
library, and in this case we use an internal module initialized
function that will be called by the shim.
"""
declaration = f"static int {module_prefix}_exec(PyObject *module)"
module_static = self.module_internal_static_name(module_name, emitter)
emitter.emit_lines(declaration, "{")
emitter.emit_line("PyObject* modname = NULL;")
emitter.emit_line(
f'modname = PyObject_GetAttrString((PyObject *){module_static}, "__name__");'
)
Expand Down Expand Up @@ -959,8 +956,9 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module

emitter.emit_lines("Py_DECREF(modname);")

emitter.emit_line(f"return {module_static};")
emitter.emit_lines("fail:", f"Py_CLEAR({module_static});", "Py_CLEAR(modname);")
emitter.emit_line("return 0;")
emitter.emit_lines("fail:")
emitter.emit_lines(f"Py_CLEAR({module_static});", "Py_CLEAR(modname);")
for name, typ in module.final_names:
static_name = emitter.static_name(name, module_name)
emitter.emit_dec_ref(static_name, typ, is_xdec=True)
Expand All @@ -970,9 +968,44 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module
# so we have to decref them
for t in type_structs:
emitter.emit_line(f"Py_CLEAR({t});")
emitter.emit_line("return NULL;")
emitter.emit_line("return -1;")
emitter.emit_line("}")

def emit_module_init_func(
self, emitter: Emitter, module_name: str, module_prefix: str
) -> None:
if not self.use_shared_lib:
declaration = f"PyMODINIT_FUNC PyInit_{module_name}(void)"
else:
declaration = f"PyObject *CPyInit_{exported_name(module_name)}(void)"
emitter.emit_lines(declaration, "{")

exec_func = f"{module_prefix}_exec"

# Store the module reference in a static and return it when necessary.
# This is separate from the *global* reference to the module that will
# be populated when it is imported by a compiled module. We want that
# reference to only be populated when the module has been successfully
# imported, whereas this we want to have to stop a circular import.
module_static = self.module_internal_static_name(module_name, emitter)

emitter.emit_lines(
f"if ({module_static}) {{",
f"Py_INCREF({module_static});",
f"return {module_static};",
"}",
)

emitter.emit_lines(
f"{module_static} = PyModule_Create(&{module_prefix}module);",
f"if (unlikely({module_static} == NULL))",
" goto fail;",
)
emitter.emit_lines(f"if ({exec_func}({module_static}) != 0)", " goto fail;")
emitter.emit_line(f"return {module_static};")
emitter.emit_lines("fail:", "return NULL;")
emitter.emit_lines("}")

def generate_top_level_call(self, module: ModuleIR, emitter: Emitter) -> None:
"""Generate call to function representing module top level."""
# Optimization: we tend to put the top level last, so reverse iterate
Expand Down