-
Notifications
You must be signed in to change notification settings - Fork 829
Expand file tree
/
Copy pathtest_cupti_fallback.py
More file actions
63 lines (52 loc) · 2.04 KB
/
test_cupti_fallback.py
File metadata and controls
63 lines (52 loc) · 2.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""Test that bench_gpu_time falls back gracefully to CUDA events when CUPTI is
unavailable or the driver does not support it (e.g. CUDA driver < 13.0)."""
import warnings
from unittest.mock import patch, MagicMock
import pytest
import torch
from flashinfer.testing import bench_gpu_time
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
def test_cupti_fallback_on_activity_enable_error():
"""When cupti.activity_enable raises (old driver), bench_gpu_time should
fall back to CUDA events instead of crashing."""
a = torch.randn(64, 64, device="cuda")
b = torch.randn(64, 64, device="cuda")
# Build a fake cupti module whose activity_enable always raises
fake_cupti = MagicMock()
fake_cupti.activity_enable.side_effect = Exception(
"CUPTI_ERROR_NOT_SUPPORTED: driver too old"
)
fake_module = MagicMock()
fake_module.cupti = fake_cupti
real_import = __import__
def _patched_import(name, *args, **kwargs):
"""Route 'cupti' imports to the fake module."""
if name == "cupti":
return fake_module
return real_import(name, *args, **kwargs)
# Also patch importlib.metadata.version to report a new-enough cupti-python
with (
patch("importlib.metadata.version", return_value="13.0.0"),
patch("builtins.__import__", side_effect=_patched_import),
warnings.catch_warnings(record=True) as caught,
):
warnings.simplefilter("always")
times = bench_gpu_time(
fn=torch.matmul,
input_args=(a, b),
repeat_iters=5,
dry_run_iters=2,
cold_l2_cache=False,
enable_cupti=True,
)
# Should have fallen back successfully
assert isinstance(times, list)
assert len(times) == 5
assert all(t > 0 for t in times)
# Should have emitted a fallback warning
fallback_warnings = [
w
for w in caught
if issubclass(w.category, UserWarning) and "Falling back" in str(w.message)
]
assert len(fallback_warnings) >= 1