Skip to content

Commit dd361d0

Browse files
committed
test kwargs with jit
1 parent b99ca91 commit dd361d0

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

tests/test_lazy.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,11 @@
1616
]
1717

1818

19-
@pytest.mark.parametrize("as_numpy", [False, pytest.param(True, marks=skip_as_numpy)])
20-
def test_lazy_apply_kwargs(xp: ModuleType, library: Backend, as_numpy: bool) -> None:
21-
expect = np.ndarray if as_numpy or library is Backend.DASK else type(xp.asarray(0))
19+
class NT(NamedTuple):
20+
a: Array
2221

23-
class NT(NamedTuple):
24-
a: Array
2522

23+
def check_lazy_apply_kwargs(x: Array, expect: type, as_numpy: bool) -> Array:
2624
def f(
2725
x: Array,
2826
z: dict[str, list[Array] | tuple[Array, ...] | NT],
@@ -36,10 +34,9 @@ def f(
3634
assert isinstance(z["baz"][0], expect)
3735
assert msg == "Hello World"
3836
assert msgs[0] == "Hello World"
39-
return x
37+
return x + 1
4038

41-
x = xp.asarray(0)
42-
y = lazy_apply( # pyright: ignore[reportCallIssue]
39+
return lazy_apply( # pyright: ignore[reportCallIssue]
4340
f,
4441
x,
4542
z={"foo": NT(x), "bar": [x], "baz": (x,)},
@@ -49,7 +46,16 @@ def f(
4946
dtype=x.dtype,
5047
as_numpy=as_numpy,
5148
)
52-
xp_assert_equal(x, y)
49+
50+
lazy_xp_function(check_lazy_apply_kwargs, static_argnames=("expect", "as_numpy"))
51+
52+
53+
@pytest.mark.parametrize("as_numpy", [False, pytest.param(True, marks=skip_as_numpy)])
54+
def test_lazy_apply_kwargs(xp: ModuleType, library: Backend, as_numpy: bool) -> None:
55+
expect = np.ndarray if as_numpy or library is Backend.DASK else type(xp.asarray(0))
56+
x = xp.asarray(0)
57+
actual = check_lazy_apply_kwargs(x, expect, as_numpy)
58+
xp_assert_equal(actual, x + 1)
5359

5460

5561
class CustomError(Exception):

0 commit comments

Comments
 (0)