Skip to content
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ classifiers = [
]
dependencies = [
"attrs>=23.1.0,<24",
"boto3>=1.34.142,<2.0",
"boto3>=1.35.75,<2.0",
"cloudpickle==2.2.1",
"docker",
"fastapi",
Expand All @@ -49,7 +49,7 @@ dependencies = [
"psutil",
"PyYAML~=6.0",
"requests",
"sagemaker-core>=1.0.15,<2.0.0",
"sagemaker-core>=1.0.17,<2.0.0",
"schema",
"smdebug_rulesconfig==1.0.1",
"tblib>=1.7.0,<4",
Expand Down
18 changes: 14 additions & 4 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@
disable_output_compression: bool = False,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
training_plan: Optional[Union[str, PipelineVariable]] = None,
**kwargs,
):
"""Initialize an ``EstimatorBase`` instance.
Expand Down Expand Up @@ -554,6 +555,8 @@
Specifies whether RemoteDebug is enabled for the training job.
enable_session_tag_chaining (bool or PipelineVariable): Optional.
Specifies whether SessionTagChaining is enabled for the training job.
training_plan (str or PipelineVariable): Optional.
Specifies which training plan arn to use for the training job
"""
instance_count = renamed_kwargs(
"train_instance_count", "instance_count", instance_count, kwargs
Expand Down Expand Up @@ -762,8 +765,7 @@

self.tensorboard_output_config = tensorboard_output_config

self.debugger_rule_configs = None
self.collection_configs = None
self.debugger_rule_configs, self.collection_configs = None, None

self.enable_sagemaker_metrics = enable_sagemaker_metrics

Expand All @@ -774,6 +776,7 @@
sagemaker_session=self.sagemaker_session,
)

self.profiler_rule_configs, self.profiler_rules = None, None
self.profiler_config = profiler_config
self.disable_profiler = resolve_value_from_config(
direct_input=disable_profiler,
Expand All @@ -796,8 +799,6 @@
) or _instance_type_supports_profiler(self.instance_type):
self.disable_profiler = True

self.profiler_rule_configs = None
self.profiler_rules = None
self.debugger_rules = None
self.disable_output_compression = disable_output_compression
validate_source_code_input_against_pipeline_variables(
Expand All @@ -807,6 +808,8 @@
enable_network_isolation=self._enable_network_isolation,
)

self.training_plan = training_plan

# Internal flag
self._is_output_path_set_from_default_bucket_and_prefix = False

Expand Down Expand Up @@ -1960,6 +1963,9 @@
"KeepAlivePeriodInSeconds"
]

if "TrainingPlanArn" in job_details["ResourceConfig"]:
init_params["training_plan"] = job_details["ResourceConfig"]["TrainingPlanArn"]

Check warning on line 1967 in src/sagemaker/estimator.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/estimator.py#L1967

Added line #L1967 was not covered by tests

has_hps = "HyperParameters" in job_details
init_params["hyperparameters"] = job_details["HyperParameters"] if has_hps else {}

