Skip to content
124 changes: 66 additions & 58 deletions src/forge/observability/perf_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import inspect
import logging
import os
import threading
import time

from concurrent.futures import Future, ThreadPoolExecutor
from functools import lru_cache, wraps
from typing import Protocol
Expand All @@ -18,6 +18,8 @@
from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_GPU
from forge.observability.metrics import record_metric, Reduce

logger = logging.getLogger(__name__)

# Thread-local memory tracking state
_local = threading.local()

Expand All @@ -44,7 +46,6 @@ def _warn_nested_memory_tracking(prefix: str) -> None:

"""


class Tracer:
==========
"""
Expand Down Expand Up @@ -150,10 +151,9 @@ def stop(self) -> None:
if not self._active:
raise ValueError("Tracer must be started before calling stop")

# Stop timing (always enabled)
# step("end") is dropped from steps, but included in total sum
self._timer.step("end") # pyre-ignore
self._record_timing_metrics()
# Stop timing
durations, stop_step_ms = self._timer.get_all_durations() # pyre-ignore
self._record_timing_metrics(durations, stop_step_ms)
self._timer = None

# Stop memory tracking
Expand Down Expand Up @@ -193,17 +193,15 @@ def _stop_memory_tracking(self) -> None:
torch.cuda.reset_max_memory_allocated()
self._memory_started = False

def _record_timing_metrics(self) -> None:
durations = self._timer.get_all_durations() # pyre-ignore

# Total: sum all recorded durations (full timeline including end)
total_ms = sum(d_ms for name, d_ms in durations)
def _record_timing_metrics(
self, durations: list[tuple[str, float]], stop_step_ms: float
) -> None:
total_ms = sum(d_ms for _, d_ms in durations) + stop_step_ms
total_s = total_ms / 1000.0
record_metric(f"{self.prefix}/total_duration_avg_s", total_s, Reduce.MEAN)
record_metric(f"{self.prefix}/total_duration_max_s", total_s, Reduce.MAX)

# Steps: record each individually (drop last "end")
for name, d_ms in durations[:-1]:
for name, d_ms in durations:
d_s = d_ms / 1000.0
record_metric(f"{self.prefix}/{name}/duration_avg_s", d_s, Reduce.MEAN)
record_metric(f"{self.prefix}/{name}/duration_max_s", d_s, Reduce.MAX)
Expand All @@ -216,7 +214,7 @@ def start(self) -> None:
def step(self, name: str) -> None:
...

def get_all_durations(self) -> list[tuple[str, float]]:
def get_all_durations(self) -> tuple[list[tuple[str, float]], float]:
...


Expand All @@ -242,13 +240,27 @@ def step(self, name: str) -> None:
self._durations.append((name, delta_ms))
self._chain_start = now

def get_all_durations(self) -> list[tuple[str, float]]:
return self._durations[:]
def get_all_durations(self) -> tuple[list[tuple[str, float]], float]:
"""Retrieve list of (step_name, duration) tuples and last step duration
between tracer.stop and the last step (or start if none)."""
stop_step_ms = 0.0
if self._chain_start is not None:
now = time.perf_counter()
stop_step_ms = (now - self._chain_start) * 1000
return self._durations[:], stop_step_ms


class _TimerCUDA(_TimerProtocol):
"""CUDA timing backend with non-blocking events and futures.
Uses a thread pool to poll CUDA events asynchronously without blocking the main thread.

Example:
timer = _TimerCUDA()
timer.start()
# torch.mm(a, b) # ~100ms GPU
timer.step("matmul")
# torch.mm(c, d) # ~200ms
durs_steps, stop_step_ms = timer.get_all_durations() # ([( "matmul", 100 )], 200)
"""

