-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: Curated hub improvements #4760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 25 commits
b1f5cd8
6b9f390
cb66608
964de22
5392504
269dc08
502063f
f553357
7571a55
5ab02e4
37a36c8
3fe2774
10dba2c
5449eb5
559ef2e
7e307bf
5f7e955
38495dc
90006f6
c331b0c
ac45eea
2f07130
6fb3223
0f3f434
65b61a6
d8b173d
ccb640c
2f5f29b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
"""This module contains accessors related to SageMaker JumpStart.""" | ||
from __future__ import absolute_import | ||
import functools | ||
import logging | ||
from typing import Any, Dict, List, Optional | ||
import boto3 | ||
|
||
|
@@ -289,15 +290,6 @@ def get_model_specs( | |
|
||
if hub_arn: | ||
try: | ||
hub_model_arn = construct_hub_model_arn_from_inputs( | ||
hub_arn=hub_arn, model_name=model_id, version=version | ||
) | ||
model_specs = JumpStartModelsAccessor._cache.get_hub_model( | ||
hub_model_arn=hub_model_arn | ||
) | ||
model_specs.set_hub_content_type(HubContentType.MODEL) | ||
return model_specs | ||
except: # noqa: E722 | ||
hub_model_arn = construct_hub_model_reference_arn_from_inputs( | ||
hub_arn=hub_arn, model_name=model_id, version=version | ||
) | ||
|
@@ -307,6 +299,21 @@ def get_model_specs( | |
model_specs.set_hub_content_type(HubContentType.MODEL_REFERENCE) | ||
return model_specs | ||
|
||
except Exception as ex: | ||
logging.info( | ||
"Recieved exeption while calling APIs for ContentType ModelReference, \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also, high-level question for future PRs: shall we retry on all types of error? As in, if retry throttling as well? That may be fine, but just want to make sure that's a conscious decision. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah I think its fine to retry on all error types but now that I am thinking about it I feel like I can just restrict retry with a different contentType only in the case of ResourceNotFound errors There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. regardless engaging with @judyheflin for the error message |
||
retrying with ContentType Model: " | ||
+ str(ex) | ||
) | ||
hub_model_arn = construct_hub_model_arn_from_inputs( | ||
hub_arn=hub_arn, model_name=model_id, version=version | ||
) | ||
model_specs = JumpStartModelsAccessor._cache.get_hub_model( | ||
hub_model_arn=hub_model_arn | ||
) | ||
model_specs.set_hub_content_type(HubContentType.MODEL) | ||
return model_specs | ||
|
||
return JumpStartModelsAccessor._cache.get_specs( # type: ignore | ||
model_id=model_id, version_str=version, model_type=model_type | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,6 @@ | |
from sagemaker.session import Session | ||
|
||
from sagemaker.jumpstart.constants import ( | ||
DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
JUMPSTART_LOGGER, | ||
) | ||
from sagemaker.jumpstart.types import ( | ||
|
@@ -68,7 +67,7 @@ def __init__( | |
self, | ||
hub_name: str, | ||
bucket_name: Optional[str] = None, | ||
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
sagemaker_session: Optional[Session] = None, | ||
) -> None: | ||
"""Instantiates a SageMaker ``Hub``. | ||
|
||
|
@@ -79,7 +78,10 @@ def __init__( | |
""" | ||
self.hub_name = hub_name | ||
self.region = sagemaker_session.boto_region_name | ||
self._sagemaker_session = sagemaker_session | ||
self._sagemaker_session = ( | ||
sagemaker_session | ||
or utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True) | ||
) | ||
self.hub_storage_location = self._generate_hub_storage_location(bucket_name) | ||
|
||
def _fetch_hub_bucket_name(self) -> str: | ||
|
@@ -274,8 +276,8 @@ def describe_model( | |
try: | ||
model_version = get_hub_model_version( | ||
hub_model_name=model_name, | ||
hub_model_type=HubContentType.MODEL.value, | ||
hub_name=self.hub_name, | ||
hub_model_type=HubContentType.MODEL_REFERENCE.value, | ||
hub_name=self.hub_name if not hub_name else hub_name, | ||
sagemaker_session=self._sagemaker_session, | ||
hub_model_version=model_version, | ||
) | ||
|
@@ -284,24 +286,27 @@ def describe_model( | |
hub_name=self.hub_name if not hub_name else hub_name, | ||
hub_content_name=model_name, | ||
hub_content_version=model_version, | ||
hub_content_type=HubContentType.MODEL.value, | ||
hub_content_type=HubContentType.MODEL_REFERENCE.value, | ||
) | ||
|
||
except Exception as ex: | ||
logging.info("Recieved expection while calling APIs for ContentType Model: " + str(ex)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo |
||
logging.info( | ||
"Recieved exeption while calling APIs for ContentType ModelReference, retrying with ContentType Model: " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo |
||
+ str(ex) | ||
) | ||
model_version = get_hub_model_version( | ||
hub_model_name=model_name, | ||
hub_model_type=HubContentType.MODEL_REFERENCE.value, | ||
hub_name=self.hub_name, | ||
hub_model_type=HubContentType.MODEL.value, | ||
hub_name=self.hub_name if not hub_name else hub_name, | ||
sagemaker_session=self._sagemaker_session, | ||
hub_model_version=model_version, | ||
) | ||
|
||
hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( | ||
hub_name=self.hub_name, | ||
hub_name=self.hub_name if not hub_name else hub_name, | ||
hub_content_name=model_name, | ||
hub_content_version=model_version, | ||
hub_content_type=HubContentType.MODEL_REFERENCE.value, | ||
hub_content_type=HubContentType.MODEL.value, | ||
) | ||
|
||
return DescribeHubContentResponse(hub_content_description) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -193,7 +193,11 @@ def get_hub_model_version( | |
hub_model_version: Optional[str] = None, | ||
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
) -> str: | ||
"""Returns available Jumpstart hub model version""" | ||
"""Returns available Jumpstart hub model version | ||
|
||
Raises: | ||
ResourceNotFound: If the specified model is not found in the hub. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the exception type should be a Python class, not the error returned by the API. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes my bad, updated it to ClientError exception instead |
||
""" | ||
|
||
try: | ||
hub_content_summaries = sagemaker_session.list_hub_content_versions( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -429,6 +429,7 @@ def attach( | |
cls, | ||
endpoint_name: str, | ||
inference_component_name: Optional[str] = None, | ||
hub_name: Optional[str] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. backward-incompatible change, please add arg to the end of the list. |
||
model_id: Optional[str] = None, | ||
model_version: Optional[str] = None, | ||
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
|
@@ -457,6 +458,7 @@ def attach( | |
model_id=model_id, | ||
model_version=model_version, | ||
sagemaker_session=sagemaker_session, | ||
hub_name=hub_name, | ||
) | ||
model.endpoint_name = endpoint_name | ||
model.inference_component_name = inference_component_name | ||
|
Uh oh!
There was an error while loading. Please reload this page.