Expand Down Expand Up @@ -2840,6 +2846,7 @@
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
training_plan: Optional[Union[str, PipelineVariable]] = None,
**kwargs,
):
"""Initialize an ``Estimator`` instance.
Expand Down Expand Up @@ -3205,6 +3212,8 @@
Specifies whether RemoteDebug is enabled for the training job
enable_session_tag_chaining (bool or PipelineVariable): Optional.
Specifies whether SessionTagChaining is enabled for the training job
training_plan (str or PipelineVariable): Optional.
Specifies which training plan arn to use for the training job
"""
self.image_uri = image_uri
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}
Expand Down Expand Up @@ -3258,6 +3267,7 @@
disable_output_compression=disable_output_compression,
enable_remote_debug=enable_remote_debug,
enable_session_tag_chaining=enable_session_tag_chaining,
training_plan=training_plan,
**kwargs,
)

Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
estimator.volume_size,
estimator.volume_kms_key,
estimator.keep_alive_period_in_seconds,
estimator.training_plan,
)
stop_condition = _Job._prepare_stop_condition(estimator.max_run, estimator.max_wait)
vpc_config = estimator.get_vpc_config()
Expand Down Expand Up @@ -294,6 +295,7 @@ def _prepare_resource_config(
volume_size,
volume_kms_key,
keep_alive_period_in_seconds,
training_plan,
):
"""Placeholder docstring"""
resource_config = {
Expand All @@ -319,6 +321,8 @@ def _prepare_resource_config(
)
resource_config["InstanceCount"] = instance_count
resource_config["InstanceType"] = instance_type
if training_plan is not None:
resource_config["TrainingPlanArn"] = training_plan

return resource_config

Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/jumpstart/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ class VariableTypes(str, Enum):
BOOL = "bool"


class HubContentCapability(str, Enum):
"""Enum class for HubContent capabilities."""

BEDROCK_CONSOLE = "BEDROCK_CONSOLE"


class JumpStartTag(str, Enum):
"""Enum class for tag keys to apply to JumpStart models."""

Expand All @@ -99,6 +105,8 @@ class JumpStartTag(str, Enum):

HUB_CONTENT_ARN = "sagemaker-sdk:hub-content-arn"

BEDROCK = "sagemaker-sdk:bedrock"


class SerializerType(str, Enum):
"""Enum class for serializers associated with JumpStart models."""
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
config_name: Optional[str] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
training_plan: Optional[Union[str, PipelineVariable]] = None,
):
"""Initializes a ``JumpStartEstimator``.

Expand Down Expand Up @@ -511,6 +512,8 @@ def __init__(
Name of the training configuration to apply to the Estimator. (Default: None).
enable_session_tag_chaining (bool or PipelineVariable): Optional.
Specifies whether SessionTagChaining is enabled for the training job
training_plan (str or PipelineVariable): Optional.
Specifies which training plan arn to use for the training job

Raises:
ValueError: If the model ID is not recognized by JumpStart.
Expand Down Expand Up @@ -599,6 +602,7 @@ def _validate_model_id_and_get_type_hook():
enable_remote_debug=enable_remote_debug,
config_name=config_name,
enable_session_tag_chaining=enable_session_tag_chaining,
training_plan=training_plan,
)

self.hub_arn = estimator_init_kwargs.hub_arn
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def get_init_kwargs(
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
config_name: Optional[str] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
training_plan: Optional[Union[str, PipelineVariable]] = None,
) -> JumpStartEstimatorInitKwargs:
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""

Expand Down Expand Up @@ -205,6 +206,7 @@ def get_init_kwargs(
enable_remote_debug=enable_remote_debug,
config_name=config_name,
enable_session_tag_chaining=enable_session_tag_chaining,
training_plan=training_plan,
)

estimator_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(
Expand Down
9 changes: 8 additions & 1 deletion src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from sagemaker.model_metrics import ModelMetrics
from sagemaker.metadata_properties import MetadataProperties
from sagemaker.drift_check_baselines import DriftCheckBaselines
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType, HubContentCapability
from sagemaker.jumpstart.types import (
HubContentType,
JumpStartModelDeployKwargs,
Expand All @@ -53,6 +53,7 @@
from sagemaker.jumpstart.utils import (
add_hub_content_arn_tags,
add_jumpstart_model_info_tags,
add_bedrock_store_tags,
get_default_jumpstart_session_with_user_agent_suffix,
get_top_ranked_config_name,
update_dict_if_key_not_present,
Expand Down Expand Up @@ -495,6 +496,10 @@
)
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn)

if hasattr(kwargs.specs, "capabilities") and kwargs.specs.capabilities is not None:
if HubContentCapability.BEDROCK_CONSOLE in kwargs.specs.capabilities:
kwargs.tags = add_bedrock_store_tags(kwargs.tags, compatibility="compatible")

Check warning on line 501 in src/sagemaker/jumpstart/factory/model.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/jumpstart/factory/model.py#L500-L501

Added lines #L500 - L501 were not covered by tests

return kwargs


Expand Down Expand Up @@ -657,6 +662,7 @@
config_name: Optional[str] = None,
routing_config: Optional[Dict[str, Any]] = None,
model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None,
inference_ami_version: Optional[str] = None,
) -> JumpStartModelDeployKwargs:
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object."""

