Skip to content

Commit a8f84c1

Browse files
committed
Add possibility to add a context function to run before each worker.
1 parent ebd16fd commit a8f84c1

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,10 @@ These are the current HTCondor Backend parameters that can be set in the `parall
9595
- `log_dir_prefix`: Prefix for each of the HTCondor log files. If not specified, it will create a `logs` directory in the `initial_dir`.
9696
- `poll_interval`: Minimum time (in seconds) between polls to the HTCondor scheduler. Defaults to 5 seconds. A lower value will increase the load on the HTCondor collector as well as the filesystem, but will increase reactivity of the backend. A higher value will decrease the load, but the backend will take longer to react to changes in the queue. Important: there will be a poller for each nested parallel call, so the load on the system will be multiplied by the number of nested parallel calls.
9797
- `shared_data_dir`: Directory where the tasks and results will be saved. This directory must be shared across all the nodes in the HTCondor pool. If not specified, it will use a `joblib_htcondor_shared_data` directory inside the current working directory.
98+
- `delete_task_file_on_load`: Whether to delete the task file once the job enters RUNNING state. This can be useful when the tasks are data-intensive, and the shared disk space is limited. However, this will have the effect of not being able to load the task file in case of failure, thus making it impossible to retry the task (default False).
9899
- `worker_log_level`: Log level for the worker. Defaults to `INFO`.
99100
- `throttle`: Throttle the number of jobs submitted at once. If list, the first element is the throttle for the current level and the rest are for the nested levels (default None).
100101
- `batch_size`: Currently under development
101102
- `max_recursion_level`: Maximum recursion level for nested parallel calls. Defaults to 0 (no nested parallel calls allowed).
102103
- `export_metadata`: Export metadata to be used with the UI. Defaults to False.
104+
- `context_func`: A function to be called in the worker before running the actual function. This can be used to set up the context for the worker, for example by setting some global variables or importing some modules. The function will be serialized, so it can't rely on global variables. If None, no function will be called (default None).

joblib_htcondor/backend.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,11 @@ class _HTCondorBackend(ParallelBackendBase):
358358
Export metadata to a file, to be used with the HTCondor Joblib Monitor.
359359
This increases the load on the filesystem considerably if the number
360360
of jobs is high and the duration is short (default False).
361+
context_func : Callable or None, optional
362+
A function to call in the worker before running the actual function.
363+
This can be used to set global configuration variables in the worker,
364+
for example. The function will be serialized, so it can't rely on
365+
global variables. If None, no function will be called (default None).
361366
362367
Raises
363368
------
@@ -388,6 +393,7 @@ def __init__( # noqa: C901
388393
batch_size: int = 1,
389394
max_recursion_level: int = 0,
390395
export_metadata: bool = False,
396+
context_func: Optional[Callable] = None,
391397
) -> None:
392398
super().__init__()
393399

@@ -426,6 +432,7 @@ def __init__( # noqa: C901
426432
self._batch_size = batch_size
427433
self._max_recursion_level = max_recursion_level
428434
self._export_metadata = export_metadata
435+
self._context_func = context_func
429436

430437
self._recursion_level = 0
431438
self._parent_uuid = None
@@ -458,6 +465,7 @@ def __init__( # noqa: C901
458465
logger.debug(f"Batch Size: {self._batch_size}")
459466
logger.debug(f"Max recursion level: {self._max_recursion_level}")
460467
logger.debug(f"Export metadata: {self._export_metadata}")
468+
logger.debug(f"Context function: {self._context_func}")
461469

462470
logger.debug(f"Recursion level: {self._recursion_level}")
463471
logger.debug(f"Parent UUID: {self._parent_uuid}")
@@ -591,6 +599,7 @@ def get_nested_backend(self) -> tuple["ParallelBackendBase", int]:
591599
batch_size=self._batch_size,
592600
max_recursion_level=self._max_recursion_level,
593601
export_metadata=self._export_metadata,
602+
context_func=self._context_func,
594603
recursion_level=self._recursion_level + 1,
595604
parent_uuid=self._this_batch_name,
596605
), self._n_jobs
@@ -617,6 +626,7 @@ def __reduce__(self) -> tuple[Callable, tuple]:
617626
self._batch_size,
618627
self._max_recursion_level,
619628
self._export_metadata,
629+
self._context_func,
620630
self._recursion_level,
621631
self._parent_uuid,
622632
),
@@ -793,6 +803,8 @@ def _submit(
793803

794804
# Create the DelayedSubmission object
795805
ds = DelayedSubmission(func)
806+
if self._context_func is not None:
807+
ds.set_context_func(self._context_func)
796808
delete_file_param = (
797809
"--delete-file-on-load" if self._delete_task_file_on_load else ""
798810
)
@@ -1201,6 +1213,7 @@ def build(
12011213
batch_size: int = 1,
12021214
max_recursion_level: int = -1,
12031215
export_metadata: bool = False,
1216+
context_func: Optional[Callable] = None,
12041217
recursion_level: int = 0,
12051218
parent_uuid: Optional[str] = None,
12061219
) -> _HTCondorBackend:
@@ -1261,6 +1274,13 @@ def build(
12611274
Monitor. This increases the load on the filesystem considerably if
12621275
the number of jobs is high and the duration is short
12631276
(default False).
1277+
context_func : callable or None, optional
1278+
A function to be called in the worker before running the actual
1279+
function. This can be used to set up the context for the worker,
1280+
for example by setting some global variables or importing some
1281+
modules. The function will be serialized, so it can't rely on
1282+
global variables. If None, no function will be called (default
1283+
None).
12641284
recursion_level : int, optional
12651285
Recursion level of the backend. With each nested
12661286
call, the recursion level increases by 1 (default 0).
@@ -1297,6 +1317,7 @@ def build(
12971317
batch_size=batch_size,
12981318
max_recursion_level=max_recursion_level,
12991319
export_metadata=export_metadata,
1320+
context_func=context_func,
13001321
)
13011322
out._recursion_level = recursion_level
13021323
out._parent_uuid = parent_uuid # type: ignore

joblib_htcondor/delayed_submission.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,12 @@ def __init__(
7474
self._done = False
7575
self._error = False
7676
self._done_timestamp = None
77+
self.context_func = None
7778

7879
def run(self) -> None:
7980
"""Run the function with the arguments and store the result."""
81+
if self.context_func is not None:
82+
self.context_func()
8083
try:
8184
self._result = self.func(*self.args, **self.kwargs) # type: ignore
8285
except BaseException as e: # noqa: BLE001
@@ -88,6 +91,17 @@ def run(self) -> None:
8891
self._done_timestamp = datetime.now()
8992
self._done = True
9093

94+
def set_context_func(self, context_func: Callable) -> None:
95+
"""Set a context function to be called prior to the main function.
96+
97+
Parameters
98+
----------
99+
context_func : callable
100+
The context function to call before running the main function.
101+
102+
"""
103+
self.context_func = context_func
104+
91105
def done(self) -> bool:
92106
"""Return whether the function has been run.
93107

0 commit comments

Comments
 (0)