@@ -144,6 +144,11 @@ def log_hyperparams(self, params: Dict[str, Any]) -> None:
144144 self .run .config .update (params )
145145
146146
147+ class GpuMetricSnapshot (TypedDict ):
148+ step : int
149+ metrics : Dict [str , Any ]
150+
151+
147152class RayGpuMonitorLogger :
148153 """Monitor GPU utilization across a Ray cluster and log metrics to a parent logger."""
149154
@@ -163,7 +168,9 @@ def __init__(
163168 self .collection_interval = collection_interval
164169 self .flush_interval = flush_interval
165170 self .parent_logger = parent_logger
166- self .metrics_buffer = [] # Store metrics with timestamps
171+ self .metrics_buffer : list [
172+ GpuMetricSnapshot
173+ ] = [] # Store metrics with timestamps
167174 self .last_flush_time = time .time ()
168175 self .is_running = False
169176 self .collection_thread = None
@@ -228,7 +235,9 @@ def _collection_loop(self):
228235
229236 time .sleep (self .collection_interval )
230237 except Exception as e :
231- print (f"Error in GPU monitoring collection loop: { e } " )
238+ print (
239+ f"Error in GPU monitoring collection loop or stopped abruptly: { e } "
240+ )
232241 time .sleep (self .collection_interval ) # Continue despite errors
233242
234243 def _parse_gpu_metric (self , sample : Sample , node_idx : int ) -> Dict [str , Any ]:
@@ -241,7 +250,6 @@ def _parse_gpu_metric(self, sample: Sample, node_idx: int) -> Dict[str, Any]:
241250 Returns:
242251 Dictionary with metric name and value
243252 """
244- # TODO: Consider plumbing {'GpuDeviceName': 'NVIDIA H100 80GB HBM3'}
245253 # Expected labels for GPU metrics
246254 expected_labels = ["GpuIndex" ]
247255 for label in expected_labels :
@@ -266,12 +274,72 @@ def _parse_gpu_metric(self, sample: Sample, node_idx: int) -> Dict[str, Any]:
266274 metric_name = f"node.{ node_idx } .gpu.{ index } .{ metric_name } "
267275 return {metric_name : value }
268276
277+ def _parse_gpu_sku (self , sample : Sample , node_idx : int ) -> Dict [str , str ]:
278+ """Parse a GPU metric sample into a standardized format.
279+
280+ Args:
281+ sample: Prometheus metric sample
282+ node_idx: Index of the node
283+
284+ Returns:
285+ Dictionary with metric name and value
286+ """
287+ # TODO: Consider plumbing {'GpuDeviceName': 'NVIDIA H100 80GB HBM3'}
288+ # Expected labels for GPU metrics
289+ expected_labels = ["GpuIndex" , "GpuDeviceName" ]
290+ for label in expected_labels :
291+ if label not in sample .labels :
292+ # This is probably a CPU node
293+ return {}
294+
295+ metric_name = sample .name
296+ # Only return SKU if the metric is one of these which publish these metrics
297+ if (
298+ metric_name != "ray_node_gpus_utilization"
299+ and metric_name != "ray_node_gram_used"
300+ ):
301+ # Skip unexpected metrics
302+ return {}
303+
304+ labels = sample .labels
305+ index = labels ["GpuIndex" ]
306+ value = labels ["GpuDeviceName" ]
307+
308+ metric_name = f"node.{ node_idx } .gpu.{ index } .type"
309+ return {metric_name : value }
310+
311+ def _collect_gpu_sku (self ) -> Dict [str , str ]:
312+ """Collect GPU SKU from all Ray nodes.
313+
314+ Note: This is an internal API and users are not expected to call this.
315+
316+ Returns:
317+ Dictionary of SKU types on all Ray nodes
318+ """
319+ # TODO: We can re-use the same path for metrics because even though both utilization and memory metrics duplicate
320+ # the GPU metadata information; since the metadata is the same for each node, we can overwrite it and expect them to
321+ # be the same
322+ return self ._collect (sku = True )
323+
269324 def _collect_metrics (self ) -> Dict [str , Any ]:
270325 """Collect GPU metrics from all Ray nodes.
271326
272327 Returns:
273328 Dictionary of collected metrics
274329 """
330+ return self ._collect (metrics = True )
331+
332+ def _collect (self , metrics : bool = False , sku : bool = False ) -> Dict [str , Any ]:
333+ """Collect GPU metrics from all Ray nodes.
334+
335+ Returns:
336+ Dictionary of collected metrics
337+ """
338+ assert metrics ^ sku , (
339+ f"Must collect either metrics or sku, not both: { metrics = } , { sku = } "
340+ )
341+ parser_fn = self ._parse_gpu_metric if metrics else self ._parse_gpu_sku
342+
275343 if not ray .is_initialized ():
276344 print ("Ray is not initialized. Cannot collect GPU metrics." )
277345 return {}
@@ -295,7 +363,9 @@ def _collect_metrics(self) -> Dict[str, Any]:
295363 # Process each node's metrics
296364 collected_metrics = {}
297365 for node_idx , metric_address in enumerate (unique_metric_addresses ):
298- gpu_metrics = self ._fetch_and_parse_metrics (node_idx , metric_address )
366+ gpu_metrics = self ._fetch_and_parse_metrics (
367+ node_idx , metric_address , parser_fn
368+ )
299369 collected_metrics .update (gpu_metrics )
300370
301371 return collected_metrics
@@ -304,7 +374,7 @@ def _collect_metrics(self) -> Dict[str, Any]:
304374 print (f"Error collecting GPU metrics: { e } " )
305375 return {}
306376
307- def _fetch_and_parse_metrics (self , node_idx , metric_address ):
377+ def _fetch_and_parse_metrics (self , node_idx , metric_address , parser_fn ):
308378 """Fetch metrics from a node and parse GPU metrics.
309379
310380 Args:
@@ -335,7 +405,7 @@ def _fetch_and_parse_metrics(self, node_idx, metric_address):
335405 continue
336406
337407 for sample in family .samples :
338- metrics = self . _parse_gpu_metric (sample , node_idx )
408+ metrics = parser_fn (sample , node_idx )
339409 gpu_metrics .update (metrics )
340410
341411 return gpu_metrics
@@ -346,18 +416,16 @@ def _fetch_and_parse_metrics(self, node_idx, metric_address):
346416
347417 def flush (self ):
348418 """Flush collected metrics to the parent logger."""
349- if not self .parent_logger :
350- return
351-
352419 with self .lock :
353420 if not self .metrics_buffer :
354421 return
355422
356- # Log each set of metrics with its original step
357- for entry in self .metrics_buffer :
358- step = entry ["step" ]
359- metrics = entry ["metrics" ]
360- self .parent_logger .log_metrics (metrics , step , prefix = "ray" )
423+ if self .parent_logger :
424+ # Log each set of metrics with its original step
425+ for entry in self .metrics_buffer :
426+ step = entry ["step" ]
427+ metrics = entry ["metrics" ]
428+ self .parent_logger .log_metrics (metrics , step , prefix = "ray" )
361429
362430 # Clear buffer after logging
363431 self .metrics_buffer = []
0 commit comments