@@ -162,6 +162,21 @@ def _is_nova_recipe(recipe):
162162
163163 return bool (has_nova_model ) or bool (has_distillation )
164164
165+ def _is_eval_recipe (recipe ):
166+ """Check if the recipe is an eval recipe.
167+
168+ An eval recipe is identified by:
169+ 1. Having a evaluation section
170+
171+ Args:
172+ recipe (OmegaConf): The loaded recipe configuration
173+
174+ Returns:
175+ bool: True if the recipe is an eval recipe, False otherwise
176+ """
177+ # Check for eval model
178+ eval_config = recipe .get ("evaluation" , {})
179+ return bool (eval_config )
165180
166181def _recipe_initialize_args (source_dir ):
167182 """Initialize the arguments dictionary for recipe setup.
@@ -949,7 +964,7 @@ def _device_validate_and_get_type(kwargs, recipe):
949964 if "instance_type" not in kwargs :
950965 raise ValueError ("Must pass instance type to estimator when using training recipes." )
951966
952- if not _is_nova_recipe (recipe ) and "trainer" not in recipe :
967+ if not _is_nova_recipe (recipe ) and "trainer" not in recipe and not _is_eval_recipe ( recipe ) :
953968 raise ValueError ("Supplied recipe does not contain required field trainer." )
954969
955970 instance_type = kwargs ["instance_type" ].split ("." )[1 ]
@@ -973,15 +988,15 @@ def _device_handle_instance_count(kwargs, recipe):
973988 """
974989 # Check if instance_count is already provided in kwargs
975990
976- is_nova = _is_nova_recipe (recipe )
991+ is_nova_or_eval = _is_nova_recipe ( recipe ) or _is_eval_recipe (recipe )
977992 if "instance_count" in kwargs :
978993 # Warn if there are conflicting configurations in the recipe
979994 if "num_nodes" in recipe .get ("trainer" , {}):
980995 logger .warning (
981996 "Using instance_count argument to estimator to set number "
982997 "of nodes. Ignoring trainer -> num_nodes in recipe."
983998 )
984- if is_nova and "replicas" in recipe .get ("run" , {}):
999+ if is_nova_or_eval and "replicas" in recipe .get ("run" , {}):
9851000 logger .warning (
9861001 "Using instance_count argument to estimator to set number "
9871002 "of nodes. Ignoring run -> replicas in recipe."
@@ -993,7 +1008,7 @@ def _device_handle_instance_count(kwargs, recipe):
9931008 kwargs ["instance_count" ] = recipe ["trainer" ]["num_nodes" ]
9941009 return
9951010
996- if is_nova and "run" in recipe and "replicas" in recipe ["run" ]:
1011+ if is_nova_or_eval and "run" in recipe and "replicas" in recipe ["run" ]:
9971012 kwargs ["instance_count" ] = recipe ["run" ]["replicas" ]
9981013 return
9991014
0 commit comments