@@ -198,6 +198,7 @@ def generate_class(model_class: ModelMetaclass, indent_size: int = 0, max_width:
198198 cls_name = model_class .__name__
199199 bases = f" : public { ', ' .join (b .__name__ for b in base_init .keys ())} " if base_init else ""
200200 default_constructor = ""
201+ default_pydantic_init = ""
201202 indent = " " * indent_size
202203 newline = "\n "
203204 newline_indent = f"{ newline } { indent } "
@@ -209,6 +210,7 @@ def generate_class(model_class: ModelMetaclass, indent_size: int = 0, max_width:
209210
210211 if default_init_args :
211212 default_constructor = f"{ indent } { cls_name } () :\n "
213+ default_pydantic_init = f"{ indent } .def(py::init<>())\n "
212214
213215 if base_init :
214216 default_constructor += \
@@ -244,8 +246,7 @@ def generate_class(model_class: ModelMetaclass, indent_size: int = 0, max_width:
244246 pydantic_bases = ", " + ", " .join (base .__name__ for base in base_init .keys ()) if base_init else ""
245247 pydantic_init = "\n " .join (args_wrapper .wrap (f"{ ', ' .join (types )} >(), { ', ' .join (kwargs )} " ))
246248 pydantic_def = f"""{ indent } py::class_<{ cls_name } { pydantic_bases } >(m, "{ cls_name } ")
247- { indent } .def(py::init<>())
248- { indent } .def(py::init<{ pydantic_init } )
249+ { default_pydantic_init } { indent } .def(py::init<{ pydantic_init } )
249250 { indent } { newline_indent .join (pydantic_attrs )} ;"""
250251
251252 return struct_def , pydantic_def , tuple (f"#include { i } " for i in sorted (all_includes ))
@@ -271,8 +272,9 @@ def generate_enum(enum_typ: EnumType, indent_size: int = 0, max_width: int = 110
271272def generate_module (module_name : str , output_dir : str , indent_size : int = 4 , max_width : int = 110 ):
272273 single_newline = "\n "
273274 double_newline = "\n \n "
274- dot = r'.'
275- slash = r'/'
275+ dot = r"."
276+ slash = r"/"
277+ indent = " " * indent_size
276278
277279 module = import_module (module_name )
278280 generated_root = Path (output_dir )
@@ -304,9 +306,17 @@ def generate_module(module_name: str, output_dir: str, indent_size: int = 4, max
304306 if self_include in includes :
305307 includes .remove (self_include )
306308
309+ imports = []
310+ for include in (i for i in includes if namespace in i ):
311+ import_parts = include .split (slash )
312+ import_parts .insert (- 2 , "__pybind__" )
313+ imprt = '.' .join (import_parts ).replace ('#include ' , '' ).replace ('.h' , '' )
314+ imports .append (f"{ indent } py::module_::import({ imprt } );" )
315+
307316 enum_contents = f"\n { double_newline .join (enum_defs )} { single_newline if struct_defs else '' } " if enum_defs else ""
308317 struct_contents = f"\n { double_newline .join (struct_defs )} " if struct_defs else ""
309318 include_contents = f"\n { single_newline .join (includes )} \n " if includes else ""
319+ import_contents = f"\n { single_newline .join (imports )} \n " if imports else ""
310320
311321 header_contents = f"""
312322#ifndef { guard }
@@ -331,7 +341,7 @@ def generate_module(module_name: str, output_dir: str, indent_size: int = 4, max
331341
332342
333343PYBIND11_MODULE({ module_base_name } , m)
334- {{
344+ {{{ import_contents }
335345{ double_newline .join (pydantic_defs )}
336346}}
337347"""
0 commit comments