1515)
1616
1717
18- def union_sub_func (match : re .Match ) -> str :
18+ def union_sub_func (match : re .Match [ str ] ) -> str :
1919 return f'{ match .group ("variable" )} : typing.TypeAlias = { match .group ("value" )} '
2020
2121
@@ -36,13 +36,20 @@ def str_sub_func(match: re.Match) -> str:
3636 return f"{ match .group ('var' )} : str"
3737
3838
39+ CompilerConfigPattern = re .compile (r"compiler_config: dict.*" )
40+
41+
42+ def compiler_config_sub_func (match : re .Match ) -> str :
43+ return "compiler_config: dict"
44+
45+
3946EqPattern = re .compile (
4047 r"(?P<indent>[ \t]+)def __eq__\(self, arg0: (?P<other>[a-zA-Z1-9.]+)\) -> (?P<return>[a-zA-Z1-9.]+):"
4148 r"(?P<ellipsis_docstring>\s*((\.\.\.)|(\"\"\"(.|\n)*?\"\"\")))"
4249)
4350
4451
45- def eq_sub_func (match : re .Match ) -> str :
52+ def eq_sub_func (match : re .Match [ str ] ) -> str :
4653 """
4754 if one - add @overload and overloaded signature
4855
@@ -95,72 +102,37 @@ def get_package_dir(name: str) -> str:
95102 return os .path .realpath (os .path .dirname (get_module_path (name )))
96103
97104
98- def patch_stubgen ():
105+ def patch_stubgen () -> None :
106+ class_member_blacklist : set [Identifier ] = FilterClassMembers ._FilterClassMembers__class_member_blacklist # type: ignore
107+ attribute_blacklist : set [Identifier ] = FilterClassMembers ._FilterClassMembers__attribute_blacklist # type: ignore
108+
99109 # Is there a better way to add items to the blacklist?
100110 # Pybind11
101- FilterClassMembers ._FilterClassMembers__class_member_blacklist .add (
102- Identifier ("_pybind11_conduit_v1_" )
103- )
111+ class_member_blacklist .add (Identifier ("_pybind11_conduit_v1_" ))
104112 # Python
105- FilterClassMembers ._FilterClassMembers__class_member_blacklist .add (
106- Identifier ("__new__" )
107- )
108- FilterClassMembers ._FilterClassMembers__class_member_blacklist .add (
109- Identifier ("__subclasshook__" )
110- )
113+ class_member_blacklist .add (Identifier ("__new__" ))
114+ class_member_blacklist .add (Identifier ("__subclasshook__" ))
111115 # Pickle
112- FilterClassMembers ._FilterClassMembers__class_member_blacklist .add (
113- Identifier ("__getnewargs__" )
114- )
115- FilterClassMembers ._FilterClassMembers__class_member_blacklist .add (
116- Identifier ("__getstate__" )
117- )
118- FilterClassMembers ._FilterClassMembers__class_member_blacklist .add (
119- Identifier ("__setstate__" )
120- )
116+ class_member_blacklist .add (Identifier ("__getnewargs__" ))
117+ class_member_blacklist .add (Identifier ("__getstate__" ))
118+ class_member_blacklist .add (Identifier ("__setstate__" ))
121119 # ABC
122- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
123- Identifier ("__abstractmethods__" )
124- )
125- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
126- Identifier ("__orig_bases__" )
127- )
128- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
129- Identifier ("__parameters__" )
130- )
131- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
132- Identifier ("_abc_impl" )
133- )
120+ attribute_blacklist .add (Identifier ("__abstractmethods__" ))
121+ attribute_blacklist .add (Identifier ("__orig_bases__" ))
122+ attribute_blacklist .add (Identifier ("__parameters__" ))
123+ attribute_blacklist .add (Identifier ("_abc_impl" ))
134124 # Protocol
135- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
136- Identifier ("__protocol_attrs__" )
137- )
138- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
139- Identifier ("__non_callable_proto_members__" )
140- )
141- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
142- Identifier ("_is_protocol" )
143- )
144- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
145- Identifier ("_is_runtime_protocol" )
146- )
125+ attribute_blacklist .add (Identifier ("__protocol_attrs__" ))
126+ attribute_blacklist .add (Identifier ("__non_callable_proto_members__" ))
127+ attribute_blacklist .add (Identifier ("_is_protocol" ))
128+ attribute_blacklist .add (Identifier ("_is_runtime_protocol" ))
147129 # dataclass
148- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
149- Identifier ("__dataclass_fields__" )
150- )
151- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
152- Identifier ("__dataclass_params__" )
153- )
154- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
155- Identifier ("__match_args__" )
156- )
130+ attribute_blacklist .add (Identifier ("__dataclass_fields__" ))
131+ attribute_blacklist .add (Identifier ("__dataclass_params__" ))
132+ attribute_blacklist .add (Identifier ("__match_args__" ))
157133 # Buffer protocol
158- FilterClassMembers ._FilterClassMembers__class_member_blacklist .add (
159- Identifier ("__buffer__" )
160- )
161- FilterClassMembers ._FilterClassMembers__class_member_blacklist .add (
162- Identifier ("__release_buffer__" )
163- )
134+ class_member_blacklist .add (Identifier ("__buffer__" ))
135+ class_member_blacklist .add (Identifier ("__release_buffer__" ))
164136
165137
166138def main () -> None :
@@ -236,6 +208,7 @@ def main() -> None:
236208 pyi = UnionPattern .sub (union_sub_func , pyi )
237209 pyi = ClassVarUnionPattern .sub (class_var_union_sub_func , pyi )
238210 pyi = VersionPattern .sub (str_sub_func , pyi )
211+ pyi = CompilerConfigPattern .sub (compiler_config_sub_func , pyi )
239212 pyi = GenericAliasPattern .sub (generic_alias_sub_func , pyi )
240213 pyi = pyi .replace (
241214 "__hash__: typing.ClassVar[None] = None" ,
@@ -244,13 +217,12 @@ def main() -> None:
244217 pyi = EqPattern .sub (eq_sub_func , pyi )
245218 pyi = pyi .replace ("**kwargs)" , "**kwargs: typing.Any)" )
246219 pyi_split = [l .rstrip ("\r " ) for l in pyi .split ("\n " )]
247- for hidden_import in []:
220+ for hidden_import in ["typing" , "types" ]:
248221 if hidden_import in pyi and f"import { hidden_import } " not in pyi_split :
249- pyi_split .insert (2 , f"import { hidden_import } " )
250- if "import typing" not in pyi_split :
251- pyi_split .insert (2 , "import typing" )
252- if "import types" not in pyi_split :
253- pyi_split .insert (2 , "import types" )
222+ pyi_split .insert (
223+ pyi_split .index ("from __future__ import annotations" ) + 1 ,
224+ f"import { hidden_import } " ,
225+ )
254226 pyi = "\n " .join (pyi_split )
255227 with open (stub_path , "w" , encoding = "utf-8" ) as f :
256228 f .write (pyi )
0 commit comments