Skip to content

Commit fc37065

Browse files
authored
Merge branch 'master' into master
2 parents 7215bd6 + e710f1d commit fc37065

File tree

14 files changed

+423
-50
lines changed

14 files changed

+423
-50
lines changed

src/sagemaker/image_uri_config/djl-lmi.json

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,44 @@
33
"inference"
44
],
55
"version_aliases": {
6-
"latest": "0.29.0"
6+
"latest": "0.30.0"
77
},
88
"versions": {
9+
"0.30.0": {
10+
"registries": {
11+
"af-south-1": "626614931356",
12+
"il-central-1": "780543022126",
13+
"ap-east-1": "871362719292",
14+
"ap-northeast-1": "763104351884",
15+
"ap-northeast-2": "763104351884",
16+
"ap-northeast-3": "364406365360",
17+
"ap-south-1": "763104351884",
18+
"ap-southeast-1": "763104351884",
19+
"ap-southeast-2": "763104351884",
20+
"ap-southeast-3": "907027046896",
21+
"ca-central-1": "763104351884",
22+
"cn-north-1": "727897471807",
23+
"cn-northwest-1": "727897471807",
24+
"eu-central-1": "763104351884",
25+
"eu-north-1": "763104351884",
26+
"eu-west-1": "763104351884",
27+
"eu-west-2": "763104351884",
28+
"eu-west-3": "763104351884",
29+
"eu-south-1": "692866216735",
30+
"me-south-1": "217643126080",
31+
"me-central-1": "914824155844",
32+
"sa-east-1": "763104351884",
33+
"us-east-1": "763104351884",
34+
"us-east-2": "763104351884",
35+
"us-gov-east-1": "446045086412",
36+
"us-gov-west-1": "442386744353",
37+
"us-west-1": "763104351884",
38+
"us-west-2": "763104351884",
39+
"ca-west-1": "204538143572"
40+
},
41+
"repository": "djl-inference",
42+
"tag_prefix": "0.30.0-lmi12.0.0-cu124"
43+
},
944
"0.29.0": {
1045
"registries": {
1146
"af-south-1": "626614931356",

src/sagemaker/image_uri_config/djl-tensorrtllm.json

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,43 @@
33
"inference"
44
],
55
"version_aliases": {
6-
"latest": "0.29.0"
6+
"latest": "0.30.0"
77
},
88
"versions": {
9+
"0.30.0": {
10+
"registries": {
11+
"af-south-1": "626614931356",
12+
"il-central-1": "780543022126",
13+
"ap-east-1": "871362719292",
14+
"ap-northeast-1": "763104351884",
15+
"ap-northeast-2": "763104351884",
16+
"ap-northeast-3": "364406365360",
17+
"ap-south-1": "763104351884",
18+
"ap-southeast-1": "763104351884",
19+
"ap-southeast-2": "763104351884",
20+
"ap-southeast-3": "907027046896",
21+
"ca-central-1": "763104351884",
22+
"cn-north-1": "727897471807",
23+
"cn-northwest-1": "727897471807",
24+
"eu-central-1": "763104351884",
25+
"eu-north-1": "763104351884",
26+
"eu-west-1": "763104351884",
27+
"eu-west-2": "763104351884",
28+
"eu-west-3": "763104351884",
29+
"eu-south-1": "692866216735",
30+
"me-south-1": "217643126080",
31+
"sa-east-1": "763104351884",
32+
"us-east-1": "763104351884",
33+
"us-east-2": "763104351884",
34+
"us-gov-east-1": "446045086412",
35+
"us-gov-west-1": "442386744353",
36+
"us-west-1": "763104351884",
37+
"us-west-2": "763104351884",
38+
"ca-west-1": "204538143572"
39+
},
40+
"repository": "djl-inference",
41+
"tag_prefix": "0.30.0-tensorrtllm0.12.0-cu125"
42+
},
943
"0.29.0": {
1044
"registries": {
1145
"af-south-1": "626614931356",

src/sagemaker/image_uri_config/pytorch-smp.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"2.2": "2.3.1",
1010
"2.2.0": "2.3.1",
1111
"2.3.1": "2.5.0",
12-
"2.4.1": "2.6.0"
12+
"2.4.1": "2.6.1"
1313
},
1414
"versions": {
1515
"2.0.1": {
@@ -162,7 +162,7 @@
162162
},
163163
"repository": "smdistributed-modelparallel"
164164
},
165-
"2.6.0": {
165+
"2.6.1": {
166166
"py_versions": [
167167
"py311"
168168
],

src/sagemaker/jumpstart/hub/interfaces.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -451,12 +451,14 @@ def from_json(self, json_obj: str) -> None:
451451
class HubModelDocument(HubDataHolderType):
452452
"""Data class for model type HubContentDocument from session.describe_hub_content()."""
453453

454-
SCHEMA_VERSION = "2.2.0"
454+
SCHEMA_VERSION = "2.3.0"
455455

456456
__slots__ = [
457457
"url",
458458
"min_sdk_version",
459459
"training_supported",
460+
"model_types",
461+
"capabilities",
460462
"incremental_training_supported",
461463
"dynamic_container_deployment_supported",
462464
"hosting_ecr_uri",
@@ -469,6 +471,7 @@ class HubModelDocument(HubDataHolderType):
469471
"hosting_use_script_uri",
470472
"hosting_eula_uri",
471473
"hosting_model_package_arn",
474+
"model_subscription_link",
472475
"inference_configs",
473476
"inference_config_components",
474477
"inference_config_rankings",
@@ -549,18 +552,22 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
549552
Args:
550553
json_obj (Dict[str, Any]): Dictionary representation of hub model document.
551554
"""
552-
self.url: str = json_obj["Url"]
553-
self.min_sdk_version: str = json_obj["MinSdkVersion"]
554-
self.hosting_ecr_uri: Optional[str] = json_obj["HostingEcrUri"]
555-
self.hosting_artifact_uri = json_obj["HostingArtifactUri"]
556-
self.hosting_script_uri = json_obj["HostingScriptUri"]
557-
self.inference_dependencies: List[str] = json_obj["InferenceDependencies"]
555+
self.url: str = json_obj.get("Url")
556+
self.min_sdk_version: str = json_obj.get("MinSdkVersion")
557+
self.hosting_ecr_uri: Optional[str] = json_obj.get("HostingEcrUri")
558+
self.hosting_artifact_uri = json_obj.get("HostingArtifactUri")
559+
self.hosting_script_uri = json_obj.get("HostingScriptUri")
560+
self.inference_dependencies: List[str] = json_obj.get("InferenceDependencies")
558561
self.inference_environment_variables: List[JumpStartEnvironmentVariable] = [
559562
JumpStartEnvironmentVariable(env_variable, is_hub_content=True)
560-
for env_variable in json_obj["InferenceEnvironmentVariables"]
563+
for env_variable in json_obj.get("InferenceEnvironmentVariables", [])
561564
]
562-
self.training_supported: bool = bool(json_obj["TrainingSupported"])
563-
self.incremental_training_supported: bool = bool(json_obj["IncrementalTrainingSupported"])
565+
self.model_types: Optional[List[str]] = json_obj.get("ModelTypes")
566+
self.capabilities: Optional[List[str]] = json_obj.get("Capabilities")
567+
self.training_supported: bool = bool(json_obj.get("TrainingSupported"))
568+
self.incremental_training_supported: bool = bool(
569+
json_obj.get("IncrementalTrainingSupported")
570+
)
564571
self.dynamic_container_deployment_supported: Optional[bool] = (
565572
bool(json_obj.get("DynamicContainerDeploymentSupported"))
566573
if json_obj.get("DynamicContainerDeploymentSupported")
@@ -586,6 +593,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
586593
self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri")
587594
self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn")
588595

596+
self.model_subscription_link: Optional[str] = json_obj.get("ModelSubscriptionLink")
597+
589598
self.inference_config_rankings = self._get_config_rankings(json_obj)
590599
self.inference_config_components = self._get_config_components(json_obj)
591600
self.inference_configs = self._get_configs(json_obj)

src/sagemaker/jumpstart/hub/parser_utils.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@
1919

2020

2121
def camel_to_snake(camel_case_string: str) -> str:
22-
"""Converts camelCaseString or UpperCamelCaseString to snake_case_string."""
23-
snake_case_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_string)
24-
if "-" in snake_case_string:
25-
# remove any hyphen from the string for accurate conversion.
26-
snake_case_string = snake_case_string.replace("-", "")
27-
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case_string).lower()
22+
"""Converts PascalCase to snake_case_string using a regex.
23+
24+
This regex cannot handle whitespace ("PascalString TwoWords")
25+
"""
26+
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case_string).lower()
2827

2928

3029
def snake_to_upper_camel(snake_case_string: str) -> str:

src/sagemaker/jumpstart/hub/parsers.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ def make_model_specs_from_describe_hub_content_response(
137137
hub_model_document: HubModelDocument = response.hub_content_document
138138
specs["url"] = hub_model_document.url
139139
specs["min_sdk_version"] = hub_model_document.min_sdk_version
140+
specs["model_types"] = hub_model_document.model_types
141+
specs["capabilities"] = hub_model_document.capabilities
140142
specs["training_supported"] = bool(hub_model_document.training_supported)
141143
specs["incremental_training_supported"] = bool(
142144
hub_model_document.incremental_training_supported
@@ -146,15 +148,19 @@ def make_model_specs_from_describe_hub_content_response(
146148
specs["inference_config_components"] = hub_model_document.inference_config_components
147149
specs["inference_config_rankings"] = hub_model_document.inference_config_rankings
148150

149-
hosting_artifact_bucket, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
150-
hub_model_document.hosting_artifact_uri
151-
)
152-
specs["hosting_artifact_key"] = hosting_artifact_key
153-
specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri
154-
hosting_script_bucket, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable
155-
hub_model_document.hosting_script_uri
156-
)
157-
specs["hosting_script_key"] = hosting_script_key
151+
if hub_model_document.hosting_artifact_uri:
152+
_, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable
153+
hub_model_document.hosting_artifact_uri
154+
)
155+
specs["hosting_artifact_key"] = hosting_artifact_key
156+
specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri
157+
158+
if hub_model_document.hosting_script_uri:
159+
_, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable
160+
hub_model_document.hosting_script_uri
161+
)
162+
specs["hosting_script_key"] = hosting_script_key
163+
158164
specs["inference_environment_variables"] = hub_model_document.inference_environment_variables
159165
specs["inference_vulnerable"] = False
160166
specs["inference_dependencies"] = hub_model_document.inference_dependencies
@@ -220,6 +226,8 @@ def make_model_specs_from_describe_hub_content_response(
220226
if hub_model_document.hosting_model_package_arn:
221227
specs["hosting_model_package_arns"] = {region: hub_model_document.hosting_model_package_arn}
222228

229+
specs["model_subscription_link"] = hub_model_document.model_subscription_link
230+
223231
specs["hosting_use_script_uri"] = hub_model_document.hosting_use_script_uri
224232

225233
specs["hosting_instance_type_variants"] = hub_model_document.hosting_instance_type_variants

src/sagemaker/jumpstart/hub/utils.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""This module contains utilities related to SageMaker JumpStart Hub."""
1515
from __future__ import absolute_import
1616
import re
17-
from typing import Optional
17+
from typing import Optional, List, Any
1818
from sagemaker.jumpstart.hub.types import S3ObjectLocation
1919
from sagemaker.s3_utils import parse_s3_url
2020
from sagemaker.session import Session
@@ -23,6 +23,14 @@
2323
from sagemaker.jumpstart import constants
2424
from packaging.specifiers import SpecifierSet, InvalidSpecifier
2525

26+
PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:"
27+
28+
29+
def _convert_str_to_optional(string: str) -> Optional[str]:
30+
if string == "None":
31+
string = None
32+
return string
33+
2634

2735
def get_info_from_hub_resource_arn(
2836
arn: str,
@@ -37,7 +45,7 @@ def get_info_from_hub_resource_arn(
3745
hub_name = match.group(4)
3846
hub_content_type = match.group(5)
3947
hub_content_name = match.group(6)
40-
hub_content_version = match.group(7)
48+
hub_content_version = _convert_str_to_optional(match.group(7))
4149

4250
return HubArnExtractedInfo(
4351
partition=partition,
@@ -194,10 +202,14 @@ def get_hub_model_version(
194202
hub_model_version: Optional[str] = None,
195203
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
196204
) -> str:
197-
"""Returns available Jumpstart hub model version
205+
"""Returns available Jumpstart hub model version.
206+
207+
It will attempt both a semantic HubContent version search and Marketplace version search.
208+
If the Marketplace version is also semantic, this function will default to HubContent version.
198209
199210
Raises:
200211
ClientError: If the specified model is not found in the hub.
212+
KeyError: If the specified model version is not found.
201213
"""
202214

203215
try:
@@ -207,6 +219,22 @@ def get_hub_model_version(
207219
except Exception as ex:
208220
raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}")
209221

222+
try:
223+
return _get_hub_model_version_for_open_weight_version(
224+
hub_content_summaries, hub_model_version
225+
)
226+
except KeyError:
227+
marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version(
228+
hub_content_summaries, hub_model_version
229+
)
230+
if marketplace_hub_content_version:
231+
return marketplace_hub_content_version
232+
raise
233+
234+
235+
def _get_hub_model_version_for_open_weight_version(
236+
hub_content_summaries: List[Any], hub_model_version: Optional[str] = None
237+
) -> str:
210238
available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries]
211239

212240
if hub_model_version == "*" or hub_model_version is None:
@@ -222,3 +250,37 @@ def get_hub_model_version(
222250
hub_model_version = str(max(available_versions_filtered))
223251

224252
return hub_model_version
253+
254+
255+
def _get_hub_model_version_for_marketplace_version(
256+
hub_content_summaries: List[Any], marketplace_version: str
257+
) -> Optional[str]:
258+
"""Returns the HubContent version associated with the Marketplace version.
259+
260+
This function will check within the HubContentSearchKeywords for the proprietary version.
261+
"""
262+
for model in hub_content_summaries:
263+
model_search_keywords = model.get("HubContentSearchKeywords", [])
264+
if _hub_search_keywords_contains_marketplace_version(
265+
model_search_keywords, marketplace_version
266+
):
267+
return model.get("HubContentVersion")
268+
269+
return None
270+
271+
272+
def _hub_search_keywords_contains_marketplace_version(
273+
model_search_keywords: List[str], marketplace_version: str
274+
) -> bool:
275+
proprietary_version_keyword = next(
276+
filter(lambda s: s.startswith(PROPRIETARY_VERSION_KEYWORD), model_search_keywords), None
277+
)
278+
279+
if not proprietary_version_keyword:
280+
return False
281+
282+
proprietary_version = proprietary_version_keyword.lstrip(PROPRIETARY_VERSION_KEYWORD)
283+
if proprietary_version == marketplace_version:
284+
return True
285+
286+
return False

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,6 +1200,8 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
12001200
"url",
12011201
"version",
12021202
"min_sdk_version",
1203+
"model_types",
1204+
"capabilities",
12031205
"incremental_training_supported",
12041206
"hosting_ecr_specs",
12051207
"hosting_ecr_uri",
@@ -1287,6 +1289,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
12871289
json_obj.get("incremental_training_supported", False)
12881290
)
12891291
if self._is_hub_content:
1292+
self.capabilities: Optional[List[str]] = json_obj.get("capabilities")
1293+
self.model_types: Optional[List[str]] = json_obj.get("model_types")
12901294
self.hosting_ecr_uri: Optional[str] = json_obj.get("hosting_ecr_uri")
12911295
self._non_serializable_slots.append("hosting_ecr_specs")
12921296
else:

0 commit comments

Comments
 (0)