@@ -30,27 +30,39 @@ def _make_config(
3030 )
3131
3232
33- def _expected_trace_path (config : PerfTracerConfig ) -> Path :
33+ def _expected_trace_path (
34+ config : PerfTracerConfig ,
35+ * ,
36+ rank : int ,
37+ ) -> Path :
3438 base_dir = Path (os .path .expanduser (config .fileroot ))
39+ filename = f"traces-r{ rank } .jsonl"
3540 return (
3641 base_dir
3742 / "logs"
3843 / getpass .getuser ()
3944 / config .experiment_name
4045 / config .trial_name
41- / "traces.jsonl"
46+ / "perf_tracer"
47+ / filename
4248 )
4349
4450
45- def _expected_request_trace_path (config : PerfTracerConfig ) -> Path :
51+ def _expected_request_trace_path (
52+ config : PerfTracerConfig ,
53+ * ,
54+ rank : int ,
55+ ) -> Path :
4656 base_dir = Path (os .path .expanduser (config .fileroot ))
57+ filename = f"requests-r{ rank } .jsonl"
4758 return (
4859 base_dir
4960 / "logs"
5061 / getpass .getuser ()
5162 / config .experiment_name
5263 / config .trial_name
53- / "requests.jsonl"
64+ / "request_tracer"
65+ / filename
5466 )
5567
5668
@@ -70,18 +82,10 @@ def test_module_level_helpers_require_configuration():
7082 perf_tracer .get_tracer ()
7183
7284
73- @pytest .mark .parametrize ("override_rank" , [None , 1 ])
74- def test_perf_tracer_records_events_and_save (tmp_path , override_rank ):
85+ def test_perf_tracer_records_events_and_save (tmp_path ):
7586 config = _make_config (tmp_path , experiment = "unit" , trial = "scope" )
76- base_rank = 0
77- tracer = perf_tracer .PerfTracer (config , rank = base_rank )
78- if override_rank is None :
79- expected_rank = base_rank
80- else :
81- tracer .set_rank (override_rank )
82- expected_rank = override_rank
83-
84- assert tracer ._rank == expected_rank # noqa: SLF001
87+ tracer = perf_tracer .PerfTracer (config , rank = 0 )
88+ assert tracer ._rank == 0 # noqa: SLF001
8589
8690 with tracer .trace_scope (
8791 "unit-block" ,
@@ -92,15 +96,15 @@ def test_perf_tracer_records_events_and_save(tmp_path, override_rank):
9296 tracer .instant ("outer-mark" )
9397
9498 tracer .save ()
95- saved_path = _expected_trace_path (config )
99+ saved_path = _expected_trace_path (config , rank = 0 )
96100 assert saved_path .exists ()
97101
98102 events = _load_trace_events (saved_path )
99103 event_names = {evt ["name" ] for evt in events if evt ["ph" ] != "M" }
100104 assert {"unit-block" , "inner-mark" , "outer-mark" }.issubset (event_names )
101105
102106
103- def test_perf_tracer_aggregate_combines_ranks (tmp_path ):
107+ def test_perf_tracer_emits_separate_rank_logs (tmp_path ):
104108 config0 = _make_config (
105109 tmp_path ,
106110 enabled = True ,
@@ -112,8 +116,8 @@ def test_perf_tracer_aggregate_combines_ranks(tmp_path):
112116 pass
113117 tracer0 .instant ("rank0-mark" , args = {"rank" : 0 })
114118 tracer0 .save ()
115- saved_path = _expected_trace_path (config0 )
116- assert saved_path .exists ()
119+ saved_path_rank0 = _expected_trace_path (config0 , rank = 0 )
120+ assert saved_path_rank0 .exists ()
117121
118122 config1 = _make_config (
119123 tmp_path ,
@@ -127,36 +131,48 @@ def test_perf_tracer_aggregate_combines_ranks(tmp_path):
127131 tracer1_thread .clear ()
128132 tracer1 .instant ("rank1-mark" , args = {"rank" : 1 })
129133 tracer1 .save ()
130- saved_path_rank1 = _expected_trace_path (config1 )
131- assert saved_path_rank1 == saved_path
134+ saved_path_rank1 = _expected_trace_path (config1 , rank = 1 )
135+ assert saved_path_rank1 .exists ()
136+
137+ events_rank0 = _load_trace_events (saved_path_rank0 )
138+ events_rank1 = _load_trace_events (saved_path_rank1 )
139+
140+ def _non_meta (events : list [dict ]) -> list [dict ]:
141+ return [evt for evt in events if evt .get ("ph" ) != "M" ]
142+
143+ def _meta (events : list [dict ], name : str ) -> list [dict ]:
144+ return [
145+ evt for evt in events if evt .get ("ph" ) == "M" and evt .get ("name" ) == name
146+ ]
147+
148+ event_names_rank0 = {evt ["name" ] for evt in _non_meta (events_rank0 )}
149+ event_names_rank1 = {evt ["name" ] for evt in _non_meta (events_rank1 )}
150+ assert {"rank0-step" , "rank0-mark" }.issubset (event_names_rank0 )
151+ assert {"rank1-mark" }.issubset (event_names_rank1 )
152+
153+ ranks_rank0 = {evt ["args" ].get ("rank" ) for evt in _non_meta (events_rank0 )}
154+ ranks_rank1 = {evt ["args" ].get ("rank" ) for evt in _non_meta (events_rank1 )}
155+ assert ranks_rank0 == {0 }
156+ assert ranks_rank1 == {1 }
157+
158+ pid_rank0 = {evt ["pid" ] for evt in _non_meta (events_rank0 )}
159+ pid_rank1 = {evt ["pid" ] for evt in _non_meta (events_rank1 )}
160+ assert pid_rank0 == {tracer0 ._pid } # noqa: SLF001
161+ assert pid_rank1 == {tracer1 ._pid } # noqa: SLF001
162+
163+ process_name_rank0 = _meta (events_rank0 , "process_name" )
164+ process_name_rank1 = _meta (events_rank1 , "process_name" )
165+ assert any (
166+ evt ["args" ].get ("name" ) == "Rank 0, Process" for evt in process_name_rank0
167+ )
168+ assert any (
169+ evt ["args" ].get ("name" ) == "Rank 1, Process" for evt in process_name_rank1
170+ )
132171
133- events = _load_trace_events (saved_path )
134- event_names = {evt ["name" ] for evt in events if evt ["ph" ] != "M" }
135- assert {"rank0-step" , "rank0-mark" , "rank1-mark" }.issubset (event_names )
136- pid_values = {evt ["pid" ] for evt in events if evt ["ph" ] != "M" }
137- assert pid_values == {tracer0 ._pid , tracer1 ._pid } # noqa: SLF001
138- rank_values = {evt ["args" ].get ("rank" ) for evt in events if evt ["ph" ] != "M" }
139- assert {0 , 1 }.issubset (rank_values )
140- meta_by_pid = {
141- (evt ["pid" ], evt ["args" ].get ("name" ))
142- for evt in events
143- if evt ["ph" ] == "M" and evt ["name" ] == "process_name"
144- }
145- assert (
146- tracer0 ._pid ,
147- "Rank 0, Process" ,
148- ) in meta_by_pid # noqa: SLF001
149- assert (
150- tracer1 ._pid ,
151- "Rank 1, Process" ,
152- ) in meta_by_pid # noqa: SLF001
153- sort_meta = {
154- evt ["pid" ]: evt ["args" ].get ("sort_index" )
155- for evt in events
156- if evt ["ph" ] == "M" and evt ["name" ] == "process_sort_index"
157- }
158- assert sort_meta [tracer0 ._pid ] == 0 # noqa: SLF001
159- assert sort_meta [tracer1 ._pid ] == 1 # noqa: SLF001
172+ sort_meta_rank0 = _meta (events_rank0 , "process_sort_index" )
173+ sort_meta_rank1 = _meta (events_rank1 , "process_sort_index" )
174+ assert any (evt ["args" ].get ("sort_index" ) == 0 for evt in sort_meta_rank0 )
175+ assert any (evt ["args" ].get ("sort_index" ) == 1 for evt in sort_meta_rank1 )
160176
161177
162178@pytest .mark .asyncio
@@ -181,7 +197,7 @@ async def test_global_tracer_configure_roundtrip(tmp_path):
181197 pass
182198
183199 tracer .save ()
184- saved_path = _expected_trace_path (config )
200+ saved_path = _expected_trace_path (config , rank = 0 )
185201 assert saved_path .exists ()
186202 events = _load_trace_events (saved_path )
187203 event_names = {evt ["name" ] for evt in events if evt ["ph" ] != "M" }
@@ -232,7 +248,7 @@ async def run_request(
232248 )
233249
234250 tracer .save ()
235- saved_path = _expected_trace_path (config )
251+ saved_path = _expected_trace_path (config , rank = 0 )
236252 assert saved_path .exists ()
237253 events = [evt for evt in _load_trace_events (saved_path ) if evt .get ("ph" ) != "M" ]
238254
@@ -263,23 +279,17 @@ async def run_request(
263279 assert overlap
264280
265281
266- def test_configure_preserves_output_path_when_rank_changes (tmp_path ):
282+ def test_configure_rejects_repeated_calls (tmp_path ):
267283 config = _make_config (tmp_path , experiment = "ranked" , trial = "zero" )
268- tracer = perf_tracer .configure (
269- config ,
270- rank = 0 ,
271- )
272- expected_path = _expected_trace_path (config )
273- first_path = Path (tracer ._output_path or "" ) # noqa: SLF001
274- assert first_path == expected_path
275-
276284 perf_tracer .configure (
277285 config ,
278- rank = 1 ,
286+ rank = 0 ,
279287 )
280- second_path = Path (tracer ._output_path or "" ) # noqa: SLF001
281- assert second_path == expected_path
282- assert tracer ._rank == 1 # noqa: SLF001
288+ with pytest .raises (RuntimeError ):
289+ perf_tracer .configure (
290+ config ,
291+ rank = 1 ,
292+ )
283293
284294
285295def test_module_level_save_helper (tmp_path ):
@@ -291,9 +301,9 @@ def test_module_level_save_helper(tmp_path):
291301 perf_tracer .instant ("module-level-mark" , args = {"flag" : True })
292302
293303 perf_tracer .save ()
294- saved_path = _expected_trace_path (config )
304+ saved_path = _expected_trace_path (config , rank = 0 )
295305 assert saved_path .exists ()
296- assert saved_path == _expected_trace_path (config )
306+ assert saved_path == _expected_trace_path (config , rank = 0 )
297307 events = _load_trace_events (saved_path )
298308 event_names = {evt ["name" ] for evt in events if evt .get ("ph" ) != "M" }
299309 assert "module-level-mark" in event_names
@@ -303,7 +313,7 @@ def test_perf_tracer_respects_save_interval(tmp_path):
303313 config = _make_config (tmp_path , experiment = "interval" , trial = "steps" )
304314 config .save_interval = 3
305315 tracer = perf_tracer .PerfTracer (config , rank = 0 )
306- trace_path = _expected_trace_path (config )
316+ trace_path = _expected_trace_path (config , rank = 0 )
307317
308318 for step in (0 , 1 ):
309319 tracer .instant (f"mark-{ step } " , args = {"step" : step })
@@ -346,16 +356,15 @@ def test_request_tracer_configuration(tmp_path):
346356 request_tracer .mark_consumed (request_id )
347357 tracer .save (force = True )
348358
349- request_path = _expected_request_trace_path (config )
359+ request_path = _expected_request_trace_path (config , rank = 0 )
350360 assert request_path .exists ()
351361 payload = [json .loads (line ) for line in request_path .read_text ().splitlines ()]
352362 assert any (entry ["status" ] == "accepted" for entry in payload )
353363
354364 updated = _make_config (tmp_path , experiment = "request" , trial = "enabled" )
355365 updated .request_tracer = RequestTracerConfig (enabled = False )
356- tracer .apply_config (updated , rank = 1 )
357- assert tracer .request_tracer is None
358- assert tracer ._rank == 1 # noqa: SLF001
366+ tracer_disabled = perf_tracer .PerfTracer (updated , rank = 1 )
367+ assert tracer_disabled .request_tracer is None
359368
360369
361370def _run_perf_tracer_torchrun (tmp_path : Path , world_size : int ) -> None :
@@ -396,9 +405,15 @@ def test_perf_tracer_torchrun_multi_rank(tmp_path, world_size):
396405 fileroot = str (tmp_path ),
397406 enabled = True ,
398407 )
399- trace_path = _expected_trace_path (config )
400- assert trace_path .exists ()
401- payload = _load_trace_events (trace_path )
408+ trace_paths = [
409+ _expected_trace_path (config , rank = rank ) for rank in range (world_size )
410+ ]
411+ for path in trace_paths :
412+ assert path .exists ()
413+
414+ payload : list [dict ] = []
415+ for path in trace_paths :
416+ payload .extend (_load_trace_events (path ))
402417 ranks_seen = {
403418 evt ["args" ].get ("rank" ) for evt in payload if evt ["name" ] == "torchrun-step"
404419 }
0 commit comments