Skip to content

Commit 29be8aa

Browse files
authored
Merge branch 'master' into fix-remove-kwargs
2 parents 5dd0b9f + 9f372a5 commit 29be8aa

File tree

8 files changed

+229
-80
lines changed

8 files changed

+229
-80
lines changed

src/sagemaker/image_uri_config/sagemaker-tritonserver.json

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,38 @@
77
"inference"
88
],
99
"versions": {
10+
"24.05": {
11+
"registries": {
12+
"af-south-1": "626614931356",
13+
"il-central-1": "780543022126",
14+
"ap-east-1": "871362719292",
15+
"ap-northeast-1": "763104351884",
16+
"ap-northeast-2": "763104351884",
17+
"ap-northeast-3": "364406365360",
18+
"ap-south-1": "763104351884",
19+
"ap-southeast-1": "763104351884",
20+
"ap-southeast-2": "763104351884",
21+
"ap-southeast-3": "907027046896",
22+
"ca-central-1": "763104351884",
23+
"cn-north-1": "727897471807",
24+
"cn-northwest-1": "727897471807",
25+
"eu-central-1": "763104351884",
26+
"eu-north-1": "763104351884",
27+
"eu-west-1": "763104351884",
28+
"eu-west-2": "763104351884",
29+
"eu-west-3": "763104351884",
30+
"eu-south-1": "692866216735",
31+
"me-south-1": "217643126080",
32+
"sa-east-1": "763104351884",
33+
"us-east-1": "763104351884",
34+
"us-east-2": "763104351884",
35+
"us-west-1": "763104351884",
36+
"us-west-2": "763104351884",
37+
"ca-west-1": "204538143572"
38+
},
39+
"repository": "sagemaker-tritonserver",
40+
"tag_prefix": "24.05-py3"
41+
},
1042
"24.03": {
1143
"registries": {
1244
"af-south-1": "626614931356",
@@ -104,4 +136,4 @@
104136
"tag_prefix": "23.12-py3"
105137
}
106138
}
107-
}
139+
}

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
add_jumpstart_model_info_tags,
7070
get_eula_message,
7171
get_default_jumpstart_session_with_user_agent_suffix,
72+
get_top_ranked_config_name,
7273
update_dict_if_key_not_present,
7374
resolve_estimator_sagemaker_config_field,
7475
verify_model_region_and_return_specs,
@@ -204,7 +205,9 @@ def get_init_kwargs(
204205

205206
estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs)
206207
estimator_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(estimator_init_kwargs)
207-
estimator_init_kwargs = _add_sagemaker_session_to_kwargs(estimator_init_kwargs)
208+
estimator_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(
209+
estimator_init_kwargs
210+
)
208211
estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs)
209212
estimator_init_kwargs = _add_instance_type_and_count_to_kwargs(estimator_init_kwargs)
210213
estimator_init_kwargs = _add_image_uri_to_kwargs(estimator_init_kwargs)
@@ -438,12 +441,17 @@ def _add_region_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs:
438441
return kwargs
439442

440443

441-
def _add_sagemaker_session_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs:
444+
def _add_sagemaker_session_with_custom_user_agent_to_kwargs(
445+
kwargs: JumpStartKwargs,
446+
) -> JumpStartKwargs:
442447
"""Sets session in kwargs based on default or override, returns full kwargs."""
443448
kwargs.sagemaker_session = (
444449
kwargs.sagemaker_session
445450
or get_default_jumpstart_session_with_user_agent_suffix(
446-
kwargs.model_id, kwargs.model_version, kwargs.hub_arn
451+
model_id=kwargs.model_id,
452+
model_version=kwargs.model_version,
453+
config_name=None,
454+
is_hub_content=kwargs.hub_arn is not None,
447455
)
448456
)
449457
return kwargs
@@ -903,21 +911,16 @@ def _add_config_name_to_kwargs(
903911
) -> JumpStartEstimatorInitKwargs:
904912
"""Sets tags in kwargs based on default or override, returns full kwargs."""
905913

