Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion flashinfer/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@ def bench_gpu_time_with_cupti(
# Use 2x L2 size to ensure complete flush
_l2_flush_size_mb = (l2_size * 2) // (1024 * 1024)

# check if CUPTI is installed and its version is >= 13.0.0
# check if CUPTI is installed, version >= 13, and the driver supports it
try:
from cupti import cupti
from importlib.metadata import version as importlib_metadata_version
Expand All @@ -1016,6 +1016,9 @@ def bench_gpu_time_with_cupti(
raise Exception(
"CUPTI needs to be >= 13.0.0. Try 'pip install -U cupti-python'."
)
# Probe driver support (raises NotSupportedError on CUDA < 13.0 drivers)
cupti.activity_enable(cupti.ActivityKind.RUNTIME)
cupti.activity_disable(cupti.ActivityKind.RUNTIME)
from functools import partial
except (ModuleNotFoundError, Exception) as e:
if isinstance(e, ModuleNotFoundError):
Expand Down
63 changes: 63 additions & 0 deletions tests/utils/test_cupti_fallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Test that bench_gpu_time falls back gracefully to CUDA events when CUPTI is
unavailable or the driver does not support it (e.g. CUDA driver < 13.0)."""

import warnings
from unittest.mock import patch, MagicMock

import pytest
import torch

from flashinfer.testing import bench_gpu_time


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
def test_cupti_fallback_on_activity_enable_error():
"""When cupti.activity_enable raises (old driver), bench_gpu_time should
fall back to CUDA events instead of crashing."""
a = torch.randn(64, 64, device="cuda")
b = torch.randn(64, 64, device="cuda")

# Build a fake cupti module whose activity_enable always raises
fake_cupti = MagicMock()
fake_cupti.activity_enable.side_effect = Exception(
"CUPTI_ERROR_NOT_SUPPORTED: driver too old"
)
fake_module = MagicMock()
fake_module.cupti = fake_cupti

real_import = __import__

def _patched_import(name, *args, **kwargs):
"""Route 'cupti' imports to the fake module."""
if name == "cupti":
return fake_module
return real_import(name, *args, **kwargs)

# Also patch importlib.metadata.version to report a new-enough cupti-python
with (
patch("importlib.metadata.version", return_value="13.0.0"),
patch("builtins.__import__", side_effect=_patched_import),
warnings.catch_warnings(record=True) as caught,
):
warnings.simplefilter("always")
times = bench_gpu_time(
fn=torch.matmul,
input_args=(a, b),
repeat_iters=5,
dry_run_iters=2,
cold_l2_cache=False,
enable_cupti=True,
)

# Should have fallen back successfully
assert isinstance(times, list)
assert len(times) == 5
assert all(t > 0 for t in times)

# Should have emitted a fallback warning
fallback_warnings = [
w
for w in caught
if issubclass(w.category, UserWarning) and "Falling back" in str(w.message)
]
assert len(fallback_warnings) >= 1