Skip to content

Commit ef3bf7b

Browse files
author
Malav Shastri
committed
Add test for pytorch reward lambda
1 parent 9059c25 commit ef3bf7b

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

tests/unit/test_pytorch_nova.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,3 +832,81 @@ def test_setup_for_nova_recipe_sets_model_type(mock_resolve_save, sagemaker_sess
832832

833833
# Verify that model_type hyperparameter was set correctly
834834
assert pytorch._hyperparameters.get("model_type") == "amazon.nova.llama-2-7b"
835+
836+
837+
@patch("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save")
838+
def test_setup_for_nova_recipe_with_reward_lambda(mock_resolve_save, sagemaker_session):
839+
"""Test that _setup_for_nova_recipe correctly handles reward lambda configuration."""
840+
# Create a mock recipe with reward lambda config
841+
recipe = OmegaConf.create(
842+
{
843+
"run": {
844+
"model_type": "amazon.nova.foobar3",
845+
"model_name_or_path": "foobar/foobar-3-8b",
846+
"reward_lambda_arn": "arn:aws:lambda:us-west-2:123456789012:function:reward-function",
847+
"replicas": 1,
848+
},
849+
}
850+
)
851+
852+
with patch(
853+
"sagemaker.pytorch.estimator.PyTorch._recipe_load", return_value=("nova_recipe", recipe)
854+
):
855+
mock_resolve_save.return_value = recipe
856+
857+
pytorch = PyTorch(
858+
training_recipe="nova_recipe",
859+
role=ROLE,
860+
sagemaker_session=sagemaker_session,
861+
instance_count=INSTANCE_COUNT,
862+
instance_type=INSTANCE_TYPE_GPU,
863+
image_uri=IMAGE_URI,
864+
framework_version="1.13.1",
865+
py_version="py3",
866+
)
867+
868+
# Check that the Nova recipe was correctly identified
869+
assert pytorch.is_nova_or_eval_recipe is True
870+
871+
# Verify that reward_lambda_arn hyperparameter was set correctly
872+
assert (
873+
pytorch._hyperparameters.get("reward_lambda_arn")
874+
== "arn:aws:lambda:us-west-2:123456789012:function:reward-function"
875+
)
876+
877+
878+
@patch("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save")
879+
def test_setup_for_nova_recipe_without_reward_lambda(mock_resolve_save, sagemaker_session):
880+
"""Test that _setup_for_nova_recipe does not set reward_lambda_arn when not present."""
881+
# Create a mock recipe without reward lambda config
882+
recipe = OmegaConf.create(
883+
{
884+
"run": {
885+
"model_type": "amazon.nova.foobar3",
886+
"model_name_or_path": "foobar/foobar-3-8b",
887+
"replicas": 1,
888+
},
889+
}
890+
)
891+
892+
with patch(
893+
"sagemaker.pytorch.estimator.PyTorch._recipe_load", return_value=("nova_recipe", recipe)
894+
):
895+
mock_resolve_save.return_value = recipe
896+
897+
pytorch = PyTorch(
898+
training_recipe="nova_recipe",
899+
role=ROLE,
900+
sagemaker_session=sagemaker_session,
901+
instance_count=INSTANCE_COUNT,
902+
instance_type=INSTANCE_TYPE_GPU,
903+
image_uri=IMAGE_URI,
904+
framework_version="1.13.1",
905+
py_version="py3",
906+
)
907+
908+
# Check that the Nova recipe was correctly identified
909+
assert pytorch.is_nova_or_eval_recipe is True
910+
911+
# Verify that reward_lambda_arn hyperparameter was not set
912+
assert "reward_lambda_arn" not in pytorch._hyperparameters

0 commit comments

Comments
 (0)