33import inspect
44import io
55import sys
6+ import weakref
67
78import innerscope
89from dask import distributed
@@ -107,6 +108,8 @@ def __init__(self, *names, data=None):
107108 # If we do this, we shouldn't keep data around.
108109 self ._is_singleton = data is None
109110 self ._frame = None
111+ # Used to cancel work
112+ self ._client_to_futures = weakref .WeakKeyDictionary ()
110113 # For now, save the following to help debug
111114 self ._where = None
112115 self ._magic_func = None
@@ -180,8 +183,14 @@ def __enter__(self):
180183 return self .data
181184
182185 def __exit__ (self , exc_type , exc_value , exc_traceback ):
186+ self ._where = None
183187 try :
184188 return self ._exit (exc_type , exc_value , exc_traceback )
189+ except KeyboardInterrupt :
190+ # Cancel all pending tasks
191+ if self ._where == "remotely" :
192+ self .cancel ()
193+ raise
185194 finally :
186195 self ._frame = None
187196 self ._lines = None
@@ -228,6 +237,11 @@ def _exit(self, exc_type, exc_value, exc_traceback):
228237 if self ._where == "remotely" :
229238 if client is None :
230239 client = distributed .client ._get_global_client ()
240+ if client not in self ._client_to_futures :
241+ weak_futures = weakref .WeakSet ()
242+ self ._client_to_futures [client ] = weak_futures
243+ else :
244+ weak_futures = self ._client_to_futures [client ]
231245 to_scatter = self .data .keys () & self ._magic_func ._scoped .outer_scope .keys ()
232246 if to_scatter :
233247 # Scatter value in `self.data` that we need in this calculation.
@@ -247,39 +261,51 @@ def _exit(self, exc_type, exc_value, exc_traceback):
247261 del self ._magic_func ._scoped .outer_scope [key ]
248262 # Scatter magic_func to avoid "Large object" UserWarning
249263 magic_func = client .scatter (self ._magic_func )
264+ weak_futures .add (magic_func )
250265 remote_dict = client .submit (
251266 run_afar , magic_func , names , futures , pure = False , ** submit_kwargs
252267 )
253- del magic_func # Let go ASAP
268+ weak_futures .add (remote_dict )
269+ magic_func .release () # Let go ASAP
254270 if display_expr :
255271 repr_val = client .submit (
256272 reprs .repr_afar ,
257273 client .submit (get_afar , remote_dict , "_afar_return_value_" ),
258274 self ._magic_func ._repr_methods ,
259275 )
260- stdout_val = client .submit (get_afar , remote_dict , "_afar_stdout_" )
261- stderr_val = client .submit (get_afar , remote_dict , "_afar_stderr_" )
276+ weak_futures .add (repr_val )
277+ stdout_future = client .submit (get_afar , remote_dict , "_afar_stdout_" )
278+ weak_futures .add (stdout_future )
279+ stderr_future = client .submit (get_afar , remote_dict , "_afar_stderr_" )
280+ weak_futures .add (stderr_future )
262281 if self ._gather_data :
263282 futures_to_name = {
264283 client .submit (get_afar , remote_dict , name , ** submit_kwargs ): name
265284 for name in names
266285 }
267- del remote_dict # Let go ASAP
286+ weak_futures .update (futures_to_name )
287+ remote_dict .release () # Let go ASAP
268288 for future , result in distributed .as_completed (futures_to_name , with_results = True ):
269289 self .data [futures_to_name [future ]] = result
270290 else :
271291 for name in names :
272- self .data [name ] = client .submit (get_afar , remote_dict , name , ** submit_kwargs )
273- del remote_dict # Let go ASAP
292+ future = client .submit (get_afar , remote_dict , name , ** submit_kwargs )
293+ weak_futures .add (future )
294+ self .data [name ] = future
295+ remote_dict .release () # Let go ASAP
296+
274297 # blocks!
275- stdout_val = stdout_val .result ()
298+ stdout_val = stdout_future .result ()
276299 if stdout_val :
277300 print (stdout_val , end = "" )
278- stderr_val = stderr_val .result ()
301+ stdout_future .release ()
302+ stderr_val = stderr_future .result ()
279303 if stderr_val :
280304 print (stderr_val , end = "" , file = sys .stderr )
305+ stderr_future .release ()
281306 if display_expr :
282307 reprs .display_repr (repr_val .result ()) # This blocks!
308+ repr_val .release ()
283309 elif self ._where == "locally" :
284310 # Run locally. This is handy for testing and debugging.
285311 results = self ._magic_func ()
@@ -297,6 +323,18 @@ def _exit(self, exc_type, exc_value, exc_traceback):
297323 frame .f_locals .update ((name , self .data [name ]) for name in names )
298324 return True
299325
326+ def cancel (self , * , client = None , force = False ):
327+ """Cancel pending tasks"""
328+ if client is not None :
329+ items = [(client , self ._client_to_futures [client ])]
330+ else :
331+ items = self ._client_to_futures .items ()
332+ for client , weak_futures in items :
333+ client .cancel (
334+ [future for future in weak_futures if future .status == "pending" ], force = force
335+ )
336+ weak_futures .clear ()
337+
300338
301339class Get (Run ):
302340 """Unlike ``run``, ``get`` automatically gathers the data locally"""
@@ -331,6 +369,7 @@ def abracadabra(runner):
331369 key : val
332370 for key , val in scoped .outer_scope .items ()
333371 if isinstance (val , distributed .Future )
372+ # TODO: what can/should we do if the future is in a bad state?
334373 }
335374 for key in futures :
336375 del scoped .outer_scope [key ]
0 commit comments