11# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2- # SPDX-License-Identifier: MIT-0
3-
4- # Permission is hereby granted, free of charge, to any person obtaining a copy
5- # of this software and associated documentation files (the "Software"), to deal
6- # in the Software without restriction, including without limitation the rights
7- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8- # copies of the Software, and to permit persons to whom the Software is
9- # furnished to do so.
10-
11- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
12- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
13- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
14- # AUTHORS OR COPYRIGHT OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
15- # IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4+ # may not use this file except in compliance with the License. A copy of
5+ # the License is located at
6+ #
7+ # http://aws.amazon.com/apache2.0/
8+ #
9+ # or in the "license" file accompanying this file. This file is
10+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+ # ANY KIND, either express or implied. See the License for the specific
12+ # language governing permissions and limitations under the License.
13+ """Helper functions to retrieve job metrics from CloudWatch."""
14+ from __future__ import absolute_import
1615
1716from datetime import datetime , timedelta
1817from typing import Callable , List , Optional , Tuple , Dict , Any
1918import hashlib
2019import os
2120from pathlib import Path
2221
22+ import logging
2323import pandas as pd
2424import numpy as np
2525import boto3
26- import logging
2726
2827logger = logging .getLogger (__name__ )
2928
@@ -58,16 +57,16 @@ def inner(*args: Any, **kwargs: Any) -> pd.DataFrame:
5857 logger .debug ("H" , end = "" )
5958 df ["ts" ] = pd .to_datetime (df ["ts" ])
6059 df ["ts" ] = df ["ts" ].dt .tz_localize (None )
61- df ["rel_ts" ] = pd .to_datetime (df ["rel_ts" ]) # pyright: ignore [reportIndexIssue, reportOptionalSubscript]
60+ # pyright: ignore [reportIndexIssue, reportOptionalSubscript]
61+ df ["rel_ts" ] = pd .to_datetime (df ["rel_ts" ])
6262 df ["rel_ts" ] = df ["rel_ts" ].dt .tz_localize (None )
6363 return df
6464 except KeyError :
6565 # Empty file leads to empty df, hence no df['ts'] possible
6666 pass
6767 # nosec b110 - doesn't matter why we could not load it.
6868 except BaseException as e :
69- logger .error ("\n Exception" , type (e ), e )
70- pass # continue with calling the outer function
69+ logger .error ("\n Exception: %s - %s" , type (e ), e )
7170
7271 logger .debug ("M" , end = "" )
7372 df = outer (* args , ** kwargs )
@@ -82,6 +81,7 @@ def inner(*args: Any, **kwargs: Any) -> pd.DataFrame:
8281
8382
8483def _metric_data_query_tpl (metric_name : str , dim_name : str , dim_value : str ) -> Dict [str , Any ]:
84+ """Returns a CloudWatch metric data query template."""
8585 return {
8686 "Id" : metric_name .lower ().replace (":" , "_" ).replace ("-" , "_" ),
8787 "MetricStat" : {
@@ -100,18 +100,19 @@ def _metric_data_query_tpl(metric_name: str, dim_name: str, dim_value: str) -> D
100100
101101
102102def _get_metric_data (
103- queries : List [Dict [str , Any ]],
104- start_time : datetime ,
103+ queries : List [Dict [str , Any ]],
104+ start_time : datetime ,
105105 end_time : datetime
106106) -> pd .DataFrame :
107+ """Fetches CloudWatch metrics between timestamps and returns a DataFrame with selected columns."""
107108 start_time = start_time - timedelta (hours = 1 )
108109 end_time = end_time + timedelta (hours = 1 )
109110 response = cw .get_metric_data (MetricDataQueries = queries , StartTime = start_time , EndTime = end_time )
110111
111112 df = pd .DataFrame ()
112113 if "MetricDataResults" not in response :
113114 return df
114-
115+
115116 for metric_data in response ["MetricDataResults" ]:
116117 values = metric_data ["Values" ]
117118 ts = np .array (metric_data ["Timestamps" ], dtype = np .datetime64 )
@@ -130,11 +131,11 @@ def _get_metric_data(
130131
131132@disk_cache
132133def _collect_metrics (
133- dimensions : List [Tuple [str , str ]],
134- start_time : datetime ,
134+ dimensions : List [Tuple [str , str ]],
135+ start_time : datetime ,
135136 end_time : Optional [datetime ]
136137) -> pd .DataFrame :
137-
138+ """Collects SageMaker training job metrics from CloudWatch based on given dimensions and time range."""
138139 df = pd .DataFrame ()
139140 for dim_name , dim_value in dimensions :
140141 response = cw .list_metrics (
@@ -158,8 +159,8 @@ def _collect_metrics(
158159
159160
160161def get_cw_job_metrics (
161- job_name : str ,
162- start_time : Optional [datetime ] = None ,
162+ job_name : str ,
163+ start_time : Optional [datetime ] = None ,
163164 end_time : Optional [datetime ] = None
164165) -> pd .DataFrame :
165166 """Retrieves CloudWatch metrics for a SageMaker training job.
@@ -182,4 +183,4 @@ def get_cw_job_metrics(
182183 # If not given, use reasonable defaults for start and end time
183184 start_time = start_time or datetime .now () - timedelta (hours = 4 )
184185 end_time = end_time or start_time + timedelta (hours = 4 )
185- return _collect_metrics (dimensions , start_time , end_time )
186+ return _collect_metrics (dimensions , start_time , end_time )
0 commit comments