Skip to content

Commit f1b1401

Browse files
committed
feature: Added condition to allow eval recipe.
1 parent 7eff90a commit f1b1401

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

src/sagemaker/pytorch/estimator.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,23 @@ def _is_nova_recipe(recipe):
163163
return bool(has_nova_model) or bool(has_distillation)
164164

165165

166+
def _is_eval_recipe(recipe):
167+
"""Check if the recipe is an eval recipe.
168+
169+
An eval recipe is identified by:
170+
1. Having a evaluation section
171+
172+
Args:
173+
recipe (OmegaConf): The loaded recipe configuration
174+
175+
Returns:
176+
bool: True if the recipe is an eval recipe, False otherwise
177+
"""
178+
# Check for eval model
179+
eval_config = recipe.get("evaluation", {})
180+
return bool(eval_config)
181+
182+
166183
def _recipe_initialize_args(source_dir):
167184
"""Initialize the arguments dictionary for recipe setup.
168185
@@ -949,7 +966,7 @@ def _device_validate_and_get_type(kwargs, recipe):
949966
if "instance_type" not in kwargs:
950967
raise ValueError("Must pass instance type to estimator when using training recipes.")
951968

952-
if not _is_nova_recipe(recipe) and "trainer" not in recipe:
969+
if not _is_nova_recipe(recipe) and "trainer" not in recipe and not _is_eval_recipe(recipe):
953970
raise ValueError("Supplied recipe does not contain required field trainer.")
954971

955972
instance_type = kwargs["instance_type"].split(".")[1]
@@ -973,15 +990,15 @@ def _device_handle_instance_count(kwargs, recipe):
973990
"""
974991
# Check if instance_count is already provided in kwargs
975992

976-
is_nova = _is_nova_recipe(recipe)
993+
is_nova_or_eval = _is_nova_recipe(recipe) or _is_eval_recipe(recipe)
977994
if "instance_count" in kwargs:
978995
# Warn if there are conflicting configurations in the recipe
979996
if "num_nodes" in recipe.get("trainer", {}):
980997
logger.warning(
981998
"Using instance_count argument to estimator to set number "
982999
"of nodes. Ignoring trainer -> num_nodes in recipe."
9831000
)
984-
if is_nova and "replicas" in recipe.get("run", {}):
1001+
if is_nova_or_eval and "replicas" in recipe.get("run", {}):
9851002
logger.warning(
9861003
"Using instance_count argument to estimator to set number "
9871004
"of nodes. Ignoring run -> replicas in recipe."
@@ -993,7 +1010,7 @@ def _device_handle_instance_count(kwargs, recipe):
9931010
kwargs["instance_count"] = recipe["trainer"]["num_nodes"]
9941011
return
9951012

996-
if is_nova and "run" in recipe and "replicas" in recipe["run"]:
1013+
if is_nova_or_eval and "run" in recipe and "replicas" in recipe["run"]:
9971014
kwargs["instance_count"] = recipe["run"]["replicas"]
9981015
return
9991016

@@ -1137,7 +1154,7 @@ def _setup_for_training_recipe(self, training_recipe, recipe_overrides, source_d
11371154
# Merge with overrides
11381155
recipe = OmegaConf.merge(recipe, recipe_overrides)
11391156

1140-
self.is_nova_recipe = _is_nova_recipe(recipe)
1157+
self.is_nova_recipe = _is_nova_recipe(recipe) or _is_eval_recipe(recipe)
11411158
if self.is_nova_recipe:
11421159
return self._setup_for_nova_recipe(
11431160
recipe,

0 commit comments

Comments
 (0)