Skip to content

Commit f11ad7a

Browse files
committed
Merge branch 'main' into simple_print
2 parents b306ef1 + 9bd8126 commit f11ad7a

File tree

1 file changed

+47
-8
lines changed

1 file changed

+47
-8
lines changed

afar/core.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import inspect
44
import io
55
import sys
6+
import weakref
67

78
import innerscope
89
from dask import distributed
@@ -107,6 +108,8 @@ def __init__(self, *names, data=None):
107108
# If we do this, we shouldn't keep data around.
108109
self._is_singleton = data is None
109110
self._frame = None
111+
# Used to cancel work
112+
self._client_to_futures = weakref.WeakKeyDictionary()
110113
# For now, save the following to help debug
111114
self._where = None
112115
self._magic_func = None
@@ -180,8 +183,14 @@ def __enter__(self):
180183
return self.data
181184

182185
def __exit__(self, exc_type, exc_value, exc_traceback):
186+
self._where = None
183187
try:
184188
return self._exit(exc_type, exc_value, exc_traceback)
189+
except KeyboardInterrupt:
190+
# Cancel all pending tasks
191+
if self._where == "remotely":
192+
self.cancel()
193+
raise
185194
finally:
186195
self._frame = None
187196
self._lines = None
@@ -228,6 +237,11 @@ def _exit(self, exc_type, exc_value, exc_traceback):
228237
if self._where == "remotely":
229238
if client is None:
230239
client = distributed.client._get_global_client()
240+
if client not in self._client_to_futures:
241+
weak_futures = weakref.WeakSet()
242+
self._client_to_futures[client] = weak_futures
243+
else:
244+
weak_futures = self._client_to_futures[client]
231245
to_scatter = self.data.keys() & self._magic_func._scoped.outer_scope.keys()
232246
if to_scatter:
233247
# Scatter value in `self.data` that we need in this calculation.
@@ -247,39 +261,51 @@ def _exit(self, exc_type, exc_value, exc_traceback):
247261
del self._magic_func._scoped.outer_scope[key]
248262
# Scatter magic_func to avoid "Large object" UserWarning
249263
magic_func = client.scatter(self._magic_func)
264+
weak_futures.add(magic_func)
250265
remote_dict = client.submit(
251266
run_afar, magic_func, names, futures, pure=False, **submit_kwargs
252267
)
253-
del magic_func # Let go ASAP
268+
weak_futures.add(remote_dict)
269+
magic_func.release() # Let go ASAP
254270
if display_expr:
255271
repr_val = client.submit(
256272
reprs.repr_afar,
257273
client.submit(get_afar, remote_dict, "_afar_return_value_"),
258274
self._magic_func._repr_methods,
259275
)
260-
stdout_val = client.submit(get_afar, remote_dict, "_afar_stdout_")
261-
stderr_val = client.submit(get_afar, remote_dict, "_afar_stderr_")
276+
weak_futures.add(repr_val)
277+
stdout_future = client.submit(get_afar, remote_dict, "_afar_stdout_")
278+
weak_futures.add(stdout_future)
279+
stderr_future = client.submit(get_afar, remote_dict, "_afar_stderr_")
280+
weak_futures.add(stderr_future)
262281
if self._gather_data:
263282
futures_to_name = {
264283
client.submit(get_afar, remote_dict, name, **submit_kwargs): name
265284
for name in names
266285
}
267-
del remote_dict # Let go ASAP
286+
weak_futures.update(futures_to_name)
287+
remote_dict.release() # Let go ASAP
268288
for future, result in distributed.as_completed(futures_to_name, with_results=True):
269289
self.data[futures_to_name[future]] = result
270290
else:
271291
for name in names:
272-
self.data[name] = client.submit(get_afar, remote_dict, name, **submit_kwargs)
273-
del remote_dict # Let go ASAP
292+
future = client.submit(get_afar, remote_dict, name, **submit_kwargs)
293+
weak_futures.add(future)
294+
self.data[name] = future
295+
remote_dict.release() # Let go ASAP
296+
274297
# blocks!
275-
stdout_val = stdout_val.result()
298+
stdout_val = stdout_future.result()
276299
if stdout_val:
277300
print(stdout_val, end="")
278-
stderr_val = stderr_val.result()
301+
stdout_future.release()
302+
stderr_val = stderr_future.result()
279303
if stderr_val:
280304
print(stderr_val, end="", file=sys.stderr)
305+
stderr_future.release()
281306
if display_expr:
282307
reprs.display_repr(repr_val.result()) # This blocks!
308+
repr_val.release()
283309
elif self._where == "locally":
284310
# Run locally. This is handy for testing and debugging.
285311
results = self._magic_func()
@@ -297,6 +323,18 @@ def _exit(self, exc_type, exc_value, exc_traceback):
297323
frame.f_locals.update((name, self.data[name]) for name in names)
298324
return True
299325

326+
def cancel(self, *, client=None, force=False):
327+
"""Cancel pending tasks"""
328+
if client is not None:
329+
items = [(client, self._client_to_futures[client])]
330+
else:
331+
items = self._client_to_futures.items()
332+
for client, weak_futures in items:
333+
client.cancel(
334+
[future for future in weak_futures if future.status == "pending"], force=force
335+
)
336+
weak_futures.clear()
337+
300338

301339
class Get(Run):
302340
"""Unlike ``run``, ``get`` automatically gathers the data locally"""
@@ -331,6 +369,7 @@ def abracadabra(runner):
331369
key: val
332370
for key, val in scoped.outer_scope.items()
333371
if isinstance(val, distributed.Future)
372+
# TODO: what can/should we do if the future is in a bad state?
334373
}
335374
for key in futures:
336375
del scoped.outer_scope[key]

0 commit comments

Comments
 (0)