From eb3f92cb7d4e68ac6064936a3789e0b30fa13ab6 Mon Sep 17 00:00:00 2001 From: chrstfu <105246221+chrstfu@users.noreply.github.com> Date: Fri, 25 Oct 2024 13:11:24 -0400 Subject: [PATCH 01/12] feat: Adding Bedrock Store model support for HubService (#1539) * feat: Adding BRS support --- src/sagemaker/jumpstart/enums.py | 8 + src/sagemaker/jumpstart/factory/model.py | 7 +- src/sagemaker/jumpstart/hub/parser_utils.py | 7 + src/sagemaker/jumpstart/hub/parsers.py | 6 +- src/sagemaker/jumpstart/hub/utils.py | 14 ++ src/sagemaker/jumpstart/types.py | 26 ++-- src/sagemaker/jumpstart/utils.py | 19 ++- .../jumpstart/model/test_jumpstart_model.py | 145 ++++++++++++++++++ tests/integ/sagemaker/jumpstart/utils.py | 17 ++ .../sagemaker/jumpstart/hub/test_utils.py | 4 +- 10 files changed, 232 insertions(+), 21 deletions(-) diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index a83964e394..91f547afb6 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -82,6 +82,12 @@ class VariableTypes(str, Enum): BOOL = "bool" +class HubContentCapability(str, Enum): + """Enum class for HubContent capabilities.""" + + BEDROCK_CONSOLE = "BEDROCK_CONSOLE" + + class JumpStartTag(str, Enum): """Enum class for tag keys to apply to JumpStart models.""" @@ -99,6 +105,8 @@ class JumpStartTag(str, Enum): HUB_CONTENT_ARN = "sagemaker-sdk:hub-content-arn" + BEDROCK = "sagemaker-sdk:bedrock" + class SerializerType(str, Enum): """Enum class for serializers associated with JumpStart models.""" diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 82bc1fc174..412c7066a6 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -42,7 +42,7 @@ from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines -from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType, HubContentCapability from sagemaker.jumpstart.types import ( HubContentType, JumpStartModelDeployKwargs, @@ -53,6 +53,7 @@ from sagemaker.jumpstart.utils import ( add_hub_content_arn_tags, add_jumpstart_model_info_tags, + add_bedrock_store_tags, get_default_jumpstart_session_with_user_agent_suffix, get_top_ranked_config_name, update_dict_if_key_not_present, @@ -495,6 +496,10 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: ) kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn) + if hasattr(kwargs.specs, "capabilities") and kwargs.specs.capabilities is not None: + if HubContentCapability.BEDROCK_CONSOLE in kwargs.specs.capabilities: + kwargs.tags = add_bedrock_store_tags(kwargs.tags, compatibility="compatible") + return kwargs diff --git a/src/sagemaker/jumpstart/hub/parser_utils.py b/src/sagemaker/jumpstart/hub/parser_utils.py index 0983122d09..08dabe1f78 100644 --- a/src/sagemaker/jumpstart/hub/parser_utils.py +++ b/src/sagemaker/jumpstart/hub/parser_utils.py @@ -18,10 +18,17 @@ from typing import Any, Dict, List, Optional +<<<<<<< HEAD def camel_to_snake(camel_case_string: str) -> str: """Converts camelCase to snake_case_string using a regex. This regex cannot handle whitespace ("camelString TwoWords") +======= +def pascal_to_snake(camel_case_string: str) -> str: + """Converts PascalCase to snake_case_string using a regex. + + This regex cannot handle whitespace ("PascalString TwoWords") +>>>>>>> ff3eae05 (feat: Adding Bedrock Store model support for HubService (#1539)) """ return re.sub(r"(?>>>>>> ff3eae05 (feat: Adding Bedrock Store model support for HubService (#1539)) try: return _get_hub_model_version_for_open_weight_version( hub_content_summaries, hub_model_version ) +<<<<<<< HEAD except KeyError: marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version( hub_content_summaries, hub_model_version @@ -230,6 +238,12 @@ def get_hub_model_version( if marketplace_hub_content_version: return marketplace_hub_content_version raise +======= + except KeyError as e: + if marketplace_hub_content_version: + return marketplace_hub_content_version + raise e +>>>>>>> ff3eae05 (feat: Adding Bedrock Store model support for HubService (#1539)) def _get_hub_model_version_for_open_weight_version( diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index cb989ca4d4..0f49f23ee7 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -40,7 +40,7 @@ from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker.enums import EndpointType from sagemaker.jumpstart.hub.parser_utils import ( - camel_to_snake, + pascal_to_snake, walk_and_apply_json, ) from sagemaker.model_life_cycle import ModelLifeCycle @@ -241,7 +241,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: return if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.framework = json_obj.get("framework") self.framework_version = json_obj.get("framework_version") @@ -295,7 +295,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: """ if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.name = json_obj["name"] self.type = json_obj["type"] self.default = json_obj["default"] @@ -363,7 +363,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Args: json_obj (Dict[str, Any]): Dictionary representation of environment variable. """ - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.name = json_obj["name"] self.type = json_obj["type"] self.default = json_obj["default"] @@ -413,7 +413,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: return if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.default_content_type = json_obj["default_content_type"] self.supported_content_types = json_obj["supported_content_types"] self.default_accept_type = json_obj["default_accept_type"] @@ -467,7 +467,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: return if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.raw_payload = json_obj self.content_type = json_obj["content_type"] self.body = json_obj.get("body") @@ -540,7 +540,7 @@ def from_describe_hub_content_response(self, response: Optional[Dict[str, Any]]) if response is None: return - response = walk_and_apply_json(response, camel_to_snake) + response = walk_and_apply_json(response, pascal_to_snake) self.aliases: Optional[dict] = response.get("aliases") self.regional_aliases = None self.variants: Optional[dict] = response.get("variants") @@ -1180,7 +1180,7 @@ def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content=False): spec (Dict[str, Any]): Dictionary representation of training config ranking. """ if is_hub_content: - spec = walk_and_apply_json(spec, camel_to_snake) + spec = walk_and_apply_json(spec, pascal_to_snake) self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -1286,7 +1286,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj (Dict[str, Any]): Dictionary representation of spec. """ if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.model_id: str = json_obj.get("model_id") self.url: str = json_obj.get("url") self.version: str = json_obj.get("version") @@ -1515,7 +1515,7 @@ def __init__( ValueError: If the component field is invalid. """ if is_hub_content: - component = walk_and_apply_json(component, camel_to_snake) + component = walk_and_apply_json(component, pascal_to_snake) self.component_name = component_name super().__init__(component, is_hub_content) self.from_json(component) @@ -1568,8 +1568,8 @@ def __init__( The list of components that are used to construct the resolved config. """ if is_hub_content: - config = walk_and_apply_json(config, camel_to_snake) - base_fields = walk_and_apply_json(base_fields, camel_to_snake) + config = walk_and_apply_json(config, pascal_to_snake) + base_fields = walk_and_apply_json(base_fields, pascal_to_snake) self.base_fields = base_fields self.config_components: Dict[str, JumpStartConfigComponent] = config_components self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = ( @@ -1735,7 +1735,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: """ super().from_json(json_obj) if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, camel_to_snake) + json_obj = walk_and_apply_json(json_obj, pascal_to_snake) self.inference_config_components: Optional[Dict[str, JumpStartConfigComponent]] = ( { component_name: JumpStartConfigComponent(component_name, component) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index d5c769efe0..e80028d00e 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -35,7 +35,7 @@ from sagemaker.jumpstart import constants, enums from sagemaker.jumpstart import accessors -from sagemaker.jumpstart.hub.parser_utils import camel_to_snake, snake_to_upper_camel +from sagemaker.jumpstart.hub.parser_utils import pascal_to_snake, snake_to_upper_camel from sagemaker.s3 import parse_s3_url from sagemaker.jumpstart.exceptions import ( DeprecatedJumpStartModelError, @@ -455,6 +455,21 @@ def add_hub_content_arn_tags( return tags +def add_bedrock_store_tags( + tags: Optional[List[TagsDict]], + compatibility: str, +) -> Optional[List[TagsDict]]: + """Adds custom Hub arn tag to JumpStart related resources.""" + + tags = add_single_jumpstart_tag( + compatibility, + enums.JumpStartTag.BEDROCK, + tags, + is_uri=False, + ) + return tags + + def add_jumpstart_uri_tags( tags: Optional[List[TagsDict]] = None, inference_model_uri: Optional[Union[str, dict]] = None, @@ -1163,7 +1178,7 @@ def get_jumpstart_configs( return ( { config_name: metadata_configs.configs[ - camel_to_snake(snake_to_upper_camel(config_name)) + pascal_to_snake(snake_to_upper_camel(config_name)) ] for config_name in config_names } diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index c3dd9c96fb..3032640969 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -418,3 +418,148 @@ def test_jumpstart_session_with_config_name(): "md/js_model_id#meta-textgeneration-llama-2-7b md/js_model_ver#* md/js_config#tgi" in mock_make_request.call_args[0][1]["headers"]["User-Agent"] ) + + +def _setup_test_hub_with_reference(public_hub_model_id: str): + session = get_sm_session() + + try: + session.create_hub( + hub_name=TEST_HUB_WITH_REFERENCE, + hub_description="this is my sagemaker hub", + hub_display_name="Mock Hub", + hub_search_keywords=["mock", "hub", "123"], + s3_storage_config={"S3OutputPath": "s3://my-hub-bucket/"}, + tags=[{"Key": "tag-key-1", "Value": "tag-value-1"}], + ) + except Exception as e: + if "ResourceInUse" in str(e): + print("Hub already exists") + else: + raise e + + try: + session.create_hub_content_reference( + hub_name=TEST_HUB_WITH_REFERENCE, + source_hub_content_arn=( + f"arn:aws:sagemaker:{session.boto_region_name}:aws:" + f"hub-content/SageMakerPublicHub/Model/{public_hub_model_id}" + ), + ) + except Exception as e: + if "ResourceInUse" in str(e): + print("Reference already exists") + else: + raise e + + +def _teardown_test_hub_with_reference(public_hub_model_id: str): + session = get_sm_session() + + try: + session.delete_hub_content_reference( + hub_name=TEST_HUB_WITH_REFERENCE, + hub_content_type="ModelReference", + hub_content_name=public_hub_model_id, + ) + except Exception as e: + if "ResourceInUse" in str(e): + print("Reference already exists") + else: + raise e + + try: + session.delete_hub(hub_name=TEST_HUB_WITH_REFERENCE) + except Exception as e: + if "ResourceInUse" in str(e): + print("Hub already exists") + else: + raise e + + +# Currently JumpStartModel does not pull from HubService for the Public Hub. +def test_model_reference_marketplace_model(setup): + session = get_sm_session() + + # TODO: hardcoded model ID is brittle - should be dynamic pull via ListHubContents + public_hub_marketplace_model_id = "upstage-solar-mini-chat" + _setup_test_hub_with_reference(public_hub_marketplace_model_id) + + JumpStartModel( # Retrieving MP model None -> defaults to latest SemVer + model_id=public_hub_marketplace_model_id, + hub_name=TEST_HUB_WITH_REFERENCE, + role=session.get_caller_identity_arn(), + sagemaker_session=session, + ) + + model_semver = JumpStartModel( # Retrieving MP model SemVer -> uses SemVer + model_id=public_hub_marketplace_model_id, + hub_name=TEST_HUB_WITH_REFERENCE, + role=session.get_caller_identity_arn(), + sagemaker_session=session, + model_version="1.0.0", + ) + + model_marketplace_version = JumpStartModel( # Retrieving MP model MP version -> uses MPver + model_id=public_hub_marketplace_model_id, + hub_name=TEST_HUB_WITH_REFERENCE, + role=session.get_caller_identity_arn(), + sagemaker_session=session, + model_version="240612.5", + ) + + _teardown_test_hub_with_reference(public_hub_marketplace_model_id) # Cleanup before assertions + + assert model_semver.model_version == model_marketplace_version.model_version + + +def test_model_reference_marketplace_model_deployment(setup): + session = get_sm_session() + public_hub_marketplace_model_id = "upstage-solar-mini-chat" + _setup_test_hub_with_reference(public_hub_marketplace_model_id) + + marketplace_model = JumpStartModel( # Retrieving MP model MP version -> uses MPver + model_id=public_hub_marketplace_model_id, + hub_name=TEST_HUB_WITH_REFERENCE, + role=session.get_caller_identity_arn(), + sagemaker_session=session, + model_version="240612.5", + ) + predictor = marketplace_model.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + accept_eula=True, + ) + + predictor.delete_predictor() + _teardown_test_hub_with_reference(public_hub_marketplace_model_id) + + +def test_bedrock_store_model_tags_from_hub_service(setup): + + session = get_sm_session() + brs_model_id = "huggingface-llm-gemma-2b-instruct" + _setup_test_hub_with_reference(brs_model_id) + + brs_model = JumpStartModel( + model_id=brs_model_id, + hub_name=TEST_HUB_WITH_REFERENCE, + role=session.get_caller_identity_arn(), + sagemaker_session=session, + ) + + predictor = brs_model.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + accept_eula=True, + ) + + endpoint_arn = ( + f"arn:aws:sagemaker:{session.boto_region_name}:" + f"{session.account_id()}:endpoint/{predictor.endpoint_name}" + ) + tags = session.list_tags(endpoint_arn) + + predictor.delete_predictor() # Cleanup before assertions + _teardown_test_hub_with_reference(brs_model_id) + + expected_tag = {"Key": "sagemaker-sdk:bedrock", "Value": "compatible"} + assert expected_tag in tags diff --git a/tests/integ/sagemaker/jumpstart/utils.py b/tests/integ/sagemaker/jumpstart/utils.py index 5898b4b2a8..2ed7b6b78b 100644 --- a/tests/integ/sagemaker/jumpstart/utils.py +++ b/tests/integ/sagemaker/jumpstart/utils.py @@ -53,6 +53,23 @@ def get_sm_session() -> Session: return Session(boto_session=boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME)) +<<<<<<< HEAD +======= +def get_sm_session_with_override() -> Session: + # [TODO]: Remove service endpoint override before GA + # boto3.set_stream_logger(name='botocore', level=logging.DEBUG) + boto_session = boto3.Session(region_name="us-west-2") + sagemaker = boto3.client( + service_name="sagemaker", + endpoint_url="https://sagemaker.gamma.us-west-2.ml-platform.aws.a2z.com", + ) + return Session( + boto_session=boto_session, + sagemaker_client=sagemaker, + ) + + +>>>>>>> ff3eae05 (feat: Adding Bedrock Store model support for HubService (#1539)) def get_training_dataset_for_model_and_version(model_id: str, version: str) -> dict: return TRAINING_DATASET_MODEL_DICT[(model_id, version)] diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py index 22bc527b18..0ef740398c 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -268,7 +268,7 @@ def test_walk_and_apply_json(): } result = parser_utils.walk_and_apply_json( - test_json, parser_utils.camel_to_snake, ["ignore_my_children"] + test_json, parser_utils.pascal_to_snake, ["ignore_my_children"] ) assert result == { "camel_case_key": "value", @@ -294,7 +294,7 @@ def test_walk_and_apply_json_no_stop(): "CamelCaseObjectListKey": {"instance.ml.type.xlarge": [{"ShouldChangeMe": "string"}]}, } - result = parser_utils.walk_and_apply_json(test_json, parser_utils.camel_to_snake) + result = parser_utils.walk_and_apply_json(test_json, parser_utils.pascal_to_snake) assert result == { "camel_case_key": "value", "camel_case_object_key": { From 9a7b6defae7954d9afd3c3842c14e64d77b5a6db Mon Sep 17 00:00:00 2001 From: chrstfu <105246221+chrstfu@users.noreply.github.com> Date: Fri, 15 Nov 2024 18:34:55 -0500 Subject: [PATCH 02/12] chore: Merge from main (#1600) * prepare release v2.232.2 * update development version to v2.232.3.dev0 * fix: Use Miniforge to replace MambaForge (#4884) * Use Miniforge to replace MambaForge * Fix download url * tests: Implement integration tests covering JumpStart PrivateHub workflows (#4883) * tests: Implement integration tests covering JumpStart PrivateHub workflows * linting * formating * Only delete the pytest session specific test * change scope to session * address nits * Address test failures * address typo * address comments * resolve flake8 errors * implement throttle handling * flake8 * flake8 * Adding more assertions --------- Co-authored-by: malavhs * chore: add lmi image config in me-central-1 (#4887) * changes for PT 2.4 currency upgrade (#4890) Co-authored-by: Andrew Tian * chore(deps): bump pyspark from 3.3.1 to 3.3.2 in /requirements/extras (#4894) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * update cloudpickle version to >=2.2.1 (#4891) * update cloudpickle version to >=2.2.1 in pyproject.toml (#4899) * Revert "update cloudpickle version to >=2.2.1 in pyproject.toml (#4899)" (#4902) * release: huggingface tgi neuronx 0.0.25 image (#4893) * release: huggingface tgi neuronx 0.0.25 image * tests * add tgi 2.3.1 * update default version * update py version * fix tests * test * Revert "fix tests" This reverts commit 9374c7518e97bd845f952f923d8784cbedb02f02. * Revert "Revert "fix tests"" This reverts commit 20d46d187a29638bcb6025a82af40e55e3331685. * Revert "test" This reverts commit 90f6e0b5253a17825a0bdb7d570640bbb274199f. * fix: fixing typo in dependecy setup (#4905) charset-normalizer is misplet in the requirements.txt files * Fix: Returning ModelPackage object on register of PipelineModel (#4878) Co-authored-by: Keshav Chandak * fix: modified pull request template (#4906) Added warning to spell check dependencies added and ensure they exist in PyPi * Fix: image_uris graviton image uri (#4909) * change: update image_uri_configs 09-27-2024 07:18:01 PST * change: update image_uri_configs 10-03-2024 07:17:59 PST * change: update image_uri_configs 10-17-2024 07:17:55 PST * change: update image_uri_configs 10-23-2024 11:26:03 PST * change: adding eu-central-2 bucket info to JS constants (#4907) * change: adding eu-central-2 bucket info to JS constants * change: adding zrh image uris for dji neuronx --------- Co-authored-by: shaernev * fix: Skip pytorch tests incompatible with latest version 2.4.0 (#4910) * change: update image_uri_configs 10-29-2024 07:17:56 PST * prepare release v2.232.3 * update development version to v2.232.4.dev0 * change: Updates for DJL 0.30.0 release (#4892) Co-authored-by: pintaoz-aws <167920275+pintaoz-aws@users.noreply.github.com> * bumping smp version from 2.6.0 to 2.6.1 (#4913) Co-authored-by: Andrew Tian Co-authored-by: nileshvd <113946607+nileshvd@users.noreply.github.com> * feat: Marketplace model support in HubService (#4916) * feat: Marketplace model support in HubService * fix: removing field * fix: Reverting name change for code coverage * fix: Adding more code coverage * fix: linting * fix: Fixing coverage tests * fix: Fixing integration tests * fix: Minor fixes * feat: triton v24.09 (#4908) * fix: Fixing JumpStart Tests (#4917) * fix: Fixing tests * fix: fixing test name * fix: dummy commit * fix: reverting dummy commit * fix: Removing flakey tests --------- Co-authored-by: nileshvd <113946607+nileshvd@users.noreply.github.com> * fix: merge * fix: Commenting out marketplace test * fix: Linting --------- Co-authored-by: ci Co-authored-by: pintaoz-aws <167920275+pintaoz-aws@users.noreply.github.com> Co-authored-by: Malav Shastri <57682969+malav-shastri@users.noreply.github.com> Co-authored-by: malavhs Co-authored-by: Haotian An <33510317+Captainia@users.noreply.github.com> Co-authored-by: adtian2 <55163384+adtian2@users.noreply.github.com> Co-authored-by: Andrew Tian Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: liujiaor <128006184+liujiaorr@users.noreply.github.com> Co-authored-by: ASHWIN KRISHNA Co-authored-by: Keshav Chandak Co-authored-by: Keshav Chandak Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> Co-authored-by: sagemaker-bot Co-authored-by: Ernev Sharma Co-authored-by: shaernev Co-authored-by: Tyler Osterberg Co-authored-by: nileshvd <113946607+nileshvd@users.noreply.github.com> Co-authored-by: Aaqib --- requirements/extras/test_requirements.txt | 3 ++ src/sagemaker/jumpstart/hub/parser_utils.py | 7 ++++ src/sagemaker/jumpstart/hub/parsers.py | 6 +-- src/sagemaker/jumpstart/hub/utils.py | 12 ++++++ src/sagemaker/jumpstart/types.py | 26 ++++++------- src/sagemaker/jumpstart/utils.py | 4 +- .../jumpstart/model/test_jumpstart_model.py | 39 ++++++++++--------- tests/integ/sagemaker/jumpstart/utils.py | 3 ++ .../sagemaker/jumpstart/hub/test_utils.py | 4 +- 9 files changed, 65 insertions(+), 39 deletions(-) diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index 9664a63e1d..cc8068e077 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -25,7 +25,10 @@ pyvis==0.2.1 pandas==1.4.4 scikit-learn==1.3.0 cloudpickle>=2.2.1 +<<<<<<< HEAD jsonpickle<4.0.0 +======= +>>>>>>> 42acb4f4 (chore: Merge from main (#1600)) PyYAML==6.0 # TODO find workaround xgboost>=1.6.2,<=1.7.6 diff --git a/src/sagemaker/jumpstart/hub/parser_utils.py b/src/sagemaker/jumpstart/hub/parser_utils.py index 08dabe1f78..259ba84c23 100644 --- a/src/sagemaker/jumpstart/hub/parser_utils.py +++ b/src/sagemaker/jumpstart/hub/parser_utils.py @@ -18,6 +18,7 @@ from typing import Any, Dict, List, Optional +<<<<<<< HEAD <<<<<<< HEAD def camel_to_snake(camel_case_string: str) -> str: """Converts camelCase to snake_case_string using a regex. @@ -29,6 +30,12 @@ def pascal_to_snake(camel_case_string: str) -> str: This regex cannot handle whitespace ("PascalString TwoWords") >>>>>>> ff3eae05 (feat: Adding Bedrock Store model support for HubService (#1539)) +======= +def camel_to_snake(camel_case_string: str) -> str: + """Converts camelCase to snake_case_string using a regex. + + This regex cannot handle whitespace ("camelString TwoWords") +>>>>>>> 42acb4f4 (chore: Merge from main (#1600)) """ return re.sub(r"(?>>>>>> ff3eae05 (feat: Adding Bedrock Store model support for HubService (#1539)) +======= +>>>>>>> 42acb4f4 (chore: Merge from main (#1600)) try: return _get_hub_model_version_for_open_weight_version( hub_content_summaries, hub_model_version ) <<<<<<< HEAD +<<<<<<< HEAD +======= +>>>>>>> 42acb4f4 (chore: Merge from main (#1600)) except KeyError: marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version( hub_content_summaries, hub_model_version ) +<<<<<<< HEAD if marketplace_hub_content_version: return marketplace_hub_content_version raise @@ -244,6 +251,11 @@ def get_hub_model_version( return marketplace_hub_content_version raise e >>>>>>> ff3eae05 (feat: Adding Bedrock Store model support for HubService (#1539)) +======= + if marketplace_hub_content_version: + return marketplace_hub_content_version + raise +>>>>>>> 42acb4f4 (chore: Merge from main (#1600)) def _get_hub_model_version_for_open_weight_version( diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 0f49f23ee7..cb989ca4d4 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -40,7 +40,7 @@ from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker.enums import EndpointType from sagemaker.jumpstart.hub.parser_utils import ( - pascal_to_snake, + camel_to_snake, walk_and_apply_json, ) from sagemaker.model_life_cycle import ModelLifeCycle @@ -241,7 +241,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: return if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, pascal_to_snake) + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.framework = json_obj.get("framework") self.framework_version = json_obj.get("framework_version") @@ -295,7 +295,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: """ if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, pascal_to_snake) + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.name = json_obj["name"] self.type = json_obj["type"] self.default = json_obj["default"] @@ -363,7 +363,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Args: json_obj (Dict[str, Any]): Dictionary representation of environment variable. """ - json_obj = walk_and_apply_json(json_obj, pascal_to_snake) + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.name = json_obj["name"] self.type = json_obj["type"] self.default = json_obj["default"] @@ -413,7 +413,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: return if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, pascal_to_snake) + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.default_content_type = json_obj["default_content_type"] self.supported_content_types = json_obj["supported_content_types"] self.default_accept_type = json_obj["default_accept_type"] @@ -467,7 +467,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: return if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, pascal_to_snake) + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.raw_payload = json_obj self.content_type = json_obj["content_type"] self.body = json_obj.get("body") @@ -540,7 +540,7 @@ def from_describe_hub_content_response(self, response: Optional[Dict[str, Any]]) if response is None: return - response = walk_and_apply_json(response, pascal_to_snake) + response = walk_and_apply_json(response, camel_to_snake) self.aliases: Optional[dict] = response.get("aliases") self.regional_aliases = None self.variants: Optional[dict] = response.get("variants") @@ -1180,7 +1180,7 @@ def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content=False): spec (Dict[str, Any]): Dictionary representation of training config ranking. """ if is_hub_content: - spec = walk_and_apply_json(spec, pascal_to_snake) + spec = walk_and_apply_json(spec, camel_to_snake) self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -1286,7 +1286,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj (Dict[str, Any]): Dictionary representation of spec. """ if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, pascal_to_snake) + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.model_id: str = json_obj.get("model_id") self.url: str = json_obj.get("url") self.version: str = json_obj.get("version") @@ -1515,7 +1515,7 @@ def __init__( ValueError: If the component field is invalid. """ if is_hub_content: - component = walk_and_apply_json(component, pascal_to_snake) + component = walk_and_apply_json(component, camel_to_snake) self.component_name = component_name super().__init__(component, is_hub_content) self.from_json(component) @@ -1568,8 +1568,8 @@ def __init__( The list of components that are used to construct the resolved config. """ if is_hub_content: - config = walk_and_apply_json(config, pascal_to_snake) - base_fields = walk_and_apply_json(base_fields, pascal_to_snake) + config = walk_and_apply_json(config, camel_to_snake) + base_fields = walk_and_apply_json(base_fields, camel_to_snake) self.base_fields = base_fields self.config_components: Dict[str, JumpStartConfigComponent] = config_components self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = ( @@ -1735,7 +1735,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: """ super().from_json(json_obj) if self._is_hub_content: - json_obj = walk_and_apply_json(json_obj, pascal_to_snake) + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.inference_config_components: Optional[Dict[str, JumpStartConfigComponent]] = ( { component_name: JumpStartConfigComponent(component_name, component) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index e80028d00e..46e5f8a847 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -35,7 +35,7 @@ from sagemaker.jumpstart import constants, enums from sagemaker.jumpstart import accessors -from sagemaker.jumpstart.hub.parser_utils import pascal_to_snake, snake_to_upper_camel +from sagemaker.jumpstart.hub.parser_utils import camel_to_snake, snake_to_upper_camel from sagemaker.s3 import parse_s3_url from sagemaker.jumpstart.exceptions import ( DeprecatedJumpStartModelError, @@ -1178,7 +1178,7 @@ def get_jumpstart_configs( return ( { config_name: metadata_configs.configs[ - pascal_to_snake(snake_to_upper_camel(config_name)) + camel_to_snake(snake_to_upper_camel(config_name)) ] for config_name in config_names } diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 3032640969..b938f489df 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -513,25 +513,26 @@ def test_model_reference_marketplace_model(setup): assert model_semver.model_version == model_marketplace_version.model_version -def test_model_reference_marketplace_model_deployment(setup): - session = get_sm_session() - public_hub_marketplace_model_id = "upstage-solar-mini-chat" - _setup_test_hub_with_reference(public_hub_marketplace_model_id) - - marketplace_model = JumpStartModel( # Retrieving MP model MP version -> uses MPver - model_id=public_hub_marketplace_model_id, - hub_name=TEST_HUB_WITH_REFERENCE, - role=session.get_caller_identity_arn(), - sagemaker_session=session, - model_version="240612.5", - ) - predictor = marketplace_model.deploy( - tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], - accept_eula=True, - ) - - predictor.delete_predictor() - _teardown_test_hub_with_reference(public_hub_marketplace_model_id) +# TODO: PySDK test account not subscribed to this model +# def test_model_reference_marketplace_model_deployment(setup): +# session = get_sm_session() +# public_hub_marketplace_model_id = "upstage-solar-mini-chat" +# _setup_test_hub_with_reference(public_hub_marketplace_model_id) + +# marketplace_model = JumpStartModel( # Retrieving MP model MP version -> uses MPver +# model_id=public_hub_marketplace_model_id, +# hub_name=TEST_HUB_WITH_REFERENCE, +# role=session.get_caller_identity_arn(), +# sagemaker_session=session, +# model_version="240612.5", +# ) +# predictor = marketplace_model.deploy( +# tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], +# accept_eula=True, +# ) + +# predictor.delete_predictor() +# _teardown_test_hub_with_reference(public_hub_marketplace_model_id) def test_bedrock_store_model_tags_from_hub_service(setup): diff --git a/tests/integ/sagemaker/jumpstart/utils.py b/tests/integ/sagemaker/jumpstart/utils.py index 2ed7b6b78b..881372f421 100644 --- a/tests/integ/sagemaker/jumpstart/utils.py +++ b/tests/integ/sagemaker/jumpstart/utils.py @@ -53,6 +53,7 @@ def get_sm_session() -> Session: return Session(boto_session=boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME)) +<<<<<<< HEAD <<<<<<< HEAD ======= def get_sm_session_with_override() -> Session: @@ -70,6 +71,8 @@ def get_sm_session_with_override() -> Session: >>>>>>> ff3eae05 (feat: Adding Bedrock Store model support for HubService (#1539)) +======= +>>>>>>> 42acb4f4 (chore: Merge from main (#1600)) def get_training_dataset_for_model_and_version(model_id: str, version: str) -> dict: return TRAINING_DATASET_MODEL_DICT[(model_id, version)] diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py index 0ef740398c..22bc527b18 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -268,7 +268,7 @@ def test_walk_and_apply_json(): } result = parser_utils.walk_and_apply_json( - test_json, parser_utils.pascal_to_snake, ["ignore_my_children"] + test_json, parser_utils.camel_to_snake, ["ignore_my_children"] ) assert result == { "camel_case_key": "value", @@ -294,7 +294,7 @@ def test_walk_and_apply_json_no_stop(): "CamelCaseObjectListKey": {"instance.ml.type.xlarge": [{"ShouldChangeMe": "string"}]}, } - result = parser_utils.walk_and_apply_json(test_json, parser_utils.pascal_to_snake) + result = parser_utils.walk_and_apply_json(test_json, parser_utils.camel_to_snake) assert result == { "camel_case_key": "value", "camel_case_object_key": { From 0201f2a7bc88c55cb1024d45df8422b6ffcfa0ed Mon Sep 17 00:00:00 2001 From: chrstfu <105246221+chrstfu@users.noreply.github.com> Date: Mon, 18 Nov 2024 17:39:30 -0500 Subject: [PATCH 03/12] feat: AMI support for BRM (#1589) * feat: AMI support for BRM --- src/sagemaker/jumpstart/factory/model.py | 2 ++ src/sagemaker/jumpstart/hub/interfaces.py | 3 +++ src/sagemaker/jumpstart/hub/parsers.py | 6 +++++- src/sagemaker/jumpstart/model.py | 2 ++ src/sagemaker/jumpstart/types.py | 4 ++++ src/sagemaker/model.py | 2 ++ src/sagemaker/session.py | 4 ++++ 7 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 412c7066a6..328e1e8227 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -662,6 +662,7 @@ def get_deploy_kwargs( config_name: Optional[str] = None, routing_config: Optional[Dict[str, Any]] = None, model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None, + inference_ami_version: Optional[str] = None, ) -> JumpStartModelDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object.""" @@ -699,6 +700,7 @@ def get_deploy_kwargs( config_name=config_name, routing_config=routing_config, model_access_configs=model_access_configs, + inference_ami_version=inference_ami_version, ) deploy_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(kwargs=deploy_kwargs) deploy_kwargs.specs = verify_model_region_and_return_specs( diff --git a/src/sagemaker/jumpstart/hub/interfaces.py b/src/sagemaker/jumpstart/hub/interfaces.py index 69a468a0b4..fd38868dcc 100644 --- a/src/sagemaker/jumpstart/hub/interfaces.py +++ b/src/sagemaker/jumpstart/hub/interfaces.py @@ -471,6 +471,7 @@ class HubModelDocument(HubDataHolderType): "hosting_use_script_uri", "hosting_eula_uri", "hosting_model_package_arn", + "inference_ami_version", "model_subscription_link", "inference_configs", "inference_config_components", @@ -593,6 +594,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri") self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn") + self.inference_ami_version: Optional[str] = json_obj.get("InferenceAmiVersion") + self.model_subscription_link: Optional[str] = json_obj.get("ModelSubscriptionLink") self.inference_config_rankings = self._get_config_rankings(json_obj) diff --git a/src/sagemaker/jumpstart/hub/parsers.py b/src/sagemaker/jumpstart/hub/parsers.py index 51da974217..01b6c5fe87 100644 --- a/src/sagemaker/jumpstart/hub/parsers.py +++ b/src/sagemaker/jumpstart/hub/parsers.py @@ -72,7 +72,11 @@ def get_model_spec_arg_keys( """ arg_keys: List[str] = [] if arg_type == ModelSpecKwargType.DEPLOY: - arg_keys = ["ModelDataDownloadTimeout", "ContainerStartupHealthCheckTimeout"] + arg_keys = [ + "ModelDataDownloadTimeout", + "ContainerStartupHealthCheckTimeout", + "InferenceAmiVersion", + ] elif arg_type == ModelSpecKwargType.ESTIMATOR: arg_keys = [ "EncryptInterContainerTraffic", diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index c173ae55ff..b0b54db557 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -666,6 +666,7 @@ def deploy( endpoint_type: EndpointType = EndpointType.MODEL_BASED, routing_config: Optional[Dict[str, Any]] = None, model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None, + inference_ami_version: Optional[str] = None, ) -> PredictorBase: """Creates endpoint by calling base ``Model`` class `deploy` method. @@ -808,6 +809,7 @@ def deploy( config_name=self.config_name, routing_config=routing_config, model_access_configs=model_access_configs, + inference_ami_version=inference_ami_version, ) if ( self.model_type == JumpStartModelType.PROPRIETARY diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index cb989ca4d4..f9502a3538 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1389,6 +1389,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hosting_model_package_arns: Optional[Dict] = ( model_package_arns if model_package_arns is not None else {} ) + self.hosting_use_script_uri: bool = json_obj.get("hosting_use_script_uri", True) self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( @@ -2245,6 +2246,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "routing_config", "specs", "model_access_configs", + "inference_ami_version", ] SERIALIZATION_EXCLUSION_SET = { @@ -2298,6 +2300,7 @@ def __init__( config_name: Optional[str] = None, routing_config: Optional[Dict[str, Any]] = None, model_access_configs: Optional[Dict[str, CoreModelAccessConfig]] = None, + inference_ami_version: Optional[str] = None, ) -> None: """Instantiates JumpStartModelDeployKwargs object.""" @@ -2336,6 +2339,7 @@ def __init__( self.config_name = config_name self.routing_config = routing_config self.model_access_configs = model_access_configs + self.inference_ami_version = inference_ami_version class JumpStartEstimatorInitKwargs(JumpStartKwargs): diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index b78a4a2a64..863bbf376c 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -1383,6 +1383,7 @@ def deploy( inference_component_name=None, routing_config: Optional[Dict[str, Any]] = None, model_reference_arn: Optional[str] = None, + inference_ami_version: Optional[str] = None, **kwargs, ): """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. @@ -1652,6 +1653,7 @@ def deploy( container_startup_health_check_timeout=container_startup_health_check_timeout, managed_instance_scaling=managed_instance_scaling_config, routing_config=routing_config, + inference_ami_version=inference_ami_version, ) self.sagemaker_session.endpoint_from_production_variants( diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index bbc2c81904..de23472fcf 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -7735,6 +7735,7 @@ def production_variant( container_startup_health_check_timeout=None, managed_instance_scaling=None, routing_config=None, + inference_ami_version=None, ): """Create a production variant description suitable for use in a ``ProductionVariant`` list. @@ -7799,6 +7800,9 @@ def production_variant( RoutingConfig=routing_config, ) + if inference_ami_version: + production_variant_configuration["InferenceAmiVersion"] = inference_ami_version + return production_variant_configuration From 12b5d626cc624697234fb2dbab00b1d0b13a58ea Mon Sep 17 00:00:00 2001 From: Tritin Truong Date: Thu, 17 Oct 2024 16:30:24 -0700 Subject: [PATCH 04/12] feature: Support SageMakerTrainingPlan for training jobs (#1544) Co-authored-by: Tritin Truong --- src/sagemaker/estimator.py | 18 +++-- src/sagemaker/job.py | 4 ++ src/sagemaker/jumpstart/estimator.py | 4 ++ src/sagemaker/jumpstart/factory/estimator.py | 2 + src/sagemaker/jumpstart/types.py | 3 + src/sagemaker/session.py | 69 ++++++++++++++++++++ tests/unit/test_estimator.py | 18 +++++ tests/unit/test_job.py | 48 +++++++++++++- 8 files changed, 159 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index ea51a86101..6efc04c88e 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -185,6 +185,7 @@ def __init__( disable_output_compression: bool = False, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, + training_plan: Optional[Union[str, PipelineVariable]] = None, **kwargs, ): """Initialize an ``EstimatorBase`` instance. @@ -554,6 +555,8 @@ def __init__( Specifies whether RemoteDebug is enabled for the training job. enable_session_tag_chaining (bool or PipelineVariable): Optional. Specifies whether SessionTagChaining is enabled for the training job. + training_plan (str or PipelineVariable): Optional. + Specifies which training plan arn to use for the training job """ instance_count = renamed_kwargs( "train_instance_count", "instance_count", instance_count, kwargs @@ -762,8 +765,7 @@ def __init__( self.tensorboard_output_config = tensorboard_output_config - self.debugger_rule_configs = None - self.collection_configs = None + self.debugger_rule_configs, self.collection_configs = None, None self.enable_sagemaker_metrics = enable_sagemaker_metrics @@ -774,6 +776,7 @@ def __init__( sagemaker_session=self.sagemaker_session, ) + self.profiler_rule_configs, self.profiler_rules = None, None self.profiler_config = profiler_config self.disable_profiler = resolve_value_from_config( direct_input=disable_profiler, @@ -796,8 +799,6 @@ def __init__( ) or _instance_type_supports_profiler(self.instance_type): self.disable_profiler = True - self.profiler_rule_configs = None - self.profiler_rules = None self.debugger_rules = None self.disable_output_compression = disable_output_compression validate_source_code_input_against_pipeline_variables( @@ -807,6 +808,8 @@ def __init__( enable_network_isolation=self._enable_network_isolation, ) + self.training_plan = training_plan + # Internal flag self._is_output_path_set_from_default_bucket_and_prefix = False @@ -1960,6 +1963,9 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na "KeepAlivePeriodInSeconds" ] + if "TrainingPlanArn" in job_details["ResourceConfig"]: + init_params["training_plan"] = job_details["ResourceConfig"]["TrainingPlanArn"] + has_hps = "HyperParameters" in job_details init_params["hyperparameters"] = job_details["HyperParameters"] if has_hps else {} @@ -2840,6 +2846,7 @@ def __init__( enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, + training_plan: Optional[Union[str, PipelineVariable]] = None, **kwargs, ): """Initialize an ``Estimator`` instance. @@ -3205,6 +3212,8 @@ def __init__( Specifies whether RemoteDebug is enabled for the training job enable_session_tag_chaining (bool or PipelineVariable): Optional. Specifies whether SessionTagChaining is enabled for the training job + training_plan (str or PipelineVariable): Optional. + Specifies which training plan arn to use for the training job """ self.image_uri = image_uri self._hyperparameters = hyperparameters.copy() if hyperparameters else {} @@ -3258,6 +3267,7 @@ def __init__( disable_output_compression=disable_output_compression, enable_remote_debug=enable_remote_debug, enable_session_tag_chaining=enable_session_tag_chaining, + training_plan=training_plan, **kwargs, ) diff --git a/src/sagemaker/job.py b/src/sagemaker/job.py index 7040c376ab..210dd426c5 100644 --- a/src/sagemaker/job.py +++ b/src/sagemaker/job.py @@ -83,6 +83,7 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True): estimator.volume_size, estimator.volume_kms_key, estimator.keep_alive_period_in_seconds, + estimator.training_plan, ) stop_condition = _Job._prepare_stop_condition(estimator.max_run, estimator.max_wait) vpc_config = estimator.get_vpc_config() @@ -294,6 +295,7 @@ def _prepare_resource_config( volume_size, volume_kms_key, keep_alive_period_in_seconds, + training_plan, ): """Placeholder docstring""" resource_config = { @@ -319,6 +321,8 @@ def _prepare_resource_config( ) resource_config["InstanceCount"] = instance_count resource_config["InstanceType"] = instance_type + if training_plan is not None: + resource_config["TrainingPlanArn"] = training_plan return resource_config diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 8b30317a52..a41c9ed952 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -115,6 +115,7 @@ def __init__( enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, config_name: Optional[str] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, + training_plan: Optional[Union[str, PipelineVariable]] = None, ): """Initializes a ``JumpStartEstimator``. @@ -511,6 +512,8 @@ def __init__( Name of the training configuration to apply to the Estimator. (Default: None). enable_session_tag_chaining (bool or PipelineVariable): Optional. Specifies whether SessionTagChaining is enabled for the training job + training_plan (str or PipelineVariable): Optional. + Specifies which training plan arn to use for the training job Raises: ValueError: If the model ID is not recognized by JumpStart. @@ -599,6 +602,7 @@ def _validate_model_id_and_get_type_hook(): enable_remote_debug=enable_remote_debug, config_name=config_name, enable_session_tag_chaining=enable_session_tag_chaining, + training_plan=training_plan, ) self.hub_arn = estimator_init_kwargs.hub_arn diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 84c9d09c3d..e4020a39bd 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -144,6 +144,7 @@ def get_init_kwargs( enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, config_name: Optional[str] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, + training_plan: Optional[Union[str, PipelineVariable]] = None, ) -> JumpStartEstimatorInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object.""" @@ -205,6 +206,7 @@ def get_init_kwargs( enable_remote_debug=enable_remote_debug, config_name=config_name, enable_session_tag_chaining=enable_session_tag_chaining, + training_plan=training_plan, ) estimator_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set( diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index f9502a3538..f59e2eddf4 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -2406,6 +2406,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "hub_content_type", "model_reference_arn", "specs", + "training_plan", ] SERIALIZATION_EXCLUSION_SET = { @@ -2479,6 +2480,7 @@ def __init__( enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, config_name: Optional[str] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, + training_plan: Optional[Union[str, PipelineVariable]] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -2541,6 +2543,7 @@ def __init__( self.enable_remote_debug = enable_remote_debug self.config_name = config_name self.enable_session_tag_chaining = enable_session_tag_chaining + self.training_plan = training_plan class JumpStartEstimatorFitKwargs(JumpStartKwargs): diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index de23472fcf..04a7326557 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2470,6 +2470,75 @@ def describe_training_job(self, job_name): """ return self.sagemaker_client.describe_training_job(TrainingJobName=job_name) + def describe_training_plan(self, training_plan_name): + """Calls the DescribeTrainingPlan API for the given training plan and returns the response. + + Args: + training_plan_name (str): The name of the training plan to describe. + + Returns: + dict: A dictionary response with the training plan description. + """ + return self.sagemaker_client.describe_training_plan(TrainingPlanName=training_plan_name) + + def list_training_plans( + self, + filters=None, + requested_start_time_after=None, + requested_start_time_before=None, + start_time_after=None, + start_time_before=None, + sort_order=None, + sort_by=None, + max_results=None, + next_token=None, + ): + """Calls the ListrTrainingPlan API for the given filters and returns the response. + + Args: + filters (dict): A dictionary of key-value pairs used to filter the training plans. + Default to None. + requested_start_time_after (datetime): A timestamp that filters the results + to only include training plans with a requested start time after this timestamp. + requested_start_time_before (datetime): A timestamp that filters the results + to only include training plans with a requested start time before this timestamp. + start_time_after (datetime): A timestamp that filters the results + to only include training plans with an actual start time after this timestamp. + start_time_before (datetime): A timestamp that filters the results + to only include training plans with an actual start time before this timestamp. + sort_order (str): The order that the training plans will be listed in result. + Default to None. + sort_by (str): The value that the training plans will be sorted by. + Default to None. + max_results (int): The number of candidates will be listed in results, + between 1 and 100. Default to None. If None, will return all the training_plans. + next_token (str): The pagination token. Default to None. + + Returns: + dict: A dictionary containing the following keys: + - "TrainingPlanSummaries": A list of dictionaries, where each dictionary represents + a training plan. + - "NextToken": A token to retrieve the next set of results, if there are more + than the maximum number of results returned. + """ + list_training_plan_args = {} + + def check_object(key, value): + if value is not None: + list_training_plan_args[key] = value + + check_object("Filters", filters) + check_object("SortBy", sort_by) + check_object("SortOrder", sort_order) + check_object("RequestedStartTimeAfter", requested_start_time_after) + check_object("RequestedStartTimeBefore", requested_start_time_before) + check_object("StartTimeAfter", start_time_after) + check_object("StartTimeBefore", start_time_before) + check_object("NextToken", next_token) + check_object("MaxResults", max_results) + + return self.sagemaker_client.list_training_plans(**list_training_plan_args) + def auto_ml( self, input_config, diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 0bc84d29d0..8294eb0039 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -89,6 +89,7 @@ INSTANCE_COUNT = 1 INSTANCE_TYPE = "c4.4xlarge" KEEP_ALIVE_PERIOD_IN_SECONDS = 1800 +TRAINING_PLAN = "arn:aws:sagemaker:us-west-2:336:training-plan/test_training_plan" ACCELERATOR_TYPE = "ml.eia.medium" ROLE = "DummyRole" IMAGE_URI = "fakeimage" @@ -861,6 +862,23 @@ def test_framework_with_keep_alive_period(sagemaker_session): assert args["resource_config"]["KeepAlivePeriodInSeconds"] == KEEP_ALIVE_PERIOD_IN_SECONDS +def test_framework_with_training_plan(sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.m4.xlarge", 2), + ], + training_plan=TRAINING_PLAN, + ) + f.fit("s3://mydata") + sagemaker_session.train.assert_called_once() + _, args = sagemaker_session.train.call_args + assert args["resource_config"]["TrainingPlanArn"] == TRAINING_PLAN + + def test_framework_with_both_training_repository_config(sagemaker_session): f = DummyFramework( entry_point=SCRIPT_PATH, diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index 603b494e5a..c93a381c11 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -31,6 +31,7 @@ INSTANCE_COUNT = 1 INSTANCE_TYPE = "c4.4xlarge" KEEP_ALIVE_PERIOD = 1800 +TRAINING_PLAN = "arn:aws:sagemaker:us-west-2:336:training-plan/test_training_plan" INSTANCE_GROUP = InstanceGroup("group", "ml.c4.xlarge", 1) VOLUME_SIZE = 1 MAX_RUNTIME = 1 @@ -633,7 +634,13 @@ def test_prepare_output_config_kms_key_none(): def test_prepare_resource_config(): resource_config = _Job._prepare_resource_config( - INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, None, None + INSTANCE_COUNT, + INSTANCE_TYPE, + None, + VOLUME_SIZE, + None, + None, + None, ) assert resource_config == { @@ -643,9 +650,35 @@ def test_prepare_resource_config(): } +def test_prepare_resource_config_with_training_plan(): + resource_config = _Job._prepare_resource_config( + INSTANCE_COUNT, + INSTANCE_TYPE, + None, + VOLUME_SIZE, + VOLUME_KMS_KEY, + None, + TRAINING_PLAN, + ) + + assert resource_config == { + "InstanceCount": INSTANCE_COUNT, + "InstanceType": INSTANCE_TYPE, + "VolumeSizeInGB": VOLUME_SIZE, + "VolumeKmsKeyId": VOLUME_KMS_KEY, + "TrainingPlanArn": TRAINING_PLAN, + } + + def test_prepare_resource_config_with_keep_alive_period(): resource_config = _Job._prepare_resource_config( - INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, VOLUME_KMS_KEY, KEEP_ALIVE_PERIOD + INSTANCE_COUNT, + INSTANCE_TYPE, + None, + VOLUME_SIZE, + VOLUME_KMS_KEY, + KEEP_ALIVE_PERIOD, + None, ) assert resource_config == { @@ -659,7 +692,13 @@ def test_prepare_resource_config_with_keep_alive_period(): def test_prepare_resource_config_with_volume_kms(): resource_config = _Job._prepare_resource_config( - INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, VOLUME_KMS_KEY, None + INSTANCE_COUNT, + INSTANCE_TYPE, + None, + VOLUME_SIZE, + VOLUME_KMS_KEY, + None, + None, ) assert resource_config == { @@ -678,6 +717,7 @@ def test_prepare_resource_config_with_heterogeneous_cluster(): VOLUME_SIZE, None, None, + None, ) assert resource_config == { @@ -698,6 +738,7 @@ def test_prepare_resource_config_with_instance_groups_instance_type_instance_cou VOLUME_SIZE, None, None, + None, ) assert "instance_count and instance_type cannot be set when instance_groups is set" in str( error @@ -713,6 +754,7 @@ def test_prepare_resource_config_with_instance_groups_instance_type_instance_cou VOLUME_SIZE, None, None, + None, ) assert "instance_count and instance_type must be set if instance_groups is not set" in str( error From 92b69320510f2d358a27b9ce5a615852e4464bf1 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Wed, 4 Dec 2024 08:30:59 -0800 Subject: [PATCH 05/12] fix test_requiremenets.txt --- requirements/extras/test_requirements.txt | 3 --- 1 file changed, 3 deletions(-) diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index cc8068e077..9664a63e1d 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -25,10 +25,7 @@ pyvis==0.2.1 pandas==1.4.4 scikit-learn==1.3.0 cloudpickle>=2.2.1 -<<<<<<< HEAD jsonpickle<4.0.0 -======= ->>>>>>> 42acb4f4 (chore: Merge from main (#1600)) PyYAML==6.0 # TODO find workaround xgboost>=1.6.2,<=1.7.6 From 3538b66b38f6aa78b6895b409c732bfa8c4d9b49 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Wed, 4 Dec 2024 08:41:17 -0800 Subject: [PATCH 06/12] fix merge artifact --- src/sagemaker/jumpstart/hub/parser_utils.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/parser_utils.py b/src/sagemaker/jumpstart/hub/parser_utils.py index 259ba84c23..0983122d09 100644 --- a/src/sagemaker/jumpstart/hub/parser_utils.py +++ b/src/sagemaker/jumpstart/hub/parser_utils.py @@ -18,24 +18,10 @@ from typing import Any, Dict, List, Optional -<<<<<<< HEAD -<<<<<<< HEAD def camel_to_snake(camel_case_string: str) -> str: """Converts camelCase to snake_case_string using a regex. This regex cannot handle whitespace ("camelString TwoWords") -======= -def pascal_to_snake(camel_case_string: str) -> str: - """Converts PascalCase to snake_case_string using a regex. - - This regex cannot handle whitespace ("PascalString TwoWords") ->>>>>>> ff3eae05 (feat: Adding Bedrock Store model support for HubService (#1539)) -======= -def camel_to_snake(camel_case_string: str) -> str: - """Converts camelCase to snake_case_string using a regex. - - This regex cannot handle whitespace ("camelString TwoWords") ->>>>>>> 42acb4f4 (chore: Merge from main (#1600)) """ return re.sub(r"(? Date: Wed, 4 Dec 2024 09:11:34 -0800 Subject: [PATCH 07/12] fix merge artifact --- src/sagemaker/jumpstart/hub/utils.py | 49 ++++++++++++------------ tests/integ/sagemaker/jumpstart/utils.py | 7 ---- 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index b9eeb248a6..84d3a13024 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -219,43 +219,42 @@ def get_hub_model_version( except Exception as ex: raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}") -<<<<<<< HEAD -<<<<<<< HEAD -======= - marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version( - hub_content_summaries, hub_model_version - ) ->>>>>>> ff3eae05 (feat: Adding Bedrock Store model support for HubService (#1539)) -======= ->>>>>>> 42acb4f4 (chore: Merge from main (#1600)) +def get_hub_model_version( + hub_name: str, + hub_model_name: str, + hub_model_type: str, + hub_model_version: Optional[str] = None, + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Returns available Jumpstart hub model version. + + It will attempt both a semantic HubContent version search and Marketplace version search. + If the Marketplace version is also semantic, this function will default to HubContent version. + + Raises: + ClientError: If the specified model is not found in the hub. + KeyError: If the specified model version is not found. + """ + + try: + hub_content_summaries = sagemaker_session.list_hub_content_versions( + hub_name=hub_name, hub_content_name=hub_model_name, hub_content_type=hub_model_type + ).get("HubContentSummaries") + except Exception as ex: + raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}") + try: return _get_hub_model_version_for_open_weight_version( hub_content_summaries, hub_model_version ) -<<<<<<< HEAD -<<<<<<< HEAD -======= ->>>>>>> 42acb4f4 (chore: Merge from main (#1600)) except KeyError: marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version( hub_content_summaries, hub_model_version ) -<<<<<<< HEAD - if marketplace_hub_content_version: - return marketplace_hub_content_version - raise -======= - except KeyError as e: - if marketplace_hub_content_version: - return marketplace_hub_content_version - raise e ->>>>>>> ff3eae05 (feat: Adding Bedrock Store model support for HubService (#1539)) -======= if marketplace_hub_content_version: return marketplace_hub_content_version raise ->>>>>>> 42acb4f4 (chore: Merge from main (#1600)) def _get_hub_model_version_for_open_weight_version( diff --git a/tests/integ/sagemaker/jumpstart/utils.py b/tests/integ/sagemaker/jumpstart/utils.py index 881372f421..d462112a2a 100644 --- a/tests/integ/sagemaker/jumpstart/utils.py +++ b/tests/integ/sagemaker/jumpstart/utils.py @@ -53,9 +53,6 @@ def get_sm_session() -> Session: return Session(boto_session=boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME)) -<<<<<<< HEAD -<<<<<<< HEAD -======= def get_sm_session_with_override() -> Session: # [TODO]: Remove service endpoint override before GA # boto3.set_stream_logger(name='botocore', level=logging.DEBUG) @@ -69,10 +66,6 @@ def get_sm_session_with_override() -> Session: sagemaker_client=sagemaker, ) - ->>>>>>> ff3eae05 (feat: Adding Bedrock Store model support for HubService (#1539)) -======= ->>>>>>> 42acb4f4 (chore: Merge from main (#1600)) def get_training_dataset_for_model_and_version(model_id: str, version: str) -> dict: return TRAINING_DATASET_MODEL_DICT[(model_id, version)] From d465be958b43ffcfa1b1f6af036e9ce94b382218 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Wed, 4 Dec 2024 09:33:58 -0800 Subject: [PATCH 08/12] fix codestyles --- src/sagemaker/jumpstart/hub/utils.py | 25 ------------------------ tests/integ/sagemaker/jumpstart/utils.py | 1 + 2 files changed, 1 insertion(+), 25 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index 84d3a13024..77540926c6 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -195,31 +195,6 @@ def is_gated_bucket(bucket_name: str) -> bool: return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET -def get_hub_model_version( - hub_name: str, - hub_model_name: str, - hub_model_type: str, - hub_model_version: Optional[str] = None, - sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> str: - """Returns available Jumpstart hub model version. - - It will attempt both a semantic HubContent version search and Marketplace version search. - If the Marketplace version is also semantic, this function will default to HubContent version. - - Raises: - ClientError: If the specified model is not found in the hub. - KeyError: If the specified model version is not found. - """ - - try: - hub_content_summaries = sagemaker_session.list_hub_content_versions( - hub_name=hub_name, hub_content_name=hub_model_name, hub_content_type=hub_model_type - ).get("HubContentSummaries") - except Exception as ex: - raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}") - - def get_hub_model_version( hub_name: str, hub_model_name: str, diff --git a/tests/integ/sagemaker/jumpstart/utils.py b/tests/integ/sagemaker/jumpstart/utils.py index d462112a2a..d439ef7e95 100644 --- a/tests/integ/sagemaker/jumpstart/utils.py +++ b/tests/integ/sagemaker/jumpstart/utils.py @@ -66,6 +66,7 @@ def get_sm_session_with_override() -> Session: sagemaker_client=sagemaker, ) + def get_training_dataset_for_model_and_version(model_id: str, version: str) -> dict: return TRAINING_DATASET_MODEL_DICT[(model_id, version)] From 00f5e26bc8c2f8e27334caac74bc2e760050a74f Mon Sep 17 00:00:00 2001 From: Chinmayee Shah Date: Wed, 4 Dec 2024 09:35:33 -0800 Subject: [PATCH 09/12] Hotfix to construct rubik uri correctly (#1646) Co-authored-by: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> --- src/sagemaker/pytorch/training_recipes.json | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/pytorch/training_recipes.json b/src/sagemaker/pytorch/training_recipes.json index df60f95df9..5aeccce5a1 100644 --- a/src/sagemaker/pytorch/training_recipes.json +++ b/src/sagemaker/pytorch/training_recipes.json @@ -5,7 +5,9 @@ "gpu_image" : { "framework": "pytorch-smp", "version": "2.4.1", - "additional_args": {} + "additional_args": { + "container_version": "cu121" + } }, "neuron_image" : { "framework": "hyperpod-recipes-neuron", From fd2440c5f6f3b6ec8a50f2a8bd34bb159104071c Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Wed, 4 Dec 2024 09:37:24 -0800 Subject: [PATCH 10/12] fix gpu_image uri --- src/sagemaker/modules/train/sm_recipes/training_recipes.json | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/modules/train/sm_recipes/training_recipes.json b/src/sagemaker/modules/train/sm_recipes/training_recipes.json index 400e13f08a..a51513f49f 100644 --- a/src/sagemaker/modules/train/sm_recipes/training_recipes.json +++ b/src/sagemaker/modules/train/sm_recipes/training_recipes.json @@ -5,7 +5,9 @@ "gpu_image" : { "framework": "pytorch-smp", "version": "2.4.1", - "additional_args": {} + "additional_args": { + "container_version": "cu121" + } }, "neuron_image": { "framework": "hyperpod-recipes-neuron", From c3f857c0f676c5cd10be64a0cf7c90ba13962476 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Wed, 4 Dec 2024 14:07:01 -0800 Subject: [PATCH 11/12] update boto3 and sagemaker-core version --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4657f41737..be05949d02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ classifiers = [ ] dependencies = [ "attrs>=23.1.0,<24", - "boto3>=1.34.142,<2.0", + "boto3>=1.35.75,<2.0", "cloudpickle==2.2.1", "docker", "fastapi", @@ -49,7 +49,7 @@ dependencies = [ "psutil", "PyYAML~=6.0", "requests", - "sagemaker-core>=1.0.15,<2.0.0", + "sagemaker-core>=1.0.17,<2.0.0", "schema", "smdebug_rulesconfig==1.0.1", "tblib>=1.7.0,<4", From a6adf07f75cde9e5ac47e6d81b7d46ccf5219ac3 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Wed, 4 Dec 2024 15:02:15 -0800 Subject: [PATCH 12/12] Fix unit tests --- tests/unit/sagemaker/modules/train/test_model_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index 049ebaa9c4..093da20ab8 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -285,6 +285,7 @@ def test_train_with_intelligent_defaults_training_job_space( volume_kms_key_id=None, keep_alive_period_in_seconds=None, instance_groups=None, + training_plan_arn=None, ), vpc_config=None, session=ANY, @@ -825,6 +826,7 @@ def mock_upload_data(path, bucket, key_prefix): volume_kms_key_id=compute.volume_kms_key_id, keep_alive_period_in_seconds=compute.keep_alive_period_in_seconds, instance_groups=None, + training_plan_arn=None, ), vpc_config=VpcConfig( security_group_ids=networking.security_group_ids,