Skip to content

Commit b395d7b

Browse files
authored
Merge pull request #549 from numpy/testgen-improvements
1 parent 8254e64 commit b395d7b

File tree

2 files changed

+59
-32
lines changed

2 files changed

+59
-32
lines changed

src/numpy-stubs/@test/generated/literal_bool_ops.pyi

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tool/testgen.py

Lines changed: 58 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
TypeVar,
2424
cast,
2525
final,
26+
get_args,
2627
overload,
2728
)
2829
from typing_extensions import override
@@ -62,12 +63,17 @@
6263
"xor",
6364
]
6465
_UnOpName: TypeAlias = Literal["abs", "neg", "pos", "invert"]
65-
_OpName: TypeAlias = Literal[_BinOpName, _UnOpName]
66+
_OpName: TypeAlias = Literal[_UnOpName, _BinOpName]
6667

6768
###
6869

69-
ROOT_DIR: Final = Path(__file__).parent.parent
70-
TARGET_DIR: Final = ROOT_DIR / "src" / "numpy-stubs" / "@test" / "generated"
70+
DIR_ROOT: Final = Path(__file__).parent.parent
71+
DIR_SRC: Final = DIR_ROOT / "src"
72+
DIRS_TARGET: Final = {
73+
dir_package.stem: dir_package / "@test" / "generated"
74+
for dir_package in DIR_SRC.iterdir()
75+
if dir_package.is_dir()
76+
}
7177

7278
TAB: Final = " " * 4
7379
BR: Final = "\n"
@@ -366,6 +372,7 @@ def _strip_preamble(source: str) -> tuple[str | None, str]:
366372

367373

368374
class TestGen(abc.ABC):
375+
package: ClassVar[str]
369376
stdlib_imports: ClassVar[tuple[str, ...]] = ("from typing import assert_type",)
370377
numpy_imports: ClassVar[tuple[str, ...]] = (f"import numpy as {NP}",)
371378

@@ -383,7 +390,7 @@ def __init__(self) -> None:
383390
@property
384391
def path(self) -> Path:
385392
assert self.testname
386-
return TARGET_DIR / f"{self.testname}.pyi"
393+
return DIRS_TARGET[self.package] / f"{self.testname}.pyi"
387394

388395
def get_names(self) -> Iterable[tuple[str, str]]:
389396
return ()
@@ -411,7 +418,7 @@ def _generate_section(self, /, *lines: str) -> Generator[str]:
411418

412419
def _generate_preamble(self) -> Generator[str]:
413420
timestamp = f"{np.datetime64('now')}Z"
414-
here = Path(__file__).relative_to(ROOT_DIR)
421+
here = Path(__file__).relative_to(DIR_ROOT)
415422

416423
yield f"# {PREAMBLE_PREFIX} {timestamp} with {here}"
417424

@@ -492,7 +499,7 @@ def regenerate(self, /, *, always: bool = False) -> Iterator[str]:
492499
head_new, body_new = _strip_preamble(src_new)
493500
assert head_new, src_new
494501

495-
path_new = str(self.path.relative_to(ROOT_DIR))
502+
path_new = str(self.path.relative_to(DIR_ROOT))
496503
date_new = head_new.split(" ", 1)[0]
497504

498505
if src_old := self._read():
@@ -516,13 +523,14 @@ def regenerate(self, /, *, always: bool = False) -> Iterator[str]:
516523
tofile=path_new,
517524
fromfiledate=date_old,
518525
tofiledate=date_new if write else date_old,
519-
n=0,
526+
n=1,
520527
lineterm=BR,
521528
)
522529

523530

524531
@final
525532
class EMath(TestGen):
533+
package = "numpy-stubs"
526534
testname = "emath"
527535

