@@ -35,7 +35,26 @@ def test_create_training_job(train, LocalSession):
3535 image = "my-docker-image:1.0"
3636
3737 algo_spec = {'TrainingImage' : image }
38- input_data_config = {}
38+ input_data_config = [
39+ {
40+ 'ChannelName' : 'a' ,
41+ 'DataSource' : {
42+ 'S3DataSource' : {
43+ 'S3DataDistributionType' : 'FullyReplicated' ,
44+ 'S3Uri' : 's3://my_bucket/tmp/source1'
45+ }
46+ }
47+ },
48+ {
49+ 'ChannelName' : 'b' ,
50+ 'DataSource' : {
51+ 'FileDataSource' : {
52+ 'FileDataDistributionType' : 'FullyReplicated' ,
53+ 'FileUri' : 'file:///tmp/source1'
54+ }
55+ }
56+ }
57+ ]
3958 output_data_config = {}
4059 resource_config = {'InstanceType' : 'local' , 'InstanceCount' : instance_count }
4160 hyperparameters = {'a' : 1 , 'b' : 'bee' }
@@ -61,6 +80,67 @@ def test_create_training_job(train, LocalSession):
6180 assert response ['ModelArtifacts' ]['S3ModelArtifacts' ] == expected ['ModelArtifacts' ]['S3ModelArtifacts' ]
6281
6382
83+ @patch ('sagemaker.local.image._SageMakerContainer.train' , return_value = "/some/path/to/model" )
84+ @patch ('sagemaker.local.local_session.LocalSession' )
85+ def test_create_training_job_invalid_data_source (train , LocalSession ):
86+ local_sagemaker_client = sagemaker .local .local_session .LocalSagemakerClient ()
87+
88+ instance_count = 2
89+ image = "my-docker-image:1.0"
90+
91+ algo_spec = {'TrainingImage' : image }
92+
93+ # InvalidDataSource is not supported. S3DataSource and FileDataSource are currently the only
94+ # valid Data Sources. We expect a ValueError if we pass this input data config.
95+ input_data_config = [{
96+ 'ChannelName' : 'a' ,
97+ 'DataSource' : {
98+ 'InvalidDataSource' : {
99+ 'FileDataDistributionType' : 'FullyReplicated' ,
100+ 'FileUri' : 'ftp://myserver.com/tmp/source1'
101+ }
102+ }
103+ }]
104+
105+ output_data_config = {}
106+ resource_config = {'InstanceType' : 'local' , 'InstanceCount' : instance_count }
107+ hyperparameters = {'a' : 1 , 'b' : 'bee' }
108+
109+ with pytest .raises (ValueError ):
110+ local_sagemaker_client .create_training_job ("my-training-job" , algo_spec , 'arn:my-role' , input_data_config ,
111+ output_data_config , resource_config , None , hyperparameters )
112+
113+
114+ @patch ('sagemaker.local.image._SageMakerContainer.train' , return_value = "/some/path/to/model" )
115+ @patch ('sagemaker.local.local_session.LocalSession' )
116+ def test_create_training_job_not_fully_replicated (train , LocalSession ):
117+ local_sagemaker_client = sagemaker .local .local_session .LocalSagemakerClient ()
118+
119+ instance_count = 2
120+ image = "my-docker-image:1.0"
121+
122+ algo_spec = {'TrainingImage' : image }
123+
124+ # Local Mode only supports FullyReplicated as Data Distribution type.
125+ input_data_config = [{
126+ 'ChannelName' : 'a' ,
127+ 'DataSource' : {
128+ 'S3DataSource' : {
129+ 'S3DataDistributionType' : 'ShardedByS3Key' ,
130+ 'S3Uri' : 's3://my_bucket/tmp/source1'
131+ }
132+ }
133+ }]
134+
135+ output_data_config = {}
136+ resource_config = {'InstanceType' : 'local' , 'InstanceCount' : instance_count }
137+ hyperparameters = {'a' : 1 , 'b' : 'bee' }
138+
139+ with pytest .raises (RuntimeError ):
140+ local_sagemaker_client .create_training_job ("my-training-job" , algo_spec , 'arn:my-role' , input_data_config ,
141+ output_data_config , resource_config , None , hyperparameters )
142+
143+
64144@patch ('sagemaker.local.local_session.LocalSession' )
65145def test_create_model (LocalSession ):
66146 local_sagemaker_client = sagemaker .local .local_session .LocalSagemakerClient ()
@@ -130,3 +210,34 @@ def test_create_endpoint_fails(serve, request, LocalSession):
130210
131211 with pytest .raises (RuntimeError ):
132212 local_sagemaker_client .create_endpoint ('my-endpoint' , 'some-endpoint-config' )
213+
214+
215+ def test_file_input_all_defaults ():
216+ prefix = 'pre'
217+ actual = sagemaker .local .local_session .file_input (fileUri = prefix )
218+ expected = \
219+ {
220+ 'DataSource' : {
221+ 'FileDataSource' : {
222+ 'FileDataDistributionType' : 'FullyReplicated' ,
223+ 'FileUri' : prefix
224+ }
225+ }
226+ }
227+ assert actual .config == expected
228+
229+
230+ def test_file_input_content_type ():
231+ prefix = 'pre'
232+ actual = sagemaker .local .local_session .file_input (fileUri = prefix , content_type = 'text/csv' )
233+ expected = \
234+ {
235+ 'DataSource' : {
236+ 'FileDataSource' : {
237+ 'FileDataDistributionType' : 'FullyReplicated' ,
238+ 'FileUri' : prefix
239+ }
240+ },
241+ 'ContentType' : 'text/csv'
242+ }
243+ assert actual .config == expected
0 commit comments