@@ -795,3 +795,40 @@ def test_setup_for_nova_recipe_with_distillation(mock_resolve_save, sagemaker_se
795795 pytorch ._hyperparameters .get ("role_arn" )
796796 == "arn:aws:iam::123456789012:role/SageMakerRole"
797797 )
798+
799+
800+ @patch ("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save" )
801+ def test_setup_for_nova_recipe_sets_model_type (mock_resolve_save , sagemaker_session ):
802+ """Test that _setup_for_nova_recipe correctly sets model_type hyperparameter."""
803+ # Create a mock nova recipe with model_type
804+ recipe = OmegaConf .create (
805+ {
806+ "run" : {
807+ "model_type" : "amazon.nova.llama-2-7b" ,
808+ "model_name_or_path" : "llama/llama-2-7b" ,
809+ "replicas" : 1 ,
810+ }
811+ }
812+ )
813+
814+ with patch (
815+ "sagemaker.pytorch.estimator.PyTorch._recipe_load" , return_value = ("nova_recipe" , recipe )
816+ ):
817+ mock_resolve_save .return_value = recipe
818+
819+ pytorch = PyTorch (
820+ training_recipe = "nova_recipe" ,
821+ role = ROLE ,
822+ sagemaker_session = sagemaker_session ,
823+ instance_count = INSTANCE_COUNT ,
824+ instance_type = INSTANCE_TYPE_GPU ,
825+ image_uri = IMAGE_URI ,
826+ framework_version = "1.13.1" ,
827+ py_version = "py3" ,
828+ )
829+
830+ # Check that the Nova recipe was correctly identified
831+ assert pytorch .is_nova_recipe is True
832+
833+ # Verify that model_type hyperparameter was set correctly
834+ assert pytorch ._hyperparameters .get ("model_type" ) == "amazon.nova.llama-2-7b"
0 commit comments