Skip to content

Feat: ReInvent 2024 Late Release #4948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Dec 5, 2024
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ classifiers = [
]
dependencies = [
"attrs>=23.1.0,<24",
"boto3>=1.34.142,<2.0",
"boto3>=1.35.75,<2.0",
"cloudpickle==2.2.1",
"docker",
"fastapi",
Expand All @@ -49,7 +49,7 @@ dependencies = [
"psutil",
"PyYAML~=6.0",
"requests",
"sagemaker-core>=1.0.15,<2.0.0",
"sagemaker-core>=1.0.17,<2.0.0",
"schema",
"smdebug_rulesconfig==1.0.1",
"tblib>=1.7.0,<4",
Expand Down
18 changes: 14 additions & 4 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@
disable_output_compression: bool = False,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
training_plan: Optional[Union[str, PipelineVariable]] = None,
**kwargs,
):
"""Initialize an ``EstimatorBase`` instance.
Expand Down Expand Up @@ -554,6 +555,8 @@
Specifies whether RemoteDebug is enabled for the training job.
enable_session_tag_chaining (bool or PipelineVariable): Optional.
Specifies whether SessionTagChaining is enabled for the training job.
training_plan (str or PipelineVariable): Optional.
Specifies which training plan arn to use for the training job
"""
instance_count = renamed_kwargs(
"train_instance_count", "instance_count", instance_count, kwargs
Expand Down Expand Up @@ -762,8 +765,7 @@

self.tensorboard_output_config = tensorboard_output_config

self.debugger_rule_configs = None
self.collection_configs = None
self.debugger_rule_configs, self.collection_configs = None, None

self.enable_sagemaker_metrics = enable_sagemaker_metrics

Expand All @@ -774,6 +776,7 @@
sagemaker_session=self.sagemaker_session,
)

self.profiler_rule_configs, self.profiler_rules = None, None
self.profiler_config = profiler_config
self.disable_profiler = resolve_value_from_config(
direct_input=disable_profiler,
Expand All @@ -796,8 +799,6 @@
) or _instance_type_supports_profiler(self.instance_type):
self.disable_profiler = True

self.profiler_rule_configs = None
self.profiler_rules = None
self.debugger_rules = None
self.disable_output_compression = disable_output_compression
validate_source_code_input_against_pipeline_variables(
Expand All @@ -807,6 +808,8 @@
enable_network_isolation=self._enable_network_isolation,
)

self.training_plan = training_plan

# Internal flag
self._is_output_path_set_from_default_bucket_and_prefix = False

Expand Down Expand Up @@ -1960,6 +1963,9 @@
"KeepAlivePeriodInSeconds"
]

if "TrainingPlanArn" in job_details["ResourceConfig"]:
init_params["training_plan"] = job_details["ResourceConfig"]["TrainingPlanArn"]

Check warning on line 1967 in src/sagemaker/estimator.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/estimator.py#L1967

Added line #L1967 was not covered by tests

has_hps = "HyperParameters" in job_details
init_params["hyperparameters"] = job_details["HyperParameters"] if has_hps else {}

