Skip to content

Commit 78874b5

Browse files
committed
chore: telemetry for deployment configs
1 parent 3b99806 commit 78874b5

File tree

7 files changed

+171
-65
lines changed

7 files changed

+171
-65
lines changed

src/sagemaker/jumpstart/exceptions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def __init__(
150150
model. (Default: None).
151151
152152
"""
153+
version = version or "*"
153154
if message:
154155
self.message = message
155156
else:
@@ -198,6 +199,7 @@ def __init__(
198199
version: Optional[str] = None,
199200
message: Optional[str] = None,
200201
):
202+
version = version or "*"
201203
if message:
202204
self.message = message
203205
else:

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
add_jumpstart_model_info_tags,
7070
get_eula_message,
7171
get_default_jumpstart_session_with_user_agent_suffix,
72+
get_top_ranked_config_name,
7273
update_dict_if_key_not_present,
7374
resolve_estimator_sagemaker_config_field,
7475
verify_model_region_and_return_specs,
@@ -204,7 +205,7 @@ def get_init_kwargs(
204205

205206
estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs)
206207
estimator_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(estimator_init_kwargs)
207-
estimator_init_kwargs = _add_sagemaker_session_to_kwargs(estimator_init_kwargs)
208+
estimator_init_kwargs = _add_sagemaker_session_with_user_agent_to_kwargs(estimator_init_kwargs)
208209
estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs)
209210
estimator_init_kwargs = _add_instance_type_and_count_to_kwargs(estimator_init_kwargs)
210211
estimator_init_kwargs = _add_image_uri_to_kwargs(estimator_init_kwargs)
@@ -438,12 +439,15 @@ def _add_region_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs:
438439
return kwargs
439440

440441

441-
def _add_sagemaker_session_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs:
442+
def _add_sagemaker_session_with_user_agent_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs:
442443
"""Sets session in kwargs based on default or override, returns full kwargs."""
443444
kwargs.sagemaker_session = (
444445
kwargs.sagemaker_session
445446
or get_default_jumpstart_session_with_user_agent_suffix(
446-
kwargs.model_id, kwargs.model_version, kwargs.hub_arn
447+
model_id=kwargs.model_id,
448+
model_version=kwargs.model_version,
449+
config_name=None,
450+
is_hub_content=kwargs.hub_arn is not None,
447451
)
448452
)
449453
return kwargs
@@ -903,20 +907,16 @@ def _add_config_name_to_kwargs(
903907
) -> JumpStartEstimatorInitKwargs:
904908
"""Sets tags in kwargs based on default or override, returns full kwargs."""
905909

906-
specs = verify_model_region_and_return_specs(
910+
kwargs.config_name = kwargs.config_name or get_top_ranked_config_name(
911+
region=kwargs.region,
907912
model_id=kwargs.model_id,
908-
version=kwargs.model_version,
913+
model_version=kwargs.model_version,
914+
sagemaker_session=kwargs.sagemaker_session,
909915
scope=JumpStartScriptScope.TRAINING,
910-
region=kwargs.region,
911-
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
916+
model_type=kwargs.model_type,
912917
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
913-
sagemaker_session=kwargs.sagemaker_session,
914-
config_name=kwargs.config_name,
918+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
919+
hub_arn=kwargs.hub_arn,
915920
)
916921

917-
if specs.training_configs and specs.training_configs.get_top_config_from_ranking():
918-
kwargs.config_name = (
919-
kwargs.config_name or specs.training_configs.get_top_config_from_ranking().config_name
920-
)
921-
922922
return kwargs

src/sagemaker/jumpstart/factory/model.py

Lines changed: 74 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
)
3030
from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
3131
from sagemaker.jumpstart.constants import (
32+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3233
INFERENCE_ENTRY_POINT_SCRIPT_NAME,
3334
JUMPSTART_DEFAULT_REGION_NAME,
3435
JUMPSTART_LOGGER,
@@ -54,6 +55,7 @@
5455
add_jumpstart_model_info_tags,
5556
get_default_jumpstart_session_with_user_agent_suffix,
5657
get_neo_content_bucket,
58+
get_top_ranked_config_name,
5759
update_dict_if_key_not_present,
5860
resolve_model_sagemaker_config_field,
5961
verify_model_region_and_return_specs,
@@ -155,15 +157,15 @@ def _add_region_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelIni
155157
return kwargs
156158

157159

158-
def _add_sagemaker_session_to_kwargs(
160+
def _add_sagemaker_session_with_custom_user_agent_to_kwargs(
159161
kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs]
160162
) -> JumpStartModelInitKwargs:
161163
"""Sets session in kwargs based on default or override, returns full kwargs."""
162164

163165
kwargs.sagemaker_session = (
164166
kwargs.sagemaker_session
165167
or get_default_jumpstart_session_with_user_agent_suffix(
166-
kwargs.model_id, kwargs.model_version, kwargs.hub_arn
168+
kwargs.model_id, kwargs.model_version, kwargs.config_name, kwargs.hub_arn
167169
)
168170
)
169171

@@ -662,33 +664,47 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
662664
ValueError: If the instance_type is not supported with the current config.
663665
"""
664666

667+
# we need to create a default JS session (without custom user agent)
668+
# in order to retrieve config name info
669+
temp_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION
670+
671+
kwargs.config_name = kwargs.config_name or get_top_ranked_config_name(
672+
region=kwargs.region,
673+
model_id=kwargs.model_id,
674+
model_version=kwargs.model_version,
675+
sagemaker_session=temp_session,
676+
scope=JumpStartScriptScope.INFERENCE,
677+
model_type=kwargs.model_type,
678+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
679+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
680+
hub_arn=kwargs.hub_arn,
681+
)
682+
683+
if kwargs.config_name is None:
684+
return kwargs
685+
665686
specs = verify_model_region_and_return_specs(
666687
model_id=kwargs.model_id,
667688
version=kwargs.model_version,
668689
scope=JumpStartScriptScope.INFERENCE,
669690
region=kwargs.region,
670691
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
671692
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
672-
sagemaker_session=kwargs.sagemaker_session,
693+
sagemaker_session=temp_session,
673694
model_type=kwargs.model_type,
674695
config_name=kwargs.config_name,
675696
)
676-
if specs.inference_configs:
677-
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
678-
kwargs.config_name = kwargs.config_name or default_config_name
679-
680-
if not kwargs.config_name:
681-
return kwargs
682-
683-
if kwargs.config_name not in set(specs.inference_configs.configs.keys()):
684-
raise ValueError(
685-
f"Config {kwargs.config_name} is not supported for model {kwargs.model_id}."
686-
)
687697

688-
resolved_config = specs.inference_configs.configs[kwargs.config_name].resolved_config
689-
supported_instance_types = resolved_config.get("supported_inference_instance_types", [])
690-
if kwargs.instance_type not in supported_instance_types:
691-
JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type)
698+
resolved_config = (
699+
specs.inference_configs.configs[kwargs.config_name].resolved_config
700+
if specs.inference_configs
701+
else None
702+
)
703+
if resolved_config is None:
704+
return kwargs
705+
supported_instance_types = resolved_config.get("supported_inference_instance_types", [])
706+
if kwargs.instance_type not in supported_instance_types:
707+
JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type)
692708
return kwargs
693709

