diff --git a/.gitignore b/.gitignore index fc07847fba..3d90b52e01 100644 --- a/.gitignore +++ b/.gitignore @@ -37,4 +37,5 @@ src/sagemaker/modules/train/container_drivers/sourcecode.json src/sagemaker/modules/train/container_drivers/distributed.json tests/data/**/_repack_model.py tests/data/experiment/sagemaker-dev-1.0.tar.gz -src/sagemaker/serve/tmp_workspace \ No newline at end of file +src/sagemaker/serve/tmp_workspace +test-examples \ No newline at end of file diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 16e6ac1cd0..9b4beae5c4 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -905,6 +905,30 @@ def _json_encode_hyperparameters(hyperparameters: Dict[str, Any]) -> Dict[str, A } return hyperparameters + @staticmethod + def _nova_encode_hyperparameters(hyperparameters: Dict[str, Any]) -> Dict[str, Any]: + """Applies JSON encoding for Nova job hyperparameters, preserving string values. + + For Nova jobs, string values should not be JSON-encoded. + + Args: + hyperparameters (dict): Dictionary of hyperparameters. + + Returns: + dict: Dictionary with encoded hyperparameters. + """ + current_hyperparameters = hyperparameters + if current_hyperparameters is not None: + hyperparameters = {} + for k, v in current_hyperparameters.items(): + if is_pipeline_variable(v): + hyperparameters[str(k)] = v.to_string() + elif isinstance(v, str): + hyperparameters[str(k)] = v + else: + hyperparameters[str(k)] = json.dumps(v) + return hyperparameters + def _prepare_for_training(self, job_name=None): """Set any values in the estimator that need to be set before training. @@ -938,7 +962,11 @@ def _prepare_for_training(self, job_name=None): self.source_dir = updated_paths["source_dir"] self.dependencies = updated_paths["dependencies"] - if self.source_dir or self.entry_point or self.dependencies: + if ( + self.source_dir + or self.entry_point + or (self.dependencies and len(self.dependencies) > 0) + ): # validate source dir will raise a ValueError if there is something wrong with # the source directory. We are intentionally not handling it because this is a # critical error. @@ -3579,7 +3607,11 @@ def __init__( git_config=git_config, enable_network_isolation=enable_network_isolation, ) - if not is_pipeline_variable(entry_point) and entry_point.startswith("s3://"): + if ( + not is_pipeline_variable(entry_point) + and entry_point is not None + and entry_point.startswith("s3://") + ): raise ValueError( "Invalid entry point script: {}. Must be a path to a local file.".format( entry_point @@ -3599,6 +3631,7 @@ def __init__( self.checkpoint_s3_uri = checkpoint_s3_uri self.checkpoint_local_path = checkpoint_local_path self.enable_sagemaker_metrics = enable_sagemaker_metrics + self.is_nova_job = kwargs.get("is_nova_job", False) def _prepare_for_training(self, job_name=None): """Set hyperparameters needed for training. This method will also validate ``source_dir``. @@ -3713,7 +3746,10 @@ def _model_entry_point(self): def set_hyperparameters(self, **kwargs): """Escapes the dict argument as JSON, updates the private hyperparameter attribute.""" - self._hyperparameters.update(EstimatorBase._json_encode_hyperparameters(kwargs)) + if self.is_nova_job: + self._hyperparameters.update(EstimatorBase._nova_encode_hyperparameters(kwargs)) + else: + self._hyperparameters.update(EstimatorBase._json_encode_hyperparameters(kwargs)) def hyperparameters(self): """Returns the hyperparameters as a dictionary to use for training. @@ -3724,7 +3760,10 @@ def hyperparameters(self): Returns: dict[str, str]: The hyperparameters. """ - return EstimatorBase._json_encode_hyperparameters(self._hyperparameters) + if self.is_nova_job: + return EstimatorBase._nova_encode_hyperparameters(self._hyperparameters) + else: + return EstimatorBase._json_encode_hyperparameters(self._hyperparameters) @classmethod def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 234f0c61fa..4a00b2dbc1 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -1063,7 +1063,7 @@ def validate_torch_distributed_distribution( ) # Check entry point type - if not entry_point.endswith(".py"): + if entry_point is not None and not entry_point.endswith(".py"): err_msg += ( "Unsupported entry point type for the distribution torch_distributed.\n" "Only python programs (*.py) are supported." diff --git a/src/sagemaker/modules/constants.py b/src/sagemaker/modules/constants.py index e64d85367d..eaf9d131ef 100644 --- a/src/sagemaker/modules/constants.py +++ b/src/sagemaker/modules/constants.py @@ -25,6 +25,10 @@ os.path.dirname(os.path.abspath(__file__)), "train/container_drivers" ) +SM_RECIPE = "recipe" +SM_RECIPE_YAML = "recipe.yaml" +SM_RECIPE_CONTAINER_PATH = f"/opt/ml/input/data/recipe/{SM_RECIPE_YAML}" + SOURCE_CODE_JSON = "sourcecode.json" DISTRIBUTED_JSON = "distributed.json" TRAIN_SCRIPT = "sm_train.sh" diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index eaabe5972a..24b7922895 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -85,6 +85,9 @@ SM_CODE_CONTAINER_PATH, SM_DRIVERS, SM_DRIVERS_LOCAL_PATH, + SM_RECIPE, + SM_RECIPE_YAML, + SM_RECIPE_CONTAINER_PATH, TRAIN_SCRIPT, DEFAULT_CONTAINER_ENTRYPOINT, DEFAULT_CONTAINER_ARGUMENTS, @@ -100,7 +103,12 @@ from sagemaker.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.telemetry.constants import Feature from sagemaker.modules import logger -from sagemaker.modules.train.sm_recipes.utils import _get_args_from_recipe, _determine_device_type +from sagemaker.modules.train.sm_recipes.utils import ( + _get_args_from_recipe, + _determine_device_type, + _is_nova_recipe, + _load_base_recipe, +) class Mode(Enum): @@ -242,6 +250,7 @@ class ModelTrainer(BaseModel): _remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None) _metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None) + _is_nova_recipe: Optional[bool] = PrivateAttr(default=None) _temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) CONFIGURABLE_ATTRIBUTES: ClassVar[List[str]] = [ @@ -449,6 +458,33 @@ def _validate_source_code(self, source_code: Optional[SourceCode]): + "Must be a valid file within the 'source_dir'.", ) + @staticmethod + def _validate_and_load_hyperparameters_file(hyperparameters_file: str) -> Dict[str, Any]: + """Validate the hyperparameters file.""" + if not os.path.exists(hyperparameters_file): + raise ValueError(f"Hyperparameters file not found: {hyperparameters_file}") + logger.info(f"Loading hyperparameters from file: {hyperparameters_file}") + with open(hyperparameters_file, "r") as f: + contents = f.read() + try: + hyperparameters = json.loads(contents) + logger.debug("Hyperparameters loaded as JSON") + return hyperparameters + except json.JSONDecodeError: + try: + logger.info(f"contents: {contents}") + hyperparameters = yaml.safe_load(contents) + if not isinstance(hyperparameters, dict): + raise ValueError("YAML contents must be a valid mapping") + logger.info(f"hyperparameters: {hyperparameters}") + logger.debug("Hyperparameters loaded as YAML") + return hyperparameters + except (yaml.YAMLError, ValueError): + raise ValueError( + f"Invalid hyperparameters file: {hyperparameters_file}. " + "Must be a valid JSON or YAML file." + ) + def model_post_init(self, __context: Any): """Post init method to perform custom validation and set default values.""" self._validate_training_image_and_algorithm_name(self.training_image, self.algorithm_name) @@ -510,27 +546,9 @@ def model_post_init(self, __context: Any): ) if self.hyperparameters and isinstance(self.hyperparameters, str): - if not os.path.exists(self.hyperparameters): - raise ValueError(f"Hyperparameters file not found: {self.hyperparameters}") - logger.info(f"Loading hyperparameters from file: {self.hyperparameters}") - with open(self.hyperparameters, "r") as f: - contents = f.read() - try: - self.hyperparameters = json.loads(contents) - logger.debug("Hyperparameters loaded as JSON") - except json.JSONDecodeError: - try: - logger.info(f"contents: {contents}") - self.hyperparameters = yaml.safe_load(contents) - if not isinstance(self.hyperparameters, dict): - raise ValueError("YAML contents must be a valid mapping") - logger.info(f"hyperparameters: {self.hyperparameters}") - logger.debug("Hyperparameters loaded as YAML") - except (yaml.YAMLError, ValueError): - raise ValueError( - f"Invalid hyperparameters file: {self.hyperparameters}. " - "Must be a valid JSON or YAML file." - ) + self.hyperparameters = self._validate_and_load_hyperparameters_file( + self.hyperparameters + ) if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB: if self.output_data_config is None: @@ -613,6 +631,22 @@ def train( final_input_data_config = list(existing_channels.values()) + new_channels + if self._is_nova_recipe: + for input_data in final_input_data_config: + if input_data.channel_name == SM_RECIPE: + raise ValueError( + "Cannot use reserved channel name 'recipe' as an input channel name " + " for Nova Recipe" + ) + recipe_file_path = os.path.join(self._temp_recipe_train_dir.name, SM_RECIPE_YAML) + recipe_channel = self.create_input_data_channel( + channel_name=SM_RECIPE, + data_source=recipe_file_path, + key_prefix=input_data_key_prefix, + ) + final_input_data_config.append(recipe_channel) + self.hyperparameters.update({"sagemaker_recipe_local_path": SM_RECIPE_CONTAINER_PATH}) + if final_input_data_config: final_input_data_config = self._get_input_data_config( final_input_data_config, input_data_key_prefix @@ -1005,6 +1039,7 @@ def from_recipe( checkpoint_config: Optional[shapes.CheckpointConfig] = None, training_input_mode: Optional[str] = "File", environment: Optional[Dict[str, str]] = None, + hyperparameters: Optional[Union[Dict[str, Any], str]] = {}, tags: Optional[List[Tag]] = None, sagemaker_session: Optional[Session] = None, role: Optional[str] = None, @@ -1101,14 +1136,21 @@ def from_recipe( """ if compute.instance_type is None: raise ValueError( - "Must set ``instance_type`` in compute_config when using training recipes." + "Must set ``instance_type`` in ``compute`` input when using training recipes." ) device_type = _determine_device_type(compute.instance_type) - if device_type == "cpu": + recipe = _load_base_recipe( + training_recipe=training_recipe, recipe_overrides=recipe_overrides + ) + is_nova = _is_nova_recipe(recipe=recipe) + + if device_type == "cpu" and not is_nova: raise ValueError( - "Training recipes are not supported for CPU instances. " + "Training recipe is not supported for CPU instances. " + "Please provide a GPU or Tranium instance type." ) + if training_image is None and is_nova: + raise ValueError("training_image must be provided when using recipe for Nova.") if training_image_config and training_image is None: raise ValueError("training_image must be provided when using training_image_config.") @@ -1126,15 +1168,27 @@ def from_recipe( # - distributed # - compute # - hyperparameters - model_trainer_args, recipe_train_dir = _get_args_from_recipe( - training_recipe=training_recipe, + model_trainer_args, tmp_dir = _get_args_from_recipe( + training_recipe=recipe, recipe_overrides=recipe_overrides, requirements=requirements, compute=compute, region_name=sagemaker_session.boto_region_name, + role=role, ) if training_image is not None: model_trainer_args["training_image"] = training_image + if hyperparameters and not is_nova: + logger.warning( + "Hyperparameters are not supported for general training recipes. " + + "Ignoring hyperparameters input." + ) + if is_nova: + if hyperparameters and isinstance(hyperparameters, str): + hyperparameters = cls._validate_and_load_hyperparameters_file(hyperparameters) + model_trainer_args["hyperparameters"].update(hyperparameters) + elif hyperparameters and isinstance(hyperparameters, dict): + model_trainer_args["hyperparameters"].update(hyperparameters) model_trainer = cls( sagemaker_session=sagemaker_session, @@ -1151,8 +1205,8 @@ def from_recipe( tags=tags, **model_trainer_args, ) - - model_trainer._temp_recipe_train_dir = recipe_train_dir + model_trainer._is_nova_recipe = is_nova + model_trainer._temp_recipe_train_dir = tmp_dir return model_trainer def with_tensorboard_output_config( diff --git a/src/sagemaker/modules/train/sm_recipes/utils.py b/src/sagemaker/modules/train/sm_recipes/utils.py index 6b39add6cd..3b7659016e 100644 --- a/src/sagemaker/modules/train/sm_recipes/utils.py +++ b/src/sagemaker/modules/train/sm_recipes/utils.py @@ -19,20 +19,21 @@ import shutil import tempfile from urllib.request import urlretrieve -from typing import Dict, Any, Optional, Tuple +from typing import Dict, Any, Optional, Tuple, Union import omegaconf -from omegaconf import OmegaConf, dictconfig +from omegaconf import OmegaConf, dictconfig, DictConfig from sagemaker.image_uris import retrieve from sagemaker.modules import logger from sagemaker.modules.utils import _run_clone_command_silent +from sagemaker.modules.constants import SM_RECIPE_YAML from sagemaker.modules.configs import Compute, SourceCode from sagemaker.modules.distributed import Torchrun, SMP -def _try_resolve_recipe(recipe, key=None): +def _try_resolve_recipe(recipe: DictConfig, key=None) -> DictConfig: """Try to resolve recipe and return resolved recipe.""" if key is not None: recipe = dictconfig.DictConfig({key: recipe}) @@ -86,6 +87,8 @@ def _load_base_recipe( ) else: recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_") + if training_recipes_cfg is None: + training_recipes_cfg = _load_recipes_cfg() launcher_repo = os.environ.get("TRAINING_LAUNCHER_GIT", None) or training_recipes_cfg.get( "launcher_repo" @@ -149,7 +152,7 @@ def _get_trainining_recipe_gpu_model_name_and_script(model_type: str): def _configure_gpu_args( training_recipes_cfg: Dict[str, Any], region_name: str, - recipe: OmegaConf, + recipe: DictConfig, recipe_train_dir: tempfile.TemporaryDirectory, ) -> Dict[str, Any]: """Configure arguments specific to GPU.""" @@ -231,12 +234,110 @@ def _configure_trainium_args( return args +def _is_nova_recipe( + recipe: DictConfig, +) -> bool: + """Check if the recipe is a Nova recipe. + + A recipe is considered a Nova recipe if it meets either of the following conditions: + + 1. It has a run section with: + - A model_type that includes "amazon.nova" + - A model_name_or_path field + + OR + + 2. It has a training_config section with: + - A distillation_data field + + Args: + recipe (DictConfig): The loaded recipe configuration + + Returns: + bool: True if the recipe is a Nova recipe, False otherwise + """ + run_config = recipe.get("run", {}) + model_type = run_config.get("model_type", "").lower() + has_nova_model = ( + model_type and "amazon.nova" in model_type and "model_name_or_path" in run_config + ) + + # Check for distillation data + training_config = recipe.get("training_config", {}) + has_distillation = training_config.get("distillation_data") is not None + return bool(has_nova_model) or bool(has_distillation) + + +def _get_args_from_nova_recipe( + recipe: DictConfig, + compute: Compute, + role: Optional[str] = None, +) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]: + if not compute.instance_count and not recipe.get("run", {}).get("replicas", None): + raise ValueError("Must set ``instance_type`` in compute or ``replicas`` in recipe.") + compute.instance_count = compute.instance_count or recipe.get("run", {}).get("replicas") + + args = dict() + args.update({"hyperparameters": {}}) + + run_config = recipe.get("run", {}) + model_name_or_path = run_config.get("model_name_or_path") + if model_name_or_path: + if model_name_or_path.startswith("s3://"): + args["hyperparameters"]["base_model_location"] = model_name_or_path + else: + args["hyperparameters"]["base_model"] = model_name_or_path + + # Handle distillation configuration + training_config = recipe.get("training_config", {}) + distillation_data = training_config.get("distillation_data") + if bool(distillation_data): + args["hyperparameters"]["distillation_data"] = distillation_data + if not role: + raise ValueError("Must provide 'role' parameter when using Nova distillation") + args["hyperparameters"]["role_arn"] = role + + kms_key = training_config.get("kms_key") + if kms_key is None: + raise ValueError( + 'Nova distillation job recipe requires "kms_key" field in "training_config"' + ) + args["hyperparameters"]["kms_key"] = kms_key + + _register_custom_resolvers() + + # Resolve Final Recipe + final_recipe = _try_resolve_recipe(recipe) + if final_recipe is None: + final_recipe = _try_resolve_recipe(recipe, "recipes") + if final_recipe is None: + final_recipe = _try_resolve_recipe(recipe, "training") + if final_recipe is None: + raise RuntimeError("Could not resolve provided recipe.") + + # Save Final Recipe to tmp dir + recipe_local_dir = tempfile.TemporaryDirectory(prefix="recipe_") + final_recipe_path = os.path.join(recipe_local_dir.name, SM_RECIPE_YAML) + OmegaConf.save(config=final_recipe, f=final_recipe_path) + + args.update( + { + "compute": compute, + "training_image": None, + "source_code": None, + "distributed": None, + } + ) + return args, recipe_local_dir + + def _get_args_from_recipe( - training_recipe: str, + training_recipe: Union[str, DictConfig], compute: Compute, region_name: str, recipe_overrides: Optional[Dict[str, Any]], requirements: Optional[str], + role: Optional[str] = None, ) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]: """Get arguments for ModelTrainer from a training recipe. @@ -252,8 +353,8 @@ def _get_args_from_recipe( ``` Args: - training_recipe (str): - Name of the training recipe or path to the recipe file. + training_recipe (Union[str, Dict[str, Any]]): + Name of the training recipe or path to the recipe file or loaded recipe Dict. compute (Compute): Compute configuration for training. region_name (str): @@ -267,7 +368,13 @@ def _get_args_from_recipe( raise ValueError("Must set `instance_type` in compute when using training recipes.") training_recipes_cfg = _load_recipes_cfg() - recipe = _load_base_recipe(training_recipe, recipe_overrides, training_recipes_cfg) + if isinstance(training_recipe, str): + recipe = _load_base_recipe(training_recipe, recipe_overrides, training_recipes_cfg) + else: + recipe = training_recipe + if _is_nova_recipe(recipe): + args, recipe_local_dir = _get_args_from_nova_recipe(recipe, compute, role=role) + return args, recipe_local_dir if "trainer" not in recipe: raise ValueError("Supplied recipe does not contain required field trainer.") @@ -281,7 +388,7 @@ def _get_args_from_recipe( if compute.instance_count is None: if "num_nodes" not in recipe["trainer"]: raise ValueError( - "Must provide Compute with instance_count or" " set trainer -> num_nodes in recipe." + "Must provide Compute with instance_count or set trainer -> num_nodes in recipe." ) compute.instance_count = recipe["trainer"]["num_nodes"] @@ -311,7 +418,7 @@ def _get_args_from_recipe( # Save Final Recipe to source_dir OmegaConf.save( - config=final_recipe, f=os.path.join(args["source_code"].source_dir, "recipe.yaml") + config=final_recipe, f=os.path.join(args["source_code"].source_dir, SM_RECIPE_YAML) ) # If recipe_requirements is provided, copy it to source_dir @@ -323,7 +430,7 @@ def _get_args_from_recipe( args.update( { "compute": compute, - "hyperparameters": {"config-path": ".", "config-name": "recipe.yaml"}, + "hyperparameters": {"config-path": ".", "config-name": SM_RECIPE_YAML}, } ) diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index d56c100546..633317927b 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -19,6 +19,8 @@ import os import shutil import tempfile +import time +from datetime import datetime from typing import Union, Optional, Dict from urllib.request import urlretrieve @@ -27,6 +29,7 @@ from packaging.version import Version from sagemaker.estimator import Framework, EstimatorBase +from sagemaker.inputs import TrainingInput, FileSystemInput from sagemaker.fw_utils import ( framework_name_from_image, framework_version_from_tag, @@ -126,6 +129,170 @@ def _get_training_recipe_trainium_script(code_dir, source_dir): return script +def _is_nova_recipe(recipe): + """Check if the recipe is a Nova recipe. + + A Nova recipe is identified by: + 1. Having a run section + 2. The model_type in run has a "amazon.nova" prefix + 3. The run contains model_name_or_path + + OR + + 1. Has a training_config section + 2. The training config_section has a distillation_data field + + Args: + recipe (OmegaConf): The loaded recipe configuration + + Returns: + bool: True if the recipe is a Nova recipe, False otherwise + """ + # Check for nova model + run_config = recipe.get("run", {}) + model_type = run_config.get("model_type", "").lower() + has_nova_model = ( + model_type and "amazon.nova" in model_type and "model_name_or_path" in run_config + ) + + # Check for distillation data + training_config = recipe.get("training_config", {}) + has_distillation = training_config.get("distillation_data") is not None + + return bool(has_nova_model) or bool(has_distillation) + + +def _recipe_initialize_args(source_dir): + """Initialize the arguments dictionary for recipe setup. + + Args: + source_dir (str): Path to the source directory. + + Returns: + dict: Initialized arguments dictionary. + + Raises: + ValueError: If source_dir is not a local directory. + """ + args = {"hyperparameters": {}} + + if source_dir is None: + args["source_dir"] = "." + else: + if not os.path.exists(source_dir): + raise ValueError("When using training_recipe, source_dir must be a local directory.") + args["source_dir"] = source_dir + + return args + + +def _recipe_get_region_name(kwargs): + """Get the AWS region name from session or create a new session. + + Args: + kwargs (dict): Dictionary of keyword arguments. + + Returns: + str: AWS region name. + """ + if kwargs.get("sagemaker_session") is not None: + return kwargs.get("sagemaker_session").boto_region_name + return Session().boto_region_name + + +def _recipe_load_config(): + """Load the training recipes configuration from JSON file. + + Returns: + dict: Training recipes configuration. + """ + training_recipes_cfg_filename = os.path.join(os.path.dirname(__file__), "training_recipes.json") + with open(training_recipes_cfg_filename) as training_recipes_cfg_file: + return json.load(training_recipes_cfg_file) + + +def _recipe_load_from_yaml(training_recipe, temp_local_recipe): + """Load recipe from a YAML file or URL. + + Args: + training_recipe (str): Path to the training recipe. + temp_local_recipe (str): Path to the temporary local recipe file. + + Raises: + ValueError: If the recipe cannot be fetched. + """ + if os.path.isfile(training_recipe): + shutil.copy(training_recipe, temp_local_recipe) + else: + try: + urlretrieve(training_recipe, temp_local_recipe) + except Exception as e: + raise ValueError( + f"Could not fetch the provided recipe {training_recipe}: exception {str(e)}" + ) + + +def _recipe_load_predefined( + training_recipe, recipe_launcher_dir, temp_local_recipe, training_recipes_cfg +): + """Load a predefined recipe from the recipe launcher. + + Args: + training_recipe (str): Name of the predefined recipe. + recipe_launcher_dir (str): Path to the recipe launcher directory. + temp_local_recipe (str): Path to the temporary local recipe file. + training_recipes_cfg (dict): Training recipes configuration. + + Raises: + ValueError: If the recipe cannot be found. + """ + launcher_repo = os.environ.get("TRAINING_LAUNCHER_GIT", None) or training_recipes_cfg.get( + "launcher_repo" + ) + _run_clone_command(launcher_repo, recipe_launcher_dir) + recipe_path = os.path.join( + recipe_launcher_dir, + "recipes_collection", + "recipes", + training_recipe + ".yaml", + ) + if os.path.isfile(recipe_path): + shutil.copy(recipe_path, temp_local_recipe) + else: + raise ValueError(f"Recipe {training_recipe} not found.") + + +def _device_get_distribution(device_type): + """Get the distribution configuration based on device type. + + Args: + device_type (str): Device type (gpu, trainium, or cpu). + + Returns: + dict: Distribution configuration. + + Raises: + ValueError: If the device type is not supported. + """ + if device_type == "gpu": + smp_options = { + "enabled": True, + "parameters": { + "placement_strategy": "cluster", + }, + } + return { + "smdistributed": {"modelparallel": smp_options}, + "torch_distributed": {"enabled": True}, + } + elif device_type == "trainium": + return { + "torch_distributed": {"enabled": True}, + } + else: + return {} + + class PyTorch(Framework): """Handle end-to-end training and deployment of custom PyTorch code.""" @@ -358,6 +525,7 @@ def __init__( :class:`~sagemaker.estimator.Framework` and :class:`~sagemaker.estimator.EstimatorBase`. """ + self.is_nova_recipe = False if training_recipe is not None: if entry_point is not None: logger.warning("Argument entry_point will be ignored with training_recipe.") @@ -368,6 +536,10 @@ def __init__( args = self._setup_for_training_recipe( training_recipe, recipe_overrides, source_dir, kwargs ) + + if self.is_nova_recipe and image_uri is None: + raise ValueError("Must supply image_uri for nova jobs.") + entry_point = args["entry_point"] source_dir = args["source_dir"] hyperparameters = args["hyperparameters"] @@ -392,7 +564,12 @@ def __init__( kwargs["enable_sagemaker_metrics"] = True super(PyTorch, self).__init__( - entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs + entry_point, + source_dir, + hyperparameters, + image_uri=image_uri, + is_nova_job=self.is_nova_recipe, + **kwargs, ) if "entry_point" not in kwargs: @@ -499,6 +676,72 @@ def hyperparameters(self): return hyperparameters + def fit( + self, + inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None, + wait: bool = True, + logs: str = "All", + job_name: Optional[str] = None, + experiment_config: Optional[Dict[str, str]] = None, + ): + """Train a model using the input training dataset. + + Adds the recipe file to the inputs when a training recipe is used. + + Args: + inputs (str or dict or sagemaker.inputs.TrainingInput or + sagemaker.inputs.FileSystemInput): Information about the training data. + wait (bool): Whether the call should wait until the job completes (default: True). + logs ([str]): A list of strings specifying which logs to print. + job_name (str): Training job name. + experiment_config (dict[str, str]): Experiment management configuration. + + Returns: + None or pipeline step arguments + """ + # Handle recipe upload and input channel creation if we have a recipe + if ( + self.is_nova_recipe is not None + and self.is_nova_recipe + and hasattr(self, "training_recipe_file") + and self.training_recipe_file + ): + # Upload the recipe to S3 if it hasn't been uploaded yet + if not hasattr(self, "recipe_s3_uri") or not self.recipe_s3_uri: + self.recipe_s3_uri = self._upload_recipe_to_s3( + self.sagemaker_session, self.training_recipe_file.name + ) + + # Prepare inputs dictionary + from sagemaker.inputs import TrainingInput + + if inputs is None: + inputs = {} + elif not isinstance(inputs, dict): + inputs = {"training": inputs} + + # Add the recipe channel + recipe_channel_name = "recipe" + inputs[recipe_channel_name] = TrainingInput( + s3_data=os.path.dirname(self.recipe_s3_uri), input_mode="File" + ) + + # Update hyperparameters to reference the recipe location in the container + recipe_filename = os.path.basename(self.training_recipe_file.name) + + self._hyperparameters.update( + { + "sagemaker_recipe_local_path": f"/opt/ml/input/data/{recipe_channel_name}/{recipe_filename}", + } + ) + return super(PyTorch, self).fit( + inputs=inputs, + wait=wait, + logs=logs, + job_name=job_name, + experiment_config=experiment_config, + ) + def create_model( self, model_server_workers=None, @@ -604,155 +847,209 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na return init_params - @classmethod - def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, source_dir, kwargs): - """Performs training recipe specific setup and returns recipe specific args. + # The old class methods have been replaced by static methods and module-level functions - Updates kwargs and returns a dictionary of args to use for estimator - initialization and setup when using a training recipe. Updates the paths in - the recipe for Sagemaker Jobs environment. + @staticmethod + def _recipe_load(training_recipe, recipe_launcher_dir, training_recipes_cfg): + """Load the recipe from file path, URL, or predefined recipe. Args: - training_recipe (str): A recipe which is a local file path, a url or a - sagemaker training recipe. - recipe_overrides (Dict): Dictionary specifying key values to override in the - source_dir (str): Path (absolute, or relative) to a directory where to copy - the scripts for training recipe. requirements.txt can also - go here. - kwargs (dict): Dictionary of args used for estimator initializaiton. + training_recipe (str): Path to the training recipe. + recipe_launcher_dir (str): Path to the recipe launcher directory. + training_recipes_cfg (dict): Training recipes configuration. + Returns: - dict containing arg values for estimator initialization and setup. + tuple: Recipe name and loaded recipe. + Raises: + ValueError: If the recipe cannot be fetched or found. """ - if kwargs.get("sagemaker_session") is not None: - region_name = kwargs.get("sagemaker_session").boto_region_name - else: - region_name = Session().boto_region_name - - training_recipes_cfg_filename = os.path.join( - os.path.dirname(__file__), "training_recipes.json" - ) - with open(training_recipes_cfg_filename) as training_recipes_cfg_file: - training_recipes_cfg = json.load(training_recipes_cfg_file) - - if recipe_overrides is None: - recipe_overrides = dict() - recipe_train_dir = tempfile.TemporaryDirectory(prefix="training_") - recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_") - args = dict() - if source_dir is None: - args["source_dir"] = "." - else: - if not os.path.exists(source_dir): - raise ValueError( - "When using training_recipe, source_dir must be a local directory." - ) - args["source_dir"] = source_dir - recipe_name = os.path.splitext(os.path.basename(training_recipe))[0] temp_local_recipe = tempfile.NamedTemporaryFile(prefix=recipe_name, suffix=".yaml").name - if training_recipe.endswith(".yaml"): - if os.path.isfile(training_recipe): - shutil.copy(training_recipe, temp_local_recipe) + + try: + if training_recipe.endswith(".yaml"): + _recipe_load_from_yaml(training_recipe, temp_local_recipe) else: - try: - urlretrieve(training_recipe, temp_local_recipe) - except Exception as e: - raise ValueError( - f"Could not fetch the provided recipe {training_recipe}: exception {str(e)}" - ) + _recipe_load_predefined( + training_recipe, recipe_launcher_dir, temp_local_recipe, training_recipes_cfg + ) + + recipe = OmegaConf.load(temp_local_recipe) + os.unlink(temp_local_recipe) + return recipe_name, recipe + except Exception as e: + if os.path.exists(temp_local_recipe): + os.unlink(temp_local_recipe) + raise e + + @staticmethod + def _device_get_image_uri(args, device_type, recipe_config, region_name, recipe): + """Get the appropriate image URI based on device type. + + Args: + args (dict): Arguments dictionary. + device_type (str): Device type (gpu, trainium, or cpu). + recipe_config (dict): Training recipes configuration. + region_name (str): AWS region name. + recipe (OmegaConf): Recipe configuration. + + Returns: + str: Image URI or None if no image URI was found. + """ + if "default_image_uri" in args: + logger.debug("Image URI already exists") + return args["default_image_uri"] + elif device_type == "gpu": + logger.info("Using GPU training image") + return _get_training_recipe_image_uri(recipe_config.get("gpu_image"), region_name) + elif device_type == "trainium": + logger.info("Using Trainium training image") + return _get_training_recipe_image_uri(recipe_config.get("neuron_image"), region_name) else: - launcher_repo = os.environ.get( - "TRAINING_LAUNCHER_GIT", None - ) or training_recipes_cfg.get("launcher_repo") - _run_clone_command(launcher_repo, recipe_launcher_dir.name) - recipe = os.path.join( - recipe_launcher_dir.name, - "recipes_collection", - "recipes", - training_recipe + ".yaml", - ) - if os.path.isfile(recipe): - shutil.copy(recipe, temp_local_recipe) + return None + + @staticmethod + def _recipe_setup_nova(args, recipe): + """Set up configuration for Nova recipes. + + Args: + args (dict): Arguments dictionary. + recipe (OmegaConf): Recipe configuration. + kwargs (dict): Dictionary of keyword arguments. + """ + run_config = recipe.get("run", {}) + model_name_or_path = run_config.get("model_name_or_path") + + # Set hyperparameters based on model_name_or_path + if model_name_or_path: + if model_name_or_path.startswith("s3://"): + args["hyperparameters"]["base_model_location"] = model_name_or_path else: - raise ValueError(f"Recipe {training_recipe} not found.") + args["hyperparameters"]["base_model"] = model_name_or_path + + args["entry_point"] = None + args["source_dir"] = None - recipe = OmegaConf.load(temp_local_recipe) - os.unlink(temp_local_recipe) - recipe = OmegaConf.merge(recipe, recipe_overrides) + @staticmethod + def _device_validate_and_get_type(kwargs, recipe): + """Validate instance type and determine device type. + + Args: + kwargs (dict): Dictionary of keyword arguments. + recipe (OmegaConf): Recipe configuration. + Returns: + str: Device type (gpu, trainium, or cpu). + + Raises: + ValueError: If instance_type is not provided or recipe is invalid. + """ if "instance_type" not in kwargs: raise ValueError("Must pass instance type to estimator when using training recipes.") + + if not _is_nova_recipe(recipe) and "trainer" not in recipe: + raise ValueError("Supplied recipe does not contain required field trainer.") + instance_type = kwargs["instance_type"].split(".")[1] if instance_type.startswith(("p", "g")): - device_type = "gpu" + return "gpu" elif instance_type.startswith("trn"): - device_type = "trainium" + return "trainium" else: - device_type = "cpu" + return "cpu" - if "trainer" not in recipe: - raise ValueError("Supplied recipe does not contain required field trainer.") - if "instance_count" in kwargs and "num_nodes" in recipe["trainer"]: - logger.warning( - "Using instance_count argument to estimator to set number " - " of nodes. Ignoring trainer -> num_nodes in recipe." - ) - if "instance_count" not in kwargs: - if "num_nodes" not in recipe["trainer"]: - raise ValueError( - "Must set either instance_count argument for estimator or" - "set trainer -> num_nodes in recipe." + @staticmethod + def _device_handle_instance_count(kwargs, recipe): + """Handle instance count configuration. + + Args: + kwargs (dict): Dictionary of keyword arguments. + recipe (OmegaConf): Recipe configuration. + + Raises: + ValueError: If instance_count is not provided and cannot be found in the recipe. + """ + # Check if instance_count is already provided in kwargs + + is_nova = _is_nova_recipe(recipe) + if "instance_count" in kwargs: + # Warn if there are conflicting configurations in the recipe + if "num_nodes" in recipe.get("trainer", {}): + logger.warning( + "Using instance_count argument to estimator to set number " + "of nodes. Ignoring trainer -> num_nodes in recipe." ) + if is_nova and "replicas" in recipe.get("run", {}): + logger.warning( + "Using instance_count argument to estimator to set number " + "of nodes. Ignoring run -> replicas in recipe." + ) + return + + # Try to get instance_count from recipe + if "trainer" in recipe and "num_nodes" in recipe["trainer"]: kwargs["instance_count"] = recipe["trainer"]["num_nodes"] + return + + if is_nova and "run" in recipe and "replicas" in recipe["run"]: + kwargs["instance_count"] = recipe["run"]["replicas"] + return - # [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve - # to retrieve the image uri below before we go GA. + # If we get here, we couldn't find instance_count anywhere + raise ValueError( + "Must set either instance_count argument for estimator or " + "set trainer -> num_nodes or run -> replicas in recipe for nova jobs." + ) + + @staticmethod + def _device_get_entry_point_script( + device_type, recipe_train_dir, recipe, source_dir, training_recipes_cfg + ): + """Get the entry point script based on device type. + + Args: + device_type (str): Device type (gpu, trainium, or cpu). + recipe_train_dir (str): Path to the recipe training directory. + recipe (OmegaConf): Recipe configuration. + source_dir (str): Path to the source directory. + training_recipes_cfg (dict): Training recipes configuration. + + Returns: + str: Path to the entry point script or None if not applicable. + """ if device_type == "gpu": adapter_repo = os.environ.get("TRAINING_ADAPTER_GIT", None) or training_recipes_cfg.get( "adapter_repo" ) - _run_clone_command(adapter_repo, recipe_train_dir.name) - script = _get_training_recipe_gpu_script( - recipe_train_dir.name, recipe, args["source_dir"] - ) - args["default_image_uri"] = _get_training_recipe_image_uri( - training_recipes_cfg.get("gpu_image"), region_name - ) - smp_options = { - "enabled": True, - "parameters": { - "placement_strategy": "cluster", - }, - } - args["distribution"] = { - "smdistributed": {"modelparallel": smp_options}, - "torch_distributed": {"enabled": True}, - } + _run_clone_command(adapter_repo, recipe_train_dir) + return _get_training_recipe_gpu_script(recipe_train_dir, recipe, source_dir) elif device_type == "trainium": - _run_clone_command(training_recipes_cfg.get("neuron_dist_repo"), recipe_train_dir.name) - script = _get_training_recipe_trainium_script(recipe_train_dir.name, args["source_dir"]) - args["default_image_uri"] = _get_training_recipe_image_uri( - training_recipes_cfg.get("neuron_image"), region_name - ) - args["distribution"] = { - "torch_distributed": {"enabled": True}, - } - else: + _run_clone_command(training_recipes_cfg.get("neuron_dist_repo"), recipe_train_dir) + return _get_training_recipe_trainium_script(recipe_train_dir, source_dir) + elif device_type == "cpu": raise ValueError( f"Devices of type {device_type} are not supported with training recipes." ) - args["entry_point"] = os.path.basename(script) + return None - recipe_train_dir.cleanup() - recipe_launcher_dir.cleanup() + def _recipe_resolve_and_save(self, recipe, recipe_name, source_dir): + """Resolve and save the final recipe configuration. - if "container" in recipe and not recipe["container"]: - logger.warning( - "Ignoring container from training_recipe. Use image_uri arg for estimator." - ) + Args: + recipe (OmegaConf): Recipe configuration. + recipe_name (str): Recipe name. + source_dir (str): Path to the source directory. + + Returns: + OmegaConf: Resolved recipe configuration. + Raises: + RuntimeError: If the recipe cannot be resolved. + """ _setup_omegaconf_resolvers() + + # Try different resolution strategies final_recipe = _try_resolve_recipe(recipe) if final_recipe is None: final_recipe = _try_resolve_recipe(recipe, "recipes") @@ -760,15 +1057,258 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, source_di final_recipe = _try_resolve_recipe(recipe, "training") if final_recipe is None: raise RuntimeError("Could not resolve provided recipe.") - cls.training_recipe_file = tempfile.NamedTemporaryFile( - dir=args["source_dir"], + + # Save the resolved recipe - this sets an instance attribute + self.training_recipe_file = tempfile.NamedTemporaryFile( + dir=source_dir, prefix=recipe_name + "_", suffix=".yaml", ) - OmegaConf.save(config=final_recipe, f=cls.training_recipe_file.name) - args["hyperparameters"] = { - "config-path": ".", - "config-name": os.path.basename(cls.training_recipe_file.name), - } + OmegaConf.save(config=final_recipe, f=self.training_recipe_file.name) + + return final_recipe + + def _upload_recipe_to_s3(self, session, recipe_file_path): + """Upload the recipe file to S3. + + Args: + session (sagemaker.session.Session): SageMaker session. + recipe_file_path (str): Path to the recipe file. + + Returns: + str: S3 URI of the uploaded recipe file. + """ + bucket = session.default_bucket() + key_prefix = session.default_bucket_prefix + + recipe_filename = os.path.basename(recipe_file_path) + + readable_date = datetime.fromtimestamp(int(time.time())) + date_format = readable_date.strftime("%Y-%m-%d") + + if key_prefix != "None" and key_prefix is not None: + s3_key = f"{key_prefix}/recipes/{date_format}_{recipe_filename[:-5]}" + else: + s3_key = f"recipes/{date_format}_{recipe_filename[:-5]}" + + # Upload the recipe file to S3 + s3_uri = session.upload_data( + path=recipe_file_path, + bucket=bucket, + key_prefix=os.path.dirname(os.path.join(s3_key, recipe_filename)), + ) + + # Return the full S3 URI to the recipe file + return f"{s3_uri}" + + def _setup_for_training_recipe(self, training_recipe, recipe_overrides, source_dir, kwargs): + """Performs training recipe specific setup and returns recipe specific args. + + Updates kwargs and returns a dictionary of args to use for estimator + initialization and setup when using a training recipe. + + Args: + training_recipe (str): A recipe which is a local file path, a url or a + sagemaker training recipe. + recipe_overrides (Dict): Dictionary specifying key values to override in the + training recipe. + source_dir (str): Path (absolute, or relative) to a directory where to copy + the scripts for training recipe. + kwargs (dict): Dictionary of args used for estimator initialization. + + Returns: + dict containing arg values for estimator initialization and setup. + """ + region_name = _recipe_get_region_name(kwargs) + training_recipes_cfg = _recipe_load_config() + recipe_overrides = recipe_overrides or {} + + # Create temporary directories for recipe processing + with ( + tempfile.TemporaryDirectory(prefix="training_") as recipe_train_dir, + tempfile.TemporaryDirectory(prefix="launcher_") as recipe_launcher_dir, + ): + # Load and process the recipe + recipe_name, recipe = PyTorch._recipe_load( + training_recipe, recipe_launcher_dir, training_recipes_cfg + ) + + # Merge with overrides + recipe = OmegaConf.merge(recipe, recipe_overrides) + + self.is_nova_recipe = _is_nova_recipe(recipe) + if self.is_nova_recipe: + return self._setup_for_nova_recipe( + recipe, + recipe_name, + source_dir, + kwargs, + ) + else: + return self._setup_for_standard_recipe( + recipe, + recipe_name, + source_dir, + kwargs, + recipe_train_dir, + training_recipes_cfg, + region_name, + ) + + def _setup_for_nova_recipe( + self, + recipe, + recipe_name, + source_dir, + kwargs, + ): + """Set up configuration specifically for Nova recipes. + + Args: + recipe (OmegaConf): Recipe configuration. + recipe_name (str): Recipe name. + source_dir (str): Path to the source directory. + kwargs (dict): Dictionary of keyword arguments. + + Returns: + dict: Arguments dictionary for estimator initialization. + """ + # Initialize args + args = _recipe_initialize_args(source_dir) + + # Set up Nova-specific configuration + run_config = recipe.get("run", {}) + model_name_or_path = run_config.get("model_name_or_path") + + # Set hyperparameters based on model_name_or_path + if model_name_or_path: + if model_name_or_path.startswith("s3://"): + args["hyperparameters"]["base_model_location"] = model_name_or_path + else: + args["hyperparameters"]["base_model"] = model_name_or_path + + args["entry_point"] = None + args["source_dir"] = None + args["distribution"] = {} + + logger.info("Remote debugging, profiler and debugger hooks are disabled for Nova recipes.") + kwargs["enable_remote_debug"] = False + kwargs["disable_profiler"] = True + kwargs["debugger_hook_config"] = False + + # Handle instance count for Nova recipes + if "instance_count" in kwargs: + if "replicas" in recipe.get("run", {}): + logger.warning( + "Using instance_count argument to estimator to set number " + "of nodes. Ignoring run -> replicas in recipe." + ) + elif "run" in recipe and "replicas" in recipe["run"]: + kwargs["instance_count"] = recipe["run"]["replicas"] + else: + raise ValueError( + "Must set either instance_count argument for estimator or " + "set run -> replicas in recipe for nova jobs." + ) + + training_config = recipe.get("training_config", {}) + is_distillation = training_config.get("distillation_data", {}) + if bool(is_distillation): + args["hyperparameters"]["distillation_data"] = is_distillation + args["hyperparameters"]["role_arn"] = kwargs["role"] + kms_key = training_config.get("kms_key") + if kms_key is None: + ValueError( + 'Nova distillation job recipe requires "kms_key" field in "training_config"' + ) + args["hyperparameters"]["kms_key"] = kms_key + + # Resolve and save the final recipe + self._recipe_resolve_and_save(recipe, recipe_name, args["source_dir"]) + + return args + + def _setup_for_standard_recipe( + self, + recipe, + recipe_name, + source_dir, + kwargs, + recipe_train_dir, + training_recipes_cfg, + region_name, + ): + """Set up configuration for standard (non-Nova) recipes. + + Args: + recipe (OmegaConf): Recipe configuration. + recipe_name (str): Recipe name. + source_dir (str): Path to the source directory. + kwargs (dict): Dictionary of keyword arguments. + recipe_train_dir (str): Path to the recipe training directory. + training_recipes_cfg (dict): Training recipes configuration. + region_name (str): AWS region name. + + Returns: + dict: Arguments dictionary for estimator initialization. + """ + # Initialize args + args = _recipe_initialize_args(source_dir) + + # Validate recipe structure + if "trainer" not in recipe: + raise ValueError("Supplied recipe does not contain required field trainer.") + + # Handle instance count for standard recipes + if "instance_count" in kwargs: + if "num_nodes" in recipe.get("trainer", {}): + logger.warning( + "Using instance_count argument to estimator to set number " + "of nodes. Ignoring trainer -> num_nodes in recipe." + ) + elif "trainer" in recipe and "num_nodes" in recipe["trainer"]: + kwargs["instance_count"] = recipe["trainer"]["num_nodes"] + else: + raise ValueError( + "Must set either instance_count argument for estimator or " + "set trainer -> num_nodes in recipe." + ) + + # Determine device type + device_type = PyTorch._device_validate_and_get_type(kwargs, recipe) + + # Get image URI + image_uri = PyTorch._device_get_image_uri( + args, device_type, training_recipes_cfg, region_name, recipe + ) + args["default_image_uri"] = image_uri if image_uri is not None else "" + + # Setup device-specific configuration + args["distribution"] = _device_get_distribution(device_type) + + # Set entry point if not already set + if "entry_point" not in args: + script = PyTorch._device_get_entry_point_script( + device_type, recipe_train_dir, recipe, args["source_dir"], training_recipes_cfg + ) + if script: + args["entry_point"] = os.path.basename(script) + + # Handle container configuration + if "container" in recipe and not recipe["container"]: + logger.warning( + "Ignoring container from training_recipe. Use image_uri arg for estimator." + ) + + # Resolve and save the final recipe + self._recipe_resolve_and_save(recipe, recipe_name, args["source_dir"]) + + # Update hyperparameters with recipe configuration + args["hyperparameters"].update( + { + "config-path": ".", + "config-name": os.path.basename(self.training_recipe_file.name), + } + ) return args diff --git a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py index 585a4d2745..a58b1f641e 100644 --- a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py +++ b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py @@ -14,9 +14,10 @@ from __future__ import absolute_import import pytest -from unittest.mock import patch +from unittest.mock import patch, MagicMock import yaml +from omegaconf import OmegaConf from urllib.request import urlretrieve from tempfile import NamedTemporaryFile @@ -27,6 +28,8 @@ _configure_gpu_args, _configure_trainium_args, _get_trainining_recipe_gpu_model_name_and_script, + _is_nova_recipe, + _get_args_from_nova_recipe, ) from sagemaker.modules.utils import _run_clone_command_silent from sagemaker.modules.configs import Compute @@ -181,6 +184,35 @@ def test_get_args_from_recipe_compute( assert args is None +@patch("sagemaker.modules.train.sm_recipes.utils._get_args_from_nova_recipe") +def test_get_args_from_recipe_with_nova_and_role(mock_get_args_from_nova_recipe, temporary_recipe): + # Set up mock return value + mock_args = {"hyperparameters": {}} + mock_dir = MagicMock() + mock_get_args_from_nova_recipe.return_value = (mock_args, mock_dir) + + # Create a Nova recipe with distillation data + recipe = OmegaConf.create( + {"training_config": {"distillation_data": True, "kms_key": "alias/my-kms-key"}} + ) + compute = Compute(instance_type="ml.g5.xlarge") + role = "arn:aws:iam::123456789012:role/SageMakerRole" + + # Mock the Nova recipe detection to return True + with patch("sagemaker.modules.train.sm_recipes.utils._is_nova_recipe", return_value=True): + _get_args_from_recipe( + training_recipe=recipe, + compute=compute, + region_name="us-west-2", + recipe_overrides=None, + requirements=None, + role=role, + ) + + # Verify _get_args_from_nova_recipe was called with the role parameter + mock_get_args_from_nova_recipe.assert_called_once_with(recipe, compute, role=role) + + @pytest.mark.parametrize( "test_case", [ @@ -213,3 +245,199 @@ def test_get_trainining_recipe_gpu_model_name_and_script(test_case): model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(model_type) assert model_base_name == test_case["model_base_name"] assert script == test_case["script"] + + +@pytest.mark.parametrize( + "test_case", + [ + { + "recipe": { + "run": { + "model_type": "amazon.nova", + "model_name_or_path": "some-model", + } + }, + "is_nova": True, + }, + { + "recipe": { + "run": { + "model_type": "amazon.nova.other", + "model_name_or_path": "some-model", + } + }, + "is_nova": True, + }, + {"recipe": {"run": {"model_type": "amazon.nova.other"}}, "is_nova": False}, + { + "recipe": {"run": {"model_type": "other.model", "model_name_or_path": "some-model"}}, + "is_nova": False, + }, + { + "recipe": {"training_config": {"distillation_data": "s3://bucket/distillation-data"}}, + "is_nova": True, + }, + { + "recipe": {"training_config": {"some_other_field": "value"}}, + "is_nova": False, + }, + ], + ids=[ + "nova_model", + "nova_model_subtype", + "nova_missing_model_path", + "non_nova_model", + "distillation_data", + "no_distillation_data", + ], +) +def test_is_nova_recipe(test_case): + recipe = OmegaConf.create(test_case["recipe"]) + is_nova = _is_nova_recipe(recipe) + assert is_nova == test_case["is_nova"] + + +@pytest.mark.parametrize( + "test_case", + [ + { + "recipe": { + "run": {"model_type": "amazon.nova", "model_name_or_path": "dummy-test"}, + }, + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "expected_args": { + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "hyperparameters": {"base_model": "dummy-test"}, + "training_image": None, + "source_code": None, + "distributed": None, + }, + }, + { + "recipe": { + "run": { + "model_type": "amazon.nova", + "model_name_or_path": "s3://bucket/path/to/model", + }, + }, + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "expected_args": { + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "hyperparameters": {"base_model_location": "s3://bucket/path/to/model"}, + "training_image": None, + "source_code": None, + "distributed": None, + }, + }, + { + "recipe": { + "run": { + "model_type": "amazon.nova", + "model_name_or_path": "s3://bucket/path/to/model", + "replicas": 4, + }, + }, + "compute": Compute(instance_type="ml.m5.xlarge"), + "expected_args": { + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=4), + "hyperparameters": {"base_model_location": "s3://bucket/path/to/model"}, + "training_image": None, + "source_code": None, + "distributed": None, + }, + }, + { + "recipe": { + "run": { + "model_type": "amazon.nova", + "model_name_or_path": "s3://bucket/path/to/model", + "replicas": 2, + }, + }, + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=4), + "expected_args": { + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=4), + "hyperparameters": {"base_model_location": "s3://bucket/path/to/model"}, + "training_image": None, + "source_code": None, + "distributed": None, + }, + }, + ], +) +def test_get_args_from_nova_recipe(test_case): + recipe = OmegaConf.create(test_case["recipe"]) + args, _ = _get_args_from_nova_recipe(recipe=recipe, compute=test_case["compute"]) + assert args == test_case["expected_args"] + + +@pytest.mark.parametrize( + "test_case", + [ + { + "recipe": { + "training_config": { + "distillation_data": "s3://bucket/distillation-data", + "kms_key": "alias/my-kms-key", + } + }, + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "role": "arn:aws:iam::123456789012:role/SageMakerRole", + "expected_args": { + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "hyperparameters": { + "distillation_data": "s3://bucket/distillation-data", + "role_arn": "arn:aws:iam::123456789012:role/SageMakerRole", + "kms_key": "alias/my-kms-key", + }, + "training_image": None, + "source_code": None, + "distributed": None, + }, + }, + ], +) +def test_get_args_from_nova_recipe_with_distillation(test_case): + recipe = OmegaConf.create(test_case["recipe"]) + args, _ = _get_args_from_nova_recipe( + recipe=recipe, compute=test_case["compute"], role=test_case["role"] + ) + assert args == test_case["expected_args"] + + +@pytest.mark.parametrize( + "test_case", + [ + { + "recipe": { + "training_config": { + "distillation_data": "s3://bucket/distillation-data", + # Missing kms_key + } + }, + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "role": "arn:aws:iam::123456789012:role/SageMakerRole", + }, + { + "recipe": { + "training_config": { + "distillation_data": "s3://bucket/distillation-data", + "kms_key": "alias/my-kms-key", + } + }, + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + # Missing role + "role": None, + }, + ], + ids=[ + "missing_kms_key", + "missing_role", + ], +) +def test_get_args_from_nova_recipe_with_distillation_errors(test_case): + recipe = OmegaConf.create(test_case["recipe"]) + with pytest.raises(ValueError): + _get_args_from_nova_recipe( + recipe=recipe, compute=test_case["compute"], role=test_case.get("role") + ) diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index 23ea167ecf..184f9c30da 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -21,6 +21,7 @@ import pytest from pydantic import ValidationError from unittest.mock import patch, MagicMock, ANY, mock_open +from tempfile import NamedTemporaryFile from sagemaker import image_uris from sagemaker_core.main.resources import TrainingJob @@ -43,6 +44,7 @@ DISTRIBUTED_JSON, SOURCE_CODE_JSON, TRAIN_SCRIPT, + SM_RECIPE_CONTAINER_PATH, ) from sagemaker.modules.configs import ( Compute, @@ -1339,3 +1341,91 @@ def test_input_merge(mock_training_job, modules_session): input_mode="File", ), ] + + +@patch("sagemaker.modules.train.model_trainer._get_unique_name") +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_nova_recipe(mock_training_job, mock_unique_name, modules_session): + def mock_upload_data(path, bucket, key_prefix): + if os.path.isfile(path): + file_name = os.path.basename(path) + return f"s3://{bucket}/{key_prefix}/{file_name}" + else: + return f"s3://{bucket}/{key_prefix}" + + unique_name = "base-job-0123456789" + base_name = "base-job" + + modules_session.upload_data.side_effect = mock_upload_data + mock_unique_name.return_value = unique_name + + recipe_data = { + "run": { + "name": "dummy-model", + "model_type": "amazon.nova", + "model_name_or_path": "dummy-model", + } + } + with NamedTemporaryFile(suffix=".yaml", delete=False) as recipe: + with open(recipe.name, "w") as file: + yaml.dump(recipe_data, file) + + trainer = ModelTrainer.from_recipe( + training_recipe=recipe.name, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + training_image=DEFAULT_IMAGE, + base_job_name=base_name, + ) + + assert trainer._is_nova_recipe + + trainer.train() + mock_training_job.create.assert_called_once() + assert mock_training_job.create.call_args.kwargs["hyper_parameters"] == { + "base_model": "dummy-model", + "sagemaker_recipe_local_path": SM_RECIPE_CONTAINER_PATH, + } + + default_base_path = f"s3://{DEFAULT_BUCKET}/{DEFAULT_BUCKET_PREFIX}/{base_name}" + assert mock_training_job.create.call_args.kwargs["input_data_config"] == [ + Channel( + channel_name="recipe", + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=f"{default_base_path}/{unique_name}/input/recipe/recipe.yaml", + s3_data_distribution_type="FullyReplicated", + ) + ), + input_mode="File", + ) + ] + + +def test_nova_recipe_with_distillation(modules_session): + recipe_data = {"training_config": {"distillation_data": "true", "kms_key": "alias/my-kms-key"}} + + with NamedTemporaryFile(suffix=".yaml", delete=False) as recipe: + with open(recipe.name, "w") as file: + yaml.dump(recipe_data, file) + + # Create ModelTrainer from recipe + trainer = ModelTrainer.from_recipe( + training_recipe=recipe.name, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + training_image=DEFAULT_IMAGE, + ) + + # Verify that the hyperparameters were set correctly + assert trainer.hyperparameters == { + "distillation_data": "true", + "role_arn": DEFAULT_ROLE, + "kms_key": "alias/my-kms-key", + } + + # Clean up the temporary file + os.unlink(recipe.name) diff --git a/tests/unit/test_pytorch_nova.py b/tests/unit/test_pytorch_nova.py new file mode 100644 index 0000000000..f78bdcae7d --- /dev/null +++ b/tests/unit/test_pytorch_nova.py @@ -0,0 +1,753 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +import tempfile +from mock import Mock, patch +from omegaconf import OmegaConf + +from sagemaker.estimator import EstimatorBase + +from sagemaker.pytorch import PyTorch +from sagemaker.pytorch.estimator import ( + _is_nova_recipe, + _device_get_distribution, +) +from sagemaker.inputs import TrainingInput +from sagemaker.session_settings import SessionSettings + +# Constants for testing +ROLE = "Dummy" +REGION = "us-west-2" +BUCKET_NAME = "mybucket" +INSTANCE_COUNT = 1 +INSTANCE_TYPE = "ml.c4.4xlarge" +INSTANCE_TYPE_GPU = "ml.p4d.24xlarge" +IMAGE_URI = "sagemaker-pytorch" + + +@pytest.fixture(name="sagemaker_session") +def fixture_sagemaker_session(): + boto_mock = Mock(name="boto_session", region_name=REGION) + session = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + s3_resource=None, + s3_client=None, + settings=SessionSettings(), + default_bucket_prefix=None, + ) + session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + session.expand_role = Mock(name="expand_role", return_value=ROLE) + session.upload_data = Mock(return_value="s3://mybucket/recipes/nova-recipe.yaml") + session.sagemaker_config = {} + return session + + +def test_is_nova_recipe(): + """Test that _is_nova_recipe correctly identifies Nova recipes.""" + # Valid Nova recipe + recipe = OmegaConf.create( + { + "run": { + "model_type": "amazon.nova.foo-bar", + "model_name_or_path": "foo-bar/foo-bar123", + } + } + ) + assert _is_nova_recipe(recipe) is True + + # Not a Nova recipe - missing run section + recipe = OmegaConf.create( + { + "trainer": { + "model_type": "amazon.nova.foo-bar", + "model_name_or_path": "foo-bar/foo-bar123", + } + } + ) + assert _is_nova_recipe(recipe) is False + + # Not a Nova recipe - wrong model_type + recipe = OmegaConf.create( + {"run": {"model_type": "foo-bar3", "model_name_or_path": "foo-bar/foo-bar123"}} + ) + assert _is_nova_recipe(recipe) is False + + # Not a Nova recipe - missing model_name_or_path + recipe = OmegaConf.create({"run": {"model_type": "amazon.nova.foo-bar"}}) + assert _is_nova_recipe(recipe) is False + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save") +def test_setup_for_nova_recipe_with_model_name(mock_resolve_save, sagemaker_session): + """Test that _setup_for_nova_recipe correctly sets up hyperparameters for Nova recipes with model name.""" + # Create a mock recipe + recipe = OmegaConf.create( + { + "run": { + "model_type": "amazon.nova.foobar3", + "model_name_or_path": "foobar/foobar-3-8b", + "replicas": 4, + } + } + ) + + # Setup the expected return value + expected_args = { + "hyperparameters": {"base_model": "foobar/foobar-3-8b"}, + "entry_point": None, + "source_dir": None, + "distribution": {}, + "default_image_uri": IMAGE_URI, + } + + # Mock the _setup_for_nova_recipe method + with patch( + "sagemaker.pytorch.estimator.PyTorch._setup_for_nova_recipe", return_value=expected_args + ) as mock_nova_setup: + # Create the PyTorch estimator with mocked _recipe_load + with patch( + "sagemaker.pytorch.estimator.PyTorch._recipe_load", return_value=("nova_recipe", recipe) + ): + # Mock _recipe_resolve_and_save to return our recipe + mock_resolve_save.return_value = recipe + + pytorch = PyTorch( + training_recipe="nova_recipe", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + ) + + # Check that the Nova recipe was correctly identified + assert pytorch.is_nova_recipe is True + + # Verify _setup_for_nova_recipe was called + mock_nova_setup.assert_called_once() + call_args = mock_nova_setup.call_args + assert len(call_args[0]) >= 2 # Check that at least recipe and recipe_name were passed + assert call_args[0][0] == recipe # first arg should be recipe + assert call_args[0][1] == "nova_recipe" # second arg should be recipe_name + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save") +def test_setup_for_nova_recipe_with_s3_path(mock_resolve_save, sagemaker_session): + """Test that _setup_for_nova_recipe correctly sets up hyperparameters for Nova recipes with S3 path.""" + # Create a mock recipe with S3 path + recipe = OmegaConf.create( + { + "run": { + "model_type": "amazon.nova.foobar3", + "model_name_or_path": "s3://mybucket/models/foobar3", + "replicas": 4, + } + } + ) + + # Setup the expected return value + expected_args = { + "hyperparameters": {"base_model_location": "s3://mybucket/models/foobar3"}, + "entry_point": None, + "source_dir": None, + "distribution": {}, + "default_image_uri": IMAGE_URI, + } + + # Mock the _setup_for_nova_recipe method + with patch( + "sagemaker.pytorch.estimator.PyTorch._setup_for_nova_recipe", return_value=expected_args + ) as mock_nova_setup: + # Create the PyTorch estimator with mocked _recipe_load + with patch( + "sagemaker.pytorch.estimator.PyTorch._recipe_load", return_value=("nova_recipe", recipe) + ): + # Mock _recipe_resolve_and_save to return our recipe + mock_resolve_save.return_value = recipe + + pytorch = PyTorch( + training_recipe="nova_recipe", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + ) + + # Check that the Nova recipe was correctly identified + assert pytorch.is_nova_recipe is True + + # Verify _setup_for_nova_recipe was called + mock_nova_setup.assert_called_once() + + # Verify that hyperparameters were set correctly + assert ( + pytorch._hyperparameters.get("base_model_location") + == "s3://mybucket/models/foobar3" + ) + + +def test_device_handle_instance_count_with_nova_replicas(): + """Test that _device_handle_instance_count correctly gets instance_count from Nova recipe replicas.""" + # Create mock recipe with replicas + recipe = OmegaConf.create( + { + "run": { + "model_type": "amazon.nova.foobar3", + "model_name_or_path": "foobar/foobar-3-8b", + "replicas": 4, + } + } + ) + + # Test with no instance_count in kwargs + kwargs = {} + PyTorch._device_handle_instance_count(kwargs, recipe) + assert kwargs["instance_count"] == 4 + + +def test_device_handle_instance_count_with_nova_no_replicas(): + """Test that _device_handle_instance_count raises an error when no instance_count or replicas are provided.""" + # Create mock recipe without replicas + recipe = OmegaConf.create( + {"run": {"model_type": "amazon.nova.foobar3", "model_name_or_path": "foobar/foobar-3-8b"}} + ) + + # Test with no instance_count in kwargs + kwargs = {} + with pytest.raises(ValueError) as error: + PyTorch._device_handle_instance_count(kwargs, recipe) + + assert "Must set either instance_count argument for estimator or" in str(error) + + +@patch("sagemaker.pytorch.estimator.logger.warning") +def test_device_handle_instance_count_with_nova_both_provided(mock_warning): + """Test that _device_handle_instance_count warns when both instance_count and replicas are provided.""" + # Create mock recipe with replicas + recipe = OmegaConf.create( + { + "run": { + "model_type": "amazon.nova.foobar3", + "model_name_or_path": "foobar/foobar-3-8b", + "replicas": 4, + } + } + ) + + # Test with instance_count in kwargs + kwargs = {"instance_count": 2} + PyTorch._device_handle_instance_count(kwargs, recipe) + + # Verify warning was logged + mock_warning.assert_called_with( + "Using instance_count argument to estimator to set number " + "of nodes. Ignoring run -> replicas in recipe." + ) + + # Verify instance_count wasn't changed + assert kwargs["instance_count"] == 2 + + +def test_device_validate_and_get_type_with_nova(): + """Test that _device_validate_and_get_type works correctly with Nova recipes.""" + # Create mock recipe + recipe = OmegaConf.create( + {"run": {"model_type": "amazon.nova.foobar3", "model_name_or_path": "foobar/foobar-3-8b"}} + ) + + # Test with GPU instance type + kwargs = {"instance_type": INSTANCE_TYPE_GPU} + device_type = PyTorch._device_validate_and_get_type(kwargs, recipe) + assert device_type == "gpu" + + # Test with CPU instance type + kwargs = {"instance_type": INSTANCE_TYPE} + device_type = PyTorch._device_validate_and_get_type(kwargs, recipe) + assert device_type == "cpu" + + +def test_device_validate_and_get_type_no_instance_type(): + """Test that _device_validate_and_get_type raises an error when no instance_type is provided.""" + # Create mock recipe + recipe = OmegaConf.create( + {"run": {"model_type": "amazon.nova.foobar3", "model_name_or_path": "foobar/foobar-3-8b"}} + ) + + # Test with no instance_type + kwargs = {} + with pytest.raises(ValueError) as error: + PyTorch._device_validate_and_get_type(kwargs, recipe) + + assert "Must pass instance type to estimator" in str(error) + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_load") +@patch("time.time", return_value=1714500000) # May 1, 2024 +def test_upload_recipe_to_s3(mock_time, mock_recipe_load, sagemaker_session): + """Test that _upload_recipe_to_s3 correctly uploads the recipe file to S3.""" + # Create a mock recipe that will be identified as a Nova recipe + mock_recipe = OmegaConf.create( + {"run": {"model_type": "amazon.nova.foobar3", "model_name_or_path": "foobar/foobar-3-8b"}} + ) + + # Set up the mock to return a recipe name and the mock recipe + mock_recipe_load.return_value = ("nova_recipe", mock_recipe) + + # Setup + pytorch = PyTorch( + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + training_recipe="nova_recipe", + ) + + # Set Nova recipe attributes + pytorch.is_nova_recipe = True + + # Create a temporary file to use as the recipe file + with tempfile.NamedTemporaryFile(suffix=".yaml") as temp_file: + # Test uploading the recipe file to S3 + s3_uri = pytorch._upload_recipe_to_s3(sagemaker_session, temp_file.name) + + # Verify the upload_data method was called with the correct parameters + sagemaker_session.upload_data.assert_called_once() + + # Check that the S3 URI is returned correctly + assert s3_uri == sagemaker_session.upload_data.return_value + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_load") +@patch("tempfile.NamedTemporaryFile") +@patch("omegaconf.OmegaConf.save") +@patch("sagemaker.pytorch.estimator._try_resolve_recipe") +def test_recipe_resolve_and_save( + mock_try_resolve, mock_save, mock_temp_file, mock_recipe_load, sagemaker_session +): + """Test that _recipe_resolve_and_save correctly resolves an`d saves the recipe.""" + # Create a mock recipe that will be identified as a Nova recipe + mock_recipe = OmegaConf.create( + {"run": {"model_type": "amazon.nova.foobar3", "model_name_or_path": "foobar/foobar-3-8b"}} + ) + + # Set up the mock to return a recipe name and the mock recipe + mock_recipe_load.return_value = ("nova_recipe", mock_recipe) + + # Setup + pytorch = PyTorch( + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + training_recipe="nova_recipe", + ) + + # Set Nova recipe attributes + pytorch.is_nova_recipe = True + + # Mock the temporary file + mock_temp_file_instance = Mock() + mock_temp_file_instance.name = "/tmp/nova-recipe_12345.yaml" + mock_temp_file.return_value = mock_temp_file_instance + + # Create mock recipe + recipe = OmegaConf.create( + {"run": {"model_type": "amazon.nova.foobar3", "model_name_or_path": "foobar/foobar-3-8b"}} + ) + + # Mock the recipe resolution + mock_try_resolve.side_effect = [recipe, None, None] + + # Call the _recipe_resolve_and_save method + result = pytorch._recipe_resolve_and_save(recipe, "nova-recipe", ".") + + # Verify the recipe was resolved and saved + mock_try_resolve.assert_called_with(recipe) + mock_save.assert_called_with(config=recipe, f=mock_temp_file_instance.name) + + # Verify the result is the resolved recipe + assert result == recipe + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_load") +@patch("sagemaker.pytorch.estimator.Framework.fit") +def test_fit_with_nova_recipe_s3_upload(mock_framework_fit, mock_recipe_load, sagemaker_session): + """Test that fit correctly uploads the recipe to S3 and adds it to the inputs.""" + # Create a mock recipe that will be identified as a Nova recipe + mock_recipe = OmegaConf.create( + {"run": {"model_type": "amazon.nova.foobar", "model_name_or_path": "foobar/foobar123"}} + ) + + # Set up the mock to return a recipe name and the mock recipe + mock_recipe_load.return_value = ("nova_recipe", mock_recipe) + + # Create a PyTorch estimator with a Nova recipe + with tempfile.NamedTemporaryFile(suffix=".yaml") as temp_file: + pytorch = PyTorch( + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + training_recipe="nova_recipe", + ) + + # Set Nova recipe attributes + pytorch.is_nova_recipe = True + pytorch.training_recipe_file = temp_file + + # Mock the _upload_recipe_to_s3 method + with patch.object(pytorch, "_upload_recipe_to_s3") as mock_upload_recipe: + mock_upload_recipe.return_value = "s3://mybucket/recipes/nova-recipe.yaml" + + # Call the fit method + pytorch.fit() + + # Verify the upload_recipe_to_s3 method was called + mock_upload_recipe.assert_called_once_with(sagemaker_session, temp_file.name) + + # Verify the fit method was called with the recipe channel + call_args = mock_framework_fit.call_args[1] + assert "inputs" in call_args + assert "recipe" in call_args["inputs"] + + # Verify the hyperparameters were updated with the recipe path + assert "sagemaker_recipe_local_path" in pytorch._hyperparameters + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_load") +@patch("sagemaker.pytorch.estimator.PyTorch._upload_recipe_to_s3") +@patch("sagemaker.pytorch.estimator.Framework.fit") +def test_fit_with_nova_recipe_and_inputs( + mock_framework_fit, mock_upload_recipe, mock_recipe_load, sagemaker_session +): + """Test that fit correctly handles Nova recipes with additional inputs.""" + # Create a mock recipe that will be identified as a Nova recipe + mock_recipe = OmegaConf.create( + {"run": {"model_type": "amazon.nova.foobar3", "model_name_or_path": "foobar/foobar-3-8b"}} + ) + + # Set up the mock to return a recipe name and the mock recipe + mock_recipe_load.return_value = ("nova_recipe", mock_recipe) + mock_upload_recipe.return_value = "s3://mybucket/recipes/nova-recipe.yaml" + + # Create a PyTorch estimator with a Nova recipe + with tempfile.NamedTemporaryFile(suffix=".yaml") as temp_file: + pytorch = PyTorch( + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + training_recipe="nova_recipe", + ) + + # Set Nova recipe attributes + pytorch.is_nova_recipe = True + pytorch.training_recipe_file = temp_file + + # Create training inputs + train_input = TrainingInput(s3_data="s3://mybucket/train") + val_input = TrainingInput(s3_data="s3://mybucket/validation") + inputs = {"train": train_input, "validation": val_input} + + # Call the fit method with inputs + pytorch.fit(inputs=inputs) + + # Verify the fit method was called with both the recipe channel and the provided inputs + call_args = mock_framework_fit.call_args[1] + assert "inputs" in call_args + assert "recipe" in call_args["inputs"] + assert "train" in call_args["inputs"] + assert "validation" in call_args["inputs"] + + # Verify the hyperparameters were updated with the recipe path + assert "sagemaker_recipe_local_path" in pytorch._hyperparameters + + +def test_device_get_distribution(): + """Test that _device_get_distribution returns the correct distribution configuration.""" + # Test with GPU device type + gpu_distribution = _device_get_distribution("gpu") + expected_gpu_distribution = { + "torch_distributed": {"enabled": True}, + "smdistributed": { + "modelparallel": { + "enabled": True, + "parameters": { + "placement_strategy": "cluster", + }, + }, + }, + } + assert gpu_distribution == expected_gpu_distribution + + # Test with Trainium device type + trainium_distribution = _device_get_distribution("trainium") + expected_trainium_distribution = { + "torch_distributed": {"enabled": True}, + } + assert trainium_distribution == expected_trainium_distribution + + # Test with CPU device type + cpu_distribution = _device_get_distribution("cpu") + assert cpu_distribution == {} + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_load") +@patch("sagemaker.pytorch.estimator.PyTorch._upload_recipe_to_s3") +@patch("sagemaker.pytorch.estimator.Framework.fit") +def test_fit_with_nova_recipe( + mock_framework_fit, mock_upload_recipe, mock_recipe_load, sagemaker_session +): + """Test that fit correctly handles Nova recipes.""" + + # Create a mock recipe that will be identified as a Nova recipe + mock_recipe = OmegaConf.create( + { + "run": { + "model_type": "amazon.nova.foo-bar", + "model_name_or_path": "foo-bar123", + } + } + ) + + # Set up the mock to return a recipe name and the mock recipe + mock_recipe_load.return_value = ("nova_recipe", mock_recipe) + + # Create a PyTorch estimator with a Nova recipe + with tempfile.NamedTemporaryFile(suffix=".yaml") as temp_file: + pytorch = PyTorch( + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + training_recipe="nova_recipe", + ) + + # Set Nova recipe attributes + pytorch.is_nova_recipe = True + pytorch.training_recipe_file = temp_file + + # Mock the upload_recipe_to_s3 method + mock_upload_recipe.return_value = "s3://mybucket/recipes/nova-recipe.yaml" + + # Call the fit method + pytorch.fit() + + # Verify the upload_recipe_to_s3 method was called + mock_upload_recipe.assert_called_once_with(sagemaker_session, temp_file.name) + + # Verify the fit method was called with the recipe channel + call_args = mock_framework_fit.call_args[1] + assert "inputs" in call_args + assert "recipe" in call_args["inputs"] + + # Verify the hyperparameters were updated with the recipe path + assert "sagemaker_recipe_local_path" in pytorch._hyperparameters + + +def test_nova_encode_hyperparameters(): + """Test that _nova_encode_hyperparameters correctly preserves string values and encodes non-string values.""" + # Setup test hyperparameters + hyperparameters = { + "string_param": "string_value", + "int_param": 42, + "float_param": 3.14, + "bool_param": True, + "list_param": [1, 2, 3], + "dict_param": {"key": "value"}, + } + + # Call the method + encoded = EstimatorBase._nova_encode_hyperparameters(hyperparameters) + + # Verify string values are preserved + assert encoded["string_param"] == "string_value" + + # Verify non-string values are JSON-encoded + assert encoded["int_param"] == "42" + assert encoded["float_param"] == "3.14" + assert encoded["bool_param"] == "true" + assert encoded["list_param"] == "[1, 2, 3]" + assert encoded["dict_param"] == '{"key": "value"}' + + +def test_framework_set_hyperparameters_nova(): + """Test that Framework.set_hyperparameters uses _nova_encode_hyperparameters for Nova jobs.""" + # Setup + framework = PyTorch( + entry_point="dummy.py", + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version="1.13.1", + py_version="py3", + image_uri=IMAGE_URI, + ) + + framework.is_nova_job = True + + # Add hyperparameters + framework.set_hyperparameters(string_param="string_value", int_param=42, bool_param=True) + + # Verify string values are preserved and non-string values are encoded + assert framework._hyperparameters["string_param"] == "string_value" + assert framework._hyperparameters["int_param"] == "42" + assert framework._hyperparameters["bool_param"] == "true" + + +def test_framework_set_hyperparameters_non_nova(): + """Test that Framework.set_hyperparameters uses _json_encode_hyperparameters for non-Nova jobs.""" + # Setup + framework = PyTorch( + entry_point="dummy.py", + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version="1.13.1", + py_version="py3", + image_uri=IMAGE_URI, + ) + framework.is_nova_recipe = False + + # Add hyperparameters + framework.set_hyperparameters(string_param="string_value", int_param=42, bool_param=True) + + # Verify all values are JSON-encoded + assert framework._hyperparameters["string_param"] == '"string_value"' + assert framework._hyperparameters["int_param"] == "42" + assert framework._hyperparameters["bool_param"] == "true" + + +def test_framework_hyperparameters_nova(): + """Test that Framework.hyperparameters uses _nova_encode_hyperparameters for Nova jobs.""" + # Setup + framework = PyTorch( + entry_point="dummy.py", + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version="1.13.1", + py_version="py3", + image_uri=IMAGE_URI, + ) + + framework.is_nova_job = True + + # Add hyperparameters directly to _hyperparameters + framework._hyperparameters = { + "string_param": "string_value", + "int_param": 42, + "bool_param": True, + } + + # Get hyperparameters + hyperparams = framework.hyperparameters() + + # Verify string values are preserved and non-string values are encoded + assert hyperparams["string_param"] == "string_value" + assert hyperparams["int_param"] == "42" + assert hyperparams["bool_param"] == "true" + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save") +def test_setup_for_nova_recipe_with_distillation(mock_resolve_save, sagemaker_session): + """Test that _setup_for_nova_recipe correctly handles distillation configurations.""" + # Create a mock recipe with distillation config + recipe = OmegaConf.create( + { + "run": { + "model_type": "amazon.nova.foobar3", + "model_name_or_path": "foobar/foobar-3-8b", + "replicas": 4, + }, + "training_config": { + "distillation_data": "s3://mybucket/distillation-data", + "kms_key": "alias/my-kms-key", + }, + } + ) + + # Setup the expected return value + expected_args = { + "hyperparameters": { + "base_model": "foobar/foobar-3-8b", + "distillation_data": "s3://mybucket/distillation-data", + "role_arn": "arn:aws:iam::123456789012:role/SageMakerRole", + "kms_key": "alias/my-kms-key", + }, + "entry_point": None, + "source_dir": None, + "distribution": {}, + "default_image_uri": IMAGE_URI, + } + + with patch( + "sagemaker.pytorch.estimator.PyTorch._setup_for_nova_recipe", return_value=expected_args + ) as mock_nova_setup: + with patch( + "sagemaker.pytorch.estimator.PyTorch._recipe_load", return_value=("nova_recipe", recipe) + ): + mock_resolve_save.return_value = recipe + + pytorch = PyTorch( + training_recipe="nova_recipe", + role="arn:aws:iam::123456789012:role/SageMakerRole", + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + ) + + # Check that the Nova recipe was correctly identified + assert pytorch.is_nova_recipe is True + + # Verify _setup_for_nova_recipe was called + mock_nova_setup.assert_called_once() + + # Verify that hyperparameters were set correctly for distillation + assert ( + pytorch._hyperparameters.get("distillation_data") + == "s3://mybucket/distillation-data" + ) + assert pytorch._hyperparameters.get("kms_key") == "alias/my-kms-key" + assert ( + pytorch._hyperparameters.get("role_arn") + == "arn:aws:iam::123456789012:role/SageMakerRole" + )