diff --git a/flashinfer/testing/utils.py b/flashinfer/testing/utils.py index b4175e097c..cffc404613 100644 --- a/flashinfer/testing/utils.py +++ b/flashinfer/testing/utils.py @@ -1006,7 +1006,7 @@ def bench_gpu_time_with_cupti( # Use 2x L2 size to ensure complete flush _l2_flush_size_mb = (l2_size * 2) // (1024 * 1024) - # check if CUPTI is installed and its version is >= 13.0.0 + # check if CUPTI is installed, version >= 13, and the driver supports it try: from cupti import cupti from importlib.metadata import version as importlib_metadata_version @@ -1016,6 +1016,9 @@ def bench_gpu_time_with_cupti( raise Exception( "CUPTI needs to be >= 13.0.0. Try 'pip install -U cupti-python'." ) + # Probe driver support (raises NotSupportedError on CUDA < 13.0 drivers) + cupti.activity_enable(cupti.ActivityKind.RUNTIME) + cupti.activity_disable(cupti.ActivityKind.RUNTIME) from functools import partial except (ModuleNotFoundError, Exception) as e: if isinstance(e, ModuleNotFoundError): diff --git a/tests/utils/test_cupti_fallback.py b/tests/utils/test_cupti_fallback.py new file mode 100644 index 0000000000..6faeca62bc --- /dev/null +++ b/tests/utils/test_cupti_fallback.py @@ -0,0 +1,63 @@ +"""Test that bench_gpu_time falls back gracefully to CUDA events when CUPTI is +unavailable or the driver does not support it (e.g. CUDA driver < 13.0).""" + +import warnings +from unittest.mock import patch, MagicMock + +import pytest +import torch + +from flashinfer.testing import bench_gpu_time + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_cupti_fallback_on_activity_enable_error(): + """When cupti.activity_enable raises (old driver), bench_gpu_time should + fall back to CUDA events instead of crashing.""" + a = torch.randn(64, 64, device="cuda") + b = torch.randn(64, 64, device="cuda") + + # Build a fake cupti module whose activity_enable always raises + fake_cupti = MagicMock() + fake_cupti.activity_enable.side_effect = Exception( + "CUPTI_ERROR_NOT_SUPPORTED: driver too old" + ) + fake_module = MagicMock() + fake_module.cupti = fake_cupti + + real_import = __import__ + + def _patched_import(name, *args, **kwargs): + """Route 'cupti' imports to the fake module.""" + if name == "cupti": + return fake_module + return real_import(name, *args, **kwargs) + + # Also patch importlib.metadata.version to report a new-enough cupti-python + with ( + patch("importlib.metadata.version", return_value="13.0.0"), + patch("builtins.__import__", side_effect=_patched_import), + warnings.catch_warnings(record=True) as caught, + ): + warnings.simplefilter("always") + times = bench_gpu_time( + fn=torch.matmul, + input_args=(a, b), + repeat_iters=5, + dry_run_iters=2, + cold_l2_cache=False, + enable_cupti=True, + ) + + # Should have fallen back successfully + assert isinstance(times, list) + assert len(times) == 5 + assert all(t > 0 for t in times) + + # Should have emitted a fallback warning + fallback_warnings = [ + w + for w in caught + if issubclass(w.category, UserWarning) and "Falling back" in str(w.message) + ] + assert len(fallback_warnings) >= 1