From 6c302aad9b9a46b76b525c09cf4ff3928e84225d Mon Sep 17 00:00:00 2001 From: sha7doww <2453900478@qq.com> Date: Thu, 19 Mar 2026 14:46:20 +0800 Subject: [PATCH 1/4] fix: fall back to CUDA events when CUPTI driver version < 13.0 Probe CUPTI activity tracing support inside the existing try block so that NotSupportedError on older drivers is caught by the existing fallback logic. Co-Authored-By: Claude Opus 4.6 (1M context) --- flashinfer/testing/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/flashinfer/testing/utils.py b/flashinfer/testing/utils.py index b4175e097c..cffc404613 100644 --- a/flashinfer/testing/utils.py +++ b/flashinfer/testing/utils.py @@ -1006,7 +1006,7 @@ def bench_gpu_time_with_cupti( # Use 2x L2 size to ensure complete flush _l2_flush_size_mb = (l2_size * 2) // (1024 * 1024) - # check if CUPTI is installed and its version is >= 13.0.0 + # check if CUPTI is installed, version >= 13, and the driver supports it try: from cupti import cupti from importlib.metadata import version as importlib_metadata_version @@ -1016,6 +1016,9 @@ def bench_gpu_time_with_cupti( raise Exception( "CUPTI needs to be >= 13.0.0. Try 'pip install -U cupti-python'." ) + # Probe driver support (raises NotSupportedError on CUDA < 13.0 drivers) + cupti.activity_enable(cupti.ActivityKind.RUNTIME) + cupti.activity_disable(cupti.ActivityKind.RUNTIME) from functools import partial except (ModuleNotFoundError, Exception) as e: if isinstance(e, ModuleNotFoundError): From 29e0f268936cec5099e15b13d965244a6ca1ce8a Mon Sep 17 00:00:00 2001 From: sha7doww <2453900478@qq.com> Date: Thu, 19 Mar 2026 15:50:29 +0800 Subject: [PATCH 2/4] test: add regression test for CUPTI driver version fallback Verify that bench_gpu_time_with_cupti gracefully falls back to CUDA events when cupti.activity_enable raises (e.g. CUDA driver < 13.0). Co-Authored-By: Claude Opus 4.6 --- tests/utils/test_cupti_fallback.py | 56 ++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 tests/utils/test_cupti_fallback.py diff --git a/tests/utils/test_cupti_fallback.py b/tests/utils/test_cupti_fallback.py new file mode 100644 index 0000000000..3d9aef6228 --- /dev/null +++ b/tests/utils/test_cupti_fallback.py @@ -0,0 +1,56 @@ +"""Test that bench_gpu_time_with_cupti falls back gracefully when CUPTI is +unavailable or the driver does not support it (e.g. CUDA driver < 13.0).""" + +import warnings +from unittest.mock import patch, MagicMock + +import pytest +import torch + +from flashinfer.testing import bench_gpu_time_with_cupti + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_cupti_fallback_on_activity_enable_error(): + """When cupti.activity_enable raises (old driver), fall back to CUDA events.""" + a = torch.randn(64, 64, device="cuda") + b = torch.randn(64, 64, device="cuda") + + # Build a fake cupti module whose activity_enable always raises + fake_cupti = MagicMock() + fake_cupti.activity_enable.side_effect = Exception( + "CUPTI_ERROR_NOT_SUPPORTED: driver too old" + ) + fake_module = MagicMock() + fake_module.cupti = fake_cupti + + real_import = __builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__ + + def _patched_import(name, *args, **kwargs): + if name == "cupti": + return fake_module + return real_import(name, *args, **kwargs) + + # Also patch importlib.metadata.version to report a new-enough cupti-python + with patch("importlib.metadata.version", return_value="13.0.0"), \ + patch("builtins.__import__", side_effect=_patched_import): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + times = bench_gpu_time_with_cupti( + fn=torch.matmul, + input_args=(a, b), + repeat_iters=5, + dry_run_iters=2, + cold_l2_cache=False, + ) + + # Should have fallen back successfully + assert isinstance(times, list) + assert len(times) == 5 + assert all(t > 0 for t in times) + + # Should have emitted a fallback warning + fallback_warnings = [ + w for w in caught if issubclass(w.category, UserWarning) and "Falling back" in str(w.message) + ] + assert len(fallback_warnings) >= 1 From 2cd7e5db3b70d924227733f1ae3fcf7b236bd0e1 Mon Sep 17 00:00:00 2001 From: sha7doww <2453900478@qq.com> Date: Thu, 19 Mar 2026 16:02:41 +0800 Subject: [PATCH 3/4] style: address review feedback in CUPTI fallback test - Use bench_gpu_time(enable_cupti=True) public API instead of bench_gpu_time_with_cupti directly (CodeRabbit suggestion) - Combine nested with statements (ruff SIM117) - Add docstrings to inner helper Co-Authored-By: Claude Opus 4.6 --- tests/utils/test_cupti_fallback.py | 41 ++++++++++++++++++------------ 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/tests/utils/test_cupti_fallback.py b/tests/utils/test_cupti_fallback.py index 3d9aef6228..13e7031381 100644 --- a/tests/utils/test_cupti_fallback.py +++ b/tests/utils/test_cupti_fallback.py @@ -1,4 +1,4 @@ -"""Test that bench_gpu_time_with_cupti falls back gracefully when CUPTI is +"""Test that bench_gpu_time falls back gracefully to CUDA events when CUPTI is unavailable or the driver does not support it (e.g. CUDA driver < 13.0).""" import warnings @@ -7,12 +7,13 @@ import pytest import torch -from flashinfer.testing import bench_gpu_time_with_cupti +from flashinfer.testing import bench_gpu_time @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") def test_cupti_fallback_on_activity_enable_error(): - """When cupti.activity_enable raises (old driver), fall back to CUDA events.""" + """When cupti.activity_enable raises (old driver), bench_gpu_time should + fall back to CUDA events instead of crashing.""" a = torch.randn(64, 64, device="cuda") b = torch.randn(64, 64, device="cuda") @@ -24,25 +25,31 @@ def test_cupti_fallback_on_activity_enable_error(): fake_module = MagicMock() fake_module.cupti = fake_cupti - real_import = __builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__ + real_import = ( + __builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__ + ) def _patched_import(name, *args, **kwargs): + """Route 'cupti' imports to the fake module.""" if name == "cupti": return fake_module return real_import(name, *args, **kwargs) # Also patch importlib.metadata.version to report a new-enough cupti-python - with patch("importlib.metadata.version", return_value="13.0.0"), \ - patch("builtins.__import__", side_effect=_patched_import): - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - times = bench_gpu_time_with_cupti( - fn=torch.matmul, - input_args=(a, b), - repeat_iters=5, - dry_run_iters=2, - cold_l2_cache=False, - ) + with ( + patch("importlib.metadata.version", return_value="13.0.0"), + patch("builtins.__import__", side_effect=_patched_import), + warnings.catch_warnings(record=True) as caught, + ): + warnings.simplefilter("always") + times = bench_gpu_time( + fn=torch.matmul, + input_args=(a, b), + repeat_iters=5, + dry_run_iters=2, + cold_l2_cache=False, + enable_cupti=True, + ) # Should have fallen back successfully assert isinstance(times, list) @@ -51,6 +58,8 @@ def _patched_import(name, *args, **kwargs): # Should have emitted a fallback warning fallback_warnings = [ - w for w in caught if issubclass(w.category, UserWarning) and "Falling back" in str(w.message) + w + for w in caught + if issubclass(w.category, UserWarning) and "Falling back" in str(w.message) ] assert len(fallback_warnings) >= 1 From bb2ee4aba4fd84899c1e5cfa66d3d77636c52db2 Mon Sep 17 00:00:00 2001 From: sha7doww <2453900478@qq.com> Date: Thu, 19 Mar 2026 16:17:06 +0800 Subject: [PATCH 4/4] style: simplify real_import to just __import__ MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address Gemini review suggestion — no need for __builtins__ check in standard Python 3. Co-Authored-By: Claude Opus 4.6 --- tests/utils/test_cupti_fallback.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/utils/test_cupti_fallback.py b/tests/utils/test_cupti_fallback.py index 13e7031381..6faeca62bc 100644 --- a/tests/utils/test_cupti_fallback.py +++ b/tests/utils/test_cupti_fallback.py @@ -25,9 +25,7 @@ def test_cupti_fallback_on_activity_enable_error(): fake_module = MagicMock() fake_module.cupti = fake_cupti - real_import = ( - __builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__ - ) + real_import = __import__ def _patched_import(name, *args, **kwargs): """Route 'cupti' imports to the fake module."""