@@ -867,8 +867,16 @@ def generate_globals_init(self, emitter: Emitter) -> None:
867867
868868 def generate_module_def (self , emitter : Emitter , module_name : str , module : ModuleIR ) -> None :
869869 """Emit the PyModuleDef struct for a module and the module init function."""
870- # Emit module methods
871870 module_prefix = emitter .names .private_name (module_name )
871+ self .emit_module_exec_func (emitter , module_name , module_prefix , module )
872+ self .emit_module_methods (emitter , module_name , module_prefix , module )
873+ self .emit_module_def_struct (emitter , module_name , module_prefix )
874+ self .emit_module_init_func (emitter , module_name , module_prefix )
875+
876+ def emit_module_methods (
877+ self , emitter : Emitter , module_name : str , module_prefix : str , module : ModuleIR
878+ ) -> None :
879+ """Emit module methods (the static PyMethodDef table)."""
872880 emitter .emit_line (f"static PyMethodDef { module_prefix } module_methods[] = {{" )
873881 for fn in module .functions :
874882 if fn .class_name is not None or fn .name == TOP_LEVEL_NAME :
@@ -888,7 +896,10 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module
888896 emitter .emit_line ("};" )
889897 emitter .emit_line ()
890898
891- # Emit module definition struct
899+ def emit_module_def_struct (
900+ self , emitter : Emitter , module_name : str , module_prefix : str
901+ ) -> None :
902+ """Emit the static module definition struct (PyModuleDef)."""
892903 emitter .emit_lines (
893904 f"static struct PyModuleDef { module_prefix } module = {{" ,
894905 "PyModuleDef_HEAD_INIT," ,
@@ -900,36 +911,22 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module
900911 "};" ,
901912 )
902913 emitter .emit_line ()
903- # Emit module init function. If we are compiling just one module, this
904- # will be the C API init function. If we are compiling 2+ modules, we
905- # generate a shared library for the modules and shims that call into
906- # the shared library, and in this case we use an internal module
907- # initialized function that will be called by the shim.
908- if not self .use_shared_lib :
909- declaration = f"PyMODINIT_FUNC PyInit_{ module_name } (void)"
910- else :
911- declaration = f"PyObject *CPyInit_{ exported_name (module_name )} (void)"
912- emitter .emit_lines (declaration , "{" )
913- emitter .emit_line ("PyObject* modname = NULL;" )
914- # Store the module reference in a static and return it when necessary.
915- # This is separate from the *global* reference to the module that will
916- # be populated when it is imported by a compiled module. We want that
917- # reference to only be populated when the module has been successfully
918- # imported, whereas this we want to have to stop a circular import.
919- module_static = self .module_internal_static_name (module_name , emitter )
920914
921- emitter .emit_lines (
922- f"if ({ module_static } ) {{" ,
923- f"Py_INCREF({ module_static } );" ,
924- f"return { module_static } ;" ,
925- "}" ,
926- )
915+ def emit_module_exec_func (
916+ self , emitter : Emitter , module_name : str , module_prefix : str , module : ModuleIR
917+ ) -> None :
918+ """Emit the module init function.
927919
928- emitter .emit_lines (
929- f"{ module_static } = PyModule_Create(&{ module_prefix } module);" ,
930- f"if (unlikely({ module_static } == NULL))" ,
931- " goto fail;" ,
932- )
920+ If we are compiling just one module, this will be the C API init
921+ function. If we are compiling 2+ modules, we generate a shared
922+ library for the modules and shims that call into the shared
923+ library, and in this case we use an internal module initialized
924+ function that will be called by the shim.
925+ """
926+ declaration = f"static int { module_prefix } _exec(PyObject *module)"
927+ module_static = self .module_internal_static_name (module_name , emitter )
928+ emitter .emit_lines (declaration , "{" )
929+ emitter .emit_line ("PyObject* modname = NULL;" )
933930 emitter .emit_line (
934931 f'modname = PyObject_GetAttrString((PyObject *){ module_static } , "__name__");'
935932 )
@@ -959,8 +956,9 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module
959956
960957 emitter .emit_lines ("Py_DECREF(modname);" )
961958
962- emitter .emit_line (f"return { module_static } ;" )
963- emitter .emit_lines ("fail:" , f"Py_CLEAR({ module_static } );" , "Py_CLEAR(modname);" )
959+ emitter .emit_line ("return 0;" )
960+ emitter .emit_lines ("fail:" )
961+ emitter .emit_lines (f"Py_CLEAR({ module_static } );" , "Py_CLEAR(modname);" )
964962 for name , typ in module .final_names :
965963 static_name = emitter .static_name (name , module_name )
966964 emitter .emit_dec_ref (static_name , typ , is_xdec = True )
@@ -970,9 +968,44 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module
970968 # so we have to decref them
971969 for t in type_structs :
972970 emitter .emit_line (f"Py_CLEAR({ t } );" )
973- emitter .emit_line ("return NULL ;" )
971+ emitter .emit_line ("return -1 ;" )
974972 emitter .emit_line ("}" )
975973
974+ def emit_module_init_func (
975+ self , emitter : Emitter , module_name : str , module_prefix : str
976+ ) -> None :
977+ if not self .use_shared_lib :
978+ declaration = f"PyMODINIT_FUNC PyInit_{ module_name } (void)"
979+ else :
980+ declaration = f"PyObject *CPyInit_{ exported_name (module_name )} (void)"
981+ emitter .emit_lines (declaration , "{" )
982+
983+ exec_func = f"{ module_prefix } _exec"
984+
985+ # Store the module reference in a static and return it when necessary.
986+ # This is separate from the *global* reference to the module that will
987+ # be populated when it is imported by a compiled module. We want that
988+ # reference to only be populated when the module has been successfully
989+ # imported, whereas this we want to have to stop a circular import.
990+ module_static = self .module_internal_static_name (module_name , emitter )
991+
992+ emitter .emit_lines (
993+ f"if ({ module_static } ) {{" ,
994+ f"Py_INCREF({ module_static } );" ,
995+ f"return { module_static } ;" ,
996+ "}" ,
997+ )
998+
999+ emitter .emit_lines (
1000+ f"{ module_static } = PyModule_Create(&{ module_prefix } module);" ,
1001+ f"if (unlikely({ module_static } == NULL))" ,
1002+ " goto fail;" ,
1003+ )
1004+ emitter .emit_lines (f"if ({ exec_func } ({ module_static } ) != 0)" , " goto fail;" )
1005+ emitter .emit_line (f"return { module_static } ;" )
1006+ emitter .emit_lines ("fail:" , "return NULL;" )
1007+ emitter .emit_lines ("}" )
1008+
9761009 def generate_top_level_call (self , module : ModuleIR , emitter : Emitter ) -> None :
9771010 """Generate call to function representing module top level."""
9781011 # Optimization: we tend to put the top level last, so reverse iterate
0 commit comments