Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 107 additions & 22 deletions pathwaysutils/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
"""Profiling utilites."""

import dataclasses
import json
import logging
import os
import pathlib
import tempfile
import threading
import time
from typing import Any
import urllib.parse

import fastapi
Expand All @@ -45,6 +45,7 @@ def reset(self):
self.executable = None


_first_profile_start = True
_profile_state = _ProfileState()
_original_start_trace = jax.profiler.start_trace
_original_stop_trace = jax.profiler.stop_trace
Expand All @@ -56,30 +57,115 @@ def toy_computation():
x.block_until_ready()


def start_trace(gcs_bucket: str):
"""Starts a profiler trace."""
def _create_profile_request(
log_dir: os.PathLike[str] | str,
) -> dict[str, Any]:
"""Creates a profile request dictionary from the given options."""
profile_request = {}
profile_request["traceLocation"] = str(log_dir)

return profile_request


def _start_trace_from_profile_request(
profile_request: dict[str, Any],
*,
create_perfetto_link: bool = False,
create_perfetto_trace: bool = False,
) -> None:
"""Starts a profiler trace from a profile request dictionary.

Args:
profile_request: A dictionary containing the profile request options.
create_perfetto_link: A boolean which, if true, creates and prints link to
the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will
block until the link is opened and Perfetto loads the trace. This feature
is experimental for Pathways on Cloud and may not be fully supported.
create_perfetto_trace: A boolean which, if true, additionally dumps a
``perfetto_trace.json.gz`` file that is compatible for upload with the
Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be
generated if ``create_perfetto_link`` is true. This could be useful if you
want to generate a Perfetto-compatible trace without blocking the process.
This feature is experimental for Pathways on Cloud and may not be fully
supported.
"""
log_dir = profile_request["traceLocation"]

with _profile_state.lock:
if start_trace._first_profile_start: # pylint: disable=protected-access, attribute-error
start_trace._first_profile_start = False # pylint: disable=protected-access
global _first_profile_start
if _first_profile_start:
_first_profile_start = False
toy_computation()

if _profile_state.executable is not None:
raise ValueError(
"start_trace called while a trace is already being taken!"
)
_profile_state.executable = plugin_executable.PluginExecutable(
f"{{profileRequest: {{traceLocation: '{gcs_bucket}'}}}}"
json.dumps({"profileRequest": profile_request})
)
try:
_profile_state.executable.call()[1].result()
except:
_, result_future = _profile_state.executable.call()
result_future.result()
except Exception as e: # pylint: disable=broad-except
_logger.exception("Failed to start trace")
_profile_state.reset()
raise

_original_start_trace(gcs_bucket)
_original_start_trace(
log_dir=log_dir,
create_perfetto_link=create_perfetto_link,
create_perfetto_trace=create_perfetto_trace,
)


def start_trace(
log_dir: os.PathLike[str] | str,
*,
create_perfetto_link: bool = False,
create_perfetto_trace: bool = False,
) -> None:
"""Starts a profiler trace.

The trace will capture CPU and TPU activity, including Python
functions and JAX on-device operations. Use :func:`stop_trace` to end the
trace and save the results to ``log_dir``.

start_trace._first_profile_start = True # pylint: disable=protected-access
The resulting trace can be viewed with TensorBoard. Note that TensorBoard
doesn't need to be running when collecting the trace.

Only one trace may be collected at a time. A RuntimeError will be raised if
:func:`start_trace` is called while another trace is running.

Args:
log_dir: The GCS directory to save the profiler trace to (usually the
TensorBoard log directory), e.g., "gs://my_bucket/profiles".
create_perfetto_link: A boolean which, if true, creates and prints link to
the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will
block until the link is opened and Perfetto loads the trace. This feature
is experimental for Pathways on Cloud and may not be fully supported.
create_perfetto_trace: A boolean which, if true, additionally dumps a
``perfetto_trace.json.gz`` file that is compatible for upload with the
Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be
generated if ``create_perfetto_link`` is true. This could be useful if you
want to generate a Perfetto-compatible trace without blocking the process.
This feature is experimental for Pathways on Cloud and may not be fully
supported.
"""
if not str(log_dir).startswith("gs://"):
raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}")

if create_perfetto_link or create_perfetto_trace:
_logger.warning(
"create_perfetto_link and create_perfetto_trace are experimental "
"features for Pathways on Cloud and may not be fully supported."
)

_start_trace_from_profile_request(
_create_profile_request(log_dir),
create_perfetto_link=create_perfetto_link,
create_perfetto_trace=create_perfetto_trace,
)


def stop_trace():
Expand All @@ -88,11 +174,10 @@ def stop_trace():
if _profile_state.executable is None:
raise ValueError("stop_trace called before a trace is being taken!")
try:
_profile_state.executable.call()[1].result()
except:
_, result_future = _profile_state.executable.call()
result_future.result()
finally:
_profile_state.reset()
raise
_profile_state.reset()

_original_stop_trace()

Expand Down Expand Up @@ -151,7 +236,7 @@ def collect_profile(
port: int,
duration_ms: int,
host: str,
log_dir: str,
log_dir: os.PathLike[str] | str,
) -> bool:
"""Collects a JAX profile and saves it to the specified directory.

Expand All @@ -167,19 +252,19 @@ def collect_profile(
Raises:
ValueError: If the log_dir is not a GCS path.
"""
if not log_dir.startswith("gs://"):
raise ValueError("log_dir must be a GCS path.")
if not str(log_dir).startswith("gs://"):
raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}")

json = {
request_json = {
"duration_ms": duration_ms,
"repository_path": log_dir,
}
address = urllib.parse.urljoin(f"http://{host}:{port}", "profiling")
try:
response = requests.post(address, json=json)
response = requests.post(address, json=request_json)
response.raise_for_status()
except requests.exceptions.RequestException as e:
_logger.error("Failed to collect profiling data: %s", e)
except requests.exceptions.RequestException:
_logger.exception("Failed to collect profiling data")
return False

return True
Expand Down
Loading