@@ -601,10 +601,17 @@ def _create_tfs_model(
601601 ** kwargs
602602 ):
603603 """Placeholder docstring"""
604+ # remove image kwarg
605+ if "image" in kwargs :
606+ image = kwargs ["image" ]
607+ kwargs = {k : v for k , v in kwargs .items () if k != "image" }
608+ else :
609+ image = None
610+
604611 return Model (
605612 model_data = self .model_data ,
606613 role = role ,
607- image = self .image_name ,
614+ image = ( image or self .image_name ) ,
608615 name = self ._current_job_name ,
609616 container_log_level = self .container_log_level ,
610617 framework_version = utils .get_short_version (self .framework_version ),
@@ -628,14 +635,21 @@ def _create_default_model(
628635 ** kwargs
629636 ):
630637 """Placeholder docstring"""
638+ # remove image kwarg
639+ if "image" in kwargs :
640+ image = kwargs ["image" ]
641+ kwargs = {k : v for k , v in kwargs .items () if k != "image" }
642+ else :
643+ image = None
644+
631645 return TensorFlowModel (
632646 self .model_data ,
633647 role ,
634648 entry_point or self .entry_point ,
635649 source_dir = source_dir or self ._model_source_dir (),
636650 enable_cloudwatch_metrics = self .enable_cloudwatch_metrics ,
637651 env = {"SAGEMAKER_REQUIREMENTS" : self .requirements_file },
638- image = self .image_name ,
652+ image = ( image or self .image_name ) ,
639653 name = self ._current_job_name ,
640654 container_log_level = self .container_log_level ,
641655 code_location = self .code_location ,
0 commit comments