@@ -64,6 +64,13 @@ class PerfTraceCategory(str, Enum):
6464_REQUEST_TRACE_FILENAME = "requests.jsonl"
6565
6666
67+ def _rank_qualified_filename (filename : str , rank : int | None ) -> str :
68+ if rank is None :
69+ return filename
70+ root , ext = os .path .splitext (filename )
71+ return f"{ root } -r{ rank } { ext } "
72+
73+
6774def _maybe_duration (start : float | None , end : float | None ) -> float | None :
6875 if start is None or end is None :
6976 return None
@@ -92,6 +99,7 @@ def _default_trace_path(
9299 config : PerfTracerConfig ,
93100 * ,
94101 filename : str = _PERF_TRACE_FILENAME ,
102+ rank : int | None = None ,
95103) -> str :
96104 base_dir = os .path .join (
97105 os .path .expanduser (os .path .expandvars (config .fileroot )),
@@ -100,7 +108,7 @@ def _default_trace_path(
100108 config .experiment_name ,
101109 config .trial_name ,
102110 )
103- return os .path .join (base_dir , filename )
111+ return os .path .join (base_dir , _rank_qualified_filename ( filename , rank ) )
104112
105113
106114def _normalize_flush_threshold (config : RequestTracerConfig ) -> int :
@@ -315,8 +323,7 @@ def flush(self, force: bool = False) -> None:
315323 with _acquire_file_lock (self ._output_path ):
316324 with open (self ._output_path , "a" , encoding = "utf-8" ) as fout :
317325 for line in lines :
318- fout .write (line )
319- fout .write ("\n " )
326+ fout .write (f"{ line } \n " )
320327 fout .flush ()
321328 os .fsync (fout .fileno ())
322329 except OSError as exc : # pragma: no cover - depends on filesystem
@@ -455,7 +462,7 @@ def __init__(self, config: PerfTracerConfig, *, rank: int) -> None:
455462 self ._origin_ns = time .perf_counter_ns ()
456463 self ._thread_meta_emitted : set [int ] = set ()
457464 self ._process_meta_emitted : set [int ] = set ()
458- self ._output_path = _default_trace_path (config )
465+ self ._output_path = _default_trace_path (config , rank = rank )
459466 self ._save_interval = _normalize_save_interval (config )
460467 self ._request_tracer : RequestTracer | None = None
461468 self ._configure_request_tracer (config , rank = rank )
@@ -479,7 +486,11 @@ def _configure_request_tracer(self, config: PerfTracerConfig, *, rank: int) -> N
479486 request_cfg = getattr (config , "request_tracer" , None )
480487 enabled = bool (request_cfg and getattr (request_cfg , "enabled" , False ))
481488 if enabled :
482- output_path = _default_trace_path (config , filename = _REQUEST_TRACE_FILENAME )
489+ output_path = _default_trace_path (
490+ config ,
491+ filename = _REQUEST_TRACE_FILENAME ,
492+ rank = rank ,
493+ )
483494 if self ._request_tracer is None :
484495 self ._request_tracer = RequestTracer (
485496 request_cfg ,
@@ -500,7 +511,7 @@ def _configure_request_tracer(self, config: PerfTracerConfig, *, rank: int) -> N
500511 def apply_config (self , config : PerfTracerConfig , * , rank : int ) -> None :
501512 self ._config = config
502513 self .set_rank (rank )
503- self ._output_path = _default_trace_path (config )
514+ self ._output_path = _default_trace_path (config , rank = rank )
504515 self .set_enabled (config .enabled )
505516 self ._save_interval = _normalize_save_interval (config )
506517 self ._configure_request_tracer (config , rank = rank )
@@ -601,8 +612,7 @@ def save(self, *, step: int | None = None, force: bool = False) -> None:
601612 with _acquire_file_lock (output_path ):
602613 with open (output_path , "a" , encoding = "utf-8" ) as fout :
603614 for line in serialized_events :
604- fout .write (line )
605- fout .write ("\n " )
615+ fout .write (f"{ line } \n " )
606616 fout .flush ()
607617 os .fsync (fout .fileno ())
608618 self ._events = []
0 commit comments