Expand Down Expand Up @@ -2840,6 +2846,7 @@
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
training_plan: Optional[Union[str, PipelineVariable]] = None,
**kwargs,
):
"""Initialize an ``Estimator`` instance.
Expand Down Expand Up @@ -3205,6 +3212,8 @@
Specifies whether RemoteDebug is enabled for the training job
enable_session_tag_chaining (bool or PipelineVariable): Optional.
Specifies whether SessionTagChaining is enabled for the training job
training_plan (str or PipelineVariable): Optional.
Specifies which training plan arn to use for the training job
"""
self.image_uri = image_uri
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}
Expand Down Expand Up @@ -3258,6 +3267,7 @@
disable_output_compression=disable_output_compression,
enable_remote_debug=enable_remote_debug,
enable_session_tag_chaining=enable_session_tag_chaining,
training_plan=training_plan,
**kwargs,
)

Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
estimator.volume_size,
estimator.volume_kms_key,
estimator.keep_alive_period_in_seconds,
estimator.training_plan,
)
stop_condition = _Job._prepare_stop_condition(estimator.max_run, estimator.max_wait)
vpc_config = estimator.get_vpc_config()
Expand Down Expand Up @@ -294,6 +295,7 @@ def _prepare_resource_config(
volume_size,
volume_kms_key,
keep_alive_period_in_seconds,
training_plan,
):
"""Placeholder docstring"""
resource_config = {
Expand All @@ -319,6 +321,8 @@ def _prepare_resource_config(
)
resource_config["InstanceCount"] = instance_count
resource_config["InstanceType"] = instance_type
if training_plan is not None:
resource_config["TrainingPlanArn"] = training_plan

return resource_config

Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/jumpstart/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ class VariableTypes(str, Enum):
BOOL = "bool"


class HubContentCapability(str, Enum):
"""Enum class for HubContent capabilities."""

BEDROCK_CONSOLE = "BEDROCK_CONSOLE"


class JumpStartTag(str, Enum):
"""Enum class for tag keys to apply to JumpStart models."""

Expand All @@ -99,6 +105,8 @@ class JumpStartTag(str, Enum):

HUB_CONTENT_ARN = "sagemaker-sdk:hub-content-arn"

BEDROCK = "sagemaker-sdk:bedrock"


class SerializerType(str, Enum):
"""Enum class for serializers associated with JumpStart models."""
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
config_name: Optional[str] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
training_plan: Optional[Union[str, PipelineVariable]] = None,
):
"""Initializes a ``JumpStartEstimator``.

Expand Down Expand Up @@ -511,6 +512,8 @@ def __init__(
Name of the training configuration to apply to the Estimator. (Default: None).
enable_session_tag_chaining (bool or PipelineVariable): Optional.
Specifies whether SessionTagChaining is enabled for the training job
training_plan (str or PipelineVariable): Optional.
Specifies which training plan arn to use for the training job

Raises:
ValueError: If the model ID is not recognized by JumpStart.
Expand Down Expand Up @@ -599,6 +602,7 @@ def _validate_model_id_and_get_type_hook():
enable_remote_debug=enable_remote_debug,
config_name=config_name,
enable_session_tag_chaining=enable_session_tag_chaining,
training_plan=training_plan,
)

self.hub_arn = estimator_init_kwargs.hub_arn
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def get_init_kwargs(
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
config_name: Optional[str] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
training_plan: Optional[Union[str, PipelineVariable]] = None,
) -> JumpStartEstimatorInitKwargs:
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""

Expand Down Expand Up @@ -205,6 +206,7 @@ def get_init_kwargs(
enable_remote_debug=enable_remote_debug,
config_name=config_name,
enable_session_tag_chaining=enable_session_tag_chaining,
training_plan=training_plan,
)

estimator_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(
Expand Down
9 changes: 8 additions & 1 deletion src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from sagemaker.model_metrics import ModelMetrics
from sagemaker.metadata_properties import MetadataProperties
from sagemaker.drift_check_baselines import DriftCheckBaselines
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType, HubContentCapability
from sagemaker.jumpstart.types import (
HubContentType,
JumpStartModelDeployKwargs,
Expand All @@ -53,6 +53,7 @@
from sagemaker.jumpstart.utils import (
add_hub_content_arn_tags,
add_jumpstart_model_info_tags,
add_bedrock_store_tags,
get_default_jumpstart_session_with_user_agent_suffix,
get_top_ranked_config_name,
update_dict_if_key_not_present,
Expand Down Expand Up @@ -495,6 +496,10 @@
)
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn)

if hasattr(kwargs.specs, "capabilities") and kwargs.specs.capabilities is not None:
if HubContentCapability.BEDROCK_CONSOLE in kwargs.specs.capabilities:
kwargs.tags = add_bedrock_store_tags(kwargs.tags, compatibility="compatible")

Check warning on line 501 in src/sagemaker/jumpstart/factory/model.py

View check run for this annotation

Codecov / codecov/patch

src/sagemaker/jumpstart/factory/model.py#L500-L501

Added lines #L500 - L501 were not covered by tests

return kwargs


Expand Down Expand Up @@ -657,6 +662,7 @@
config_name: Optional[str] = None,
routing_config: Optional[Dict[str, Any]] = None,
model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None,
inference_ami_version: Optional[str] = None,
) -> JumpStartModelDeployKwargs:
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object."""

