@@ -841,6 +841,7 @@ def write_to_module(
841
841
find_replace : ty .Optional [ty .List [ty .Tuple [str , str ]]] = None ,
842
842
inline_intra_pkg : bool = False ,
843
843
additional_imports : ty .Optional [ty .List [ImportStatement ]] = None ,
844
+ interface_module : bool = False ,
844
845
):
845
846
"""Writes the given imports, constants, classes, and functions to the file at the given path,
846
847
merging with existing code if it exists"""
@@ -875,9 +876,13 @@ def write_to_module(
875
876
existing_imports = parse_imports (existing_import_strs , relative_to = module_name )
876
877
converter_imports = []
877
878
879
+ src_module_name = self .untranslate_submodule (module_name )
880
+ if interface_module :
881
+ src_module_name = "." .join (src_module_name .split ("." )[:- 1 ])
882
+
878
883
for klass in used .classes :
879
884
if (
880
- klass .__module__ == module_name
885
+ klass .__module__ == src_module_name
881
886
and f"\n class { klass .__name__ } (" not in code_str
882
887
):
883
888
try :
@@ -912,7 +917,7 @@ def write_to_module(
912
917
913
918
for func in sorted (used .functions , key = attrgetter ("__name__" )):
914
919
if (
915
- func .__module__ == module_name
920
+ func .__module__ == src_module_name
916
921
and f"\n def { func .__name__ } (" not in code_str
917
922
):
918
923
if func .__name__ in self .functions :
0 commit comments