@@ -115,6 +115,19 @@ def __init__(
115115 self ._enable_network_isolation = enable_network_isolation
116116 self .model_kms_key = model_kms_key
117117
118+ def _init_sagemaker_session_if_does_not_exist (self , instance_type ):
119+ """Set ``self.sagemaker_session`` to be a ``LocalSession`` or
120+ ``Session`` if it is not already. The type of session object is
121+ determined by the instance type.
122+ """
123+ if self .sagemaker_session :
124+ return
125+
126+ if instance_type in ("local" , "local_gpu" ):
127+ self .sagemaker_session = local .LocalSession ()
128+ else :
129+ self .sagemaker_session = session .Session ()
130+
118131 def prepare_container_def (
119132 self , instance_type , accelerator_type = None
120133 ): # pylint: disable=unused-argument
@@ -164,6 +177,8 @@ def _create_sagemaker_model(self, instance_type, accelerator_type=None, tags=Non
164177 container_def = self .prepare_container_def (instance_type , accelerator_type = accelerator_type )
165178 self .name = self .name or utils .name_from_image (container_def ["Image" ])
166179 enable_network_isolation = self .enable_network_isolation ()
180+
181+ self ._init_sagemaker_session_if_does_not_exist (instance_type )
167182 self .sagemaker_session .create_model (
168183 self .name ,
169184 self .role ,
@@ -324,6 +339,7 @@ def compile(
324339 framework = framework .upper ()
325340 framework_version = self ._get_framework_version () or framework_version
326341
342+ self ._init_sagemaker_session_if_does_not_exist (target_instance_family )
327343 config = self ._compilation_job_config (
328344 target_instance_family ,
329345 input_shape ,
@@ -413,11 +429,7 @@ def deploy(
413429 ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
414430 is not None. Otherwise, return None.
415431 """
416- if not self .sagemaker_session :
417- if instance_type in ("local" , "local_gpu" ):
418- self .sagemaker_session = local .LocalSession ()
419- else :
420- self .sagemaker_session = session .Session ()
432+ self ._init_sagemaker_session_if_does_not_exist (instance_type )
421433
422434 if self .role is None :
423435 raise ValueError ("Role can not be null for deploying a model" )
@@ -514,6 +526,8 @@ def transformer(
514526 volume_kms_key (str): Optional. KMS key ID for encrypting the volume
515527 attached to the ML compute instance (default: None).
516528 """
529+ self ._init_sagemaker_session_if_does_not_exist (instance_type )
530+
517531 self ._create_sagemaker_model (instance_type , tags = tags )
518532 if self .enable_network_isolation ():
519533 env = None
0 commit comments