3232REGION = 'us-west-2'
3333BUCKET_NAME = 'mybucket'
3434EXPANDED_ROLE = 'arn:aws:iam::111111111111:role/ExpandedRole'
35+ TRAINING_JOB_NAME = 'my-job'
3536INPUT_DATA_CONFIG = [
3637 {
3738 'ChannelName' : 'a' ,
5556]
5657HYPERPARAMETERS = {'a' : 1 ,
5758 'b' : json .dumps ('bee' ),
58- 'sagemaker_submit_directory' : json .dumps ('s3://my_bucket/code' ),
59- 'sagemaker_job_name' : json . dumps ( 'my-job' )}
59+ 'sagemaker_submit_directory' : json .dumps ('s3://my_bucket/code' )}
60+
6061
6162LOCAL_CODE_HYPERPARAMETERS = {'a' : 1 ,
6263 'b' : 2 ,
63- 'sagemaker_submit_directory' : json .dumps ('file:///tmp/code' ),
64- 'sagemaker_job_name' : json .dumps ('my-job' )}
64+ 'sagemaker_submit_directory' : json .dumps ('file:///tmp/code' )}
6565
6666
6767@pytest .fixture ()
@@ -230,7 +230,7 @@ def test_train(_download_folder, _cleanup, popen, _stream_output, LocalSession,
230230 instance_count = 2
231231 image = 'my-image'
232232 sagemaker_container = _SageMakerContainer ('local' , instance_count , image , sagemaker_session = sagemaker_session )
233- sagemaker_container .train (INPUT_DATA_CONFIG , HYPERPARAMETERS )
233+ sagemaker_container .train (INPUT_DATA_CONFIG , HYPERPARAMETERS , TRAINING_JOB_NAME )
234234
235235 channel_dir = os .path .join (directories [1 ], 'b' )
236236 download_folder_calls = [call ('my-own-bucket' , 'prefix' , channel_dir )]
@@ -252,13 +252,36 @@ def test_train(_download_folder, _cleanup, popen, _stream_output, LocalSession,
252252 assert config ['services' ][h ]['image' ] == image
253253 assert config ['services' ][h ]['command' ] == 'train'
254254 assert 'AWS_REGION={}' .format (REGION ) in config ['services' ][h ]['environment' ]
255- assert 'TRAINING_JOB_NAME=my-job' in config ['services' ][h ]['environment' ]
255+ assert 'TRAINING_JOB_NAME={}' . format ( TRAINING_JOB_NAME ) in config ['services' ][h ]['environment' ]
256256
257257 # assert that expected by sagemaker container output directories exist
258258 assert os .path .exists (os .path .join (sagemaker_container .container_root , 'output' ))
259259 assert os .path .exists (os .path .join (sagemaker_container .container_root , 'output/data' ))
260260
261261
262+ @patch ('sagemaker.local.local_session.LocalSession' )
263+ @patch ('sagemaker.local.image._stream_output' )
264+ @patch ('sagemaker.local.image._SageMakerContainer._cleanup' )
265+ @patch ('sagemaker.local.image._SageMakerContainer._download_folder' )
266+ def test_train_with_hyperparameters_without_job_name (_download_folder , _cleanup , _stream_output , LocalSession , tmpdir ):
267+
268+ directories = [str (tmpdir .mkdir ('container-root' )), str (tmpdir .mkdir ('data' ))]
269+ with patch ('sagemaker.local.image._SageMakerContainer._create_tmp_folder' ,
270+ side_effect = directories ):
271+
272+ instance_count = 2
273+ image = 'my-image'
274+ sagemaker_container = _SageMakerContainer ('local' , instance_count , image , sagemaker_session = LocalSession )
275+ sagemaker_container .train (INPUT_DATA_CONFIG , HYPERPARAMETERS , TRAINING_JOB_NAME )
276+
277+ docker_compose_file = os .path .join (sagemaker_container .container_root , 'docker-compose.yaml' )
278+
279+ with open (docker_compose_file , 'r' ) as f :
280+ config = yaml .load (f )
281+ for h in sagemaker_container .hosts :
282+ assert 'TRAINING_JOB_NAME={}' .format (TRAINING_JOB_NAME ) in config ['services' ][h ]['environment' ]
283+
284+
262285@patch ('sagemaker.local.local_session.LocalSession' )
263286@patch ('sagemaker.local.image._stream_output' , side_effect = RuntimeError ('this is expected' ))
264287@patch ('subprocess.Popen' )
@@ -273,7 +296,7 @@ def test_train_error(_download_folder, _cleanup, popen, _stream_output, LocalSes
273296 sagemaker_container = _SageMakerContainer ('local' , instance_count , image , sagemaker_session = sagemaker_session )
274297
275298 with pytest .raises (RuntimeError ) as e :
276- sagemaker_container .train (INPUT_DATA_CONFIG , HYPERPARAMETERS )
299+ sagemaker_container .train (INPUT_DATA_CONFIG , HYPERPARAMETERS , TRAINING_JOB_NAME )
277300
278301 assert 'this is expected' in str (e )
279302
@@ -293,7 +316,7 @@ def test_train_local_code(_download_folder, _cleanup, popen, _stream_output,
293316 sagemaker_container = _SageMakerContainer ('local' , instance_count , image ,
294317 sagemaker_session = sagemaker_session )
295318
296- sagemaker_container .train (INPUT_DATA_CONFIG , LOCAL_CODE_HYPERPARAMETERS )
319+ sagemaker_container .train (INPUT_DATA_CONFIG , LOCAL_CODE_HYPERPARAMETERS , TRAINING_JOB_NAME )
297320
298321 docker_compose_file = os .path .join (sagemaker_container .container_root ,
299322 'docker-compose.yaml' )
0 commit comments