@@ -112,8 +112,8 @@ def __init__(
112112 if "entry_point" in kwargs :
113113 repack_model = True
114114 entry_point = kwargs .pop ("entry_point" , None )
115- source_dir = kwargs .get ("source_dir" )
116- dependencies = kwargs .get ("dependencies" )
115+ source_dir = kwargs .pop ("source_dir" , None )
116+ dependencies = kwargs .pop ("dependencies" , None )
117117 kwargs = dict (** kwargs , output_kms_key = kwargs .pop ("model_kms_key" , None ))
118118
119119 repack_model_step = _RepackModelStep (
@@ -130,13 +130,10 @@ def __init__(
130130 steps .append (repack_model_step )
131131 model_data = repack_model_step .properties .ModelArtifacts .S3ModelArtifacts
132132
133- # remove kwargs consumed by model repacking step
134- kwargs .pop ("entry_point" , None )
135- kwargs .pop ("source_dir" , None )
136- kwargs .pop ("dependencies" , None )
137- kwargs .pop ("output_kms_key" , None )
133+ # remove kwargs consumed by model repacking step
134+ kwargs .pop ("output_kms_key" , None )
138135
139- if model is not None :
136+ elif model is not None :
140137 if isinstance (model , PipelineModel ):
141138 self .model_list = model .models
142139 self .container_def_list = model .pipeline_container_def (inference_instances [0 ])
@@ -156,7 +153,9 @@ def __init__(
156153 entry_point = model_entity .entry_point
157154 source_dir = model_entity .source_dir
158155 dependencies = model_entity .dependencies
156+ kwargs = dict (** kwargs , output_kms_key = model_entity .model_kms_key )
159157 name = model_entity .name or model_entity ._framework_name
158+
160159 repack_model_step = _RepackModelStep (
161160 name = f"{ name } RepackModel" ,
162161 depends_on = depends_on ,
@@ -166,12 +165,16 @@ def __init__(
166165 entry_point = entry_point ,
167166 source_dir = source_dir ,
168167 dependencies = dependencies ,
168+ ** kwargs ,
169169 )
170170 steps .append (repack_model_step )
171171 model_entity .model_data = (
172172 repack_model_step .properties .ModelArtifacts .S3ModelArtifacts
173173 )
174174
175+ # remove kwargs consumed by model repacking step
176+ kwargs .pop ("output_kms_key" , None )
177+
175178 register_model_step = _RegisterModelStep (
176179 name = name ,
177180 estimator = estimator ,
0 commit comments