Skip to content

Commit 836deee

Browse files
Merge branch 'master' into master
2 parents 6bb4e7e + 811d3ae commit 836deee

File tree

29 files changed

+662
-157
lines changed

29 files changed

+662
-157
lines changed

CHANGELOG.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
# Changelog
22

3+
## v2.226.1 (2024-07-17)
4+
5+
## v2.226.0 (2024-07-12)
6+
7+
### Features
8+
9+
* Curated hub improvements
10+
* InferenceSpec support for MMS and testing
11+
12+
### Bug Fixes and Other Changes
13+
14+
* ModelBuilder not passing HF_TOKEN to model.
15+
* update image_uri_configs 07-10-2024 07:18:04 PST
16+
317
## v2.225.0 (2024-07-10)
418

519
### Features

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.225.1.dev0
1+
2.226.2.dev0

src/sagemaker/image_uri_config/pytorch.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,7 @@
10081008
"us-gov-west-1": "442386744353",
10091009
"us-iso-east-1": "886529160074",
10101010
"us-isob-east-1": "094389454867",
1011+
"us-isof-south-1": "454834333376",
10111012
"us-west-1": "763104351884",
10121013
"us-west-2": "763104351884"
10131014
},
@@ -1051,6 +1052,7 @@
10511052
"us-gov-west-1": "442386744353",
10521053
"us-iso-east-1": "886529160074",
10531054
"us-isob-east-1": "094389454867",
1055+
"us-isof-south-1": "454834333376",
10541056
"us-west-1": "763104351884",
10551057
"us-west-2": "763104351884"
10561058
},
@@ -2329,6 +2331,7 @@
23292331
"us-gov-west-1": "442386744353",
23302332
"us-iso-east-1": "886529160074",
23312333
"us-isob-east-1": "094389454867",
2334+
"us-isof-south-1": "454834333376",
23322335
"us-west-1": "763104351884",
23332336
"us-west-2": "763104351884"
23342337
},
@@ -2372,6 +2375,7 @@
23722375
"us-gov-west-1": "442386744353",
23732376
"us-iso-east-1": "886529160074",
23742377
"us-isob-east-1": "094389454867",
2378+
"us-isof-south-1": "454834333376",
23752379
"us-west-1": "763104351884",
23762380
"us-west-2": "763104351884"
23772381
},
@@ -2415,6 +2419,7 @@
24152419
"us-gov-west-1": "442386744353",
24162420
"us-iso-east-1": "886529160074",
24172421
"us-isob-east-1": "094389454867",
2422+
"us-isof-south-1": "454834333376",
24182423
"us-west-1": "763104351884",
24192424
"us-west-2": "763104351884"
24202425
},

src/sagemaker/image_uri_config/tensorflow.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2140,6 +2140,7 @@
21402140
"us-gov-west-1": "442386744353",
21412141
"us-iso-east-1": "886529160074",
21422142
"us-isob-east-1": "094389454867",
2143+
"us-isof-south-1": "454834333376",
21432144
"us-west-1": "763104351884",
21442145
"us-west-2": "763104351884"
21452146
},
@@ -2180,6 +2181,7 @@
21802181
"us-gov-west-1": "442386744353",
21812182
"us-iso-east-1": "886529160074",
21822183
"us-isob-east-1": "094389454867",
2184+
"us-isof-south-1": "454834333376",
21832185
"us-west-1": "763104351884",
21842186
"us-west-2": "763104351884"
21852187
},
@@ -4352,6 +4354,7 @@
43524354
"us-gov-west-1": "442386744353",
43534355
"us-iso-east-1": "886529160074",
43544356
"us-isob-east-1": "094389454867",
4357+
"us-isof-south-1": "454834333376",
43554358
"us-west-1": "763104351884",
43564359
"us-west-2": "763104351884"
43574360
},
@@ -4395,6 +4398,7 @@
43954398
"us-gov-west-1": "442386744353",
43964399
"us-iso-east-1": "886529160074",
43974400
"us-isob-east-1": "094389454867",
4401+
"us-isof-south-1": "454834333376",
43984402
"us-west-1": "763104351884",
43994403
"us-west-2": "763104351884"
44004404
},

