1919import sagemaker
2020from sagemaker .estimator import Estimator
2121from sagemaker .session_settings import SessionSettings
22- from sagemaker .tuner import (
23- HyperparameterTuner
24- )
22+ from sagemaker .tuner import HyperparameterTuner
2523from tests .unit .tuner_test_utils import (
2624 OBJECTIVE_METRIC_NAME ,
2725 HYPERPARAMETER_RANGES ,
28- METRIC_DEFINITIONS
26+ METRIC_DEFINITIONS ,
2927)
28+
3029# Visualization specific imports
3130from sagemaker .amtviz .visualization import visualize_tuning_job , get_job_analytics_data
3231from tests .unit .tuner_visualize_test_utils import (
4443 TRIALS_DF_TRAINING_JOB_STATUSES ,
4544 TRIALS_DF_VALID_F1_VALUES ,
4645 FILTERED_TUNING_JOB_DF_DATA ,
47- TUNING_RANGES
46+ TUNING_RANGES ,
4847)
4948import altair as alt
5049
@@ -56,7 +55,7 @@ def create_sagemaker_session():
5655 boto_session = boto_mock ,
5756 config = None ,
5857 local_mode = False ,
59- settings = SessionSettings ()
58+ settings = SessionSettings (),
6059 )
6160 sms .sagemaker_config = {}
6261 return sms
@@ -103,12 +102,7 @@ def mock_visualize_tuning_job():
103102@pytest .fixture
104103def mock_get_job_analytics_data ():
105104 with patch ("sagemaker.amtviz.visualization.get_job_analytics_data" ) as mock :
106- mock .return_value = (
107- pd .DataFrame (TRIALS_DF_DATA ),
108- TUNED_PARAMETERS ,
109- OBJECTIVE_NAME ,
110- True
111- )
105+ mock .return_value = (pd .DataFrame (TRIALS_DF_DATA ), TUNED_PARAMETERS , OBJECTIVE_NAME , True )
112106 yield mock
113107
114108
@@ -144,21 +138,22 @@ def test_visualize_jobs(mock_visualize_tuning_job):
144138 result = HyperparameterTuner .visualize_jobs (TUNING_JOB_NAMES )
145139 assert result == "mock_chart"
146140 mock_visualize_tuning_job .assert_called_once_with (
147- TUNING_JOB_NAMES ,
148- return_dfs = False ,
149- job_metrics = None ,
150- trials_only = False ,
151- advanced = False
141+ TUNING_JOB_NAMES , return_dfs = False , job_metrics = None , trials_only = False , advanced = False
152142 )
153143 # Vary the parameters and check if they have been passed correctly
154144 result = HyperparameterTuner .visualize_jobs (
155- [TUNING_JOB_NAME_1 ], return_dfs = True , job_metrics = "job_metrics" , trials_only = True , advanced = True )
145+ [TUNING_JOB_NAME_1 ],
146+ return_dfs = True ,
147+ job_metrics = "job_metrics" ,
148+ trials_only = True ,
149+ advanced = True ,
150+ )
156151 mock_visualize_tuning_job .assert_called_with (
157152 [TUNING_JOB_NAME_1 ],
158153 return_dfs = True ,
159154 job_metrics = "job_metrics" ,
160155 trials_only = True ,
161- advanced = True
156+ advanced = True ,
162157 )
163158
164159
@@ -168,21 +163,15 @@ def test_visualize_job(tuner, mock_visualize_tuning_job):
168163 result = tuner .visualize_job ()
169164 assert result == "mock_chart"
170165 mock_visualize_tuning_job .assert_called_once_with (
171- tuner ,
172- return_dfs = False ,
173- job_metrics = None ,
174- trials_only = False ,
175- advanced = False
166+ tuner , return_dfs = False , job_metrics = None , trials_only = False , advanced = False
176167 )
177168 # With varying parameters
178- result = tuner .visualize_job (return_dfs = True , job_metrics = "job_metrics" , trials_only = True , advanced = True )
169+ result = tuner .visualize_job (
170+ return_dfs = True , job_metrics = "job_metrics" , trials_only = True , advanced = True
171+ )
179172 assert result == "mock_chart"
180173 mock_visualize_tuning_job .assert_called_with (
181- tuner ,
182- return_dfs = True ,
183- job_metrics = "job_metrics" ,
184- trials_only = True ,
185- advanced = True
174+ tuner , return_dfs = True , job_metrics = "job_metrics" , trials_only = True , advanced = True
186175 )
187176
188177
@@ -191,21 +180,22 @@ def test_visualize_multiple_jobs(tuner, tuner2, mock_visualize_tuning_job):
191180 result = HyperparameterTuner .visualize_jobs ([tuner , tuner2 ])
192181 assert result == "mock_chart"
193182 mock_visualize_tuning_job .assert_called_once_with (
194- [tuner , tuner2 ],
195- return_dfs = False ,
196- job_metrics = None ,
197- trials_only = False ,
198- advanced = False
183+ [tuner , tuner2 ], return_dfs = False , job_metrics = None , trials_only = False , advanced = False
199184 )
200185 # Vary the parameters and check if they have been passed correctly
201186 result = HyperparameterTuner .visualize_jobs (
202- [[tuner , tuner2 ]], return_dfs = True , job_metrics = "job_metrics" , trials_only = True , advanced = True )
187+ [[tuner , tuner2 ]],
188+ return_dfs = True ,
189+ job_metrics = "job_metrics" ,
190+ trials_only = True ,
191+ advanced = True ,
192+ )
203193 mock_visualize_tuning_job .assert_called_with (
204194 [[tuner , tuner2 ]],
205195 return_dfs = True ,
206196 job_metrics = "job_metrics" ,
207197 trials_only = True ,
208- advanced = True
198+ advanced = True ,
209199 )
210200
211201
@@ -226,10 +216,10 @@ def test_visualize_tuning_job_return_dfs(mock_get_job_analytics_data, mock_prepa
226216 assert isinstance (trials_df , pd .DataFrame )
227217 assert trials_df .shape == (2 , len (TRIALS_DF_COLUMNS ))
228218 assert trials_df .columns .tolist () == TRIALS_DF_COLUMNS
229- assert trials_df [' TrainingJobName' ].tolist () == TRIALS_DF_TRAINING_JOB_NAMES
230- assert trials_df [' TrainingJobStatus' ].tolist () == TRIALS_DF_TRAINING_JOB_STATUSES
231- assert trials_df [' TuningJobName' ].tolist () == TUNING_JOB_NAMES
232- assert trials_df [' valid-f1' ].tolist () == TRIALS_DF_VALID_F1_VALUES
219+ assert trials_df [" TrainingJobName" ].tolist () == TRIALS_DF_TRAINING_JOB_NAMES
220+ assert trials_df [" TrainingJobStatus" ].tolist () == TRIALS_DF_TRAINING_JOB_STATUSES
221+ assert trials_df [" TuningJobName" ].tolist () == TUNING_JOB_NAMES
222+ assert trials_df [" valid-f1" ].tolist () == TRIALS_DF_VALID_F1_VALUES
233223
234224 # Assertions for full_df
235225 assert isinstance (full_df , pd .DataFrame )
@@ -244,7 +234,7 @@ def test_visualize_tuning_job_empty_trials(mock_get_job_analytics_data):
244234 pd .DataFrame (), # empty dataframe
245235 TUNED_PARAMETERS ,
246236 OBJECTIVE_NAME ,
247- True
237+ True ,
248238 )
249239 charts = visualize_tuning_job ("empty_job" )
250240 assert charts .empty
@@ -267,9 +257,7 @@ def test_visualize_tuning_job_trials_only(mock_get_job_analytics_data):
267257# Check if all parameters are correctly passed to the (mocked) create_charts method
268258@patch ("sagemaker.amtviz.visualization.create_charts" )
269259def test_visualize_tuning_job_with_full_df (
270- mock_create_charts ,
271- mock_get_job_analytics_data ,
272- mock_prepare_consolidated_df
260+ mock_create_charts , mock_get_job_analytics_data , mock_prepare_consolidated_df
273261):
274262 mock_create_charts .return_value = alt .Chart ()
275263 visualize_tuning_job ("dummy_job" )
@@ -298,10 +286,15 @@ def test_visualize_tuning_job_with_full_df(
298286@patch ("sagemaker.HyperparameterTuningJobAnalytics" )
299287def test_get_job_analytics_data (mock_hyperparameter_tuning_job_analytics ):
300288 # Mock sagemaker's describe_hyper_parameter_tuning_job and some internal methods
301- sagemaker .amtviz .visualization .sm .describe_hyper_parameter_tuning_job = Mock (return_value = TUNING_JOB_RESULT )
289+ sagemaker .amtviz .visualization .sm .describe_hyper_parameter_tuning_job = Mock (
290+ return_value = TUNING_JOB_RESULT
291+ )
302292 sagemaker .amtviz .visualization ._get_tuning_job_names_with_parents = Mock (
303- return_value = [TUNING_JOB_NAME_1 , TUNING_JOB_NAME_2 ])
304- sagemaker .amtviz .visualization ._get_df = Mock (return_value = pd .DataFrame (FILTERED_TUNING_JOB_DF_DATA ))
293+ return_value = [TUNING_JOB_NAME_1 , TUNING_JOB_NAME_2 ]
294+ )
295+ sagemaker .amtviz .visualization ._get_df = Mock (
296+ return_value = pd .DataFrame (FILTERED_TUNING_JOB_DF_DATA )
297+ )
305298 mock_tuning_job_instance = MagicMock ()
306299 mock_hyperparameter_tuning_job_analytics .return_value = mock_tuning_job_instance
307300 mock_tuning_job_instance .tuning_ranges .values .return_value = TUNING_RANGES
0 commit comments