1414"""Profiling utilites."""
1515
1616import dataclasses
17+ import json
1718import logging
1819import os
19- import pathlib
20- import tempfile
2120import threading
2221import time
22+ from typing import Any
2323import urllib .parse
2424
2525import fastapi
@@ -45,6 +45,7 @@ def reset(self):
4545 self .executable = None
4646
4747
48+ _first_profile_start = True
4849_profile_state = _ProfileState ()
4950_original_start_trace = jax .profiler .start_trace
5051_original_stop_trace = jax .profiler .stop_trace
@@ -56,30 +57,108 @@ def toy_computation():
5657 x .block_until_ready ()
5758
5859
59- def start_trace (gcs_bucket : str ):
60- """Starts a profiler trace."""
60+ def _create_profile_request (
61+ log_dir : os .PathLike [str ] | str ,
62+ profiler_options : jax .profiler .ProfileOptions | None ,
63+ ) -> dict [str , Any ]:
64+ """Creates a profile request dictionary from the given options."""
65+ profile_request = {}
66+ profile_request ["traceLocation" ] = str (log_dir )
67+
68+ if profiler_options is not None :
69+ xprof_trace_options = {}
70+ if profiler_options .host_tracer_level is not None :
71+ xprof_trace_options ["hostTraceLevel" ] = profiler_options .host_tracer_level
72+ trace_options = {}
73+ if profiler_options .start_timestamp_ns :
74+ trace_options ["profilingStartTimeNs" ] = (
75+ profiler_options .start_timestamp_ns
76+ )
77+ if profiler_options .duration_ms :
78+ trace_options ["profilingDurationMs" ] = profiler_options .duration_ms
79+
80+ if trace_options :
81+ xprof_trace_options ["traceOptions" ] = trace_options
82+ if xprof_trace_options :
83+ profile_request ["xprofTraceOptions" ] = xprof_trace_options
84+ return profile_request
85+
86+
87+ def start_trace (
88+ log_dir : os .PathLike [str ] | str ,
89+ create_perfetto_link : bool = False ,
90+ create_perfetto_trace : bool = False ,
91+ profiler_options : jax .profiler .ProfileOptions | None = None ,
92+ ):
93+ """Starts a profiler trace.
94+
95+ The trace will capture CPU and TPU activity, including Python
96+ functions and JAX on-device operations. Use :func:`stop_trace` to end the
97+ trace and save the results to ``log_dir``.
98+
99+ The resulting trace can be viewed with TensorBoard. Note that TensorBoard
100+ doesn't need to be running when collecting the trace.
101+
102+ Only one trace may be collected at a time. A RuntimeError will be raised if
103+ :func:`start_trace` is called while another trace is running.
104+
105+ Args:
106+ log_dir: The GCS directory to save the profiler trace to (usually the
107+ TensorBoard log directory), e.g., "gs://my_bucket/profiles".
108+ create_perfetto_link: A boolean which, if true, creates and prints link to
109+ the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will
110+ block until the link is opened and Perfetto loads the trace. This feature
111+ is experimental for Pathways on Cloud and may not be fully supported.
112+ create_perfetto_trace: A boolean which, if true, additionally dumps a
113+ ``perfetto_trace.json.gz`` file that is compatible for upload with the
114+ Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be
115+ generated if ``create_perfetto_link`` is true. This could be useful if you
116+ want to generate a Perfetto-compatible trace without blocking the process.
117+ This feature is experimental for Pathways on Cloud and may not be fully
118+ supported.
119+ profiler_options: Profiler options to configure the profiler for collection.
120+ Passing a mappable object is experimental for Pathways on Cloud and may
121+ not be fully supported.
122+ """
123+ if log_dir is None :
124+ raise ValueError ("log_dir cannot be None." )
125+ if not str (log_dir ).startswith ("gs://" ):
126+ raise ValueError (f"log_dir must be a GCS bucket path, got { log_dir } " )
127+
128+ if create_perfetto_link or create_perfetto_trace :
129+ _logger .warning (
130+ "create_perfetto_link and create_perfetto_trace are experimental "
131+ "features for Pathways on Cloud and may not be fully supported."
132+ )
133+
134+ profile_request = _create_profile_request (log_dir , profiler_options )
135+
61136 with _profile_state .lock :
62- if start_trace ._first_profile_start : # pylint: disable=protected-access, attribute-error
63- start_trace ._first_profile_start = False # pylint: disable=protected-access
137+ global _first_profile_start
138+ if _first_profile_start :
139+ _first_profile_start = False
64140 toy_computation ()
65141
66142 if _profile_state .executable is not None :
67143 raise ValueError (
68144 "start_trace called while a trace is already being taken!"
69145 )
70146 _profile_state .executable = plugin_executable .PluginExecutable (
71- f"{{ profileRequest: {{traceLocation: ' { gcs_bucket } '}}}}"
147+ json . dumps ({ " profileRequest" : profile_request })
72148 )
73149 try :
74150 _profile_state .executable .call ()[1 ].result ()
75- except :
151+ except Exception as e : # pylint: disable=broad-except
152+ _logger .exception ("Failed to start trace" )
76153 _profile_state .reset ()
77154 raise
78155
79- _original_start_trace (gcs_bucket )
80-
81-
82- start_trace ._first_profile_start = True # pylint: disable=protected-access
156+ _original_start_trace (
157+ log_dir = log_dir ,
158+ create_perfetto_link = create_perfetto_link ,
159+ create_perfetto_trace = create_perfetto_trace ,
160+ profiler_options = profiler_options ,
161+ )
83162
84163
85164def stop_trace ():
@@ -89,10 +168,8 @@ def stop_trace():
89168 raise ValueError ("stop_trace called before a trace is being taken!" )
90169 try :
91170 _profile_state .executable .call ()[1 ].result ()
92- except :
171+ finally :
93172 _profile_state .reset ()
94- raise
95- _profile_state .reset ()
96173
97174 _original_stop_trace ()
98175
@@ -170,16 +247,16 @@ def collect_profile(
170247 if not log_dir .startswith ("gs://" ):
171248 raise ValueError ("log_dir must be a GCS path." )
172249
173- json = {
250+ request_json = {
174251 "duration_ms" : duration_ms ,
175252 "repository_path" : log_dir ,
176253 }
177254 address = urllib .parse .urljoin (f"http://{ host } :{ port } " , "profiling" )
178255 try :
179- response = requests .post (address , json = json )
256+ response = requests .post (address , json = request_json )
180257 response .raise_for_status ()
181- except requests .exceptions .RequestException as e :
182- _logger .error ("Failed to collect profiling data: %s" , e )
258+ except requests .exceptions .RequestException :
259+ _logger .exception ("Failed to collect profiling data" )
183260 return False
184261
185262 return True
0 commit comments