Skip to content

Commit e7ffc13

Browse files
lukebaumanncopybara-github
authored andcommitted
Adding some support for ProfileOptions to jax.profiler.start_trace for Pathways.
Specifically, `host_tracer_level`, `start_timestamp_ns`, and `duration_ms` are supported. Additionally - Renamed `start_trace`'s `gcs_bucket` argument to `log_dir` to match the original jax function signature to better support kwarg use. - Added experimental support for perfetto link generation PiperOrigin-RevId: 841263378
1 parent f416009 commit e7ffc13

File tree

2 files changed

+385
-22
lines changed

2 files changed

+385
-22
lines changed

pathwaysutils/profiling.py

Lines changed: 96 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
"""Profiling utilites."""
1515

1616
import dataclasses
17+
import json
1718
import logging
1819
import os
19-
import pathlib
20-
import tempfile
2120
import threading
2221
import time
22+
from typing import Any
2323
import urllib.parse
2424

2525
import fastapi
@@ -45,6 +45,7 @@ def reset(self):
4545
self.executable = None
4646

4747

48+
_first_profile_start = True
4849
_profile_state = _ProfileState()
4950
_original_start_trace = jax.profiler.start_trace
5051
_original_stop_trace = jax.profiler.stop_trace
@@ -56,30 +57,108 @@ def toy_computation():
5657
x.block_until_ready()
5758

5859

59-
def start_trace(gcs_bucket: str):
60-
"""Starts a profiler trace."""
60+
def _create_profile_request(
61+
log_dir: os.PathLike[str] | str,
62+
profiler_options: jax.profiler.ProfileOptions | None,
63+
) -> dict[str, Any]:
64+
"""Creates a profile request dictionary from the given options."""
65+
profile_request = {}
66+
profile_request["traceLocation"] = str(log_dir)
67+
68+
if profiler_options is not None:
69+
xprof_trace_options = {}
70+
if profiler_options.host_tracer_level is not None:
71+
xprof_trace_options["hostTraceLevel"] = profiler_options.host_tracer_level
72+
trace_options = {}
73+
if profiler_options.start_timestamp_ns:
74+
trace_options["profilingStartTimeNs"] = (
75+
profiler_options.start_timestamp_ns
76+
)
77+
if profiler_options.duration_ms:
78+
trace_options["profilingDurationMs"] = profiler_options.duration_ms
79+
80+
if trace_options:
81+
xprof_trace_options["traceOptions"] = trace_options
82+
if xprof_trace_options:
83+
profile_request["xprofTraceOptions"] = xprof_trace_options
84+
return profile_request
85+
86+
87+
def start_trace(
88+
log_dir: os.PathLike[str] | str,
89+
create_perfetto_link: bool = False,
90+
create_perfetto_trace: bool = False,
91+
profiler_options: jax.profiler.ProfileOptions | None = None,
92+
):
93+
"""Starts a profiler trace.
94+
95+
The trace will capture CPU and TPU activity, including Python
96+
functions and JAX on-device operations. Use :func:`stop_trace` to end the
97+
trace and save the results to ``log_dir``.
98+
99+
The resulting trace can be viewed with TensorBoard. Note that TensorBoard
100+
doesn't need to be running when collecting the trace.
101+
102+
Only one trace may be collected at a time. A RuntimeError will be raised if
103+
:func:`start_trace` is called while another trace is running.
104+
105+
Args:
106+
log_dir: The GCS directory to save the profiler trace to (usually the
107+
TensorBoard log directory), e.g., "gs://my_bucket/profiles".
108+
create_perfetto_link: A boolean which, if true, creates and prints link to
109+
the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will
110+
block until the link is opened and Perfetto loads the trace. This feature
111+
is experimental for Pathways on Cloud and may not be fully supported.
112+
create_perfetto_trace: A boolean which, if true, additionally dumps a
113+
``perfetto_trace.json.gz`` file that is compatible for upload with the
114+
Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be
115+
generated if ``create_perfetto_link`` is true. This could be useful if you
116+
want to generate a Perfetto-compatible trace without blocking the process.
117+
This feature is experimental for Pathways on Cloud and may not be fully
118+
supported.
119+
profiler_options: Profiler options to configure the profiler for collection.
120+
Passing a mappable object is experimental for Pathways on Cloud and may
121+
not be fully supported.
122+
"""
123+
if log_dir is None:
124+
raise ValueError("log_dir cannot be None.")
125+
if not str(log_dir).startswith("gs://"):
126+
raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}")
127+
128+
if create_perfetto_link or create_perfetto_trace:
129+
_logger.warning(
130+
"create_perfetto_link and create_perfetto_trace are experimental "
131+
"features for Pathways on Cloud and may not be fully supported."
132+
)
133+
134+
profile_request = _create_profile_request(log_dir, profiler_options)
135+
61136
with _profile_state.lock:
62-
if start_trace._first_profile_start: # pylint: disable=protected-access, attribute-error
63-
start_trace._first_profile_start = False # pylint: disable=protected-access
137+
global _first_profile_start
138+
if _first_profile_start:
139+
_first_profile_start = False
64140
toy_computation()
65141

66142
if _profile_state.executable is not None:
67143
raise ValueError(
68144
"start_trace called while a trace is already being taken!"
69145
)
70146
_profile_state.executable = plugin_executable.PluginExecutable(
71-
f"{{profileRequest: {{traceLocation: '{gcs_bucket}'}}}}"
147+
json.dumps({"profileRequest": profile_request})
72148
)
73149
try:
74150
_profile_state.executable.call()[1].result()
75-
except:
151+
except Exception as e: # pylint: disable=broad-except
152+
_logger.exception("Failed to start trace")
76153
_profile_state.reset()
77154
raise
78155

79-
_original_start_trace(gcs_bucket)
80-
81-
82-
start_trace._first_profile_start = True # pylint: disable=protected-access
156+
_original_start_trace(
157+
log_dir=log_dir,
158+
create_perfetto_link=create_perfetto_link,
159+
create_perfetto_trace=create_perfetto_trace,
160+
profiler_options=profiler_options,
161+
)
83162

84163

85164
def stop_trace():
@@ -89,10 +168,8 @@ def stop_trace():
89168
raise ValueError("stop_trace called before a trace is being taken!")
90169
try:
91170
_profile_state.executable.call()[1].result()
92-
except:
171+
finally:
93172
_profile_state.reset()
94-
raise
95-
_profile_state.reset()
96173

97174
_original_stop_trace()
98175

@@ -170,16 +247,16 @@ def collect_profile(
170247
if not log_dir.startswith("gs://"):
171248
raise ValueError("log_dir must be a GCS path.")
172249

173-
json = {
250+
request_json = {
174251
"duration_ms": duration_ms,
175252
"repository_path": log_dir,
176253
}
177254
address = urllib.parse.urljoin(f"http://{host}:{port}", "profiling")
178255
try:
179-
response = requests.post(address, json=json)
256+
response = requests.post(address, json=request_json)
180257
response.raise_for_status()
181-
except requests.exceptions.RequestException as e:
182-
_logger.error("Failed to collect profiling data: %s", e)
258+
except requests.exceptions.RequestException:
259+
_logger.exception("Failed to collect profiling data")
183260
return False
184261

185262
return True

0 commit comments

Comments
 (0)