Skip to content

Commit a2ca5ad

Browse files
authored
HTEX handles pre-serialized submissions (#3983)
# Description Allow an interested party to utilize a custom task executor with HTEX, rather than the one provided by Parsl. One example might be if a task is pre-serialized in a custom manner, the default `execute_task()` function will not know how to deserialize and run it. As of this PR, the HTEX implements the context-variable keys `resource_spec` and `task_executor`. The `resource_spec` is nominally well-understood at this point. The `task_executor` is documented in the `submit_payload` docstring. An example, borrowed from that documentation: ```python >>> htex: HighThroughputExecutor # setup prior to this example >>> ctxt = { ... "task_executor": { ... "f": "full.import.path.of.custom_execute_task", ... "a": ("additional", "arguments"), ... "k": {"some": "keyword", "args": "here"} ... } ... } >>> fn_buf = custom_serialize(task_func, *task_args, **task_kwargs) >>> fut = htex.submit_payload(ctxt, fn_buf) ``` The custom ``custom_execute_task`` would be dynamically imported, and invoked within the `process_worker_pool.py` worker as if: ```python from full.import.path.of import custom_execute_task args = ("additional", "arguments") kwargs = {"some": "keyword", "args": "here"} result = custom_execute_task(fn_buf, *args, **kwargs) ``` # Changed Behaviour There should be no change to existing workflows, but new workflows may be able to use `submit_payload` to fine-tune how tasks are executed within the worker. ## Type of change - New feature
1 parent 6e427e7 commit a2ca5ad

File tree

5 files changed

+253
-15
lines changed

5 files changed

+253
-15
lines changed

parsl/executors/high_throughput/executor.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -697,20 +697,11 @@ def submit(self, func: Callable, resource_specification: dict, *args, **kwargs)
697697

698698
self.validate_resource_spec(resource_specification)
699699

700-
if self.bad_state_is_set:
701-
raise self.executor_exception
702-
703-
self._task_counter += 1
704-
task_id = self._task_counter
705-
706700
# handle people sending blobs gracefully
707701
if logger.getEffectiveLevel() <= logging.DEBUG:
708702
args_to_print = tuple([ar if len(ar := repr(arg)) < 100 else (ar[:100] + '...') for arg in args])
709703
logger.debug("Pushing function {} to queue with args {}".format(func, args_to_print))
710704

711-
fut = HTEXFuture(task_id)
712-
self.tasks[task_id] = fut
713-
714705
try:
715706
fn_buf = pack_apply_message(func, args, kwargs, buffer_threshold=1 << 20)
716707
except TypeError:
@@ -720,12 +711,69 @@ def submit(self, func: Callable, resource_specification: dict, *args, **kwargs)
720711
if resource_specification:
721712
context["resource_spec"] = resource_specification
722713

723-
msg = {"task_id": task_id, "context": context, "buffer": fn_buf}
714+
return self.submit_payload(context, fn_buf)
715+
716+
def submit_payload(self, context: dict, buffer: bytes) -> HTEXFuture:
717+
"""
718+
Submit specially crafted payloads.
719+
720+
For use-cases where the ``HighThroughputExecutor`` consumer needs the payload
721+
handled by the worker in a special way. For example, if the function is
722+
serialized differently than Parsl's default approach, or if the task must
723+
be setup more precisely than Parsl's default ``execute_task`` allows.
724+
725+
An example interaction:
726+
727+
.. code-block: python
728+
729+
>>> htex: HighThroughputExecutor # setup prior to this example
730+
>>> ctxt = {
731+
... "task_executor": {
732+
... "f": "full.import.path.of.custom_execute_task",
733+
... "a": ("additional", "arguments"),
734+
... "k": {"some": "keyword", "args": "here"}
735+
... }
736+
... }
737+
>>> fn_buf = custom_serialize(task_func, *task_args, **task_kwargs)
738+
>>> fut = htex.submit_payload(ctxt, fn_buf)
739+
740+
The custom ``custom_execute_task`` would be dynamically imported, and
741+
invoked as:
742+
743+
.. code-block: python
744+
745+
args = ("additional", "arguments")
746+
kwargs = {"some": "keyword", "args": "here"}
747+
result = custom_execute_task(fn_buf, *args, **kwargs)
748+
749+
Parameters
750+
----------
751+
context:
752+
A task-specific context associated with the function buffer. Parsl
753+
currently implements the keys ``task_executor`` and ``resource_spec``
754+
755+
buffer:
756+
A serialized function, that will be deserialized and executed by
757+
``execute_task`` (or custom function, if ``task_executor`` is specified)
758+
759+
Returns
760+
-------
761+
An HTEXFuture (a normal Future, with the attribute ``.parsl_executor_task_id``
762+
set). The future will be set to done when the associated function buffer has
763+
been invoked and completed.
764+
"""
765+
if self.bad_state_is_set:
766+
raise self.executor_exception
767+
768+
self._task_counter += 1
769+
task_id = self._task_counter
770+
771+
fut = HTEXFuture(task_id)
772+
self.tasks[task_id] = fut
724773

725-
# Post task to the outgoing queue
774+
msg = {"task_id": task_id, "context": context, "buffer": buffer}
726775
self.outgoing_q.put(msg)
727776

728-
# Return the future
729777
return fut
730778

731779
@property

parsl/executors/high_throughput/process_worker_pool.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python3
22

33
import argparse
4+
import importlib
45
import logging
56
import math
67
import multiprocessing
@@ -17,7 +18,7 @@
1718
from multiprocessing.context import SpawnProcess
1819
from multiprocessing.managers import DictProxy
1920
from multiprocessing.sharedctypes import Synchronized
20-
from typing import Dict, List, Optional, Sequence
21+
from typing import Callable, Dict, List, Optional, Sequence
2122

2223
import psutil
2324
import zmq
@@ -778,8 +779,20 @@ def manager_is_alive():
778779

779780
_init_mpi_env(mpi_launcher=mpi_launcher, resource_spec=res_spec)
780781

782+
exec_func: Callable = execute_task
783+
exec_args = ()
784+
exec_kwargs = {}
785+
781786
try:
782-
result = execute_task(req['buffer'])
787+
if task_executor := ctxt.get("task_executor", None):
788+
mod_name, _, fn_name = task_executor["f"].rpartition(".")
789+
exec_mod = importlib.import_module(mod_name)
790+
exec_func = getattr(exec_mod, fn_name)
791+
792+
exec_args = task_executor.get("a", ())
793+
exec_kwargs = task_executor.get("k", {})
794+
795+
result = exec_func(req['buffer'], *exec_args, **exec_kwargs)
783796
serialized_result = serialize(result, buffer_threshold=1000000)
784797
except Exception as e:
785798
logger.info('Caught an exception: {}'.format(e))

parsl/tests/test_htex/test_htex.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88

99
from parsl import HighThroughputExecutor, curvezmq
10+
from parsl.serialize.facade import pack_apply_message, unpack_apply_message
1011

1112
_MOCK_BASE = "parsl.executors.high_throughput.executor"
1213

@@ -19,11 +20,16 @@ def encrypted(request: pytest.FixtureRequest):
1920

2021

2122
@pytest.fixture
22-
def htex(encrypted: bool):
23+
def htex(encrypted: bool, tmpd_cwd):
2324
htex = HighThroughputExecutor(encrypted=encrypted)
25+
htex.max_workers_per_node = 1
26+
htex.run_dir = tmpd_cwd
27+
htex.provider.script_dir = tmpd_cwd
2428

2529
yield htex
2630

31+
if hasattr(htex, "outgoing_q"):
32+
htex.scale_in(blocks=1000)
2733
htex.shutdown()
2834

2935

@@ -146,3 +152,32 @@ def test_htex_interchange_launch_cmd(cmd: Optional[Sequence[str]]):
146152
else:
147153
htex = HighThroughputExecutor()
148154
assert htex.interchange_launch_cmd == ["interchange.py"]
155+
156+
157+
def dyn_exec(buf, *vec_y):
158+
f, a, _ = unpack_apply_message(buf)
159+
custom_args = [a, vec_y]
160+
return f(*custom_args)
161+
162+
163+
@pytest.mark.local
164+
def test_worker_dynamic_import(htex: HighThroughputExecutor):
165+
def _dot_prod(vec_x, vec_y):
166+
return sum(x * y for x, y in zip(vec_x, vec_y))
167+
168+
htex.start()
169+
htex.scale_out_facade(1)
170+
171+
num_array = tuple(range(10))
172+
173+
fn_buf = pack_apply_message(_dot_prod, num_array, {})
174+
ctxt = {
175+
"task_executor": {
176+
"f": f"{dyn_exec.__module__}.{dyn_exec.__name__}",
177+
"a": num_array, # prove "custom" dyn_exec
178+
}
179+
}
180+
val = htex.submit_payload(ctxt, fn_buf).result()
181+
exp_val = _dot_prod(num_array, num_array)
182+
183+
assert val == exp_val
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from unittest import mock
2+
3+
import pytest
4+
5+
from parsl import HighThroughputExecutor
6+
from parsl.executors.high_throughput import zmq_pipes
7+
8+
9+
@pytest.mark.local
10+
def test_submit_payload():
11+
htex = HighThroughputExecutor()
12+
htex.outgoing_q = mock.Mock(spec=zmq_pipes.TasksOutgoing)
13+
ctxt = {"some": "context"}
14+
buf = b'some buffer (function) payload'
15+
for task_num in range(1, 20):
16+
htex.outgoing_q.reset_mock()
17+
fut = htex.submit_payload(ctxt, buf)
18+
(msg,), _ = htex.outgoing_q.put.call_args
19+
20+
assert htex.tasks[fut.parsl_executor_task_id] is fut
21+
assert fut.parsl_executor_task_id == task_num, "Expect monotonic increase"
22+
assert msg["task_id"] == fut.parsl_executor_task_id
23+
assert msg["context"] == ctxt, "Expect no modification"
24+
assert msg["buffer"] == buf, "Expect no modification"

parsl/tests/unit/executors/high_throughput/test_process_worker_pool.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1+
import os
2+
import pickle
13
import sys
24
from argparse import ArgumentError
5+
from unittest import mock
36

47
import pytest
58

9+
from parsl.app.errors import RemoteExceptionWrapper
610
from parsl.executors.high_throughput import process_worker_pool
11+
from parsl.executors.high_throughput.process_worker_pool import worker
12+
from parsl.multiprocessing import SpawnContext
13+
from parsl.serialize.facade import deserialize
714

815
if sys.version_info < (3, 12):
916
# exit_on_error bug; see https://github.com/python/cpython/issues/121018
@@ -72,3 +79,114 @@ def test_arg_parser_validates_cpu_affinity(valid, val):
7279
with pytest.raises(ArgumentError) as pyt_exc:
7380
p.parse_args(reqd_args)
7481
assert "must be one of" in pyt_exc.value.args[1]
82+
83+
84+
def _always_raise(*a, **k):
85+
raise ArithmeticError(f"{a=}\n{k=}")
86+
87+
88+
@pytest.mark.local
89+
def test_worker_dynamic_import_happy_path(tmpd_cwd):
90+
import_str = f"{_always_raise.__module__}.{_always_raise.__name__}"
91+
task_exec = {
92+
"f": import_str,
93+
"a": (1, 2),
94+
"k": {"a": "b"},
95+
}
96+
req = {
97+
"task_id": 15,
98+
"context": {"task_executor": task_exec},
99+
"buffer": b"some serialized value"
100+
}
101+
102+
try:
103+
task_args = [req["buffer"]]
104+
task_args.extend(task_exec["a"])
105+
_always_raise(*task_args, **task_exec["k"])
106+
except Exception as e:
107+
exp_exc = e
108+
else:
109+
raise RuntimeError("Test failure; this branch should not run")
110+
111+
q = mock.Mock(side_effect=(req, MemoryError("intentional test error")))
112+
q.get = q
113+
114+
block_id = "bid"
115+
worker_id = 1
116+
pool = 1
117+
(tmpd_cwd / f"block-{block_id}/{worker_id}").mkdir(parents=True)
118+
with pytest.raises(MemoryError):
119+
worker(
120+
worker_id,
121+
pool_id=str(pool),
122+
pool_size=pool,
123+
task_queue=q,
124+
result_queue=q,
125+
monitoring_queue=None,
126+
ready_worker_count=SpawnContext.Value("i", 0),
127+
tasks_in_progress={},
128+
cpu_affinity="none",
129+
accelerator=None,
130+
block_id=block_id,
131+
task_queue_timeout=0,
132+
manager_pid=os.getpid(),
133+
logdir=str(tmpd_cwd),
134+
debug=True,
135+
mpi_launcher="",
136+
)
137+
(result_pkl,), _ = q.put.call_args
138+
r = pickle.loads(result_pkl)
139+
assert "exception" in r
140+
wrapped_exc: RemoteExceptionWrapper = deserialize(r["exception"])
141+
exc = wrapped_exc.get_exception()
142+
assert isinstance(exc, type(exp_exc)), "Approximate equality"
143+
assert str(exp_exc) == str(exc), "Approximate equality; all args, kwargs conveyed"
144+
145+
146+
@pytest.mark.local
147+
def test_worker_bad_dynamic_import(tmpd_cwd):
148+
req = {
149+
"task_id": 15,
150+
"context": {
151+
"task_executor": {
152+
"f": "parsl.some.not_existing.module.__nope",
153+
"a": (1, 2),
154+
"k": {"a": "b"},
155+
},
156+
},
157+
"buffer": b"some serialized value"
158+
}
159+
160+
q = mock.Mock(side_effect=(req, MemoryError("intentional test error")))
161+
q.get = q
162+
163+
block_id = "bid"
164+
worker_id = 1
165+
pool = 1
166+
(tmpd_cwd / f"block-{block_id}/{worker_id}").mkdir(parents=True)
167+
with pytest.raises(MemoryError):
168+
worker(
169+
worker_id,
170+
pool_id=str(pool),
171+
pool_size=pool,
172+
task_queue=q,
173+
result_queue=q,
174+
monitoring_queue=None,
175+
ready_worker_count=SpawnContext.Value("i", 0),
176+
tasks_in_progress={},
177+
cpu_affinity="none",
178+
accelerator=None,
179+
block_id=block_id,
180+
task_queue_timeout=0,
181+
manager_pid=os.getpid(),
182+
logdir=str(tmpd_cwd),
183+
debug=True,
184+
mpi_launcher="",
185+
)
186+
(result_pkl,), _ = q.put.call_args
187+
r = pickle.loads(result_pkl)
188+
assert "exception" in r
189+
wrapped_exc: RemoteExceptionWrapper = deserialize(r["exception"])
190+
exc = wrapped_exc.get_exception()
191+
assert isinstance(exc, ModuleNotFoundError)
192+
assert "No module named" in str(exc)

0 commit comments

Comments
 (0)