906-
specs = verify_model_region_and_return_specs(
914+
kwargs.config_name = kwargs.config_name or get_top_ranked_config_name(
915+
region=kwargs.region,
907916
model_id=kwargs.model_id,
908-
version=kwargs.model_version,
917+
model_version=kwargs.model_version,
918+
sagemaker_session=kwargs.sagemaker_session,
909919
scope=JumpStartScriptScope.TRAINING,
910-
region=kwargs.region,
911-
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
920+
model_type=kwargs.model_type,
912921
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
913-
sagemaker_session=kwargs.sagemaker_session,
914-
config_name=kwargs.config_name,
922+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
915923
hub_arn=kwargs.hub_arn,
916924
)
917925

918-
if specs.training_configs and specs.training_configs.get_top_config_from_ranking():
919-
kwargs.config_name = (
920-
kwargs.config_name or specs.training_configs.get_top_config_from_ranking().config_name
921-
)
922-
923926
return kwargs

src/sagemaker/jumpstart/factory/model.py

Lines changed: 84 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
)
3030
from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
3131
from sagemaker.jumpstart.constants import (
32+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3233
INFERENCE_ENTRY_POINT_SCRIPT_NAME,
3334
JUMPSTART_DEFAULT_REGION_NAME,
3435
JUMPSTART_LOGGER,
@@ -54,6 +55,7 @@
5455
add_jumpstart_model_info_tags,
5556
get_default_jumpstart_session_with_user_agent_suffix,
5657
get_neo_content_bucket,
58+
get_top_ranked_config_name,
5759
update_dict_if_key_not_present,
5860
resolve_model_sagemaker_config_field,
5961
verify_model_region_and_return_specs,
@@ -155,15 +157,18 @@ def _add_region_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelIni
155157
return kwargs
156158

157159

158-
def _add_sagemaker_session_to_kwargs(
160+
def _add_sagemaker_session_with_custom_user_agent_to_kwargs(
159161
kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs]
160162
) -> JumpStartModelInitKwargs:
161163
"""Sets session in kwargs based on default or override, returns full kwargs."""
162164

163165
kwargs.sagemaker_session = (
164166
kwargs.sagemaker_session
165167
or get_default_jumpstart_session_with_user_agent_suffix(
166-
kwargs.model_id, kwargs.model_version, kwargs.hub_arn
168+
model_id=kwargs.model_id,
169+
model_version=kwargs.model_version,
170+
config_name=kwargs.config_name,
171+
is_hub_content=kwargs.hub_arn is not None,
167172
)
168173
)
169174

@@ -244,6 +249,32 @@ def _add_instance_type_to_kwargs(
244249
kwargs.instance_type,
245250
)
246251

252+
specs = verify_model_region_and_return_specs(
253+
model_id=kwargs.model_id,
254+
version=kwargs.model_version,
255+
scope=JumpStartScriptScope.INFERENCE,
256+
region=kwargs.region,
257+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
258+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
259+
sagemaker_session=kwargs.sagemaker_session,
260+
model_type=kwargs.model_type,
261+
config_name=kwargs.config_name,
262+
)
263+
264+
if specs.inference_configs and kwargs.config_name not in specs.inference_configs.configs:
265+
return kwargs
266+
267+
resolved_config = (
268+
specs.inference_configs.configs[kwargs.config_name].resolved_config
269+
if specs.inference_configs
270+
else None
271+
)
272+
if resolved_config is None:
273+
return kwargs
274+
supported_instance_types = resolved_config.get("supported_inference_instance_types", [])
275+
if kwargs.instance_type not in supported_instance_types:
276+
JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type)
277+
247278
return kwargs
248279

249280

@@ -662,38 +693,25 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
662693
ValueError: If the instance_type is not supported with the current config.
663694
"""
664695

665-
specs = verify_model_region_and_return_specs(
696+
# we need to create a default JS session (without custom user agent)
697+
# in order to retrieve config name info
698+
temp_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION
699+
700+
kwargs.config_name = kwargs.config_name or get_top_ranked_config_name(
701+
region=kwargs.region,
666702
model_id=kwargs.model_id,
667-
version=kwargs.model_version,
703+
model_version=kwargs.model_version,
704+
sagemaker_session=temp_session,
668705
scope=JumpStartScriptScope.INFERENCE,
669-
region=kwargs.region,
670-
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
671-
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
672-
sagemaker_session=kwargs.sagemaker_session,
673706
model_type=kwargs.model_type,
674-
config_name=kwargs.config_name,
707+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
708+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
675709
hub_arn=kwargs.hub_arn,
676710
)
677-
if specs.inference_configs:
678-
default_config_name = (
679-
specs.inference_configs.get_top_config_from_ranking().config_name
680-
if specs.inference_configs.get_top_config_from_ranking()
681-
else None
682-
)
683-
kwargs.config_name = kwargs.config_name or default_config_name
684-
685-
if not kwargs.config_name:
686-
return kwargs
687711

688-
if kwargs.config_name not in set(specs.inference_configs.configs.keys()):
689-
raise ValueError(
690-
f"Config {kwargs.config_name} is not supported for model {kwargs.model_id}."
691-
)
712+
if kwargs.config_name is None:
713+
return kwargs
692714

693-
resolved_config = specs.inference_configs.configs[kwargs.config_name].resolved_config
694-
supported_instance_types = resolved_config.get("supported_inference_instance_types", [])
695-
if kwargs.instance_type not in supported_instance_types:
696-
JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type)
697715
return kwargs
698716

699717

@@ -746,32 +764,41 @@ def _add_config_name_to_deploy_kwargs(
746764
ValueError: If the instance_type is not supported with the current config.
747765
"""
748766

