Skip to content

Commit 3246e1c

Browse files
lukebaumanncopybara-github
authored andcommitted
Refactor profiling and add experimental Perfetto support.
PiperOrigin-RevId: 842402871
1 parent f416009 commit 3246e1c

File tree

2 files changed

+382
-25
lines changed

2 files changed

+382
-25
lines changed

pathwaysutils/profiling.py

Lines changed: 107 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
"""Profiling utilites."""
1515

1616
import dataclasses
17+
import json
1718
import logging
1819
import os
19-
import pathlib
20-
import tempfile
2120
import threading
2221
import time
22+
from typing import Any
2323
import urllib.parse
2424

2525
import fastapi
@@ -45,6 +45,7 @@ def reset(self):
4545
self.executable = None
4646

4747

48+
_first_profile_start = True
4849
_profile_state = _ProfileState()
4950
_original_start_trace = jax.profiler.start_trace
5051
_original_stop_trace = jax.profiler.stop_trace
@@ -56,30 +57,115 @@ def toy_computation():
5657
x.block_until_ready()
5758

5859

59-
def start_trace(gcs_bucket: str):
60-
"""Starts a profiler trace."""
60+
def _create_profile_request(
61+
log_dir: os.PathLike[str] | str,
62+
) -> dict[str, Any]:
63+
"""Creates a profile request dictionary from the given options."""
64+
profile_request = {}
65+
profile_request["traceLocation"] = str(log_dir)
66+
67+
return profile_request
68+
69+
70+
def _start_trace_from_profile_request(
71+
profile_request: dict[str, Any],
72+
*,
73+
create_perfetto_link: bool = False,
74+
create_perfetto_trace: bool = False,
75+
) -> None:
76+
"""Starts a profiler trace from a profile request dictionary.
77+
78+
Args:
79+
profile_request: A dictionary containing the profile request options.
80+
create_perfetto_link: A boolean which, if true, creates and prints link to
81+
the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will
82+
block until the link is opened and Perfetto loads the trace. This feature
83+
is experimental for Pathways on Cloud and may not be fully supported.
84+
create_perfetto_trace: A boolean which, if true, additionally dumps a
85+
``perfetto_trace.json.gz`` file that is compatible for upload with the
86+
Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be
87+
generated if ``create_perfetto_link`` is true. This could be useful if you
88+
want to generate a Perfetto-compatible trace without blocking the process.
89+
This feature is experimental for Pathways on Cloud and may not be fully
90+
supported.
91+
"""
92+
log_dir = profile_request["traceLocation"]
93+
6194
with _profile_state.lock:
62-
if start_trace._first_profile_start: # pylint: disable=protected-access, attribute-error
63-
start_trace._first_profile_start = False # pylint: disable=protected-access
95+
global _first_profile_start
96+
if _first_profile_start:
97+
_first_profile_start = False
6498
toy_computation()
6599

66100
if _profile_state.executable is not None:
67101
raise ValueError(
68102
"start_trace called while a trace is already being taken!"
69103
)
70104
_profile_state.executable = plugin_executable.PluginExecutable(
71-
f"{{profileRequest: {{traceLocation: '{gcs_bucket}'}}}}"
105+
json.dumps({"profileRequest": profile_request})
72106
)
73107
try:
74-
_profile_state.executable.call()[1].result()
75-
except:
108+
_, result_future = _profile_state.executable.call()
109+
result_future.result()
110+
except Exception as e: # pylint: disable=broad-except
111+
_logger.exception("Failed to start trace")
76112
_profile_state.reset()
77113
raise
78114

