Skip to content

Commit facaf27

Browse files
author
Uemit Yoldas
committed
fix: code reformat using black
1 parent 2289f64 commit facaf27

File tree

6 files changed

+178
-162
lines changed

6 files changed

+178
-162
lines changed

src/sagemaker/amtviz/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@
2323
from __future__ import absolute_import
2424

2525
from sagemaker.amtviz.visualization import visualize_tuning_job
26-
__all__ = ['visualize_tuning_job']
26+
27+
__all__ = ["visualize_tuning_job"]

src/sagemaker/amtviz/job_metrics.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,7 @@ def _metric_data_query_tpl(metric_name: str, dim_name: str, dim_value: str) -> D
100100

101101

102102
def _get_metric_data(
103-
queries: List[Dict[str, Any]],
104-
start_time: datetime,
105-
end_time: datetime
103+
queries: List[Dict[str, Any]], start_time: datetime, end_time: datetime
106104
) -> pd.DataFrame:
107105
"""Fetches CloudWatch metrics between timestamps, returns a DataFrame with selected columns."""
108106
start_time = start_time - timedelta(hours=1)
@@ -131,9 +129,7 @@ def _get_metric_data(
131129

132130
@disk_cache
133131
def _collect_metrics(
134-
dimensions: List[Tuple[str, str]],
135-
start_time: datetime,
136-
end_time: Optional[datetime]
132+
dimensions: List[Tuple[str, str]], start_time: datetime, end_time: Optional[datetime]
137133
) -> pd.DataFrame:
138134
"""Collects SageMaker training job metrics from CloudWatch for dimensions and time range."""
139135
df = pd.DataFrame()
@@ -159,9 +155,7 @@ def _collect_metrics(
159155

160156

161157
def get_cw_job_metrics(
162-
job_name: str,
163-
start_time: Optional[datetime] = None,
164-
end_time: Optional[datetime] = None
158+
job_name: str, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None
165159
) -> pd.DataFrame:
166160
"""Retrieves CloudWatch metrics for a SageMaker training job.
167161

src/sagemaker/amtviz/visualization.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def visualize_tuning_job(
100100

101101
try:
102102
from IPython import get_ipython, display
103+
103104
if get_ipython():
104105
# Running in a Jupyter Notebook
105106
display(trials_df.head(10))
@@ -110,9 +111,7 @@ def visualize_tuning_job(
110111
# Not running in a Jupyter Notebook
111112
logger.info(trials_df.head(10).to_string())
112113

113-
full_df = (
114-
_prepare_consolidated_df(trials_df) if not trials_only else pd.DataFrame()
115-
)
114+
full_df = _prepare_consolidated_df(trials_df) if not trials_only else pd.DataFrame()
116115

117116
trials_df.columns = trials_df.columns.map(_clean_parameter_name)
118117
full_df.columns = full_df.columns.map(_clean_parameter_name)
@@ -216,9 +215,11 @@ def create_charts(
216215
jobs_props["stroke"] = alt.condition(
217216
job_highlight_selection,
218217
alt.StrokeValue("gold"),
219-
alt.Stroke("TuningJobName:N", legend=None)
220-
if multiple_tuning_jobs
221-
else alt.StrokeValue("white"),
218+
(
219+
alt.Stroke("TuningJobName:N", legend=None)
220+
if multiple_tuning_jobs
221+
else alt.StrokeValue("white")
222+
),
222223
)
223224

224225
opacity = alt.condition(brush, alt.value(1.0), alt.value(0.35))
@@ -759,9 +760,11 @@ def get_job_analytics_data(tuning_job_names):
759760

760761
# Ensure to create a list of tuning job names (strings)
761762
tuning_job_names = [
762-
tuning_job.describe()["HyperParameterTuningJobName"]
763-
if isinstance(tuning_job, sagemaker.tuner.HyperparameterTuner)
764-
else tuning_job
763+
(
764+
tuning_job.describe()["HyperParameterTuningJobName"]
765+
if isinstance(tuning_job, sagemaker.tuner.HyperparameterTuner)
766+
else tuning_job
767+
)
765768
for tuning_job in tuning_job_names
766769
]
767770

src/sagemaker/tuner.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2120,13 +2120,14 @@ def _add_estimator(
21202120
@staticmethod
21212121
def visualize_jobs(
21222122
tuning_jobs: Union[
2123-
str, 'sagemaker.tuner.HyperparameterTuner',
2124-
List[Union[str, 'sagemaker.tuner.HyperparameterTuner']]
2123+
str,
2124+
"sagemaker.tuner.HyperparameterTuner",
2125+
List[Union[str, "sagemaker.tuner.HyperparameterTuner"]],
21252126
],
21262127
return_dfs: bool = False,
21272128
job_metrics: Optional[List[str]] = None,
21282129
trials_only: bool = False,
2129-
advanced: bool = False
2130+
advanced: bool = False,
21302131
):
21312132
"""Create interactive visualization via altair charts using the sagemaker.amtviz package.
21322133
@@ -2144,7 +2145,7 @@ def visualize_jobs(
21442145
"""
21452146
try:
21462147
# Check if altair is installed
2147-
importlib.import_module('altair')
2148+
importlib.import_module("altair")
21482149

21492150
except ImportError:
21502151
print("Altair is not installed. Install Altair to use the visualization feature:")
@@ -2164,10 +2165,11 @@ def visualize_jobs(
21642165
)
21652166

21662167
def visualize_job(
2167-
self, return_dfs: bool = False,
2168+
self,
2169+
return_dfs: bool = False,
21682170
job_metrics: Optional[List[str]] = None,
21692171
trials_only: bool = False,
2170-
advanced: bool = False
2172+
advanced: bool = False,
21712173
):
21722174
"""Convenience method on instance level for visualize_jobs().
21732175

tests/unit/test_tuner_visualize.py

Lines changed: 41 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@
1919
import sagemaker
2020
from sagemaker.estimator import Estimator
2121
from sagemaker.session_settings import SessionSettings
22-
from sagemaker.tuner import (
23-
HyperparameterTuner
24-
)
22+
from sagemaker.tuner import HyperparameterTuner
2523
from tests.unit.tuner_test_utils import (
2624
OBJECTIVE_METRIC_NAME,
2725
HYPERPARAMETER_RANGES,
28-
METRIC_DEFINITIONS
26+
METRIC_DEFINITIONS,
2927
)
28+
3029
# Visualization specific imports
3130
from sagemaker.amtviz.visualization import visualize_tuning_job, get_job_analytics_data
3231
from tests.unit.tuner_visualize_test_utils import (
@@ -44,7 +43,7 @@
4443
TRIALS_DF_TRAINING_JOB_STATUSES,
4544
TRIALS_DF_VALID_F1_VALUES,
4645
FILTERED_TUNING_JOB_DF_DATA,
47-
TUNING_RANGES
46+
TUNING_RANGES,
4847
)
4948
import altair as alt
5049

@@ -56,7 +55,7 @@ def create_sagemaker_session():
5655
boto_session=boto_mock,
5756
config=None,
5857
local_mode=False,
59-
settings=SessionSettings()
58+
settings=SessionSettings(),
6059
)
6160
sms.sagemaker_config = {}
6261
return sms
@@ -103,12 +102,7 @@ def mock_visualize_tuning_job():
103102
@pytest.fixture
104103
def mock_get_job_analytics_data():
105104
with patch("sagemaker.amtviz.visualization.get_job_analytics_data") as mock:
106-
mock.return_value = (
107-
pd.DataFrame(TRIALS_DF_DATA),
108-
TUNED_PARAMETERS,
109-
OBJECTIVE_NAME,
110-
True
111-
)
105+
mock.return_value = (pd.DataFrame(TRIALS_DF_DATA), TUNED_PARAMETERS, OBJECTIVE_NAME, True)
112106
yield mock
113107

114108

@@ -144,21 +138,22 @@ def test_visualize_jobs(mock_visualize_tuning_job):
144138
result = HyperparameterTuner.visualize_jobs(TUNING_JOB_NAMES)
145139
assert result == "mock_chart"
146140
mock_visualize_tuning_job.assert_called_once_with(
147-
TUNING_JOB_NAMES,
148-
return_dfs=False,
149-
job_metrics=None,
150-
trials_only=False,
151-
advanced=False
141+
TUNING_JOB_NAMES, return_dfs=False, job_metrics=None, trials_only=False, advanced=False
152142
)
153143
# Vary the parameters and check if they have been passed correctly
154144
result = HyperparameterTuner.visualize_jobs(
155-
[TUNING_JOB_NAME_1], return_dfs=True, job_metrics="job_metrics", trials_only=True, advanced=True)
145+
[TUNING_JOB_NAME_1],
146+
return_dfs=True,
147+
job_metrics="job_metrics",
148+
trials_only=True,
149+
advanced=True,
150+
)
156151
mock_visualize_tuning_job.assert_called_with(
157152
[TUNING_JOB_NAME_1],
158153
return_dfs=True,
159154
job_metrics="job_metrics",
160155
trials_only=True,
161-
advanced=True
156+
advanced=True,
162157
)
163158

164159

@@ -168,21 +163,15 @@ def test_visualize_job(tuner, mock_visualize_tuning_job):
168163
result = tuner.visualize_job()
169164
assert result == "mock_chart"
170165
mock_visualize_tuning_job.assert_called_once_with(
171-
tuner,
172-
return_dfs=False,
173-
job_metrics=None,
174-
trials_only=False,
175-
advanced=False
166+
tuner, return_dfs=False, job_metrics=None, trials_only=False, advanced=False
176167
)
177168
# With varying parameters
178-
result = tuner.visualize_job(return_dfs=True, job_metrics="job_metrics", trials_only=True, advanced=True)
169+
result = tuner.visualize_job(
170+
return_dfs=True, job_metrics="job_metrics", trials_only=True, advanced=True
171+
)
179172
assert result == "mock_chart"
180173
mock_visualize_tuning_job.assert_called_with(
181-
tuner,
182-
return_dfs=True,
183-
job_metrics="job_metrics",
184-
trials_only=True,
185-
advanced=True
174+
tuner, return_dfs=True, job_metrics="job_metrics", trials_only=True, advanced=True
186175
)
187176

188177

@@ -191,21 +180,22 @@ def test_visualize_multiple_jobs(tuner, tuner2, mock_visualize_tuning_job):
191180
result = HyperparameterTuner.visualize_jobs([tuner, tuner2])
192181
assert result == "mock_chart"
193182
mock_visualize_tuning_job.assert_called_once_with(
194-
[tuner, tuner2],
195-
return_dfs=False,
196-
job_metrics=None,
197-
trials_only=False,
198-
advanced=False
183+
[tuner, tuner2], return_dfs=False, job_metrics=None, trials_only=False, advanced=False
199184
)
200185
# Vary the parameters and check if they have been passed correctly
201186
result = HyperparameterTuner.visualize_jobs(
202-
[[tuner, tuner2]], return_dfs=True, job_metrics="job_metrics", trials_only=True, advanced=True)
187+
[[tuner, tuner2]],
188+
return_dfs=True,
189+
job_metrics="job_metrics",
190+
trials_only=True,
191+
advanced=True,
192+
)
203193
mock_visualize_tuning_job.assert_called_with(
204194
[[tuner, tuner2]],
205195
return_dfs=True,
206196
job_metrics="job_metrics",
207197
trials_only=True,
208-
advanced=True
198+
advanced=True,
209199
)
210200

