diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 257a9e71af..25804bba74 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -462,19 +462,24 @@ def _retrieval_function( HubContentType.MODEL_REFERENCE, }: - hub_arn_extracted_info = hub_utils.get_info_from_hub_resource_arn(id_info) + hub_resource_arn_extracted_info = hub_utils.get_info_from_hub_resource_arn(id_info) + hub_arn = hub_utils.construct_hub_arn_from_name( + hub_name=hub_resource_arn_extracted_info.hub_name, + region=hub_resource_arn_extracted_info.region, + account_id=hub_resource_arn_extracted_info.account_id, + ) model_version: str = hub_utils.get_hub_model_version( - hub_model_name=hub_arn_extracted_info.hub_content_name, + hub_model_name=hub_resource_arn_extracted_info.hub_content_name, hub_model_type=data_type.value, - hub_name=hub_arn_extracted_info.hub_name, + hub_name=hub_arn, sagemaker_session=self._sagemaker_session, - hub_model_version=hub_arn_extracted_info.hub_content_version, + hub_model_version=hub_resource_arn_extracted_info.hub_content_version, ) hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( - hub_name=hub_arn_extracted_info.hub_name, - hub_content_name=hub_arn_extracted_info.hub_content_name, + hub_name=hub_arn, + hub_content_name=hub_resource_arn_extracted_info.hub_content_name, hub_content_version=model_version, hub_content_type=data_type.value, ) diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index 2624796b3f..535bc5e9be 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -67,10 +67,11 @@ def construct_hub_arn_from_name( hub_name: str, region: Optional[str] = None, session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + account_id: Optional[str] = None, ) -> str: """Constructs a Hub arn from the Hub name using default Session values.""" - account_id = session.account_id() + account_id = account_id or session.account_id() region = region or session.boto_region_name partition = aws_partition(region)