Skip to content

Commit 9bd8126

Browse files
authored
Merge pull request #17 from eriknw/cancel
Add `.cancel` method to cancel pending tasks.
2 parents 4f30f9b + 9b02fb9 commit 9bd8126

File tree

1 file changed

+38
-4
lines changed

1 file changed

+38
-4
lines changed

afar/core.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import dis
22
import inspect
3+
import weakref
34

45
import innerscope
56
from dask import distributed
@@ -104,6 +105,8 @@ def __init__(self, *names, data=None):
104105
# If we do this, we shouldn't keep data around.
105106
self._is_singleton = data is None
106107
self._frame = None
108+
# Used to cancel work
109+
self._client_to_futures = weakref.WeakKeyDictionary()
107110
# For now, save the following to help debug
108111
self._where = None
109112
self._magic_func = None
@@ -177,8 +180,14 @@ def __enter__(self):
177180
return self.data
178181

179182
def __exit__(self, exc_type, exc_value, exc_traceback):
183+
self._where = None
180184
try:
181185
return self._exit(exc_type, exc_value, exc_traceback)
186+
except KeyboardInterrupt:
187+
# Cancel all pending tasks
188+
if self._where == "remotely":
189+
self.cancel()
190+
raise
182191
finally:
183192
self._frame = None
184193
self._lines = None
@@ -225,6 +234,11 @@ def _exit(self, exc_type, exc_value, exc_traceback):
225234
if self._where == "remotely":
226235
if client is None:
227236
client = distributed.client._get_global_client()
237+
if client not in self._client_to_futures:
238+
weak_futures = weakref.WeakSet()
239+
self._client_to_futures[client] = weak_futures
240+
else:
241+
weak_futures = self._client_to_futures[client]
228242
to_scatter = self.data.keys() & self._magic_func._scoped.outer_scope.keys()
229243
if to_scatter:
230244
# Scatter value in `self.data` that we need in this calculation.
@@ -244,30 +258,37 @@ def _exit(self, exc_type, exc_value, exc_traceback):
244258
del self._magic_func._scoped.outer_scope[key]
245259
# Scatter magic_func to avoid "Large object" UserWarning
246260
magic_func = client.scatter(self._magic_func)
261+
weak_futures.add(magic_func)
247262
remote_dict = client.submit(
248263
run_afar, magic_func, names, futures, pure=False, **submit_kwargs
249264
)
250-
del magic_func # Let go ASAP
265+
weak_futures.add(remote_dict)
266+
magic_func.release() # Let go ASAP
251267
if display_expr:
252268
repr_val = client.submit(
253269
reprs.repr_afar,
254270
client.submit(get_afar, remote_dict, "_afar_return_value_"),
255271
self._magic_func._repr_methods,
256272
)
273+
weak_futures.add(repr_val)
257274
if self._gather_data:
258275
futures_to_name = {
259276
client.submit(get_afar, remote_dict, name, **submit_kwargs): name
260277
for name in names
261278
}
262-
del remote_dict # Let go ASAP
279+
weak_futures.update(futures_to_name)
280+
remote_dict.release() # Let go ASAP
263281
for future, result in distributed.as_completed(futures_to_name, with_results=True):
264282
self.data[futures_to_name[future]] = result
265283
else:
266284
for name in names:
267-
self.data[name] = client.submit(get_afar, remote_dict, name, **submit_kwargs)
268-
del remote_dict # Let go ASAP
285+
future = client.submit(get_afar, remote_dict, name, **submit_kwargs)
286+
weak_futures.add(future)
287+
self.data[name] = future
288+
remote_dict.release() # Let go ASAP
269289
if display_expr:
270290
reprs.display_repr(repr_val.result()) # This blocks!
291+
repr_val.release()
271292
elif self._where == "locally":
272293
# Run locally. This is handy for testing and debugging.
273294
results = self._magic_func()
@@ -285,6 +306,18 @@ def _exit(self, exc_type, exc_value, exc_traceback):
285306
frame.f_locals.update((name, self.data[name]) for name in names)
286307
return True
287308

309+
def cancel(self, *, client=None, force=False):
310+
"""Cancel pending tasks"""
311+
if client is not None:
312+
items = [(client, self._client_to_futures[client])]
313+
else:
314+
items = self._client_to_futures.items()
315+
for client, weak_futures in items:
316+
client.cancel(
317+
[future for future in weak_futures if future.status == "pending"], force=force
318+
)
319+
weak_futures.clear()
320+
288321

289322
class Get(Run):
290323
"""Unlike ``run``, ``get`` automatically gathers the data locally"""
@@ -319,6 +352,7 @@ def abracadabra(runner):
319352
key: val
320353
for key, val in scoped.outer_scope.items()
321354
if isinstance(val, distributed.Future)
355+
# TODO: what can/should we do if the future is in a bad state?
322356
}
323357
for key in futures:
324358
del scoped.outer_scope[key]

0 commit comments

Comments
 (0)