@@ -575,6 +575,86 @@ def capture_yaml_dump(config, *args, **kwargs):
575575 print (f"Output: { result .output } " )
576576 if result .exception :
577577 print (f"Exception: { result .exception } " )
578+
579+ @mock .patch ('subprocess.run' )
580+ @mock .patch ("yaml.dump" )
581+ @mock .patch ("os.path.exists" , return_value = True )
582+ @mock .patch ("os.remove" , return_value = None )
583+ @mock .patch ("hyperpod_cli.utils.get_cluster_console_url" )
584+ @mock .patch ("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__" )
585+ @mock .patch ("hyperpod_cli.commands.job.JobValidator" )
586+ @mock .patch ("boto3.Session" )
587+ def test_start_job_label_selector_preferred_instance_type (
588+ self ,
589+ mock_boto3 ,
590+ mock_validator_cls ,
591+ mock_kubernetes_client ,
592+ mock_get_console_link ,
593+ mock_remove ,
594+ mock_exists ,
595+ mock_yaml_dump ,
596+ mock_subprocess_run ,
597+ ):
598+ # Setup mocks
599+ mock_validator = mock_validator_cls .return_value
600+ mock_validator .validate_aws_credential .return_value = True
601+ mock_kubernetes_client .get_current_context_namespace .return_value = "kubeflow"
602+ mock_get_console_link .return_value = "test-console-link"
603+ mock_subprocess_run .return_value = subprocess .CompletedProcess (
604+ args = ['some_command' ],
605+ returncode = 0 ,
606+ stdout = 'Command executed successfully' ,
607+ stderr = ''
608+ )
609+
610+ expected_default_label_selector_config = {
611+ "preferred" : {"beta.kubernetes.io/instance-type" : ["ml.c5.xlarge" ]},
612+ }
613+
614+ # Capture the yaml.dump calls to inspect the config
615+ configs_dumped = []
616+ def capture_yaml_dump (config , * args , ** kwargs ):
617+ configs_dumped .append (config )
618+ print (f"Dumped config: { config } " )
619+ return None
620+ mock_yaml_dump .side_effect = capture_yaml_dump
621+
622+ # Run the command
623+ result = self .runner .invoke (
624+ start_job ,
625+ [
626+ "--job-name" , "test-job" ,
627+ "--instance-type" , "ml.c5.xlarge" ,
628+ "--image" , "pytorch:1.9.0-cuda11.1-cudnn8-runtime" ,
629+ "--node-count" , "2" ,
630+ "--entry-script" , "/opt/train/src/train.py" ,
631+ "--label_selector" ,
632+ '{"preferred": {"beta.kubernetes.io/instance-type": ["ml.c5.xlarge"]}}' ,
633+ ],
634+ catch_exceptions = False
635+ )
636+
637+ # Verify the command executed successfully
638+ self .assertEqual (result .exit_code , 0 )
639+
640+ # Get the config that was generated
641+ self .assertTrue (len (configs_dumped ) > 0 , "No config was generated" )
642+ config = configs_dumped [0 ] # Get the first config that was dumped
643+
644+ # Verify label_selector configuration
645+ self .assertIn ('cluster' , config )
646+ self .assertIn ('cluster_config' , config ['cluster' ])
647+ self .assertIn ('label_selector' , config ['cluster' ]['cluster_config' ])
648+
649+ self .assertEqual (
650+ config ['cluster' ]['cluster_config' ]['label_selector' ],
651+ expected_default_label_selector_config
652+ )
653+
654+ print (f"Exit code: { result .exit_code } " )
655+ print (f"Output: { result .output } " )
656+ if result .exception :
657+ print (f"Exception: { result .exception } " )
578658
579659 @mock .patch ('subprocess.run' )
580660 @mock .patch ("yaml.dump" )
0 commit comments