1414"""Profiling utilites."""
1515
1616import dataclasses
17+ import json
1718import logging
1819import os
19- import pathlib
20- import tempfile
2120import threading
2221import time
22+ from typing import Any
2323import urllib .parse
2424
2525import fastapi
@@ -45,6 +45,7 @@ def reset(self):
4545 self .executable = None
4646
4747
48+ _first_profile_start = True
4849_profile_state = _ProfileState ()
4950_original_start_trace = jax .profiler .start_trace
5051_original_stop_trace = jax .profiler .stop_trace
@@ -56,30 +57,115 @@ def toy_computation():
5657 x .block_until_ready ()
5758
5859
59- def start_trace (gcs_bucket : str ):
60- """Starts a profiler trace."""
60+ def _create_profile_request (
61+ log_dir : os .PathLike [str ] | str ,
62+ ) -> dict [str , Any ]:
63+ """Creates a profile request dictionary from the given options."""
64+ profile_request = {}
65+ profile_request ["traceLocation" ] = str (log_dir )
66+
67+ return profile_request
68+
69+
70+ def _start_trace_from_profile_request (
71+ profile_request : dict [str , Any ],
72+ * ,
73+ create_perfetto_link : bool = False ,
74+ create_perfetto_trace : bool = False ,
75+ ) -> None :
76+ """Starts a profiler trace from a profile request dictionary.
77+
78+ Args:
79+ profile_request: A dictionary containing the profile request options.
80+ create_perfetto_link: A boolean which, if true, creates and prints link to
81+ the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will
82+ block until the link is opened and Perfetto loads the trace. This feature
83+ is experimental for Pathways on Cloud and may not be fully supported.
84+ create_perfetto_trace: A boolean which, if true, additionally dumps a
85+ ``perfetto_trace.json.gz`` file that is compatible for upload with the
86+ Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be
87+ generated if ``create_perfetto_link`` is true. This could be useful if you
88+ want to generate a Perfetto-compatible trace without blocking the process.
89+ This feature is experimental for Pathways on Cloud and may not be fully
90+ supported.
91+ """
92+ log_dir = profile_request ["traceLocation" ]
93+
6194 with _profile_state .lock :
62- if start_trace ._first_profile_start : # pylint: disable=protected-access, attribute-error
63- start_trace ._first_profile_start = False # pylint: disable=protected-access
95+ global _first_profile_start
96+ if _first_profile_start :
97+ _first_profile_start = False
6498 toy_computation ()
6599
66100 if _profile_state .executable is not None :
67101 raise ValueError (
68102 "start_trace called while a trace is already being taken!"
69103 )
70104 _profile_state .executable = plugin_executable .PluginExecutable (
71- f"{{ profileRequest: {{traceLocation: ' { gcs_bucket } '}}}}"
105+ json . dumps ({ " profileRequest" : profile_request })
72106 )
73107 try :
74- _profile_state .executable .call ()[1 ].result ()
75- except :
108+ _ , result_future = _profile_state .executable .call ()
109+ result_future .result ()
110+ except Exception as e : # pylint: disable=broad-except
111+ _logger .exception ("Failed to start trace" )
76112 _profile_state .reset ()
77113 raise
78114
79- _original_start_trace (gcs_bucket )
115+ _original_start_trace (
116+ log_dir = log_dir ,
117+ create_perfetto_link = create_perfetto_link ,
118+ create_perfetto_trace = create_perfetto_trace ,
119+ )
120+
121+
122+ def start_trace (
123+ log_dir : os .PathLike [str ] | str ,
124+ * ,
125+ create_perfetto_link : bool = False ,
126+ create_perfetto_trace : bool = False ,
127+ ) -> None :
128+ """Starts a profiler trace.
80129
130+ The trace will capture CPU and TPU activity, including Python
131+ functions and JAX on-device operations. Use :func:`stop_trace` to end the
132+ trace and save the results to ``log_dir``.
81133
82- start_trace ._first_profile_start = True # pylint: disable=protected-access
134+ The resulting trace can be viewed with TensorBoard. Note that TensorBoard
135+ doesn't need to be running when collecting the trace.
136+
137+ Only one trace may be collected at a time. A RuntimeError will be raised if
138+ :func:`start_trace` is called while another trace is running.
139+
140+ Args:
141+ log_dir: The GCS directory to save the profiler trace to (usually the
142+ TensorBoard log directory), e.g., "gs://my_bucket/profiles".
143+ create_perfetto_link: A boolean which, if true, creates and prints link to
144+ the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will
145+ block until the link is opened and Perfetto loads the trace. This feature
146+ is experimental for Pathways on Cloud and may not be fully supported.
147+ create_perfetto_trace: A boolean which, if true, additionally dumps a
148+ ``perfetto_trace.json.gz`` file that is compatible for upload with the
149+ Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be
150+ generated if ``create_perfetto_link`` is true. This could be useful if you
151+ want to generate a Perfetto-compatible trace without blocking the process.
152+ This feature is experimental for Pathways on Cloud and may not be fully
153+ supported.
154+ """
155+ if not str (log_dir ).startswith ("gs://" ):
156+ raise ValueError (f"log_dir must be a GCS bucket path, got { log_dir } " )
157+
158+ if create_perfetto_link or create_perfetto_trace :
159+ _logger .warning (
160+ "create_perfetto_link and create_perfetto_trace are experimental "
161+ "features for Pathways on Cloud and may not be fully supported."
162+ )
163+
164+ _start_trace_from_profile_request (
165+ _create_profile_request (log_dir ),
166+ create_perfetto_link = create_perfetto_link ,
167+ create_perfetto_trace = create_perfetto_trace ,
168+ )
83169
84170
85171def stop_trace ():
@@ -88,11 +174,10 @@ def stop_trace():
88174 if _profile_state .executable is None :
89175 raise ValueError ("stop_trace called before a trace is being taken!" )
90176 try :
91- _profile_state .executable .call ()[1 ].result ()
92- except :
177+ _ , result_future = _profile_state .executable .call ()
178+ result_future .result ()
179+ finally :
93180 _profile_state .reset ()
94- raise
95- _profile_state .reset ()
96181
97182 _original_stop_trace ()
98183
@@ -151,7 +236,7 @@ def collect_profile(
151236 port : int ,
152237 duration_ms : int ,
153238 host : str ,
154- log_dir : str ,
239+ log_dir : os . PathLike [ str ] | str ,
155240) -> bool :
156241 """Collects a JAX profile and saves it to the specified directory.
157242
@@ -167,19 +252,19 @@ def collect_profile(
167252 Raises:
168253 ValueError: If the log_dir is not a GCS path.
169254 """
170- if not log_dir .startswith ("gs://" ):
171- raise ValueError ("log_dir must be a GCS path. " )
255+ if not str ( log_dir ) .startswith ("gs://" ):
256+ raise ValueError (f "log_dir must be a GCS bucket path, got { log_dir } " )
172257
173- json = {
258+ request_json = {
174259 "duration_ms" : duration_ms ,
175260 "repository_path" : log_dir ,
176261 }
177262 address = urllib .parse .urljoin (f"http://{ host } :{ port } " , "profiling" )
178263 try :
179- response = requests .post (address , json = json )
264+ response = requests .post (address , json = request_json )
180265 response .raise_for_status ()
181- except requests .exceptions .RequestException as e :
182- _logger .error ("Failed to collect profiling data: %s" , e )
266+ except requests .exceptions .RequestException :
267+ _logger .exception ("Failed to collect profiling data" )
183268 return False
184269
185270 return True
0 commit comments