|
| 1 | +"""Targeted edge-path tests for low-frequency runtime branches.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +import shapix._memo as memo_mod |
| 8 | +from beartype import BeartypeConf |
| 9 | +from shapix import N |
| 10 | +from shapix._array_types import _StructChecker, _to_shape_spec, make_array_type |
| 11 | +from shapix._dtypes import FLOAT32, extract_dtype_str |
| 12 | +from shapix._memo import ShapeMemo, bindings_str, get_memo |
| 13 | +from shapix._shape import FixedDim, NamedDim, VariadicDim, check_shape |
| 14 | +from shapix._tree import _TreeFactory |
| 15 | +from shapix.claw import shapix_this_package |
| 16 | +from shapix.numpy import F32 |
| 17 | + |
| 18 | + |
| 19 | +class TestArrayFactoryEdges: |
| 20 | + def test_shape_checker_rejects_obj_with_dtype_but_no_shape(self) -> None: |
| 21 | + class DtypeOnly: |
| 22 | + dtype = np.dtype(np.float32) |
| 23 | + |
| 24 | + checker = _StructChecker(FLOAT32, (NamedDim("N"),)) |
| 25 | + assert checker(DtypeOnly()) is False |
| 26 | + |
| 27 | + def test_array_factory_repr(self) -> None: |
| 28 | + factory = make_array_type(np.ndarray, FLOAT32) |
| 29 | + assert repr(factory) == "Float32Array" |
| 30 | + |
| 31 | + def test_to_shape_spec_non_dimension_falls_back_to_named(self) -> None: |
| 32 | + class Token: |
| 33 | + def __str__(self) -> str: |
| 34 | + return "Token" |
| 35 | + |
| 36 | + spec = _to_shape_spec((Token(),)) |
| 37 | + assert spec == (NamedDim("Token"),) |
| 38 | + |
| 39 | + |
| 40 | +class TestShapeEdgeBranches: |
| 41 | + def test_variadic_prefix_error_is_returned(self) -> None: |
| 42 | + memo = ShapeMemo() |
| 43 | + spec = (FixedDim(2), VariadicDim("B"), NamedDim("C")) |
| 44 | + err = check_shape((3, 5, 7), spec, memo) |
| 45 | + assert "expected 2 but got 3" in err |
| 46 | + |
| 47 | + def test_variadic_suffix_error_is_returned(self) -> None: |
| 48 | + memo = ShapeMemo() |
| 49 | + spec = (NamedDim("N"), VariadicDim("B"), FixedDim(7)) |
| 50 | + err = check_shape((3, 5, 6), spec, memo) |
| 51 | + assert "expected 7 but got 6" in err |
| 52 | + |
| 53 | + def test_variadic_broadcast_failure_path(self) -> None: |
| 54 | + memo = ShapeMemo() |
| 55 | + spec = (VariadicDim("B", broadcastable=True), NamedDim("C")) |
| 56 | + |
| 57 | + assert check_shape((2, 3, 4), spec, memo) == "" |
| 58 | + err = check_shape((5, 4, 4), spec, memo) |
| 59 | + assert "cannot broadcast" in err |
| 60 | + |
| 61 | + |
| 62 | +class TestDtypeEdgeBranches: |
| 63 | + def test_extract_dtype_str_tensorflow_style_dtype(self) -> None: |
| 64 | + class FakeDType: |
| 65 | + as_numpy_dtype = np.float32 |
| 66 | + |
| 67 | + class FakeTensor: |
| 68 | + dtype = FakeDType() |
| 69 | + |
| 70 | + assert extract_dtype_str(FakeTensor()) == "float32" |
| 71 | + |
| 72 | + |
| 73 | +class TestMemoEdgeBranches: |
| 74 | + def test_get_memo_falls_back_if_frame_lookup_fails(self, monkeypatch) -> None: |
| 75 | + def _raise_value_error(_depth: int = 0) -> None: |
| 76 | + raise ValueError |
| 77 | + |
| 78 | + monkeypatch.setattr(memo_mod.sys, "_getframe", _raise_value_error) |
| 79 | + assert isinstance(get_memo(_depth=999), ShapeMemo) |
| 80 | + |
| 81 | + def test_bindings_str_formats_single_and_variadic_bindings(self) -> None: |
| 82 | + memo = ShapeMemo(single={"N": 3}, variadic={"B": (False, (2, 4))}) |
| 83 | + formatted = bindings_str(memo) |
| 84 | + assert "N=3" in formatted |
| 85 | + assert "~B=(2, 4)" in formatted |
| 86 | + |
| 87 | + |
| 88 | +class TestTreeFactoryEdgeBranches: |
| 89 | + def test_tuple_with_only_leaf_type(self) -> None: |
| 90 | + tree = _TreeFactory(object, name="Tree") |
| 91 | + hint = tree[(F32[N],)] |
| 92 | + assert hasattr(hint, "__metadata__") |
| 93 | + |
| 94 | + def test_tree_factory_repr(self) -> None: |
| 95 | + tree = _TreeFactory(object, name="MyTree") |
| 96 | + assert repr(tree) == "MyTree" |
| 97 | + |
| 98 | + def test_tree_factory_empty_tuple_raises(self) -> None: |
| 99 | + tree = _TreeFactory(object, name="Tree") |
| 100 | + import pytest |
| 101 | + |
| 102 | + with pytest.raises(TypeError, match="at least a leaf type"): |
| 103 | + tree[()] |
| 104 | + |
| 105 | + def test_tree_factory_single_item_not_tuple(self) -> None: |
| 106 | + tree = _TreeFactory(object, name="Tree") |
| 107 | + hint = tree[F32[N]] |
| 108 | + assert hasattr(hint, "__metadata__") |
| 109 | + |
| 110 | + def test_tree_single_ellipsis_raises(self) -> None: |
| 111 | + tree = _TreeFactory(object, name="Tree") |
| 112 | + import pytest |
| 113 | + |
| 114 | + with pytest.raises(TypeError, match="At least one structure name"): |
| 115 | + tree[F32[N], ...] |
| 116 | + |
| 117 | + def test_tree_both_ellipsis_raises(self) -> None: |
| 118 | + from shapix._tree import Structure |
| 119 | + |
| 120 | + tree = _TreeFactory(object, name="Tree") |
| 121 | + X = Structure("X") |
| 122 | + import pytest |
| 123 | + |
| 124 | + with pytest.raises(TypeError, match="Cannot have"): |
| 125 | + tree[F32[N], ..., X, ...] |
| 126 | + |
| 127 | + |
| 128 | +class TestArrayFactoryShapeSpecEdges: |
| 129 | + def test_scalar_dim_to_shape_spec(self) -> None: |
| 130 | + from shapix._dimensions import Scalar |
| 131 | + |
| 132 | + specs = _to_shape_spec((Scalar,)) |
| 133 | + assert specs == () |
| 134 | + |
| 135 | + def test_mixed_type_shape_spec(self) -> None: |
| 136 | + from shapix._dimensions import Dimension |
| 137 | + from shapix._shape import ANONYMOUS_VARIADIC |
| 138 | + |
| 139 | + specs = _to_shape_spec((3, Dimension("N"), Ellipsis)) |
| 140 | + assert specs[0] == FixedDim(3) |
| 141 | + assert specs[1] == NamedDim("N") |
| 142 | + assert specs[2] is ANONYMOUS_VARIADIC |
| 143 | + |
| 144 | + def test_memo_restored_on_failure(self) -> None: |
| 145 | + checker = _StructChecker(FLOAT32, (NamedDim("N"),)) |
| 146 | + arr = np.ones((10,), dtype=np.float32) |
| 147 | + # Should fail because N=5 != 10, and memo should be restored |
| 148 | + from shapix._memo import push_memo, pop_memo |
| 149 | + |
| 150 | + push_memo_ref = push_memo() |
| 151 | + push_memo_ref.single["N"] = 5 |
| 152 | + result = checker(arr) |
| 153 | + assert result is False |
| 154 | + pop_memo() |
| 155 | + |
| 156 | + |
| 157 | +class TestArrayLikeCheckerEdges: |
| 158 | + def test_casting_no_uses_strict_match(self) -> None: |
| 159 | + """casting='no' should only accept exact dtype match.""" |
| 160 | + from shapix._array_types import _ArrayLikeChecker |
| 161 | + |
| 162 | + checker = _ArrayLikeChecker(FLOAT32, (NamedDim("X"),), casting="no", name="F32Like") |
| 163 | + assert checker(np.ones(3, dtype=np.float32)) is True |
| 164 | + assert checker(np.ones(3, dtype=np.float64)) is False |
| 165 | + |
| 166 | + def test_wildcard_dtype_accepts_anything(self) -> None: |
| 167 | + """SHAPED's wildcard '*' in allowed should accept any dtype.""" |
| 168 | + from shapix._array_types import _ArrayLikeChecker |
| 169 | + from shapix._dtypes import SHAPED |
| 170 | + |
| 171 | + checker = _ArrayLikeChecker( |
| 172 | + SHAPED, (NamedDim("X"),), casting="same_kind", name="ShapedLike" |
| 173 | + ) |
| 174 | + assert checker(np.ones(3, dtype=np.float32)) is True |
| 175 | + assert checker(np.ones(3, dtype=np.int64)) is True |
| 176 | + assert checker(np.ones(3, dtype=np.bool_)) is True |
| 177 | + |
| 178 | + def test_asarray_failure_returns_false(self) -> None: |
| 179 | + """Objects that can't be converted to array should return False.""" |
| 180 | + from shapix._array_types import _ArrayLikeChecker |
| 181 | + |
| 182 | + checker = _ArrayLikeChecker( |
| 183 | + FLOAT32, (NamedDim("X"),), casting="same_kind", name="F32Like" |
| 184 | + ) |
| 185 | + # object() has no shape/dtype → takes slow path → np.asarray(object()) succeeds |
| 186 | + # but a custom __array__ that raises should trigger the except branch |
| 187 | + |
| 188 | + class Unconvertible: |
| 189 | + def __array__(self, *_a: object, **_kw: object) -> None: # noqa: PLW3201 |
| 190 | + raise TypeError("nope") |
| 191 | + |
| 192 | + assert checker(Unconvertible()) is False |
| 193 | + |
| 194 | + def test_asarray_slow_path_bad_dtype(self) -> None: |
| 195 | + """Slow-path object whose dtype string is empty should fail.""" |
| 196 | + from shapix._array_types import _ArrayLikeChecker |
| 197 | + |
| 198 | + # A plain list goes through slow path (no .shape/.dtype), converts to array |
| 199 | + # but complex128 can't cast to float32 under same_kind |
| 200 | + assert ( |
| 201 | + _ArrayLikeChecker( |
| 202 | + FLOAT32, (NamedDim("X"),), casting="same_kind", name="F32Like" |
| 203 | + )([(1 + 2j)]) |
| 204 | + is False |
| 205 | + ) |
| 206 | + |
| 207 | + def test_arraylike_memo_restore_on_shape_failure(self) -> None: |
| 208 | + """ArrayLikeChecker should restore memo on shape mismatch.""" |
| 209 | + from shapix._array_types import _ArrayLikeChecker |
| 210 | + from shapix._memo import push_memo, pop_memo |
| 211 | + |
| 212 | + checker = _ArrayLikeChecker( |
| 213 | + FLOAT32, (NamedDim("N"),), casting="same_kind", name="F32Like" |
| 214 | + ) |
| 215 | + memo = push_memo() |
| 216 | + memo.single["N"] = 5 |
| 217 | + |
| 218 | + # Should fail because N=5 but array has shape (10,) |
| 219 | + result = checker(np.ones(10, dtype=np.float32)) |
| 220 | + assert result is False |
| 221 | + assert memo.single["N"] == 5 # memo restored |
| 222 | + pop_memo() |
| 223 | + |
| 224 | + def test_arraylike_dtype_no_source_returns_false(self) -> None: |
| 225 | + """Object with dtype but no extractable string should fail.""" |
| 226 | + from shapix._array_types import _ArrayLikeChecker |
| 227 | + |
| 228 | + class WeirdDtype: |
| 229 | + pass |
| 230 | + |
| 231 | + class WeirdObj: |
| 232 | + dtype = WeirdDtype() |
| 233 | + shape = (3,) |
| 234 | + |
| 235 | + checker = _ArrayLikeChecker( |
| 236 | + FLOAT32, (NamedDim("X"),), casting="same_kind", name="F32Like" |
| 237 | + ) |
| 238 | + assert checker(WeirdObj()) is False |
| 239 | + |
| 240 | + def test_arraylike_can_cast_type_error(self) -> None: |
| 241 | + """np.can_cast TypeError should be caught gracefully.""" |
| 242 | + from shapix._array_types import _ArrayLikeChecker |
| 243 | + from shapix._dtypes import DtypeSpec |
| 244 | + |
| 245 | + # Create a spec with a bogus target dtype name |
| 246 | + bogus = DtypeSpec("Bogus", frozenset({"not_a_real_dtype"})) |
| 247 | + checker = _ArrayLikeChecker( |
| 248 | + bogus, (NamedDim("X"),), casting="same_kind", name="BogusLike" |
| 249 | + ) |
| 250 | + assert checker(np.ones(3, dtype=np.float32)) is False |
| 251 | + |
| 252 | + def test_arraylike_repr(self) -> None: |
| 253 | + from shapix._array_types import _ArrayLikeChecker |
| 254 | + |
| 255 | + checker = _ArrayLikeChecker( |
| 256 | + FLOAT32, (NamedDim("N"), NamedDim("C")), casting="same_kind", name="F32Like" |
| 257 | + ) |
| 258 | + assert repr(checker) == "F32Like[N, C]" |
| 259 | + |
| 260 | + def test_arraylike_factory_single_dim_subscript(self) -> None: |
| 261 | + """F32Like[N] (single dim, not tuple) should work.""" |
| 262 | + from shapix._array_types import make_array_like_type |
| 263 | + |
| 264 | + factory = make_array_like_type(FLOAT32, name="F32Like") |
| 265 | + hint = factory[N] |
| 266 | + assert hasattr(hint, "__metadata__") |
| 267 | + |
| 268 | + def test_arraylike_factory_repr(self) -> None: |
| 269 | + from shapix._array_types import make_array_like_type |
| 270 | + |
| 271 | + factory = make_array_like_type(FLOAT32, name="F32Like") |
| 272 | + assert repr(factory) == "F32Like" |
| 273 | + |
| 274 | + def test_arraylike_fail_obj_replay(self) -> None: |
| 275 | + """Second call with same failing obj should replay failure.""" |
| 276 | + from shapix._array_types import _ArrayLikeChecker |
| 277 | + |
| 278 | + checker = _ArrayLikeChecker(FLOAT32, (NamedDim("N"),), casting="no", name="F32Like") |
| 279 | + bad_obj = np.ones(3, dtype=np.int64) |
| 280 | + assert checker(bad_obj) is False # first call fails (casting='no'), sets _fail_obj |
| 281 | + assert checker(bad_obj) is False # replay |
| 282 | + |
| 283 | + def test_struct_checker_fail_obj_replay(self) -> None: |
| 284 | + """StructChecker should replay failure for same object.""" |
| 285 | + checker = _StructChecker(FLOAT32, (NamedDim("N"),)) |
| 286 | + bad = np.ones(3, dtype=np.int64) |
| 287 | + assert checker(bad) is False # dtype mismatch, sets _fail_obj |
| 288 | + assert checker(bad) is False # replay |
| 289 | + |
| 290 | + |
| 291 | +class TestClawWrapper: |
| 292 | + def test_shapix_this_package_delegates_to_beartype(self, monkeypatch) -> None: |
| 293 | + captured: dict[str, object] = {} |
| 294 | + |
| 295 | + def _fake_beartype_this_package(*, conf: object) -> None: |
| 296 | + captured["conf"] = conf |
| 297 | + |
| 298 | + import shapix.claw as claw_mod |
| 299 | + |
| 300 | + monkeypatch.setattr(claw_mod, "_beartype_this_package", _fake_beartype_this_package) |
| 301 | + conf = BeartypeConf() |
| 302 | + shapix_this_package(conf=conf) |
| 303 | + |
| 304 | + assert captured["conf"] is conf |
0 commit comments