11
11
import difflib
12
12
import itertools
13
13
import operator as op
14
- from collections .abc import Callable , Generator , Iterable , Iterator , Sequence
15
- from typing import Any , ClassVar , Final , Literal , TypeAlias , TypeVar , cast , final
14
+ from collections .abc import Callable , Generator , Iterable , Iterator , Mapping , Sequence
15
+ from typing import (
16
+ Any ,
17
+ ClassVar ,
18
+ Final ,
19
+ Literal ,
20
+ TypeAlias ,
21
+ TypeVar ,
22
+ cast ,
23
+ final ,
24
+ )
16
25
from typing_extensions import override
17
26
18
27
import numpy as np
@@ -1105,6 +1114,17 @@ def __init__(self, opname: _OpName, /) -> None:
1105
1114
1106
1115
super ().__init__ ()
1107
1116
1117
+ @property
1118
+ def _scalars_py (self ) -> Mapping [str , type [complex | bytes | str ]]:
1119
+ kindmap = {"b" : bool , "i" : int , "f" : float , "c" : complex , "S" : bytes , "U" : str }
1120
+ kinds = {dtype .kind : "" for dtype in self .dtypes }
1121
+ return {f"{ kind } _py" : kindmap [kind ] for kind in kinds if kind in kindmap }
1122
+
1123
+ def _op_expr (self , lhs : str , rhs : str , / ) -> str :
1124
+ if self .opfunc .__name__ == "divmod" :
1125
+ return f"divmod({ lhs } , { rhs } )"
1126
+ return lhs + str (self .opfunc .__doc__ )[9 :- 2 ] + rhs
1127
+
1108
1128
@staticmethod
1109
1129
def _get_arrays (
1110
1130
dtype1 : np .dtype ,
@@ -1124,47 +1144,89 @@ def _get_arrays(
1124
1144
1125
1145
@override
1126
1146
def get_names (self ) -> Iterable [tuple [str , str ]]:
1147
+ # ndarays
1127
1148
for dtype in self .dtypes :
1128
- yield f"array_ { dtype_label (dtype )} _nd" , _array_expr (dtype , npt = True )
1149
+ yield f"{ dtype_label (dtype )} _nd" , _array_expr (dtype , npt = True )
1129
1150
1130
- @override
1131
- def get_testcases (self ) -> Iterable [str | None ]:
1132
- op_expr_template = str (self .opfunc .__doc__ )[8 :- 1 ]
1133
- op_expr_template = op_expr_template .replace ("a" , "{}" ).replace ("b" , "{}" )
1151
+ yield "" , "" # linebreak
1134
1152
1135
- yield from self ._generate_section ()
1153
+ # python scalars
1154
+ for name , pytype in self ._scalars_py .items ():
1155
+ yield name , pytype .__name__
1136
1156
1137
- for dtype1 in self .dtypes :
1138
- yielded = 0
1139
- for dtype2 in self .dtypes :
1140
- name1 = f"array_{ dtype_label (dtype1 )} _nd"
1141
- name2 = f"array_{ dtype_label (dtype2 )} _nd"
1142
- op_expr = op_expr_template .format (name1 , name2 )
1157
+ def _gen_testcases_np_nd (self , dtype1 : np .dtype , / ) -> Generator [str | None ]:
1158
+ name1 = f"{ dtype_label (dtype1 )} _nd"
1143
1159
1144
- arr1 , arr2 = self ._get_arrays (dtype1 , dtype2 )
1160
+ for dtype2 in self .dtypes :
1161
+ name2 = f"{ dtype_label (dtype2 )} _nd"
1162
+ expr = self ._op_expr (name1 , name2 )
1145
1163
1146
- try :
1147
- out = self .opfunc (arr1 , arr2 )
1148
- except TypeError :
1149
- if "O" in dtype1 .char + dtype2 .char :
1150
- # impossible to reject
1151
- continue
1164
+ arr1 , arr2 = self ._get_arrays (dtype1 , dtype2 )
1152
1165
1153
- testcase = " " .join (( # noqa: FLY002
1154
- op_expr ,
1155
- "# type: ignore[operator]" ,
1156
- "# pyright: ignore[reportOperatorIssue]" ,
1157
- ))
1158
- else :
1159
- out_type_expr = _array_expr (out .dtype , npt = True )
1160
- testcase = _expr_assert_type (op_expr , out_type_expr )
1166
+ try :
1167
+ out = self .opfunc (arr1 , arr2 )
1168
+ except TypeError :
1169
+ if "O" in dtype1 .char + dtype2 .char :
1170
+ # impossible to reject
1171
+ continue
1161
1172
1162
- yield testcase
1163
- yielded += 1
1173
+ testcase = " " .join (( # noqa: FLY002
1174
+ expr ,
1175
+ "# type: ignore[operator]" ,
1176
+ "# pyright: ignore[reportOperatorIssue]" ,
1177
+ ))
1178
+ else :
1179
+ out_type_expr = _array_expr (out .dtype , npt = True )
1180
+ testcase = _expr_assert_type (expr , out_type_expr )
1181
+
1182
+ yield testcase
1183
+
1184
+ def _gen_testcases_py_0d (
1185
+ self ,
1186
+ dtype : np .dtype ,
1187
+ / ,
1188
+ * ,
1189
+ reflect : bool = False ,
1190
+ ) -> Generator [str | None ]:
1191
+ name_np = f"{ dtype_label (dtype )} _nd"
1192
+
1193
+ for name_py , pytype in self ._scalars_py .items ():
1194
+ name1 , name2 = (name_py , name_np ) if reflect else (name_np , name_py )
1195
+
1196
+ val_np , val_py = self ._get_arrays (dtype , np .dtype (pytype ))[0 ], pytype (1 )
1197
+ val1 , val2 = (val_py , val_np ) if reflect else (val_np , val_py )
1164
1198
1165
- if yielded > 2 :
1166
- # avoid inserting excessive newlines
1167
- yield ""
1199
+ expr = self ._op_expr (name1 , name2 )
1200
+
1201
+ try :
1202
+ out = self .opfunc (val1 , val2 )
1203
+ except TypeError :
1204
+ if reflect and pytype is bytes :
1205
+ # impossible to reject
1206
+ continue
1207
+
1208
+ testcase = " " .join (( # noqa: FLY002
1209
+ expr ,
1210
+ "# type: ignore[operator]" ,
1211
+ "# pyright: ignore[reportOperatorIssue]" ,
1212
+ ))
1213
+ else :
1214
+ out_type_expr = _array_expr (out .dtype , npt = True )
1215
+ testcase = _expr_assert_type (expr , out_type_expr )
1216
+
1217
+ yield testcase
1218
+
1219
+ @override
1220
+ def get_testcases (self ) -> Iterable [str | None ]:
1221
+ yield from self ._generate_section ()
1222
+
1223
+ for dtype in self .dtypes :
1224
+ yield from self ._gen_testcases_np_nd (dtype )
1225
+ yield ""
1226
+ yield from self ._gen_testcases_py_0d (dtype )
1227
+ yield ""
1228
+ yield from self ._gen_testcases_py_0d (dtype , reflect = True )
1229
+ yield ""
1168
1230
1169
1231
1170
1232
TESTGENS : Final [Sequence [TestGen ]] = [
0 commit comments