@@ -191,26 +191,40 @@ def f(x: Array) -> Array:
191191 xp_assert_equal (y .compute (), x_cp + 1 ) # type: ignore[attr-defined] # pyright: ignore[reportUnknownArgumentType,reportAttributeAccessIssue]
192192
193193
194- @pytest .mark .xfail_xp_backend (Backend .JAX , reason = "unknown shape" )
195194def test_lazy_apply_none_shape_in_args (xp : ModuleType , library : Backend ):
196195 x = xp .asarray ([1 , 1 , 2 , 2 , 2 ])
197196
198- xp2 = np if library is Backend .DASK else xp
199-
200- # Single output
201- values = lazy_apply (xp2 .unique_values , x , shape = (None ,))
202- xp_assert_equal (values , xp .asarray ([1 , 2 ]))
203-
204- # Multi output
197+ # TODO mxp = meta_namespace(x, xp=xp)
198+ mxp = np if library is Backend .DASK else xp
205199 int_type = xp .asarray (0 ).dtype
206- values , counts = lazy_apply (
207- xp2 .unique_counts ,
208- x ,
209- shape = ((None ,), (None ,)),
210- dtype = (x .dtype , int_type ),
211- )
212- xp_assert_equal (values , xp .asarray ([1 , 2 ]))
213- xp_assert_equal (counts , xp .asarray ([2 , 3 ]))
200+
201+ if library is Backend .JAX :
202+ # Single output
203+ with pytest .raises (ValueError , match = "Output shape must be fully known" ):
204+ _ = lazy_apply (mxp .unique_values , x , shape = (None ,))
205+
206+ # Multi output
207+ with pytest .raises (ValueError , match = "Output shape must be fully known" ):
208+ _ = lazy_apply (
209+ mxp .unique_counts ,
210+ x ,
211+ shape = ((None ,), (None ,)),
212+ dtype = (x .dtype , int_type ),
213+ )
214+ else :
215+ # Single output
216+ values = lazy_apply (mxp .unique_values , x , shape = (None ,))
217+ xp_assert_equal (values , xp .asarray ([1 , 2 ]))
218+
219+ # Multi output
220+ values , counts = lazy_apply (
221+ mxp .unique_counts ,
222+ x ,
223+ shape = ((None ,), (None ,)),
224+ dtype = (x .dtype , int_type ),
225+ )
226+ xp_assert_equal (values , xp .asarray ([1 , 2 ]))
227+ xp_assert_equal (counts , xp .asarray ([2 , 3 ]))
214228
215229
216230def check_lazy_apply_none_shape_broadcast (x : Array ) -> Array :
@@ -349,10 +363,8 @@ def eager(
349363def test_lazy_apply_kwargs (xp : ModuleType , library : Backend , as_numpy : bool ):
350364 """When as_numpy=True, search and replace arrays in the (nested) keywords arguments
351365 with numpy arrays, and leave the rest untouched."""
352- expect_cls = (
353- np .ndarray if as_numpy or library is Backend .DASK else type (xp .asarray (0 ))
354- )
355366 x = xp .asarray (0 )
367+ expect_cls = np .ndarray if as_numpy or library is Backend .DASK else type (x )
356368 actual = check_lazy_apply_kwargs (x , expect_cls , as_numpy ) # pyright: ignore[reportUnknownArgumentType]
357369 xp_assert_equal (actual , x + 1 )
358370
0 commit comments