Skip to content

Commit b145c8d

Browse files
s-noghabiThe tunix Authors
authored andcommitted
instrument environment interactions for perf metrics
PiperOrigin-RevId: 888837273
1 parent 8ea9ef7 commit b145c8d

File tree

9 files changed

+1193
-112
lines changed

9 files changed

+1193
-112
lines changed

tests/perf/experimental/timeline_utils_test.py

Lines changed: 426 additions & 0 deletions
Large diffs are not rendered by default.

tests/perf/experimental/trace_writer_test.py

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,9 @@ def add_packet_side_effect():
188188
actual_descriptors = [
189189
{
190190
"name": p.track_descriptor.name,
191-
"parent_uuid": p.track_descriptor.parent_uuid if i > 0 else None,
191+
"parent_uuid": getattr(p.track_descriptor, "parent_uuid", None),
192192
}
193-
for i, p in enumerate(captured_packets[:3])
193+
for p in captured_packets[:3]
194194
]
195195
expected_descriptors = [
196196
{"name": "overlap_timeline", "parent_uuid": None},
@@ -265,6 +265,71 @@ def add_packet_side_effect():
265265
self.assertEqual(actual_events, expected_events)
266266
mock_builder.serialize.assert_called_once()
267267

268+
@mock.patch.object(trace_writer_lib, "TraceProtoBuilder", autospec=True)
269+
def test_write_timelines_grouping(self, mock_builder_cls):
270+
mock_builder = mock_builder_cls.return_value
271+
mock_builder.serialize.return_value = b""
272+
captured_packets = []
273+
274+
def add_packet_side_effect():
275+
p = mock.create_autospec(TracePacket, instance=True)
276+
p.track_descriptor = mock.create_autospec(TrackDescriptor, instance=True)
277+
p.track_event = mock.create_autospec(TrackEvent, instance=True)
278+
captured_packets.append(p)
279+
return p
280+
281+
mock_builder.add_packet.side_effect = add_packet_side_effect
282+
283+
with tempfile.TemporaryDirectory() as tmp_dir:
284+
writer = trace_writer_lib.PerfettoTraceWriter(
285+
trace_dir=tmp_dir, role_to_devices={"actor": ["tpu0", "tpu1"]}
286+
)
287+
288+
t_main = tracer.Timeline("host-1", 1000.0)
289+
t_main.start_span("main_span", 1001.0)
290+
291+
t_rollout = tracer.Timeline("host-2", 1000.0)
292+
t_rollout.start_span("rollout", 1002.0)
293+
294+
t_tpu = tracer.Timeline("tpu0", 1000.0)
295+
t_tpu.start_span("compute", 1003.0)
296+
297+
writer.write_timelines({
298+
"host-1": t_main,
299+
"host-2": t_rollout,
300+
"tpu0": t_tpu,
301+
})
302+
303+
main_group = captured_packets[0].track_descriptor
304+
rollout_group = captured_packets[1].track_descriptor
305+
tpu_group = captured_packets[2].track_descriptor
306+
host_1 = captured_packets[3].track_descriptor
307+
host_2 = captured_packets[4].track_descriptor
308+
tpu0 = captured_packets[5].track_descriptor
309+
310+
with self.subTest("host_main_threads_group"):
311+
self.assertEqual(main_group.name, "Host - Main threads")
312+
self.assertEqual(main_group.uuid, 100000)
313+
314+
with self.subTest("host_rollout_threads_group"):
315+
self.assertEqual(rollout_group.name, "Host - Rollout threads")
316+
self.assertEqual(rollout_group.uuid, 100001)
317+
318+
with self.subTest("actor_cluster"):
319+
self.assertEqual(tpu_group.name, "Actor Cluster")
320+
321+
with self.subTest("host_1"):
322+
self.assertEqual(host_1.name, "host-1")
323+
self.assertEqual(host_1.parent_uuid, 100000)
324+
325+
with self.subTest("host_2"):
326+
self.assertEqual(host_2.name, "host-2")
327+
self.assertEqual(host_2.parent_uuid, 100001)
328+
329+
with self.subTest("tpu0"):
330+
self.assertEqual(tpu0.name, "tpu0")
331+
self.assertEqual(tpu0.parent_uuid, tpu_group.uuid)
332+
268333
def test_perfetto_trace_writer_integration(self):
269334
with tempfile.TemporaryDirectory() as tmp_dir:
270335
writer = trace_writer_lib.PerfettoTraceWriter(trace_dir=tmp_dir)
@@ -293,15 +358,16 @@ def test_perfetto_trace_writer_integration(self):
293358
self.assertLen(files, 1)
294359

295360
if files:
361+
file_name = files[0]
296362
with self.subTest("file_name_prefix"):
297-
self.assertStartsWith(files[0], "perfetto_trace_v2_")
363+
self.assertStartsWith(file_name, "perfetto_trace_v2_")
298364

299365
with self.subTest("file_name_suffix"):
300-
self.assertEndsWith(files[0], ".pb")
366+
self.assertEndsWith(file_name, ".pb")
301367

302368
with self.subTest("file_content"):
303369
self.assertGreater(
304-
os.path.getsize(os.path.join(tmp_dir, files[0])), 0
370+
os.path.getsize(os.path.join(tmp_dir, file_name)), 0
305371
)
306372

307373
def test_perfetto_trace_writer_invalid_dir(self):

tests/perf/experimental/tracer_test.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from unittest import mock
2222

2323
from absl.testing import absltest
24-
from absl.testing import parameterized
2524
import numpy as np
2625
from tunix.perf.experimental import timeline
2726
from tunix.perf.experimental import tracer
@@ -79,45 +78,6 @@ def trigger_all(self) -> None:
7978
t.join()
8079

8180

82-
class TracerUtilsTest(parameterized.TestCase):
83-
84-
def test_generate_host_timeline_id(self):
85-
tid = tracer.generate_host_timeline_id()
86-
self.assertStartsWith(tid, "host-")
87-
self.assertIn(str(threading.get_ident()), tid)
88-
89-
@parameterized.named_parameters(
90-
("string", "tpu0", "tpu0"),
91-
("device_object", MockDevice("gpu", 7), "gpu7"),
92-
)
93-
def test_generate_device_timeline_id(self, device_id, expected_id):
94-
self.assertEqual(tracer.generate_device_timeline_id(device_id), expected_id)
95-
96-
def test_generate_device_timeline_id_error(self):
97-
with self.assertRaisesRegex(ValueError, "Unsupported id type"):
98-
tracer.generate_device_timeline_id(123)
99-
100-
@parameterized.named_parameters(
101-
("none", None, []),
102-
("mixed_list", ["dev1", MockDevice("tpu", 0)], ["dev1", "tpu0"]),
103-
(
104-
"numpy_array",
105-
np.array([MockDevice("tpu", 0), MockDevice("tpu", 1)]),
106-
["tpu0", "tpu1"],
107-
),
108-
(
109-
"numpy_array_2d",
110-
np.array([
111-
[MockDevice("tpu", 0), MockDevice("tpu", 1)],
112-
[MockDevice("tpu", 2), MockDevice("tpu", 3)],
113-
]),
114-
["tpu0", "tpu1", "tpu2", "tpu3"],
115-
),
116-
)
117-
def test_generate_device_timeline_ids(self, devices, expected_ids):
118-
self.assertEqual(tracer.generate_device_timeline_ids(devices), expected_ids)
119-
120-
12181
class PerfTracerTest(absltest.TestCase):
12282

12383
def setUp(self):

tunix/perf/experimental/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ROLE = "role"
2323
GROUP_ID = "group_id"
2424
PAIR_INDEX = "pair_index"
25+
QUEUED_SPAN = "queued_span"
2526

2627
# Common Span / Event names.
2728

@@ -33,3 +34,6 @@
3334
OLD_ACTOR_INFERENCE = "old_actor_inference"
3435
ADVANTAGE_COMPUTATION = "advantage_computation"
3536
PEFT_TRAIN = "peft_train"
37+
IDLE = "idle"
38+
QUEUE = "queue"
39+
ENVIRONMENT = "environment"

0 commit comments

Comments
 (0)