Skip to content

Commit c8bc7cc

Browse files
authored
Add support for job groups for local executor (#220)
* Add support for job groups for local executor Signed-off-by: Hemil Desai <[email protected]> * handle proper exit Signed-off-by: Hemil Desai <[email protected]> * fix Signed-off-by: Hemil Desai <[email protected]> * fix Signed-off-by: Hemil Desai <[email protected]> * fix Signed-off-by: Hemil Desai <[email protected]> --------- Signed-off-by: Hemil Desai <[email protected]>
1 parent 67de154 commit c8bc7cc

File tree

4 files changed

+64
-16
lines changed

4 files changed

+64
-16
lines changed

nemo_run/run/job.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from nemo_run.config import Config, ConfigurableMixin, Partial, Script
1010
from nemo_run.core.execution.base import Executor
1111
from nemo_run.core.execution.docker import DockerExecutor
12+
from nemo_run.core.execution.local import LocalExecutor
1213
from nemo_run.core.execution.slurm import SlurmExecutor
1314
from nemo_run.core.frontend.console.api import CONSOLE
1415
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
@@ -211,7 +212,7 @@ class JobGroup(ConfigurableMixin):
211212
handle groups of related tasks.
212213
"""
213214

214-
SUPPORTED_EXECUTORS = [SlurmExecutor, DockerExecutor]
215+
SUPPORTED_EXECUTORS = [SlurmExecutor, DockerExecutor, LocalExecutor]
215216

216217
id: str
217218
tasks: list[Union[Partial, Script]]
@@ -370,16 +371,19 @@ def launch(
370371
self.launched = True
371372

372373
def wait(self, runner: Runner | None = None):
373-
assert len(self.handles) == 1, "Only one handle is supported for task groups currently."
374+
new_states = []
374375
try:
375-
status = wait_and_exit(
376-
app_handle=self.handles[0],
377-
log=self.tail_logs,
378-
runner=runner,
379-
)
380-
self.states = [status.state]
376+
for handle in self.handles:
377+
status = wait_and_exit(
378+
app_handle=handle,
379+
log=self.tail_logs,
380+
runner=runner,
381+
)
382+
new_states.append(status.state)
381383
except nemo_run.exceptions.UnknownStatusError:
382-
self.states = [AppState.UNKNOWN]
384+
new_states = [AppState.UNKNOWN]
385+
386+
self.states = new_states
383387

384388
def cancel(self, runner: Runner):
385389
if not self.handles:

nemo_run/run/torchx_backend/launcher.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def wait_and_exit(
122122
log: bool,
123123
runner: Runner | None = None,
124124
timeout: int = 10,
125+
log_join_timeout: int = 600,
125126
) -> specs.AppStatus:
126127
if runner is None:
127128
runner = get_runner()
@@ -159,6 +160,12 @@ def wait_and_exit(
159160

160161
logger.info(f"Job {app_id} finished: {status.state}")
161162

163+
if log_thread and log_thread.is_alive():
164+
logger.debug("Waiting for log thread to complete...")
165+
log_thread.join(timeout=log_join_timeout)
166+
if log_thread.is_alive():
167+
logger.warning("Log thread did not complete within timeout, some logs may be missing")
168+
162169
return status
163170

164171

nemo_run/run/torchx_backend/schedulers/local.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,12 @@
4444
PopenRequest,
4545
_LocalAppDef,
4646
)
47-
from torchx.specs.api import AppDef, AppState, Role
47+
from torchx.specs.api import AppDef, AppState, Role, is_terminal, parse_app_handle
4848

4949
from nemo_run.config import get_nemorun_home
5050
from nemo_run.core.execution.base import Executor
5151
from nemo_run.core.execution.local import LocalExecutor
52+
from nemo_run.run import experiment as run_experiment
5253
from nemo_run.run.torchx_backend.schedulers.api import SchedulerMixin
5354

5455
try:
@@ -69,6 +70,7 @@ def __init__(
6970
image_provider_class: Callable[[LocalOpts], ImageProvider],
7071
cache_size: int = 100,
7172
extra_paths: Optional[list[str]] = None,
73+
experiment: Optional[run_experiment.Experiment] = None,
7274
) -> None:
7375
# NOTE: make sure any new init options are supported in create_scheduler(...)
7476
self.backend = "local"
@@ -86,6 +88,7 @@ def __init__(
8688
# sets lazily on submit or dryrun based on log_dir cfg
8789
self._base_log_dir: Optional[str] = None
8890
self._created_tmp_log_dir: bool = False
91+
self.experiment = experiment
8992

9093
def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[PopenRequest]: # type: ignore
9194
assert isinstance(cfg, LocalExecutor), f"{cfg.__class__} not supported for local scheduler."
@@ -106,9 +109,31 @@ def schedule(self, dryrun_info: AppDryRunInfo[PopenRequest]) -> str:
106109

107110
def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
108111
resp = super().describe(app_id=app_id)
109-
110112
if resp:
111113
_save_job_dir(self._apps)
114+
if self.experiment:
115+
maybe_job_to_kill = None
116+
for job in self.experiment.jobs:
117+
if isinstance(job, run_experiment.JobGroup):
118+
for handle in job.handles:
119+
_, _, job_id = parse_app_handle(handle)
120+
if job_id == app_id:
121+
maybe_job_to_kill = job
122+
break
123+
124+
if maybe_job_to_kill:
125+
to_kill = False
126+
for handle in maybe_job_to_kill.handles:
127+
_, _, _id = parse_app_handle(handle)
128+
resp = super().describe(app_id=_id)
129+
if resp and is_terminal(resp.state):
130+
to_kill = True
131+
132+
if to_kill:
133+
for handle in maybe_job_to_kill.handles:
134+
_, _, _id = parse_app_handle(handle)
135+
self._apps[_id].kill()
136+
112137
return resp
113138

114139
saved_apps = _get_job_dirs()
@@ -173,11 +198,14 @@ def create_scheduler(
173198
extra_paths: Optional[list[str]] = None,
174199
**kwargs: Any,
175200
) -> PersistentLocalScheduler:
201+
experiment = kwargs.pop("experiment", None)
202+
176203
return PersistentLocalScheduler(
177204
session_name=session_name,
178205
image_provider_class=CWDImageProvider,
179206
cache_size=cache_size,
180207
extra_paths=extra_paths,
208+
experiment=experiment,
181209
)
182210

183211

test/core/execution/test_dgxcloud.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020

2121
import pytest
2222

23+
from nemo_run.config import set_nemorun_home
2324
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor, DGXCloudState
2425
from nemo_run.core.packaging.git import GitArchivePackager
25-
from nemo_run.config import set_nemorun_home
2626

2727

2828
class TestDGXCloudExecutor:
@@ -252,17 +252,20 @@ def test_delete_workload(self, mock_delete):
252252
headers=executor._default_headers(token="test_token"),
253253
)
254254

255+
@patch("time.sleep")
255256
@patch.object(DGXCloudExecutor, "create_data_mover_workload")
256257
@patch.object(DGXCloudExecutor, "status")
257258
@patch.object(DGXCloudExecutor, "delete_workload")
258-
def test_move_data_success(self, mock_delete, mock_status, mock_create):
259+
def test_move_data_success(self, mock_delete, mock_status, mock_create, mock_sleep):
259260
mock_response = MagicMock()
260261
mock_response.status_code = 200
261262
mock_response.json.return_value = {"workloadId": "job123", "actualPhase": "Pending"}
262263
mock_create.return_value = mock_response
263264
mock_delete.return_value = mock_response
264265

265-
mock_status.return_value = DGXCloudState.COMPLETED
266+
# Set up status to change after first check to avoid infinite loop
267+
# First return PENDING, then return COMPLETED
268+
mock_status.side_effect = [DGXCloudState.PENDING, DGXCloudState.COMPLETED]
266269

267270
executor = DGXCloudExecutor(
268271
base_url="https://dgxapi.example.com",
@@ -276,12 +279,17 @@ def test_move_data_success(self, mock_delete, mock_status, mock_create):
276279

277280
executor.move_data(token="test_token", project_id="proj_id", cluster_id="cluster_id")
278281

282+
# Verify all expected calls were made
279283
mock_create.assert_called_once_with("test_token", "proj_id", "cluster_id")
280284
mock_status.assert_called()
281285
mock_delete.assert_called_once_with("test_token", "job123")
282286

287+
# Verify time.sleep was called
288+
mock_sleep.assert_called()
289+
290+
@patch("time.sleep")
283291
@patch.object(DGXCloudExecutor, "create_data_mover_workload")
284-
def test_move_data_data_mover_fail(self, mock_create):
292+
def test_move_data_data_mover_fail(self, mock_create, mock_sleep):
285293
mock_response = MagicMock()
286294
mock_response.status_code = 400
287295

@@ -298,9 +306,10 @@ def test_move_data_data_mover_fail(self, mock_create):
298306
with pytest.raises(RuntimeError, match="Failed to create data mover workload"):
299307
executor.move_data(token="test_token", project_id="proj_id", cluster_id="cluster_id")
300308

309+
@patch("time.sleep")
301310
@patch.object(DGXCloudExecutor, "create_data_mover_workload")
302311
@patch.object(DGXCloudExecutor, "status")
303-
def test_move_data_failed(self, mock_status, mock_create):
312+
def test_move_data_failed(self, mock_status, mock_create, mock_sleep):
304313
mock_response = MagicMock()
305314
mock_response.status_code = 200
306315
mock_response.json.return_value = {"workloadId": "job123", "actualPhase": "Pending"}

0 commit comments

Comments
 (0)