diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index eaeedfd4..44aa89dc 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -11,4 +11,6 @@ a37f348ba27b6818e92fda8aee2406c653c671ea ec5a3b4e185c262b0a5f5b1631b84a09f766d80e 9058908b58ce627467ac34e768098a25f5863d31 c80e1823c2e738381ca02f27cea1e2b89dde0ac5 +# gh-402 +bdc84e8316046cb5bdc637067460057eef17d0f1 diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 84bcaa28..f074bd8e 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -737,48 +737,63 @@ def test_abs(ctx, data): if x.dtype in dh.int_dtypes: assume(xp.all(x > dh.dtype_ranges[x.dtype].min)) - out = ctx.func(x) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({x!r})") + try: + out = ctx.func(x) - if x.dtype in dh.complex_dtypes: - assert out.dtype == dh.dtype_components[x.dtype] - else: - ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl( - ctx.func_name, - x, - out, - abs, # type: ignore - res_stype=float if x.dtype in dh.complex_dtypes else None, - expr_template="abs({})={}", - # filter_=lambda s: ( - # s == float("infinity") or (cmath.isfinite(s) and not ph.is_neg_zero(s)) - # ), - ) + if x.dtype in dh.complex_dtypes: + assert out.dtype == dh.dtype_components[x.dtype] + else: + ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl( + ctx.func_name, + x, + out, + abs, # type: ignore + res_stype=float if x.dtype in dh.complex_dtypes else None, + expr_template="abs({})={}", + # filter_=lambda s: ( + # s == float("infinity") or (cmath.isfinite(s) and not ph.is_neg_zero(s)) + # ), + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_acos(x): - out = xp.acos(x) - ph.assert_dtype("acos", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("acos", out_shape=out.shape, expected=x.shape) - refimpl = cmath.acos if x.dtype in dh.complex_dtypes else math.acos - filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 <= s <= 1 - unary_assert_against_refimpl( - "acos", x, out, refimpl, filter_=filter_ - ) + repro_snippet = ph.format_snippet(f"xp.acos({x!r})") + try: + out = xp.acos(x) + ph.assert_dtype("acos", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("acos", out_shape=out.shape, expected=x.shape) + refimpl = cmath.acos if x.dtype in dh.complex_dtypes else math.acos + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 <= s <= 1 + unary_assert_against_refimpl( + "acos", x, out, refimpl, filter_=filter_ + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_acosh(x): - out = xp.acosh(x) - ph.assert_dtype("acosh", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("acosh", out_shape=out.shape, expected=x.shape) - refimpl = cmath.acosh if x.dtype in dh.complex_dtypes else math.acosh - filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s >= 1 - unary_assert_against_refimpl( - "acosh", x, out, refimpl, filter_=filter_ - ) + repro_snippet = ph.format_snippet(f"xp.acosh({x!r})") + try: + out = xp.acosh(x) + ph.assert_dtype("acosh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("acosh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.acosh if x.dtype in dh.complex_dtypes else math.acosh + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s >= 1 + unary_assert_against_refimpl( + "acosh", x, out, refimpl, filter_=filter_ + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx,", make_binary_params("add", dh.numeric_dtypes)) @@ -787,71 +802,101 @@ def test_add(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - with hh.reject_overflow(): - res = ctx.func(left, right) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + with hh.reject_overflow(): + res = ctx.func(left, right) - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - binary_param_assert_against_refimpl(ctx, left, right, res, "+", operator.add) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl(ctx, left, right, res, "+", operator.add) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_asin(x): - out = xp.asin(x) - ph.assert_dtype("asin", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("asin", out_shape=out.shape, expected=x.shape) - refimpl = cmath.asin if x.dtype in dh.complex_dtypes else math.asin - filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 <= s <= 1 - unary_assert_against_refimpl( - "asin", x, out, refimpl, filter_=filter_ - ) + repro_snippet = ph.format_snippet(f"xp.asin({x!r})") + try: + out = xp.asin(x) + ph.assert_dtype("asin", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("asin", out_shape=out.shape, expected=x.shape) + refimpl = cmath.asin if x.dtype in dh.complex_dtypes else math.asin + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 <= s <= 1 + unary_assert_against_refimpl( + "asin", x, out, refimpl, filter_=filter_ + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_asinh(x): - out = xp.asinh(x) - ph.assert_dtype("asinh", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("asinh", out_shape=out.shape, expected=x.shape) - refimpl = cmath.asinh if x.dtype in dh.complex_dtypes else math.asinh - unary_assert_against_refimpl("asinh", x, out, refimpl) + repro_snippet = ph.format_snippet(f"xp.asinh({x!r})") + try: + out = xp.asinh(x) + ph.assert_dtype("asinh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("asinh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.asinh if x.dtype in dh.complex_dtypes else math.asinh + unary_assert_against_refimpl("asinh", x, out, refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_atan(x): - out = xp.atan(x) - ph.assert_dtype("atan", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("atan", out_shape=out.shape, expected=x.shape) - refimpl = cmath.atan if x.dtype in dh.complex_dtypes else math.atan - unary_assert_against_refimpl("atan", x, out, refimpl) + repro_snippet = ph.format_snippet(f"xp.atan({x!r})") + try: + out = xp.atan(x) + ph.assert_dtype("atan", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("atan", out_shape=out.shape, expected=x.shape) + refimpl = cmath.atan if x.dtype in dh.complex_dtypes else math.atan + unary_assert_against_refimpl("atan", x, out, refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_atan2(x1, x2): - out = xp.atan2(x1, x2) - _assert_correctness_binary( - "atan", - cmath.atan2 if x1.dtype in dh.complex_dtypes else math.atan2, - in_dtypes=[x1.dtype, x2.dtype], - in_shapes=[x1.shape, x2.shape], - in_arrs=[x1, x2], - out=out, - ) + repro_snippet = ph.format_snippet(f"xp.atan2({x1!r}, {x2!r})") + try: + out = xp.atan2(x1, x2) + _assert_correctness_binary( + "atan", + cmath.atan2 if x1.dtype in dh.complex_dtypes else math.atan2, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out, + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_atanh(x): - out = xp.atanh(x) - ph.assert_dtype("atanh", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("atanh", out_shape=out.shape, expected=x.shape) - refimpl = cmath.atanh if x.dtype in dh.complex_dtypes else math.atanh - filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 < s < 1 - unary_assert_against_refimpl( - "atanh", - x, - out, - refimpl, - filter_=filter_, - ) + repro_snippet = ph.format_snippet(f"xp.atanh({x!r})") + try: + out = xp.atanh(x) + ph.assert_dtype("atanh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("atanh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.atanh if x.dtype in dh.complex_dtypes else math.atanh + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 < s < 1 + unary_assert_against_refimpl( + "atanh", + x, + out, + refimpl, + filter_=filter_, + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize( @@ -862,15 +907,20 @@ def test_bitwise_and(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - res = ctx.func(left, right) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + res = ctx.func(left, right) - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - if left.dtype == xp.bool: - refimpl = operator.and_ - else: - refimpl = lambda l, r: mock_int_dtype(l & r, res.dtype) - binary_param_assert_against_refimpl(ctx, left, right, res, "&", refimpl) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + if left.dtype == xp.bool: + refimpl = operator.and_ + else: + refimpl = lambda l, r: mock_int_dtype(l & r, res.dtype) + binary_param_assert_against_refimpl(ctx, left, right, res, "&", refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize( @@ -885,14 +935,19 @@ def test_bitwise_left_shift(ctx, data): else: assume(not xp.any(ah.isnegative(right))) - res = ctx.func(left, right) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + res = ctx.func(left, right) - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - nbits = dh.dtype_nbits[res.dtype] - binary_param_assert_against_refimpl( - ctx, left, right, res, "<<", lambda l, r: l << r if r < nbits else 0 - ) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + nbits = dh.dtype_nbits[res.dtype] + binary_param_assert_against_refimpl( + ctx, left, right, res, "<<", lambda l, r: l << r if r < nbits else 0 + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize( @@ -902,15 +957,20 @@ def test_bitwise_left_shift(ctx, data): def test_bitwise_invert(ctx, data): x = data.draw(ctx.strat, label="x") - out = ctx.func(x) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({x!r})") + try: + out = ctx.func(x) - ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) - if x.dtype == xp.bool: - refimpl = operator.not_ - else: - refimpl = lambda s: mock_int_dtype(~s, x.dtype) - unary_assert_against_refimpl(ctx.func_name, x, out, refimpl, expr_template="~{}={}") + ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) + if x.dtype == xp.bool: + refimpl = operator.not_ + else: + refimpl = lambda s: mock_int_dtype(~s, x.dtype) + unary_assert_against_refimpl(ctx.func_name, x, out, refimpl, expr_template="~{}={}") + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize( @@ -921,15 +981,20 @@ def test_bitwise_or(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - res = ctx.func(left, right) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + res = ctx.func(left, right) - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - if left.dtype == xp.bool: - refimpl = operator.or_ - else: - refimpl = lambda l, r: mock_int_dtype(l | r, res.dtype) - binary_param_assert_against_refimpl(ctx, left, right, res, "|", refimpl) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + if left.dtype == xp.bool: + refimpl = operator.or_ + else: + refimpl = lambda l, r: mock_int_dtype(l | r, res.dtype) + binary_param_assert_against_refimpl(ctx, left, right, res, "|", refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize( @@ -944,13 +1009,18 @@ def test_bitwise_right_shift(ctx, data): else: assume(not xp.any(ah.isnegative(right))) - res = ctx.func(left, right) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + res = ctx.func(left, right) - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - binary_param_assert_against_refimpl( - ctx, left, right, res, ">>", lambda l, r: mock_int_dtype(l >> r, res.dtype) - ) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl( + ctx, left, right, res, ">>", lambda l, r: mock_int_dtype(l >> r, res.dtype) + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize( @@ -961,24 +1031,32 @@ def test_bitwise_xor(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - res = ctx.func(left, right) - - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - if left.dtype == xp.bool: - refimpl = operator.xor - else: - refimpl = lambda l, r: mock_int_dtype(l ^ r, res.dtype) - binary_param_assert_against_refimpl(ctx, left, right, res, "^", refimpl) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + res = ctx.func(left, right) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + if left.dtype == xp.bool: + refimpl = operator.xor + else: + refimpl = lambda l, r: mock_int_dtype(l ^ r, res.dtype) + binary_param_assert_against_refimpl(ctx, left, right, res, "^", refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes())) def test_ceil(x): - out = xp.ceil(x) - ph.assert_dtype("ceil", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("ceil", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True) - + repro_snippet = ph.format_snippet(f"xp.ceil({x!r})") + try: + out = xp.ceil(x) + ph.assert_dtype("ceil", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("ceil", out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2023.12") @given(x=hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes()), data=st.data()) @@ -1009,141 +1087,163 @@ def test_clip(x, data): ("max", max, None)), label="kwargs") - out = xp.clip(x, **kw) - - # min and max do not participate in type promotion - ph.assert_dtype("clip", in_dtype=x.dtype, out_dtype=out.dtype) - - shapes = [x.shape] - if min is not None and not dh.is_scalar(min): - shapes.append(min.shape) - if max is not None and not dh.is_scalar(max): - shapes.append(max.shape) - expected_shape = sh.broadcast_shapes(*shapes) - ph.assert_shape("clip", out_shape=out.shape, expected=expected_shape) - - # This is based on right_scalar_assert_against_refimpl and - # binary_assert_against_refimpl. clip() is currently the only ternary - # elementwise function and the only function that supports arrays and - # scalars. However, where() (in test_searching_functions) is similar - # and if scalar support is added to it, we may want to factor out and - # reuse this logic. - - def refimpl(_x, _min, _max): - # Skip cases where _min and _max are integers whose values do not - # fit in the dtype of _x, since this behavior is unspecified. - if dh.is_int_dtype(x.dtype): - if _min is not None and _min not in dh.dtype_ranges[x.dtype]: - return None - if _max is not None and _max not in dh.dtype_ranges[x.dtype]: - return None - - # If min or max are float64 and x is float32, they will need to be - # downcast to float32. This could result in a round in the wrong - # direction meaning the resulting clipped value might not actually be - # between min and max. This behavior is unspecified, so skip any cases - # where x is within the rounding error of downcasting min or max. - if x.dtype == xp.float32: - if min is not None and not dh.is_scalar(min) and min.dtype == xp.float64 and math.isfinite(_min): - _min_float32 = float(xp.asarray(_min, dtype=xp.float32)) - if math.isinf(_min_float32): - return None - tol = abs(_min - _min_float32) - if math.isclose(_min, _min_float32, abs_tol=tol): - return None - if max is not None and not dh.is_scalar(max) and max.dtype == xp.float64 and math.isfinite(_max): - _max_float32 = float(xp.asarray(_max, dtype=xp.float32)) - if math.isinf(_max_float32): + repro_snippet = ph.format_snippet(f"xp.clip({x!r}, **kw) with {kw = }") + try: + out = xp.clip(x, **kw) + + # min and max do not participate in type promotion + ph.assert_dtype("clip", in_dtype=x.dtype, out_dtype=out.dtype) + + shapes = [x.shape] + if min is not None and not dh.is_scalar(min): + shapes.append(min.shape) + if max is not None and not dh.is_scalar(max): + shapes.append(max.shape) + expected_shape = sh.broadcast_shapes(*shapes) + ph.assert_shape("clip", out_shape=out.shape, expected=expected_shape) + + # This is based on right_scalar_assert_against_refimpl and + # binary_assert_against_refimpl. clip() is currently the only ternary + # elementwise function and the only function that supports arrays and + # scalars. However, where() (in test_searching_functions) is similar + # and if scalar support is added to it, we may want to factor out and + # reuse this logic. + + def refimpl(_x, _min, _max): + # Skip cases where _min and _max are integers whose values do not + # fit in the dtype of _x, since this behavior is unspecified. + if dh.is_int_dtype(x.dtype): + if _min is not None and _min not in dh.dtype_ranges[x.dtype]: return None - tol = abs(_max - _max_float32) - if math.isclose(_max, _max_float32, abs_tol=tol): + if _max is not None and _max not in dh.dtype_ranges[x.dtype]: return None - if (math.isnan(_x) - or (_min is not None and math.isnan(_min)) - or (_max is not None and math.isnan(_max))): - return math.nan - if _min is _max is None: - return _x - if _max is None: - return builtins.max(_x, _min) - if _min is None: - return builtins.min(_x, _max) - return builtins.min(builtins.max(_x, _min), _max) - - stype = dh.get_scalar_type(x.dtype) - min_shape = () if min is None or dh.is_scalar(min) else min.shape - max_shape = () if max is None or dh.is_scalar(max) else max.shape - - for x_idx, min_idx, max_idx, o_idx in sh.iter_indices( - x.shape, min_shape, max_shape, out.shape): - x_val = stype(x[x_idx]) - if min is None or dh.is_scalar(min): - min_val = min - else: - min_val = stype(min[min_idx]) - if max is None or dh.is_scalar(max): - max_val = max - else: - max_val = stype(max[max_idx]) - expected = refimpl(x_val, min_val, max_val) - if expected is None: - continue - out_val = stype(out[o_idx]) - if math.isnan(expected): - assert math.isnan(out_val), ( - f"out[{o_idx}]={out[o_idx]} but should be nan [clip()]\n" - f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}" - ) - else: - if out.dtype == xp.float32: - # conversion to builtin float is prone to roundoff errors - close_enough = math.isclose(out_val, expected, rel_tol=EPS32) + # If min or max are float64 and x is float32, they will need to be + # downcast to float32. This could result in a round in the wrong + # direction meaning the resulting clipped value might not actually be + # between min and max. This behavior is unspecified, so skip any cases + # where x is within the rounding error of downcasting min or max. + if x.dtype == xp.float32: + if min is not None and not dh.is_scalar(min) and min.dtype == xp.float64 and math.isfinite(_min): + _min_float32 = float(xp.asarray(_min, dtype=xp.float32)) + if math.isinf(_min_float32): + return None + tol = abs(_min - _min_float32) + if math.isclose(_min, _min_float32, abs_tol=tol): + return None + if max is not None and not dh.is_scalar(max) and max.dtype == xp.float64 and math.isfinite(_max): + _max_float32 = float(xp.asarray(_max, dtype=xp.float32)) + if math.isinf(_max_float32): + return None + tol = abs(_max - _max_float32) + if math.isclose(_max, _max_float32, abs_tol=tol): + return None + + if (math.isnan(_x) + or (_min is not None and math.isnan(_min)) + or (_max is not None and math.isnan(_max))): + return math.nan + if _min is _max is None: + return _x + if _max is None: + return builtins.max(_x, _min) + if _min is None: + return builtins.min(_x, _max) + return builtins.min(builtins.max(_x, _min), _max) + + stype = dh.get_scalar_type(x.dtype) + min_shape = () if min is None or dh.is_scalar(min) else min.shape + max_shape = () if max is None or dh.is_scalar(max) else max.shape + + for x_idx, min_idx, max_idx, o_idx in sh.iter_indices( + x.shape, min_shape, max_shape, out.shape): + x_val = stype(x[x_idx]) + if min is None or dh.is_scalar(min): + min_val = min else: - close_enough = out_val == expected - - assert close_enough, ( - f"out[{o_idx}]={out[o_idx]} but should be {expected} [clip()]\n" - f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}" - ) + min_val = stype(min[min_idx]) + if max is None or dh.is_scalar(max): + max_val = max + else: + max_val = stype(max[max_idx]) + expected = refimpl(x_val, min_val, max_val) + if expected is None: + continue + out_val = stype(out[o_idx]) + if math.isnan(expected): + assert math.isnan(out_val), ( + f"out[{o_idx}]={out[o_idx]} but should be nan [clip()]\n" + f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}" + ) + else: + if out.dtype == xp.float32: + # conversion to builtin float is prone to roundoff errors + close_enough = math.isclose(out_val, expected, rel_tol=EPS32) + else: + close_enough = out_val == expected + + assert close_enough, ( + f"out[{o_idx}]={out[o_idx]} but should be {expected} [clip()]\n" + f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}" + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2022.12") @pytest.mark.skipif(hh.complex_dtypes.is_empty, reason="no complex data types to draw from") @given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes())) def test_conj(x): - out = xp.conj(x) - ph.assert_dtype("conj", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("conj", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate")) + repro_snippet = ph.format_snippet(f"xp.conj({x!r})") + try: + out = xp.conj(x) + ph.assert_dtype("conj", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("conj", out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate")) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2023.12") @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_copysign(x1, x2): - out = xp.copysign(x1, x2) - ph.assert_dtype("copysign", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) - ph.assert_result_shape("copysign", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) - binary_assert_against_refimpl("copysign", x1, x2, out, math.copysign) - + repro_snippet = ph.format_snippet(f"xp.copysign({x1!r}, {x2!r})") + try: + out = xp.copysign(x1, x2) + ph.assert_dtype("copysign", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) + ph.assert_result_shape("copysign", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) + binary_assert_against_refimpl("copysign", x1, x2, out, math.copysign) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_cos(x): - out = xp.cos(x) - ph.assert_dtype("cos", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("cos", out_shape=out.shape, expected=x.shape) - refimpl = cmath.cos if x.dtype in dh.complex_dtypes else math.cos - unary_assert_against_refimpl("cos", x, out, refimpl) - + repro_snippet = ph.format_snippet(f"xp.cos({x!r})") + try: + out = xp.cos(x) + ph.assert_dtype("cos", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("cos", out_shape=out.shape, expected=x.shape) + refimpl = cmath.cos if x.dtype in dh.complex_dtypes else math.cos + unary_assert_against_refimpl("cos", x, out, refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_cosh(x): - out = xp.cosh(x) - ph.assert_dtype("cosh", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("cosh", out_shape=out.shape, expected=x.shape) - refimpl = cmath.cosh if x.dtype in dh.complex_dtypes else math.cosh - unary_assert_against_refimpl("cosh", x, out, refimpl) - + repro_snippet = ph.format_snippet(f"xp.cosh({x!r})") + try: + out = xp.cosh(x) + ph.assert_dtype("cosh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("cosh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.cosh if x.dtype in dh.complex_dtypes else math.cosh + unary_assert_against_refimpl("cosh", x, out, refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("divide", dh.all_float_dtypes)) @given(data=st.data()) @@ -1153,19 +1253,24 @@ def test_divide(ctx, data): if ctx.right_is_scalar: assume # TODO: assume what? - res = ctx.func(left, right) - - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - binary_param_assert_against_refimpl( - ctx, - left, - right, - res, - "/", - operator.truediv, - filter_=lambda s: cmath.isfinite(s) and s != 0, - ) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + res = ctx.func(left, right) + + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl( + ctx, + left, + right, + res, + "/", + operator.truediv, + filter_=lambda s: cmath.isfinite(s) and s != 0, + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("equal", dh.all_dtypes)) @@ -1174,72 +1279,91 @@ def test_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = ctx.func(left, right) - - binary_param_assert_dtype(ctx, left, right, out, xp.bool) - binary_param_assert_shape(ctx, left, right, out) - if not ctx.right_is_scalar: - # We manually promote the dtypes as incorrect internal type promotion - # could lead to false positives. For example - # - # >>> xp.equal( - # ... xp.asarray(1.0, dtype=xp.float32), - # ... xp.asarray(1.00000001, dtype=xp.float64), - # ... ) - # - # would erroneously be True if float64 downcasted to float32. - promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - left = xp.astype(left, promoted_dtype) - right = xp.astype(right, promoted_dtype) - binary_param_assert_against_refimpl( - ctx, left, right, out, "==", operator.eq, res_stype=bool - ) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + out = ctx.func(left, right) + + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) + if not ctx.right_is_scalar: + # We manually promote the dtypes as incorrect internal type promotion + # could lead to false positives. For example + # + # >>> xp.equal( + # ... xp.asarray(1.0, dtype=xp.float32), + # ... xp.asarray(1.00000001, dtype=xp.float64), + # ... ) + # + # would erroneously be True if float64 downcasted to float32. + promoted_dtype = dh.promotion_table[left.dtype, right.dtype] + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, "==", operator.eq, res_stype=bool + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_exp(x): - out = xp.exp(x) - ph.assert_dtype("exp", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("exp", out_shape=out.shape, expected=x.shape) - refimpl = cmath.exp if x.dtype in dh.complex_dtypes else math.exp - unary_assert_against_refimpl("exp", x, out, refimpl) - + repro_snippet = ph.format_snippet(f"xp.exp({x!r})") + try: + out = xp.exp(x) + ph.assert_dtype("exp", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("exp", out_shape=out.shape, expected=x.shape) + refimpl = cmath.exp if x.dtype in dh.complex_dtypes else math.exp + unary_assert_against_refimpl("exp", x, out, refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_expm1(x): - out = xp.expm1(x) - ph.assert_dtype("expm1", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("expm1", out_shape=out.shape, expected=x.shape) - if x.dtype in dh.complex_dtypes: - def refimpl(z): - # There's no cmath.expm1. Use - # - # exp(x+yi) - 1 - # = exp(x)exp(yi) - 1 - # = exp(x)(cos(y) + sin(y)i) - 1 - # = (exp(x) - 1)cos(y) + (cos(y) - 1) + exp(x)sin(y)i - # = expm1(x)cos(y) - 2sin(y/2)^2 + exp(x)sin(y)i - # - # where 1 - cos(y) = 2sin(y/2)^2 is used to avoid loss of - # significance near y = 0. - re, im = z.real, z.imag - return math.expm1(re)*math.cos(im) - 2*math.sin(im/2)**2 + 1j*math.exp(re)*math.sin(im) - else: - refimpl = math.expm1 - unary_assert_against_refimpl("expm1", x, out, refimpl) + repro_snippet = ph.format_snippet(f"xp.expm1({x!r})") + try: + out = xp.expm1(x) + ph.assert_dtype("expm1", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("expm1", out_shape=out.shape, expected=x.shape) + if x.dtype in dh.complex_dtypes: + def refimpl(z): + # There's no cmath.expm1. Use + # + # exp(x+yi) - 1 + # = exp(x)exp(yi) - 1 + # = exp(x)(cos(y) + sin(y)i) - 1 + # = (exp(x) - 1)cos(y) + (cos(y) - 1) + exp(x)sin(y)i + # = expm1(x)cos(y) - 2sin(y/2)^2 + exp(x)sin(y)i + # + # where 1 - cos(y) = 2sin(y/2)^2 is used to avoid loss of + # significance near y = 0. + re, im = z.real, z.imag + return math.expm1(re)*math.cos(im) - 2*math.sin(im/2)**2 + 1j*math.exp(re)*math.sin(im) + else: + refimpl = math.expm1 + unary_assert_against_refimpl("expm1", x, out, refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes())) def test_floor(x): - out = xp.floor(x) - ph.assert_dtype("floor", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("floor", out_shape=out.shape, expected=x.shape) - if x.dtype in dh.complex_dtypes: - def refimpl(z): - return complex(math.floor(z.real), math.floor(z.imag)) - else: - refimpl = math.floor - unary_assert_against_refimpl("floor", x, out, refimpl, strict_check=True) + repro_snippet = ph.format_snippet(f"xp.floor({x!r})") + try: + out = xp.floor(x) + ph.assert_dtype("floor", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("floor", out_shape=out.shape, expected=x.shape) + if x.dtype in dh.complex_dtypes: + def refimpl(z): + return complex(math.floor(z.real), math.floor(z.imag)) + else: + refimpl = math.floor + unary_assert_against_refimpl("floor", x, out, refimpl, strict_check=True) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.real_dtypes)) @@ -1254,11 +1378,16 @@ def test_floor_divide(ctx, data): else: assume(not xp.any(right == 0)) - res = ctx.func(left, right) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + res = ctx.func(left, right) - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - binary_param_assert_against_refimpl(ctx, left, right, res, "//", operator.floordiv) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl(ctx, left, right, res, "//", operator.floordiv) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("greater", dh.real_dtypes)) @@ -1267,18 +1396,23 @@ def test_greater(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = ctx.func(left, right) - - binary_param_assert_dtype(ctx, left, right, out, xp.bool) - binary_param_assert_shape(ctx, left, right, out) - if not ctx.right_is_scalar: - # See test_equal note - promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - left = xp.astype(left, promoted_dtype) - right = xp.astype(right, promoted_dtype) - binary_param_assert_against_refimpl( - ctx, left, right, out, ">", operator.gt, res_stype=bool - ) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + out = ctx.func(left, right) + + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) + if not ctx.right_is_scalar: + # See test_equal note + promoted_dtype = dh.promotion_table[left.dtype, right.dtype] + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, ">", operator.gt, res_stype=bool + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("greater_equal", dh.real_dtypes)) @@ -1287,69 +1421,99 @@ def test_greater_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = ctx.func(left, right) - - binary_param_assert_dtype(ctx, left, right, out, xp.bool) - binary_param_assert_shape(ctx, left, right, out) - if not ctx.right_is_scalar: - # See test_equal note - promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - left = xp.astype(left, promoted_dtype) - right = xp.astype(right, promoted_dtype) - binary_param_assert_against_refimpl( - ctx, left, right, out, ">=", operator.ge, res_stype=bool - ) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + out = ctx.func(left, right) + + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) + if not ctx.right_is_scalar: + # See test_equal note + promoted_dtype = dh.promotion_table[left.dtype, right.dtype] + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, ">=", operator.ge, res_stype=bool + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2023.12") @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_hypot(x1, x2): - out = xp.hypot(x1, x2) - _assert_correctness_binary( - "hypot", - math.hypot, - in_dtypes=[x1.dtype, x2.dtype], - in_shapes=[x1.shape, x2.shape], - in_arrs=[x1, x2], - out=out - ) + repro_snippet = ph.format_snippet(f"xp.hypot({x1!r}, {x2!r})") + try: + out = xp.hypot(x1, x2) + _assert_correctness_binary( + "hypot", + math.hypot, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2022.12") @pytest.mark.skipif(hh.complex_dtypes.is_empty, reason="no complex data types to draw from") @given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes())) def test_imag(x): - out = xp.imag(x) - ph.assert_dtype("imag", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype]) - ph.assert_shape("imag", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("imag", x, out, operator.attrgetter("imag")) + repro_snippet = ph.format_snippet(f"xp.imag({x!r})") + try: + out = xp.imag(x) + ph.assert_dtype("imag", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype]) + ph.assert_shape("imag", out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl("imag", x, out, operator.attrgetter("imag")) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) def test_isfinite(x): - out = xp.isfinite(x) - ph.assert_dtype("isfinite", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) - ph.assert_shape("isfinite", out_shape=out.shape, expected=x.shape) - refimpl = cmath.isfinite if x.dtype in dh.complex_dtypes else math.isfinite - unary_assert_against_refimpl("isfinite", x, out, refimpl, res_stype=bool) + repro_snippet = ph.format_snippet(f"xp.isfinite({x!r})") + try: + out = xp.isfinite(x) + ph.assert_dtype("isfinite", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) + ph.assert_shape("isfinite", out_shape=out.shape, expected=x.shape) + refimpl = cmath.isfinite if x.dtype in dh.complex_dtypes else math.isfinite + unary_assert_against_refimpl("isfinite", x, out, refimpl, res_stype=bool) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) def test_isinf(x): - out = xp.isinf(x) - ph.assert_dtype("isfinite", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) - ph.assert_shape("isinf", out_shape=out.shape, expected=x.shape) - refimpl = cmath.isinf if x.dtype in dh.complex_dtypes else math.isinf - unary_assert_against_refimpl("isinf", x, out, refimpl, res_stype=bool) + repro_snippet = ph.format_snippet(f"xp.isinf({x!r})") + try: + out = xp.isinf(x) + ph.assert_dtype("isfinite", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) + ph.assert_shape("isinf", out_shape=out.shape, expected=x.shape) + refimpl = cmath.isinf if x.dtype in dh.complex_dtypes else math.isinf + unary_assert_against_refimpl("isinf", x, out, refimpl, res_stype=bool) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) def test_isnan(x): - out = xp.isnan(x) - ph.assert_dtype("isnan", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) - ph.assert_shape("isnan", out_shape=out.shape, expected=x.shape) - refimpl = cmath.isnan if x.dtype in dh.complex_dtypes else math.isnan - unary_assert_against_refimpl("isnan", x, out, refimpl, res_stype=bool) + repro_snippet = ph.format_snippet(f"xp.isnan({x!r})") + try: + out = xp.isnan(x) + ph.assert_dtype("isnan", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) + ph.assert_shape("isnan", out_shape=out.shape, expected=x.shape) + refimpl = cmath.isnan if x.dtype in dh.complex_dtypes else math.isnan + unary_assert_against_refimpl("isnan", x, out, refimpl, res_stype=bool) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("less", dh.real_dtypes)) @@ -1358,18 +1522,23 @@ def test_less(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = ctx.func(left, right) - - binary_param_assert_dtype(ctx, left, right, out, xp.bool) - binary_param_assert_shape(ctx, left, right, out) - if not ctx.right_is_scalar: - # See test_equal note - promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - left = xp.astype(left, promoted_dtype) - right = xp.astype(right, promoted_dtype) - binary_param_assert_against_refimpl( - ctx, left, right, out, "<", operator.lt, res_stype=bool - ) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + out = ctx.func(left, right) + + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) + if not ctx.right_is_scalar: + # See test_equal note + promoted_dtype = dh.promotion_table[left.dtype, right.dtype] + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, "<", operator.lt, res_stype=bool + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("less_equal", dh.real_dtypes)) @@ -1378,81 +1547,106 @@ def test_less_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = ctx.func(left, right) - - binary_param_assert_dtype(ctx, left, right, out, xp.bool) - binary_param_assert_shape(ctx, left, right, out) - if not ctx.right_is_scalar: - # See test_equal note - promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - left = xp.astype(left, promoted_dtype) - right = xp.astype(right, promoted_dtype) - binary_param_assert_against_refimpl( - ctx, left, right, out, "<=", operator.le, res_stype=bool - ) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + out = ctx.func(left, right) + + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) + if not ctx.right_is_scalar: + # See test_equal note + promoted_dtype = dh.promotion_table[left.dtype, right.dtype] + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, "<=", operator.le, res_stype=bool + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_log(x): - out = xp.log(x) - ph.assert_dtype("log", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("log", out_shape=out.shape, expected=x.shape) - refimpl = cmath.log if x.dtype in dh.complex_dtypes else math.log - filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 - unary_assert_against_refimpl( - "log", x, out, refimpl, filter_=filter_ - ) + repro_snippet = ph.format_snippet(f"xp.log({x!r})") + try: + out = xp.log(x) + ph.assert_dtype("log", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("log", out_shape=out.shape, expected=x.shape) + refimpl = cmath.log if x.dtype in dh.complex_dtypes else math.log + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 + unary_assert_against_refimpl( + "log", x, out, refimpl, filter_=filter_ + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_log1p(x): - out = xp.log1p(x) - ph.assert_dtype("log1p", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("log1p", out_shape=out.shape, expected=x.shape) - # There isn't a cmath.log1p, and implementing one isn't straightforward - # (see - # https://stackoverflow.com/questions/78318212/unexpected-behaviour-of-log1p-numpy). - # For now, just use log(1+p) for complex inputs, which should hopefully be - # fine given the very loose numerical tolerances we use. If it isn't, we - # can try using something like a series expansion for small p. - if x.dtype in dh.complex_dtypes: - refimpl = lambda z: cmath.log(1+z) - else: - refimpl = math.log1p - filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > -1 - unary_assert_against_refimpl( - "log1p", x, out, refimpl, filter_=filter_ - ) + repro_snippet = ph.format_snippet(f"xp.log1p({x!r})") + try: + out = xp.log1p(x) + ph.assert_dtype("log1p", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("log1p", out_shape=out.shape, expected=x.shape) + # There isn't a cmath.log1p, and implementing one isn't straightforward + # (see + # https://stackoverflow.com/questions/78318212/unexpected-behaviour-of-log1p-numpy). + # For now, just use log(1+p) for complex inputs, which should hopefully be + # fine given the very loose numerical tolerances we use. If it isn't, we + # can try using something like a series expansion for small p. + if x.dtype in dh.complex_dtypes: + refimpl = lambda z: cmath.log(1+z) + else: + refimpl = math.log1p + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > -1 + unary_assert_against_refimpl( + "log1p", x, out, refimpl, filter_=filter_ + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_log2(x): - out = xp.log2(x) - ph.assert_dtype("log2", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("log2", out_shape=out.shape, expected=x.shape) - if x.dtype in dh.complex_dtypes: - refimpl = lambda z: cmath.log(z)/math.log(2) - else: - refimpl = math.log2 - filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 - unary_assert_against_refimpl( - "log2", x, out, refimpl, filter_=filter_ - ) + repro_snippet = ph.format_snippet(f"xp.log2({x!r})") + try: + out = xp.log2(x) + ph.assert_dtype("log2", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("log2", out_shape=out.shape, expected=x.shape) + if x.dtype in dh.complex_dtypes: + refimpl = lambda z: cmath.log(z)/math.log(2) + else: + refimpl = math.log2 + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 + unary_assert_against_refimpl( + "log2", x, out, refimpl, filter_=filter_ + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_log10(x): - out = xp.log10(x) - ph.assert_dtype("log10", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("log10", out_shape=out.shape, expected=x.shape) - if x.dtype in dh.complex_dtypes: - refimpl = lambda z: cmath.log(z)/math.log(10) - else: - refimpl = math.log10 - filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 - unary_assert_against_refimpl( - "log10", x, out, refimpl, filter_=filter_ - ) + repro_snippet = ph.format_snippet(f"xp.log10({x!r})") + try: + out = xp.log10(x) + ph.assert_dtype("log10", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("log10", out_shape=out.shape, expected=x.shape) + if x.dtype in dh.complex_dtypes: + refimpl = lambda z: cmath.log(z)/math.log(10) + else: + refimpl = math.log10 + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 + unary_assert_against_refimpl( + "log10", x, out, refimpl, filter_=filter_ + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise def logaddexp_refimpl(l: float, r: float) -> float: @@ -1465,85 +1659,120 @@ def logaddexp_refimpl(l: float, r: float) -> float: @pytest.mark.min_version("2023.12") @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_logaddexp(x1, x2): - out = xp.logaddexp(x1, x2) - _assert_correctness_binary( - "logaddexp", - logaddexp_refimpl, - in_dtypes=[x1.dtype, x2.dtype], - in_shapes=[x1.shape, x2.shape], - in_arrs=[x1, x2], - out=out - ) + repro_snippet = ph.format_snippet(f"xp.logaddexp({x1!r}, {x2!r})") + try: + out = xp.logaddexp(x1, x2) + _assert_correctness_binary( + "logaddexp", + logaddexp_refimpl, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=xp.bool, shape=hh.shapes())) def test_logical_not(x): - out = xp.logical_not(x) - ph.assert_dtype("logical_not", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("logical_not", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl( - "logical_not", x, out, operator.not_, expr_template="(not {})={}" - ) + repro_snippet = ph.format_snippet(f"xp.logical_not({x!r})") + try: + out = xp.logical_not(x) + ph.assert_dtype("logical_not", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("logical_not", out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl( + "logical_not", x, out, operator.not_, expr_template="(not {})={}" + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_and(x1, x2): - out = xp.logical_and(x1, x2) - _assert_correctness_binary( - "logical_and", - operator.and_, - in_dtypes=[x1.dtype, x2.dtype], - in_shapes=[x1.shape, x2.shape], - in_arrs=[x1, x2], - out=out, - expr_template="({} and {})={}" - ) + repro_snippet = ph.format_snippet(f"xp.logical_and({x1!r}, {x2!r})") + try: + out = xp.logical_and(x1, x2) + _assert_correctness_binary( + "logical_and", + operator.and_, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out, + expr_template="({} and {})={}" + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_or(x1, x2): - out = xp.logical_or(x1, x2) - _assert_correctness_binary( - "logical_or", - operator.or_, - in_dtypes=[x1.dtype, x2.dtype], - in_shapes=[x1.shape, x2.shape], - in_arrs=[x1, x2], - out=out, - expr_template="({} or {})={}" - ) + repro_snippet = ph.format_snippet(f"xp.logical_or({x1!r}, {x2!r})") + try: + out = xp.logical_or(x1, x2) + _assert_correctness_binary( + "logical_or", + operator.or_, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out, + expr_template="({} or {})={}" + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_xor(x1, x2): - out = xp.logical_xor(x1, x2) - _assert_correctness_binary( - "logical_xor", - operator.xor, - in_dtypes=[x1.dtype, x2.dtype], - in_shapes=[x1.shape, x2.shape], - in_arrs=[x1, x2], - out=out, - expr_template="({} ^ {})={}" - ) + repro_snippet = ph.format_snippet(f"xp.logical_xor({x1!r}, {x2!r})") + try: + out = xp.logical_xor(x1, x2) + _assert_correctness_binary( + "logical_xor", + operator.xor, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out, + expr_template="({} ^ {})={}" + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2023.12") @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_maximum(x1, x2): - out = xp.maximum(x1, x2) - _assert_correctness_binary( - "maximum", max, [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out, strict_check=True - ) + repro_snippet = ph.format_snippet(f"xp.maximum({x1!r}, {x2!r})") + try: + out = xp.maximum(x1, x2) + _assert_correctness_binary( + "maximum", max, [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out, strict_check=True + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2023.12") @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_minimum(x1, x2): - out = xp.minimum(x1, x2) - _assert_correctness_binary( - "minimum", min, [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out, strict_check=True - ) + repro_snippet = ph.format_snippet(f"xp.minumum({x1!r}, {x2!r})") + try: + out = xp.minimum(x1, x2) + _assert_correctness_binary( + "minimum", min, [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out, strict_check=True + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes)) @@ -1552,11 +1781,16 @@ def test_multiply(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - res = ctx.func(left, right) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + res = ctx.func(left, right) - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - binary_param_assert_against_refimpl(ctx, left, right, res, "*", operator.mul) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl(ctx, left, right, res, "*", operator.mul) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise # TODO: clarify if uints are acceptable, adjust accordingly @@ -1568,14 +1802,18 @@ def test_negative(ctx, data): if x.dtype in dh.int_dtypes: assume(xp.all(x > dh.dtype_ranges[x.dtype].min)) - out = ctx.func(x) - - ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl( - ctx.func_name, x, out, operator.neg, expr_template="-({})={}" # type: ignore - ) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({x!r}") + try: + out = ctx.func(x) + ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl( + ctx.func_name, x, out, operator.neg, expr_template="-({})={}" # type: ignore + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("not_equal", dh.all_dtypes)) @given(data=st.data()) @@ -1583,18 +1821,23 @@ def test_not_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = ctx.func(left, right) - - binary_param_assert_dtype(ctx, left, right, out, xp.bool) - binary_param_assert_shape(ctx, left, right, out) - if not ctx.right_is_scalar: - # See test_equal note - promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - left = xp.astype(left, promoted_dtype) - right = xp.astype(right, promoted_dtype) - binary_param_assert_against_refimpl( - ctx, left, right, out, "!=", operator.ne, res_stype=bool - ) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + out = ctx.func(left, right) + + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) + if not ctx.right_is_scalar: + # See test_equal note + promoted_dtype = dh.promotion_table[left.dtype, right.dtype] + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, "!=", operator.ne, res_stype=bool + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2024.12") @@ -1607,26 +1850,37 @@ def test_nextafter(shapes, dtype, data): x1 = data.draw(hh.arrays(dtype=dtype, shape=shapes[0]), label="x1") x2 = data.draw(hh.arrays(dtype=dtype, shape=shapes[0]), label="x2") - out = xp.nextafter(x1, x2) - _assert_correctness_binary( - "nextafter", - math.nextafter, - in_dtypes=[x1.dtype, x2.dtype], - in_shapes=[x1.shape, x2.shape], - in_arrs=[x1, x2], - out=out - ) + repro_snippet = ph.format_snippet(f"xp.nextafter({x1!r}, {x2!r})") + try: + out = xp.nextafter(x1, x2) + _assert_correctness_binary( + "nextafter", + math.nextafter, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise + @pytest.mark.parametrize("ctx", make_unary_params("positive", dh.numeric_dtypes)) @given(data=st.data()) def test_positive(ctx, data): x = data.draw(ctx.strat, label="x") - out = ctx.func(x) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({x!r})") + try: + out = ctx.func(x) - ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) - ph.assert_array_elements(ctx.func_name, out=out, expected=x) + ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) + ph.assert_array_elements(ctx.func_name, out=out, expected=x) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("pow", dh.numeric_dtypes)) @@ -1641,38 +1895,53 @@ def test_pow(ctx, data): if dh.is_int_dtype(right.dtype): assume(xp.all(right >= 0)) - with hh.reject_overflow(): - res = ctx.func(left, right) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + with hh.reject_overflow(): + res = ctx.func(left, right) - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - # Values testing pow is too finicky + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + # Values testing pow is too finicky + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2022.12") @pytest.mark.skipif(hh.complex_dtypes.is_empty, reason="no complex data types to draw from") @given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes())) def test_real(x): - out = xp.real(x) - ph.assert_dtype("real", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype]) - ph.assert_shape("real", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("real", x, out, operator.attrgetter("real")) + repro_snippet = ph.format_snippet(f"xp.real({x!r})") + try: + out = xp.real(x) + ph.assert_dtype("real", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype]) + ph.assert_shape("real", out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl("real", x, out, operator.attrgetter("real")) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2024.12") @given(hh.arrays(dtype=hh.floating_dtypes, shape=hh.shapes(), elements=finite_kw)) def test_reciprocal(x): - out = xp.reciprocal(x) - ph.assert_dtype("reciprocal", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("reciprocal", out_shape=out.shape, expected=x.shape) - refimpl = lambda x: 1.0 / x - unary_assert_against_refimpl( - "reciprocal", - x, - out, - refimpl, - strict_check=True, - ) + repro_snippet = ph.format_snippet(f"xp.reciprocal({x!r})") + try: + out = xp.reciprocal(x) + ph.assert_dtype("reciprocal", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("reciprocal", out_shape=out.shape, expected=x.shape) + refimpl = lambda x: 1.0 / x + unary_assert_against_refimpl( + "reciprocal", + x, + out, + refimpl, + strict_check=True, + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.skip(reason="flaky") @@ -1686,88 +1955,128 @@ def test_remainder(ctx, data): else: assume(not xp.any(right == 0)) - res = ctx.func(left, right) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + res = ctx.func(left, right) - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - binary_param_assert_against_refimpl(ctx, left, right, res, "%", operator.mod) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl(ctx, left, right, res, "%", operator.mod) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) def test_round(x): - out = xp.round(x) - ph.assert_dtype("round", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("round", out_shape=out.shape, expected=x.shape) - if x.dtype in dh.complex_dtypes: - refimpl = lambda z: complex(round(z.real), round(z.imag)) - else: - refimpl = round - unary_assert_against_refimpl("round", x, out, refimpl, strict_check=True) + repro_snippet = ph.format_snippet(f"xp.round({x!r})") + try: + out = xp.round(x) + ph.assert_dtype("round", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("round", out_shape=out.shape, expected=x.shape) + if x.dtype in dh.complex_dtypes: + refimpl = lambda z: complex(round(z.real), round(z.imag)) + else: + refimpl = round + unary_assert_against_refimpl("round", x, out, refimpl, strict_check=True) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2023.12") @given(hh.arrays(dtype=hh.real_floating_dtypes, shape=hh.shapes())) def test_signbit(x): - out = xp.signbit(x) - ph.assert_dtype("signbit", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) - ph.assert_shape("signbit", out_shape=out.shape, expected=x.shape) - refimpl = lambda x: math.copysign(1.0, x) < 0 - unary_assert_against_refimpl("round", x, out, refimpl, strict_check=True) + repro_snippet = ph.format_snippet(f"xp.signbit({x!r})") + try: + out = xp.signbit(x) + ph.assert_dtype("signbit", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) + ph.assert_shape("signbit", out_shape=out.shape, expected=x.shape) + refimpl = lambda x: math.copysign(1.0, x) < 0 + unary_assert_against_refimpl("round", x, out, refimpl, strict_check=True) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes(), elements=finite_kw)) def test_sign(x): - out = xp.sign(x) - ph.assert_dtype("sign", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("sign", out_shape=out.shape, expected=x.shape) - refimpl = lambda x: x / abs(x) if x != 0 else 0 - unary_assert_against_refimpl( - "sign", - x, - out, - refimpl, - strict_check=True, - ) + repro_snippet = ph.format_snippet(f"xp.sign({x!r})") + try: + out = xp.sign(x) + ph.assert_dtype("sign", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("sign", out_shape=out.shape, expected=x.shape) + refimpl = lambda x: x / abs(x) if x != 0 else 0 + unary_assert_against_refimpl( + "sign", + x, + out, + refimpl, + strict_check=True, + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_sin(x): - out = xp.sin(x) - ph.assert_dtype("sin", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("sin", out_shape=out.shape, expected=x.shape) - refimpl = cmath.sin if x.dtype in dh.complex_dtypes else math.sin - unary_assert_against_refimpl("sin", x, out, refimpl) + repro_snippet = ph.format_snippet(f"xp.sin({x!r})") + try: + out = xp.sin(x) + ph.assert_dtype("sin", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("sin", out_shape=out.shape, expected=x.shape) + refimpl = cmath.sin if x.dtype in dh.complex_dtypes else math.sin + unary_assert_against_refimpl("sin", x, out, refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_sinh(x): - out = xp.sinh(x) - ph.assert_dtype("sinh", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("sinh", out_shape=out.shape, expected=x.shape) - refimpl = cmath.sinh if x.dtype in dh.complex_dtypes else math.sinh - unary_assert_against_refimpl("sinh", x, out, refimpl) + repro_snippet = ph.format_snippet(f"xp.sinh({x!r})") + try: + out = xp.sinh(x) + ph.assert_dtype("sinh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("sinh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.sinh if x.dtype in dh.complex_dtypes else math.sinh + unary_assert_against_refimpl("sinh", x, out, refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) def test_square(x): - out = xp.square(x) - ph.assert_dtype("square", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("square", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl( - "square", x, out, lambda s: s*s, expr_template="{}²={}" - ) + repro_snippet = ph.format_snippet(f"xp.square({x!r})") + try: + out = xp.square(x) + ph.assert_dtype("square", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("square", out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl( + "square", x, out, lambda s: s*s, expr_template="{}²={}" + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_sqrt(x): - out = xp.sqrt(x) - ph.assert_dtype("sqrt", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("sqrt", out_shape=out.shape, expected=x.shape) - refimpl = cmath.sqrt if x.dtype in dh.complex_dtypes else math.sqrt - filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s >= 0 - unary_assert_against_refimpl( - "sqrt", x, out, refimpl, filter_=filter_ - ) + repro_snippet = ph.format_snippet(f"xp.sqrt({x!r})") + try: + out = xp.sqrt(x) + ph.assert_dtype("sqrt", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("sqrt", out_shape=out.shape, expected=x.shape) + refimpl = cmath.sqrt if x.dtype in dh.complex_dtypes else math.sqrt + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s >= 0 + unary_assert_against_refimpl( + "sqrt", x, out, refimpl, filter_=filter_ + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("subtract", dh.numeric_dtypes)) @@ -1776,50 +2085,73 @@ def test_subtract(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - with hh.reject_overflow(): - res = ctx.func(left, right) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + with hh.reject_overflow(): + res = ctx.func(left, right) - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - binary_param_assert_against_refimpl(ctx, left, right, res, "-", operator.sub) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl(ctx, left, right, res, "-", operator.sub) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_tan(x): - out = xp.tan(x) - ph.assert_dtype("tan", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("tan", out_shape=out.shape, expected=x.shape) - refimpl = cmath.tan if x.dtype in dh.complex_dtypes else math.tan - unary_assert_against_refimpl("tan", x, out, refimpl) + repro_snippet = ph.format_snippet(f"xp.tan({x!r})") + try: + out = xp.tan(x) + ph.assert_dtype("tan", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("tan", out_shape=out.shape, expected=x.shape) + refimpl = cmath.tan if x.dtype in dh.complex_dtypes else math.tan + unary_assert_against_refimpl("tan", x, out, refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_tanh(x): - out = xp.tanh(x) - ph.assert_dtype("tanh", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("tanh", out_shape=out.shape, expected=x.shape) - refimpl = cmath.tanh if x.dtype in dh.complex_dtypes else math.tanh - unary_assert_against_refimpl("tanh", x, out, refimpl) - + repro_snippet = ph.format_snippet(f"xp.tanh({x!r})") + try: + out = xp.tanh(x) + ph.assert_dtype("tanh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("tanh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.tanh if x.dtype in dh.complex_dtypes else math.tanh + unary_assert_against_refimpl("tanh", x, out, refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.real_dtypes, shape=xps.array_shapes())) def test_trunc(x): - out = xp.trunc(x) - ph.assert_dtype("trunc", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("trunc", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("trunc", x, out, math.trunc, strict_check=True) - + repro_snippet = ph.format_snippet(f"xp.trunc({x!r})") + try: + out = xp.trunc(x) + ph.assert_dtype("trunc", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("trunc", out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl("trunc", x, out, math.trunc, strict_check=True) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise def _check_binary_with_scalars(func_data, x1x2): x1, x2 = x1x2 func_name, refimpl, kwds, expected_dtype = func_data func = getattr(xp, func_name) - out = func(x1, x2) - in_dtypes, in_shapes, (x1a, x2a) = _convert_scalars_helper(x1, x2) - _assert_correctness_binary( - func_name, refimpl, in_dtypes, in_shapes, (x1a, x2a), out, expected_dtype, **kwds - ) + repro_snippet = ph.format_snippet(f"xp.{func_name}({x1!r}, {x2!r})") + try: + out = func(x1, x2) + in_dtypes, in_shapes, (x1a, x2a) = _convert_scalars_helper(x1, x2) + _assert_correctness_binary( + func_name, refimpl, in_dtypes, in_shapes, (x1a, x2a), out, expected_dtype, **kwds + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise def _filter_zero(x): return x != 0 if dh.is_scalar(x) else (not xp.any(x == 0)) @@ -1940,16 +2272,19 @@ def test_where_with_scalars(x1x2, data): condition = data.draw(hh.arrays(shape=shape, dtype=xp.bool)) - out = xp.where(condition, x1, x2) - - assert out.dtype == dtype, f"where: got {out.dtype = } for {dtype=}, {x1=} and {x2=}" - assert out.shape == shape, f"where: got {out.shape = } for {shape=}, {x1=} and {x2=}" - - # value test - for idx in sh.ndindex(shape): - if condition[idx]: - assert out[idx] == x1_arr[idx] - else: - assert out[idx] == x2_arr[idx] + repro_snippet = ph.format_snippet(f"xp.where({condition!r}, {x1!r}, {x2!r})") + try: + out = xp.where(condition, x1, x2) + assert out.dtype == dtype, f"where: got {out.dtype = } for {dtype=}, {x1=} and {x2=}" + assert out.shape == shape, f"where: got {out.shape = } for {shape=}, {x1=} and {x2=}" + # value test + for idx in sh.ndindex(shape): + if condition[idx]: + assert out[idx] == x1_arr[idx] + else: + assert out[idx] == x2_arr[idx] + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise