Skip to content

Commit 441cd0f

Browse files
test: added unit test cases
1 parent eca509a commit 441cd0f

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

test/unit_tests/test_job.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,51 @@ def test_start_job_with_cli_args_label_selection_not_json_str(
723723
)
724724
self.assertEqual(result.exit_code, 1)
725725

726+
@mock.patch("yaml.dump")
727+
@mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__")
728+
@mock.patch("hyperpod_cli.commands.job.JobValidator")
729+
@mock.patch("boto3.Session")
730+
def test_start_job_with_cli_args_pre_script_and_post_script(
731+
self,
732+
mock_boto3,
733+
mock_validator_cls,
734+
mock_kubernetes_client,
735+
mock_yaml_dump,
736+
):
737+
mock_validator = mock_validator_cls.return_value
738+
mock_validator.validate_aws_credential.return_value = True
739+
mock_kubernetes_client.get_current_context_namespace.return_value = "kubeflow"
740+
mock_yaml_dump.return_value = None
741+
result = self.runner.invoke(
742+
start_job,
743+
[
744+
"--job-name",
745+
"test-job",
746+
"--instance-type",
747+
"ml.c5.xlarge",
748+
"--image",
749+
"pytorch:1.9.0-cuda11.1-cudnn8-runtime",
750+
"--node-count",
751+
"2",
752+
"--label-selector",
753+
"{NonJsonStr",
754+
"--entry-script",
755+
"/opt/train/src/train.py",
756+
"--pre-script",
757+
"echo 'test', echo 'test 1'",
758+
"--post-script",
759+
"echo 'test 1', echo 'test 2'"
760+
],
761+
)
762+
763+
# Assert that yaml.dump was called with the correct configuration
764+
mock_yaml_dump.assert_called_once()
765+
call_args = mock_yaml_dump.call_args[0]
766+
self.assertEqual(call_args[0]['cluster']['cluster_config']['pre_script'], ["echo 'test'", " echo 'test 1'"])
767+
self.assertEqual(call_args[0]['cluster']['cluster_config']['post_script'], ["echo 'test 1'", " echo 'test 2'"])
768+
769+
self.assertEqual(result.exit_code, 1)
770+
726771
@mock.patch("yaml.dump")
727772
@mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__")
728773
@mock.patch("hyperpod_cli.commands.job.JobValidator")

0 commit comments

Comments
 (0)