Skip to content

Commit 2aa0a60

Browse files
delgadofFrancisco Delgado
andauthored
change the create dist job functionn to support creating a single nod… (#240)
* change the create dist job functionn to support creating a single node job and distribuited jobs Signed-off-by: Francisco Delgado <[email protected]> * Modified the formatting to meet the repository requirements Signed-off-by: Francisco Delgado <[email protected]> * Added testing for single and multi-node dist workloads. Signed-off-by: Francisco Delgado <[email protected]> --------- Signed-off-by: Francisco Delgado <[email protected]> Co-authored-by: Francisco Delgado <[email protected]>
1 parent 9512c3b commit 2aa0a60

File tree

2 files changed

+192
-30
lines changed

2 files changed

+192
-30
lines changed

nemo_run/core/execution/dgxcloud.py

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -202,41 +202,76 @@ def move_data(self, token: str, project_id: str, cluster_id: str, sleep: float =
202202
resp.text,
203203
)
204204

205-
def create_distributed_job(self, token: str, project_id: str, cluster_id: str, name: str):
205+
def create_training_job(
206+
self, token: str, project_id: str, cluster_id: str, name: str
207+
) -> requests.Response:
206208
"""
207-
Creates a distributed PyTorch job using the provided project/cluster IDs.
209+
Creates a training job on DGX Cloud using the provided project/cluster IDs.
210+
For multi-node jobs, creates a distributed workload. Otherwise creates a single-node training.
211+
212+
Args:
213+
token: Authentication token for DGX Cloud API
214+
project_id: ID of the project to create the job in
215+
cluster_id: ID of the cluster to create the job on
216+
name: Name for the job
217+
218+
Returns:
219+
Response object from the API request
208220
"""
221+
# Validate inputs
222+
if not token or not project_id or not cluster_id:
223+
raise ValueError("Token, project ID, and cluster ID are required")
209224

210-
url = f"{self.base_url}/workloads/distributed"
211-
headers = self._default_headers(token=token)
225+
if self.nodes < 1:
226+
raise ValueError("Node count must be at least 1")
212227

213-
payload = {
228+
# Common payload elements
229+
common_payload = {
214230
"name": name,
215231
"useGivenNameAsPrefix": True,
216232
"projectId": project_id,
217233
"clusterId": cluster_id,
218-
"spec": {
219-
"command": f"/bin/bash {self.pvc_job_dir}/launch_script.sh",
220-
"image": self.container_image,
234+
}
235+
236+
# Common spec elements
237+
common_spec = {
238+
"command": f"/bin/bash {self.pvc_job_dir}/launch_script.sh",
239+
"image": self.container_image,
240+
"compute": {"gpuDevicesRequest": self.gpus_per_node},
241+
"storage": {"pvc": self.pvcs},
242+
"environmentVariables": [
243+
{"name": key, "value": value} for key, value in self.env_vars.items()
244+
],
245+
**self.custom_spec,
246+
}
247+
248+
# Determine endpoint and build payload based on node count
249+
if self.nodes > 1:
250+
url = f"{self.base_url}/workloads/distributed"
251+
252+
# Add distributed-specific parameters
253+
distributed_spec = {
221254
"distributedFramework": self.distributed_framework,
222255
"minReplicas": self.nodes,
223256
"maxReplicas": self.nodes,
224257
"numWorkers": self.nodes,
225-
"compute": {"gpuDevicesRequest": self.gpus_per_node},
226-
"storage": {"pvc": self.pvcs},
227-
"environmentVariables": [
228-
{"name": key, "value": value} for key, value in self.env_vars.items()
229-
],
230-
**self.custom_spec,
231-
},
232-
}
258+
}
233259

260+
payload = {**common_payload, "spec": {**common_spec, **distributed_spec}}
261+
else:
262+
url = f"{self.base_url}/workloads/trainings"
263+
payload = {**common_payload, "spec": common_spec}
264+
265+
headers = self._default_headers(token=token)
234266
response = requests.post(url, json=payload, headers=headers)
267+
235268
logger.debug(
236-
"Created distributed job; response code=%s, content=%s",
269+
"Created %s job; response code=%s, content=%s",
270+
"distributed" if self.nodes > 1 else "training",
237271
response.status_code,
238272
response.text.strip(),
239273
)
274+
240275
return response
241276

242277
def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
@@ -262,8 +297,8 @@ def launch(self, name: str, cmd: list[str]) -> tuple[str, str]:
262297
logger.info("Creating data movement workload")
263298
self.move_data(token, project_id, cluster_id)
264299

265-
logger.info("Creating distributed workload")
266-
resp = self.create_distributed_job(token, project_id, cluster_id, name)
300+
logger.info("Creating training workload")
301+
resp = self.create_training_job(token, project_id, cluster_id, name)
267302
if resp.status_code not in [200, 202]:
268303
raise RuntimeError(
269304
f"Failed to create job, status_code={resp.status_code}, reason={resp.text}"

test/core/execution/test_dgxcloud.py

Lines changed: 138 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,65 @@ def test_move_data_failed(self, mock_status, mock_create, mock_sleep):
334334
mock_status.assert_called()
335335

336336
@patch("requests.post")
337-
def test_create_distributed_job(self, mock_post):
337+
def test_create_training_job_single_node(self, mock_post):
338+
"""Test that single node jobs use the correct training endpoint and payload structure."""
339+
mock_response = MagicMock()
340+
mock_response.status_code = 200
341+
mock_response.text = '{"status": "submitted"}'
342+
mock_post.return_value = mock_response
343+
344+
executor = DGXCloudExecutor(
345+
base_url="https://dgxapi.example.com",
346+
app_id="test_app_id",
347+
app_secret="test_app_secret",
348+
project_name="test_project",
349+
container_image="nvcr.io/nvidia/test:latest",
350+
nodes=1,
351+
gpus_per_node=8,
352+
pvc_nemo_run_dir="/workspace/nemo_run",
353+
pvcs=[{"path": "workspace", "claimName": "test-claim"}],
354+
)
355+
executor.pvc_job_dir = "/workspace/nemo_run/job_dir"
356+
executor.env_vars = {"TEST_VAR": "test_value"}
357+
358+
response = executor.create_training_job(
359+
token="test_token",
360+
project_id="proj_id",
361+
cluster_id="cluster_id",
362+
name="test_job",
363+
)
364+
365+
assert response == mock_response
366+
367+
# Check if the API call is made correctly for single node
368+
mock_post.assert_called_once()
369+
args, kwargs = mock_post.call_args
370+
371+
# Verify single node endpoint
372+
assert args[0] == "https://dgxapi.example.com/workloads/trainings"
373+
374+
# Verify payload structure for single node job
375+
assert kwargs["json"]["name"] == "test_job"
376+
assert kwargs["json"]["projectId"] == "proj_id"
377+
assert kwargs["json"]["clusterId"] == "cluster_id"
378+
assert kwargs["json"]["spec"]["image"] == "nvcr.io/nvidia/test:latest"
379+
assert (
380+
kwargs["json"]["spec"]["command"]
381+
== "/bin/bash /workspace/nemo_run/job_dir/launch_script.sh"
382+
)
383+
assert kwargs["json"]["spec"]["compute"]["gpuDevicesRequest"] == 8
384+
385+
# Verify distributed-specific fields are NOT present
386+
assert "distributedFramework" not in kwargs["json"]["spec"]
387+
assert "minReplicas" not in kwargs["json"]["spec"]
388+
assert "maxReplicas" not in kwargs["json"]["spec"]
389+
assert "numWorkers" not in kwargs["json"]["spec"]
390+
391+
assert kwargs["headers"] == executor._default_headers(token="test_token")
392+
393+
@patch("requests.post")
394+
def test_create_training_job_multi_node(self, mock_post):
395+
"""Test that multi-node jobs use the correct distributed endpoint and payload structure."""
338396
mock_response = MagicMock()
339397
mock_response.status_code = 200
340398
mock_response.text = '{"status": "submitted"}'
@@ -348,13 +406,14 @@ def test_create_distributed_job(self, mock_post):
348406
container_image="nvcr.io/nvidia/test:latest",
349407
nodes=2,
350408
gpus_per_node=8,
409+
distributed_framework="PyTorch",
351410
pvc_nemo_run_dir="/workspace/nemo_run",
352411
pvcs=[{"path": "workspace", "claimName": "test-claim"}],
353412
)
354413
executor.pvc_job_dir = "/workspace/nemo_run/job_dir"
355414
executor.env_vars = {"TEST_VAR": "test_value"}
356415

357-
response = executor.create_distributed_job(
416+
response = executor.create_training_job(
358417
token="test_token",
359418
project_id="proj_id",
360419
cluster_id="cluster_id",
@@ -363,10 +422,14 @@ def test_create_distributed_job(self, mock_post):
363422

364423
assert response == mock_response
365424

366-
# Check if the API call is made correctly
425+
# Check if the API call is made correctly for multi-node
367426
mock_post.assert_called_once()
368-
# The URL is the first argument to post
369427
args, kwargs = mock_post.call_args
428+
429+
# Verify multi-node endpoint
430+
assert args[0] == "https://dgxapi.example.com/workloads/distributed"
431+
432+
# Verify payload structure for multi-node job
370433
assert kwargs["json"]["name"] == "test_job"
371434
assert kwargs["json"]["projectId"] == "proj_id"
372435
assert kwargs["json"]["clusterId"] == "cluster_id"
@@ -375,18 +438,24 @@ def test_create_distributed_job(self, mock_post):
375438
kwargs["json"]["spec"]["command"]
376439
== "/bin/bash /workspace/nemo_run/job_dir/launch_script.sh"
377440
)
378-
assert kwargs["json"]["spec"]["numWorkers"] == 2
379441
assert kwargs["json"]["spec"]["compute"]["gpuDevicesRequest"] == 8
380-
assert kwargs["json"]["spec"]["environmentVariables"] == [
381-
{"name": "TEST_VAR", "value": "test_value"}
382-
]
442+
443+
# Verify distributed-specific fields
444+
assert kwargs["json"]["spec"]["distributedFramework"] == "PyTorch"
445+
assert kwargs["json"]["spec"]["minReplicas"] == 2
446+
assert kwargs["json"]["spec"]["maxReplicas"] == 2
447+
assert kwargs["json"]["spec"]["numWorkers"] == 2
448+
383449
assert kwargs["headers"] == executor._default_headers(token="test_token")
384450

385451
@patch.object(DGXCloudExecutor, "get_auth_token")
386452
@patch.object(DGXCloudExecutor, "get_project_and_cluster_id")
387453
@patch.object(DGXCloudExecutor, "move_data")
388-
@patch.object(DGXCloudExecutor, "create_distributed_job")
389-
def test_launch_success(self, mock_create_job, mock_move_data, mock_get_ids, mock_get_token):
454+
@patch.object(DGXCloudExecutor, "create_training_job")
455+
def test_launch_single_node(
456+
self, mock_create_job, mock_move_data, mock_get_ids, mock_get_token
457+
):
458+
"""Test that launch correctly handles single-node job submission."""
390459
mock_get_token.return_value = "test_token"
391460
mock_get_ids.return_value = ("proj_id", "cluster_id")
392461

@@ -402,7 +471,10 @@ def test_launch_success(self, mock_create_job, mock_move_data, mock_get_ids, moc
402471
app_secret="test_app_secret",
403472
project_name="test_project",
404473
container_image="nvcr.io/nvidia/test:latest",
474+
nodes=1, # Single node
475+
gpus_per_node=8, # 8 GPUs per node
405476
pvc_nemo_run_dir="/workspace/nemo_run",
477+
pvcs=[{"path": "/workspace", "claimName": "test-claim"}],
406478
)
407479
executor.job_dir = tmp_dir
408480

@@ -411,13 +483,68 @@ def test_launch_success(self, mock_create_job, mock_move_data, mock_get_ids, moc
411483
assert job_id == "job123"
412484
assert status == "Pending"
413485
assert os.path.exists(os.path.join(tmp_dir, "launch_script.sh"))
486+
487+
# Verify launch script contents for single node
488+
with open(os.path.join(tmp_dir, "launch_script.sh"), "r") as f:
489+
script = f.read()
490+
assert "python train.py" in script
491+
414492
mock_get_token.assert_called_once()
415493
mock_get_ids.assert_called_once_with("test_token")
416494
mock_move_data.assert_called_once_with("test_token", "proj_id", "cluster_id")
417495
mock_create_job.assert_called_once_with(
418496
"test_token", "proj_id", "cluster_id", "test-job"
419497
)
420498

499+
@patch.object(DGXCloudExecutor, "get_auth_token")
500+
@patch.object(DGXCloudExecutor, "get_project_and_cluster_id")
501+
@patch.object(DGXCloudExecutor, "move_data")
502+
@patch.object(DGXCloudExecutor, "create_training_job")
503+
def test_launch_multi_node(self, mock_create_job, mock_move_data, mock_get_ids, mock_get_token):
504+
"""Test that launch correctly handles multi-node job submission."""
505+
mock_get_token.return_value = "test_token"
506+
mock_get_ids.return_value = ("proj_id", "cluster_id")
507+
508+
mock_response = MagicMock()
509+
mock_response.status_code = 200
510+
mock_response.json.return_value = {"workloadId": "job456", "actualPhase": "Pending"}
511+
mock_create_job.return_value = mock_response
512+
513+
with tempfile.TemporaryDirectory() as tmp_dir:
514+
executor = DGXCloudExecutor(
515+
base_url="https://dgxapi.example.com",
516+
app_id="test_app_id",
517+
app_secret="test_app_secret",
518+
project_name="test_project",
519+
container_image="nvcr.io/nvidia/test:latest",
520+
nodes=2, # Multi-node
521+
gpus_per_node=8,
522+
distributed_framework="PyTorch",
523+
pvc_nemo_run_dir="/workspace/nemo_run",
524+
pvcs=[{"path": "/workspace", "claimName": "test-claim"}],
525+
)
526+
executor.job_dir = tmp_dir
527+
528+
job_id, status = executor.launch(
529+
"test_multi_job", ["python", "-m", "torch.distributed.run", "train.py"]
530+
)
531+
532+
assert job_id == "job456"
533+
assert status == "Pending"
534+
assert os.path.exists(os.path.join(tmp_dir, "launch_script.sh"))
535+
536+
# Verify launch script contents for multi-node
537+
with open(os.path.join(tmp_dir, "launch_script.sh"), "r") as f:
538+
script = f.read()
539+
assert "python -m torch.distributed.run train.py" in script
540+
541+
mock_get_token.assert_called_once()
542+
mock_get_ids.assert_called_once_with("test_token")
543+
mock_move_data.assert_called_once_with("test_token", "proj_id", "cluster_id")
544+
mock_create_job.assert_called_once_with(
545+
"test_token", "proj_id", "cluster_id", "test-multi-job"
546+
)
547+
421548
@patch.object(DGXCloudExecutor, "get_auth_token")
422549
def test_launch_no_token(self, mock_get_token):
423550
mock_get_token.return_value = None
@@ -455,7 +582,7 @@ def test_launch_no_project_id(self, mock_get_ids, mock_get_token):
455582
@patch.object(DGXCloudExecutor, "get_auth_token")
456583
@patch.object(DGXCloudExecutor, "get_project_and_cluster_id")
457584
@patch.object(DGXCloudExecutor, "move_data")
458-
@patch.object(DGXCloudExecutor, "create_distributed_job")
585+
@patch.object(DGXCloudExecutor, "create_training_job")
459586
def test_launch_job_creation_failed(
460587
self, mock_create_job, mock_move_data, mock_get_ids, mock_get_token
461588
):

0 commit comments

Comments
 (0)