Skip to content

Commit f23dd64

Browse files
Merge pull request jax-ml#26853 from jeffcarp:scalar-event
PiperOrigin-RevId: 744807700
2 parents e1e37f8 + 123ce52 commit f23dd64

File tree

3 files changed

+71
-2
lines changed

3 files changed

+71
-2
lines changed

jax/_src/monitoring.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,18 @@ def __call__(
4646
) -> None:
4747
...
4848

49+
class ScalarListenerWithMetadata(Protocol):
50+
51+
def __call__(
52+
self, event: str, value: float | int, **kwargs: str | int,
53+
) -> None:
54+
...
55+
4956

5057
_event_listeners: list[EventListenerWithMetadata] = []
5158
_event_duration_secs_listeners: list[EventDurationListenerWithMetadata] = []
5259
_event_time_span_listeners: list[EventTimeSpanListenerWithMetadata] = []
60+
_scalar_listeners: list[ScalarListenerWithMetadata] = []
5361

5462

5563
def record_event(event: str, **kwargs: str | int) -> None:
@@ -81,6 +89,14 @@ def record_event_time_span(
8189
callback(event, start_time, end_time, **kwargs)
8290

8391

92+
def record_scalar(
93+
event: str, value: float | int, **kwargs: str | int
94+
) -> None:
95+
"""Record a scalar summary value."""
96+
for callback in _scalar_listeners:
97+
callback(event, value, **kwargs)
98+
99+
84100
def register_event_listener(
85101
callback: EventListenerWithMetadata,
86102
) -> None:
@@ -100,6 +116,14 @@ def register_event_duration_secs_listener(
100116
"""Register a callback to be invoked during record_event_duration_secs()."""
101117
_event_duration_secs_listeners.append(callback)
102118

119+
120+
def register_scalar_listener(
121+
callback : ScalarListenerWithMetadata,
122+
) -> None:
123+
"""Register a callback to be invoked during record_scalar()."""
124+
_scalar_listeners.append(callback)
125+
126+
103127
def get_event_duration_listeners() -> list[EventDurationListenerWithMetadata]:
104128
"""Get event duration listeners."""
105129
return list(_event_duration_secs_listeners)
@@ -114,12 +138,20 @@ def get_event_listeners() -> list[EventListenerWithMetadata]:
114138
"""Get event listeners."""
115139
return list(_event_listeners)
116140

141+
142+
def get_scalar_listeners() -> list[ScalarListenerWithMetadata]:
143+
"""Get scalar event listeners."""
144+
return list(_scalar_listeners)
145+
146+
117147
def clear_event_listeners():
118148
"""Clear event listeners."""
119149
global _event_listeners, _event_duration_secs_listeners, _event_time_span_listeners
120150
_event_listeners = []
121151
_event_duration_secs_listeners = []
122152
_event_time_span_listeners = []
153+
_scalar_listeners = []
154+
123155

124156
def _unregister_event_duration_listener_by_callback(
125157
callback: EventDurationListenerWithMetadata) -> None:
@@ -159,3 +191,14 @@ def _unregister_event_listener_by_callback(
159191
"""
160192
assert callback in _event_listeners
161193
_event_listeners.remove(callback)
194+
195+
196+
def _unregister_scalar_listener_by_callback(
197+
callback: ScalarListenerWithMetadata,
198+
) -> None:
199+
"""Unregister a scalar event listener by callback.
200+
201+
This function is supposed to be called for testing only.
202+
"""
203+
assert callback in _scalar_listeners
204+
_scalar_listeners.remove(callback)

jax/monitoring.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
record_event_duration_secs as record_event_duration_secs,
2727
record_event_time_span as record_event_time_span,
2828
record_event as record_event,
29+
record_scalar as record_scalar,
2930
register_event_duration_secs_listener as register_event_duration_secs_listener,
3031
register_event_listener as register_event_listener,
3132
register_event_time_span_listener as register_event_time_span_listener,
33+
register_scalar_listener as register_scalar_listener,
3234
)

tests/monitoring_test.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def tearDown(self):
2929

3030
def test_record_event(self):
3131
events = []
32-
counters = {} # Map event names to frequency.
32+
counters = {} # Map event names to frequency.
3333
def increment_event_counter(event):
3434
if event not in counters:
3535
counters[event] = 0
@@ -48,7 +48,7 @@ def increment_event_counter(event):
4848
"test_common_event": 2})
4949

5050
def test_record_event_durations(self):
51-
durations = {} # Map event names to frequency.
51+
durations = {} # Map event names to frequency.
5252
def increment_event_duration(event, duration):
5353
if event not in durations:
5454
durations[event] = 0.
@@ -62,6 +62,30 @@ def increment_event_duration(event, duration):
6262
self.assertDictEqual(durations, {"test_short_event": 3,
6363
"test_long_event": 10})
6464

65+
def test_record_scalar(self):
66+
observed_keys = []
67+
observed_values = []
68+
69+
monitoring.register_scalar_listener(
70+
lambda key, _: observed_keys.append(key),
71+
)
72+
monitoring.register_scalar_listener(
73+
lambda _, value: observed_values.append(value),
74+
)
75+
76+
monitoring.record_scalar("test_unique_event", 1)
77+
monitoring.record_scalar("test_common_event", 2.5)
78+
monitoring.record_scalar("test_common_event", 5e5)
79+
80+
self.assertListEqual(
81+
observed_keys,
82+
["test_unique_event", "test_common_event", "test_common_event"],
83+
)
84+
self.assertListEqual(
85+
observed_values,
86+
[1, 2.5, 5e5],
87+
)
88+
6589
def test_unregister_exist_callback_success(self):
6690
original_duration_listeners = jax_src_monitoring.get_event_duration_listeners()
6791
callback = lambda event, durations: None

0 commit comments

Comments
 (0)