Skip to content

Commit 57d2333

Browse files
authored
feature: Added condition to allow eval recipe. (#5298)
* feature: Added condition to allow eval recipe. * change: renamed is_nova_recipe to is_nova_or_eval_recipe
1 parent 7eff90a commit 57d2333

File tree

2 files changed

+39
-22
lines changed

2 files changed

+39
-22
lines changed

src/sagemaker/pytorch/estimator.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,23 @@ def _is_nova_recipe(recipe):
163163
return bool(has_nova_model) or bool(has_distillation)
164164

165165

166+
def _is_eval_recipe(recipe):
167+
"""Check if the recipe is an eval recipe.
168+
169+
An eval recipe is identified by:
170+
1. Having a evaluation section
171+
172+
Args:
173+
recipe (OmegaConf): The loaded recipe configuration
174+
175+
Returns:
176+
bool: True if the recipe is an eval recipe, False otherwise
177+
"""
178+
# Check for eval model
179+
eval_config = recipe.get("evaluation", {})
180+
return bool(eval_config)
181+
182+
166183
def _recipe_initialize_args(source_dir):
167184
"""Initialize the arguments dictionary for recipe setup.
168185
@@ -526,7 +543,7 @@ def __init__(
526543
:class:`~sagemaker.estimator.Framework` and
527544
:class:`~sagemaker.estimator.EstimatorBase`.
528545
"""
529-
self.is_nova_recipe = False
546+
self.is_nova_or_eval_recipe = False
530547
if training_recipe is not None:
531548
if entry_point is not None:
532549
logger.warning("Argument entry_point will be ignored with training_recipe.")
@@ -538,7 +555,7 @@ def __init__(
538555
training_recipe, recipe_overrides, source_dir, kwargs
539556
)
540557

541-
if self.is_nova_recipe and image_uri is None:
558+
if self.is_nova_or_eval_recipe and image_uri is None:
542559
raise ValueError("Must supply image_uri for nova jobs.")
543560

544561
entry_point = args["entry_point"]
@@ -569,7 +586,7 @@ def __init__(
569586
source_dir,
570587
hyperparameters,
571588
image_uri=image_uri,
572-
is_nova_job=self.is_nova_recipe,
589+
is_nova_job=self.is_nova_or_eval_recipe,
573590
**kwargs,
574591
)
575592

@@ -702,8 +719,8 @@ def fit(
702719
"""
703720
# Handle recipe upload and input channel creation if we have a recipe
704721
if (
705-
self.is_nova_recipe is not None
706-
and self.is_nova_recipe
722+
self.is_nova_or_eval_recipe is not None
723+
and self.is_nova_or_eval_recipe
707724
and hasattr(self, "training_recipe_file")
708725
and self.training_recipe_file
709726
):
@@ -949,7 +966,7 @@ def _device_validate_and_get_type(kwargs, recipe):
949966
if "instance_type" not in kwargs:
950967
raise ValueError("Must pass instance type to estimator when using training recipes.")
951968

952-
if not _is_nova_recipe(recipe) and "trainer" not in recipe:
969+
if not _is_nova_recipe(recipe) and "trainer" not in recipe and not _is_eval_recipe(recipe):
953970
raise ValueError("Supplied recipe does not contain required field trainer.")
954971

955972
instance_type = kwargs["instance_type"].split(".")[1]
@@ -973,15 +990,15 @@ def _device_handle_instance_count(kwargs, recipe):
973990
"""
974991
# Check if instance_count is already provided in kwargs
975992

976-
is_nova = _is_nova_recipe(recipe)
993+
is_nova_or_eval = _is_nova_recipe(recipe) or _is_eval_recipe(recipe)
977994
if "instance_count" in kwargs:
978995
# Warn if there are conflicting configurations in the recipe
979996
if "num_nodes" in recipe.get("trainer", {}):
980997
logger.warning(
981998
"Using instance_count argument to estimator to set number "
982999
"of nodes. Ignoring trainer -> num_nodes in recipe."
9831000
)
984-
if is_nova and "replicas" in recipe.get("run", {}):
1001+
if is_nova_or_eval and "replicas" in recipe.get("run", {}):
9851002
logger.warning(
9861003
"Using instance_count argument to estimator to set number "
9871004
"of nodes. Ignoring run -> replicas in recipe."
@@ -993,7 +1010,7 @@ def _device_handle_instance_count(kwargs, recipe):
9931010
kwargs["instance_count"] = recipe["trainer"]["num_nodes"]
9941011
return
9951012

996-
if is_nova and "run" in recipe and "replicas" in recipe["run"]:
1013+
if is_nova_or_eval and "run" in recipe and "replicas" in recipe["run"]:
9971014
kwargs["instance_count"] = recipe["run"]["replicas"]
9981015
return
9991016

@@ -1137,8 +1154,8 @@ def _setup_for_training_recipe(self, training_recipe, recipe_overrides, source_d
11371154
# Merge with overrides
11381155
recipe = OmegaConf.merge(recipe, recipe_overrides)
11391156

1140-
self.is_nova_recipe = _is_nova_recipe(recipe)
1141-
if self.is_nova_recipe:
1157+
self.is_nova_or_eval_recipe = _is_nova_recipe(recipe) or _is_eval_recipe(recipe)
1158+
if self.is_nova_or_eval_recipe:
11421159
return self._setup_for_nova_recipe(
11431160
recipe,
11441161
recipe_name,

tests/unit/test_pytorch_nova.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def test_setup_for_nova_recipe_with_model_name(mock_resolve_save, sagemaker_sess
138138
)
139139

140140
# Check that the Nova recipe was correctly identified
141-
assert pytorch.is_nova_recipe is True
141+
assert pytorch.is_nova_or_eval_recipe is True
142142

143143
# Verify _setup_for_nova_recipe was called
144144
mock_nova_setup.assert_called_once()
@@ -194,7 +194,7 @@ def test_setup_for_nova_recipe_with_s3_path(mock_resolve_save, sagemaker_session
194194
)
195195

196196
# Check that the Nova recipe was correctly identified
197-
assert pytorch.is_nova_recipe is True
197+
assert pytorch.is_nova_or_eval_recipe is True
198198

199199
# Verify _setup_for_nova_recipe was called
200200
mock_nova_setup.assert_called_once()
@@ -326,7 +326,7 @@ def test_upload_recipe_to_s3(mock_time, mock_recipe_load, sagemaker_session):
326326
)
327327

328328
# Set Nova recipe attributes
329-
pytorch.is_nova_recipe = True
329+
pytorch.is_nova_or_eval_recipe = True
330330

331331
# Create a temporary file to use as the recipe file
332332
with tempfile.NamedTemporaryFile(suffix=".yaml") as temp_file:
@@ -369,7 +369,7 @@ def test_recipe_resolve_and_save(
369369
)
370370

371371
# Set Nova recipe attributes
372-
pytorch.is_nova_recipe = True
372+
pytorch.is_nova_or_eval_recipe = True
373373

374374
# Mock the temporary file
375375
mock_temp_file_instance = Mock()
@@ -421,7 +421,7 @@ def test_fit_with_nova_recipe_s3_upload(mock_framework_fit, mock_recipe_load, sa
421421
)
422422

423423
# Set Nova recipe attributes
424-
pytorch.is_nova_recipe = True
424+
pytorch.is_nova_or_eval_recipe = True
425425
pytorch.training_recipe_file = temp_file
426426

427427
# Mock the _upload_recipe_to_s3 method
@@ -473,7 +473,7 @@ def test_fit_with_nova_recipe_and_inputs(
473473
)
474474

475475
# Set Nova recipe attributes
476-
pytorch.is_nova_recipe = True
476+
pytorch.is_nova_or_eval_recipe = True
477477
pytorch.training_recipe_file = temp_file
478478

479479
# Create training inputs
@@ -559,7 +559,7 @@ def test_fit_with_nova_recipe(
559559
)
560560

561561
# Set Nova recipe attributes
562-
pytorch.is_nova_recipe = True
562+
pytorch.is_nova_or_eval_recipe = True
563563
pytorch.training_recipe_file = temp_file
564564

565565
# Mock the upload_recipe_to_s3 method
@@ -642,7 +642,7 @@ def test_framework_set_hyperparameters_non_nova():
642642
py_version="py3",
643643
image_uri=IMAGE_URI,
644644
)
645-
framework.is_nova_recipe = False
645+
framework.is_nova_or_eval_recipe = False
646646

647647
# Add hyperparameters
648648
framework.set_hyperparameters(string_param="string_value", int_param=42, bool_param=True)
@@ -719,7 +719,7 @@ def test_setup_for_nova_recipe_with_evaluation_lambda(mock_resolve_save, sagemak
719719
)
720720

721721
# Check that the Nova recipe was correctly identified
722-
assert pytorch.is_nova_recipe is True
722+
assert pytorch.is_nova_or_eval_recipe is True
723723

724724
# Verify that eval_lambda_arn hyperparameter was set correctly
725725
assert (
@@ -780,7 +780,7 @@ def test_setup_for_nova_recipe_with_distillation(mock_resolve_save, sagemaker_se
780780
)
781781

782782
# Check that the Nova recipe was correctly identified
783-
assert pytorch.is_nova_recipe is True
783+
assert pytorch.is_nova_or_eval_recipe is True
784784

785785
# Verify _setup_for_nova_recipe was called
786786
mock_nova_setup.assert_called_once()
@@ -828,7 +828,7 @@ def test_setup_for_nova_recipe_sets_model_type(mock_resolve_save, sagemaker_sess
828828
)
829829

830830
# Check that the Nova recipe was correctly identified
831-
assert pytorch.is_nova_recipe is True
831+
assert pytorch.is_nova_or_eval_recipe is True
832832

833833
# Verify that model_type hyperparameter was set correctly
834834
assert pytorch._hyperparameters.get("model_type") == "amazon.nova.llama-2-7b"

0 commit comments

Comments
 (0)