Skip to content

Commit 522c779

Browse files
authored
🤖 testgen ndarray binop vs builtin scalar types (#502)
1 parent 43bcac0 commit 522c779

File tree

1 file changed

+96
-34
lines changed

1 file changed

+96
-34
lines changed

‎tool/testgen.py

Lines changed: 96 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,17 @@
1111
import difflib
1212
import itertools
1313
import operator as op
14-
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
15-
from typing import Any, ClassVar, Final, Literal, TypeAlias, TypeVar, cast, final
14+
from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence
15+
from typing import (
16+
Any,
17+
ClassVar,
18+
Final,
19+
Literal,
20+
TypeAlias,
21+
TypeVar,
22+
cast,
23+
final,
24+
)
1625
from typing_extensions import override
1726

1827
import numpy as np
@@ -1105,6 +1114,17 @@ def __init__(self, opname: _OpName, /) -> None:
11051114

11061115
super().__init__()
11071116

1117+
@property
1118+
def _scalars_py(self) -> Mapping[str, type[complex | bytes | str]]:
1119+
kindmap = {"b": bool, "i": int, "f": float, "c": complex, "S": bytes, "U": str}
1120+
kinds = {dtype.kind: "" for dtype in self.dtypes}
1121+
return {f"{kind}_py": kindmap[kind] for kind in kinds if kind in kindmap}
1122+
1123+
def _op_expr(self, lhs: str, rhs: str, /) -> str:
1124+
if self.opfunc.__name__ == "divmod":
1125+
return f"divmod({lhs}, {rhs})"
1126+
return lhs + str(self.opfunc.__doc__)[9:-2] + rhs
1127+
11081128
@staticmethod
11091129
def _get_arrays(
11101130
dtype1: np.dtype,
@@ -1124,47 +1144,89 @@ def _get_arrays(
11241144

11251145
@override
11261146
def get_names(self) -> Iterable[tuple[str, str]]:
1147+
# ndarays
11271148
for dtype in self.dtypes:
1128-
yield f"array_{dtype_label(dtype)}_nd", _array_expr(dtype, npt=True)
1149+
yield f"{dtype_label(dtype)}_nd", _array_expr(dtype, npt=True)
11291150

1130-
@override
1131-
def get_testcases(self) -> Iterable[str | None]:
1132-
op_expr_template = str(self.opfunc.__doc__)[8:-1]
1133-
op_expr_template = op_expr_template.replace("a", "{}").replace("b", "{}")
1151+
yield "", "" # linebreak
11341152

1135-
yield from self._generate_section()
1153+
# python scalars
1154+
for name, pytype in self._scalars_py.items():
1155+
yield name, pytype.__name__
11361156

1137-
for dtype1 in self.dtypes:
1138-
yielded = 0
1139-
for dtype2 in self.dtypes:
1140-
name1 = f"array_{dtype_label(dtype1)}_nd"
1141-
name2 = f"array_{dtype_label(dtype2)}_nd"
1142-
op_expr = op_expr_template.format(name1, name2)
1157+
def _gen_testcases_np_nd(self, dtype1: np.dtype, /) -> Generator[str | None]:
1158+
name1 = f"{dtype_label(dtype1)}_nd"
11431159

1144-
arr1, arr2 = self._get_arrays(dtype1, dtype2)
1160+
for dtype2 in self.dtypes:
1161+
name2 = f"{dtype_label(dtype2)}_nd"
1162+
expr = self._op_expr(name1, name2)
11451163

1146-
try:
1147-
out = self.opfunc(arr1, arr2)
1148-
except TypeError:
1149-
if "O" in dtype1.char + dtype2.char:
1150-
# impossible to reject
1151-
continue
1164+
arr1, arr2 = self._get_arrays(dtype1, dtype2)
11521165

1153-
testcase = " ".join(( # noqa: FLY002
1154-
op_expr,
1155-
"# type: ignore[operator]",
1156-
"# pyright: ignore[reportOperatorIssue]",
1157-
))
1158-
else:
1159-
out_type_expr = _array_expr(out.dtype, npt=True)
1160-
testcase = _expr_assert_type(op_expr, out_type_expr)
1166+
try:
1167+
out = self.opfunc(arr1, arr2)
1168+
except TypeError:
1169+
if "O" in dtype1.char + dtype2.char:
1170+
# impossible to reject
1171+
continue
11611172

1162-
yield testcase
1163-
yielded += 1
1173+
testcase = " ".join(( # noqa: FLY002
1174+
expr,
1175+
"# type: ignore[operator]",
1176+
"# pyright: ignore[reportOperatorIssue]",
1177+
))
1178+
else:
1179+
out_type_expr = _array_expr(out.dtype, npt=True)
1180+
testcase = _expr_assert_type(expr, out_type_expr)
1181+
1182+
yield testcase
1183+
1184+
def _gen_testcases_py_0d(
1185+
self,
1186+
dtype: np.dtype,
1187+
/,
1188+
*,
1189+
reflect: bool = False,
1190+
) -> Generator[str | None]:
1191+
name_np = f"{dtype_label(dtype)}_nd"
1192+
1193+
for name_py, pytype in self._scalars_py.items():
1194+
name1, name2 = (name_py, name_np) if reflect else (name_np, name_py)
1195+
1196+
val_np, val_py = self._get_arrays(dtype, np.dtype(pytype))[0], pytype(1)
1197+
val1, val2 = (val_py, val_np) if reflect else (val_np, val_py)
11641198

1165-
if yielded > 2:
1166-
# avoid inserting excessive newlines
1167-
yield ""
1199+
expr = self._op_expr(name1, name2)
1200+
1201+
try:
1202+
out = self.opfunc(val1, val2)
1203+
except TypeError:
1204+
if reflect and pytype is bytes:
1205+
# impossible to reject
1206+
continue
1207+
1208+
testcase = " ".join(( # noqa: FLY002
1209+
expr,
1210+
"# type: ignore[operator]",
1211+
"# pyright: ignore[reportOperatorIssue]",
1212+
))
1213+
else:
1214+
out_type_expr = _array_expr(out.dtype, npt=True)
1215+
testcase = _expr_assert_type(expr, out_type_expr)
1216+
1217+
yield testcase
1218+
1219+
@override
1220+
def get_testcases(self) -> Iterable[str | None]:
1221+
yield from self._generate_section()
1222+
1223+
for dtype in self.dtypes:
1224+
yield from self._gen_testcases_np_nd(dtype)
1225+
yield ""
1226+
yield from self._gen_testcases_py_0d(dtype)
1227+
yield ""
1228+
yield from self._gen_testcases_py_0d(dtype, reflect=True)
1229+
yield ""
11681230

11691231

11701232
TESTGENS: Final[Sequence[TestGen]] = [

0 commit comments

Comments
 (0)