Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion flashinfer/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@ def bench_gpu_time_with_cupti(
# Use 2x L2 size to ensure complete flush
_l2_flush_size_mb = (l2_size * 2) // (1024 * 1024)

# check if CUPTI is installed and its version is >= 13.0.0
# check if CUPTI is installed, version >= 13, and the driver supports it
try:
from cupti import cupti
from importlib.metadata import version as importlib_metadata_version
Expand All @@ -1016,6 +1016,9 @@ def bench_gpu_time_with_cupti(
raise Exception(
"CUPTI needs to be >= 13.0.0. Try 'pip install -U cupti-python'."
)
# Probe driver support (raises NotSupportedError on CUDA < 13.0 drivers)
cupti.activity_enable(cupti.ActivityKind.RUNTIME)
cupti.activity_disable(cupti.ActivityKind.RUNTIME)
from functools import partial
except (ModuleNotFoundError, Exception) as e:
if isinstance(e, ModuleNotFoundError):
Expand Down
65 changes: 65 additions & 0 deletions tests/utils/test_cupti_fallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Test that bench_gpu_time falls back gracefully to CUDA events when CUPTI is
unavailable or the driver does not support it (e.g. CUDA driver < 13.0)."""

import warnings
from unittest.mock import patch, MagicMock

import pytest
import torch

from flashinfer.testing import bench_gpu_time


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
def test_cupti_fallback_on_activity_enable_error():
"""When cupti.activity_enable raises (old driver), bench_gpu_time should
fall back to CUDA events instead of crashing."""
a = torch.randn(64, 64, device="cuda")
b = torch.randn(64, 64, device="cuda")

# Build a fake cupti module whose activity_enable always raises
fake_cupti = MagicMock()
fake_cupti.activity_enable.side_effect = Exception(
"CUPTI_ERROR_NOT_SUPPORTED: driver too old"
)
fake_module = MagicMock()
fake_module.cupti = fake_cupti

real_import = (
__builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__
)

def _patched_import(name, *args, **kwargs):
"""Route 'cupti' imports to the fake module."""
if name == "cupti":
return fake_module
return real_import(name, *args, **kwargs)

# Also patch importlib.metadata.version to report a new-enough cupti-python
with (
patch("importlib.metadata.version", return_value="13.0.0"),
patch("builtins.__import__", side_effect=_patched_import),
warnings.catch_warnings(record=True) as caught,
):
warnings.simplefilter("always")
times = bench_gpu_time(
fn=torch.matmul,
input_args=(a, b),
repeat_iters=5,
dry_run_iters=2,
cold_l2_cache=False,
enable_cupti=True,
)

# Should have fallen back successfully
assert isinstance(times, list)
assert len(times) == 5
assert all(t > 0 for t in times)

# Should have emitted a fallback warning
fallback_warnings = [
w
for w in caught
if issubclass(w.category, UserWarning) and "Falling back" in str(w.message)
]
assert len(fallback_warnings) >= 1