Skip to content

Commit c51df67

Browse files
authored
change: allow specifying model name in create_model() for Chainer, MXNet, PyTorch, and RL (#1396)
1 parent e6d8fc6 commit c51df67

File tree

9 files changed

+24
-8
lines changed

9 files changed

+24
-8
lines changed

src/sagemaker/chainer/estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,13 +206,15 @@ def create_model(
206206
if "image" not in kwargs:
207207
kwargs["image"] = self.image_name
208208

209+
if "name" not in kwargs:
210+
kwargs["name"] = self._current_job_name
211+
209212
return ChainerModel(
210213
self.model_data,
211214
role or self.role,
212215
entry_point or self.entry_point,
213216
source_dir=(source_dir or self._model_source_dir()),
214217
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
215-
name=self._current_job_name,
216218
container_log_level=self.container_log_level,
217219
code_location=self.code_location,
218220
py_version=self.py_version,

src/sagemaker/mxnet/estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,13 +209,15 @@ def create_model(
209209
if "image" not in kwargs:
210210
kwargs["image"] = image_name or self.image_name
211211

212+
if "name" not in kwargs:
213+
kwargs["name"] = self._current_job_name
214+
212215
return MXNetModel(
213216
self.model_data,
214217
role or self.role,
215218
entry_point or self.entry_point,
216219
source_dir=(source_dir or self._model_source_dir()),
217220
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
218-
name=self._current_job_name,
219221
container_log_level=self.container_log_level,
220222
code_location=self.code_location,
221223
py_version=self.py_version,

src/sagemaker/pytorch/estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,15 @@ def create_model(
167167
if "image" not in kwargs:
168168
kwargs["image"] = self.image_name
169169

170+
if "name" not in kwargs:
171+
kwargs["name"] = self._current_job_name
172+
170173
return PyTorchModel(
171174
self.model_data,
172175
role or self.role,
173176
entry_point or self.entry_point,
174177
source_dir=(source_dir or self._model_source_dir()),
175178
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
176-
name=self._current_job_name,
177179
container_log_level=self.container_log_level,
178180
code_location=self.code_location,
179181
py_version=self.py_version,

src/sagemaker/rl/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,8 @@ def create_model(
218218
base_args = dict(
219219
model_data=self.model_data,
220220
role=role or self.role,
221-
image=kwargs["image"] if "image" in kwargs else self.image_name,
222-
name=self._current_job_name,
221+
image=kwargs.get("image", self.image_name),
222+
name=kwargs.get("name", self._current_job_name),
223223
container_log_level=self.container_log_level,
224224
sagemaker_session=self.sagemaker_session,
225225
vpc_config=self.get_vpc_config(vpc_config_override),

tests/unit/test_chainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,19 +322,22 @@ def test_create_model_with_optional_params(sagemaker_session):
322322
new_role = "role"
323323
model_server_workers = 2
324324
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
325+
model_name = "model-name"
325326
model = chainer.create_model(
326327
role=new_role,
327328
model_server_workers=model_server_workers,
328329
vpc_config_override=vpc_config,
329330
entry_point=SERVING_SCRIPT_FILE,
330331
env=ENV,
332+
name=model_name,
331333
)
332334

333335
assert model.role == new_role
334336
assert model.model_server_workers == model_server_workers
335337
assert model.vpc_config == vpc_config
336338
assert model.entry_point == SERVING_SCRIPT_FILE
337339
assert model.env == ENV
340+
assert model.name == model_name
338341

339342

340343
def test_create_model_with_custom_image(sagemaker_session):

tests/unit/test_mxnet.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,19 +227,22 @@ def test_create_model_with_optional_params(sagemaker_session):
227227
new_role = "role"
228228
model_server_workers = 2
229229
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
230+
model_name = "model-name"
230231
model = mx.create_model(
231232
role=new_role,
232233
model_server_workers=model_server_workers,
233234
vpc_config_override=vpc_config,
234235
entry_point=SERVING_SCRIPT_FILE,
235236
env=ENV,
237+
name=model_name,
236238
)
237239

238240
assert model.role == new_role
239241
assert model.model_server_workers == model_server_workers
240242
assert model.vpc_config == vpc_config
241243
assert model.entry_point == SERVING_SCRIPT_FILE
242244
assert model.env == ENV
245+
assert model.name == model_name
243246

244247

245248
def test_create_model_with_custom_image(sagemaker_session):

tests/unit/test_pytorch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,19 +208,22 @@ def test_create_model_with_optional_params(sagemaker_session):
208208
new_role = "role"
209209
model_server_workers = 2
210210
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
211+
model_name = "model-name"
211212
model = pytorch.create_model(
212213
role=new_role,
213214
model_server_workers=model_server_workers,
214215
vpc_config_override=vpc_config,
215216
entry_point=SERVING_SCRIPT_FILE,
216217
env=ENV,
218+
name=model_name,
217219
)
218220

219221
assert model.role == new_role
220222
assert model.model_server_workers == model_server_workers
221223
assert model.vpc_config == vpc_config
222224
assert model.entry_point == SERVING_SCRIPT_FILE
223225
assert model.env == ENV
226+
assert model.name == model_name
224227

225228

226229
def test_create_model_with_custom_image(sagemaker_session):

tests/unit/test_rl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,13 +248,15 @@ def test_create_model_with_optional_params(sagemaker_session, rl_coach_mxnet_ver
248248
new_role = "role"
249249
new_entry_point = "deploy_script.py"
250250
vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
251+
model_name = "model-name"
251252
model = rl.create_model(
252-
role=new_role, entry_point=new_entry_point, vpc_config_override=vpc_config
253+
role=new_role, entry_point=new_entry_point, vpc_config_override=vpc_config, name=model_name
253254
)
254255

255256
assert model.role == new_role
256257
assert model.vpc_config == vpc_config
257258
assert model.entry_point == new_entry_point
259+
assert model.name == model_name
258260

259261

260262
def test_create_model_with_custom_image(sagemaker_session):

tox.ini

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ envlist = black-format,flake8,pylint,twine,sphinx,py27,py36
88

99
skip_missing_interpreters = False
1010

11-
1211
[flake8]
1312
max-line-length = 120
1413
exclude =
@@ -59,7 +58,7 @@ passenv =
5958
# Can be used to specify which tests to run, e.g.: tox -- -s
6059
commands =
6160
coverage run --source sagemaker -m pytest {posargs}
62-
{env:IGNORE_COVERAGE:} coverage report --fail-under=84 --omit */tensorflow/tensorflow_serving/*
61+
{env:IGNORE_COVERAGE:} coverage report --fail-under=85 --omit */tensorflow/tensorflow_serving/*
6362
extras = test
6463

6564
[testenv:flake8]

0 commit comments

Comments
 (0)