Skip to content

Commit 8860ee1

Browse files
committed
fix: failing config tests
1 parent fc160a5 commit 8860ee1

File tree

3 files changed

+32
-29
lines changed

3 files changed

+32
-29
lines changed

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,9 @@ def get_init_kwargs(
207207
enable_session_tag_chaining=enable_session_tag_chaining,
208208
)
209209

210-
estimator_init_kwargs = _set_temp_sagemaker_session_if_not_set(kwargs=estimator_init_kwargs)
210+
estimator_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(
211+
kwargs=estimator_init_kwargs
212+
)
211213
estimator_init_kwargs.specs = verify_model_region_and_return_specs(
212214
**get_model_info_default_kwargs(
213215
estimator_init_kwargs, include_model_version=False, include_tolerate_flags=False
@@ -223,7 +225,7 @@ def get_init_kwargs(
223225
estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs)
224226
estimator_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(estimator_init_kwargs)
225227
estimator_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(
226-
estimator_init_kwargs
228+
estimator_init_kwargs, orig_session
227229
)
228230
estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs)
229231
estimator_init_kwargs = _add_instance_type_and_count_to_kwargs(estimator_init_kwargs)
@@ -280,7 +282,7 @@ def get_fit_kwargs(
280282
config_name=config_name,
281283
)
282284

283-
estimator_fit_kwargs = _set_temp_sagemaker_session_if_not_set(kwargs=estimator_fit_kwargs)
285+
estimator_fit_kwargs, _ = _set_temp_sagemaker_session_if_not_set(kwargs=estimator_fit_kwargs)
284286
estimator_fit_kwargs.specs = verify_model_region_and_return_specs(
285287
**get_model_info_default_kwargs(
286288
estimator_fit_kwargs, include_model_version=False, include_tolerate_flags=False
@@ -472,17 +474,14 @@ def _add_region_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs:
472474

473475

474476
def _add_sagemaker_session_with_custom_user_agent_to_kwargs(
475-
kwargs: JumpStartKwargs,
477+
kwargs: JumpStartKwargs, orig_session: Optional[Session]
476478
) -> JumpStartKwargs:
477479
"""Sets session in kwargs based on default or override, returns full kwargs."""
478-
kwargs.sagemaker_session = (
479-
kwargs.sagemaker_session
480-
or get_default_jumpstart_session_with_user_agent_suffix(
481-
model_id=kwargs.model_id,
482-
model_version=kwargs.model_version,
483-
config_name=None,
484-
is_hub_content=kwargs.hub_arn is not None,
485-
)
480+
kwargs.sagemaker_session = orig_session or get_default_jumpstart_session_with_user_agent_suffix(
481+
model_id=kwargs.model_id,
482+
model_version=kwargs.model_version,
483+
config_name=None,
484+
is_hub_content=kwargs.hub_arn is not None,
486485
)
487486
return kwargs
488487

src/sagemaker/jumpstart/factory/model.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -161,18 +161,16 @@ def _add_region_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelIni
161161

162162

163163
def _add_sagemaker_session_with_custom_user_agent_to_kwargs(
164-
kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs]
164+
kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs],
165+
orig_session: Optional[Session],
165166
) -> JumpStartModelInitKwargs:
166167
"""Sets session in kwargs based on default or override, returns full kwargs."""
167168

168-
kwargs.sagemaker_session = (
169-
kwargs.sagemaker_session
170-
or get_default_jumpstart_session_with_user_agent_suffix(
171-
model_id=kwargs.model_id,
172-
model_version=kwargs.model_version,
173-
config_name=kwargs.config_name,
174-
is_hub_content=kwargs.hub_arn is not None,
175-
)
169+
kwargs.sagemaker_session = orig_session or get_default_jumpstart_session_with_user_agent_suffix(
170+
model_id=kwargs.model_id,
171+
model_version=kwargs.model_version,
172+
config_name=kwargs.config_name,
173+
is_hub_content=kwargs.hub_arn is not None,
176174
)
177175

178176
return kwargs
@@ -686,7 +684,7 @@ def get_deploy_kwargs(
686684
config_name=config_name,
687685
routing_config=routing_config,
688686
)
689-
deploy_kwargs = _set_temp_sagemaker_session_if_not_set(kwargs=deploy_kwargs)
687+
deploy_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(kwargs=deploy_kwargs)
690688
deploy_kwargs.specs = verify_model_region_and_return_specs(
691689
**get_model_info_default_kwargs(
692690
deploy_kwargs, include_model_version=False, include_tolerate_flags=False
@@ -705,7 +703,9 @@ def get_deploy_kwargs(
705703

706704
deploy_kwargs = _add_model_version_to_kwargs(kwargs=deploy_kwargs)
707705

708-
deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(kwargs=deploy_kwargs)
706+
deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(
707+
kwargs=deploy_kwargs, orig_session=orig_session
708+
)
709709

710710
deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs)
711711

@@ -890,7 +890,9 @@ def get_init_kwargs(
890890
config_name=config_name,
891891
additional_model_data_sources=additional_model_data_sources,
892892
)
893-
model_init_kwargs = _set_temp_sagemaker_session_if_not_set(kwargs=model_init_kwargs)
893+
model_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(
894+
kwargs=model_init_kwargs
895+
)
894896
model_init_kwargs.specs = verify_model_region_and_return_specs(
895897
**get_model_info_default_kwargs(
896898
model_init_kwargs, include_model_version=False, include_tolerate_flags=False
@@ -908,7 +910,7 @@ def get_init_kwargs(
908910
model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs)
909911

910912
model_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(
911-
kwargs=model_init_kwargs
913+
kwargs=model_init_kwargs, orig_session=orig_session
912914
)
913915
model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs)
914916

src/sagemaker/jumpstart/factory/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""This module stores JumpStart factory utilities."""
1414

1515
from __future__ import absolute_import
16-
from typing import Union
16+
from typing import Tuple, Union
1717

1818
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
1919
from sagemaker.jumpstart.types import (
@@ -23,6 +23,7 @@
2323
JumpStartModelDeployKwargs,
2424
JumpStartModelInitKwargs,
2525
)
26+
from sagemaker.session import Session
2627

2728
KwargsType = Union[
2829
JumpStartModelDeployKwargs,
@@ -65,13 +66,14 @@ def get_model_info_default_kwargs(
6566
return kwargs_dict
6667

6768

68-
def _set_temp_sagemaker_session_if_not_set(kwargs: KwargsType) -> KwargsType:
69-
"""Sets a temporary sagemaker session if one is not set.
69+
def _set_temp_sagemaker_session_if_not_set(kwargs: KwargsType) -> Tuple[KwargsType, Session]:
70+
"""Sets a temporary sagemaker session if one is not set, and returns original session.
7071
7172
We need to create a default JS session (without custom user agent)
7273
in order to retrieve config name info.
7374
"""
7475

76+
orig_session = kwargs.sagemaker_session
7577
if kwargs.sagemaker_session is None:
7678
kwargs.sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION
77-
return kwargs
79+
return kwargs, orig_session

0 commit comments

Comments
 (0)