Skip to content

Commit 6a4fffc

Browse files
committed
Rudimentary special case tests on runtime
1 parent 61eec43 commit 6a4fffc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+240
-5383
lines changed

array_api_tests/special_cases.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
import inspect
2+
import math
3+
import re
4+
from typing import Callable, Dict, NamedTuple, Pattern
5+
from warnings import warn
6+
7+
import pytest
8+
from attr import dataclass
9+
from hypothesis import assume, given
10+
11+
from . import hypothesis_helpers as hh
12+
from . import shape_helpers as sh
13+
from . import xps
14+
from ._array_module import mod as xp
15+
from .stubs import category_to_funcs
16+
17+
repr_to_value = {
18+
"NaN": float("nan"),
19+
"+infinity": float("infinity"),
20+
"infinity": float("infinity"),
21+
"-infinity": float("-infinity"),
22+
"+0": 0.0,
23+
"0": 0.0,
24+
"-0": -0.0,
25+
"+1": 1.0,
26+
"1": 1.0,
27+
"-1": -1.0,
28+
"+π/2": math.pi / 2,
29+
"π/2": math.pi / 2,
30+
"-π/2": -math.pi / 2,
31+
}
32+
33+
34+
def make_eq(v: float) -> Callable[[float], bool]:
35+
if math.isnan(v):
36+
return math.isnan
37+
38+
def eq(i: float) -> bool:
39+
return i == v
40+
41+
return eq
42+
43+
44+
def make_rough_eq(v: float) -> Callable[[float], bool]:
45+
def rough_eq(i: float) -> bool:
46+
return math.isclose(i, v, abs_tol=0.01)
47+
48+
return rough_eq
49+
50+
51+
def make_gt(v: float):
52+
assert not math.isnan(v) # sanity check
53+
54+
def gt(i: float):
55+
return i > v
56+
57+
return gt
58+
59+
60+
def make_lt(v: float):
61+
assert not math.isnan(v) # sanity check
62+
63+
def lt(i: float):
64+
return i < v
65+
66+
return lt
67+
68+
69+
def make_or(cond1: Callable, cond2: Callable):
70+
def or_(i: float):
71+
return cond1(i) or cond2(i)
72+
73+
return or_
74+
75+
76+
r_value = re.compile(r"``([^\s]+)``")
77+
r_approx_value = re.compile(
78+
rf"an implementation-dependent approximation to {r_value.pattern}"
79+
)
80+
81+
82+
@dataclass
83+
class ValueParseError(ValueError):
84+
value: str
85+
86+
87+
def parse_value(value: str) -> float:
88+
if m := r_value.match(value):
89+
return repr_to_value[m.group(1)]
90+
raise ValueParseError(value)
91+
92+
93+
class Result(NamedTuple):
94+
value: float
95+
repr_: str
96+
strict_check: bool
97+
98+
99+
def parse_result(result: str) -> Result:
100+
if m := r_value.match(result):
101+
repr_ = m.group(1)
102+
strict_check = True
103+
elif m := r_approx_value.match(result):
104+
repr_ = m.group(1)
105+
strict_check = False
106+
else:
107+
raise ValueParseError(result)
108+
value = repr_to_value[repr_]
109+
return Result(value, repr_, strict_check)
110+
111+
112+
r_special_cases = re.compile(
113+
r"\*\*Special [Cc]ases\*\*\n\n\s*"
114+
r"For floating-point operands,\n\n"
115+
r"((?:\s*-\s*.*\n)+)"
116+
)
117+
r_case = re.compile(r"\s+-\s*(.*)\.\n?")
118+
r_remaining_case = re.compile("In the remaining cases.+")
119+
120+
121+
unary_pattern_to_condition_factory: Dict[Pattern, Callable] = {
122+
re.compile("If ``x_i`` is greater than (.+), the result is (.+)"): make_gt,
123+
re.compile("If ``x_i`` is less than (.+), the result is (.+)"): make_lt,
124+
re.compile("If ``x_i`` is either (.+) or (.+), the result is (.+)"): (
125+
lambda v1, v2: make_or(make_eq(v1), make_eq(v2))
126+
),
127+
# This pattern must come after the previous patterns to avoid unwanted matches
128+
re.compile("If ``x_i`` is (.+), the result is (.+)"): make_eq,
129+
re.compile(
130+
"If two integers are equally close to ``x_i``, the result is (.+)"
131+
): lambda: (lambda i: (abs(i) - math.floor(abs(i))) == 0.5),
132+
}
133+
134+
135+
def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]:
136+
match = r_special_cases.search(docstring)
137+
if match is None:
138+
return {}
139+
cases = match.group(1).split("\n")[:-1]
140+
condition_to_result = {}
141+
for line in cases:
142+
if m := r_case.match(line):
143+
case = m.group(1)
144+
else:
145+
warn(f"line not machine-readable: '{line}'")
146+
continue
147+
for pattern, make_cond in unary_pattern_to_condition_factory.items():
148+
if m := pattern.search(case):
149+
*s_values, s_result = m.groups()
150+
try:
151+
values = [parse_value(v) for v in s_values]
152+
except ValueParseError as e:
153+
warn(f"value not machine-readable: '{e.value}'")
154+
break
155+
cond = make_cond(*values)
156+
try:
157+
result = parse_result(s_result)
158+
except ValueParseError as e:
159+
warn(f"result not machine-readable: '{e.value}'")
160+
break
161+
condition_to_result[cond] = result
162+
break
163+
else:
164+
if not r_remaining_case.search(case):
165+
warn(f"case not machine-readable: '{case}'")
166+
return condition_to_result
167+
168+
169+
unary_params = []
170+
for stub in category_to_funcs["elementwise"]:
171+
if stub.__doc__ is None:
172+
warn(f"{stub.__name__}() stub has no docstring")
173+
continue
174+
marks = []
175+
try:
176+
func = getattr(xp, stub.__name__)
177+
except AttributeError:
178+
marks.append(
179+
pytest.mark.skip(reason=f"{stub.__name__} not found in array module")
180+
)
181+
func = None
182+
sig = inspect.signature(stub)
183+
param_names = list(sig.parameters.keys())
184+
if len(sig.parameters) == 0:
185+
warn(f"{func=} has no parameters")
186+
continue
187+
if param_names[0] == "x":
188+
if condition_to_result := parse_unary_docstring(stub.__doc__):
189+
p = pytest.param(stub.__name__, func, condition_to_result, id=stub.__name__)
190+
unary_params.append(p)
191+
continue
192+
if len(sig.parameters) == 1:
193+
warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'")
194+
continue
195+
if param_names[0] == "x1" and param_names[1] == "x2":
196+
pass # TODO
197+
else:
198+
warn(
199+
f"{func=} starts with two parameters '{param_names[0]}' and "
200+
f"'{param_names[1]}', which are not named 'x1' and 'x2'"
201+
)
202+
203+
204+
# good_example is a flag that tells us whether Hypothesis generated an array
205+
# with at least on element that is special-cased. We reject the example when
206+
# its False - Hypothesis will complain if we reject too many examples, thus
207+
# indicating we should modify the array strategy being used.
208+
209+
210+
@pytest.mark.parametrize("func_name, func, condition_to_result", unary_params)
211+
@given(x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)))
212+
def test_unary_special_cases(func_name, func, condition_to_result, x):
213+
res = func(x)
214+
good_example = False
215+
for idx in sh.ndindex(res.shape):
216+
in_ = float(x[idx])
217+
for cond, result in condition_to_result.items():
218+
if cond(in_):
219+
good_example = True
220+
out = float(res[idx])
221+
f_in = f"{sh.fmt_idx('x', idx)}={in_}"
222+
f_out = f"{sh.fmt_idx('out', idx)}={out}"
223+
if result.strict_check:
224+
msg = (
225+
f"{f_out}, but should be {result.repr_} [{func_name}()]\n"
226+
f"{f_in}"
227+
)
228+
if math.isnan(result.value):
229+
assert math.isnan(out), msg
230+
else:
231+
assert out == result.value, msg
232+
else:
233+
assert math.isfinite(result.value) # sanity check
234+
assert math.isclose(out, result.value, abs_tol=0.1), (
235+
f"{f_out}, but should be roughly {result.repr_}={result.value} "
236+
f"[{func_name}()]\n"
237+
f"{f_in}"
238+
)
239+
break
240+
assume(good_example)

array_api_tests/special_cases/__init__.py

Whitespace-only changes.

array_api_tests/special_cases/test_abs.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

array_api_tests/special_cases/test_acos.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

0 commit comments

Comments
 (0)