Skip to content

Commit 163bffd

Browse files
nadiayajesterhazy
authored andcommitted
Make InputDataConfig optional for training. (#459)
* Make InputDataConfig optional for training. * Update boto3 dependency to make sure boto support no InputDataConfig. * Update changelog. * Add missing assertion for chainer failure script test.
1 parent 868f81b commit 163bffd

File tree

12 files changed

+82
-33
lines changed

12 files changed

+82
-33
lines changed

CHANGELOG.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
CHANGELOG
33
=========
44

5+
1.13.1.dev
6+
==========
7+
8+
* feature: Estimator: make input channels optional
9+
10+
511
1.13.0
612
======
713

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def read(fname):
5353
],
5454

5555
# Declare minimal set for installation
56-
install_requires=['boto3>=1.4.8', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=0.19.0',
56+
install_requires=['boto3>=1.9.38', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=0.19.0',
5757
'urllib3 >=1.21, <1.23',
5858
'PyYAML>=3.2', 'protobuf3-to-dict>=0.1.5', 'docker-compose>=1.21.0'],
5959

src/sagemaker/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def _prepare_for_training(self, job_name=None):
176176
else:
177177
self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket())
178178

179-
def fit(self, inputs, wait=True, logs=True, job_name=None):
179+
def fit(self, inputs=None, wait=True, logs=True, job_name=None):
180180
"""Train a model using the input training dataset.
181181
182182
The API calls the Amazon SageMaker CreateTrainingJob API to start model training.

src/sagemaker/job.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def _load_config(inputs, estimator):
6464

6565
model_channel = _Job._prepare_model_channel(input_config, estimator.model_uri, estimator.model_channel_name)
6666
if model_channel:
67+
input_config = [] if input_config is None else input_config
6768
input_config.append(model_channel)
6869

6970
return {'input_config': input_config,
@@ -75,6 +76,9 @@ def _load_config(inputs, estimator):
7576

7677
@staticmethod
7778
def _format_inputs_to_input_config(inputs):
79+
if inputs is None:
80+
return None
81+
7882
# Deferred import due to circular dependency
7983
from sagemaker.amazon.amazon_estimator import RecordSet
8084
if isinstance(inputs, RecordSet):
@@ -130,9 +134,10 @@ def _prepare_model_channel(input_config, model_uri=None, model_channel_name=None
130134
elif not model_channel_name:
131135
raise ValueError('Expected a pre-trained model channel name if a model URL is specified.')
132136

133-
for channel in input_config:
134-
if channel['ChannelName'] == model_channel_name:
135-
raise ValueError('Duplicate channels not allowed.')
137+
if input_config:
138+
for channel in input_config:
139+
if channel['ChannelName'] == model_channel_name:
140+
raise ValueError('Duplicate channels not allowed.')
136141

137142
model_input = _Job._format_model_uri_input(model_uri)
138143
model_channel = _Job._convert_input_to_channel(model_channel_name, model_input)

src/sagemaker/session.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,16 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
257257
'TrainingImage': image,
258258
'TrainingInputMode': input_mode
259259
},
260-
'InputDataConfig': input_config,
261260
'OutputDataConfig': output_config,
262261
'TrainingJobName': job_name,
263262
'StoppingCondition': stop_condition,
264263
'ResourceConfig': resource_config,
265264
'RoleArn': role,
266265
}
267266

267+
if input_config is not None:
268+
train_request['InputDataConfig'] = input_config
269+
268270
if hyperparameters and len(hyperparameters) > 0:
269271
train_request['HyperParameters'] = hyperparameters
270272

src/sagemaker/tensorflow/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def _validate_requirements_file(self, requirements_file):
207207
if not os.path.exists(os.path.join(self.source_dir, requirements_file)):
208208
raise ValueError('Requirements file {} does not exist.'.format(requirements_file))
209209

210-
def fit(self, inputs, wait=True, logs=True, job_name=None, run_tensorboard_locally=False):
210+
def fit(self, inputs=None, wait=True, logs=True, job_name=None, run_tensorboard_locally=False):
211211
"""Train a model using the input training dataset.
212212
213213
See :func:`~sagemaker.estimator.EstimatorBase.fit` for more details.

tests/integ/test_chainer_train.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,18 +105,15 @@ def test_async_fit(sagemaker_session):
105105
def test_failed_training_job(sagemaker_session, chainer_full_version):
106106
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
107107
script_path = os.path.join(DATA_DIR, 'chainer_mnist', 'failure_script.py')
108-
data_path = os.path.join(DATA_DIR, 'chainer_mnist')
109108

110109
chainer = Chainer(entry_point=script_path, role='SageMakerRole',
111110
framework_version=chainer_full_version, py_version=PYTHON_VERSION,
112111
train_instance_count=1, train_instance_type='ml.c4.xlarge',
113112
sagemaker_session=sagemaker_session)
114113

115-
train_input = chainer.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
116-
key_prefix='integ-test-data/chainer_mnist/train')
117-
118-
with pytest.raises(ValueError):
119-
chainer.fit(train_input)
114+
with pytest.raises(ValueError) as e:
115+
chainer.fit()
116+
assert 'This failure is expected' in str(e.value)
120117

121118

122119
def _run_mnist_training_job(sagemaker_session, instance_type, instance_count,

tests/integ/test_mxnet_train.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,11 @@ def test_async_fit(sagemaker_session, mxnet_full_version):
105105
def test_failed_training_job(sagemaker_session, mxnet_full_version):
106106
with timeout():
107107
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'failure_script.py')
108-
data_path = os.path.join(DATA_DIR, 'mxnet_mnist')
109108

110109
mx = MXNet(entry_point=script_path, role='SageMakerRole', framework_version=mxnet_full_version,
111110
py_version=PYTHON_VERSION, train_instance_count=1, train_instance_type='ml.c4.xlarge',
112111
sagemaker_session=sagemaker_session)
113112

114-
train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
115-
key_prefix='integ-test-data/mxnet_mnist/train-failure')
116-
117113
with pytest.raises(ValueError) as e:
118-
mx.fit(train_input)
114+
mx.fit()
119115
assert 'This failure is expected' in str(e.value)

tests/integ/test_pytorch_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_failed_training_job(sagemaker_session, pytorch_full_version):
106106
pytorch = _get_pytorch_estimator(sagemaker_session, pytorch_full_version, entry_point=script_path)
107107

108108
with pytest.raises(ValueError) as e:
109-
pytorch.fit(_upload_training_data(pytorch))
109+
pytorch.fit()
110110
assert 'This failure is expected' in str(e.value)
111111

112112

tests/integ/test_tf.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,6 @@ def test_failed_tf_training(sagemaker_session, tf_full_version):
160160
train_instance_type='ml.c4.xlarge',
161161
sagemaker_session=sagemaker_session)
162162

163-
inputs = estimator.sagemaker_session.upload_data(path=DATA_PATH, key_prefix='integ-test-data/tf-failure')
164-
165163
with pytest.raises(ValueError) as e:
166-
estimator.fit(inputs)
164+
estimator.fit()
167165
assert 'This failure is expected' in str(e.value)

0 commit comments

Comments
 (0)