|
| 1 | +import inspect |
| 2 | +import math |
| 3 | +import re |
| 4 | +from typing import Callable, Dict, NamedTuple, Pattern |
| 5 | +from warnings import warn |
| 6 | + |
| 7 | +import pytest |
| 8 | +from attr import dataclass |
| 9 | +from hypothesis import assume, given |
| 10 | + |
| 11 | +from . import hypothesis_helpers as hh |
| 12 | +from . import shape_helpers as sh |
| 13 | +from . import xps |
| 14 | +from ._array_module import mod as xp |
| 15 | +from .stubs import category_to_funcs |
| 16 | + |
| 17 | +repr_to_value = { |
| 18 | + "NaN": float("nan"), |
| 19 | + "+infinity": float("infinity"), |
| 20 | + "infinity": float("infinity"), |
| 21 | + "-infinity": float("-infinity"), |
| 22 | + "+0": 0.0, |
| 23 | + "0": 0.0, |
| 24 | + "-0": -0.0, |
| 25 | + "+1": 1.0, |
| 26 | + "1": 1.0, |
| 27 | + "-1": -1.0, |
| 28 | + "+π/2": math.pi / 2, |
| 29 | + "π/2": math.pi / 2, |
| 30 | + "-π/2": -math.pi / 2, |
| 31 | +} |
| 32 | + |
| 33 | + |
| 34 | +def make_eq(v: float) -> Callable[[float], bool]: |
| 35 | + if math.isnan(v): |
| 36 | + return math.isnan |
| 37 | + |
| 38 | + def eq(i: float) -> bool: |
| 39 | + return i == v |
| 40 | + |
| 41 | + return eq |
| 42 | + |
| 43 | + |
| 44 | +def make_rough_eq(v: float) -> Callable[[float], bool]: |
| 45 | + def rough_eq(i: float) -> bool: |
| 46 | + return math.isclose(i, v, abs_tol=0.01) |
| 47 | + |
| 48 | + return rough_eq |
| 49 | + |
| 50 | + |
| 51 | +def make_gt(v: float): |
| 52 | + assert not math.isnan(v) # sanity check |
| 53 | + |
| 54 | + def gt(i: float): |
| 55 | + return i > v |
| 56 | + |
| 57 | + return gt |
| 58 | + |
| 59 | + |
| 60 | +def make_lt(v: float): |
| 61 | + assert not math.isnan(v) # sanity check |
| 62 | + |
| 63 | + def lt(i: float): |
| 64 | + return i < v |
| 65 | + |
| 66 | + return lt |
| 67 | + |
| 68 | + |
| 69 | +def make_or(cond1: Callable, cond2: Callable): |
| 70 | + def or_(i: float): |
| 71 | + return cond1(i) or cond2(i) |
| 72 | + |
| 73 | + return or_ |
| 74 | + |
| 75 | + |
| 76 | +r_value = re.compile(r"``([^\s]+)``") |
| 77 | +r_approx_value = re.compile( |
| 78 | + rf"an implementation-dependent approximation to {r_value.pattern}" |
| 79 | +) |
| 80 | + |
| 81 | + |
| 82 | +@dataclass |
| 83 | +class ValueParseError(ValueError): |
| 84 | + value: str |
| 85 | + |
| 86 | + |
| 87 | +def parse_value(value: str) -> float: |
| 88 | + if m := r_value.match(value): |
| 89 | + return repr_to_value[m.group(1)] |
| 90 | + raise ValueParseError(value) |
| 91 | + |
| 92 | + |
| 93 | +class Result(NamedTuple): |
| 94 | + value: float |
| 95 | + repr_: str |
| 96 | + strict_check: bool |
| 97 | + |
| 98 | + |
| 99 | +def parse_result(result: str) -> Result: |
| 100 | + if m := r_value.match(result): |
| 101 | + repr_ = m.group(1) |
| 102 | + strict_check = True |
| 103 | + elif m := r_approx_value.match(result): |
| 104 | + repr_ = m.group(1) |
| 105 | + strict_check = False |
| 106 | + else: |
| 107 | + raise ValueParseError(result) |
| 108 | + value = repr_to_value[repr_] |
| 109 | + return Result(value, repr_, strict_check) |
| 110 | + |
| 111 | + |
| 112 | +r_special_cases = re.compile( |
| 113 | + r"\*\*Special [Cc]ases\*\*\n\n\s*" |
| 114 | + r"For floating-point operands,\n\n" |
| 115 | + r"((?:\s*-\s*.*\n)+)" |
| 116 | +) |
| 117 | +r_case = re.compile(r"\s+-\s*(.*)\.\n?") |
| 118 | +r_remaining_case = re.compile("In the remaining cases.+") |
| 119 | + |
| 120 | + |
| 121 | +unary_pattern_to_condition_factory: Dict[Pattern, Callable] = { |
| 122 | + re.compile("If ``x_i`` is greater than (.+), the result is (.+)"): make_gt, |
| 123 | + re.compile("If ``x_i`` is less than (.+), the result is (.+)"): make_lt, |
| 124 | + re.compile("If ``x_i`` is either (.+) or (.+), the result is (.+)"): ( |
| 125 | + lambda v1, v2: make_or(make_eq(v1), make_eq(v2)) |
| 126 | + ), |
| 127 | + # This pattern must come after the previous patterns to avoid unwanted matches |
| 128 | + re.compile("If ``x_i`` is (.+), the result is (.+)"): make_eq, |
| 129 | + re.compile( |
| 130 | + "If two integers are equally close to ``x_i``, the result is (.+)" |
| 131 | + ): lambda: (lambda i: (abs(i) - math.floor(abs(i))) == 0.5), |
| 132 | +} |
| 133 | + |
| 134 | + |
| 135 | +def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]: |
| 136 | + match = r_special_cases.search(docstring) |
| 137 | + if match is None: |
| 138 | + return {} |
| 139 | + cases = match.group(1).split("\n")[:-1] |
| 140 | + condition_to_result = {} |
| 141 | + for line in cases: |
| 142 | + if m := r_case.match(line): |
| 143 | + case = m.group(1) |
| 144 | + else: |
| 145 | + warn(f"line not machine-readable: '{line}'") |
| 146 | + continue |
| 147 | + for pattern, make_cond in unary_pattern_to_condition_factory.items(): |
| 148 | + if m := pattern.search(case): |
| 149 | + *s_values, s_result = m.groups() |
| 150 | + try: |
| 151 | + values = [parse_value(v) for v in s_values] |
| 152 | + except ValueParseError as e: |
| 153 | + warn(f"value not machine-readable: '{e.value}'") |
| 154 | + break |
| 155 | + cond = make_cond(*values) |
| 156 | + try: |
| 157 | + result = parse_result(s_result) |
| 158 | + except ValueParseError as e: |
| 159 | + warn(f"result not machine-readable: '{e.value}'") |
| 160 | + break |
| 161 | + condition_to_result[cond] = result |
| 162 | + break |
| 163 | + else: |
| 164 | + if not r_remaining_case.search(case): |
| 165 | + warn(f"case not machine-readable: '{case}'") |
| 166 | + return condition_to_result |
| 167 | + |
| 168 | + |
| 169 | +unary_params = [] |
| 170 | +for stub in category_to_funcs["elementwise"]: |
| 171 | + if stub.__doc__ is None: |
| 172 | + warn(f"{stub.__name__}() stub has no docstring") |
| 173 | + continue |
| 174 | + marks = [] |
| 175 | + try: |
| 176 | + func = getattr(xp, stub.__name__) |
| 177 | + except AttributeError: |
| 178 | + marks.append( |
| 179 | + pytest.mark.skip(reason=f"{stub.__name__} not found in array module") |
| 180 | + ) |
| 181 | + func = None |
| 182 | + sig = inspect.signature(stub) |
| 183 | + param_names = list(sig.parameters.keys()) |
| 184 | + if len(sig.parameters) == 0: |
| 185 | + warn(f"{func=} has no parameters") |
| 186 | + continue |
| 187 | + if param_names[0] == "x": |
| 188 | + if condition_to_result := parse_unary_docstring(stub.__doc__): |
| 189 | + p = pytest.param(stub.__name__, func, condition_to_result, id=stub.__name__) |
| 190 | + unary_params.append(p) |
| 191 | + continue |
| 192 | + if len(sig.parameters) == 1: |
| 193 | + warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'") |
| 194 | + continue |
| 195 | + if param_names[0] == "x1" and param_names[1] == "x2": |
| 196 | + pass # TODO |
| 197 | + else: |
| 198 | + warn( |
| 199 | + f"{func=} starts with two parameters '{param_names[0]}' and " |
| 200 | + f"'{param_names[1]}', which are not named 'x1' and 'x2'" |
| 201 | + ) |
| 202 | + |
| 203 | + |
| 204 | +# good_example is a flag that tells us whether Hypothesis generated an array |
| 205 | +# with at least on element that is special-cased. We reject the example when |
| 206 | +# its False - Hypothesis will complain if we reject too many examples, thus |
| 207 | +# indicating we should modify the array strategy being used. |
| 208 | + |
| 209 | + |
| 210 | +@pytest.mark.parametrize("func_name, func, condition_to_result", unary_params) |
| 211 | +@given(x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) |
| 212 | +def test_unary_special_cases(func_name, func, condition_to_result, x): |
| 213 | + res = func(x) |
| 214 | + good_example = False |
| 215 | + for idx in sh.ndindex(res.shape): |
| 216 | + in_ = float(x[idx]) |
| 217 | + for cond, result in condition_to_result.items(): |
| 218 | + if cond(in_): |
| 219 | + good_example = True |
| 220 | + out = float(res[idx]) |
| 221 | + f_in = f"{sh.fmt_idx('x', idx)}={in_}" |
| 222 | + f_out = f"{sh.fmt_idx('out', idx)}={out}" |
| 223 | + if result.strict_check: |
| 224 | + msg = ( |
| 225 | + f"{f_out}, but should be {result.repr_} [{func_name}()]\n" |
| 226 | + f"{f_in}" |
| 227 | + ) |
| 228 | + if math.isnan(result.value): |
| 229 | + assert math.isnan(out), msg |
| 230 | + else: |
| 231 | + assert out == result.value, msg |
| 232 | + else: |
| 233 | + assert math.isfinite(result.value) # sanity check |
| 234 | + assert math.isclose(out, result.value, abs_tol=0.1), ( |
| 235 | + f"{f_out}, but should be roughly {result.repr_}={result.value} " |
| 236 | + f"[{func_name}()]\n" |
| 237 | + f"{f_in}" |
| 238 | + ) |
| 239 | + break |
| 240 | + assume(good_example) |
0 commit comments