Skip to content

Commit 8a5c58c

Browse files
committed
feat: Bedrock Store and Marketplace model support in HubService
Feat: Bedrock Store and Marketplace model support in HubService fix: Adding test for tags fix: Adding more fields fix: Fixing tag generation feat: Adding BRS and marketplace model support fix: Adding better flag fix: Making fields non-required fix: Adding more parsing logic fix: Renaming fix: Adding test fix: reverting comment fix: Adding tests fix: Adding initial prop support logic fix: Adding initial prop management fix: Adding new stuff fix: linting fix: trigger github acitons fix: revertin
1 parent f7c6cd3 commit 8a5c58c

File tree

11 files changed

+251
-60
lines changed

11 files changed

+251
-60
lines changed

src/sagemaker/jumpstart/enums.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ class VariableTypes(str, Enum):
8282
BOOL = "bool"
8383

8484

85+
class HubContentCapability(str, Enum):
86+
"""Enum class for HubContent capabilities."""
87+
88+
BEDROCK_CONSOLE = "BEDROCK_CONSOLE"
89+
90+
8591
class JumpStartTag(str, Enum):
8692
"""Enum class for tag keys to apply to JumpStart models."""
8793

@@ -99,6 +105,8 @@ class JumpStartTag(str, Enum):
99105

100106
HUB_CONTENT_ARN = "sagemaker-sdk:hub-content-arn"
101107

108+
BEDROCK = "sagemaker-sdk:bedrock"
109+
102110

103111
class SerializerType(str, Enum):
104112
"""Enum class for serializers associated with JumpStart models."""

src/sagemaker/jumpstart/factory/model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from sagemaker.model_metrics import ModelMetrics
4242
from sagemaker.metadata_properties import MetadataProperties
4343
from sagemaker.drift_check_baselines import DriftCheckBaselines
44-
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
44+
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType, HubContentCapability
4545
from sagemaker.jumpstart.types import (
4646
HubContentType,
4747
JumpStartModelDeployKwargs,
@@ -51,7 +51,9 @@
5151
)
5252
from sagemaker.jumpstart.utils import (
5353
add_hub_content_arn_tags,
54+
add_bedrock_store_tags,
5455
add_jumpstart_model_info_tags,
56+
add_bedrock_store_tags,
5557
get_default_jumpstart_session_with_user_agent_suffix,
5658
get_neo_content_bucket,
5759
get_top_ranked_config_name,
@@ -488,6 +490,10 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:
488490
)
489491
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn)
490492

493+
if hasattr(kwargs.specs, "capabilities") and kwargs.specs.capabilities is not None:
494+
if HubContentCapability.BEDROCK_CONSOLE in kwargs.specs.capabilities:
495+
kwargs.tags = add_bedrock_store_tags(kwargs.tags, compatibility="compatible")
496+
491497
return kwargs
492498

493499

src/sagemaker/jumpstart/hub/interfaces.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -451,12 +451,14 @@ def from_json(self, json_obj: str) -> None:
451451
class HubModelDocument(HubDataHolderType):
452452
"""Data class for model type HubContentDocument from session.describe_hub_content()."""
453453

454-
SCHEMA_VERSION = "2.2.0"
454+
SCHEMA_VERSION = "2.3.0"
455455

