Skip to content

Commit 07f1270

Browse files
committed
🤖 ignore redundant scalar binop types, reject number bitops, stop skipping M & m
1 parent 3582d21 commit 07f1270

File tree

1 file changed

+42
-20
lines changed

1 file changed

+42
-20
lines changed

‎tool/testgen.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@
4949
"c": frozenset({f"{NP}.complex64", f"{NP}.complex128", f"{NP}.clongdouble"}),
5050
}
5151

52-
DATETIME_OPS: Final = {"+", "-"}
53-
TIMEDELTA_OPS: Final = DATETIME_OPS | {"*", "/", "//", "%"}
54-
BITWISE_OPS: Final = {"<<", ">>", "&", "^", "|"}
52+
DATETIME_OPS: Final = {"add", "sub"}
53+
TIMEDELTA_OPS: Final = DATETIME_OPS | {"mul", "truediv", "floordiv", "mod", "divmod"}
54+
BITWISE_OPS: Final = {"lshift", "rshift", "and", "or", "xor"}
5555
BITWISE_CHARS: Final = "?bhilqBHILQ"
5656

5757
INTP_EXPR: Final = f"{NP}.{np.intp.__name__}"
@@ -659,35 +659,57 @@ class ScalarOps(TestGen):
659659
}
660660

661661
ops: Final[dict[str, _BinOp]]
662+
names: Final[dict[str, str]]
663+
reject: Final[frozenset[str]]
662664

663665
def __init__(self, kind: _BinOpKind, /) -> None:
666+
reject: set[str] = set()
667+
ignore: set[str] = set()
668+
669+
if kind in {"modular", "bitwise"}:
670+
# TODO(jorenham): Reject `inexact`
671+
reject |= {"bhilBHILefdgFDG", "FDG", "F", "D", "G"}
672+
673+
# ignore the non-standard concrete complex (and float if bitwise) types
674+
# to avoid generating many redundant rejection tests
675+
ignore |= set("FGM")
676+
if kind == "bitwise":
677+
ignore |= set("efgm")
678+
664679
match kind:
665680
case "arithmetic":
666681
ops = self.OPS_ARITHMETIC
667682
case "modular":
668683
ops = self.OPS_MODULAR
669684
case "bitwise":
670685
ops = self.OPS_BITWISE
686+
reject |= {"efdgFDG", "efdg", "e", "f", "d", "g"}
671687
case "comparison":
672688
ops = self.OPS_COMPARISON
673689

674690
self.ops = ops
691+
self.names = {k: name for k, name in self.NAMES.items() if k not in ignore}
692+
self.reject = frozenset(reject)
693+
675694
self.testname = self.testname.format(kind)
676695

677696
super().__init__()
678697

679698
def _is_builtin(self, key: str, /) -> bool:
680-
return len(key) > 1 and self.NAMES[key].endswith("_py")
699+
return len(key) > 1 and self.names[key].endswith("_py")
681700

682701
def _is_abstract(self, key: str, /) -> bool:
683-
return len(key) > 1 and not self.NAMES[key].endswith("_py")
702+
return len(key) > 1 and not self.names[key].endswith("_py")
684703

685704
def _decompose(self, key: str, /) -> tuple[_Scalar, ...]:
686705
if not self._is_abstract(key):
687706
return (_scalar(key),)
688707
return tuple(map(_scalar, key))
689708

690709
def _evaluate_concrete(self, op: str, lhs: str, rhs: str, /) -> str | None:
710+
if lhs in self.reject or rhs in self.reject:
711+
return None
712+
691713
fn = self.ops[op]
692714
nout = 2 if fn.__module__ == "builtins" else 1
693715

@@ -722,26 +744,26 @@ def _evaluate_concrete(self, op: str, lhs: str, rhs: str, /) -> str | None:
722744
return f"tuple[{', '.join(result_exprs)}]" if nout > 1 else result_exprs[0]
723745