def __init__(self, max_workers: int = 2) -> None:
Expand Down Expand Up @@ -277,74 +289,70 @@ def step(self, name: str) -> None:
Args:
name: Label for this segment's duration
"""
# Submit polling future; chain to next event.
if self._chain_start is None:
raise ValueError("Timer must be started before calling step")

stream = torch.cuda.current_stream()
end_event = torch.cuda.Event(enable_timing=True)
end_event.record(stream)

def _compute_elapsed(start_event, end_event):
# Poll with backoff: starts fast (1ms), grows to cap (50ms) for mixed workloads.
sleep_time = 0.001 # Start at 1ms
while not end_event.query():
time.sleep(sleep_time)
sleep_time = min(sleep_time * 1.5, 0.05) # Backoff, cap at 50ms
return start_event.elapsed_time(end_event)

future = self._executor.submit(_compute_elapsed, self._chain_start, end_event)
future = self._executor.submit(self._poll_elapsed, self._chain_start, end_event)
index = len(self._futures)
self._futures.append((name, future, index))

if len(self._futures) >= 5: # clean up every 5
self._collect_completed_futures()

self._chain_start = end_event

def _collect_completed_futures(self) -> None:
def _poll_elapsed(
self, start_event: torch.cuda.Event, end_event: torch.cuda.Event
) -> float:
"""Compute elapsed time after polling with backoff."""
# Poll until ready
sleep_time = 0.001 # Start at 1ms
while not end_event.query():
time.sleep(sleep_time)
sleep_time = min(sleep_time * 1.5, 0.05) # Backoff, cap at 50ms
return start_event.elapsed_time(end_event)

def _collect_completed_futures(self, wait_till_done: bool = False) -> None:
"""Drain done futures to avoid memory leak; update durations in submission order."""
completed = []
still_pending = []
for name, future, idx in self._futures:
if future.done():
try:
dur = future.result()
completed.append((idx, name, dur))
except Exception as e:
raise RuntimeError(f"Timing failed for {name}: {e}") from e
if future.done() or wait_till_done:
dur = future.result()
self._durations.append((name, dur))
else:
still_pending.append((name, future, idx))

# Sort completed by submission index to preserve order
completed.sort(key=lambda x: x[0])
for _, name, dur in completed:
self._durations.append((name, dur))

self._futures = still_pending

def get_all_durations(self) -> list[tuple[str, float]]:
"""Retrieve list of (name, duration) tuples in submission order after waiting for background polls to finish."""
# Wait and collect if pendings; return durations.
self._collect_completed_futures()
completed = []
for name, future, idx in self._futures:
try:
dur = future.result()
completed.append((idx, name, dur))
except Exception as e:
raise RuntimeError(f"Timing failed for {name}: {e}") from e

# Sort by submission index to preserve order
completed.sort(key=lambda x: x[0])
for _, name, dur in completed:
self._durations.append((name, dur))
def get_all_durations(self) -> tuple[list[tuple[str, float]], float]:
"""Retrieve list of (step_name, duration) tuples and last step duration
between tracer.stop and the last step (or start if none). Order of tuples is random.
"""
# Final timing since last step (or start) until this function is called
stop_step = f"_stop_step_{id(self)}"
self.step(stop_step)

# Wait on remaining futures
self._collect_completed_futures(wait_till_done=True)
self._futures.clear()
return self._durations[:]

# Extract stop_step_ms
stop_step_ms = 0.0
durations = [
(name, duration) for name, duration in self._durations if name != stop_step
]
for name, duration in self._durations:
if name == stop_step:
stop_step_ms = duration
break

return durations, stop_step_ms

def __del__(self) -> None:
# Fallback cleanup in finalizer; ignores errors to avoid shutdown noise.
# Fallback cleanup in finalizer
try:
self._executor.shutdown(wait=True)
except Exception:
Expand Down
26 changes: 18 additions & 8 deletions tests/unit_tests/observability/test_perf_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,29 +309,39 @@ def test_tracer_and_timer_reuse(self, mock_record_metric_calls):
cpu_timer.start()
time.sleep(0.005)
cpu_timer.step("cpu_step1")
durations1 = cpu_timer.get_all_durations()
cpu_durations_list1, cpu_final_ms1 = cpu_timer.get_all_durations()

cpu_timer.start()
time.sleep(0.005)
cpu_timer.step("cpu_step2")
durations2 = cpu_timer.get_all_durations()
cpu_durations_list2, cpu_final_ms2 = cpu_timer.get_all_durations()

assert len(durations1) == 1 and durations1[0][0] == "cpu_step1"
assert len(durations2) == 1 and durations2[0][0] == "cpu_step2"
assert (
len(cpu_durations_list1) == 1 and cpu_durations_list1[0][0] == "cpu_step1"
)
assert (
len(cpu_durations_list2) == 1 and cpu_durations_list2[0][0] == "cpu_step2"
)

# Test CUDA timer reuse (if available)
if torch.cuda.is_available():
cuda_timer = _TimerCUDA()
cuda_timer.start()
cuda_timer.step("cuda_step1")
cuda_durations1 = cuda_timer.get_all_durations()
cuda_durations_list1, cuda_final_ms1 = cuda_timer.get_all_durations()

cuda_timer.start()
cuda_timer.step("cuda_step2")
cuda_durations2 = cuda_timer.get_all_durations()
cuda_durations_list2, cuda_final_ms2 = cuda_timer.get_all_durations()

assert len(cuda_durations1) == 1 and cuda_durations1[0][0] == "cuda_step1"
assert len(cuda_durations2) == 1 and cuda_durations2[0][0] == "cuda_step2"
assert (
len(cuda_durations_list1) == 1
and cuda_durations_list1[0][0] == "cuda_step1"
)
assert (
len(cuda_durations_list2) == 1
and cuda_durations_list2[0][0] == "cuda_step2"
)

def test_exception_handling_context_manager(self, mock_record_metric_calls):
"""Test context manager properly cleans up on exception."""
Expand Down
Loading