Skip to content

Commit 63938db

Browse files
Improve build tools (#436)
1 parent ae3d1f8 commit 63938db

File tree

6 files changed

+62
-72
lines changed

6 files changed

+62
-72
lines changed

build_requires.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# mypy: disable-error-code=no-redef
2+
13
from typing import Union, Mapping
24

35
from setuptools import build_meta

mypy.ini

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,13 @@ disallow_untyped_defs = True
33
check_untyped_defs = True
44
warn_return_any = True
55
python_version = 3.12
6+
explicit_package_bases = True
67
mypy_path = $MYPY_CONFIG_FILE_DIR/src,$MYPY_CONFIG_FILE_DIR/tests
7-
packages = amulet,test_amulet_core
8+
files =
9+
src,
10+
tests,
11+
tools,
12+
get_compiler,
13+
build_requires.py,
14+
requirements.py,
15+
setup.py

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ docs = [
2929
]
3030
dev = [
3131
"setuptools>=42",
32+
"types-setuptools",
3233
"versioneer",
34+
"types-versioneer",
3335
"packaging",
3436
"wheel",
3537
"pybind11_stubgen>=2.5.4",

setup.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pathlib import Path
55
import platform
66
from tempfile import TemporaryDirectory
7+
from typing import TypeAlias, TYPE_CHECKING
78

89
from setuptools import setup, Extension, Command
910
from setuptools.command.build_ext import build_ext
@@ -13,15 +14,20 @@
1314
import requirements
1415

1516

16-
def fix_path(path: str) -> str:
17+
def fix_path(path: str | os.PathLike[str]) -> str:
1718
return os.path.realpath(path).replace(os.sep, "/")
1819

1920

2021
cmdclass: dict[str, type[Command]] = versioneer.get_cmdclass()
2122

23+
if TYPE_CHECKING:
24+
BuildExt: TypeAlias = build_ext
25+
else:
26+
BuildExt = cmdclass.get("build_ext", build_ext)
2227

23-
class CMakeBuild(cmdclass.get("build_ext", build_ext)):
24-
def build_extension(self, ext):
28+
29+
class CMakeBuild(BuildExt):
30+
def build_extension(self, ext: Extension) -> None:
2531
import pybind11
2632
import amulet.pybind11_extensions
2733
import amulet.io
@@ -80,7 +86,7 @@ def build_extension(self, ext):
8086
raise RuntimeError("Error installing amulet-core")
8187

8288

83-
cmdclass["build_ext"] = CMakeBuild
89+
cmdclass["build_ext"] = CMakeBuild # type: ignore
8490

8591

8692
setup(

tools/cmake_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def fix_path(path: str) -> str:
1818
RootDir = fix_path(os.path.dirname(os.path.dirname(__file__)))
1919

2020

21-
def main():
21+
def main() -> None:
2222
platform_args = []
2323
if sys.platform == "win32":
2424
platform_args.extend(["-G", "Visual Studio 17 2022"])

tools/generate_pybind_stubs.py

Lines changed: 38 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616

1717

18-
def union_sub_func(match: re.Match) -> str:
18+
def union_sub_func(match: re.Match[str]) -> str:
1919
return f'{match.group("variable")}: typing.TypeAlias = {match.group("value")}'
2020

2121

@@ -36,13 +36,20 @@ def str_sub_func(match: re.Match) -> str:
3636
return f"{match.group('var')}: str"
3737

3838

39+
CompilerConfigPattern = re.compile(r"compiler_config: dict.*")
40+
41+
42+
def compiler_config_sub_func(match: re.Match) -> str:
43+
return "compiler_config: dict"
44+
45+
3946
EqPattern = re.compile(
4047
r"(?P<indent>[ \t]+)def __eq__\(self, arg0: (?P<other>[a-zA-Z1-9.]+)\) -> (?P<return>[a-zA-Z1-9.]+):"
4148
r"(?P<ellipsis_docstring>\s*((\.\.\.)|(\"\"\"(.|\n)*?\"\"\")))"
4249
)
4350

4451

45-
def eq_sub_func(match: re.Match) -> str:
52+
def eq_sub_func(match: re.Match[str]) -> str:
4653
"""
4754
if one - add @overload and overloaded signature
4855
@@ -95,72 +102,37 @@ def get_package_dir(name: str) -> str:
95102
return os.path.realpath(os.path.dirname(get_module_path(name)))
96103

97104

98-
def patch_stubgen():
105+
def patch_stubgen() -> None:
106+
class_member_blacklist: set[Identifier] = FilterClassMembers._FilterClassMembers__class_member_blacklist # type: ignore
107+
attribute_blacklist: set[Identifier] = FilterClassMembers._FilterClassMembers__attribute_blacklist # type: ignore
108+
99109
# Is there a better way to add items to the blacklist?
100110
# Pybind11
101-
FilterClassMembers._FilterClassMembers__class_member_blacklist.add(
102-
Identifier("_pybind11_conduit_v1_")
103-
)
111+
class_member_blacklist.add(Identifier("_pybind11_conduit_v1_"))
104112
# Python
105-
FilterClassMembers._FilterClassMembers__class_member_blacklist.add(
106-
Identifier("__new__")
107-
)
108-
FilterClassMembers._FilterClassMembers__class_member_blacklist.add(
109-
Identifier("__subclasshook__")
110-
)
113+
class_member_blacklist.add(Identifier("__new__"))
114+
class_member_blacklist.add(Identifier("__subclasshook__"))
111115
# Pickle
112-
FilterClassMembers._FilterClassMembers__class_member_blacklist.add(
113-
Identifier("__getnewargs__")
114-
)
115-
FilterClassMembers._FilterClassMembers__class_member_blacklist.add(
116-
Identifier("__getstate__")
117-
)
118-
FilterClassMembers._FilterClassMembers__class_member_blacklist.add(
119-
Identifier("__setstate__")
120-
)
116+
class_member_blacklist.add(Identifier("__getnewargs__"))
117+
class_member_blacklist.add(Identifier("__getstate__"))
118+
class_member_blacklist.add(Identifier("__setstate__"))
121119
# ABC
122-
FilterClassMembers._FilterClassMembers__attribute_blacklist.add(
123-
Identifier("__abstractmethods__")
124-
)
125-
FilterClassMembers._FilterClassMembers__attribute_blacklist.add(
126-
Identifier("__orig_bases__")
127-
)
128-
FilterClassMembers._FilterClassMembers__attribute_blacklist.add(
129-
Identifier("__parameters__")
130-
)
131-
FilterClassMembers._FilterClassMembers__attribute_blacklist.add(
132-
Identifier("_abc_impl")
133-
)
120+
attribute_blacklist.add(Identifier("__abstractmethods__"))
121+
attribute_blacklist.add(Identifier("__orig_bases__"))
122+
attribute_blacklist.add(Identifier("__parameters__"))
123+
attribute_blacklist.add(Identifier("_abc_impl"))
134124
# Protocol
135-
FilterClassMembers._FilterClassMembers__attribute_blacklist.add(
136-
Identifier("__protocol_attrs__")
137-
)
138-
FilterClassMembers._FilterClassMembers__attribute_blacklist.add(
139-
Identifier("__non_callable_proto_members__")
140-
)
141-
FilterClassMembers._FilterClassMembers__attribute_blacklist.add(
142-
Identifier("_is_protocol")
143-
)
144-
FilterClassMembers._FilterClassMembers__attribute_blacklist.add(
145-
Identifier("_is_runtime_protocol")
146-
)
125+
attribute_blacklist.add(Identifier("__protocol_attrs__"))
126+
attribute_blacklist.add(Identifier("__non_callable_proto_members__"))
127+
attribute_blacklist.add(Identifier("_is_protocol"))
128+
attribute_blacklist.add(Identifier("_is_runtime_protocol"))
147129
# dataclass
148-
FilterClassMembers._FilterClassMembers__attribute_blacklist.add(
149-
Identifier("__dataclass_fields__")
150-
)
151-
FilterClassMembers._FilterClassMembers__attribute_blacklist.add(
152-
Identifier("__dataclass_params__")
153-
)
154-
FilterClassMembers._FilterClassMembers__attribute_blacklist.add(
155-
Identifier("__match_args__")
156-
)
130+
attribute_blacklist.add(Identifier("__dataclass_fields__"))
131+
attribute_blacklist.add(Identifier("__dataclass_params__"))
132+
attribute_blacklist.add(Identifier("__match_args__"))
157133
# Buffer protocol
158-
FilterClassMembers._FilterClassMembers__class_member_blacklist.add(
159-
Identifier("__buffer__")
160-
)
161-
FilterClassMembers._FilterClassMembers__class_member_blacklist.add(
162-
Identifier("__release_buffer__")
163-
)
134+
class_member_blacklist.add(Identifier("__buffer__"))
135+
class_member_blacklist.add(Identifier("__release_buffer__"))
164136

165137

166138
def main() -> None:
@@ -236,6 +208,7 @@ def main() -> None:
236208
pyi = UnionPattern.sub(union_sub_func, pyi)
237209
pyi = ClassVarUnionPattern.sub(class_var_union_sub_func, pyi)
238210
pyi = VersionPattern.sub(str_sub_func, pyi)
211+
pyi = CompilerConfigPattern.sub(compiler_config_sub_func, pyi)
239212
pyi = GenericAliasPattern.sub(generic_alias_sub_func, pyi)
240213
pyi = pyi.replace(
241214
"__hash__: typing.ClassVar[None] = None",
@@ -244,13 +217,12 @@ def main() -> None:
244217
pyi = EqPattern.sub(eq_sub_func, pyi)
245218
pyi = pyi.replace("**kwargs)", "**kwargs: typing.Any)")
246219
pyi_split = [l.rstrip("\r") for l in pyi.split("\n")]
247-
for hidden_import in []:
220+
for hidden_import in ["typing", "types"]:
248221
if hidden_import in pyi and f"import {hidden_import}" not in pyi_split:
249-
pyi_split.insert(2, f"import {hidden_import}")
250-
if "import typing" not in pyi_split:
251-
pyi_split.insert(2, "import typing")
252-
if "import types" not in pyi_split:
253-
pyi_split.insert(2, "import types")
222+
pyi_split.insert(
223+
pyi_split.index("from __future__ import annotations") + 1,
224+
f"import {hidden_import}",
225+
)
254226
pyi = "\n".join(pyi_split)
255227
with open(stub_path, "w", encoding="utf-8") as f:
256228
f.write(pyi)

0 commit comments

Comments
 (0)