@@ -188,9 +188,9 @@ def add_packet_side_effect():
188188 actual_descriptors = [
189189 {
190190 "name" : p .track_descriptor .name ,
191- "parent_uuid" : p .track_descriptor . parent_uuid if i > 0 else None ,
191+ "parent_uuid" : getattr ( p .track_descriptor , " parent_uuid" , None ) ,
192192 }
193- for i , p in enumerate ( captured_packets [:3 ])
193+ for p in captured_packets [:3 ]
194194 ]
195195 expected_descriptors = [
196196 {"name" : "overlap_timeline" , "parent_uuid" : None },
@@ -265,6 +265,71 @@ def add_packet_side_effect():
265265 self .assertEqual (actual_events , expected_events )
266266 mock_builder .serialize .assert_called_once ()
267267
268+ @mock .patch .object (trace_writer_lib , "TraceProtoBuilder" , autospec = True )
269+ def test_write_timelines_grouping (self , mock_builder_cls ):
270+ mock_builder = mock_builder_cls .return_value
271+ mock_builder .serialize .return_value = b""
272+ captured_packets = []
273+
274+ def add_packet_side_effect ():
275+ p = mock .create_autospec (TracePacket , instance = True )
276+ p .track_descriptor = mock .create_autospec (TrackDescriptor , instance = True )
277+ p .track_event = mock .create_autospec (TrackEvent , instance = True )
278+ captured_packets .append (p )
279+ return p
280+
281+ mock_builder .add_packet .side_effect = add_packet_side_effect
282+
283+ with tempfile .TemporaryDirectory () as tmp_dir :
284+ writer = trace_writer_lib .PerfettoTraceWriter (
285+ trace_dir = tmp_dir , role_to_devices = {"actor" : ["tpu0" , "tpu1" ]}
286+ )
287+
288+ t_main = tracer .Timeline ("host-1" , 1000.0 )
289+ t_main .start_span ("main_span" , 1001.0 )
290+
291+ t_rollout = tracer .Timeline ("host-2" , 1000.0 )
292+ t_rollout .start_span ("rollout" , 1002.0 )
293+
294+ t_tpu = tracer .Timeline ("tpu0" , 1000.0 )
295+ t_tpu .start_span ("compute" , 1003.0 )
296+
297+ writer .write_timelines ({
298+ "host-1" : t_main ,
299+ "host-2" : t_rollout ,
300+ "tpu0" : t_tpu ,
301+ })
302+
303+ main_group = captured_packets [0 ].track_descriptor
304+ rollout_group = captured_packets [1 ].track_descriptor
305+ tpu_group = captured_packets [2 ].track_descriptor
306+ host_1 = captured_packets [3 ].track_descriptor
307+ host_2 = captured_packets [4 ].track_descriptor
308+ tpu0 = captured_packets [5 ].track_descriptor
309+
310+ with self .subTest ("host_main_threads_group" ):
311+ self .assertEqual (main_group .name , "Host - Main threads" )
312+ self .assertEqual (main_group .uuid , 100000 )
313+
314+ with self .subTest ("host_rollout_threads_group" ):
315+ self .assertEqual (rollout_group .name , "Host - Rollout threads" )
316+ self .assertEqual (rollout_group .uuid , 100001 )
317+
318+ with self .subTest ("actor_cluster" ):
319+ self .assertEqual (tpu_group .name , "Actor Cluster" )
320+
321+ with self .subTest ("host_1" ):
322+ self .assertEqual (host_1 .name , "host-1" )
323+ self .assertEqual (host_1 .parent_uuid , 100000 )
324+
325+ with self .subTest ("host_2" ):
326+ self .assertEqual (host_2 .name , "host-2" )
327+ self .assertEqual (host_2 .parent_uuid , 100001 )
328+
329+ with self .subTest ("tpu0" ):
330+ self .assertEqual (tpu0 .name , "tpu0" )
331+ self .assertEqual (tpu0 .parent_uuid , tpu_group .uuid )
332+
268333 def test_perfetto_trace_writer_integration (self ):
269334 with tempfile .TemporaryDirectory () as tmp_dir :
270335 writer = trace_writer_lib .PerfettoTraceWriter (trace_dir = tmp_dir )
@@ -293,15 +358,16 @@ def test_perfetto_trace_writer_integration(self):
293358 self .assertLen (files , 1 )
294359
295360 if files :
361+ file_name = files [0 ]
296362 with self .subTest ("file_name_prefix" ):
297- self .assertStartsWith (files [ 0 ] , "perfetto_trace_v2_" )
363+ self .assertStartsWith (file_name , "perfetto_trace_v2_" )
298364
299365 with self .subTest ("file_name_suffix" ):
300- self .assertEndsWith (files [ 0 ] , ".pb" )
366+ self .assertEndsWith (file_name , ".pb" )
301367
302368 with self .subTest ("file_content" ):
303369 self .assertGreater (
304- os .path .getsize (os .path .join (tmp_dir , files [ 0 ] )), 0
370+ os .path .getsize (os .path .join (tmp_dir , file_name )), 0
305371 )
306372
307373 def test_perfetto_trace_writer_invalid_dir (self ):
0 commit comments