@@ -334,7 +334,65 @@ def test_move_data_failed(self, mock_status, mock_create, mock_sleep):
334334 mock_status .assert_called ()
335335
336336 @patch ("requests.post" )
337- def test_create_distributed_job (self , mock_post ):
337+ def test_create_training_job_single_node (self , mock_post ):
338+ """Test that single node jobs use the correct training endpoint and payload structure."""
339+ mock_response = MagicMock ()
340+ mock_response .status_code = 200
341+ mock_response .text = '{"status": "submitted"}'
342+ mock_post .return_value = mock_response
343+
344+ executor = DGXCloudExecutor (
345+ base_url = "https://dgxapi.example.com" ,
346+ app_id = "test_app_id" ,
347+ app_secret = "test_app_secret" ,
348+ project_name = "test_project" ,
349+ container_image = "nvcr.io/nvidia/test:latest" ,
350+ nodes = 1 ,
351+ gpus_per_node = 8 ,
352+ pvc_nemo_run_dir = "/workspace/nemo_run" ,
353+ pvcs = [{"path" : "workspace" , "claimName" : "test-claim" }],
354+ )
355+ executor .pvc_job_dir = "/workspace/nemo_run/job_dir"
356+ executor .env_vars = {"TEST_VAR" : "test_value" }
357+
358+ response = executor .create_training_job (
359+ token = "test_token" ,
360+ project_id = "proj_id" ,
361+ cluster_id = "cluster_id" ,
362+ name = "test_job" ,
363+ )
364+
365+ assert response == mock_response
366+
367+ # Check if the API call is made correctly for single node
368+ mock_post .assert_called_once ()
369+ args , kwargs = mock_post .call_args
370+
371+ # Verify single node endpoint
372+ assert args [0 ] == "https://dgxapi.example.com/workloads/trainings"
373+
374+ # Verify payload structure for single node job
375+ assert kwargs ["json" ]["name" ] == "test_job"
376+ assert kwargs ["json" ]["projectId" ] == "proj_id"
377+ assert kwargs ["json" ]["clusterId" ] == "cluster_id"
378+ assert kwargs ["json" ]["spec" ]["image" ] == "nvcr.io/nvidia/test:latest"
379+ assert (
380+ kwargs ["json" ]["spec" ]["command" ]
381+ == "/bin/bash /workspace/nemo_run/job_dir/launch_script.sh"
382+ )
383+ assert kwargs ["json" ]["spec" ]["compute" ]["gpuDevicesRequest" ] == 8
384+
385+ # Verify distributed-specific fields are NOT present
386+ assert "distributedFramework" not in kwargs ["json" ]["spec" ]
387+ assert "minReplicas" not in kwargs ["json" ]["spec" ]
388+ assert "maxReplicas" not in kwargs ["json" ]["spec" ]
389+ assert "numWorkers" not in kwargs ["json" ]["spec" ]
390+
391+ assert kwargs ["headers" ] == executor ._default_headers (token = "test_token" )
392+
393+ @patch ("requests.post" )
394+ def test_create_training_job_multi_node (self , mock_post ):
395+ """Test that multi-node jobs use the correct distributed endpoint and payload structure."""
338396 mock_response = MagicMock ()
339397 mock_response .status_code = 200
340398 mock_response .text = '{"status": "submitted"}'
@@ -348,13 +406,14 @@ def test_create_distributed_job(self, mock_post):
348406 container_image = "nvcr.io/nvidia/test:latest" ,
349407 nodes = 2 ,
350408 gpus_per_node = 8 ,
409+ distributed_framework = "PyTorch" ,
351410 pvc_nemo_run_dir = "/workspace/nemo_run" ,
352411 pvcs = [{"path" : "workspace" , "claimName" : "test-claim" }],
353412 )
354413 executor .pvc_job_dir = "/workspace/nemo_run/job_dir"
355414 executor .env_vars = {"TEST_VAR" : "test_value" }
356415
357- response = executor .create_distributed_job (
416+ response = executor .create_training_job (
358417 token = "test_token" ,
359418 project_id = "proj_id" ,
360419 cluster_id = "cluster_id" ,
@@ -363,10 +422,14 @@ def test_create_distributed_job(self, mock_post):
363422
364423 assert response == mock_response
365424
366- # Check if the API call is made correctly
425+ # Check if the API call is made correctly for multi-node
367426 mock_post .assert_called_once ()
368- # The URL is the first argument to post
369427 args , kwargs = mock_post .call_args
428+
429+ # Verify multi-node endpoint
430+ assert args [0 ] == "https://dgxapi.example.com/workloads/distributed"
431+
432+ # Verify payload structure for multi-node job
370433 assert kwargs ["json" ]["name" ] == "test_job"
371434 assert kwargs ["json" ]["projectId" ] == "proj_id"
372435 assert kwargs ["json" ]["clusterId" ] == "cluster_id"
@@ -375,18 +438,24 @@ def test_create_distributed_job(self, mock_post):
375438 kwargs ["json" ]["spec" ]["command" ]
376439 == "/bin/bash /workspace/nemo_run/job_dir/launch_script.sh"
377440 )
378- assert kwargs ["json" ]["spec" ]["numWorkers" ] == 2
379441 assert kwargs ["json" ]["spec" ]["compute" ]["gpuDevicesRequest" ] == 8
380- assert kwargs ["json" ]["spec" ]["environmentVariables" ] == [
381- {"name" : "TEST_VAR" , "value" : "test_value" }
382- ]
442+
443+ # Verify distributed-specific fields
444+ assert kwargs ["json" ]["spec" ]["distributedFramework" ] == "PyTorch"
445+ assert kwargs ["json" ]["spec" ]["minReplicas" ] == 2
446+ assert kwargs ["json" ]["spec" ]["maxReplicas" ] == 2
447+ assert kwargs ["json" ]["spec" ]["numWorkers" ] == 2
448+
383449 assert kwargs ["headers" ] == executor ._default_headers (token = "test_token" )
384450
385451 @patch .object (DGXCloudExecutor , "get_auth_token" )
386452 @patch .object (DGXCloudExecutor , "get_project_and_cluster_id" )
387453 @patch .object (DGXCloudExecutor , "move_data" )
388- @patch .object (DGXCloudExecutor , "create_distributed_job" )
389- def test_launch_success (self , mock_create_job , mock_move_data , mock_get_ids , mock_get_token ):
454+ @patch .object (DGXCloudExecutor , "create_training_job" )
455+ def test_launch_single_node (
456+ self , mock_create_job , mock_move_data , mock_get_ids , mock_get_token
457+ ):
458+ """Test that launch correctly handles single-node job submission."""
390459 mock_get_token .return_value = "test_token"
391460 mock_get_ids .return_value = ("proj_id" , "cluster_id" )
392461
@@ -402,7 +471,10 @@ def test_launch_success(self, mock_create_job, mock_move_data, mock_get_ids, moc
402471 app_secret = "test_app_secret" ,
403472 project_name = "test_project" ,
404473 container_image = "nvcr.io/nvidia/test:latest" ,
474+ nodes = 1 , # Single node
475+ gpus_per_node = 8 , # 8 GPUs per node
405476 pvc_nemo_run_dir = "/workspace/nemo_run" ,
477+ pvcs = [{"path" : "/workspace" , "claimName" : "test-claim" }],
406478 )
407479 executor .job_dir = tmp_dir
408480
@@ -411,13 +483,68 @@ def test_launch_success(self, mock_create_job, mock_move_data, mock_get_ids, moc
411483 assert job_id == "job123"
412484 assert status == "Pending"
413485 assert os .path .exists (os .path .join (tmp_dir , "launch_script.sh" ))
486+
487+ # Verify launch script contents for single node
488+ with open (os .path .join (tmp_dir , "launch_script.sh" ), "r" ) as f :
489+ script = f .read ()
490+ assert "python train.py" in script
491+
414492 mock_get_token .assert_called_once ()
415493 mock_get_ids .assert_called_once_with ("test_token" )
416494 mock_move_data .assert_called_once_with ("test_token" , "proj_id" , "cluster_id" )
417495 mock_create_job .assert_called_once_with (
418496 "test_token" , "proj_id" , "cluster_id" , "test-job"
419497 )
420498
499+ @patch .object (DGXCloudExecutor , "get_auth_token" )
500+ @patch .object (DGXCloudExecutor , "get_project_and_cluster_id" )
501+ @patch .object (DGXCloudExecutor , "move_data" )
502+ @patch .object (DGXCloudExecutor , "create_training_job" )
503+ def test_launch_multi_node (self , mock_create_job , mock_move_data , mock_get_ids , mock_get_token ):
504+ """Test that launch correctly handles multi-node job submission."""
505+ mock_get_token .return_value = "test_token"
506+ mock_get_ids .return_value = ("proj_id" , "cluster_id" )
507+
508+ mock_response = MagicMock ()
509+ mock_response .status_code = 200
510+ mock_response .json .return_value = {"workloadId" : "job456" , "actualPhase" : "Pending" }
511+ mock_create_job .return_value = mock_response
512+
513+ with tempfile .TemporaryDirectory () as tmp_dir :
514+ executor = DGXCloudExecutor (
515+ base_url = "https://dgxapi.example.com" ,
516+ app_id = "test_app_id" ,
517+ app_secret = "test_app_secret" ,
518+ project_name = "test_project" ,
519+ container_image = "nvcr.io/nvidia/test:latest" ,
520+ nodes = 2 , # Multi-node
521+ gpus_per_node = 8 ,
522+ distributed_framework = "PyTorch" ,
523+ pvc_nemo_run_dir = "/workspace/nemo_run" ,
524+ pvcs = [{"path" : "/workspace" , "claimName" : "test-claim" }],
525+ )
526+ executor .job_dir = tmp_dir
527+
528+ job_id , status = executor .launch (
529+ "test_multi_job" , ["python" , "-m" , "torch.distributed.run" , "train.py" ]
530+ )
531+
532+ assert job_id == "job456"
533+ assert status == "Pending"
534+ assert os .path .exists (os .path .join (tmp_dir , "launch_script.sh" ))
535+
536+ # Verify launch script contents for multi-node
537+ with open (os .path .join (tmp_dir , "launch_script.sh" ), "r" ) as f :
538+ script = f .read ()
539+ assert "python -m torch.distributed.run train.py" in script
540+
541+ mock_get_token .assert_called_once ()
542+ mock_get_ids .assert_called_once_with ("test_token" )
543+ mock_move_data .assert_called_once_with ("test_token" , "proj_id" , "cluster_id" )
544+ mock_create_job .assert_called_once_with (
545+ "test_token" , "proj_id" , "cluster_id" , "test-multi-job"
546+ )
547+
421548 @patch .object (DGXCloudExecutor , "get_auth_token" )
422549 def test_launch_no_token (self , mock_get_token ):
423550 mock_get_token .return_value = None
@@ -455,7 +582,7 @@ def test_launch_no_project_id(self, mock_get_ids, mock_get_token):
455582 @patch .object (DGXCloudExecutor , "get_auth_token" )
456583 @patch .object (DGXCloudExecutor , "get_project_and_cluster_id" )
457584 @patch .object (DGXCloudExecutor , "move_data" )
458- @patch .object (DGXCloudExecutor , "create_distributed_job " )
585+ @patch .object (DGXCloudExecutor , "create_training_job " )
459586 def test_launch_job_creation_failed (
460587 self , mock_create_job , mock_move_data , mock_get_ids , mock_get_token
461588 ):
0 commit comments