@@ -203,19 +203,35 @@ class CountingDaskScheduler(SchedulerGetCallable):
203203 """
204204 Dask scheduler that counts how many times `dask.compute` is called.
205205
206+ If the number of times exceeds 'max_count', it raises an error.
206207 This is a wrapper around Dask's own 'synchronous' scheduler.
208+
209+ Parameters
210+ ----------
211+ max_count : int
212+ Maximum number of allowed calls to `dask.compute`.
213+ msg : str
214+ Assertion to raise when the count exceeds `max_count`.
207215 """
208216
209217 count : int
218+ max_count : int
219+ msg : str
210220
211- def __init__ (self ): # numpydoc ignore=GL08
221+ def __init__ (self , max_count : int , msg : str ): # numpydoc ignore=GL08
212222 self .count = 0
223+ self .max_count = max_count
224+ self .msg = msg
213225
214226 @override
215227 def __call__ (self , dsk : Graph , keys : Sequence [Key ] | Key , ** kwargs : Any ) -> Any : # type: ignore[no-any-decorated,no-any-explicit] # numpydoc ignore=GL08
216228 import dask
217229
218230 self .count += 1
231+ # This should yield a nice traceback to the
232+ # offending line in the user's code
233+ assert self .count <= self .max_count , self .msg
234+
219235 return dask .get (dsk , keys , ** kwargs ) # type: ignore[attr-defined,no-untyped-call] # pyright: ignore[reportPrivateImportUsage]
220236
221237
@@ -228,21 +244,18 @@ def _allow_dask_compute(
228244 import dask .config
229245
230246 func_name = getattr (func , "__name__" , str (func ))
247+ n_str = f"only up to { n } " if n else "no"
248+ msg = (
249+ f"Called `dask.compute()` or `dask.persist()` { n + 1 } times, "
250+ f"but { n_str } calls are allowed. Set "
251+ f"`lazy_xp_function({ func_name } , allow_dask_compute={ n + 1 } ) "
252+ "to allow for more (but note that this will harm performance). "
253+ )
231254
232255 @wraps (func )
233256 def wrapper (* args : P .args , ** kwargs : P .kwargs ) -> T : # numpydoc ignore=GL08
234- scheduler = CountingDaskScheduler ()
257+ scheduler = CountingDaskScheduler (n , msg )
235258 with dask .config .set ({"scheduler" : scheduler }):
236- out = func (* args , ** kwargs )
237- if scheduler .count > n :
238- n_str = f"only up to { n } " if n else "no"
239- msg = (
240- f"Called `dask.compute()` or `dask.persist()` { scheduler .count } times, "
241- f"but { n_str } calls are allowed. Set "
242- f"`lazy_xp_function({ func_name } , allow_dask_compute={ scheduler .count } ) "
243- "to allow for more (but note that this will harm performance). "
244- )
245- raise AssertionError (msg )
246- return out
259+ return func (* args , ** kwargs )
247260
248261 return wrapper
0 commit comments