Skip to content

feat: Marketplace model support in HubService #4900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions src/sagemaker/jumpstart/hub/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,12 +451,14 @@ def from_json(self, json_obj: str) -> None:
class HubModelDocument(HubDataHolderType):
"""Data class for model type HubContentDocument from session.describe_hub_content()."""

SCHEMA_VERSION = "2.2.0"
SCHEMA_VERSION = "2.3.0"

__slots__ = [
"url",
"min_sdk_version",
"training_supported",
"model_types",
"capabilities",
"incremental_training_supported",
"dynamic_container_deployment_supported",
"hosting_ecr_uri",
Expand All @@ -469,6 +471,7 @@ class HubModelDocument(HubDataHolderType):
"hosting_use_script_uri",
"hosting_eula_uri",
"hosting_model_package_arn",
"model_subscription_link",
"inference_configs",
"inference_config_components",
"inference_config_rankings",
Expand Down Expand Up @@ -549,18 +552,22 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
Args:
json_obj (Dict[str, Any]): Dictionary representation of hub model document.
"""
self.url: str = json_obj["Url"]
self.min_sdk_version: str = json_obj["MinSdkVersion"]
self.hosting_ecr_uri: Optional[str] = json_obj["HostingEcrUri"]
self.hosting_artifact_uri = json_obj["HostingArtifactUri"]
self.hosting_script_uri = json_obj["HostingScriptUri"]
self.inference_dependencies: List[str] = json_obj["InferenceDependencies"]
self.url: str = json_obj.get("Url")
self.min_sdk_version: str = json_obj.get("MinSdkVersion")
self.hosting_ecr_uri: Optional[str] = json_obj.get("HostingEcrUri")
self.hosting_artifact_uri = json_obj.get("HostingArtifactUri")
self.hosting_script_uri = json_obj.get("HostingScriptUri")
self.inference_dependencies: List[str] = json_obj.get("InferenceDependencies")
self.inference_environment_variables: List[JumpStartEnvironmentVariable] = [
JumpStartEnvironmentVariable(env_variable, is_hub_content=True)
for env_variable in json_obj["InferenceEnvironmentVariables"]
for env_variable in json_obj.get("InferenceEnvironmentVariables", [])
]
self.training_supported: bool = bool(json_obj["TrainingSupported"])
self.incremental_training_supported: bool = bool(json_obj["IncrementalTrainingSupported"])
self.model_types: Optional[List[str]] = json_obj.get("ModelTypes")
self.capabilities: Optional[List[str]] = json_obj.get("Capabilities")
self.training_supported: bool = bool(json_obj.get("TrainingSupported"))
self.incremental_training_supported: bool = bool(
json_obj.get("IncrementalTrainingSupported")
)
self.dynamic_container_deployment_supported: Optional[bool] = (
bool(json_obj.get("DynamicContainerDeploymentSupported"))
if json_obj.get("DynamicContainerDeploymentSupported")
Expand All @@ -586,6 +593,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri")
self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn")

self.model_subscription_link: Optional[str] = json_obj.get("ModelSubscriptionLink")

self.inference_config_rankings = self._get_config_rankings(json_obj)
self.inference_config_components = self._get_config_components(json_obj)
self.inference_configs = self._get_configs(json_obj)
Expand Down
10 changes: 3 additions & 7 deletions src/sagemaker/jumpstart/hub/parser_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,9 @@
from typing import Any, Dict, List, Optional


def camel_to_snake(camel_case_string: str) -> str:
"""Converts camelCaseString or UpperCamelCaseString to snake_case_string."""
snake_case_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_string)
if "-" in snake_case_string:
# remove any hyphen from the string for accurate conversion.
snake_case_string = snake_case_string.replace("-", "")
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case_string).lower()
def pascal_to_snake(camel_case_string: str) -> str:
"""Converts PascalCase to snake_case_string."""
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case_string).lower()


def snake_to_upper_camel(snake_case_string: str) -> str:
Expand Down
34 changes: 22 additions & 12 deletions src/sagemaker/jumpstart/hub/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
HubModelDocument,
)
from sagemaker.jumpstart.hub.parser_utils import (
camel_to_snake,
pascal_to_snake,
snake_to_upper_camel,
walk_and_apply_json,
)
Expand Down Expand Up @@ -86,7 +86,7 @@ def get_model_spec_arg_keys(
arg_keys = []

if naming_convention == NamingConventionType.SNAKE_CASE:
arg_keys = [camel_to_snake(key) for key in arg_keys]
arg_keys = [pascal_to_snake(key) for key in arg_keys]
elif naming_convention == NamingConventionType.UPPER_CAMEL_CASE:
return arg_keys
else:
Expand Down Expand Up @@ -137,6 +137,8 @@ def make_model_specs_from_describe_hub_content_response(
hub_model_document: HubModelDocument = response.hub_content_document
specs["url"] = hub_model_document.url
specs["min_sdk_version"] = hub_model_document.min_sdk_version
specs["model_types"] = hub_model_document.model_types
specs["capabilities"] = hub_model_document.capabilities
specs["training_supported"] = bool(hub_model_document.training_supported)
specs["incremental_training_supported"] = bool(
hub_model_document.incremental_training_supported
Expand All @@ -146,15 +148,19 @@ def make_model_specs_from_describe_hub_content_response(
specs["inference_config_components"] = hub_model_document.inference_config_components
specs["inference_config_rankings"] = hub_model_document.inference_config_rankings

hosting_artifact_bucket, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
hub_model_document.hosting_artifact_uri
)
specs["hosting_artifact_key"] = hosting_artifact_key
specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri
hosting_script_bucket, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable
hub_model_document.hosting_script_uri
)
specs["hosting_script_key"] = hosting_script_key
if hub_model_document.hosting_artifact_uri:
_, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
hub_model_document.hosting_artifact_uri
)
specs["hosting_artifact_key"] = hosting_artifact_key
specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri

if hub_model_document.hosting_script_uri:
_, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable
hub_model_document.hosting_script_uri
)
specs["hosting_script_key"] = hosting_script_key

specs["inference_environment_variables"] = hub_model_document.inference_environment_variables
specs["inference_vulnerable"] = False
specs["inference_dependencies"] = hub_model_document.inference_dependencies
Expand Down Expand Up @@ -201,7 +207,7 @@ def make_model_specs_from_describe_hub_content_response(
default_payloads: Dict[str, Any] = {}
if hub_model_document.default_payloads is not None:
for alias, payload in hub_model_document.default_payloads.items():
default_payloads[alias] = walk_and_apply_json(payload.to_json(), camel_to_snake)
default_payloads[alias] = walk_and_apply_json(payload.to_json(), pascal_to_snake)
specs["default_payloads"] = default_payloads
specs["gated_bucket"] = hub_model_document.gated_bucket
specs["inference_volume_size"] = hub_model_document.inference_volume_size
Expand All @@ -219,6 +225,10 @@ def make_model_specs_from_describe_hub_content_response(

if hub_model_document.hosting_model_package_arn:
specs["hosting_model_package_arns"] = {region: hub_model_document.hosting_model_package_arn}

specs["model_subscription_link"] = hub_model_document.model_subscription_link

specs["model_subscription_link"] = hub_model_document.model_subscription_link

specs["hosting_use_script_uri"] = hub_model_document.hosting_use_script_uri

Expand Down
57 changes: 54 additions & 3 deletions src/sagemaker/jumpstart/hub/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""This module contains utilities related to SageMaker JumpStart Hub."""
from __future__ import absolute_import
import re
from typing import Optional
from typing import Optional, List, Any
from sagemaker.jumpstart.hub.types import S3ObjectLocation
from sagemaker.s3_utils import parse_s3_url
from sagemaker.session import Session
Expand All @@ -23,6 +23,8 @@
from sagemaker.jumpstart import constants
from packaging.specifiers import SpecifierSet, InvalidSpecifier

PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:"


def get_info_from_hub_resource_arn(
arn: str,
Expand Down Expand Up @@ -207,6 +209,24 @@ def get_hub_model_version(
except Exception as ex:
raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}")

open_weight_hub_content_version = _get_open_weight_hub_model_version(
hub_content_summaries, hub_model_version
)
if open_weight_hub_content_version:
return open_weight_hub_content_version

proprietary_hub_content_version = _get_proprietary_hub_model_version(
hub_content_summaries, hub_model_version
)
if proprietary_hub_content_version:
return proprietary_hub_content_version

raise KeyError(f"Could not find HubContent with specified version: {hub_model_version}")


def _get_open_weight_hub_model_version(
hub_content_summaries: List[Any], hub_model_version: Optional[str] = None
) -> Optional[str]:
available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries]

if hub_model_version == "*" or hub_model_version is None:
Expand All @@ -215,10 +235,41 @@ def get_hub_model_version(
try:
spec = SpecifierSet(f"=={hub_model_version}")
except InvalidSpecifier:
raise KeyError(f"Bad semantic version: {hub_model_version}")
return None
available_versions_filtered = list(spec.filter(available_model_versions))
if not available_versions_filtered:
raise KeyError("Model version not available in the Hub")
return None
hub_model_version = str(max(available_versions_filtered))

return hub_model_version


def _get_proprietary_hub_model_version(
hub_content_summaries: List[Any], proprietary_hub_model_version: str
) -> Optional[str]:

for model in hub_content_summaries:
model_search_keywords = model.get("HubContentSearchKeywords", [])
if _hub_search_keywords_contains_proprietary_version(
model_search_keywords, proprietary_hub_model_version
):
return model.get("HubContentVersion")

return None


def _hub_search_keywords_contains_proprietary_version(
model_search_keywords: List[str], proprietary_hub_model_version: str
) -> bool:
proprietary_version_keyword = next(
filter(lambda s: s.startswith(PROPRIETARY_VERSION_KEYWORD), model_search_keywords), None
)

if not proprietary_version_keyword:
return False

proprietary_version = proprietary_version_keyword.lstrip(PROPRIETARY_VERSION_KEYWORD)
if proprietary_version == proprietary_hub_model_version:
return True

return False
30 changes: 17 additions & 13 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
from sagemaker.enums import EndpointType
from sagemaker.jumpstart.hub.parser_utils import (
camel_to_snake,
pascal_to_snake,
walk_and_apply_json,
)

Expand Down Expand Up @@ -239,7 +239,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
return

if self._is_hub_content:
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)

self.framework = json_obj.get("framework")
self.framework_version = json_obj.get("framework_version")
Expand Down Expand Up @@ -293,7 +293,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
"""

if self._is_hub_content:
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
self.name = json_obj["name"]
self.type = json_obj["type"]
self.default = json_obj["default"]
Expand Down Expand Up @@ -361,7 +361,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
Args:
json_obj (Dict[str, Any]): Dictionary representation of environment variable.
"""
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
self.name = json_obj["name"]
self.type = json_obj["type"]
self.default = json_obj["default"]
Expand Down Expand Up @@ -411,7 +411,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
return

if self._is_hub_content:
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
self.default_content_type = json_obj["default_content_type"]
self.supported_content_types = json_obj["supported_content_types"]
self.default_accept_type = json_obj["default_accept_type"]
Expand Down Expand Up @@ -465,7 +465,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
return

if self._is_hub_content:
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
self.raw_payload = json_obj
self.content_type = json_obj["content_type"]
self.body = json_obj.get("body")
Expand Down Expand Up @@ -538,7 +538,7 @@ def from_describe_hub_content_response(self, response: Optional[Dict[str, Any]])
if response is None:
return

response = walk_and_apply_json(response, camel_to_snake)
response = walk_and_apply_json(response, pascal_to_snake)
self.aliases: Optional[dict] = response.get("aliases")
self.regional_aliases = None
self.variants: Optional[dict] = response.get("variants")
Expand Down Expand Up @@ -1174,7 +1174,7 @@ def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content=False):
spec (Dict[str, Any]): Dictionary representation of training config ranking.
"""
if is_hub_content:
spec = walk_and_apply_json(spec, camel_to_snake)
spec = walk_and_apply_json(spec, pascal_to_snake)
self.from_json(spec)

def from_json(self, json_obj: Dict[str, Any]) -> None:
Expand All @@ -1200,6 +1200,8 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
"url",
"version",
"min_sdk_version",
"model_types",
"capabilities",
"incremental_training_supported",
"hosting_ecr_specs",
"hosting_ecr_uri",
Expand Down Expand Up @@ -1278,7 +1280,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
json_obj (Dict[str, Any]): Dictionary representation of spec.
"""
if self._is_hub_content:
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
self.model_id: str = json_obj.get("model_id")
self.url: str = json_obj.get("url")
self.version: str = json_obj.get("version")
Expand All @@ -1287,6 +1289,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
json_obj.get("incremental_training_supported", False)
)
if self._is_hub_content:
self.capabilities: Optional[List[str]] = json_obj.get("capabilities")
self.model_types: Optional[List[str]] = json_obj.get("model_types")
self.hosting_ecr_uri: Optional[str] = json_obj.get("hosting_ecr_uri")
self._non_serializable_slots.append("hosting_ecr_specs")
else:
Expand Down Expand Up @@ -1505,7 +1509,7 @@ def __init__(
ValueError: If the component field is invalid.
"""
if is_hub_content:
component = walk_and_apply_json(component, camel_to_snake)
component = walk_and_apply_json(component, pascal_to_snake)
self.component_name = component_name
super().__init__(component, is_hub_content)
self.from_json(component)
Expand Down Expand Up @@ -1558,8 +1562,8 @@ def __init__(
The list of components that are used to construct the resolved config.
"""
if is_hub_content:
config = walk_and_apply_json(config, camel_to_snake)
base_fields = walk_and_apply_json(base_fields, camel_to_snake)
config = walk_and_apply_json(config, pascal_to_snake)
base_fields = walk_and_apply_json(base_fields, pascal_to_snake)
self.base_fields = base_fields
self.config_components: Dict[str, JumpStartConfigComponent] = config_components
self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = (
Expand Down Expand Up @@ -1725,7 +1729,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
"""
super().from_json(json_obj)
if self._is_hub_content:
json_obj = walk_and_apply_json(json_obj, camel_to_snake)
json_obj = walk_and_apply_json(json_obj, pascal_to_snake)
self.inference_config_components: Optional[Dict[str, JumpStartConfigComponent]] = (
{
component_name: JumpStartConfigComponent(component_name, component)
Expand Down
Loading
Loading