@@ -200,7 +200,7 @@ def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]:
200200 if is_dask_namespace (xp ):
201201 for name , func , tags in iter_tagged ():
202202 n = tags ["allow_dask_compute" ]
203- wrapped = _allow_dask_compute (func , n )
203+ wrapped = _dask_wrap (func , n )
204204 monkeypatch .setitem (globals_ , name , wrapped )
205205
206206 elif is_jax_namespace (xp ):
@@ -256,13 +256,15 @@ def __call__(self, dsk: Graph, keys: Sequence[Key] | Key, **kwargs: Any) -> Any:
256256 return dask .get (dsk , keys , ** kwargs ) # type: ignore[attr-defined,no-untyped-call] # pyright: ignore[reportPrivateImportUsage]
257257
258258
259- def _allow_dask_compute (
259+ def _dask_wrap (
260260 func : Callable [P , T ], n : int
261261) -> Callable [P , T ]: # numpydoc ignore=PR01,RT01
262262 """
263263 Wrap `func` to raise if it attempts to call `dask.compute` more than `n` times.
264+
265+ After the function returns, materialize the graph in order to re-raise exceptions.
264266 """
265- import dask . config
267+ import dask
266268
267269 func_name = getattr (func , "__name__" , str (func ))
268270 n_str = f"only up to { n } " if n else "no"
@@ -276,7 +278,12 @@ def _allow_dask_compute(
276278 @wraps (func )
277279 def wrapper (* args : P .args , ** kwargs : P .kwargs ) -> T : # numpydoc ignore=GL08
278280 scheduler = CountingDaskScheduler (n , msg )
279- with dask .config .set ({"scheduler" : scheduler }):
280- return func (* args , ** kwargs )
281+ with dask .config .set ({"scheduler" : scheduler }): # pyright: ignore[reportPrivateImportUsage]
282+ out = func (* args , ** kwargs )
283+
284+ # Block until the graph materializes and reraise exceptions. This allows
285+ # `pytest.raises` and `pytest.warns` to work as expected. Note that this would
286+ # not work on scheduler='distributed', as it would not block.
287+ return dask .persist (out , scheduler = "threads" )[0 ] # type: ignore[no-any-return,attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage]
281288
282289 return wrapper
0 commit comments