src/sagemaker/jumpstart/accessors.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""This module contains accessors related to SageMaker JumpStart."""
1515
from __future__ import absolute_import
1616
import functools
17+
import logging
1718
from typing import Any, Dict, List, Optional
1819
import boto3
1920

@@ -289,15 +290,6 @@ def get_model_specs(
289290

290291
if hub_arn:
291292
try:
292-
hub_model_arn = construct_hub_model_arn_from_inputs(
293-
hub_arn=hub_arn, model_name=model_id, version=version
294-
)
295-
model_specs = JumpStartModelsAccessor._cache.get_hub_model(
296-
hub_model_arn=hub_model_arn
297-
)
298-
model_specs.set_hub_content_type(HubContentType.MODEL)
299-
return model_specs
300-
except: # noqa: E722
301293
hub_model_arn = construct_hub_model_reference_arn_from_inputs(
302294
hub_arn=hub_arn, model_name=model_id, version=version
303295
)
@@ -307,6 +299,21 @@ def get_model_specs(
307299
model_specs.set_hub_content_type(HubContentType.MODEL_REFERENCE)
308300
return model_specs
309301

302+
except Exception as ex:
303+
logging.info(
304+
"Received exeption while calling APIs for ContentType ModelReference, \
305+
retrying with ContentType Model: "
306+
+ str(ex)
307+
)
308+
hub_model_arn = construct_hub_model_arn_from_inputs(
309+
hub_arn=hub_arn, model_name=model_id, version=version
310+
)
311+
model_specs = JumpStartModelsAccessor._cache.get_hub_model(
312+
hub_model_arn=hub_model_arn
313+
)
314+
model_specs.set_hub_content_type(HubContentType.MODEL)
315+
return model_specs
316+
310317
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
311318
model_id=model_id, version_str=version, model_type=model_type
312319
)

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
_retrieve_model_package_model_artifact_s3_uri,
3030
)
3131
from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
32+
from sagemaker.jumpstart.hub.utils import (
33+
construct_hub_model_arn_from_inputs,
34+
construct_hub_model_reference_arn_from_inputs,
35+
)
3236
from sagemaker.session import Session
3337
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
3438
from sagemaker.base_deserializers import BaseDeserializer
@@ -52,6 +56,7 @@
5256
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
5357
from sagemaker.jumpstart.factory import model
5458
from sagemaker.jumpstart.types import (
59+
HubContentType,
5560
JumpStartEstimatorDeployKwargs,
5661
JumpStartEstimatorFitKwargs,
5762
JumpStartEstimatorInitKwargs,
@@ -203,6 +208,11 @@ def get_init_kwargs(
203208
estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs)
204209
estimator_init_kwargs = _add_instance_type_and_count_to_kwargs(estimator_init_kwargs)
205210
estimator_init_kwargs = _add_image_uri_to_kwargs(estimator_init_kwargs)
211+
if hub_arn:
212+
estimator_init_kwargs = _add_model_reference_arn_to_kwargs(kwargs=estimator_init_kwargs)
213+
else:
214+
estimator_init_kwargs.model_reference_arn = None
215+
estimator_init_kwargs.hub_content_type = None
206216
estimator_init_kwargs = _add_model_uri_to_kwargs(estimator_init_kwargs)
207217
estimator_init_kwargs = _add_source_dir_to_kwargs(estimator_init_kwargs)
208218
estimator_init_kwargs = _add_entry_point_to_kwargs(estimator_init_kwargs)
@@ -433,7 +443,7 @@ def _add_sagemaker_session_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs
433443
kwargs.sagemaker_session = (
434444
kwargs.sagemaker_session
435445
or get_default_jumpstart_session_with_user_agent_suffix(
436-
kwargs.model_id, kwargs.model_version
446+
kwargs.model_id, kwargs.model_version, kwargs.hub_arn
437447
)
438448
)
439449
return kwargs
@@ -528,7 +538,15 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima
528538
)
529539

530540
if kwargs.hub_arn:
531-
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, kwargs.hub_arn)
541+
if kwargs.model_reference_arn:
542+
hub_content_arn = construct_hub_model_reference_arn_from_inputs(
543+
kwargs.hub_arn, kwargs.model_id, kwargs.model_version
544+
)
545+
else:
546+
hub_content_arn = construct_hub_model_arn_from_inputs(
547+
kwargs.hub_arn, kwargs.model_id, kwargs.model_version
548+
)
549+
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn)
532550

533551
return kwargs
534552

@@ -553,6 +571,33 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
553571
return kwargs
554572

555573

574+
def _add_model_reference_arn_to_kwargs(
575+
kwargs: JumpStartEstimatorInitKwargs,
576+
) -> JumpStartEstimatorInitKwargs:
577+
"""Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs."""
578+
579+
hub_content_type = verify_model_region_and_return_specs(
580+
model_id=kwargs.model_id,
581+
version=kwargs.model_version,
582+
hub_arn=kwargs.hub_arn,
583+
scope=JumpStartScriptScope.TRAINING,
584+
region=kwargs.region,
585+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
586+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
587+
sagemaker_session=kwargs.sagemaker_session,
588+
model_type=kwargs.model_type,
589+
).hub_content_type
590+
kwargs.hub_content_type = hub_content_type if kwargs.hub_arn else None
591+
592+
if hub_content_type == HubContentType.MODEL_REFERENCE:
593+
kwargs.model_reference_arn = construct_hub_model_reference_arn_from_inputs(
594+
hub_arn=kwargs.hub_arn, model_name=kwargs.model_id, version=kwargs.model_version
595+
)
596+
else:
597+
kwargs.model_reference_arn = None
598+
return kwargs
599+
600+
556601
def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs:
557602
"""Sets model uri in kwargs based on default or override, returns full kwargs."""
558603

src/sagemaker/jumpstart/factory/model.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@
3434
JUMPSTART_LOGGER,
3535
)
3636
from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard
37-
from sagemaker.jumpstart.hub.utils import construct_hub_model_reference_arn_from_inputs
37+
from sagemaker.jumpstart.hub.utils import (
38+
construct_hub_model_arn_from_inputs,
39+
construct_hub_model_reference_arn_from_inputs,
40+
)
3841
from sagemaker.model_metrics import ModelMetrics
3942
from sagemaker.metadata_properties import MetadataProperties
4043
from sagemaker.drift_check_baselines import DriftCheckBaselines
@@ -156,12 +159,14 @@ def _add_sagemaker_session_to_kwargs(
156159
kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs]
157160
) -> JumpStartModelInitKwargs:
158161
"""Sets session in kwargs based on default or override, returns full kwargs."""
162+
159163
kwargs.sagemaker_session = (
160164
kwargs.sagemaker_session
161165
or get_default_jumpstart_session_with_user_agent_suffix(
162-
kwargs.model_id, kwargs.model_version
166+
kwargs.model_id, kwargs.model_version, kwargs.hub_arn
163167
)
164168
)
169+
165170
return kwargs
166171

167172

@@ -273,6 +278,7 @@ def _add_model_reference_arn_to_kwargs(
273278
kwargs: JumpStartModelInitKwargs,
274279
) -> JumpStartModelInitKwargs:
275280
"""Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs."""
281+
276282
hub_content_type = verify_model_region_and_return_specs(
277283
model_id=kwargs.model_id,
278284
version=kwargs.model_version,
@@ -573,7 +579,15 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:
573579
)
574580

575581
if kwargs.hub_arn:
576-
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, kwargs.hub_arn)
582+
if kwargs.model_reference_arn:
583+
hub_content_arn = construct_hub_model_reference_arn_from_inputs(
584+
kwargs.hub_arn, kwargs.model_id, kwargs.model_version
585+
)
586+
else:
587+
hub_content_arn = construct_hub_model_arn_from_inputs(
588+
kwargs.hub_arn, kwargs.model_id, kwargs.model_version
589+
)
590+
kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn)
577591

578592
return kwargs
579593

src/sagemaker/jumpstart/hub/hub.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from sagemaker.session import Session
2424

2525
from sagemaker.jumpstart.constants import (
26-
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2726
JUMPSTART_LOGGER,
2827
)
2928
from sagemaker.jumpstart.types import (
@@ -68,7 +67,7 @@ def __init__(
6867
self,
6968
hub_name: str,
7069
bucket_name: Optional[str] = None,
71-
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
70+
sagemaker_session: Optional[Session] = None,
7271
) -> None:
7372
"""Instantiates a SageMaker ``Hub``.
7473
@@ -79,7 +78,10 @@ def __init__(
7978
"""
8079
self.hub_name = hub_name
8180
self.region = sagemaker_session.boto_region_name
82-
self._sagemaker_session = sagemaker_session
81+
self._sagemaker_session = (
82+
sagemaker_session
83+
or utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True)
84+
)
8385
self.hub_storage_location = self._generate_hub_storage_location(bucket_name)
8486

8587
def _fetch_hub_bucket_name(self) -> str:
@@ -274,8 +276,8 @@ def describe_model(
274276
try:
275277
model_version = get_hub_model_version(
276278
hub_model_name=model_name,
277-
hub_model_type=HubContentType.MODEL.value,
278-
hub_name=self.hub_name,
279+
hub_model_type=HubContentType.MODEL_REFERENCE.value,
280+
hub_name=self.hub_name if not hub_name else hub_name,
279281
sagemaker_session=self._sagemaker_session,
280282
hub_model_version=model_version,
281283
)
@@ -284,24 +286,27 @@ def describe_model(
284286
hub_name=self.hub_name if not hub_name else hub_name,
285287
hub_content_name=model_name,
286288
hub_content_version=model_version,
287-
hub_content_type=HubContentType.MODEL.value,
289+
hub_content_type=HubContentType.MODEL_REFERENCE.value,
288290
)
289291

290292
except Exception as ex:
291-
logging.info("Recieved expection while calling APIs for ContentType Model: " + str(ex))
293+
logging.info(
294+
"Received exeption while calling APIs for ContentType ModelReference, retrying with ContentType Model: "
295+
+ str(ex)
296+
)
292297
model_version = get_hub_model_version(
293298
hub_model_name=model_name,
294-
hub_model_type=HubContentType.MODEL_REFERENCE.value,
295-
hub_name=self.hub_name,
299+
hub_model_type=HubContentType.MODEL.value,
300+
hub_name=self.hub_name if not hub_name else hub_name,
296301
sagemaker_session=self._sagemaker_session,
297302
hub_model_version=model_version,
298303
)
299304

300305
hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
301-
hub_name=self.hub_name,
306+
hub_name=self.hub_name if not hub_name else hub_name,
302307
hub_content_name=model_name,
303308
hub_content_version=model_version,
304-
hub_content_type=HubContentType.MODEL_REFERENCE.value,
309+
hub_content_type=HubContentType.MODEL.value,
305310
)
306311

307312
return DescribeHubContentResponse(hub_content_description)

src/sagemaker/jumpstart/hub/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,11 @@ def get_hub_model_version(
193193
hub_model_version: Optional[str] = None,
194194
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
195195
) -> str:
196-
"""Returns available Jumpstart hub model version"""
196+
"""Returns available Jumpstart hub model version
197+
198+
Raises:
199+
ClientError: If the specified model is not found in the hub.
200+
"""
197201

198202
try:
199203
hub_content_summaries = sagemaker_session.list_hub_content_versions(

src/sagemaker/jumpstart/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,7 @@ def attach(
527527
model_id: Optional[str] = None,
528528
model_version: Optional[str] = None,
529529
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
530+
hub_name: Optional[str] = None,
530531
) -> "JumpStartModel":
531532
"""Attaches a JumpStartModel object to an existing SageMaker Endpoint.
532533
@@ -552,6 +553,7 @@ def attach(
552553
model_id=model_id,
553554
model_version=model_version,
554555
sagemaker_session=sagemaker_session,
556+
hub_name=hub_name,
555557
)
556558
model.endpoint_name = endpoint_name
557559
model.inference_component_name = inference_component_name

0 commit comments

Comments
 (0)