1+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+ # SPDX-License-Identifier: MIT-0
3+
4+ # Permission is hereby granted, free of charge, to any person obtaining a copy
5+ # of this software and associated documentation files (the "Software"), to deal
6+ # in the Software without restriction, including without limitation the rights
7+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+ # copies of the Software, and to permit persons to whom the Software is
9+ # furnished to do so.
10+
11+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
12+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
13+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
14+ # AUTHORS OR COPYRIGHT OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
15+ # IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
16+
17+ from datetime import datetime , timedelta
18+ from typing import Callable , List , Optional , Tuple , Dict , Any
19+ import hashlib
20+ import os
21+ from pathlib import Path
22+
23+ import pandas as pd
24+ import numpy as np
25+ import boto3
26+ import logging
27+
28+ logger = logging .getLogger (__name__ )
29+
30+ cw = boto3 .client ("cloudwatch" )
31+ sm = boto3 .client ("sagemaker" )
32+
33+
34+ def disk_cache (outer : Callable ) -> Callable :
35+ """A decorator that implements disk-based caching for CloudWatch metrics data.
36+
37+ This decorator caches the output of the wrapped function to disk in JSON Lines format.
38+ It creates a cache key using MD5 hash of the function arguments and stores the data
39+ in the user's home directory under .amtviz/cw_metrics_cache/.
40+
41+ Args:
42+ outer (Callable): The function to be wrapped. Must return a pandas DataFrame
43+ containing CloudWatch metrics data.
44+
45+ Returns:
46+ Callable: A wrapper function that implements the caching logic.
47+ """
48+
49+ def inner (* args : Any , ** kwargs : Any ) -> pd .DataFrame :
50+ key_input = str (args ) + str (kwargs )
51+ # nosec b303 - Not used for cryptography, but to create lookup key
52+ key = hashlib .md5 (key_input .encode ("utf-8" )).hexdigest ()
53+ cache_dir = Path .home ().joinpath (".amtviz/cw_metrics_cache" )
54+ fn = f"{ cache_dir } /req_{ key } .jsonl.gz"
55+ if Path (fn ).exists ():
56+ try :
57+ df = pd .read_json (fn , lines = True )
58+ logger .debug ("H" , end = "" )
59+ df ["ts" ] = pd .to_datetime (df ["ts" ])
60+ df ["ts" ] = df ["ts" ].dt .tz_localize (None )
61+ df ["rel_ts" ] = pd .to_datetime (df ["rel_ts" ]) # pyright: ignore [reportIndexIssue, reportOptionalSubscript]
62+ df ["rel_ts" ] = df ["rel_ts" ].dt .tz_localize (None )
63+ return df
64+ except KeyError :
65+ # Empty file leads to empty df, hence no df['ts'] possible
66+ pass
67+ # nosec b110 - doesn't matter why we could not load it.
68+ except BaseException as e :
69+ logger .error ("\n Exception" , type (e ), e )
70+ pass # continue with calling the outer function
71+
72+ logger .debug ("M" , end = "" )
73+ df = outer (* args , ** kwargs )
74+ assert isinstance (df , pd .DataFrame ), "Only caching Pandas DataFrames."
75+
76+ os .makedirs (cache_dir , exist_ok = True )
77+ df .to_json (fn , orient = "records" , date_format = "iso" , lines = True )
78+
79+ return df
80+
81+ return inner
82+
83+
84+ def _metric_data_query_tpl (metric_name : str , dim_name : str , dim_value : str ) -> Dict [str , Any ]:
85+ return {
86+ "Id" : metric_name .lower ().replace (":" , "_" ).replace ("-" , "_" ),
87+ "MetricStat" : {
88+ "Stat" : "Average" ,
89+ "Metric" : {
90+ "Namespace" : "/aws/sagemaker/TrainingJobs" ,
91+ "MetricName" : metric_name ,
92+ "Dimensions" : [
93+ {"Name" : dim_name , "Value" : dim_value },
94+ ],
95+ },
96+ "Period" : 60 ,
97+ },
98+ "ReturnData" : True ,
99+ }
100+
101+
102+ def _get_metric_data (
103+ queries : List [Dict [str , Any ]],
104+ start_time : datetime ,
105+ end_time : datetime
106+ ) -> pd .DataFrame :
107+ start_time = start_time - timedelta (hours = 1 )
108+ end_time = end_time + timedelta (hours = 1 )
109+ response = cw .get_metric_data (MetricDataQueries = queries , StartTime = start_time , EndTime = end_time )
110+
111+ df = pd .DataFrame ()
112+ if "MetricDataResults" not in response :
113+ return df
114+
115+ for metric_data in response ["MetricDataResults" ]:
116+ values = metric_data ["Values" ]
117+ ts = np .array (metric_data ["Timestamps" ], dtype = np .datetime64 )
118+ labels = [metric_data ["Label" ]] * len (values )
119+
120+ df = pd .concat ([df , pd .DataFrame ({"value" : values , "ts" : ts , "label" : labels })])
121+
122+ # We now calculate the relative time based on the first actual observed
123+ # time stamps, not the potentially start time that we used to scope our CW
124+ # API call. The difference could be for example startup times or waiting
125+ # for Spot.
126+ if not df .empty :
127+ df ["rel_ts" ] = datetime .fromtimestamp (1 ) + (df ["ts" ] - df ["ts" ].min ()) # pyright: ignore
128+ return df
129+
130+
131+ @disk_cache
132+ def _collect_metrics (
133+ dimensions : List [Tuple [str , str ]],
134+ start_time : datetime ,
135+ end_time : Optional [datetime ]
136+ ) -> pd .DataFrame :
137+
138+ df = pd .DataFrame ()
139+ for dim_name , dim_value in dimensions :
140+ response = cw .list_metrics (
141+ Namespace = "/aws/sagemaker/TrainingJobs" ,
142+ Dimensions = [
143+ {"Name" : dim_name , "Value" : dim_value },
144+ ],
145+ )
146+ if not response ["Metrics" ]:
147+ continue
148+ metric_names = [metric ["MetricName" ] for metric in response ["Metrics" ]]
149+ if not metric_names :
150+ # No metric data yet, or not any longer, because the data were aged out
151+ continue
152+ metric_data_queries = [
153+ _metric_data_query_tpl (metric_name , dim_name , dim_value ) for metric_name in metric_names
154+ ]
155+ df = pd .concat ([df , _get_metric_data (metric_data_queries , start_time , end_time )])
156+
157+ return df
158+
159+
160+ def get_cw_job_metrics (
161+ job_name : str ,
162+ start_time : Optional [datetime ] = None ,
163+ end_time : Optional [datetime ] = None
164+ ) -> pd .DataFrame :
165+ """Retrieves CloudWatch metrics for a SageMaker training job.
166+
167+ Args:
168+ job_name (str): Name of the SageMaker training job.
169+ start_time (datetime, optional): Start time for metrics collection.
170+ Defaults to now - 4 hours.
171+ end_time (datetime, optional): End time for metrics collection.
172+ Defaults to start_time + 4 hours.
173+
174+ Returns:
175+ pd.DataFrame: Metrics data with columns for value, timestamp, and metric name.
176+ Results are cached to disk for improved performance.
177+ """
178+ dimensions = [
179+ ("TrainingJobName" , job_name ),
180+ ("Host" , job_name + "/algo-1" ),
181+ ]
182+ # If not given, use reasonable defaults for start and end time
183+ start_time = start_time or datetime .now () - timedelta (hours = 4 )
184+ end_time = end_time or start_time + timedelta (hours = 4 )
185+ return _collect_metrics (dimensions , start_time , end_time )
0 commit comments