Skip to content

Commit b5d295f

Browse files
s-noghabiThe tunix Authors
authored andcommitted
make async timeline failure log errors
PiperOrigin-RevId: 889305628
1 parent 2c90667 commit b5d295f

File tree

19 files changed

+1565
-132
lines changed

19 files changed

+1565
-132
lines changed

examples/deepscaler/train_deepscaler_nb.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,8 +578,9 @@ def get_lora_model(base_model, model_mesh):
578578

579579
# Perf Metrics logging
580580
perf_metrics_config = PerfMetricsConfig(
581-
custom_export_fn_v2=PerfMetricsExport(
582-
trace_dir="/tmp/agentic_perf"
581+
custom_export_fn_v2=PerfMetricsExport.from_cluster_config(
582+
cluster_config=cluster_config,
583+
trace_dir="/tmp/agentic_perf",
583584
).export_metrics
584585
)
585586

tests/perf/experimental/export_v2_test.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from absl.testing import parameterized
2121
from tunix.perf.experimental import export
2222
from tunix.perf.experimental import tracer
23+
import jax
24+
from tunix.rl import rl_cluster
2325

2426

2527
class ExportTest(parameterized.TestCase):
@@ -63,19 +65,24 @@ def test_basic_metrics_export(self):
6365
),
6466
)
6567
@mock.patch.object(
66-
export.trace_writer_lib, "PerfettoTraceWriter", autospec=True
68+
export.trace_writer_lib,
69+
"PerfettoTraceWriter",
70+
autospec=True,
71+
spec_set=True,
6772
)
6873
def test_perf_metrics_export_initialization_with_trace_writer_enabled(
6974
self, mock_writer_cls, trace_dir, expected_dir
7075
):
7176
with export.PerfMetricsExport(
7277
enable_trace_writer=True, trace_dir=trace_dir
7378
) as exporter:
74-
mock_writer_cls.assert_called_once_with(expected_dir)
79+
mock_writer_cls.assert_called_once_with(expected_dir, role_to_devices=None)
7580
# export_metrics shouldn't crash
7681
exporter.export_metrics({})
7782

78-
@mock.patch.object(export.trace_writer_lib, "NoopTraceWriter", autospec=True)
83+
@mock.patch.object(
84+
export.trace_writer_lib, "NoopTraceWriter", autospec=True, spec_set=True
85+
)
7986
def test_perf_metrics_export_initialization_with_trace_writer_disabled(
8087
self, mock_noop_cls
8188
):
@@ -109,6 +116,62 @@ def test_perf_metrics_export_shutdown_can_be_called_manually(
109116
mock_executor_instance.shutdown.assert_called_once_with(wait=False)
110117
self.assertIsNone(exporter._executor)
111118

119+
@mock.patch.object(
120+
export.trace_writer_lib,
121+
"PerfettoTraceWriter",
122+
autospec=True,
123+
)
124+
def test_from_cluster_config(self, mock_writer_cls):
125+
import numpy as np
126+
mock_mesh_1 = mock.create_autospec(jax.sharding.Mesh, instance=True)
127+
mock_mesh_1.devices = np.array([["tpu0", "tpu1"], ["tpu2", "tpu3"]])
128+
129+
mock_mesh_2 = mock.create_autospec(jax.sharding.Mesh, instance=True)
130+
mock_mesh_2.devices = np.array([["tpu4", "tpu5"], ["tpu6", "tpu7"]])
131+
132+
mock_cluster_config = mock.create_autospec(
133+
rl_cluster.ClusterConfig, instance=True
134+
)
135+
mock_cluster_config.role_to_mesh = {
136+
rl_cluster.Role.ACTOR: mock_mesh_1,
137+
rl_cluster.Role.ROLLOUT: mock_mesh_2,
138+
}
139+
140+
exporter = export.PerfMetricsExport.from_cluster_config(
141+
mock_cluster_config,
142+
enable_trace_writer=True,
143+
trace_dir="/test/dir",
144+
)
145+
146+
expected_role_to_devices = {
147+
"actor": ["tpu0", "tpu1", "tpu2", "tpu3"],
148+
"rollout": ["tpu4", "tpu5", "tpu6", "tpu7"],
149+
}
150+
mock_writer_cls.assert_called_once_with(
151+
"/test/dir", role_to_devices=expected_role_to_devices
152+
)
153+
self.assertIs(exporter._writer, mock_writer_cls.return_value)
154+
155+
@mock.patch.object(
156+
export.trace_writer_lib,
157+
"PerfettoTraceWriter",
158+
autospec=True,
159+
spec_set=True,
160+
)
161+
def test_from_cluster_config_no_role_to_mesh(self, mock_writer_cls):
162+
mock_cluster_config = mock.create_autospec(
163+
rl_cluster.ClusterConfig, instance=True, spec_set=True
164+
)
165+
del mock_cluster_config.role_to_mesh
166+
167+
exporter = export.PerfMetricsExport.from_cluster_config(
168+
mock_cluster_config,
169+
enable_trace_writer=True,
170+
trace_dir="/test/dir",
171+
)
172+
173+
mock_writer_cls.assert_called_once_with("/test/dir", role_to_devices={})
174+
self.assertIs(exporter._writer, mock_writer_cls.return_value)
112175

113176
if __name__ == "__main__":
114177
absltest.main()

tests/perf/experimental/timeline_test.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,19 @@ def fail_wait(waitlist, success, failure):
234234

235235
self.mock_async_wait.side_effect = fail_wait
236236

237-
with self.assertRaisesRegex(RuntimeError, "failed"):
238-
t.span("failed", 1.0, ["wait"])
237+
with mock.patch.object(timeline.logging, "error") as mock_log_error:
238+
t.span("failed_op", 1.0, ["wait"])
239+
# Exception is caught and logged, no exception is raised to the caller.
240+
mock_log_error.assert_called_once()
241+
args, kwargs = mock_log_error.call_args
242+
format_str, name, span_id, err = args
243+
self.assertEqual(format_str, "Timeline span '%s' (id=%d) failed: %s")
244+
self.assertEqual(name, "failed_op")
245+
self.assertEqual(span_id, 0)
246+
self.assertIsInstance(err, RuntimeError)
247+
self.assertEqual(str(err), "failed")
248+
self.assertIn("exc_info", kwargs)
249+
self.assertIsInstance(kwargs["exc_info"], RuntimeError)
239250

240251
def test_wait_pending_spans_clears_threads(self):
241252
t = timeline.AsyncTimeline("test_tl", 0.0)

0 commit comments

Comments
 (0)