diff --git a/pyproject.toml b/pyproject.toml index 4657f41737..be05949d02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ classifiers = [ ] dependencies = [ "attrs>=23.1.0,<24", - "boto3>=1.34.142,<2.0", + "boto3>=1.35.75,<2.0", "cloudpickle==2.2.1", "docker", "fastapi", @@ -49,7 +49,7 @@ dependencies = [ "psutil", "PyYAML~=6.0", "requests", - "sagemaker-core>=1.0.15,<2.0.0", + "sagemaker-core>=1.0.17,<2.0.0", "schema", "smdebug_rulesconfig==1.0.1", "tblib>=1.7.0,<4", diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index ea51a86101..6efc04c88e 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -185,6 +185,7 @@ def __init__( disable_output_compression: bool = False, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, + training_plan: Optional[Union[str, PipelineVariable]] = None, **kwargs, ): """Initialize an ``EstimatorBase`` instance. @@ -554,6 +555,8 @@ def __init__( Specifies whether RemoteDebug is enabled for the training job. enable_session_tag_chaining (bool or PipelineVariable): Optional. Specifies whether SessionTagChaining is enabled for the training job. + training_plan (str or PipelineVariable): Optional. + Specifies which training plan arn to use for the training job """ instance_count = renamed_kwargs( "train_instance_count", "instance_count", instance_count, kwargs @@ -762,8 +765,7 @@ def __init__( self.tensorboard_output_config = tensorboard_output_config - self.debugger_rule_configs = None - self.collection_configs = None + self.debugger_rule_configs, self.collection_configs = None, None self.enable_sagemaker_metrics = enable_sagemaker_metrics @@ -774,6 +776,7 @@ def __init__( sagemaker_session=self.sagemaker_session, ) + self.profiler_rule_configs, self.profiler_rules = None, None self.profiler_config = profiler_config self.disable_profiler = resolve_value_from_config( direct_input=disable_profiler, @@ -796,8 +799,6 @@ def __init__( ) or _instance_type_supports_profiler(self.instance_type): self.disable_profiler = True - self.profiler_rule_configs = None - self.profiler_rules = None self.debugger_rules = None self.disable_output_compression = disable_output_compression validate_source_code_input_against_pipeline_variables( @@ -807,6 +808,8 @@ def __init__( enable_network_isolation=self._enable_network_isolation, ) + self.training_plan = training_plan + # Internal flag self._is_output_path_set_from_default_bucket_and_prefix = False @@ -1960,6 +1963,9 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na "KeepAlivePeriodInSeconds" ] + if "TrainingPlanArn" in job_details["ResourceConfig"]: + init_params["training_plan"] = job_details["ResourceConfig"]["TrainingPlanArn"] + has_hps = "HyperParameters" in job_details init_params["hyperparameters"] = job_details["HyperParameters"] if has_hps else {} @@ -2840,6 +2846,7 @@ def __init__( enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, + training_plan: Optional[Union[str, PipelineVariable]] = None, **kwargs, ): """Initialize an ``Estimator`` instance. @@ -3205,6 +3212,8 @@ def __init__( Specifies whether RemoteDebug is enabled for the training job enable_session_tag_chaining (bool or PipelineVariable): Optional. Specifies whether SessionTagChaining is enabled for the training job + training_plan (str or PipelineVariable): Optional. + Specifies which training plan arn to use for the training job """ self.image_uri = image_uri self._hyperparameters = hyperparameters.copy() if hyperparameters else {} @@ -3258,6 +3267,7 @@ def __init__( disable_output_compression=disable_output_compression, enable_remote_debug=enable_remote_debug, enable_session_tag_chaining=enable_session_tag_chaining, + training_plan=training_plan, **kwargs, ) diff --git a/src/sagemaker/job.py b/src/sagemaker/job.py index 7040c376ab..210dd426c5 100644 --- a/src/sagemaker/job.py +++ b/src/sagemaker/job.py @@ -83,6 +83,7 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True): estimator.volume_size, estimator.volume_kms_key, estimator.keep_alive_period_in_seconds, + estimator.training_plan, ) stop_condition = _Job._prepare_stop_condition(estimator.max_run, estimator.max_wait) vpc_config = estimator.get_vpc_config() @@ -294,6 +295,7 @@ def _prepare_resource_config( volume_size, volume_kms_key, keep_alive_period_in_seconds, + training_plan, ): """Placeholder docstring""" resource_config = { @@ -319,6 +321,8 @@ def _prepare_resource_config( ) resource_config["InstanceCount"] = instance_count resource_config["InstanceType"] = instance_type + if training_plan is not None: + resource_config["TrainingPlanArn"] = training_plan return resource_config diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index a83964e394..91f547afb6 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -82,6 +82,12 @@ class VariableTypes(str, Enum): BOOL = "bool" +class HubContentCapability(str, Enum): + """Enum class for HubContent capabilities.""" + + BEDROCK_CONSOLE = "BEDROCK_CONSOLE" + + class JumpStartTag(str, Enum): """Enum class for tag keys to apply to JumpStart models.""" @@ -99,6 +105,8 @@ class JumpStartTag(str, Enum): HUB_CONTENT_ARN = "sagemaker-sdk:hub-content-arn" + BEDROCK = "sagemaker-sdk:bedrock" + class SerializerType(str, Enum): """Enum class for serializers associated with JumpStart models.""" diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 8b30317a52..a41c9ed952 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -115,6 +115,7 @@ def __init__( enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, config_name: Optional[str] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, + training_plan: Optional[Union[str, PipelineVariable]] = None, ): """Initializes a ``JumpStartEstimator``. @@ -511,6 +512,8 @@ def __init__( Name of the training configuration to apply to the Estimator. (Default: None). enable_session_tag_chaining (bool or PipelineVariable): Optional. Specifies whether SessionTagChaining is enabled for the training job + training_plan (str or PipelineVariable): Optional. + Specifies which training plan arn to use for the training job Raises: ValueError: If the model ID is not recognized by JumpStart. @@ -599,6 +602,7 @@ def _validate_model_id_and_get_type_hook(): enable_remote_debug=enable_remote_debug, config_name=config_name, enable_session_tag_chaining=enable_session_tag_chaining, + training_plan=training_plan, ) self.hub_arn = estimator_init_kwargs.hub_arn diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 84c9d09c3d..e4020a39bd 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -144,6 +144,7 @@ def get_init_kwargs( enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, config_name: Optional[str] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, + training_plan: Optional[Union[str, PipelineVariable]] = None, ) -> JumpStartEstimatorInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object.""" @@ -205,6 +206,7 @@ def get_init_kwargs( enable_remote_debug=enable_remote_debug, config_name=config_name, enable_session_tag_chaining=enable_session_tag_chaining, + training_plan=training_plan, ) estimator_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set( diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 82bc1fc174..328e1e8227 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -42,7 +42,7 @@ from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines -from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType, HubContentCapability from sagemaker.jumpstart.types import ( HubContentType, JumpStartModelDeployKwargs, @@ -53,6 +53,7 @@ from sagemaker.jumpstart.utils import ( add_hub_content_arn_tags, add_jumpstart_model_info_tags, + add_bedrock_store_tags, get_default_jumpstart_session_with_user_agent_suffix, get_top_ranked_config_name, update_dict_if_key_not_present, @@ -495,6 +496,10 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: ) kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn) + if hasattr(kwargs.specs, "capabilities") and kwargs.specs.capabilities is not None: + if HubContentCapability.BEDROCK_CONSOLE in kwargs.specs.capabilities: + kwargs.tags = add_bedrock_store_tags(kwargs.tags, compatibility="compatible") + return kwargs @@ -657,6 +662,7 @@ def get_deploy_kwargs( config_name: Optional[str] = None, routing_config: Optional[Dict[str, Any]] = None, model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None, + inference_ami_version: Optional[str] = None, ) -> JumpStartModelDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object.""" @@ -694,6 +700,7 @@ def get_deploy_kwargs( config_name=config_name, routing_config=routing_config, model_access_configs=model_access_configs, + inference_ami_version=inference_ami_version, ) deploy_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(kwargs=deploy_kwargs) deploy_kwargs.specs = verify_model_region_and_return_specs( diff --git a/src/sagemaker/jumpstart/hub/interfaces.py b/src/sagemaker/jumpstart/hub/interfaces.py index 69a468a0b4..fd38868dcc 100644 --- a/src/sagemaker/jumpstart/hub/interfaces.py +++ b/src/sagemaker/jumpstart/hub/interfaces.py @@ -471,6 +471,7 @@ class HubModelDocument(HubDataHolderType): "hosting_use_script_uri", "hosting_eula_uri", "hosting_model_package_arn", + "inference_ami_version", "model_subscription_link", "inference_configs", "inference_config_components", @@ -593,6 +594,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri") self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn") + self.inference_ami_version: Optional[str] = json_obj.get("InferenceAmiVersion") + self.model_subscription_link: Optional[str] = json_obj.get("ModelSubscriptionLink") self.inference_config_rankings = self._get_config_rankings(json_obj) diff --git a/src/sagemaker/jumpstart/hub/parsers.py b/src/sagemaker/jumpstart/hub/parsers.py index 51da974217..01b6c5fe87 100644 --- a/src/sagemaker/jumpstart/hub/parsers.py +++ b/src/sagemaker/jumpstart/hub/parsers.py @@ -72,7 +72,11 @@ def get_model_spec_arg_keys( """ arg_keys: List[str] = [] if arg_type == ModelSpecKwargType.DEPLOY: - arg_keys = ["ModelDataDownloadTimeout", "ContainerStartupHealthCheckTimeout"] + arg_keys = [ + "ModelDataDownloadTimeout", + "ContainerStartupHealthCheckTimeout", + "InferenceAmiVersion", + ] elif arg_type == ModelSpecKwargType.ESTIMATOR: arg_keys = [ "EncryptInterContainerTraffic", diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index c173ae55ff..b0b54db557 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -666,6 +666,7 @@ def deploy( endpoint_type: EndpointType = EndpointType.MODEL_BASED, routing_config: Optional[Dict[str, Any]] = None, model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None, + inference_ami_version: Optional[str] = None, ) -> PredictorBase: """Creates endpoint by calling base ``Model`` class `deploy` method. @@ -808,6 +809,7 @@ def deploy( config_name=self.config_name, routing_config=routing_config, model_access_configs=model_access_configs, + inference_ami_version=inference_ami_version, ) if ( self.model_type == JumpStartModelType.PROPRIETARY diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index cb989ca4d4..f59e2eddf4 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1389,6 +1389,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hosting_model_package_arns: Optional[Dict] = ( model_package_arns if model_package_arns is not None else {} ) + self.hosting_use_script_uri: bool = json_obj.get("hosting_use_script_uri", True) self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( @@ -2245,6 +2246,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "routing_config", "specs", "model_access_configs", + "inference_ami_version", ] SERIALIZATION_EXCLUSION_SET = { @@ -2298,6 +2300,7 @@ def __init__( config_name: Optional[str] = None, routing_config: Optional[Dict[str, Any]] = None, model_access_configs: Optional[Dict[str, CoreModelAccessConfig]] = None, + inference_ami_version: Optional[str] = None, ) -> None: """Instantiates JumpStartModelDeployKwargs object.""" @@ -2336,6 +2339,7 @@ def __init__( self.config_name = config_name self.routing_config = routing_config self.model_access_configs = model_access_configs + self.inference_ami_version = inference_ami_version class JumpStartEstimatorInitKwargs(JumpStartKwargs): @@ -2402,6 +2406,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "hub_content_type", "model_reference_arn", "specs", + "training_plan", ] SERIALIZATION_EXCLUSION_SET = { @@ -2475,6 +2480,7 @@ def __init__( enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, config_name: Optional[str] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, + training_plan: Optional[Union[str, PipelineVariable]] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -2537,6 +2543,7 @@ def __init__( self.enable_remote_debug = enable_remote_debug self.config_name = config_name self.enable_session_tag_chaining = enable_session_tag_chaining + self.training_plan = training_plan class JumpStartEstimatorFitKwargs(JumpStartKwargs): diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index d5c769efe0..46e5f8a847 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -455,6 +455,21 @@ def add_hub_content_arn_tags( return tags +def add_bedrock_store_tags( + tags: Optional[List[TagsDict]], + compatibility: str, +) -> Optional[List[TagsDict]]: + """Adds custom Hub arn tag to JumpStart related resources.""" + + tags = add_single_jumpstart_tag( + compatibility, + enums.JumpStartTag.BEDROCK, + tags, + is_uri=False, + ) + return tags + + def add_jumpstart_uri_tags( tags: Optional[List[TagsDict]] = None, inference_model_uri: Optional[Union[str, dict]] = None, diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index b78a4a2a64..863bbf376c 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -1383,6 +1383,7 @@ def deploy( inference_component_name=None, routing_config: Optional[Dict[str, Any]] = None, model_reference_arn: Optional[str] = None, + inference_ami_version: Optional[str] = None, **kwargs, ): """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. @@ -1652,6 +1653,7 @@ def deploy( container_startup_health_check_timeout=container_startup_health_check_timeout, managed_instance_scaling=managed_instance_scaling_config, routing_config=routing_config, + inference_ami_version=inference_ami_version, ) self.sagemaker_session.endpoint_from_production_variants( diff --git a/src/sagemaker/modules/train/sm_recipes/training_recipes.json b/src/sagemaker/modules/train/sm_recipes/training_recipes.json index 400e13f08a..a51513f49f 100644 --- a/src/sagemaker/modules/train/sm_recipes/training_recipes.json +++ b/src/sagemaker/modules/train/sm_recipes/training_recipes.json @@ -5,7 +5,9 @@ "gpu_image" : { "framework": "pytorch-smp", "version": "2.4.1", - "additional_args": {} + "additional_args": { + "container_version": "cu121" + } }, "neuron_image": { "framework": "hyperpod-recipes-neuron", diff --git a/src/sagemaker/pytorch/training_recipes.json b/src/sagemaker/pytorch/training_recipes.json index df60f95df9..5aeccce5a1 100644 --- a/src/sagemaker/pytorch/training_recipes.json +++ b/src/sagemaker/pytorch/training_recipes.json @@ -5,7 +5,9 @@ "gpu_image" : { "framework": "pytorch-smp", "version": "2.4.1", - "additional_args": {} + "additional_args": { + "container_version": "cu121" + } }, "neuron_image" : { "framework": "hyperpod-recipes-neuron", diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index bbc2c81904..04a7326557 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2470,6 +2470,75 @@ def describe_training_job(self, job_name): """ return self.sagemaker_client.describe_training_job(TrainingJobName=job_name) + def describe_training_plan(self, training_plan_name): + """Calls the DescribeTrainingPlan API for the given training plan and returns the response. + + Args: + training_plan_name (str): The name of the training plan to describe. + + Returns: + dict: A dictionary response with the training plan description. + """ + return self.sagemaker_client.describe_training_plan(TrainingPlanName=training_plan_name) + + def list_training_plans( + self, + filters=None, + requested_start_time_after=None, + requested_start_time_before=None, + start_time_after=None, + start_time_before=None, + sort_order=None, + sort_by=None, + max_results=None, + next_token=None, + ): + """Calls the ListrTrainingPlan API for the given filters and returns the response. + + Args: + filters (dict): A dictionary of key-value pairs used to filter the training plans. + Default to None. + requested_start_time_after (datetime): A timestamp that filters the results + to only include training plans with a requested start time after this timestamp. + requested_start_time_before (datetime): A timestamp that filters the results + to only include training plans with a requested start time before this timestamp. + start_time_after (datetime): A timestamp that filters the results + to only include training plans with an actual start time after this timestamp. + start_time_before (datetime): A timestamp that filters the results + to only include training plans with an actual start time before this timestamp. + sort_order (str): The order that the training plans will be listed in result. + Default to None. + sort_by (str): The value that the training plans will be sorted by. + Default to None. + max_results (int): The number of candidates will be listed in results, + between 1 and 100. Default to None. If None, will return all the training_plans. + next_token (str): The pagination token. Default to None. + + Returns: + dict: A dictionary containing the following keys: + - "TrainingPlanSummaries": A list of dictionaries, where each dictionary represents + a training plan. + - "NextToken": A token to retrieve the next set of results, if there are more + than the maximum number of results returned. + """ + list_training_plan_args = {} + + def check_object(key, value): + if value is not None: + list_training_plan_args[key] = value + + check_object("Filters", filters) + check_object("SortBy", sort_by) + check_object("SortOrder", sort_order) + check_object("RequestedStartTimeAfter", requested_start_time_after) + check_object("RequestedStartTimeBefore", requested_start_time_before) + check_object("StartTimeAfter", start_time_after) + check_object("StartTimeBefore", start_time_before) + check_object("NextToken", next_token) + check_object("MaxResults", max_results) + + return self.sagemaker_client.list_training_plans(**list_training_plan_args) + def auto_ml( self, input_config, @@ -7735,6 +7804,7 @@ def production_variant( container_startup_health_check_timeout=None, managed_instance_scaling=None, routing_config=None, + inference_ami_version=None, ): """Create a production variant description suitable for use in a ``ProductionVariant`` list. @@ -7799,6 +7869,9 @@ def production_variant( RoutingConfig=routing_config, ) + if inference_ami_version: + production_variant_configuration["InferenceAmiVersion"] = inference_ami_version + return production_variant_configuration diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index c3dd9c96fb..b938f489df 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -418,3 +418,149 @@ def test_jumpstart_session_with_config_name(): "md/js_model_id#meta-textgeneration-llama-2-7b md/js_model_ver#* md/js_config#tgi" in mock_make_request.call_args[0][1]["headers"]["User-Agent"] ) + + +def _setup_test_hub_with_reference(public_hub_model_id: str): + session = get_sm_session() + + try: + session.create_hub( + hub_name=TEST_HUB_WITH_REFERENCE, + hub_description="this is my sagemaker hub", + hub_display_name="Mock Hub", + hub_search_keywords=["mock", "hub", "123"], + s3_storage_config={"S3OutputPath": "s3://my-hub-bucket/"}, + tags=[{"Key": "tag-key-1", "Value": "tag-value-1"}], + ) + except Exception as e: + if "ResourceInUse" in str(e): + print("Hub already exists") + else: + raise e + + try: + session.create_hub_content_reference( + hub_name=TEST_HUB_WITH_REFERENCE, + source_hub_content_arn=( + f"arn:aws:sagemaker:{session.boto_region_name}:aws:" + f"hub-content/SageMakerPublicHub/Model/{public_hub_model_id}" + ), + ) + except Exception as e: + if "ResourceInUse" in str(e): + print("Reference already exists") + else: + raise e + + +def _teardown_test_hub_with_reference(public_hub_model_id: str): + session = get_sm_session() + + try: + session.delete_hub_content_reference( + hub_name=TEST_HUB_WITH_REFERENCE, + hub_content_type="ModelReference", + hub_content_name=public_hub_model_id, + ) + except Exception as e: + if "ResourceInUse" in str(e): + print("Reference already exists") + else: + raise e + + try: + session.delete_hub(hub_name=TEST_HUB_WITH_REFERENCE) + except Exception as e: + if "ResourceInUse" in str(e): + print("Hub already exists") + else: + raise e + + +# Currently JumpStartModel does not pull from HubService for the Public Hub. +def test_model_reference_marketplace_model(setup): + session = get_sm_session() + + # TODO: hardcoded model ID is brittle - should be dynamic pull via ListHubContents + public_hub_marketplace_model_id = "upstage-solar-mini-chat" + _setup_test_hub_with_reference(public_hub_marketplace_model_id) + + JumpStartModel( # Retrieving MP model None -> defaults to latest SemVer + model_id=public_hub_marketplace_model_id, + hub_name=TEST_HUB_WITH_REFERENCE, + role=session.get_caller_identity_arn(), + sagemaker_session=session, + ) + + model_semver = JumpStartModel( # Retrieving MP model SemVer -> uses SemVer + model_id=public_hub_marketplace_model_id, + hub_name=TEST_HUB_WITH_REFERENCE, + role=session.get_caller_identity_arn(), + sagemaker_session=session, + model_version="1.0.0", + ) + + model_marketplace_version = JumpStartModel( # Retrieving MP model MP version -> uses MPver + model_id=public_hub_marketplace_model_id, + hub_name=TEST_HUB_WITH_REFERENCE, + role=session.get_caller_identity_arn(), + sagemaker_session=session, + model_version="240612.5", + ) + + _teardown_test_hub_with_reference(public_hub_marketplace_model_id) # Cleanup before assertions + + assert model_semver.model_version == model_marketplace_version.model_version + + +# TODO: PySDK test account not subscribed to this model +# def test_model_reference_marketplace_model_deployment(setup): +# session = get_sm_session() +# public_hub_marketplace_model_id = "upstage-solar-mini-chat" +# _setup_test_hub_with_reference(public_hub_marketplace_model_id) + +# marketplace_model = JumpStartModel( # Retrieving MP model MP version -> uses MPver +# model_id=public_hub_marketplace_model_id, +# hub_name=TEST_HUB_WITH_REFERENCE, +# role=session.get_caller_identity_arn(), +# sagemaker_session=session, +# model_version="240612.5", +# ) +# predictor = marketplace_model.deploy( +# tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], +# accept_eula=True, +# ) + +# predictor.delete_predictor() +# _teardown_test_hub_with_reference(public_hub_marketplace_model_id) + + +def test_bedrock_store_model_tags_from_hub_service(setup): + + session = get_sm_session() + brs_model_id = "huggingface-llm-gemma-2b-instruct" + _setup_test_hub_with_reference(brs_model_id) + + brs_model = JumpStartModel( + model_id=brs_model_id, + hub_name=TEST_HUB_WITH_REFERENCE, + role=session.get_caller_identity_arn(), + sagemaker_session=session, + ) + + predictor = brs_model.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + accept_eula=True, + ) + + endpoint_arn = ( + f"arn:aws:sagemaker:{session.boto_region_name}:" + f"{session.account_id()}:endpoint/{predictor.endpoint_name}" + ) + tags = session.list_tags(endpoint_arn) + + predictor.delete_predictor() # Cleanup before assertions + _teardown_test_hub_with_reference(brs_model_id) + + expected_tag = {"Key": "sagemaker-sdk:bedrock", "Value": "compatible"} + assert expected_tag in tags diff --git a/tests/integ/sagemaker/jumpstart/utils.py b/tests/integ/sagemaker/jumpstart/utils.py index 5898b4b2a8..d439ef7e95 100644 --- a/tests/integ/sagemaker/jumpstart/utils.py +++ b/tests/integ/sagemaker/jumpstart/utils.py @@ -53,6 +53,20 @@ def get_sm_session() -> Session: return Session(boto_session=boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME)) +def get_sm_session_with_override() -> Session: + # [TODO]: Remove service endpoint override before GA + # boto3.set_stream_logger(name='botocore', level=logging.DEBUG) + boto_session = boto3.Session(region_name="us-west-2") + sagemaker = boto3.client( + service_name="sagemaker", + endpoint_url="https://sagemaker.gamma.us-west-2.ml-platform.aws.a2z.com", + ) + return Session( + boto_session=boto_session, + sagemaker_client=sagemaker, + ) + + def get_training_dataset_for_model_and_version(model_id: str, version: str) -> dict: return TRAINING_DATASET_MODEL_DICT[(model_id, version)] diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index 049ebaa9c4..093da20ab8 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -285,6 +285,7 @@ def test_train_with_intelligent_defaults_training_job_space( volume_kms_key_id=None, keep_alive_period_in_seconds=None, instance_groups=None, + training_plan_arn=None, ), vpc_config=None, session=ANY, @@ -825,6 +826,7 @@ def mock_upload_data(path, bucket, key_prefix): volume_kms_key_id=compute.volume_kms_key_id, keep_alive_period_in_seconds=compute.keep_alive_period_in_seconds, instance_groups=None, + training_plan_arn=None, ), vpc_config=VpcConfig( security_group_ids=networking.security_group_ids, diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 0bc84d29d0..8294eb0039 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -89,6 +89,7 @@ INSTANCE_COUNT = 1 INSTANCE_TYPE = "c4.4xlarge" KEEP_ALIVE_PERIOD_IN_SECONDS = 1800 +TRAINING_PLAN = "arn:aws:sagemaker:us-west-2:336:training-plan/test_training_plan" ACCELERATOR_TYPE = "ml.eia.medium" ROLE = "DummyRole" IMAGE_URI = "fakeimage" @@ -861,6 +862,23 @@ def test_framework_with_keep_alive_period(sagemaker_session): assert args["resource_config"]["KeepAlivePeriodInSeconds"] == KEEP_ALIVE_PERIOD_IN_SECONDS +def test_framework_with_training_plan(sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.m4.xlarge", 2), + ], + training_plan=TRAINING_PLAN, + ) + f.fit("s3://mydata") + sagemaker_session.train.assert_called_once() + _, args = sagemaker_session.train.call_args + assert args["resource_config"]["TrainingPlanArn"] == TRAINING_PLAN + + def test_framework_with_both_training_repository_config(sagemaker_session): f = DummyFramework( entry_point=SCRIPT_PATH, diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index 603b494e5a..c93a381c11 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -31,6 +31,7 @@ INSTANCE_COUNT = 1 INSTANCE_TYPE = "c4.4xlarge" KEEP_ALIVE_PERIOD = 1800 +TRAINING_PLAN = "arn:aws:sagemaker:us-west-2:336:training-plan/test_training_plan" INSTANCE_GROUP = InstanceGroup("group", "ml.c4.xlarge", 1) VOLUME_SIZE = 1 MAX_RUNTIME = 1 @@ -633,7 +634,13 @@ def test_prepare_output_config_kms_key_none(): def test_prepare_resource_config(): resource_config = _Job._prepare_resource_config( - INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, None, None + INSTANCE_COUNT, + INSTANCE_TYPE, + None, + VOLUME_SIZE, + None, + None, + None, ) assert resource_config == { @@ -643,9 +650,35 @@ def test_prepare_resource_config(): } +def test_prepare_resource_config_with_training_plan(): + resource_config = _Job._prepare_resource_config( + INSTANCE_COUNT, + INSTANCE_TYPE, + None, + VOLUME_SIZE, + VOLUME_KMS_KEY, + None, + TRAINING_PLAN, + ) + + assert resource_config == { + "InstanceCount": INSTANCE_COUNT, + "InstanceType": INSTANCE_TYPE, + "VolumeSizeInGB": VOLUME_SIZE, + "VolumeKmsKeyId": VOLUME_KMS_KEY, + "TrainingPlanArn": TRAINING_PLAN, + } + + def test_prepare_resource_config_with_keep_alive_period(): resource_config = _Job._prepare_resource_config( - INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, VOLUME_KMS_KEY, KEEP_ALIVE_PERIOD + INSTANCE_COUNT, + INSTANCE_TYPE, + None, + VOLUME_SIZE, + VOLUME_KMS_KEY, + KEEP_ALIVE_PERIOD, + None, ) assert resource_config == { @@ -659,7 +692,13 @@ def test_prepare_resource_config_with_keep_alive_period(): def test_prepare_resource_config_with_volume_kms(): resource_config = _Job._prepare_resource_config( - INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, VOLUME_KMS_KEY, None + INSTANCE_COUNT, + INSTANCE_TYPE, + None, + VOLUME_SIZE, + VOLUME_KMS_KEY, + None, + None, ) assert resource_config == { @@ -678,6 +717,7 @@ def test_prepare_resource_config_with_heterogeneous_cluster(): VOLUME_SIZE, None, None, + None, ) assert resource_config == { @@ -698,6 +738,7 @@ def test_prepare_resource_config_with_instance_groups_instance_type_instance_cou VOLUME_SIZE, None, None, + None, ) assert "instance_count and instance_type cannot be set when instance_groups is set" in str( error @@ -713,6 +754,7 @@ def test_prepare_resource_config_with_instance_groups_instance_type_instance_cou VOLUME_SIZE, None, None, + None, ) assert "instance_count and instance_type must be set if instance_groups is not set" in str( error