Skip to content

Commit 757c679

Browse files
yangustc07copybara-github
authored andcommitted
Add metrics for grain source per-example read.
PiperOrigin-RevId: 836322135
1 parent bff9985 commit 757c679

File tree

1 file changed

+44
-1
lines changed
  • grain/_src/python/dataset/transformations

1 file changed

+44
-1
lines changed

grain/_src/python/dataset/transformations/source.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,53 @@
1515

1616
from __future__ import annotations
1717

18+
import threading
19+
import time
1820
from typing import Any, Sequence, Union
1921

2022
from absl import logging
23+
from grain._src.core import monitoring as grain_monitoring
2124
from grain._src.core import sharding
2225
from grain._src.python import options
2326
from grain._src.python.dataset import base
2427
from grain._src.python.dataset import dataset
2528
from grain._src.python.dataset import stats as dataset_stats
2629
import numpy as np
2730

31+
from grain._src.core import monitoring
32+
33+
34+
_source_read_time_ns_histogram = monitoring.EventMetric(
35+
"/grain/python/dataset/source_read_time_ns",
36+
metadata=monitoring.Metadata(
37+
description="Histogram of source read time in nanoseconds.",
38+
units=monitoring.Units.NANOSECONDS,
39+
),
40+
root=grain_monitoring.get_monitoring_root(),
41+
fields=[("source", str)],
42+
bucketer=monitoring.Bucketer.PowersOf(2.0),
43+
)
44+
45+
_metric_lock = threading.Lock()
46+
47+
48+
def _maybe_record_source_read_time(
49+
elapsed_time_ns: int, source_name: str
50+
) -> None:
51+
"""Records the source read time in nanoseconds if metric lock is available.
52+
53+
To avoid contention and potential slowness, we only record the time if the
54+
lock is immediately available.
55+
56+
Args:
57+
elapsed_time_ns: The elapsed time in nanoseconds.
58+
source_name: The name of the source.
59+
"""
60+
61+
if _metric_lock.acquire(blocking=False):
62+
_source_read_time_ns_histogram.Record(elapsed_time_ns, source_name)
63+
_metric_lock.release()
64+
2865

2966
class SourceMapDataset(dataset.MapDataset):
3067
"""Simple wrapper for random access data sources."""
@@ -45,7 +82,13 @@ def __getitem__(self, index):
4582
if isinstance(index, slice):
4683
return self.slice(index)
4784
with self._stats.record_self_time():
48-
return self._stats.record_output_spec(self._source[index % len(self)])
85+
start_time = time.perf_counter_ns()
86+
result = self._stats.record_output_spec(self._source[index % len(self)])
87+
stop_time = time.perf_counter_ns()
88+
_maybe_record_source_read_time(
89+
stop_time - start_time, self._source.__class__.__name__
90+
)
91+
return result
4992

5093
def _getitems(self, indices: Sequence[int]):
5194
if not isinstance(

0 commit comments

Comments
 (0)