@@ -304,9 +304,13 @@ def test_check_output():
304304@patch ('sagemaker.local.local_session.LocalSession' , Mock ())
305305@patch ('sagemaker.local.image._stream_output' , Mock ())
306306@patch ('sagemaker.local.image._SageMakerContainer._cleanup' , Mock ())
307- @patch ('sagemaker.local.data.get_data_source_instance' , Mock () )
307+ @patch ('sagemaker.local.data.get_data_source_instance' )
308308@patch ('subprocess.Popen' )
309- def test_train (popen , tmpdir , sagemaker_session ):
309+ def test_train (popen , get_data_source_instance , tmpdir , sagemaker_session ):
310+ data_source = Mock ()
311+ data_source .get_root_dir .return_value = 'foo'
312+ get_data_source_instance .return_value = data_source
313+
310314 directories = [str (tmpdir .mkdir ('container-root' )), str (tmpdir .mkdir ('data' ))]
311315 with patch ('sagemaker.local.image._SageMakerContainer._create_tmp_folder' ,
312316 side_effect = directories ):
@@ -342,8 +346,12 @@ def test_train(popen, tmpdir, sagemaker_session):
342346@patch ('sagemaker.local.local_session.LocalSession' , Mock ())
343347@patch ('sagemaker.local.image._stream_output' , Mock ())
344348@patch ('sagemaker.local.image._SageMakerContainer._cleanup' , Mock ())
345- @patch ('sagemaker.local.data.get_data_source_instance' , Mock ())
346- def test_train_with_hyperparameters_without_job_name (tmpdir , sagemaker_session ):
349+ @patch ('sagemaker.local.data.get_data_source_instance' )
350+ def test_train_with_hyperparameters_without_job_name (get_data_source_instance , tmpdir , sagemaker_session ):
351+ data_source = Mock ()
352+ data_source .get_root_dir .return_value = 'foo'
353+ get_data_source_instance .return_value = data_source
354+
347355 directories = [str (tmpdir .mkdir ('container-root' )), str (tmpdir .mkdir ('data' ))]
348356 with patch ('sagemaker.local.image._SageMakerContainer._create_tmp_folder' ,
349357 side_effect = directories ):
@@ -364,11 +372,14 @@ def test_train_with_hyperparameters_without_job_name(tmpdir, sagemaker_session):
364372@patch ('sagemaker.local.local_session.LocalSession' , Mock ())
365373@patch ('sagemaker.local.image._stream_output' , side_effect = RuntimeError ('this is expected' ))
366374@patch ('sagemaker.local.image._SageMakerContainer._cleanup' , Mock ())
367- @patch ('sagemaker.local.data.get_data_source_instance' , Mock () )
375+ @patch ('sagemaker.local.data.get_data_source_instance' )
368376@patch ('subprocess.Popen' , Mock ())
369- def test_train_error (_stream_output , tmpdir , sagemaker_session ):
370- directories = [str (tmpdir .mkdir ('container-root' )), str (tmpdir .mkdir ('data' ))]
377+ def test_train_error (get_data_source_instance , _stream_output , tmpdir , sagemaker_session ):
378+ data_source = Mock ()
379+ data_source .get_root_dir .return_value = 'foo'
380+ get_data_source_instance .return_value = data_source
371381
382+ directories = [str (tmpdir .mkdir ('container-root' )), str (tmpdir .mkdir ('data' ))]
372383 with patch ('sagemaker.local.image._SageMakerContainer._create_tmp_folder' , side_effect = directories ):
373384 instance_count = 2
374385 image = 'my-image'
@@ -384,9 +395,13 @@ def test_train_error(_stream_output, tmpdir, sagemaker_session):
384395@patch ('sagemaker.local.local_session.LocalSession' , Mock ())
385396@patch ('sagemaker.local.image._stream_output' , Mock ())
386397@patch ('sagemaker.local.image._SageMakerContainer._cleanup' , Mock ())
387- @patch ('sagemaker.local.data.get_data_source_instance' , Mock () )
398+ @patch ('sagemaker.local.data.get_data_source_instance' )
388399@patch ('subprocess.Popen' , Mock ())
389- def test_train_local_code (tmpdir , sagemaker_session ):
400+ def test_train_local_code (get_data_source_instance , tmpdir , sagemaker_session ):
401+ data_source = Mock ()
402+ data_source .get_root_dir .return_value = 'foo'
403+ get_data_source_instance .return_value = data_source
404+
390405 directories = [str (tmpdir .mkdir ('container-root' )), str (tmpdir .mkdir ('data' ))]
391406 with patch ('sagemaker.local.image._SageMakerContainer._create_tmp_folder' ,
392407 side_effect = directories ):
@@ -422,9 +437,13 @@ def test_train_local_code(tmpdir, sagemaker_session):
422437@patch ('sagemaker.local.local_session.LocalSession' , Mock ())
423438@patch ('sagemaker.local.image._stream_output' , Mock ())
424439@patch ('sagemaker.local.image._SageMakerContainer._cleanup' , Mock ())
425- @patch ('sagemaker.local.data.get_data_source_instance' , Mock () )
440+ @patch ('sagemaker.local.data.get_data_source_instance' )
426441@patch ('subprocess.Popen' , Mock ())
427- def test_train_local_intermediate_output (tmpdir , sagemaker_session ):
442+ def test_train_local_intermediate_output (get_data_source_instance , tmpdir , sagemaker_session ):
443+ data_source = Mock ()
444+ data_source .get_root_dir .return_value = 'foo'
445+ get_data_source_instance .return_value = data_source
446+
428447 directories = [str (tmpdir .mkdir ('container-root' )), str (tmpdir .mkdir ('data' ))]
429448 with patch ('sagemaker.local.image._SageMakerContainer._create_tmp_folder' ,
430449 side_effect = directories ):
0 commit comments