Skip to content

Commit 69abf80

Browse files
authored
fix: use per-rank jsonl instead of file lock in case that NFS does not support it (#513)
1 parent 972d287 commit 69abf80

File tree

2 files changed

+54
-14
lines changed

2 files changed

+54
-14
lines changed

areal/tools/perf_trace_converter.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import sys
77
from collections.abc import Sequence
8+
from glob import glob
89
from pathlib import Path
910

1011

@@ -24,19 +25,41 @@ def _load_events(path: Path) -> list[dict]:
2425
return events
2526

2627

28+
def _resolve_trace_files(source: Path) -> list[Path]:
29+
if source.is_file():
30+
return [source]
31+
if source.is_dir():
32+
return sorted(p for p in source.glob("*.jsonl") if p.is_file())
33+
matches = [Path(p) for p in glob(str(source), recursive=True)]
34+
files = [p for p in matches if p.is_file()]
35+
return sorted(files)
36+
37+
2738
def convert_jsonl_to_chrome_trace(
2839
input_path: str | os.PathLike[str],
2940
output_path: str | os.PathLike[str] | None = None,
3041
*,
3142
display_time_unit: str = "ms",
3243
) -> dict:
33-
"""Convert newline-delimited trace events into Chrome Trace JSON."""
44+
"""Convert newline-delimited trace events into Chrome Trace JSON.
45+
46+
The ``input_path`` may point to a single JSONL file, a directory containing
47+
per-rank JSONL files, or a glob pattern. All matching files are concatenated
48+
in lexical order before emitting the Chrome trace payload.
49+
"""
50+
51+
sources = _resolve_trace_files(Path(input_path))
52+
if not sources:
53+
raise FileNotFoundError(f"No trace files matched input path: {input_path}")
54+
55+
events: list[dict] = []
56+
for path in sources:
57+
events.extend(_load_events(path))
3458

35-
source = Path(input_path)
36-
if not source.is_file(): # pragma: no cover - defensive guard
37-
raise FileNotFoundError(f"Input trace file not found: {source}")
59+
events.sort(
60+
key=lambda event: (event.get("ts", 0), event.get("pid", 0), event.get("tid", 0))
61+
)
3862

39-
events = _load_events(source)
4063
chrome_trace = {
4164
"traceEvents": events,
4265
"displayTimeUnit": display_time_unit,
@@ -55,7 +78,14 @@ def _parse_args(argv: Sequence[str] | None) -> argparse.Namespace:
5578
parser = argparse.ArgumentParser(
5679
description="Convert PerfTracer JSONL output into Chrome Trace JSON format.",
5780
)
58-
parser.add_argument("input", type=str, help="Path to the PerfTracer JSONL file")
81+
parser.add_argument(
82+
"input",
83+
type=str,
84+
help=(
85+
"Path, directory, or glob pattern for PerfTracer JSONL files "
86+
"(per-rank outputs allowed)"
87+
),
88+
)
5989
parser.add_argument(
6090
"output",
6191
type=str,

areal/utils/perf_tracer.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ class PerfTraceCategory(str, Enum):
6464
_REQUEST_TRACE_FILENAME = "requests.jsonl"
6565

6666

67+
def _rank_qualified_filename(filename: str, rank: int | None) -> str:
68+
if rank is None:
69+
return filename
70+
root, ext = os.path.splitext(filename)
71+
return f"{root}-r{rank}{ext}"
72+
73+
6774
def _maybe_duration(start: float | None, end: float | None) -> float | None:
6875
if start is None or end is None:
6976
return None
@@ -92,6 +99,7 @@ def _default_trace_path(
9299
config: PerfTracerConfig,
93100
*,
94101
filename: str = _PERF_TRACE_FILENAME,
102+
rank: int | None = None,
95103
) -> str:
96104
base_dir = os.path.join(
97105
os.path.expanduser(os.path.expandvars(config.fileroot)),
@@ -100,7 +108,7 @@ def _default_trace_path(
100108
config.experiment_name,
101109
config.trial_name,
102110
)
103-
return os.path.join(base_dir, filename)
111+
return os.path.join(base_dir, _rank_qualified_filename(filename, rank))
104112

105113

106114
def _normalize_flush_threshold(config: RequestTracerConfig) -> int:
@@ -315,8 +323,7 @@ def flush(self, force: bool = False) -> None:
315323
with _acquire_file_lock(self._output_path):
316324
with open(self._output_path, "a", encoding="utf-8") as fout:
317325
for line in lines:
318-
fout.write(line)
319-
fout.write("\n")
326+
fout.write(f"{line}\n")
320327
fout.flush()
321328
os.fsync(fout.fileno())
322329
except OSError as exc: # pragma: no cover - depends on filesystem
@@ -455,7 +462,7 @@ def __init__(self, config: PerfTracerConfig, *, rank: int) -> None:
455462
self._origin_ns = time.perf_counter_ns()
456463
self._thread_meta_emitted: set[int] = set()
457464
self._process_meta_emitted: set[int] = set()
458-
self._output_path = _default_trace_path(config)
465+
self._output_path = _default_trace_path(config, rank=rank)
459466
self._save_interval = _normalize_save_interval(config)
460467
self._request_tracer: RequestTracer | None = None
461468
self._configure_request_tracer(config, rank=rank)
@@ -479,7 +486,11 @@ def _configure_request_tracer(self, config: PerfTracerConfig, *, rank: int) -> N
479486
request_cfg = getattr(config, "request_tracer", None)
480487
enabled = bool(request_cfg and getattr(request_cfg, "enabled", False))
481488
if enabled:
482-
output_path = _default_trace_path(config, filename=_REQUEST_TRACE_FILENAME)
489+
output_path = _default_trace_path(
490+
config,
491+
filename=_REQUEST_TRACE_FILENAME,
492+
rank=rank,
493+
)
483494
if self._request_tracer is None:
484495
self._request_tracer = RequestTracer(
485496
request_cfg,
@@ -500,7 +511,7 @@ def _configure_request_tracer(self, config: PerfTracerConfig, *, rank: int) -> N
500511
def apply_config(self, config: PerfTracerConfig, *, rank: int) -> None:
501512
self._config = config
502513
self.set_rank(rank)
503-
self._output_path = _default_trace_path(config)
514+
self._output_path = _default_trace_path(config, rank=rank)
504515
self.set_enabled(config.enabled)
505516
self._save_interval = _normalize_save_interval(config)
506517
self._configure_request_tracer(config, rank=rank)
@@ -601,8 +612,7 @@ def save(self, *, step: int | None = None, force: bool = False) -> None:
601612
with _acquire_file_lock(output_path):
602613
with open(output_path, "a", encoding="utf-8") as fout:
603614
for line in serialized_events:
604-
fout.write(line)
605-
fout.write("\n")
615+
fout.write(f"{line}\n")
606616
fout.flush()
607617
os.fsync(fout.fileno())
608618
self._events = []

0 commit comments

Comments
 (0)