@@ -163,6 +163,23 @@ def _is_nova_recipe(recipe):
163163 return bool (has_nova_model ) or bool (has_distillation )
164164
165165
166+ def _is_eval_recipe (recipe ):
167+ """Check if the recipe is an eval recipe.
168+
169+ An eval recipe is identified by:
170+ 1. Having a evaluation section
171+
172+ Args:
173+ recipe (OmegaConf): The loaded recipe configuration
174+
175+ Returns:
176+ bool: True if the recipe is an eval recipe, False otherwise
177+ """
178+ # Check for eval model
179+ eval_config = recipe .get ("evaluation" , {})
180+ return bool (eval_config )
181+
182+
166183def _recipe_initialize_args (source_dir ):
167184 """Initialize the arguments dictionary for recipe setup.
168185
@@ -526,7 +543,7 @@ def __init__(
526543 :class:`~sagemaker.estimator.Framework` and
527544 :class:`~sagemaker.estimator.EstimatorBase`.
528545 """
529- self .is_nova_recipe = False
546+ self .is_nova_or_eval_recipe = False
530547 if training_recipe is not None :
531548 if entry_point is not None :
532549 logger .warning ("Argument entry_point will be ignored with training_recipe." )
@@ -538,7 +555,7 @@ def __init__(
538555 training_recipe , recipe_overrides , source_dir , kwargs
539556 )
540557
541- if self .is_nova_recipe and image_uri is None :
558+ if self .is_nova_or_eval_recipe and image_uri is None :
542559 raise ValueError ("Must supply image_uri for nova jobs." )
543560
544561 entry_point = args ["entry_point" ]
@@ -569,7 +586,7 @@ def __init__(
569586 source_dir ,
570587 hyperparameters ,
571588 image_uri = image_uri ,
572- is_nova_job = self .is_nova_recipe ,
589+ is_nova_job = self .is_nova_or_eval_recipe ,
573590 ** kwargs ,
574591 )
575592
@@ -702,8 +719,8 @@ def fit(
702719 """
703720 # Handle recipe upload and input channel creation if we have a recipe
704721 if (
705- self .is_nova_recipe is not None
706- and self .is_nova_recipe
722+ self .is_nova_or_eval_recipe is not None
723+ and self .is_nova_or_eval_recipe
707724 and hasattr (self , "training_recipe_file" )
708725 and self .training_recipe_file
709726 ):
@@ -949,7 +966,7 @@ def _device_validate_and_get_type(kwargs, recipe):
949966 if "instance_type" not in kwargs :
950967 raise ValueError ("Must pass instance type to estimator when using training recipes." )
951968
952- if not _is_nova_recipe (recipe ) and "trainer" not in recipe :
969+ if not _is_nova_recipe (recipe ) and "trainer" not in recipe and not _is_eval_recipe ( recipe ) :
953970 raise ValueError ("Supplied recipe does not contain required field trainer." )
954971
955972 instance_type = kwargs ["instance_type" ].split ("." )[1 ]
@@ -973,15 +990,15 @@ def _device_handle_instance_count(kwargs, recipe):
973990 """
974991 # Check if instance_count is already provided in kwargs
975992
976- is_nova = _is_nova_recipe (recipe )
993+ is_nova_or_eval = _is_nova_recipe ( recipe ) or _is_eval_recipe (recipe )
977994 if "instance_count" in kwargs :
978995 # Warn if there are conflicting configurations in the recipe
979996 if "num_nodes" in recipe .get ("trainer" , {}):
980997 logger .warning (
981998 "Using instance_count argument to estimator to set number "
982999 "of nodes. Ignoring trainer -> num_nodes in recipe."
9831000 )
984- if is_nova and "replicas" in recipe .get ("run" , {}):
1001+ if is_nova_or_eval and "replicas" in recipe .get ("run" , {}):
9851002 logger .warning (
9861003 "Using instance_count argument to estimator to set number "
9871004 "of nodes. Ignoring run -> replicas in recipe."
@@ -993,7 +1010,7 @@ def _device_handle_instance_count(kwargs, recipe):
9931010 kwargs ["instance_count" ] = recipe ["trainer" ]["num_nodes" ]
9941011 return
9951012
996- if is_nova and "run" in recipe and "replicas" in recipe ["run" ]:
1013+ if is_nova_or_eval and "run" in recipe and "replicas" in recipe ["run" ]:
9971014 kwargs ["instance_count" ] = recipe ["run" ]["replicas" ]
9981015 return
9991016
@@ -1137,8 +1154,8 @@ def _setup_for_training_recipe(self, training_recipe, recipe_overrides, source_d
11371154 # Merge with overrides
11381155 recipe = OmegaConf .merge (recipe , recipe_overrides )
11391156
1140- self .is_nova_recipe = _is_nova_recipe (recipe )
1141- if self .is_nova_recipe :
1157+ self .is_nova_or_eval_recipe = _is_nova_recipe ( recipe ) or _is_eval_recipe (recipe )
1158+ if self .is_nova_or_eval_recipe :
11421159 return self ._setup_for_nova_recipe (
11431160 recipe ,
11441161 recipe_name ,
0 commit comments