456456
__slots__ = [
457457
"url",
458458
"min_sdk_version",
459459
"training_supported",
460+
"model_types",
461+
"capabilities",
460462
"incremental_training_supported",
461463
"dynamic_container_deployment_supported",
462464
"hosting_ecr_uri",
@@ -469,6 +471,7 @@ class HubModelDocument(HubDataHolderType):
469471
"hosting_use_script_uri",
470472
"hosting_eula_uri",
471473
"hosting_model_package_arn",
474+
"model_subscription_link",
472475
"inference_configs",
473476
"inference_config_components",
474477
"inference_config_rankings",
@@ -549,18 +552,22 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
549552
Args:
550553
json_obj (Dict[str, Any]): Dictionary representation of hub model document.
551554
"""
552-
self.url: str = json_obj["Url"]
553-
self.min_sdk_version: str = json_obj["MinSdkVersion"]
554-
self.hosting_ecr_uri: Optional[str] = json_obj["HostingEcrUri"]
555-
self.hosting_artifact_uri = json_obj["HostingArtifactUri"]
556-
self.hosting_script_uri = json_obj["HostingScriptUri"]
557-
self.inference_dependencies: List[str] = json_obj["InferenceDependencies"]
555+
self.url: str = json_obj.get("Url")
556+
self.min_sdk_version: str = json_obj.get("MinSdkVersion")
557+
self.hosting_ecr_uri: Optional[str] = json_obj.get("HostingEcrUri")
558+
self.hosting_artifact_uri = json_obj.get("HostingArtifactUri")
559+
self.hosting_script_uri = json_obj.get("HostingScriptUri")
560+
self.inference_dependencies: List[str] = json_obj.get("InferenceDependencies")
558561
self.inference_environment_variables: List[JumpStartEnvironmentVariable] = [
559562
JumpStartEnvironmentVariable(env_variable, is_hub_content=True)
560-
for env_variable in json_obj["InferenceEnvironmentVariables"]
563+
for env_variable in json_obj.get("InferenceEnvironmentVariables", [])
561564
]
562-
self.training_supported: bool = bool(json_obj["TrainingSupported"])
563-
self.incremental_training_supported: bool = bool(json_obj["IncrementalTrainingSupported"])
565+
self.model_types: Optional[List[str]] = json_obj.get("ModelTypes")
566+
self.capabilities: Optional[List[str]] = json_obj.get("Capabilities")
567+
self.training_supported: bool = bool(json_obj.get("TrainingSupported"))
568+
self.incremental_training_supported: bool = bool(
569+
json_obj.get("IncrementalTrainingSupported")
570+
)
564571
self.dynamic_container_deployment_supported: Optional[bool] = (
565572
bool(json_obj.get("DynamicContainerDeploymentSupported"))
566573
if json_obj.get("DynamicContainerDeploymentSupported")
@@ -586,6 +593,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
586593
self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri")
587594
self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn")
588595

596+
self.model_subscription_link: Optional[str] = json_obj.get("ModelSubscriptionLink")
597+
589598
self.inference_config_rankings = self._get_config_rankings(json_obj)
590599
self.inference_config_components = self._get_config_components(json_obj)
591600
self.inference_configs = self._get_configs(json_obj)

src/sagemaker/jumpstart/hub/parser_utils.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,9 @@
1818
from typing import Any, Dict, List, Optional
1919

2020

21-
def camel_to_snake(camel_case_string: str) -> str:
22-
"""Converts camelCaseString or UpperCamelCaseString to snake_case_string."""
23-
snake_case_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_string)
24-
if "-" in snake_case_string:
25-
# remove any hyphen from the string for accurate conversion.
26-
snake_case_string = snake_case_string.replace("-", "")
27-
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case_string).lower()
21+
def pascal_to_snake(camel_case_string: str) -> str:
22+
"""Converts PascalCase to snake_case_string."""
23+
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case_string).lower()
2824

2925

3026
def snake_to_upper_camel(snake_case_string: str) -> str:

src/sagemaker/jumpstart/hub/parsers.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
HubModelDocument,
2828
)
2929
from sagemaker.jumpstart.hub.parser_utils import (
30-
camel_to_snake,
30+
pascal_to_snake,
3131
snake_to_upper_camel,
3232
walk_and_apply_json,
3333
)
@@ -86,7 +86,7 @@ def get_model_spec_arg_keys(
8686
arg_keys = []
8787

8888
if naming_convention == NamingConventionType.SNAKE_CASE:
89-
arg_keys = [camel_to_snake(key) for key in arg_keys]
89+
arg_keys = [pascal_to_snake(key) for key in arg_keys]
9090
elif naming_convention == NamingConventionType.UPPER_CAMEL_CASE:
9191
return arg_keys
9292
else:
@@ -137,6 +137,8 @@ def make_model_specs_from_describe_hub_content_response(
137137
hub_model_document: HubModelDocument = response.hub_content_document
138138
specs["url"] = hub_model_document.url
139139
specs["min_sdk_version"] = hub_model_document.min_sdk_version
140+
specs["model_types"] = hub_model_document.model_types
141+
specs["capabilities"] = hub_model_document.capabilities
140142
specs["training_supported"] = bool(hub_model_document.training_supported)
141143
specs["incremental_training_supported"] = bool(
142144
hub_model_document.incremental_training_supported
@@ -146,15 +148,19 @@ def make_model_specs_from_describe_hub_content_response(
146148
specs["inference_config_components"] = hub_model_document.inference_config_components
147149
specs["inference_config_rankings"] = hub_model_document.inference_config_rankings
148150

149-
hosting_artifact_bucket, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
150-
hub_model_document.hosting_artifact_uri
151-
)
152-
specs["hosting_artifact_key"] = hosting_artifact_key
153-
specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri
154-
hosting_script_bucket, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable
155-
hub_model_document.hosting_script_uri
156-
)
157-
specs["hosting_script_key"] = hosting_script_key
151+
if hub_model_document.hosting_artifact_uri:
152+
_, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
153+
hub_model_document.hosting_artifact_uri
154+
)
155+
specs["hosting_artifact_key"] = hosting_artifact_key
156+
specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri
157+
158+
if hub_model_document.hosting_script_uri:
159+
_, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable
160+
hub_model_document.hosting_script_uri
161+
)
162+
specs["hosting_script_key"] = hosting_script_key
163+
158164
specs["inference_environment_variables"] = hub_model_document.inference_environment_variables
159165
specs["inference_vulnerable"] = False
160166
specs["inference_dependencies"] = hub_model_document.inference_dependencies
@@ -201,7 +207,7 @@ def make_model_specs_from_describe_hub_content_response(
201207
default_payloads: Dict[str, Any] = {}
202208
if hub_model_document.default_payloads is not None:
203209
for alias, payload in hub_model_document.default_payloads.items():
204-
default_payloads[alias] = walk_and_apply_json(payload.to_json(), camel_to_snake)
210+
default_payloads[alias] = walk_and_apply_json(payload.to_json(), pascal_to_snake)
205211
specs["default_payloads"] = default_payloads
206212
specs["gated_bucket"] = hub_model_document.gated_bucket
207213
specs["inference_volume_size"] = hub_model_document.inference_volume_size
@@ -219,6 +225,10 @@ def make_model_specs_from_describe_hub_content_response(
219225

220226
if hub_model_document.hosting_model_package_arn:
221227
specs["hosting_model_package_arns"] = {region: hub_model_document.hosting_model_package_arn}
228+
229+
specs["model_subscription_link"] = hub_model_document.model_subscription_link
230+
231+
specs["model_subscription_link"] = hub_model_document.model_subscription_link
222232

223233
specs["hosting_use_script_uri"] = hub_model_document.hosting_use_script_uri
224234

src/sagemaker/jumpstart/hub/utils.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""This module contains utilities related to SageMaker JumpStart Hub."""
1515
from __future__ import absolute_import
1616
import re
17-
from typing import Optional
17+
from typing import Optional, List, Any
1818
from sagemaker.jumpstart.hub.types import S3ObjectLocation
1919
from sagemaker.s3_utils import parse_s3_url
2020
from sagemaker.session import Session
@@ -23,6 +23,8 @@
2323
from sagemaker.jumpstart import constants
2424
from packaging.specifiers import SpecifierSet, InvalidSpecifier
2525

26+
PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:"
27+
2628

2729
def get_info_from_hub_resource_arn(
2830
arn: str,
@@ -117,8 +119,8 @@ def generate_hub_arn_for_init_kwargs(
117119

118120
hub_arn = None
119121
if hub_name:
120-
if hub_name == constants.JUMPSTART_MODEL_HUB_NAME:
121-
return None
122+
# if hub_name == constants.JUMPSTART_MODEL_HUB_NAME:
123+
# return None
122124
match = re.match(constants.HUB_ARN_REGEX, hub_name)
123125
if match:
124126
hub_arn = hub_name
@@ -207,6 +209,24 @@ def get_hub_model_version(
207209
except Exception as ex:
208210
raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}")
209211

212+
open_weight_hub_content_version = _get_open_weight_hub_model_version(
213+
hub_content_summaries, hub_model_version
214+
)
215+
if open_weight_hub_content_version:
216+
return open_weight_hub_content_version
217+
218+
proprietary_hub_content_version = _get_proprietary_hub_model_version(
219+
hub_content_summaries, hub_model_version
220+
)
221+
if proprietary_hub_content_version:
222+
return proprietary_hub_content_version
223+
224+
raise KeyError(f"Could not find HubContent with specified version: {hub_model_version}")
225+
226+
227+
def _get_open_weight_hub_model_version(
228+
hub_content_summaries: List[Any], hub_model_version: Optional[str] = None
229+
) -> Optional[str]:
210230
available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries]
211231

212232
if hub_model_version == "*" or hub_model_version is None:
@@ -215,10 +235,41 @@ def get_hub_model_version(
215235
try:
216236
spec = SpecifierSet(f"=={hub_model_version}")
217237
except InvalidSpecifier:
218-
raise KeyError(f"Bad semantic version: {hub_model_version}")
238+
return None
219239
available_versions_filtered = list(spec.filter(available_model_versions))
220240
if not available_versions_filtered:
221-
raise KeyError("Model version not available in the Hub")
241+
return None
222242
hub_model_version = str(max(available_versions_filtered))
223243

224244
return hub_model_version
245+
246+
247+
def _get_proprietary_hub_model_version(
248+
hub_content_summaries: List[Any], proprietary_hub_model_version: str
249+
) -> Optional[str]:
250+
251+
for model in hub_content_summaries:
252+
model_search_keywords = model.get("HubContentSearchKeywords", [])
253+
if _hub_search_keywords_contains_proprietary_version(
254+
model_search_keywords, proprietary_hub_model_version
255+
):
256+
return model.get("HubContentVersion")
257+
258+
return None
259+
260+
261+
def _hub_search_keywords_contains_proprietary_version(
262+
model_search_keywords: List[str], proprietary_hub_model_version: str
263+
) -> bool:
264+
proprietary_version_keyword = next(
265+
filter(lambda s: s.startswith(PROPRIETARY_VERSION_KEYWORD), model_search_keywords), None
266+
)
267+
268+
if not proprietary_version_keyword:
269+
return False
270+
271+
proprietary_version = proprietary_version_keyword.lstrip(PROPRIETARY_VERSION_KEYWORD)
272+
if proprietary_version == proprietary_hub_model_version:
273+
return True
274+
275+
return False

0 commit comments

Comments
 (0)