Skip to content

Commit ee68b89

Browse files
committed
Rudimentary testing for binary elwise special cases
1 parent 6a4fffc commit ee68b89

File tree

1 file changed

+162
-36
lines changed

1 file changed

+162
-36
lines changed

array_api_tests/special_cases.py renamed to array_api_tests/test_special_cases.py

Lines changed: 162 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,32 @@
66

77
import pytest
88
from attr import dataclass
9-
from hypothesis import assume, given
9+
from hypothesis import HealthCheck, assume, given, settings
1010

11+
from . import dtype_helpers as dh
1112
from . import hypothesis_helpers as hh
1213
from . import shape_helpers as sh
1314
from . import xps
1415
from ._array_module import mod as xp
1516
from .stubs import category_to_funcs
1617

17-
repr_to_value = {
18-
"NaN": float("nan"),
19-
"+infinity": float("infinity"),
20-
"infinity": float("infinity"),
21-
"-infinity": float("-infinity"),
22-
"+0": 0.0,
23-
"0": 0.0,
24-
"-0": -0.0,
25-
"+1": 1.0,
26-
"1": 1.0,
27-
"-1": -1.0,
28-
"+π/2": math.pi / 2,
29-
"π/2": math.pi / 2,
30-
"-π/2": -math.pi / 2,
31-
}
18+
19+
def is_pos_zero(n: float) -> bool:
20+
return n == 0 and math.copysign(1, n) == 1
21+
22+
23+
def is_neg_zero(n: float) -> bool:
24+
return n == 0 and math.copysign(1, n) == -1
3225

3326

3427
def make_eq(v: float) -> Callable[[float], bool]:
3528
if math.isnan(v):
3629
return math.isnan
30+
if v == 0:
31+
if is_pos_zero(v):
32+
return is_pos_zero
33+
else:
34+
return is_neg_zero
3735

3836
def eq(i: float) -> bool:
3937
return i == v
@@ -42,6 +40,8 @@ def eq(i: float) -> bool:
4240

4341

4442
def make_rough_eq(v: float) -> Callable[[float], bool]:
43+
assert math.isfinite(v) # sanity check
44+
4545
def rough_eq(i: float) -> bool:
4646
return math.isclose(i, v, abs_tol=0.01)
4747

@@ -73,21 +73,52 @@ def or_(i: float):
7373
return or_
7474

7575

76-
r_value = re.compile(r"``([^\s]+)``")
77-
r_approx_value = re.compile(
78-
rf"an implementation-dependent approximation to {r_value.pattern}"
79-
)
76+
repr_to_value = {
77+
"NaN": float("nan"),
78+
"infinity": float("infinity"),
79+
"0": 0.0,
80+
"1": 1.0,
81+
}
82+
83+
r_value = re.compile(r"([+-]?)(.+)")
84+
r_pi = re.compile(r"(\d?)π(?:/(\d))?")
8085

8186

8287
@dataclass
8388
class ValueParseError(ValueError):
8489
value: str
8590

8691

87-
def parse_value(value: str) -> float:
88-
if m := r_value.match(value):
89-
return repr_to_value[m.group(1)]
90-
raise ValueParseError(value)
92+
def parse_value(s_value: str) -> float:
93+
assert not s_value.startswith("``") and not s_value.endswith("``") # sanity check
94+
m = r_value.match(s_value)
95+
if m is None:
96+
raise ValueParseError(s_value)
97+
if pi_m := r_pi.match(m.group(2)):
98+
value = math.pi
99+
if numerator := pi_m.group(1):
100+
value *= int(numerator)
101+
if denominator := pi_m.group(2):
102+
value /= int(denominator)
103+
else:
104+
value = repr_to_value[m.group(2)]
105+
if sign := m.group(1):
106+
if sign == "-":
107+
value *= -1
108+
return value
109+
110+
111+
r_inline_code = re.compile(r"``([^\s]+)``")
112+
r_approx_value = re.compile(
113+
rf"an implementation-dependent approximation to {r_inline_code.pattern}"
114+
)
115+
116+
117+
def parse_inline_code(inline_code: str) -> float:
118+
if m := r_inline_code.match(inline_code):
119+
return parse_value(m.group(1))
120+
else:
121+
raise ValueParseError(inline_code)
91122

92123

93124
class Result(NamedTuple):
@@ -96,22 +127,24 @@ class Result(NamedTuple):
96127
strict_check: bool
97128

98129

99-
def parse_result(result: str) -> Result:
100-
if m := r_value.match(result):
101-
repr_ = m.group(1)
130+
def parse_result(s_result: str) -> Result:
131+
match = None
132+
if m := r_inline_code.match(s_result):
133+
match = m
102134
strict_check = True
103-
elif m := r_approx_value.match(result):
104-
repr_ = m.group(1)
135+
elif m := r_approx_value.match(s_result):
136+
match = m
105137
strict_check = False
106138
else:
107-
raise ValueParseError(result)
108-
value = repr_to_value[repr_]
139+
raise ValueParseError(s_result)
140+
value = parse_value(match.group(1))
141+
repr_ = match.group(1)
109142
return Result(value, repr_, strict_check)
110143

111144