79-
_original_start_trace(gcs_bucket)
115+
_original_start_trace(
116+
log_dir=log_dir,
117+
create_perfetto_link=create_perfetto_link,
118+
create_perfetto_trace=create_perfetto_trace,
119+
)
120+
121+
122+
def start_trace(
123+
log_dir: os.PathLike[str] | str,
124+
*,
125+
create_perfetto_link: bool = False,
126+
create_perfetto_trace: bool = False,
127+
) -> None:
128+
"""Starts a profiler trace.
80129
130+
The trace will capture CPU and TPU activity, including Python
131+
functions and JAX on-device operations. Use :func:`stop_trace` to end the
132+
trace and save the results to ``log_dir``.
81133
82-
start_trace._first_profile_start = True # pylint: disable=protected-access
134+
The resulting trace can be viewed with TensorBoard. Note that TensorBoard
135+
doesn't need to be running when collecting the trace.
136+
137+
Only one trace may be collected at a time. A RuntimeError will be raised if
138+
:func:`start_trace` is called while another trace is running.
139+
140+
Args:
141+
log_dir: The GCS directory to save the profiler trace to (usually the
142+
TensorBoard log directory), e.g., "gs://my_bucket/profiles".
143+
create_perfetto_link: A boolean which, if true, creates and prints link to
144+
the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will
145+
block until the link is opened and Perfetto loads the trace. This feature
146+
is experimental for Pathways on Cloud and may not be fully supported.
147+
create_perfetto_trace: A boolean which, if true, additionally dumps a
148+
``perfetto_trace.json.gz`` file that is compatible for upload with the
149+
Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be
150+
generated if ``create_perfetto_link`` is true. This could be useful if you
151+
want to generate a Perfetto-compatible trace without blocking the process.
152+
This feature is experimental for Pathways on Cloud and may not be fully
153+
supported.
154+
"""
155+
if not str(log_dir).startswith("gs://"):
156+
raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}")
157+
158+
if create_perfetto_link or create_perfetto_trace:
159+
_logger.warning(
160+
"create_perfetto_link and create_perfetto_trace are experimental "
161+
"features for Pathways on Cloud and may not be fully supported."
162+
)
163+
164+
_start_trace_from_profile_request(
165+
_create_profile_request(log_dir),
166+
create_perfetto_link=create_perfetto_link,
167+
create_perfetto_trace=create_perfetto_trace,
168+
)
83169

84170

85171
def stop_trace():
@@ -88,11 +174,10 @@ def stop_trace():
88174
if _profile_state.executable is None:
89175
raise ValueError("stop_trace called before a trace is being taken!")
90176
try:
91-
_profile_state.executable.call()[1].result()
92-
except:
177+
_, result_future = _profile_state.executable.call()
178+
result_future.result()
179+
finally:
93180
_profile_state.reset()
94-
raise
95-
_profile_state.reset()
96181

97182
_original_stop_trace()
98183

@@ -151,7 +236,7 @@ def collect_profile(
151236
port: int,
152237
duration_ms: int,
153238
host: str,
154-
log_dir: str,
239+
log_dir: os.PathLike[str] | str,
155240
) -> bool:
156241
"""Collects a JAX profile and saves it to the specified directory.
157242
@@ -167,19 +252,19 @@ def collect_profile(
167252
Raises:
168253
ValueError: If the log_dir is not a GCS path.
169254
"""
170-
if not log_dir.startswith("gs://"):
171-
raise ValueError("log_dir must be a GCS path.")
255+
if not str(log_dir).startswith("gs://"):
256+
raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}")
172257

173-
json = {
258+
request_json = {
174259
"duration_ms": duration_ms,
175260
"repository_path": log_dir,
176261
}
177262
address = urllib.parse.urljoin(f"http://{host}:{port}", "profiling")
178263
try:
179-
response = requests.post(address, json=json)
264+
response = requests.post(address, json=request_json)
180265
response.raise_for_status()
181-
except requests.exceptions.RequestException as e:
182-
_logger.error("Failed to collect profiling data: %s", e)
266+
except requests.exceptions.RequestException:
267+
_logger.exception("Failed to collect profiling data")
183268
return False
184269

185270
return True

0 commit comments

Comments
 (0)