211201

@@ -226,10 +216,10 @@ def test_visualize_tuning_job_return_dfs(mock_get_job_analytics_data, mock_prepa
226216
assert isinstance(trials_df, pd.DataFrame)
227217
assert trials_df.shape == (2, len(TRIALS_DF_COLUMNS))
228218
assert trials_df.columns.tolist() == TRIALS_DF_COLUMNS
229-
assert trials_df['TrainingJobName'].tolist() == TRIALS_DF_TRAINING_JOB_NAMES
230-
assert trials_df['TrainingJobStatus'].tolist() == TRIALS_DF_TRAINING_JOB_STATUSES
231-
assert trials_df['TuningJobName'].tolist() == TUNING_JOB_NAMES
232-
assert trials_df['valid-f1'].tolist() == TRIALS_DF_VALID_F1_VALUES
219+
assert trials_df["TrainingJobName"].tolist() == TRIALS_DF_TRAINING_JOB_NAMES
220+
assert trials_df["TrainingJobStatus"].tolist() == TRIALS_DF_TRAINING_JOB_STATUSES
221+
assert trials_df["TuningJobName"].tolist() == TUNING_JOB_NAMES
222+
assert trials_df["valid-f1"].tolist() == TRIALS_DF_VALID_F1_VALUES
233223

234224
# Assertions for full_df
235225
assert isinstance(full_df, pd.DataFrame)
@@ -244,7 +234,7 @@ def test_visualize_tuning_job_empty_trials(mock_get_job_analytics_data):
244234
pd.DataFrame(), # empty dataframe
245235
TUNED_PARAMETERS,
246236
OBJECTIVE_NAME,
247-
True
237+
True,
248238
)
249239
charts = visualize_tuning_job("empty_job")
250240
assert charts.empty
@@ -267,9 +257,7 @@ def test_visualize_tuning_job_trials_only(mock_get_job_analytics_data):
267257
# Check if all parameters are correctly passed to the (mocked) create_charts method
268258
@patch("sagemaker.amtviz.visualization.create_charts")
269259
def test_visualize_tuning_job_with_full_df(
270-
mock_create_charts,
271-
mock_get_job_analytics_data,
272-
mock_prepare_consolidated_df
260+
mock_create_charts, mock_get_job_analytics_data, mock_prepare_consolidated_df
273261
):
274262
mock_create_charts.return_value = alt.Chart()
275263
visualize_tuning_job("dummy_job")
@@ -298,10 +286,15 @@ def test_visualize_tuning_job_with_full_df(
298286
@patch("sagemaker.HyperparameterTuningJobAnalytics")
299287
def test_get_job_analytics_data(mock_hyperparameter_tuning_job_analytics):
300288
# Mock sagemaker's describe_hyper_parameter_tuning_job and some internal methods
301-
sagemaker.amtviz.visualization.sm.describe_hyper_parameter_tuning_job = Mock(return_value=TUNING_JOB_RESULT)
289+
sagemaker.amtviz.visualization.sm.describe_hyper_parameter_tuning_job = Mock(
290+
return_value=TUNING_JOB_RESULT
291+
)
302292
sagemaker.amtviz.visualization._get_tuning_job_names_with_parents = Mock(
303-
return_value=[TUNING_JOB_NAME_1, TUNING_JOB_NAME_2])
304-
sagemaker.amtviz.visualization._get_df = Mock(return_value=pd.DataFrame(FILTERED_TUNING_JOB_DF_DATA))
293+
return_value=[TUNING_JOB_NAME_1, TUNING_JOB_NAME_2]
294+
)
295+
sagemaker.amtviz.visualization._get_df = Mock(
296+
return_value=pd.DataFrame(FILTERED_TUNING_JOB_DF_DATA)
297+
)
305298
mock_tuning_job_instance = MagicMock()
306299
mock_hyperparameter_tuning_job_analytics.return_value = mock_tuning_job_instance
307300
mock_tuning_job_instance.tuning_ranges.values.return_value = TUNING_RANGES

0 commit comments

Comments
 (0)