19
19
import sagemaker
20
20
from sagemaker .estimator import Estimator
21
21
from sagemaker .session_settings import SessionSettings
22
- from sagemaker .tuner import (
23
- HyperparameterTuner
24
- )
22
+ from sagemaker .tuner import HyperparameterTuner
25
23
from tests .unit .tuner_test_utils import (
26
24
OBJECTIVE_METRIC_NAME ,
27
25
HYPERPARAMETER_RANGES ,
28
- METRIC_DEFINITIONS
26
+ METRIC_DEFINITIONS ,
29
27
)
28
+
30
29
# Visualization specific imports
31
30
from sagemaker .amtviz .visualization import visualize_tuning_job , get_job_analytics_data
32
31
from tests .unit .tuner_visualize_test_utils import (
44
43
TRIALS_DF_TRAINING_JOB_STATUSES ,
45
44
TRIALS_DF_VALID_F1_VALUES ,
46
45
FILTERED_TUNING_JOB_DF_DATA ,
47
- TUNING_RANGES
46
+ TUNING_RANGES ,
48
47
)
49
48
import altair as alt
50
49
@@ -56,7 +55,7 @@ def create_sagemaker_session():
56
55
boto_session = boto_mock ,
57
56
config = None ,
58
57
local_mode = False ,
59
- settings = SessionSettings ()
58
+ settings = SessionSettings (),
60
59
)
61
60
sms .sagemaker_config = {}
62
61
return sms
@@ -103,12 +102,7 @@ def mock_visualize_tuning_job():
103
102
@pytest .fixture
104
103
def mock_get_job_analytics_data ():
105
104
with patch ("sagemaker.amtviz.visualization.get_job_analytics_data" ) as mock :
106
- mock .return_value = (
107
- pd .DataFrame (TRIALS_DF_DATA ),
108
- TUNED_PARAMETERS ,
109
- OBJECTIVE_NAME ,
110
- True
111
- )
105
+ mock .return_value = (pd .DataFrame (TRIALS_DF_DATA ), TUNED_PARAMETERS , OBJECTIVE_NAME , True )
112
106
yield mock
113
107
114
108
@@ -144,21 +138,22 @@ def test_visualize_jobs(mock_visualize_tuning_job):
144
138
result = HyperparameterTuner .visualize_jobs (TUNING_JOB_NAMES )
145
139
assert result == "mock_chart"
146
140
mock_visualize_tuning_job .assert_called_once_with (
147
- TUNING_JOB_NAMES ,
148
- return_dfs = False ,
149
- job_metrics = None ,
150
- trials_only = False ,
151
- advanced = False
141
+ TUNING_JOB_NAMES , return_dfs = False , job_metrics = None , trials_only = False , advanced = False
152
142
)
153
143
# Vary the parameters and check if they have been passed correctly
154
144
result = HyperparameterTuner .visualize_jobs (
155
- [TUNING_JOB_NAME_1 ], return_dfs = True , job_metrics = "job_metrics" , trials_only = True , advanced = True )
145
+ [TUNING_JOB_NAME_1 ],
146
+ return_dfs = True ,
147
+ job_metrics = "job_metrics" ,
148
+ trials_only = True ,
149
+ advanced = True ,
150
+ )
156
151
mock_visualize_tuning_job .assert_called_with (
157
152
[TUNING_JOB_NAME_1 ],
158
153
return_dfs = True ,
159
154
job_metrics = "job_metrics" ,
160
155
trials_only = True ,
161
- advanced = True
156
+ advanced = True ,
162
157
)
163
158
164
159
@@ -168,21 +163,15 @@ def test_visualize_job(tuner, mock_visualize_tuning_job):
168
163
result = tuner .visualize_job ()
169
164
assert result == "mock_chart"
170
165
mock_visualize_tuning_job .assert_called_once_with (
171
- tuner ,
172
- return_dfs = False ,
173
- job_metrics = None ,
174
- trials_only = False ,
175
- advanced = False
166
+ tuner , return_dfs = False , job_metrics = None , trials_only = False , advanced = False
176
167
)
177
168
# With varying parameters
178
- result = tuner .visualize_job (return_dfs = True , job_metrics = "job_metrics" , trials_only = True , advanced = True )
169
+ result = tuner .visualize_job (
170
+ return_dfs = True , job_metrics = "job_metrics" , trials_only = True , advanced = True
171
+ )
179
172
assert result == "mock_chart"
180
173
mock_visualize_tuning_job .assert_called_with (
181
- tuner ,
182
- return_dfs = True ,
183
- job_metrics = "job_metrics" ,
184
- trials_only = True ,
185
- advanced = True
174
+ tuner , return_dfs = True , job_metrics = "job_metrics" , trials_only = True , advanced = True
186
175
)
187
176
188
177
@@ -191,21 +180,22 @@ def test_visualize_multiple_jobs(tuner, tuner2, mock_visualize_tuning_job):
191
180
result = HyperparameterTuner .visualize_jobs ([tuner , tuner2 ])
192
181
assert result == "mock_chart"
193
182
mock_visualize_tuning_job .assert_called_once_with (
194
- [tuner , tuner2 ],
195
- return_dfs = False ,
196
- job_metrics = None ,
197
- trials_only = False ,
198
- advanced = False
183
+ [tuner , tuner2 ], return_dfs = False , job_metrics = None , trials_only = False , advanced = False
199
184
)
200
185
# Vary the parameters and check if they have been passed correctly
201
186
result = HyperparameterTuner .visualize_jobs (
202
- [[tuner , tuner2 ]], return_dfs = True , job_metrics = "job_metrics" , trials_only = True , advanced = True )
187
+ [[tuner , tuner2 ]],
188
+ return_dfs = True ,
189
+ job_metrics = "job_metrics" ,
190
+ trials_only = True ,
191
+ advanced = True ,
192
+ )
203
193
mock_visualize_tuning_job .assert_called_with (
204
194
[[tuner , tuner2 ]],
205
195
return_dfs = True ,
206
196
job_metrics = "job_metrics" ,
207
197
trials_only = True ,
208
- advanced = True
198
+ advanced = True ,
209
199
)
210
200
211
201
@@ -226,10 +216,10 @@ def test_visualize_tuning_job_return_dfs(mock_get_job_analytics_data, mock_prepa
226
216
assert isinstance (trials_df , pd .DataFrame )
227
217
assert trials_df .shape == (2 , len (TRIALS_DF_COLUMNS ))
228
218
assert trials_df .columns .tolist () == TRIALS_DF_COLUMNS
229
- assert trials_df [' TrainingJobName' ].tolist () == TRIALS_DF_TRAINING_JOB_NAMES
230
- assert trials_df [' TrainingJobStatus' ].tolist () == TRIALS_DF_TRAINING_JOB_STATUSES
231
- assert trials_df [' TuningJobName' ].tolist () == TUNING_JOB_NAMES
232
- assert trials_df [' valid-f1' ].tolist () == TRIALS_DF_VALID_F1_VALUES
219
+ assert trials_df [" TrainingJobName" ].tolist () == TRIALS_DF_TRAINING_JOB_NAMES
220
+ assert trials_df [" TrainingJobStatus" ].tolist () == TRIALS_DF_TRAINING_JOB_STATUSES
221
+ assert trials_df [" TuningJobName" ].tolist () == TUNING_JOB_NAMES
222
+ assert trials_df [" valid-f1" ].tolist () == TRIALS_DF_VALID_F1_VALUES
233
223
234
224
# Assertions for full_df
235
225
assert isinstance (full_df , pd .DataFrame )
@@ -244,7 +234,7 @@ def test_visualize_tuning_job_empty_trials(mock_get_job_analytics_data):
244
234
pd .DataFrame (), # empty dataframe
245
235
TUNED_PARAMETERS ,
246
236
OBJECTIVE_NAME ,
247
- True
237
+ True ,
248
238
)
249
239
charts = visualize_tuning_job ("empty_job" )
250
240
assert charts .empty
@@ -267,9 +257,7 @@ def test_visualize_tuning_job_trials_only(mock_get_job_analytics_data):
267
257
# Check if all parameters are correctly passed to the (mocked) create_charts method
268
258
@patch ("sagemaker.amtviz.visualization.create_charts" )
269
259
def test_visualize_tuning_job_with_full_df (
270
- mock_create_charts ,
271
- mock_get_job_analytics_data ,
272
- mock_prepare_consolidated_df
260
+ mock_create_charts , mock_get_job_analytics_data , mock_prepare_consolidated_df
273
261
):
274
262
mock_create_charts .return_value = alt .Chart ()
275
263
visualize_tuning_job ("dummy_job" )
@@ -298,10 +286,15 @@ def test_visualize_tuning_job_with_full_df(
298
286
@patch ("sagemaker.HyperparameterTuningJobAnalytics" )
299
287
def test_get_job_analytics_data (mock_hyperparameter_tuning_job_analytics ):
300
288
# Mock sagemaker's describe_hyper_parameter_tuning_job and some internal methods
301
- sagemaker .amtviz .visualization .sm .describe_hyper_parameter_tuning_job = Mock (return_value = TUNING_JOB_RESULT )
289
+ sagemaker .amtviz .visualization .sm .describe_hyper_parameter_tuning_job = Mock (
290
+ return_value = TUNING_JOB_RESULT
291
+ )
302
292
sagemaker .amtviz .visualization ._get_tuning_job_names_with_parents = Mock (
303
- return_value = [TUNING_JOB_NAME_1 , TUNING_JOB_NAME_2 ])
304
- sagemaker .amtviz .visualization ._get_df = Mock (return_value = pd .DataFrame (FILTERED_TUNING_JOB_DF_DATA ))
293
+ return_value = [TUNING_JOB_NAME_1 , TUNING_JOB_NAME_2 ]
294
+ )
295
+ sagemaker .amtviz .visualization ._get_df = Mock (
296
+ return_value = pd .DataFrame (FILTERED_TUNING_JOB_DF_DATA )
297
+ )
305
298
mock_tuning_job_instance = MagicMock ()
306
299
mock_hyperparameter_tuning_job_analytics .return_value = mock_tuning_job_instance
307
300
mock_tuning_job_instance .tuning_ranges .values .return_value = TUNING_RANGES
0 commit comments