@@ -136,10 +136,8 @@ def __init__(
136136 elif model is not None :
137137 if isinstance (model , PipelineModel ):
138138 self .model_list = model .models
139- self .container_def_list = model .pipeline_container_def (inference_instances [0 ])
140139 elif isinstance (model , Model ):
141140 self .model_list = [model ]
142- self .container_def_list = [model .prepare_container_def (inference_instances [0 ])]
143141
144142 for model_entity in self .model_list :
145143 if estimator is not None :
@@ -154,10 +152,10 @@ def __init__(
154152 source_dir = model_entity .source_dir
155153 dependencies = model_entity .dependencies
156154 kwargs = dict (** kwargs , output_kms_key = model_entity .model_kms_key )
157- name = model_entity .name or model_entity ._framework_name
155+ model_name = model_entity .name or model_entity ._framework_name
158156
159157 repack_model_step = _RepackModelStep (
160- name = f"{ name } RepackModel" ,
158+ name = f"{ model_name } RepackModel" ,
161159 depends_on = depends_on ,
162160 sagemaker_session = sagemaker_session ,
163161 role = role ,
@@ -171,10 +169,14 @@ def __init__(
171169 model_entity .model_data = (
172170 repack_model_step .properties .ModelArtifacts .S3ModelArtifacts
173171 )
174-
175172 # remove kwargs consumed by model repacking step
176173 kwargs .pop ("output_kms_key" , None )
177174
175+ if isinstance (model , PipelineModel ):
176+ self .container_def_list = model .pipeline_container_def (inference_instances [0 ])
177+ elif isinstance (model , Model ):
178+ self .container_def_list = [model .prepare_container_def (inference_instances [0 ])]
179+
178180 register_model_step = _RegisterModelStep (
179181 name = name ,
180182 estimator = estimator ,
0 commit comments