diff --git a/mypyc/codegen/emitmodule.py b/mypyc/codegen/emitmodule.py index b8a19ac1d669..a3970b9c181e 100644 --- a/mypyc/codegen/emitmodule.py +++ b/mypyc/codegen/emitmodule.py @@ -867,8 +867,16 @@ def generate_globals_init(self, emitter: Emitter) -> None: def generate_module_def(self, emitter: Emitter, module_name: str, module: ModuleIR) -> None: """Emit the PyModuleDef struct for a module and the module init function.""" - # Emit module methods module_prefix = emitter.names.private_name(module_name) + self.emit_module_exec_func(emitter, module_name, module_prefix, module) + self.emit_module_methods(emitter, module_name, module_prefix, module) + self.emit_module_def_struct(emitter, module_name, module_prefix) + self.emit_module_init_func(emitter, module_name, module_prefix) + + def emit_module_methods( + self, emitter: Emitter, module_name: str, module_prefix: str, module: ModuleIR + ) -> None: + """Emit module methods (the static PyMethodDef table).""" emitter.emit_line(f"static PyMethodDef {module_prefix}module_methods[] = {{") for fn in module.functions: if fn.class_name is not None or fn.name == TOP_LEVEL_NAME: @@ -888,7 +896,10 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module emitter.emit_line("};") emitter.emit_line() - # Emit module definition struct + def emit_module_def_struct( + self, emitter: Emitter, module_name: str, module_prefix: str + ) -> None: + """Emit the static module definition struct (PyModuleDef).""" emitter.emit_lines( f"static struct PyModuleDef {module_prefix}module = {{", "PyModuleDef_HEAD_INIT,", @@ -900,36 +911,22 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module "};", ) emitter.emit_line() - # Emit module init function. If we are compiling just one module, this - # will be the C API init function. If we are compiling 2+ modules, we - # generate a shared library for the modules and shims that call into - # the shared library, and in this case we use an internal module - # initialized function that will be called by the shim. - if not self.use_shared_lib: - declaration = f"PyMODINIT_FUNC PyInit_{module_name}(void)" - else: - declaration = f"PyObject *CPyInit_{exported_name(module_name)}(void)" - emitter.emit_lines(declaration, "{") - emitter.emit_line("PyObject* modname = NULL;") - # Store the module reference in a static and return it when necessary. - # This is separate from the *global* reference to the module that will - # be populated when it is imported by a compiled module. We want that - # reference to only be populated when the module has been successfully - # imported, whereas this we want to have to stop a circular import. - module_static = self.module_internal_static_name(module_name, emitter) - emitter.emit_lines( - f"if ({module_static}) {{", - f"Py_INCREF({module_static});", - f"return {module_static};", - "}", - ) + def emit_module_exec_func( + self, emitter: Emitter, module_name: str, module_prefix: str, module: ModuleIR + ) -> None: + """Emit the module init function. - emitter.emit_lines( - f"{module_static} = PyModule_Create(&{module_prefix}module);", - f"if (unlikely({module_static} == NULL))", - " goto fail;", - ) + If we are compiling just one module, this will be the C API init + function. If we are compiling 2+ modules, we generate a shared + library for the modules and shims that call into the shared + library, and in this case we use an internal module initialized + function that will be called by the shim. + """ + declaration = f"static int {module_prefix}_exec(PyObject *module)" + module_static = self.module_internal_static_name(module_name, emitter) + emitter.emit_lines(declaration, "{") + emitter.emit_line("PyObject* modname = NULL;") emitter.emit_line( f'modname = PyObject_GetAttrString((PyObject *){module_static}, "__name__");' ) @@ -959,8 +956,9 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module emitter.emit_lines("Py_DECREF(modname);") - emitter.emit_line(f"return {module_static};") - emitter.emit_lines("fail:", f"Py_CLEAR({module_static});", "Py_CLEAR(modname);") + emitter.emit_line("return 0;") + emitter.emit_lines("fail:") + emitter.emit_lines(f"Py_CLEAR({module_static});", "Py_CLEAR(modname);") for name, typ in module.final_names: static_name = emitter.static_name(name, module_name) emitter.emit_dec_ref(static_name, typ, is_xdec=True) @@ -970,9 +968,44 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module # so we have to decref them for t in type_structs: emitter.emit_line(f"Py_CLEAR({t});") - emitter.emit_line("return NULL;") + emitter.emit_line("return -1;") emitter.emit_line("}") + def emit_module_init_func( + self, emitter: Emitter, module_name: str, module_prefix: str + ) -> None: + if not self.use_shared_lib: + declaration = f"PyMODINIT_FUNC PyInit_{module_name}(void)" + else: + declaration = f"PyObject *CPyInit_{exported_name(module_name)}(void)" + emitter.emit_lines(declaration, "{") + + exec_func = f"{module_prefix}_exec" + + # Store the module reference in a static and return it when necessary. + # This is separate from the *global* reference to the module that will + # be populated when it is imported by a compiled module. We want that + # reference to only be populated when the module has been successfully + # imported, whereas this we want to have to stop a circular import. + module_static = self.module_internal_static_name(module_name, emitter) + + emitter.emit_lines( + f"if ({module_static}) {{", + f"Py_INCREF({module_static});", + f"return {module_static};", + "}", + ) + + emitter.emit_lines( + f"{module_static} = PyModule_Create(&{module_prefix}module);", + f"if (unlikely({module_static} == NULL))", + " goto fail;", + ) + emitter.emit_lines(f"if ({exec_func}({module_static}) != 0)", " goto fail;") + emitter.emit_line(f"return {module_static};") + emitter.emit_lines("fail:", "return NULL;") + emitter.emit_lines("}") + def generate_top_level_call(self, module: ModuleIR, emitter: Emitter) -> None: """Generate call to function representing module top level.""" # Optimization: we tend to put the top level last, so reverse iterate