1515
1616from __future__ import annotations
1717
18+ import threading
19+ import time
1820from typing import Any , Sequence , Union
1921
2022from absl import logging
23+ from grain ._src .core import monitoring as grain_monitoring
2124from grain ._src .core import sharding
2225from grain ._src .python import options
2326from grain ._src .python .dataset import base
2427from grain ._src .python .dataset import dataset
2528from grain ._src .python .dataset import stats as dataset_stats
2629import numpy as np
2730
31+ from grain ._src .core import monitoring
32+
33+
34+ _source_read_time_ns_histogram = monitoring .EventMetric (
35+ "/grain/python/dataset/source_read_time_ns" ,
36+ metadata = monitoring .Metadata (
37+ description = "Histogram of source read time in nanoseconds." ,
38+ units = monitoring .Units .NANOSECONDS ,
39+ ),
40+ root = grain_monitoring .get_monitoring_root (),
41+ fields = [("source" , str )],
42+ bucketer = monitoring .Bucketer .PowersOf (2.0 ),
43+ )
44+
45+ _metric_lock = threading .Lock ()
46+
47+
48+ def _maybe_record_source_read_time (
49+ elapsed_time_ns : int , source_name : str
50+ ) -> None :
51+ """Records the source read time in nanoseconds if metric lock is available.
52+
53+ To avoid contention and potential slowness, we only record the time if the
54+ lock is immediately available.
55+
56+ Args:
57+ elapsed_time_ns: The elapsed time in nanoseconds.
58+ source_name: The name of the source.
59+ """
60+
61+ if _metric_lock .acquire (blocking = False ):
62+ _source_read_time_ns_histogram .Record (elapsed_time_ns , source_name )
63+ _metric_lock .release ()
64+
2865
2966class SourceMapDataset (dataset .MapDataset ):
3067 """Simple wrapper for random access data sources."""
@@ -45,7 +82,13 @@ def __getitem__(self, index):
4582 if isinstance (index , slice ):
4683 return self .slice (index )
4784 with self ._stats .record_self_time ():
48- return self ._stats .record_output_spec (self ._source [index % len (self )])
85+ start_time = time .perf_counter_ns ()
86+ result = self ._stats .record_output_spec (self ._source [index % len (self )])
87+ stop_time = time .perf_counter_ns ()
88+ _maybe_record_source_read_time (
89+ stop_time - start_time , self ._source .__class__ .__name__
90+ )
91+ return result
4992
5093 def _getitems (self , indices : Sequence [int ]):
5194 if not isinstance (
0 commit comments