528536
VALUES: Final[dict[str, list[Any]]] = {
@@ -735,6 +743,7 @@ def get_testcases(self) -> Iterable[str | None]:
735743

736744
@final
737745
class LiteralBoolOps(TestGen):
746+
package = "numpy-stubs"
738747
testname = "literal_bool_ops"
739748

740749
UNOPS: ClassVar = {
@@ -887,6 +896,7 @@ def get_testcases(self) -> Iterable[str | None]:
887896

888897
@final
889898
class ScalarOps(TestGen):
899+
package = "numpy-stubs"
890900
testname = "scalar_ops_{}"
891901

892902
OPS_ARITHMETIC: ClassVar[dict[str, _BinOp]] = {
@@ -1144,6 +1154,7 @@ def get_testcases(self) -> Iterable[str | None]:
11441154

11451155

11461156
class NDArrayOps(TestGen):
1157+
package = "numpy-stubs"
11471158
testname = "ndarray_{}"
11481159
numpy_imports_extra: tuple[str, ...] = ("import _numtype as _nt",)
11491160

@@ -1527,36 +1538,52 @@ def get_testcases(self) -> Iterable[str | None]:
15271538
TESTGENS: Final[Sequence[TestGen]] = [
15281539
EMath(binary=False),
15291540
LiteralBoolOps(),
1530-
ScalarOps("arithmetic"),
1531-
ScalarOps("modular"),
1532-
ScalarOps("bitwise"),
1533-
ScalarOps("comparison"),
1534-
NDArrayOps("pos"),
1535-
NDArrayOps("neg"),
1536-
NDArrayOps("abs"),
1537-
NDArrayOps("invert"),
1538-
NDArrayOps("add"),
1539-
NDArrayOps("sub"),
1540-
NDArrayOps("mul"),
1541-
NDArrayOps("matmul"),
1542-
NDArrayOps("pow"),
1543-
NDArrayOps("truediv"),
1544-
NDArrayOps("floordiv"),
1545-
NDArrayOps("mod"),
1546-
NDArrayOps("divmod"),
1547-
NDArrayOps("lshift"),
1548-
NDArrayOps("rshift"),
1549-
NDArrayOps("and"),
1550-
NDArrayOps("xor"),
1551-
NDArrayOps("or"),
1541+
*(ScalarOps(op_kind) for op_kind in get_args(_BinOpKind)),
1542+
*(NDArrayOps(op_name) for op_name in get_args(_OpName)),
15521543
]
15531544

15541545

15551546
@np.errstate(all="ignore")
15561547
def main() -> None:
1557-
"""(Re)generate the `src/numpy-stubs/@test/generated/{}.pyi` type-tests."""
1548+
"""(Re)generate the `src/*/@test/generated/{}.pyi` type-tests."""
1549+
cwd = Path.cwd()
1550+
paths: dict[str, dict[Path, bool]] = {}
1551+
15581552
for testgen in TESTGENS:
1559-
sys.stdout.writelines(testgen.regenerate())
1553+
path = testgen.path
1554+
diff = testgen.regenerate()
1555+
diff_out, diff_check = itertools.tee(diff, 2)
1556+
sys.stderr.writelines(diff_out)
1557+
sys.stderr.write("\n")
1558+
sys.stderr.flush()
1559+
1560+
diff_count = sum(1 for _ in diff_check)
1561+
if not diff_count:
1562+
sys.stdout.write(f"skipped ./{path.relative_to(cwd)}\n")
1563+
sys.stdout.flush()
1564+
1565+
package_paths = paths.setdefault(testgen.package, {})
1566+
assert path not in package_paths, path
1567+
package_paths[path] = bool(diff_count)
1568+
1569+
orphans: list[Path] = []
1570+
for package, testdir in DIRS_TARGET.items():
1571+
if not testdir.exists():
1572+
continue
1573+
assert testdir.is_dir()
1574+
1575+
known = paths.get(package, {})
1576+
for path in testdir.rglob("*.pyi"):
1577+
assert path.is_file()
1578+
if path not in known:
1579+
orphans.append(path)
1580+
1581+
for orphan in orphans:
1582+
assert orphan.is_file()
1583+
orphan.unlink()
1584+
1585+
sys.stderr.write(f"removed ./{orphan.relative_to(cwd)}\n")
1586+
sys.stderr.flush()
15601587

15611588

15621589
if __name__ == "__main__":

0 commit comments

Comments
 (0)