77import pybind11_stubgen
88from pybind11_stubgen .structs import Identifier
99from pybind11_stubgen .parser .mixins .filter import FilterClassMembers
10- from pybind11_stubgen import main as pybind11_stubgen_main
10+
11+
12+ ForwardRefPattern = re .compile (r"ForwardRef\('(?P<variable>[a-zA-Z_][a-zA-Z0-9_]*)'\)" )
13+
14+ QuotePattern = re .compile (r"'(?P<variable>[a-zA-Z_][a-zA-Z0-9_]*)'" )
15+
16+
17+ def fix_value (value : str ) -> str :
18+ value = value .replace ("NoneType" , "None" )
19+ value = ForwardRefPattern .sub (lambda match : match .group ("variable" ), value )
20+ value = QuotePattern .sub (lambda match : match .group ("variable" ), value )
21+ return value
22+
1123
1224UnionPattern = re .compile (
13- r"^(?P<variable>[a-zA-Z_][a-zA-Z0-9_]*): types\.UnionType\s*#\s*value = (?P<value>.*)$" ,
25+ r"^(?P<variable>[a-zA-Z_][a-zA-Z0-9_]*): ( types\.UnionType|typing\._UnionGenericAlias) \s*#\s*value = (?P<value>.*)$" ,
1426 flags = re .MULTILINE ,
1527)
1628
1729
18- def union_sub_func (match : re .Match ) -> str :
19- return f'{ match .group ("variable" )} : typing.TypeAlias = { match .group ("value" )} '
30+ def union_sub_func (match : re .Match [ str ] ) -> str :
31+ return f'{ match .group ("variable" )} : typing.TypeAlias = { fix_value ( match .group ("value" ) )} '
2032
2133
2234ClassVarUnionPattern = re .compile (
@@ -26,7 +38,7 @@ def union_sub_func(match: re.Match) -> str:
2638
2739
2840def class_var_union_sub_func (match : re .Match ) -> str :
29- return f'{ match .group ("variable" )} : typing.ClassVar[typing. TypeAlias] = { match .group ("value" )} '
41+ return f'{ match .group ("variable" )} : typing.TypeAlias = { fix_value ( match .group ("value" ) )} '
3042
3143
3244VersionPattern = re .compile (r"(?P<var>[a-zA-Z0-9_].*): str = '.*?'" )
@@ -36,13 +48,20 @@ def str_sub_func(match: re.Match) -> str:
3648 return f"{ match .group ('var' )} : str"
3749
3850
51+ CompilerConfigPattern = re .compile (r"compiler_config: dict.*" )
52+
53+
54+ def compiler_config_sub_func (match : re .Match ) -> str :
55+ return "compiler_config: dict"
56+
57+
3958EqPattern = re .compile (
4059 r"(?P<indent>[ \t]+)def __eq__\(self, arg0: (?P<other>[a-zA-Z1-9.]+)\) -> (?P<return>[a-zA-Z1-9.]+):"
4160 r"(?P<ellipsis_docstring>\s*((\.\.\.)|(\"\"\"(.|\n)*?\"\"\")))"
4261)
4362
4463
45- def eq_sub_func (match : re .Match ) -> str :
64+ def eq_sub_func (match : re .Match [ str ] ) -> str :
4665 """
4766 if one - add @overload and overloaded signature
4867
@@ -80,7 +99,7 @@ def eq_sub_func(match: re.Match) -> str:
8099
81100
82101def generic_alias_sub_func (match : re .Match ) -> str :
83- return f" { match .group (' variable' )} : typing.TypeAlias = { match .group (' value' ) } "
102+ return f' { match .group (" variable" )} : typing.TypeAlias = { fix_value ( match .group (" value" )) } '
84103
85104
86105def get_module_path (name : str ) -> str :
@@ -95,72 +114,37 @@ def get_package_dir(name: str) -> str:
95114 return os .path .realpath (os .path .dirname (get_module_path (name )))
96115
97116
98- def patch_stubgen ():
117+ def patch_stubgen () -> None :
118+ class_member_blacklist : set [Identifier ] = FilterClassMembers ._FilterClassMembers__class_member_blacklist # type: ignore
119+ attribute_blacklist : set [Identifier ] = FilterClassMembers ._FilterClassMembers__attribute_blacklist # type: ignore
120+
99121 # Is there a better way to add items to the blacklist?
100122 # Pybind11
101- FilterClassMembers ._FilterClassMembers__class_member_blacklist .add (
102- Identifier ("_pybind11_conduit_v1_" )
103- )
123+ class_member_blacklist .add (Identifier ("_pybind11_conduit_v1_" ))
104124 # Python
105- FilterClassMembers ._FilterClassMembers__class_member_blacklist .add (
106- Identifier ("__new__" )
107- )
108- FilterClassMembers ._FilterClassMembers__class_member_blacklist .add (
109- Identifier ("__subclasshook__" )
110- )
125+ class_member_blacklist .add (Identifier ("__new__" ))
126+ class_member_blacklist .add (Identifier ("__subclasshook__" ))
111127 # Pickle
112- FilterClassMembers ._FilterClassMembers__class_member_blacklist .add (
113- Identifier ("__getnewargs__" )
114- )
115- FilterClassMembers ._FilterClassMembers__class_member_blacklist .add (
116- Identifier ("__getstate__" )
117- )
118- FilterClassMembers ._FilterClassMembers__class_member_blacklist .add (
119- Identifier ("__setstate__" )
120- )
128+ class_member_blacklist .add (Identifier ("__getnewargs__" ))
129+ class_member_blacklist .add (Identifier ("__getstate__" ))
130+ class_member_blacklist .add (Identifier ("__setstate__" ))
121131 # ABC
122- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
123- Identifier ("__abstractmethods__" )
124- )
125- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
126- Identifier ("__orig_bases__" )
127- )
128- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
129- Identifier ("__parameters__" )
130- )
131- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
132- Identifier ("_abc_impl" )
133- )
132+ attribute_blacklist .add (Identifier ("__abstractmethods__" ))
133+ attribute_blacklist .add (Identifier ("__orig_bases__" ))
134+ attribute_blacklist .add (Identifier ("__parameters__" ))
135+ attribute_blacklist .add (Identifier ("_abc_impl" ))
134136 # Protocol
135- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
136- Identifier ("__protocol_attrs__" )
137- )
138- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
139- Identifier ("__non_callable_proto_members__" )
140- )
141- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
142- Identifier ("_is_protocol" )
143- )
144- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
145- Identifier ("_is_runtime_protocol" )
146- )
137+ attribute_blacklist .add (Identifier ("__protocol_attrs__" ))
138+ attribute_blacklist .add (Identifier ("__non_callable_proto_members__" ))
139+ attribute_blacklist .add (Identifier ("_is_protocol" ))
140+ attribute_blacklist .add (Identifier ("_is_runtime_protocol" ))
147141 # dataclass
148- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
149- Identifier ("__dataclass_fields__" )
150- )
151- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
152- Identifier ("__dataclass_params__" )
153- )
154- FilterClassMembers ._FilterClassMembers__attribute_blacklist .add (
155- Identifier ("__match_args__" )
156- )
142+ attribute_blacklist .add (Identifier ("__dataclass_fields__" ))
143+ attribute_blacklist .add (Identifier ("__dataclass_params__" ))
144+ attribute_blacklist .add (Identifier ("__match_args__" ))
157145 # Buffer protocol
158- FilterClassMembers ._FilterClassMembers__class_member_blacklist .add (
159- Identifier ("__buffer__" )
160- )
161- FilterClassMembers ._FilterClassMembers__class_member_blacklist .add (
162- Identifier ("__release_buffer__" )
163- )
146+ class_member_blacklist .add (Identifier ("__buffer__" ))
147+ class_member_blacklist .add (Identifier ("__release_buffer__" ))
164148
165149
166150def main () -> None :
@@ -236,6 +220,7 @@ def main() -> None:
236220 pyi = UnionPattern .sub (union_sub_func , pyi )
237221 pyi = ClassVarUnionPattern .sub (class_var_union_sub_func , pyi )
238222 pyi = VersionPattern .sub (str_sub_func , pyi )
223+ pyi = CompilerConfigPattern .sub (compiler_config_sub_func , pyi )
239224 pyi = GenericAliasPattern .sub (generic_alias_sub_func , pyi )
240225 pyi = pyi .replace (
241226 "__hash__: typing.ClassVar[None] = None" ,
@@ -244,13 +229,12 @@ def main() -> None:
244229 pyi = EqPattern .sub (eq_sub_func , pyi )
245230 pyi = pyi .replace ("**kwargs)" , "**kwargs: typing.Any)" )
246231 pyi_split = [l .rstrip ("\r " ) for l in pyi .split ("\n " )]
247- for hidden_import in []:
232+ for hidden_import in ["typing" , "types" ]:
248233 if hidden_import in pyi and f"import { hidden_import } " not in pyi_split :
249- pyi_split .insert (2 , f"import { hidden_import } " )
250- if "import typing" not in pyi_split :
251- pyi_split .insert (2 , "import typing" )
252- if "import types" not in pyi_split :
253- pyi_split .insert (2 , "import types" )
234+ pyi_split .insert (
235+ pyi_split .index ("from __future__ import annotations" ) + 1 ,
236+ f"import { hidden_import } " ,
237+ )
254238 pyi = "\n " .join (pyi_split )
255239 with open (stub_path , "w" , encoding = "utf-8" ) as f :
256240 f .write (pyi )
0 commit comments