@@ -493,6 +493,90 @@ def test_start_job_with_cli_args(
493493 print (f"Exception: { result .exception } " )
494494 self .assertEqual (result .exit_code , 0 )
495495
496+ @mock .patch ('subprocess.run' )
497+ @mock .patch ("yaml.dump" )
498+ @mock .patch ("os.path.exists" , return_value = True )
499+ @mock .patch ("os.remove" , return_value = None )
500+ @mock .patch ("hyperpod_cli.utils.get_cluster_console_url" )
501+ @mock .patch ("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__" )
502+ @mock .patch ("hyperpod_cli.commands.job.JobValidator" )
503+ @mock .patch ("boto3.Session" )
504+ def test_start_job_default_label_selector_config (
505+ self ,
506+ mock_boto3 ,
507+ mock_validator_cls ,
508+ mock_kubernetes_client ,
509+ mock_get_console_link ,
510+ mock_remove ,
511+ mock_exists ,
512+ mock_yaml_dump ,
513+ mock_subprocess_run ,
514+ ):
515+ # Setup mocks
516+ mock_validator = mock_validator_cls .return_value
517+ mock_validator .validate_aws_credential .return_value = True
518+ mock_kubernetes_client .get_current_context_namespace .return_value = "kubeflow"
519+ mock_get_console_link .return_value = "test-console-link"
520+ mock_subprocess_run .return_value = subprocess .CompletedProcess (
521+ args = ['some_command' ],
522+ returncode = 0 ,
523+ stdout = 'Command executed successfully' ,
524+ stderr = ''
525+ )
526+
527+ expected_default_label_selector_config = {
528+ "required" : {
529+ "sagemaker.amazonaws.com/node-health-status" : ["Schedulable" ],
530+ "beta.kubernetes.io/instance-type" : ["ml.c5.xlarge" ]
531+ },
532+ "preferred" : {"sagemaker.amazonaws.com/deep-health-check-status" : ["Passed" ]},
533+ "weights" : [100 ],
534+ }
535+
536+ # Capture the yaml.dump calls to inspect the config
537+ configs_dumped = []
538+ def capture_yaml_dump (config , * args , ** kwargs ):
539+ configs_dumped .append (config )
540+ print (f"Dumped config: { config } " )
541+ return None
542+ mock_yaml_dump .side_effect = capture_yaml_dump
543+
544+ # Run the command
545+ result = self .runner .invoke (
546+ start_job ,
547+ [
548+ "--job-name" , "test-job" ,
549+ "--instance-type" , "ml.c5.xlarge" ,
550+ "--image" , "pytorch:1.9.0-cuda11.1-cudnn8-runtime" ,
551+ "--node-count" , "2" ,
552+ "--entry-script" , "/opt/train/src/train.py" ,
553+ ],
554+ catch_exceptions = False
555+ )
556+
557+ # Verify the command executed successfully
558+ self .assertEqual (result .exit_code , 0 )
559+
560+ # Get the config that was generated
561+ self .assertTrue (len (configs_dumped ) > 0 , "No config was generated" )
562+ config = configs_dumped [0 ] # Get the first config that was dumped
563+
564+ # Verify label_selector configuration
565+ self .assertIn ('cluster' , config )
566+ self .assertIn ('cluster_config' , config ['cluster' ])
567+ self .assertIn ('label_selector' , config ['cluster' ]['cluster_config' ])
568+
569+ self .assertEqual (
570+ config ['cluster' ]['cluster_config' ]['label_selector' ],
571+ expected_default_label_selector_config
572+ )
573+
574+ print (f"Exit code: { result .exit_code } " )
575+ print (f"Output: { result .output } " )
576+ if result .exception :
577+ print (f"Exception: { result .exception } " )
578+
579+
496580 @mock .patch ('subprocess.run' )
497581 @mock .patch ("yaml.dump" )
498582 @mock .patch ("os.path.exists" , return_value = True )
0 commit comments