749-
specs = verify_model_region_and_return_specs(
750-
model_id=kwargs.model_id,
751-
version=kwargs.model_version,
752-
scope=JumpStartScriptScope.INFERENCE,
753-
region=kwargs.region,
754-
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
755-
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
756-
sagemaker_session=kwargs.sagemaker_session,
757-
model_type=kwargs.model_type,
758-
config_name=kwargs.config_name,
759-
hub_arn=kwargs.hub_arn,
760-
)
767+
# we need to create a default JS session (without custom user agent)
768+
# in order to retrieve config name info
769+
temp_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION
761770

762771
if training_config_name:
763-
kwargs.config_name = _select_inference_config_from_training_config(
772+
773+
specs = verify_model_region_and_return_specs(
774+
model_id=kwargs.model_id,
775+
version=kwargs.model_version,
776+
scope=JumpStartScriptScope.INFERENCE,
777+
region=kwargs.region,
778+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
779+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
780+
sagemaker_session=temp_session,
781+
model_type=kwargs.model_type,
782+
config_name=kwargs.config_name,
783+
)
784+
default_config_name = _select_inference_config_from_training_config(
764785
specs=specs, training_config_name=training_config_name
765786
)
766-
return kwargs
767787

768-
if specs.inference_configs:
769-
default_config_name = (
770-
specs.inference_configs.get_top_config_from_ranking().config_name
771-
if specs.inference_configs.get_top_config_from_ranking()
772-
else None
788+
else:
789+
default_config_name = get_top_ranked_config_name(
790+
region=kwargs.region,
791+
model_id=kwargs.model_id,
792+
model_version=kwargs.model_version,
793+
sagemaker_session=temp_session,
794+
scope=JumpStartScriptScope.INFERENCE,
795+
model_type=kwargs.model_type,
796+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
797+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
798+
hub_arn=kwargs.hub_arn,
773799
)
774-
kwargs.config_name = kwargs.config_name or default_config_name
800+
801+
kwargs.config_name = kwargs.config_name or default_config_name
775802

776803
return kwargs
777804

@@ -850,15 +877,15 @@ def get_deploy_kwargs(
850877
routing_config=routing_config,
851878
)
852879

853-
deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs)
880+
deploy_kwargs = _add_config_name_to_deploy_kwargs(
881+
kwargs=deploy_kwargs, training_config_name=training_config_name
882+
)
854883

855884
deploy_kwargs = _add_model_version_to_kwargs(kwargs=deploy_kwargs)
856885

857-
deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs)
886+
deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(kwargs=deploy_kwargs)
858887

859-
deploy_kwargs = _add_config_name_to_deploy_kwargs(
860-
kwargs=deploy_kwargs, training_config_name=training_config_name
861-
)
888+
deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs)
862889

863890
deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs)
864891

@@ -1041,11 +1068,14 @@ def get_init_kwargs(
10411068
)
10421069

10431070
model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs)
1071+
model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs)
1072+
model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs)
10441073

1045-
model_init_kwargs = _add_sagemaker_session_to_kwargs(kwargs=model_init_kwargs)
1074+
model_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(
1075+
kwargs=model_init_kwargs
1076+
)
10461077
model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs)
10471078

1048-
model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs)
10491079
model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs)
10501080

10511081
model_init_kwargs = _add_instance_type_to_kwargs(
@@ -1073,8 +1103,6 @@ def get_init_kwargs(
10731103

10741104
model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs)
10751105

1076-
model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs)
1077-
10781106
model_init_kwargs = _add_additional_model_data_sources_to_kwargs(kwargs=model_init_kwargs)
10791107

10801108
return model_init_kwargs

src/sagemaker/jumpstart/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2458,7 +2458,7 @@ def __init__(
24582458
self.model_id = model_id
24592459
self.model_version = model_version
24602460
self.hub_arn = hub_arn
2461-
self.model_type = (model_type,)
2461+
self.model_type = model_type
24622462
self.instance_type = instance_type
24632463
self.instance_count = instance_count
24642464
self.region = region

0 commit comments

Comments
 (0)