-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: Curated hub improvements #4760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 21 commits
b1f5cd8
6b9f390
cb66608
964de22
5392504
269dc08
502063f
f553357
7571a55
5ab02e4
37a36c8
3fe2774
10dba2c
5449eb5
559ef2e
7e307bf
5f7e955
38495dc
90006f6
c331b0c
ac45eea
2f07130
6fb3223
0f3f434
65b61a6
d8b173d
ccb640c
2f5f29b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
"""This module contains accessors related to SageMaker JumpStart.""" | ||
from __future__ import absolute_import | ||
import functools | ||
import logging | ||
from typing import Any, Dict, List, Optional | ||
import boto3 | ||
|
||
|
@@ -289,15 +290,6 @@ def get_model_specs( | |
|
||
if hub_arn: | ||
try: | ||
hub_model_arn = construct_hub_model_arn_from_inputs( | ||
hub_arn=hub_arn, model_name=model_id, version=version | ||
) | ||
model_specs = JumpStartModelsAccessor._cache.get_hub_model( | ||
hub_model_arn=hub_model_arn | ||
) | ||
model_specs.set_hub_content_type(HubContentType.MODEL) | ||
return model_specs | ||
except: # noqa: E722 | ||
hub_model_arn = construct_hub_model_reference_arn_from_inputs( | ||
hub_arn=hub_arn, model_name=model_id, version=version | ||
) | ||
|
@@ -307,6 +299,21 @@ def get_model_specs( | |
model_specs.set_hub_content_type(HubContentType.MODEL_REFERENCE) | ||
return model_specs | ||
|
||
except Exception as ex: | ||
logging.info( | ||
"Recieved exeption while calling APIs for ContentType Model, \ | ||
retrying with ContentType ModelReference: " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually it other way around. The code is first attempting using ModelReference and then as Model. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ohh yeah I forgot I changed it recently, let me correct this. Thanks There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also identified one more place where we were trying with Model first and then ModelRef. Changed that as well, in the Hub class. Thanks |
||
+ str(ex) | ||
) | ||
hub_model_arn = construct_hub_model_arn_from_inputs( | ||
hub_arn=hub_arn, model_name=model_id, version=version | ||
) | ||
model_specs = JumpStartModelsAccessor._cache.get_hub_model( | ||
hub_model_arn=hub_model_arn | ||
) | ||
model_specs.set_hub_content_type(HubContentType.MODEL) | ||
return model_specs | ||
|
||
return JumpStartModelsAccessor._cache.get_specs( # type: ignore | ||
model_id=model_id, version_str=version, model_type=model_type | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,6 @@ | |
from sagemaker.session import Session | ||
|
||
from sagemaker.jumpstart.constants import ( | ||
DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
JUMPSTART_LOGGER, | ||
) | ||
from sagemaker.jumpstart.types import ( | ||
|
@@ -68,7 +67,9 @@ def __init__( | |
self, | ||
hub_name: str, | ||
bucket_name: Optional[str] = None, | ||
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
sagemaker_session: Optional[ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i wouldn't set this as a default argument since this function will get invoked whenever the module is imported, which may cause slow latency or errors on some systems. can you set in constructor body instead? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks changed it in the new revision |
||
Session | ||
] = utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True), | ||
) -> None: | ||
"""Instantiates a SageMaker ``Hub``. | ||
|
||
|
@@ -288,7 +289,10 @@ def describe_model( | |
) | ||
|
||
except Exception as ex: | ||
logging.info("Recieved expection while calling APIs for ContentType Model: " + str(ex)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo |
||
logging.info( | ||
"Recieved exeption while calling APIs for ContentType Model, retrying with ContentType ModelReference: " | ||
+ str(ex) | ||
) | ||
model_version = get_hub_model_version( | ||
hub_model_name=model_name, | ||
hub_model_type=HubContentType.MODEL_REFERENCE.value, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -193,7 +193,11 @@ def get_hub_model_version( | |
hub_model_version: Optional[str] = None, | ||
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
) -> str: | ||
"""Returns available Jumpstart hub model version""" | ||
"""Returns available Jumpstart hub model version | ||
|
||
Raises: | ||
ResourceNotFound: If the specified model is not found in the hub. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the exception type should be a Python class, not the error returned by the API. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes my bad, updated it to ClientError exception instead |
||
""" | ||
|
||
try: | ||
hub_content_summaries = sagemaker_session.list_hub_content_versions( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -429,6 +429,7 @@ def attach( | |
cls, | ||
endpoint_name: str, | ||
inference_component_name: Optional[str] = None, | ||
hub_name: Optional[str] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. backward-incompatible change, please add arg to the end of the list. |
||
model_id: Optional[str] = None, | ||
model_version: Optional[str] = None, | ||
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
|
@@ -457,6 +458,7 @@ def attach( | |
model_id=model_id, | ||
model_version=model_version, | ||
sagemaker_session=sagemaker_session, | ||
hub_name=hub_name, | ||
) | ||
model.endpoint_name = endpoint_name | ||
model.inference_component_name = inference_component_name | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -384,12 +384,12 @@ def add_jumpstart_model_id_version_tags( | |
|
||
def add_hub_content_arn_tags( | ||
tags: Optional[List[TagsDict]], | ||
hub_arn: str, | ||
hub_content_arn: str, | ||
) -> Optional[List[TagsDict]]: | ||
"""Adds custom Hub arn tag to JumpStart related resources.""" | ||
|
||
tags = add_single_jumpstart_tag( | ||
hub_arn, | ||
hub_content_arn, | ||
enums.JumpStartTag.HUB_CONTENT_ARN, | ||
tags, | ||
is_uri=False, | ||
|
@@ -1012,24 +1012,37 @@ def get_jumpstart_configs( | |
) | ||
|
||
|
||
def get_jumpstart_user_agent_extra_suffix(model_id: str, model_version: str) -> str: | ||
def get_jumpstart_user_agent_extra_suffix( | ||
model_id: Optional[str], model_version: Optional[str], is_hub_content: Optional[bool] | ||
) -> str: | ||
"""Returns the model-specific user agent string to be added to requests.""" | ||
sagemaker_python_sdk_headers = get_user_agent_extra_suffix() | ||
jumpstart_specific_suffix = f"md/js_model_id#{model_id} md/js_model_ver#{model_version}" | ||
return ( | ||
sagemaker_python_sdk_headers | ||
if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None) | ||
else f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix}" | ||
) | ||
hub_specific_suffix = f"md/js_is_hub_content#{is_hub_content}" | ||
|
||
if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None): | ||
headers = sagemaker_python_sdk_headers | ||
elif model_id is None and model_version is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we only add the tag There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks changed it in the new revision |
||
headers = f"{sagemaker_python_sdk_headers} {hub_specific_suffix}" | ||
else: | ||
headers = ( | ||
f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix} {hub_specific_suffix}" | ||
) | ||
|
||
return headers | ||
|
||
|
||
def get_default_jumpstart_session_with_user_agent_suffix( | ||
model_id: str, model_version: str | ||
model_id: Optional[str] = None, | ||
model_version: Optional[str] = None, | ||
is_hub_content: Optional[bool] = False, | ||
) -> Session: | ||
"""Returns default JumpStart SageMaker Session with model-specific user agent suffix.""" | ||
botocore_session = botocore.session.get_session() | ||
botocore_config = botocore.config.Config( | ||
user_agent_extra=get_jumpstart_user_agent_extra_suffix(model_id, model_version), | ||
user_agent_extra=get_jumpstart_user_agent_extra_suffix( | ||
model_id, model_version, is_hub_content | ||
), | ||
) | ||
botocore_session.set_default_client_config(botocore_config) | ||
# shallow copy to not affect default session constant | ||
|
Uh oh!
There was an error while loading. Please reload this page.