@@ -130,13 +130,18 @@ def non_materializable4(x: Array) -> Array:
130130    return  non_materializable (x )
131131
132132
133+ def  non_materializable5 (x : Array ) ->  Array :
134+     return  non_materializable (x )
135+ 
136+ 
133137lazy_xp_function (good_lazy )
134138# Works on JAX and Dask 
135139lazy_xp_function (non_materializable2 , jax_jit = False , allow_dask_compute = 2 )
140+ lazy_xp_function (non_materializable3 , jax_jit = False , allow_dask_compute = True )
136141# Works on JAX, but not Dask 
137- lazy_xp_function (non_materializable3 , jax_jit = False , allow_dask_compute = 1 )
142+ lazy_xp_function (non_materializable4 , jax_jit = False , allow_dask_compute = 1 )
138143# Works neither on Dask nor JAX 
139- lazy_xp_function (non_materializable4 )
144+ lazy_xp_function (non_materializable5 )
140145
141146
142147def  test_lazy_xp_function (xp : ModuleType ):
@@ -147,29 +152,30 @@ def test_lazy_xp_function(xp: ModuleType):
147152    xp_assert_equal (non_materializable (x ), xp .asarray ([1.0 , 2.0 ]))
148153    # Wrapping explicitly disabled 
149154    xp_assert_equal (non_materializable2 (x ), xp .asarray ([1.0 , 2.0 ]))
155+     xp_assert_equal (non_materializable3 (x ), xp .asarray ([1.0 , 2.0 ]))
150156
151157    if  is_jax_namespace (xp ):
152-         xp_assert_equal (non_materializable3 (x ), xp .asarray ([1.0 , 2.0 ]))
158+         xp_assert_equal (non_materializable4 (x ), xp .asarray ([1.0 , 2.0 ]))
153159        with  pytest .raises (
154160            TypeError , match = "Attempted boolean conversion of traced array" 
155161        ):
156-             _  =  non_materializable4 (x )  # Wrapped 
162+             _  =  non_materializable5 (x )  # Wrapped 
157163
158164    elif  is_dask_namespace (xp ):
159165        with  pytest .raises (
160166            AssertionError ,
161167            match = r"dask\.compute.* 2 times, but only up to 1 calls are allowed" ,
162168        ):
163-             _  =  non_materializable3 (x )
169+             _  =  non_materializable4 (x )
164170        with  pytest .raises (
165171            AssertionError ,
166172            match = r"dask\.compute.* 1 times, but no calls are allowed" ,
167173        ):
168-             _  =  non_materializable4 (x )
174+             _  =  non_materializable5 (x )
169175
170176    else :
171-         xp_assert_equal (non_materializable3 (x ), xp .asarray ([1.0 , 2.0 ]))
172177        xp_assert_equal (non_materializable4 (x ), xp .asarray ([1.0 , 2.0 ]))
178+         xp_assert_equal (non_materializable5 (x ), xp .asarray ([1.0 , 2.0 ]))
173179
174180
175181def  static_params (x : Array , n : int , flag : bool  =  False ) ->  Array :
0 commit comments