diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index 402b2ce534..692966cee4 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -16,15 +16,11 @@ from datetime import datetime import logging from typing import Optional, Dict, List, Any, Union -from botocore import exceptions from sagemaker.jumpstart.constants import JUMPSTART_MODEL_HUB_NAME from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.session import Session -from sagemaker.jumpstart.constants import ( - JUMPSTART_LOGGER, -) from sagemaker.jumpstart.types import ( HubContentType, ) @@ -32,9 +28,6 @@ from sagemaker.jumpstart.hub.utils import ( get_hub_model_version, get_info_from_hub_resource_arn, - create_hub_bucket_if_it_does_not_exist, - generate_default_hub_bucket_name, - create_s3_object_reference_from_uri, construct_hub_arn_from_name, ) @@ -42,9 +35,6 @@ list_jumpstart_models, ) -from sagemaker.jumpstart.hub.types import ( - S3ObjectLocation, -) from sagemaker.jumpstart.hub.interfaces import ( DescribeHubResponse, DescribeHubContentResponse, @@ -66,8 +56,8 @@ class Hub: def __init__( self, hub_name: str, + sagemaker_session: Session, bucket_name: Optional[str] = None, - sagemaker_session: Optional[Session] = None, ) -> None: """Instantiates a SageMaker ``Hub``. @@ -78,41 +68,11 @@ def __init__( """ self.hub_name = hub_name self.region = sagemaker_session.boto_region_name + self.bucket_name = bucket_name self._sagemaker_session = ( sagemaker_session or utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True) ) - self.hub_storage_location = self._generate_hub_storage_location(bucket_name) - - def _fetch_hub_bucket_name(self) -> str: - """Retrieves hub bucket name from Hub config if exists""" - try: - hub_response = self._sagemaker_session.describe_hub(hub_name=self.hub_name) - hub_output_location = hub_response["S3StorageConfig"].get("S3OutputPath") - if hub_output_location: - location = create_s3_object_reference_from_uri(hub_output_location) - return location.bucket - default_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) - JUMPSTART_LOGGER.warning( - "There is not a Hub bucket associated with %s. Using %s", - self.hub_name, - default_bucket_name, - ) - return default_bucket_name - except exceptions.ClientError: - hub_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) - JUMPSTART_LOGGER.warning( - "There is not a Hub bucket associated with %s. Using %s", - self.hub_name, - hub_bucket_name, - ) - return hub_bucket_name - - def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> None: - """Generates an ``S3ObjectLocation`` given a Hub name.""" - hub_bucket_name = bucket_name or self._fetch_hub_bucket_name() - curr_timestamp = datetime.now().timestamp() - return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}") def _get_latest_model_version(self, model_id: str) -> str: """Populates the lastest version of a model from specs no matter what is passed. @@ -132,19 +92,22 @@ def create( tags: Optional[str] = None, ) -> Dict[str, str]: """Creates a hub with the given description""" + curr_timestamp = datetime.now().timestamp() - create_hub_bucket_if_it_does_not_exist( - self.hub_storage_location.bucket, self._sagemaker_session - ) + request = { + "hub_name": self.hub_name, + "hub_description": description, + "hub_display_name": display_name, + "hub_search_keywords": search_keywords, + "tags": tags, + } - return self._sagemaker_session.create_hub( - hub_name=self.hub_name, - hub_description=description, - hub_display_name=display_name, - hub_search_keywords=search_keywords, - s3_storage_config={"S3OutputPath": self.hub_storage_location.get_uri()}, - tags=tags, - ) + if self.bucket_name: + request["s3_storage_config"] = { + "S3OutputPath": (f"s3://{self.bucket_name}/{self.hub_name}-{curr_timestamp}") + } + + return self._sagemaker_session.create_hub(**request) def describe(self, hub_name: Optional[str] = None) -> DescribeHubResponse: """Returns descriptive information about the Hub""" diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index 75af019ca6..0df5e9d5c3 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -15,8 +15,6 @@ from __future__ import absolute_import import re from typing import Optional, List, Any -from sagemaker.jumpstart.hub.types import S3ObjectLocation -from sagemaker.s3_utils import parse_s3_url from sagemaker.session import Session from sagemaker.utils import aws_partition from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo @@ -139,61 +137,6 @@ def generate_hub_arn_for_init_kwargs( return hub_arn -def generate_default_hub_bucket_name( - sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> str: - """Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions. - - Returns: - str: The name of the default bucket. If the name was not explicitly specified through - the Session or sagemaker_config, the bucket will take the form: - ``sagemaker-hubs-{region}-{AWS account ID}``. - """ - - region: str = sagemaker_session.boto_region_name - account_id: str = sagemaker_session.account_id() - - # TODO: Validate and fast fail - - return f"sagemaker-hubs-{region}-{account_id}" - - -def create_s3_object_reference_from_uri(s3_uri: Optional[str]) -> Optional[S3ObjectLocation]: - """Utiity to help generate an S3 object reference""" - if not s3_uri: - return None - - bucket, key = parse_s3_url(s3_uri) - - return S3ObjectLocation( - bucket=bucket, - key=key, - ) - - -def create_hub_bucket_if_it_does_not_exist( - bucket_name: Optional[str] = None, - sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> str: - """Creates the default SageMaker Hub bucket if it does not exist. - - Returns: - str: The name of the default bucket. Takes the form: - ``sagemaker-hubs-{region}-{AWS account ID}``. - """ - - region: str = sagemaker_session.boto_region_name - if bucket_name is None: - bucket_name: str = generate_default_hub_bucket_name(sagemaker_session) - - sagemaker_session._create_s3_bucket_if_it_does_not_exist( - bucket_name=bucket_name, - region=region, - ) - - return bucket_name - - def is_gated_bucket(bucket_name: str) -> bool: """Returns true if the bucket name is the JumpStart gated bucket.""" return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET diff --git a/tests/unit/sagemaker/jumpstart/hub/test_hub.py b/tests/unit/sagemaker/jumpstart/hub/test_hub.py index 06f5473322..29efb6b31f 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_hub.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_hub.py @@ -16,7 +16,6 @@ import pytest from mock import Mock from sagemaker.jumpstart.hub.hub import Hub -from sagemaker.jumpstart.hub.types import S3ObjectLocation REGION = "us-east-1" @@ -60,48 +59,34 @@ def test_instantiates(sagemaker_session): @pytest.mark.parametrize( - ("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"), + ("hub_name,hub_description,,hub_display_name,hub_search_keywords,tags"), [ - pytest.param("MockHub1", "this is my sagemaker hub", None, None, None, None), + pytest.param("MockHub1", "this is my sagemaker hub", None, None, None), pytest.param( "MockHub2", "this is my sagemaker hub two", - None, "DisplayMockHub2", ["mock", "hub", "123"], [{"Key": "tag-key-1", "Value": "tag-value-1"}], ), ], ) -@patch("sagemaker.jumpstart.hub.hub.Hub._generate_hub_storage_location") def test_create_with_no_bucket_name( - mock_generate_hub_storage_location, sagemaker_session, hub_name, hub_description, - hub_bucket_name, hub_display_name, hub_search_keywords, tags, ): - storage_location = S3ObjectLocation( - "sagemaker-hubs-us-east-1-123456789123", f"{hub_name}-{FAKE_TIME.timestamp()}" - ) - mock_generate_hub_storage_location.return_value = storage_location create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} sagemaker_session.create_hub = Mock(return_value=create_hub) - sagemaker_session.describe_hub.return_value = { - "S3StorageConfig": {"S3OutputPath": f"s3://{hub_bucket_name}/{storage_location.key}"} - } hub = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session) request = { "hub_name": hub_name, "hub_description": hub_description, "hub_display_name": hub_display_name, "hub_search_keywords": hub_search_keywords, - "s3_storage_config": { - "S3OutputPath": f"s3://sagemaker-hubs-us-east-1-123456789123/{storage_location.key}" - }, "tags": tags, } response = hub.create( @@ -128,9 +113,9 @@ def test_create_with_no_bucket_name( ), ], ) -@patch("sagemaker.jumpstart.hub.hub.Hub._generate_hub_storage_location") +@patch("sagemaker.jumpstart.hub.hub.datetime") def test_create_with_bucket_name( - mock_generate_hub_storage_location, + mock_datetime, sagemaker_session, hub_name, hub_description, @@ -139,8 +124,8 @@ def test_create_with_bucket_name( hub_search_keywords, tags, ): - storage_location = S3ObjectLocation(hub_bucket_name, f"{hub_name}-{FAKE_TIME.timestamp()}") - mock_generate_hub_storage_location.return_value = storage_location + mock_datetime.now.return_value = FAKE_TIME + create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} sagemaker_session.create_hub = Mock(return_value=create_hub) hub = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session, bucket_name=hub_bucket_name) @@ -149,7 +134,9 @@ def test_create_with_bucket_name( "hub_description": hub_description, "hub_display_name": hub_display_name, "hub_search_keywords": hub_search_keywords, - "s3_storage_config": {"S3OutputPath": f"s3://mock-bucket-123/{storage_location.key}"}, + "s3_storage_config": { + "S3OutputPath": f"s3://mock-bucket-123/{hub_name}-{FAKE_TIME.timestamp()}" + }, "tags": tags, } response = hub.create( diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py index a0b824fc9b..5745a7f79c 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -173,30 +173,6 @@ def test_generate_hub_arn_for_init_kwargs(): assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn -def test_create_hub_bucket_if_it_does_not_exist_hub_arn(): - mock_sagemaker_session = Mock() - mock_sagemaker_session.account_id.return_value = "123456789123" - mock_sagemaker_session.client("sts").get_caller_identity.return_value = { - "Account": "123456789123" - } - hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub" - # Mock custom session with custom values - mock_custom_session = Mock() - mock_custom_session.account_id.return_value = "000000000000" - mock_custom_session.boto_region_name = "us-east-2" - mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None - mock_sagemaker_session.boto_region_name = "us-east-1" - - bucket_name = "sagemaker-hubs-us-east-1-123456789123" - created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( - sagemaker_session=mock_sagemaker_session - ) - - mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() - assert created_hub_bucket_name == bucket_name - assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn - - def test_is_gated_bucket(): assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-west-2") is True @@ -207,23 +183,6 @@ def test_is_gated_bucket(): assert utils.is_gated_bucket("") is False -def test_create_hub_bucket_if_it_does_not_exist(): - mock_sagemaker_session = Mock() - mock_sagemaker_session.account_id.return_value = "123456789123" - mock_sagemaker_session.client("sts").get_caller_identity.return_value = { - "Account": "123456789123" - } - mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None - mock_sagemaker_session.boto_region_name = "us-east-1" - bucket_name = "sagemaker-hubs-us-east-1-123456789123" - created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( - sagemaker_session=mock_sagemaker_session - ) - - mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() - assert created_hub_bucket_name == bucket_name - - @patch("sagemaker.session.Session") def test_get_hub_model_version_success(mock_session): hub_name = "test_hub"