Skip to content

Commit 9059c25

Browse files
author
Malav Shastri
committed
feat: Extract reward_lambda_arn from Nova recipes to training job hyperparameters
1 parent 2e2b27c commit 9059c25

File tree

3 files changed

+66
-0
lines changed

3 files changed

+66
-0
lines changed

src/sagemaker/modules/train/sm_recipes/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,12 @@ def _get_args_from_nova_recipe(
312312
if lambda_arn:
313313
args["hyperparameters"]["eval_lambda_arn"] = lambda_arn
314314

315+
# Handle reward lambda configuration
316+
run_config = recipe.get("run", {})
317+
reward_lambda_arn = run_config.get("reward_lambda_arn", "")
318+
if reward_lambda_arn:
319+
args["hyperparameters"]["reward_lambda_arn"] = reward_lambda_arn
320+
315321
_register_custom_resolvers()
316322

317323
# Resolve Final Recipe

src/sagemaker/pytorch/estimator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,6 +1251,12 @@ def _setup_for_nova_recipe(
12511251
if lambda_arn:
12521252
args["hyperparameters"]["eval_lambda_arn"] = lambda_arn
12531253

1254+
# Handle reward lambda configuration
1255+
run_config = recipe.get("run", {})
1256+
reward_lambda_arn = run_config.get("reward_lambda_arn", "")
1257+
if reward_lambda_arn:
1258+
args["hyperparameters"]["reward_lambda_arn"] = reward_lambda_arn
1259+
12541260
# Resolve and save the final recipe
12551261
self._recipe_resolve_and_save(recipe, recipe_name, args["source_dir"])
12561262

tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,3 +478,57 @@ def test_get_args_from_nova_recipe_with_evaluation(test_case):
478478
recipe=recipe, compute=test_case["compute"], role=test_case["role"]
479479
)
480480
assert args == test_case["expected_args"]
481+
482+
483+
@pytest.mark.parametrize(
484+
"test_case",
485+
[
486+
{
487+
"recipe": {
488+
"run": {
489+
"model_type": "amazon.nova",
490+
"model_name_or_path": "dummy-test",
491+
"reward_lambda_arn": "arn:aws:lambda:us-east-1:123456789012:function:MyRewardLambdaFunction",
492+
},
493+
},
494+
"compute": Compute(instance_type="ml.m5.xlarge", instance_count=2),
495+
"role": "arn:aws:iam::123456789012:role/SageMakerRole",
496+
"expected_args": {
497+
"compute": Compute(instance_type="ml.m5.xlarge", instance_count=2),
498+
"hyperparameters": {
499+
"base_model": "dummy-test",
500+
"reward_lambda_arn": "arn:aws:lambda:us-east-1:123456789012:function:MyRewardLambdaFunction",
501+
},
502+
"training_image": None,
503+
"source_code": None,
504+
"distributed": None,
505+
},
506+
},
507+
{
508+
"recipe": {
509+
"run": {
510+
"model_type": "amazon.nova",
511+
"model_name_or_path": "dummy-test",
512+
# No reward_lambda_arn - should not be in hyperparameters
513+
},
514+
},
515+
"compute": Compute(instance_type="ml.m5.xlarge", instance_count=2),
516+
"role": "arn:aws:iam::123456789012:role/SageMakerRole",
517+
"expected_args": {
518+
"compute": Compute(instance_type="ml.m5.xlarge", instance_count=2),
519+
"hyperparameters": {
520+
"base_model": "dummy-test",
521+
},
522+
"training_image": None,
523+
"source_code": None,
524+
"distributed": None,
525+
},
526+
},
527+
],
528+
)
529+
def test_get_args_from_nova_recipe_with_reward_lambda(test_case):
530+
recipe = OmegaConf.create(test_case["recipe"])
531+
args, _ = _get_args_from_nova_recipe(
532+
recipe=recipe, compute=test_case["compute"], role=test_case["role"]
533+
)
534+
assert args == test_case["expected_args"]

0 commit comments

Comments
 (0)