Expand Down Expand Up @@ -694,6 +700,7 @@
config_name=config_name,
routing_config=routing_config,
model_access_configs=model_access_configs,
inference_ami_version=inference_ami_version,
)
deploy_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(kwargs=deploy_kwargs)
deploy_kwargs.specs = verify_model_region_and_return_specs(
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/hub/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ class HubModelDocument(HubDataHolderType):
"hosting_use_script_uri",
"hosting_eula_uri",
"hosting_model_package_arn",
"inference_ami_version",
"model_subscription_link",
"inference_configs",
"inference_config_components",
Expand Down Expand Up @@ -593,6 +594,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri")
self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn")

self.inference_ami_version: Optional[str] = json_obj.get("InferenceAmiVersion")

self.model_subscription_link: Optional[str] = json_obj.get("ModelSubscriptionLink")

self.inference_config_rankings = self._get_config_rankings(json_obj)
Expand Down
6 changes: 5 additions & 1 deletion src/sagemaker/jumpstart/hub/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ def get_model_spec_arg_keys(
"""
arg_keys: List[str] = []
if arg_type == ModelSpecKwargType.DEPLOY:
arg_keys = ["ModelDataDownloadTimeout", "ContainerStartupHealthCheckTimeout"]
arg_keys = [
"ModelDataDownloadTimeout",
"ContainerStartupHealthCheckTimeout",
"InferenceAmiVersion",
]
elif arg_type == ModelSpecKwargType.ESTIMATOR:
arg_keys = [
"EncryptInterContainerTraffic",
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ def deploy(
endpoint_type: EndpointType = EndpointType.MODEL_BASED,
routing_config: Optional[Dict[str, Any]] = None,
model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None,
inference_ami_version: Optional[str] = None,
) -> PredictorBase:
"""Creates endpoint by calling base ``Model`` class `deploy` method.

Expand Down Expand Up @@ -808,6 +809,7 @@ def deploy(
config_name=self.config_name,
routing_config=routing_config,
model_access_configs=model_access_configs,
inference_ami_version=inference_ami_version,
)
if (
self.model_type == JumpStartModelType.PROPRIETARY
Expand Down
7 changes: 7 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,6 +1389,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
self.hosting_model_package_arns: Optional[Dict] = (
model_package_arns if model_package_arns is not None else {}
)

self.hosting_use_script_uri: bool = json_obj.get("hosting_use_script_uri", True)

self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = (
Expand Down Expand Up @@ -2245,6 +2246,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
"routing_config",
"specs",
"model_access_configs",
"inference_ami_version",
]

SERIALIZATION_EXCLUSION_SET = {
Expand Down Expand Up @@ -2298,6 +2300,7 @@ def __init__(
config_name: Optional[str] = None,
routing_config: Optional[Dict[str, Any]] = None,
model_access_configs: Optional[Dict[str, CoreModelAccessConfig]] = None,
inference_ami_version: Optional[str] = None,
) -> None:
"""Instantiates JumpStartModelDeployKwargs object."""

Expand Down Expand Up @@ -2336,6 +2339,7 @@ def __init__(
self.config_name = config_name
self.routing_config = routing_config
self.model_access_configs = model_access_configs
self.inference_ami_version = inference_ami_version


class JumpStartEstimatorInitKwargs(JumpStartKwargs):
Expand Down Expand Up @@ -2402,6 +2406,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
"hub_content_type",
"model_reference_arn",
"specs",
"training_plan",
]

SERIALIZATION_EXCLUSION_SET = {
Expand Down Expand Up @@ -2475,6 +2480,7 @@ def __init__(
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
config_name: Optional[str] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
training_plan: Optional[Union[str, PipelineVariable]] = None,
) -> None:
"""Instantiates JumpStartEstimatorInitKwargs object."""

Expand Down Expand Up @@ -2537,6 +2543,7 @@ def __init__(
self.enable_remote_debug = enable_remote_debug
self.config_name = config_name
self.enable_session_tag_chaining = enable_session_tag_chaining
self.training_plan = training_plan


class JumpStartEstimatorFitKwargs(JumpStartKwargs):
Expand Down
15 changes: 15 additions & 0 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,21 @@ def add_hub_content_arn_tags(
return tags


def add_bedrock_store_tags(
tags: Optional[List[TagsDict]],
compatibility: str,
) -> Optional[List[TagsDict]]:
"""Adds custom Hub arn tag to JumpStart related resources."""

tags = add_single_jumpstart_tag(
compatibility,
enums.JumpStartTag.BEDROCK,
tags,
is_uri=False,
)
return tags


def add_jumpstart_uri_tags(
tags: Optional[List[TagsDict]] = None,
inference_model_uri: Optional[Union[str, dict]] = None,
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,7 @@ def deploy(
inference_component_name=None,
routing_config: Optional[Dict[str, Any]] = None,
model_reference_arn: Optional[str] = None,
inference_ami_version: Optional[str] = None,
**kwargs,
):
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
Expand Down Expand Up @@ -1652,6 +1653,7 @@ def deploy(
container_startup_health_check_timeout=container_startup_health_check_timeout,
managed_instance_scaling=managed_instance_scaling_config,
routing_config=routing_config,
inference_ami_version=inference_ami_version,
)

self.sagemaker_session.endpoint_from_production_variants(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
"gpu_image" : {
"framework": "pytorch-smp",
"version": "2.4.1",
"additional_args": {}
"additional_args": {
"container_version": "cu121"
}
},
"neuron_image": {
"framework": "hyperpod-recipes-neuron",
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/pytorch/training_recipes.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
"gpu_image" : {
"framework": "pytorch-smp",
"version": "2.4.1",
"additional_args": {}
"additional_args": {
"container_version": "cu121"
}
},
"neuron_image" : {
"framework": "hyperpod-recipes-neuron",
Expand Down
Loading
Loading