@@ -832,3 +832,81 @@ def test_setup_for_nova_recipe_sets_model_type(mock_resolve_save, sagemaker_sess
832832
833833 # Verify that model_type hyperparameter was set correctly
834834 assert pytorch ._hyperparameters .get ("model_type" ) == "amazon.nova.llama-2-7b"
835+
836+
837+ @patch ("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save" )
838+ def test_setup_for_nova_recipe_with_reward_lambda (mock_resolve_save , sagemaker_session ):
839+ """Test that _setup_for_nova_recipe correctly handles reward lambda configuration."""
840+ # Create a mock recipe with reward lambda config
841+ recipe = OmegaConf .create (
842+ {
843+ "run" : {
844+ "model_type" : "amazon.nova.foobar3" ,
845+ "model_name_or_path" : "foobar/foobar-3-8b" ,
846+ "reward_lambda_arn" : "arn:aws:lambda:us-west-2:123456789012:function:reward-function" ,
847+ "replicas" : 1 ,
848+ },
849+ }
850+ )
851+
852+ with patch (
853+ "sagemaker.pytorch.estimator.PyTorch._recipe_load" , return_value = ("nova_recipe" , recipe )
854+ ):
855+ mock_resolve_save .return_value = recipe
856+
857+ pytorch = PyTorch (
858+ training_recipe = "nova_recipe" ,
859+ role = ROLE ,
860+ sagemaker_session = sagemaker_session ,
861+ instance_count = INSTANCE_COUNT ,
862+ instance_type = INSTANCE_TYPE_GPU ,
863+ image_uri = IMAGE_URI ,
864+ framework_version = "1.13.1" ,
865+ py_version = "py3" ,
866+ )
867+
868+ # Check that the Nova recipe was correctly identified
869+ assert pytorch .is_nova_or_eval_recipe is True
870+
871+ # Verify that reward_lambda_arn hyperparameter was set correctly
872+ assert (
873+ pytorch ._hyperparameters .get ("reward_lambda_arn" )
874+ == "arn:aws:lambda:us-west-2:123456789012:function:reward-function"
875+ )
876+
877+
878+ @patch ("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save" )
879+ def test_setup_for_nova_recipe_without_reward_lambda (mock_resolve_save , sagemaker_session ):
880+ """Test that _setup_for_nova_recipe does not set reward_lambda_arn when not present."""
881+ # Create a mock recipe without reward lambda config
882+ recipe = OmegaConf .create (
883+ {
884+ "run" : {
885+ "model_type" : "amazon.nova.foobar3" ,
886+ "model_name_or_path" : "foobar/foobar-3-8b" ,
887+ "replicas" : 1 ,
888+ },
889+ }
890+ )
891+
892+ with patch (
893+ "sagemaker.pytorch.estimator.PyTorch._recipe_load" , return_value = ("nova_recipe" , recipe )
894+ ):
895+ mock_resolve_save .return_value = recipe
896+
897+ pytorch = PyTorch (
898+ training_recipe = "nova_recipe" ,
899+ role = ROLE ,
900+ sagemaker_session = sagemaker_session ,
901+ instance_count = INSTANCE_COUNT ,
902+ instance_type = INSTANCE_TYPE_GPU ,
903+ image_uri = IMAGE_URI ,
904+ framework_version = "1.13.1" ,
905+ py_version = "py3" ,
906+ )
907+
908+ # Check that the Nova recipe was correctly identified
909+ assert pytorch .is_nova_or_eval_recipe is True
910+
911+ # Verify that reward_lambda_arn hyperparameter was not set
912+ assert "reward_lambda_arn" not in pytorch ._hyperparameters
0 commit comments