@@ -723,6 +723,51 @@ def test_start_job_with_cli_args_label_selection_not_json_str(
723723 )
724724 self .assertEqual (result .exit_code , 1 )
725725
726+ @mock .patch ("yaml.dump" )
727+ @mock .patch ("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__" )
728+ @mock .patch ("hyperpod_cli.commands.job.JobValidator" )
729+ @mock .patch ("boto3.Session" )
730+ def test_start_job_with_cli_args_pre_script_and_post_script (
731+ self ,
732+ mock_boto3 ,
733+ mock_validator_cls ,
734+ mock_kubernetes_client ,
735+ mock_yaml_dump ,
736+ ):
737+ mock_validator = mock_validator_cls .return_value
738+ mock_validator .validate_aws_credential .return_value = True
739+ mock_kubernetes_client .get_current_context_namespace .return_value = "kubeflow"
740+ mock_yaml_dump .return_value = None
741+ result = self .runner .invoke (
742+ start_job ,
743+ [
744+ "--job-name" ,
745+ "test-job" ,
746+ "--instance-type" ,
747+ "ml.c5.xlarge" ,
748+ "--image" ,
749+ "pytorch:1.9.0-cuda11.1-cudnn8-runtime" ,
750+ "--node-count" ,
751+ "2" ,
752+ "--label-selector" ,
753+ "{NonJsonStr" ,
754+ "--entry-script" ,
755+ "/opt/train/src/train.py" ,
756+ "--pre-script" ,
757+ "echo 'test', echo 'test 1'" ,
758+ "--post-script" ,
759+ "echo 'test 1', echo 'test 2'"
760+ ],
761+ )
762+
763+ # Assert that yaml.dump was called with the correct configuration
764+ mock_yaml_dump .assert_called_once ()
765+ call_args = mock_yaml_dump .call_args [0 ]
766+ self .assertEqual (call_args [0 ]['cluster' ]['cluster_config' ]['pre_script' ], ["echo 'test'" , " echo 'test 1'" ])
767+ self .assertEqual (call_args [0 ]['cluster' ]['cluster_config' ]['post_script' ], ["echo 'test 1'" , " echo 'test 2'" ])
768+
769+ self .assertEqual (result .exit_code , 1 )
770+
726771 @mock .patch ("yaml.dump" )
727772 @mock .patch ("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__" )
728773 @mock .patch ("hyperpod_cli.commands.job.JobValidator" )
0 commit comments