694710

@@ -740,27 +756,41 @@ def _add_config_name_to_deploy_kwargs(
740756
ValueError: If the instance_type is not supported with the current config.
741757
"""
742758

743-
specs = verify_model_region_and_return_specs(
744-
model_id=kwargs.model_id,
745-
version=kwargs.model_version,
746-
scope=JumpStartScriptScope.INFERENCE,
747-
region=kwargs.region,
748-
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
749-
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
750-
sagemaker_session=kwargs.sagemaker_session,
751-
model_type=kwargs.model_type,
752-
config_name=kwargs.config_name,
753-
)
759+
# we need to create a default JS session (without custom user agent)
760+
# in order to retrieve config name info
761+
temp_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION
754762

755763
if training_config_name:
756-
kwargs.config_name = _select_inference_config_from_training_config(
764+
765+
specs = verify_model_region_and_return_specs(
766+
model_id=kwargs.model_id,
767+
version=kwargs.model_version,
768+
scope=JumpStartScriptScope.INFERENCE,
769+
region=kwargs.region,
770+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
771+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
772+
sagemaker_session=temp_session,
773+
model_type=kwargs.model_type,
774+
config_name=kwargs.config_name,
775+
)
776+
default_config_name = _select_inference_config_from_training_config(
757777
specs=specs, training_config_name=training_config_name
758778
)
759-
return kwargs
760779

761-
if specs.inference_configs:
762-
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
763-
kwargs.config_name = kwargs.config_name or default_config_name
780+
else:
781+
default_config_name = get_top_ranked_config_name(
782+
region=kwargs.region,
783+
model_id=kwargs.model_id,
784+
model_version=kwargs.model_version,
785+
sagemaker_session=temp_session,
786+
scope=JumpStartScriptScope.INFERENCE,
787+
model_type=kwargs.model_type,
788+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
789+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
790+
hub_arn=kwargs.hub_arn,
791+
)
792+
793+
kwargs.config_name = kwargs.config_name or default_config_name
764794

765795
return kwargs
766796

@@ -839,16 +869,16 @@ def get_deploy_kwargs(
839869
routing_config=routing_config,
840870
)
841871

842-
deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs)
872+
deploy_kwargs = _add_config_name_to_deploy_kwargs(
873+
kwargs=deploy_kwargs, training_config_name=training_config_name
874+
)
875+
876+
deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(kwargs=deploy_kwargs)
843877

844878
deploy_kwargs = _add_model_version_to_kwargs(kwargs=deploy_kwargs)
845879

846880
deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs)
847881

848-
deploy_kwargs = _add_config_name_to_deploy_kwargs(
849-
kwargs=deploy_kwargs, training_config_name=training_config_name
850-
)
851-
852882
deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs)
853883

854884
deploy_kwargs.initial_instance_count = initial_instance_count or 1
@@ -1030,11 +1060,14 @@ def get_init_kwargs(
10301060
)
10311061

10321062
model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs)
1063+
model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs)
1064+
model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs)
10331065

1034-
model_init_kwargs = _add_sagemaker_session_to_kwargs(kwargs=model_init_kwargs)
1066+
model_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(
1067+
kwargs=model_init_kwargs
1068+
)
10351069
model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs)
10361070

1037-
model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs)
10381071
model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs)
10391072

10401073
model_init_kwargs = _add_instance_type_to_kwargs(
@@ -1062,8 +1095,6 @@ def get_init_kwargs(
10621095

10631096
model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs)
10641097

1065-
model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs)
1066-
10671098
model_init_kwargs = _add_additional_model_data_sources_to_kwargs(kwargs=model_init_kwargs)
10681099

10691100
return model_init_kwargs

src/sagemaker/jumpstart/utils.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,11 +1109,16 @@ def get_jumpstart_configs(
11091109

11101110

11111111
def get_jumpstart_user_agent_extra_suffix(
1112-
model_id: Optional[str], model_version: Optional[str], is_hub_content: Optional[bool]
1112+
model_id: Optional[str],
1113+
model_version: Optional[str],
1114+
config_name: Optional[str],
1115+
is_hub_content: Optional[bool],
11131116
) -> str:
11141117
"""Returns the model-specific user agent string to be added to requests."""
11151118
sagemaker_python_sdk_headers = get_user_agent_extra_suffix()
11161119
jumpstart_specific_suffix = f"md/js_model_id#{model_id} md/js_model_ver#{model_version}"
1120+
config_specific_suffix = f"md/js_config#{config_name}"
1121+
print(config_name)
11171122
hub_specific_suffix = f"md/js_is_hub_content#{is_hub_content}"
11181123

11191124
if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None):
@@ -1128,19 +1133,66 @@ def get_jumpstart_user_agent_extra_suffix(
11281133
else:
11291134
headers = f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix}"
11301135

1136+
if config_name:
1137+
headers = f"{headers} {config_specific_suffix}"
1138+
11311139
return headers
11321140

11331141

1142+
def get_top_ranked_config_name(
1143+
region: str,
1144+
model_id: str,
1145+
model_version: str,
1146+
sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1147+
scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE,
1148+
model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
1149+
tolerate_deprecated_model: bool = False,
1150+
tolerate_vulnerable_model: bool = False,
1151+
hub_arn: Optional[str] = None,
1152+
) -> Optional[str]:
1153+
"""Returns the top ranked config name for the given model ID and region.
1154+
1155+
Raises:
1156+
ValueError: If the script scope is not supported by JumpStart.
1157+
"""
1158+
model_specs = verify_model_region_and_return_specs(
1159+
model_id=model_id,
1160+
version=model_version,
1161+
scope=scope,
1162+
region=region,
1163+
hub_arn=hub_arn,
1164+
tolerate_vulnerable_model=tolerate_vulnerable_model,
1165+
tolerate_deprecated_model=tolerate_deprecated_model,
1166+
sagemaker_session=sagemaker_session,
1167+
model_type=model_type,
1168+
)
1169+
1170+
if scope == enums.JumpStartScriptScope.INFERENCE:
1171+
return (
1172+
model_specs.inference_configs.get_top_config_from_ranking().config_name
1173+
if model_specs.inference_configs
1174+
else None
1175+
)
1176+
if scope == enums.JumpStartScriptScope.TRAINING:
1177+
return (
1178+
model_specs.training_configs.get_top_config_from_ranking().config_name
1179+
if model_specs.training_configs
1180+
else None
1181+
)
1182+
raise ValueError(f"Unsupported script scope: {scope}.")
1183+
1184+
11341185
def get_default_jumpstart_session_with_user_agent_suffix(
11351186
model_id: Optional[str] = None,
11361187
model_version: Optional[str] = None,
1188+
config_name: Optional[str] = None,
11371189
is_hub_content: Optional[bool] = False,
11381190
) -> Session:
11391191
"""Returns default JumpStart SageMaker Session with model-specific user agent suffix."""
11401192
botocore_session = botocore.session.get_session()
11411193
botocore_config = botocore.config.Config(
11421194
user_agent_extra=get_jumpstart_user_agent_extra_suffix(
1143-
model_id, model_version, is_hub_content
1195+
model_id, model_version, config_name, is_hub_content
11441196
),
11451197
)
11461198
botocore_session.set_default_client_config(botocore_config)

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,3 +396,20 @@ def test_jumpstart_model_with_deployment_configs(setup):
396396
response = predictor.predict(payload, custom_attributes="accept_eula=true")
397397

398398
assert response is not None
399+
400+
401+
def test_jumpstart_session_with_config_name():
402+
model = JumpStartModel(model_id="meta-textgeneration-llama-2-7b", model_version="*")
403+
assert model.config_name != None
404+
session = model.sagemaker_session
405+
406+
with mock.patch("botocore.client.BaseClient._make_request") as mock_make_request:
407+
try:
408+
session.sagemaker_client.list_endpoints()
409+
except Exception:
410+
pass
411+
412+
assert (
413+
"md/js_model_id#meta-textgeneration-llama-2-7b md/js_model_ver#* md/js_config#tgi"
414+
in mock_make_request.call_args[0][1]["headers"]["User-Agent"]
415+
)

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1637,7 +1637,7 @@ def test_training_passes_role_to_deploy(
16371637
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
16381638
@mock.patch(
16391639
"sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix",
1640-
sagemaker_session,
1640+
lambda *largs, **kwargs: sagemaker_session,
16411641
)
16421642
@mock.patch(
16431643
"sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix",

0 commit comments

Comments
 (0)