@@ -161,18 +161,16 @@ def _add_region_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelIni
161
161
162
162
163
163
def _add_sagemaker_session_with_custom_user_agent_to_kwargs (
164
- kwargs : Union [JumpStartModelInitKwargs , JumpStartModelDeployKwargs ]
164
+ kwargs : Union [JumpStartModelInitKwargs , JumpStartModelDeployKwargs ],
165
+ orig_session : Optional [Session ],
165
166
) -> JumpStartModelInitKwargs :
166
167
"""Sets session in kwargs based on default or override, returns full kwargs."""
167
168
168
- kwargs .sagemaker_session = (
169
- kwargs .sagemaker_session
170
- or get_default_jumpstart_session_with_user_agent_suffix (
171
- model_id = kwargs .model_id ,
172
- model_version = kwargs .model_version ,
173
- config_name = kwargs .config_name ,
174
- is_hub_content = kwargs .hub_arn is not None ,
175
- )
169
+ kwargs .sagemaker_session = orig_session or get_default_jumpstart_session_with_user_agent_suffix (
170
+ model_id = kwargs .model_id ,
171
+ model_version = kwargs .model_version ,
172
+ config_name = kwargs .config_name ,
173
+ is_hub_content = kwargs .hub_arn is not None ,
176
174
)
177
175
178
176
return kwargs
@@ -686,7 +684,7 @@ def get_deploy_kwargs(
686
684
config_name = config_name ,
687
685
routing_config = routing_config ,
688
686
)
689
- deploy_kwargs = _set_temp_sagemaker_session_if_not_set (kwargs = deploy_kwargs )
687
+ deploy_kwargs , orig_session = _set_temp_sagemaker_session_if_not_set (kwargs = deploy_kwargs )
690
688
deploy_kwargs .specs = verify_model_region_and_return_specs (
691
689
** get_model_info_default_kwargs (
692
690
deploy_kwargs , include_model_version = False , include_tolerate_flags = False
@@ -705,7 +703,9 @@ def get_deploy_kwargs(
705
703
706
704
deploy_kwargs = _add_model_version_to_kwargs (kwargs = deploy_kwargs )
707
705
708
- deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs (kwargs = deploy_kwargs )
706
+ deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs (
707
+ kwargs = deploy_kwargs , orig_session = orig_session
708
+ )
709
709
710
710
deploy_kwargs = _add_endpoint_name_to_kwargs (kwargs = deploy_kwargs )
711
711
@@ -890,7 +890,9 @@ def get_init_kwargs(
890
890
config_name = config_name ,
891
891
additional_model_data_sources = additional_model_data_sources ,
892
892
)
893
- model_init_kwargs = _set_temp_sagemaker_session_if_not_set (kwargs = model_init_kwargs )
893
+ model_init_kwargs , orig_session = _set_temp_sagemaker_session_if_not_set (
894
+ kwargs = model_init_kwargs
895
+ )
894
896
model_init_kwargs .specs = verify_model_region_and_return_specs (
895
897
** get_model_info_default_kwargs (
896
898
model_init_kwargs , include_model_version = False , include_tolerate_flags = False
@@ -908,7 +910,7 @@ def get_init_kwargs(
908
910
model_init_kwargs = _add_config_name_to_init_kwargs (kwargs = model_init_kwargs )
909
911
910
912
model_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs (
911
- kwargs = model_init_kwargs
913
+ kwargs = model_init_kwargs , orig_session = orig_session
912
914
)
913
915
model_init_kwargs = _add_region_to_kwargs (kwargs = model_init_kwargs )
914
916
0 commit comments