Expand Down Expand Up @@ -694,6 +700,7 @@
config_name=config_name,
routing_config=routing_config,
model_access_configs=model_access_configs,
inference_ami_version=inference_ami_version,
)
deploy_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(kwargs=deploy_kwargs)
deploy_kwargs.specs = verify_model_region_and_return_specs(
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/hub/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ class HubModelDocument(HubDataHolderType):
"hosting_use_script_uri",
"hosting_eula_uri",
"hosting_model_package_arn",
"inference_ami_version",
"model_subscription_link",
"inference_configs",
"inference_config_components",
Expand Down Expand Up @@ -593,6 +594,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri")
self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn")

self.inference_ami_version: Optional[str] = json_obj.get("InferenceAmiVersion")

self.model_subscription_link: Optional[str] = json_obj.get("ModelSubscriptionLink")

self.inference_config_rankings = self._get_config_rankings(json_obj)
Expand Down
6 changes: 5 additions & 1 deletion src/sagemaker/jumpstart/hub/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ def get_model_spec_arg_keys(
"""
arg_keys: List[str] = []
if arg_type == ModelSpecKwargType.DEPLOY:
arg_keys = ["ModelDataDownloadTimeout", "ContainerStartupHealthCheckTimeout"]
arg_keys = [
"ModelDataDownloadTimeout",
"ContainerStartupHealthCheckTimeout",
"InferenceAmiVersion",
]
elif arg_type == ModelSpecKwargType.ESTIMATOR:
arg_keys = [
"EncryptInterContainerTraffic",
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ def deploy(
endpoint_type: EndpointType = EndpointType.MODEL_BASED,
routing_config: Optional[Dict[str, Any]] = None,
model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None,
inference_ami_version: Optional[str] = None,
) -> PredictorBase:
"""Creates endpoint by calling base ``Model`` class `deploy` method.

Expand Down Expand Up @@ -808,6 +809,7 @@ def deploy(
config_name=self.config_name,
routing_config=routing_config,
model_access_configs=model_access_configs,
inference_ami_version=inference_ami_version,
)
if (
self.model_type == JumpStartModelType.PROPRIETARY
Expand Down
7 changes: 7 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,6 +1389,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
self.hosting_model_package_arns: Optional[Dict] = (
model_package_arns if model_package_arns is not None else {}
)

self.hosting_use_script_uri: bool = json_obj.get("hosting_use_script_uri", True)

self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = (
Expand Down Expand Up @@ -2245,6 +2246,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
"routing_config",
"specs",
"model_access_configs",
"inference_ami_version",
]

SERIALIZATION_EXCLUSION_SET = {
Expand Down Expand Up @@ -2298,6 +2300,7 @@ def __init__(
config_name: Optional[str] = None,
routing_config: Optional[Dict[str, Any]] = None,
model_access_configs: Optional[Dict[str, CoreModelAccessConfig]] = None,
inference_ami_version: Optional[str] = None,
) -> None:
"""Instantiates JumpStartModelDeployKwargs object."""

Expand Down Expand Up @@ -2336,6 +2339,7 @@ def __init__(
self.config_name = config_name
self.routing_config = routing_config
self.model_access_configs = model_access_configs
self.inference_ami_version = inference_ami_version


class JumpStartEstimatorInitKwargs(JumpStartKwargs):
Expand Down Expand Up @@ -2402,6 +2406,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
"hub_content_type",
"model_reference_arn",
"specs",
"training_plan",
]

SERIALIZATION_EXCLUSION_SET = {
Expand Down Expand Up @@ -2475,6 +2480,7 @@ def __init__(
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
config_name: Optional[str] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
training_plan: Optional[Union[str, PipelineVariable]] = None,
) -> None:
"""Instantiates JumpStartEstimatorInitKwargs object."""

Expand Down Expand Up @@ -2537,6 +2543,7 @@ def __init__(
self.enable_remote_debug = enable_remote_debug
self.config_name = config_name
self.enable_session_tag_chaining = enable_session_tag_chaining
self.training_plan = training_plan


class JumpStartEstimatorFitKwargs(JumpStartKwargs):
Expand Down
15 changes: 15 additions & 0 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,21 @@ def add_hub_content_arn_tags(
return tags


def add_bedrock_store_tags(
tags: Optional[List[TagsDict]],
compatibility: str,
) -> Optional[List[TagsDict]]:
"""Adds custom Hub arn tag to JumpStart related resources."""

tags = add_single_jumpstart_tag(
compatibility,
enums.JumpStartTag.BEDROCK,
tags,
is_uri=False,
)
return tags


def add_jumpstart_uri_tags(
tags: Optional[List[TagsDict]] = None,
inference_model_uri: Optional[Union[str, dict]] = None,
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,7 @@ def deploy(
inference_component_name=None,
routing_config: Optional[Dict[str, Any]] = None,
model_reference_arn: Optional[str] = None,
inference_ami_version: Optional[str] = None,
**kwargs,
):
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
Expand Down Expand Up @@ -1652,6 +1653,7 @@ def deploy(
container_startup_health_check_timeout=container_startup_health_check_timeout,
managed_instance_scaling=managed_instance_scaling_config,
routing_config=routing_config,
inference_ami_version=inference_ami_version,
)

self.sagemaker_session.endpoint_from_production_variants(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
"gpu_image" : {
"framework": "pytorch-smp",
"version": "2.4.1",
"additional_args": {}
"additional_args": {
"container_version": "cu121"
}
},
"neuron_image": {
"framework": "hyperpod-recipes-neuron",
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/pytorch/training_recipes.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
"gpu_image" : {
"framework": "pytorch-smp",
"version": "2.4.1",
"additional_args": {}
"additional_args": {
"container_version": "cu121"
}
},
"neuron_image" : {
"framework": "hyperpod-recipes-neuron",
Expand Down
Loading
Loading