File tree Expand file tree Collapse file tree 6 files changed +21
-3
lines changed Expand file tree Collapse file tree 6 files changed +21
-3
lines changed Original file line number Diff line number Diff line change @@ -203,6 +203,9 @@ def create_model(
203203 sagemaker.chainer.model.ChainerModel: A SageMaker ``ChainerModel``
204204 object. See :func:`~sagemaker.chainer.model.ChainerModel` for full details.
205205 """
206+ if "image" not in kwargs :
207+ kwargs ["image" ] = self .image_name
208+
206209 return ChainerModel (
207210 self .model_data ,
208211 role or self .role ,
@@ -215,10 +218,10 @@ def create_model(
215218 py_version = self .py_version ,
216219 framework_version = self .framework_version ,
217220 model_server_workers = model_server_workers ,
218- image = kwargs ["image" ] if "image" in kwargs else self .image_name ,
219221 sagemaker_session = self .sagemaker_session ,
220222 vpc_config = self .get_vpc_config (vpc_config_override ),
221223 dependencies = (dependencies or self .dependencies ),
224+ ** kwargs
222225 )
223226
224227 @classmethod
Original file line number Diff line number Diff line change @@ -206,6 +206,9 @@ def create_model(
206206 sagemaker.mxnet.model.MXNetModel: A SageMaker ``MXNetModel`` object.
207207 See :func:`~sagemaker.mxnet.model.MXNetModel` for full details.
208208 """
209+ if "image" not in kwargs :
210+ kwargs ["image" ] = image_name or self .image_name
211+
209212 return MXNetModel (
210213 self .model_data ,
211214 role or self .role ,
@@ -217,11 +220,11 @@ def create_model(
217220 code_location = self .code_location ,
218221 py_version = self .py_version ,
219222 framework_version = self .framework_version ,
220- image = kwargs ["image" ] if "image" in kwargs else (image_name or self .image_name ),
221223 model_server_workers = model_server_workers ,
222224 sagemaker_session = self .sagemaker_session ,
223225 vpc_config = self .get_vpc_config (vpc_config_override ),
224226 dependencies = (dependencies or self .dependencies ),
227+ ** kwargs
225228 )
226229
227230 @classmethod
Original file line number Diff line number Diff line change @@ -164,6 +164,9 @@ def create_model(
164164 sagemaker.pytorch.model.PyTorchModel: A SageMaker ``PyTorchModel``
165165 object. See :func:`~sagemaker.pytorch.model.PyTorchModel` for full details.
166166 """
167+ if "image" not in kwargs :
168+ kwargs ["image" ] = self .image_name
169+
167170 return PyTorchModel (
168171 self .model_data ,
169172 role or self .role ,
@@ -175,11 +178,11 @@ def create_model(
175178 code_location = self .code_location ,
176179 py_version = self .py_version ,
177180 framework_version = self .framework_version ,
178- image = kwargs ["image" ] if "image" in kwargs else self .image_name ,
179181 model_server_workers = model_server_workers ,
180182 sagemaker_session = self .sagemaker_session ,
181183 vpc_config = self .get_vpc_config (vpc_config_override ),
182184 dependencies = (dependencies or self .dependencies ),
185+ ** kwargs
183186 )
184187
185188 @classmethod
Original file line number Diff line number Diff line change 3131SCRIPT_PATH = os .path .join (DATA_DIR , "dummy_script.py" )
3232SERVING_SCRIPT_FILE = "another_dummy_script.py"
3333MODEL_DATA = "s3://some/data.tar.gz"
34+ ENV = {"DUMMY_ENV_VAR" : "dummy_value" }
3435TIMESTAMP = "2017-11-06-14:14:15.672"
3536TIME = 1507167947
3637BUCKET_NAME = "mybucket"
@@ -326,12 +327,14 @@ def test_create_model_with_optional_params(sagemaker_session):
326327 model_server_workers = model_server_workers ,
327328 vpc_config_override = vpc_config ,
328329 entry_point = SERVING_SCRIPT_FILE ,
330+ env = ENV ,
329331 )
330332
331333 assert model .role == new_role
332334 assert model .model_server_workers == model_server_workers
333335 assert model .vpc_config == vpc_config
334336 assert model .entry_point == SERVING_SCRIPT_FILE
337+ assert model .env == ENV
335338
336339
337340def test_create_model_with_custom_image (sagemaker_session ):
Original file line number Diff line number Diff line change 3030SCRIPT_PATH = os .path .join (DATA_DIR , "dummy_script.py" )
3131SERVING_SCRIPT_FILE = "another_dummy_script.py"
3232MODEL_DATA = "s3://mybucket/model"
33+ ENV = {"DUMMY_ENV_VAR" : "dummy_value" }
3334TIMESTAMP = "2017-11-06-14:14:15.672"
3435TIME = 1507167947
3536BUCKET_NAME = "mybucket"
@@ -231,12 +232,14 @@ def test_create_model_with_optional_params(sagemaker_session):
231232 model_server_workers = model_server_workers ,
232233 vpc_config_override = vpc_config ,
233234 entry_point = SERVING_SCRIPT_FILE ,
235+ env = ENV ,
234236 )
235237
236238 assert model .role == new_role
237239 assert model .model_server_workers == model_server_workers
238240 assert model .vpc_config == vpc_config
239241 assert model .entry_point == SERVING_SCRIPT_FILE
242+ assert model .env == ENV
240243
241244
242245def test_create_model_with_custom_image (sagemaker_session ):
Original file line number Diff line number Diff line change 2828SCRIPT_PATH = os .path .join (DATA_DIR , "dummy_script.py" )
2929SERVING_SCRIPT_FILE = "another_dummy_script.py"
3030MODEL_DATA = "s3://some/data.tar.gz"
31+ ENV = {"DUMMY_ENV_VAR" : "dummy_value" }
3132TIMESTAMP = "2017-11-06-14:14:15.672"
3233TIME = 1507167947
3334BUCKET_NAME = "mybucket"
@@ -212,12 +213,14 @@ def test_create_model_with_optional_params(sagemaker_session):
212213 model_server_workers = model_server_workers ,
213214 vpc_config_override = vpc_config ,
214215 entry_point = SERVING_SCRIPT_FILE ,
216+ env = ENV ,
215217 )
216218
217219 assert model .role == new_role
218220 assert model .model_server_workers == model_server_workers
219221 assert model .vpc_config == vpc_config
220222 assert model .entry_point == SERVING_SCRIPT_FILE
223+ assert model .env == ENV
221224
222225
223226def test_create_model_with_custom_image (sagemaker_session ):
You can’t perform that action at this time.
0 commit comments