11import dis
22import inspect
3+ import weakref
34
45import innerscope
56from dask import distributed
@@ -104,6 +105,8 @@ def __init__(self, *names, data=None):
104105 # If we do this, we shouldn't keep data around.
105106 self ._is_singleton = data is None
106107 self ._frame = None
108+ # Used to cancel work
109+ self ._client_to_futures = weakref .WeakKeyDictionary ()
107110 # For now, save the following to help debug
108111 self ._where = None
109112 self ._magic_func = None
@@ -177,8 +180,14 @@ def __enter__(self):
177180 return self .data
178181
179182 def __exit__ (self , exc_type , exc_value , exc_traceback ):
183+ self ._where = None
180184 try :
181185 return self ._exit (exc_type , exc_value , exc_traceback )
186+ except KeyboardInterrupt :
187+ # Cancel all pending tasks
188+ if self ._where == "remotely" :
189+ self .cancel ()
190+ raise
182191 finally :
183192 self ._frame = None
184193 self ._lines = None
@@ -225,6 +234,11 @@ def _exit(self, exc_type, exc_value, exc_traceback):
225234 if self ._where == "remotely" :
226235 if client is None :
227236 client = distributed .client ._get_global_client ()
237+ if client not in self ._client_to_futures :
238+ weak_futures = weakref .WeakSet ()
239+ self ._client_to_futures [client ] = weak_futures
240+ else :
241+ weak_futures = self ._client_to_futures [client ]
228242 to_scatter = self .data .keys () & self ._magic_func ._scoped .outer_scope .keys ()
229243 if to_scatter :
230244 # Scatter value in `self.data` that we need in this calculation.
@@ -244,30 +258,37 @@ def _exit(self, exc_type, exc_value, exc_traceback):
244258 del self ._magic_func ._scoped .outer_scope [key ]
245259 # Scatter magic_func to avoid "Large object" UserWarning
246260 magic_func = client .scatter (self ._magic_func )
261+ weak_futures .add (magic_func )
247262 remote_dict = client .submit (
248263 run_afar , magic_func , names , futures , pure = False , ** submit_kwargs
249264 )
250- del magic_func # Let go ASAP
265+ weak_futures .add (remote_dict )
266+ magic_func .release () # Let go ASAP
251267 if display_expr :
252268 repr_val = client .submit (
253269 reprs .repr_afar ,
254270 client .submit (get_afar , remote_dict , "_afar_return_value_" ),
255271 self ._magic_func ._repr_methods ,
256272 )
273+ weak_futures .add (repr_val )
257274 if self ._gather_data :
258275 futures_to_name = {
259276 client .submit (get_afar , remote_dict , name , ** submit_kwargs ): name
260277 for name in names
261278 }
262- del remote_dict # Let go ASAP
279+ weak_futures .update (futures_to_name )
280+ remote_dict .release () # Let go ASAP
263281 for future , result in distributed .as_completed (futures_to_name , with_results = True ):
264282 self .data [futures_to_name [future ]] = result
265283 else :
266284 for name in names :
267- self .data [name ] = client .submit (get_afar , remote_dict , name , ** submit_kwargs )
268- del remote_dict # Let go ASAP
285+ future = client .submit (get_afar , remote_dict , name , ** submit_kwargs )
286+ weak_futures .add (future )
287+ self .data [name ] = future
288+ remote_dict .release () # Let go ASAP
269289 if display_expr :
270290 reprs .display_repr (repr_val .result ()) # This blocks!
291+ repr_val .release ()
271292 elif self ._where == "locally" :
272293 # Run locally. This is handy for testing and debugging.
273294 results = self ._magic_func ()
@@ -285,6 +306,18 @@ def _exit(self, exc_type, exc_value, exc_traceback):
285306 frame .f_locals .update ((name , self .data [name ]) for name in names )
286307 return True
287308
309+ def cancel (self , * , client = None , force = False ):
310+ """Cancel pending tasks"""
311+ if client is not None :
312+ items = [(client , self ._client_to_futures [client ])]
313+ else :
314+ items = self ._client_to_futures .items ()
315+ for client , weak_futures in items :
316+ client .cancel (
317+ [future for future in weak_futures if future .status == "pending" ], force = force
318+ )
319+ weak_futures .clear ()
320+
288321
289322class Get (Run ):
290323 """Unlike ``run``, ``get`` automatically gathers the data locally"""
@@ -319,6 +352,7 @@ def abracadabra(runner):
319352 key : val
320353 for key , val in scoped .outer_scope .items ()
321354 if isinstance (val , distributed .Future )
355+ # TODO: what can/should we do if the future is in a bad state?
322356 }
323357 for key in futures :
324358 del scoped .outer_scope [key ]
0 commit comments