5555 TEST_RUN_DISPLAY_NAME ,
5656 TEST_ARTIFACT_BUCKET ,
5757 TEST_ARTIFACT_PREFIX ,
58+ TEST_TAGS ,
5859)
5960
6061
@@ -155,24 +156,22 @@ def test_run_init_name_length_exceed_limit(sagemaker_session):
155156
156157
157158@pytest .mark .parametrize (
158- ("kwargs" , "expected_artifact_bucket" , "expected_artifact_prefix" ),
159+ ("kwargs" , "expected_artifact_bucket" , "expected_artifact_prefix" , "expected_tags" ),
159160 [
160- ({}, None , _DEFAULT_ARTIFACT_PREFIX ),
161+ ({}, None , _DEFAULT_ARTIFACT_PREFIX , None ),
161162 (
162163 {
163164 "artifact_bucket" : TEST_ARTIFACT_BUCKET ,
164165 "artifact_prefix" : TEST_ARTIFACT_PREFIX ,
166+ "tags" : TEST_TAGS ,
165167 },
166168 TEST_ARTIFACT_BUCKET ,
167169 TEST_ARTIFACT_PREFIX ,
170+ TEST_TAGS ,
168171 ),
169172 ],
170173)
171174@patch .object (_TrialComponent , "save" , MagicMock (return_value = None ))
172- @patch (
173- "sagemaker.experiments.run.Experiment._load_or_create" ,
174- MagicMock (return_value = Experiment (experiment_name = TEST_EXP_NAME )),
175- )
176175@patch (
177176 "sagemaker.experiments.run._Trial._load_or_create" ,
178177 MagicMock (side_effect = mock_trial_load_or_create_func ),
@@ -189,6 +188,7 @@ def test_run_load_no_run_name_and_in_train_job(
189188 kwargs ,
190189 expected_artifact_bucket ,
191190 expected_artifact_prefix ,
191+ expected_tags ,
192192):
193193 client = sagemaker_session .sagemaker_client
194194 job_name = "my-train-job"
@@ -213,26 +213,32 @@ def test_run_load_no_run_name_and_in_train_job(
213213 {
214214 "TrialComponent" : {
215215 "Parents" : [
216- {"ExperimentName" : TEST_EXP_NAME , "TrialName" : exp_config [TRIAL_NAME ]}
216+ {
217+ "ExperimentName" : TEST_EXP_NAME ,
218+ "TrialName" : exp_config [TRIAL_NAME ],
219+ }
217220 ],
218221 "TrialComponentName" : expected_tc_name ,
219222 }
220223 }
221224 ]
222225 }
223- with load_run (sagemaker_session = sagemaker_session , ** kwargs ) as run_obj :
224- assert run_obj ._in_load
225- assert not run_obj ._inside_init_context
226- assert run_obj ._inside_load_context
227- assert run_obj .run_name == TEST_RUN_NAME
228- assert run_obj ._trial_component .trial_component_name == expected_tc_name
229- assert run_obj .run_group_name == Run ._generate_trial_name (TEST_EXP_NAME )
230- assert run_obj ._trial
231- assert run_obj .experiment_name == TEST_EXP_NAME
232- assert run_obj ._experiment
233- assert run_obj .experiment_config == exp_config
234- assert run_obj ._artifact_uploader .artifact_bucket == expected_artifact_bucket
235- assert run_obj ._artifact_uploader .artifact_prefix == expected_artifact_prefix
226+ expmock = MagicMock (return_value = Experiment (experiment_name = TEST_EXP_NAME , tags = expected_tags ))
227+ with patch ("sagemaker.experiments.run.Experiment._load_or_create" , expmock ):
228+ with load_run (sagemaker_session = sagemaker_session , ** kwargs ) as run_obj :
229+ assert run_obj ._in_load
230+ assert not run_obj ._inside_init_context
231+ assert run_obj ._inside_load_context
232+ assert run_obj .run_name == TEST_RUN_NAME
233+ assert run_obj ._trial_component .trial_component_name == expected_tc_name
234+ assert run_obj .run_group_name == Run ._generate_trial_name (TEST_EXP_NAME )
235+ assert run_obj ._trial
236+ assert run_obj .experiment_name == TEST_EXP_NAME
237+ assert run_obj ._experiment
238+ assert run_obj .experiment_config == exp_config
239+ assert run_obj ._artifact_uploader .artifact_bucket == expected_artifact_bucket
240+ assert run_obj ._artifact_uploader .artifact_prefix == expected_artifact_prefix
241+ assert run_obj ._experiment .tags == expected_tags
236242
237243 client .describe_training_job .assert_called_once_with (TrainingJobName = job_name )
238244 run_obj ._trial .add_trial_component .assert_not_called ()
@@ -265,7 +271,9 @@ def test_run_load_no_run_name_and_not_in_train_job(run_obj, sagemaker_session):
265271 assert run_obj == run
266272
267273
268- def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context (sagemaker_session ):
274+ def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context (
275+ sagemaker_session ,
276+ ):
269277 with pytest .raises (RuntimeError ) as err :
270278 with load_run (sagemaker_session = sagemaker_session ):
271279 pass
@@ -388,7 +396,10 @@ def test_run_load_in_sm_processing_job(mock_run_env, sagemaker_session):
388396 {
389397 "TrialComponent" : {
390398 "Parents" : [
391- {"ExperimentName" : TEST_EXP_NAME , "TrialName" : exp_config [TRIAL_NAME ]}
399+ {
400+ "ExperimentName" : TEST_EXP_NAME ,
401+ "TrialName" : exp_config [TRIAL_NAME ],
402+ }
392403 ],
393404 "TrialComponentName" : expected_tc_name ,
394405 }
@@ -442,7 +453,10 @@ def test_run_load_in_sm_transform_job(mock_run_env, sagemaker_session):
442453 {
443454 "TrialComponent" : {
444455 "Parents" : [
445- {"ExperimentName" : TEST_EXP_NAME , "TrialName" : exp_config [TRIAL_NAME ]}
456+ {
457+ "ExperimentName" : TEST_EXP_NAME ,
458+ "TrialName" : exp_config [TRIAL_NAME ],
459+ }
446460 ],
447461 "TrialComponentName" : expected_tc_name ,
448462 }
@@ -589,7 +603,10 @@ def test_log_output_artifact_outside_run_context(run_obj):
589603
590604
591605def test_log_output_artifact (run_obj ):
592- run_obj ._artifact_uploader .upload_artifact .return_value = ("s3uri_value" , "etag_value" )
606+ run_obj ._artifact_uploader .upload_artifact .return_value = (
607+ "s3uri_value" ,
608+ "etag_value" ,
609+ )
593610 with run_obj :
594611 run_obj .log_file ("foo.txt" , "name" , "whizz/bang" )
595612 run_obj ._artifact_uploader .upload_artifact .assert_called_with ("foo.txt" , extra_args = None )
@@ -608,7 +625,10 @@ def test_log_input_artifact_outside_run_context(run_obj):
608625
609626
610627def test_log_input_artifact (run_obj ):
611- run_obj ._artifact_uploader .upload_artifact .return_value = ("s3uri_value" , "etag_value" )
628+ run_obj ._artifact_uploader .upload_artifact .return_value = (
629+ "s3uri_value" ,
630+ "etag_value" ,
631+ )
612632 with run_obj :
613633 run_obj .log_file ("foo.txt" , "name" , "whizz/bang" , is_output = False )
614634 run_obj ._artifact_uploader .upload_artifact .assert_called_with ("foo.txt" , extra_args = None )
@@ -653,7 +673,10 @@ def test_log_multiple_input_artifacts(run_obj):
653673 "etag_value" + str (index ),
654674 )
655675 run_obj .log_file (
656- file_path , "name" + str (index ), "whizz/bang" + str (index ), is_output = False
676+ file_path ,
677+ "name" + str (index ),
678+ "whizz/bang" + str (index ),
679+ is_output = False ,
657680 )
658681 run_obj ._artifact_uploader .upload_artifact .assert_called_with (
659682 file_path , extra_args = None
@@ -757,7 +780,12 @@ def test_log_precision_recall_invalid_input(run_obj):
757780 with run_obj :
758781 with pytest .raises (ValueError ) as error :
759782 run_obj .log_precision_recall (
760- y_true , y_scores , 0 , title = "TestPrecisionRecall" , no_skill = no_skill , is_output = False
783+ y_true ,
784+ y_scores ,
785+ 0 ,
786+ title = "TestPrecisionRecall" ,
787+ no_skill = no_skill ,
788+ is_output = False ,
761789 )
762790 assert "Lengths mismatch between true labels and predicted probabilities" in str (error )
763791
@@ -905,7 +933,8 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
905933 display_name = "C" + str (i ),
906934 source_arn = "D" + str (i ),
907935 status = TrialComponentStatus (
908- primary_status = _TrialComponentStatusType .InProgress .value , message = "E" + str (i )
936+ primary_status = _TrialComponentStatusType .InProgress .value ,
937+ message = "E" + str (i ),
909938 ),
910939 start_time = start_time + datetime .timedelta (hours = i ),
911940 end_time = end_time + datetime .timedelta (hours = i ),
@@ -925,7 +954,8 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
925954 display_name = "C" + str (i ),
926955 source_arn = "D" + str (i ),
927956 status = TrialComponentStatus (
928- primary_status = _TrialComponentStatusType .InProgress .value , message = "E" + str (i )
957+ primary_status = _TrialComponentStatusType .InProgress .value ,
958+ message = "E" + str (i ),
929959 ),
930960 start_time = start_time + datetime .timedelta (hours = i ),
931961 end_time = end_time + datetime .timedelta (hours = i ),
0 commit comments