724746
def _assert_stmt(self, op: str, lhs: str, rhs: str, /) -> str | None:
725-
expr_eval = op.format(self.NAMES[lhs], self.NAMES[rhs])
726-
is_op = self.ops[op].__module__ != "builtins" # not the case for divmod
747+
expr_eval = op.format(self.names[lhs], self.names[rhs])
727748

728749
if not (expr_type := self._evaluate_concrete(op, lhs, rhs)):
729750
# generate rejection test, while avoiding trivial cases
751+
opname = self.ops[op].__name__.removesuffix("_")
730752
if (
731-
# ignore bitwise ops if either arg is not a bitwise char
732-
(op in BITWISE_OPS and {lhs, rhs} - set(BITWISE_CHARS))
753+
# ignore bitwise ops if neither arg is a bitwise char
754+
(opname in BITWISE_OPS and not {lhs, rhs} & set(BITWISE_CHARS))
733755
# ignore if either arg is datetime and and not a datetime op
734-
or (op not in DATETIME_OPS and "M" in {lhs, rhs})
756+
or (opname not in DATETIME_OPS and "M" in {lhs, rhs})
735757
# ignore if either arg is timedelta and and not a timedelta op
736-
or (op not in TIMEDELTA_OPS and "m" in {lhs, rhs})
758+
or (opname not in TIMEDELTA_OPS and "m" in {lhs, rhs})
737759
):
738760
return None
739761

740762
# pyright special casing
741-
if is_op:
742-
pyright_rules = ["OperatorIssue"]
743-
else:
763+
if opname == "divmod":
744764
pyright_rules = ["ArgumentType", "CallIssue"]
765+
else:
766+
pyright_rules = ["OperatorIssue"]
745767
pyright_ignore = ", ".join(map("report{}".format, pyright_rules))
746768

747769
return " ".join((
@@ -769,7 +791,7 @@ def _assert_stmt(self, op: str, lhs: str, rhs: str, /) -> str | None:
769791
if (
770792
op in self.OPS_ARITHMETIC | self.OPS_MODULAR
771793
and lhs == rhs
772-
and (abstract_arg := self.ABSTRACT_TYPES.get(self.NAMES[lhs]))
794+
and (abstract_arg := self.ABSTRACT_TYPES.get(self.names[lhs]))
773795
):
774796
if abstract_arg == "integer" and " / " not in op:
775797
mypy_ignore = "assert-type, operator"
@@ -790,19 +812,19 @@ def _generate_names_section(self) -> Generator[str]:
790812
@override
791813
def get_names(self) -> Iterable[tuple[str, str]]:
792814
# builtin scalars
793-
for builtin, name in self.NAMES.items():
815+
for builtin, name in self.names.items():
794816
if self._is_builtin(builtin):
795817
yield name, builtin
796818

797819
# constrete numpy scalars
798820
yield "", ""
799-
for char, name in self.NAMES.items():
821+
for char, name in self.names.items():
800822
if len(char) == 1:
801823
yield name, _sctype_expr(np.dtype(char))
802824

803825
# abstract numpy scalars
804826
yield "", ""
805-
for char, kind in self.NAMES.items():
827+
for char, kind in self.names.items():
806828
if self._is_abstract(char):
807829
yield kind, f"{NP}.{self.ABSTRACT_TYPES[kind]}"
808830

@@ -813,13 +835,13 @@ def get_testcases(self) -> Iterable[str | None]:
813835

814836
yield from self._generate_section(f"__[r]{opname}__")
815837

816-
for lhs in self.NAMES:
838+
for lhs in self.names:
817839
if self._is_builtin(lhs):
818840
# will cause false positives on pyright; as designed, of course
819841
continue
820842

821843
n = 0
822-
for rhs in self.NAMES:
844+
for rhs in self.names:
823845
if fn.__name__ in {"eq", "ne"} and self._is_abstract(rhs):
824846
# will be inferred by mypy as `Any` for some reason
825847
continue

0 commit comments

Comments
 (0)