Skip to content

Commit 9f15206

Browse files
committed
feature: Added condition to allow eval recipe.
1 parent 79c8294 commit 9f15206

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

src/sagemaker/pytorch/estimator.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,21 @@ def _is_nova_recipe(recipe):
162162

163163
return bool(has_nova_model) or bool(has_distillation)
164164

165+
def _is_eval_recipe(recipe):
166+
"""Check if the recipe is an eval recipe.
167+
168+
An eval recipe is identified by:
169+
1. Having a evaluation section
170+
171+
Args:
172+
recipe (OmegaConf): The loaded recipe configuration
173+
174+
Returns:
175+
bool: True if the recipe is an eval recipe, False otherwise
176+
"""
177+
# Check for eval model
178+
eval_config = recipe.get("evaluation", {})
179+
return bool(eval_config)
165180

166181
def _recipe_initialize_args(source_dir):
167182
"""Initialize the arguments dictionary for recipe setup.
@@ -949,7 +964,7 @@ def _device_validate_and_get_type(kwargs, recipe):
949964
if "instance_type" not in kwargs:
950965
raise ValueError("Must pass instance type to estimator when using training recipes.")
951966

952-
if not _is_nova_recipe(recipe) and "trainer" not in recipe:
967+
if not _is_nova_recipe(recipe) and "trainer" not in recipe and not _is_eval_recipe(recipe):
953968
raise ValueError("Supplied recipe does not contain required field trainer.")
954969

955970
instance_type = kwargs["instance_type"].split(".")[1]
@@ -973,15 +988,15 @@ def _device_handle_instance_count(kwargs, recipe):
973988
"""
974989
# Check if instance_count is already provided in kwargs
975990

976-
is_nova = _is_nova_recipe(recipe)
991+
is_nova_or_eval = _is_nova_recipe(recipe) or _is_eval_recipe(recipe)
977992
if "instance_count" in kwargs:
978993
# Warn if there are conflicting configurations in the recipe
979994
if "num_nodes" in recipe.get("trainer", {}):
980995
logger.warning(
981996
"Using instance_count argument to estimator to set number "
982997
"of nodes. Ignoring trainer -> num_nodes in recipe."
983998
)
984-
if is_nova and "replicas" in recipe.get("run", {}):
999+
if is_nova_or_eval and "replicas" in recipe.get("run", {}):
9851000
logger.warning(
9861001
"Using instance_count argument to estimator to set number "
9871002
"of nodes. Ignoring run -> replicas in recipe."
@@ -993,7 +1008,7 @@ def _device_handle_instance_count(kwargs, recipe):
9931008
kwargs["instance_count"] = recipe["trainer"]["num_nodes"]
9941009
return
9951010

996-
if is_nova and "run" in recipe and "replicas" in recipe["run"]:
1011+
if is_nova_or_eval and "run" in recipe and "replicas" in recipe["run"]:
9971012
kwargs["instance_count"] = recipe["run"]["replicas"]
9981013
return
9991014

0 commit comments

Comments
 (0)