112145
r_special_cases = re.compile(
113-
r"\*\*Special [Cc]ases\*\*\n\n\s*"
114-
r"For floating-point operands,\n\n"
146+
r"\*\*Special [Cc]ases\*\*\n+\s*"
147+
r"For floating-point operands,\n+"
115148
r"((?:\s*-\s*.*\n)+)"
116149
)
117150
r_case = re.compile(r"\s+-\s*(.*)\.\n?")
@@ -148,7 +181,7 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]:
148181
if m := pattern.search(case):
149182
*s_values, s_result = m.groups()
150183
try:
151-
values = [parse_value(v) for v in s_values]
184+
values = [parse_inline_code(v) for v in s_values]
152185
except ValueParseError as e:
153186
warn(f"value not machine-readable: '{e.value}'")
154187
break
@@ -166,7 +199,56 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]:
166199
return condition_to_result
167200

168201

202+
binary_pattern_to_condition_factory: Dict[Pattern, Callable] = {
203+
re.compile(
204+
"If ``x1_i`` is (.+) and ``x2_i`` is (.+), the result is (.+)"
205+
): lambda v1, v2: lambda i1, i2: make_eq(v1)(i1)
206+
and make_eq(v2)(i2),
207+
}
208+
209+
210+
def parse_binary_docstring(docstring: str) -> Dict[Callable, Result]:
211+
match = r_special_cases.search(docstring)
212+
if match is None:
213+
return {}
214+
cases = match.group(1).split("\n")[:-1]
215+
condition_to_result = {}
216+
for line in cases:
217+
if m := r_case.match(line):
218+
case = m.group(1)
219+
else:
220+
warn(f"line not machine-readable: '{line}'")
221+
continue
222+
for pattern, make_cond in binary_pattern_to_condition_factory.items():
223+
if m := pattern.search(case):
224+
*s_values, s_result = m.groups()
225+
try:
226+
values = [parse_inline_code(v) for v in s_values]
227+
except ValueParseError as e:
228+
warn(f"value not machine-readable: '{e.value}'")
229+
break
230+
cond = make_cond(*values)
231+
if (
232+
"atan2" in docstring
233+
and is_pos_zero(values[0])
234+
and is_neg_zero(values[1])
235+
):
236+
breakpoint()
237+
try:
238+
result = parse_result(s_result)
239+
except ValueParseError as e:
240+
warn(f"result not machine-readable: '{e.value}'")
241+
break
242+
condition_to_result[cond] = result
243+
break
244+
else:
245+
if not r_remaining_case.search(case):
246+
warn(f"case not machine-readable: '{case}'")
247+
return condition_to_result
248+
249+
169250
unary_params = []
251+
binary_params = []
170252
for stub in category_to_funcs["elementwise"]:
171253
if stub.__doc__ is None:
172254
warn(f"{stub.__name__}() stub has no docstring")
@@ -193,7 +275,10 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]:
193275
warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'")
194276
continue
195277
if param_names[0] == "x1" and param_names[1] == "x2":
196-
pass # TODO
278+
if condition_to_result := parse_binary_docstring(stub.__doc__):
279+
p = pytest.param(stub.__name__, func, condition_to_result, id=stub.__name__)
280+
binary_params.append(p)
281+
continue
197282
else:
198283
warn(
199284
f"{func=} starts with two parameters '{param_names[0]}' and "
@@ -209,7 +294,7 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]:
209294

210295
@pytest.mark.parametrize("func_name, func, condition_to_result", unary_params)
211296
@given(x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)))
212-
def test_unary_special_cases(func_name, func, condition_to_result, x):
297+
def test_unary(func_name, func, condition_to_result, x):
213298
res = func(x)
214299
good_example = False
215300
for idx in sh.ndindex(res.shape):
@@ -238,3 +323,44 @@ def test_unary_special_cases(func_name, func, condition_to_result, x):
238323
)
239324
break
240325
assume(good_example)
326+
327+
328+
@pytest.mark.parametrize("func_name, func, condition_to_result", binary_params)
329+
@given(
330+
*hh.two_mutual_arrays(
331+
dtypes=dh.float_dtypes,
332+
two_shapes=hh.mutually_broadcastable_shapes(2, min_side=1),
333+
)
334+
)
335+
@settings(suppress_health_check=[HealthCheck.filter_too_much]) # TODO: remove
336+
def test_binary(func_name, func, condition_to_result, x1, x2):
337+
res = func(x1, x2)
338+
good_example = False
339+
for l_idx, r_idx, o_idx in sh.iter_indices(x1.shape, x2.shape, res.shape):
340+
l = float(x1[l_idx])
341+
r = float(x2[r_idx])
342+
for cond, result in condition_to_result.items():
343+
if cond(l, r):
344+
good_example = True
345+
out = float(res[o_idx])
346+
f_left = f"{sh.fmt_idx('x1', l_idx)}={l}"
347+
f_right = f"{sh.fmt_idx('x2', r_idx)}={r}"
348+
f_out = f"{sh.fmt_idx('out', o_idx)}={out}"
349+
if result.strict_check:
350+
msg = (
351+
f"{f_out}, but should be {result.repr_} [{func_name}()]\n"
352+
f"{f_left}, {f_right}"
353+
)
354+
if math.isnan(result.value):
355+
assert math.isnan(out), msg
356+
else:
357+
assert out == result.value, msg
358+
else:
359+
assert math.isfinite(result.value) # sanity check
360+
assert math.isclose(out, result.value, abs_tol=0.1), (
361+
f"{f_out}, but should be roughly {result.repr_}={result.value} "
362+
f"[{func_name}()]\n"
363+
f"{f_left}, {f_right}"
364+
)
365+
break
366+
assume(good_example)

0 commit comments

Comments
 (0)