Skip to content

Commit 2037e53

Browse files
authored
Filter out solutions that have cuda streams (#56)
1 parent 56237fc commit 2037e53

File tree

4 files changed

+226
-0
lines changed

4 files changed

+226
-0
lines changed

BackendBench/eval.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import triton.testing
66

77

8+
from BackendBench.utils import uses_cuda_stream
9+
810
logger = logging.getLogger(__name__)
911

1012
EXC_MSG = """
@@ -101,6 +103,11 @@ def eval_performance(op, impl, tests):
101103

102104
def eval_one_op(op, impl, correctness_tests, performance_tests):
103105
"""Evaluate impl of op against correctness_tests and performance_tests."""
106+
# TODO: We should have proper error reporting instead of just saying this is 0,
107+
# but that should be a separate PR.
108+
if uses_cuda_stream(impl):
109+
logger.warning(f"Skipping {op.__name__} because it uses CUDA stream")
110+
return 0, 0
104111
return eval_correctness(op, impl, correctness_tests), eval_performance(
105112
op, impl, performance_tests
106113
)

BackendBench/utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import ast
2+
import inspect
3+
import re
4+
import textwrap
5+
6+
7+
def uses_cuda_stream(func) -> bool:
8+
"""
9+
Detects whether a Python function creates CUDA streams.
10+
11+
Args:
12+
func: The Python function to analyze
13+
14+
Returns:
15+
bool: True if CUDA streams are created, False otherwise
16+
"""
17+
try:
18+
source = inspect.getsource(func)
19+
except (TypeError, OSError):
20+
# Handle builtin functions, OpOverload objects, and other callables
21+
# without source code. These cannot create CUDA streams.
22+
return False
23+
24+
# Check for stream creation patterns
25+
patterns = [
26+
r"torch\.cuda\.Stream\(", # torch.cuda.Stream() constructor
27+
r"cupy\.cuda\.Stream\(", # cupy.cuda.Stream() constructor
28+
r"cuda\.Stream\(", # Generic cuda.Stream() constructor
29+
r"pycuda.*Stream\(", # PyCUDA stream creation
30+
r"\bStream\(", # Stream() constructor calls
31+
r"make_stream\(", # make_stream() factory function
32+
r"create_stream\(", # create_stream() factory function
33+
]
34+
35+
if any(re.search(p, source, re.IGNORECASE) for p in patterns):
36+
return True
37+
38+
class StreamCreationFinder(ast.NodeVisitor):
39+
def __init__(self):
40+
self.found = False
41+
42+
def visit_Call(self, node):
43+
# Check for Stream() constructor calls
44+
if hasattr(node.func, "attr") and node.func.attr == "Stream":
45+
self.found = True
46+
elif hasattr(node.func, "id") and node.func.id == "Stream":
47+
self.found = True
48+
self.generic_visit(node)
49+
50+
tree = ast.parse(textwrap.dedent(source))
51+
finder = StreamCreationFinder()
52+
finder.visit(tree)
53+
return finder.found

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ dev = [
3333
"pytest-mock",
3434
"pytest-timeout",
3535
"ruff==0.12.1",
36+
"torch",
37+
"numpy",
38+
"cupy-cuda12x",
3639
]
3740
flaggems = [
3841
"flag_gems",

test/test_utils.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import pytest
2+
from BackendBench.utils import uses_cuda_stream
3+
4+
# Check if CUDA is available
5+
import torch
6+
7+
HAS_CUDA = torch.cuda.is_available()
8+
9+
10+
class TestCudaStreamDetection:
11+
@pytest.mark.skipif(not HAS_CUDA, reason="CUDA not available")
12+
def test_pytorch_stream_creation(self):
13+
"""Test detection of PyTorch CUDA stream creation."""
14+
15+
def func_with_pytorch_stream():
16+
import torch
17+
18+
stream = torch.cuda.Stream()
19+
return stream
20+
21+
assert uses_cuda_stream(func_with_pytorch_stream)
22+
23+
@pytest.mark.skipif(not HAS_CUDA, reason="CUDA not available")
24+
def test_cupy_stream_creation(self):
25+
import cupy
26+
27+
"""Test detection of CuPy CUDA stream creation."""
28+
29+
def func_with_cupy_stream():
30+
stream = cupy.cuda.Stream()
31+
return stream
32+
33+
assert uses_cuda_stream(func_with_cupy_stream)
34+
35+
@pytest.mark.skipif(not HAS_CUDA, reason="CUDA not available")
36+
def test_generic_stream_creation(self):
37+
"""Test detection of generic Stream() calls."""
38+
39+
def func_with_generic_stream():
40+
from torch.cuda import Stream
41+
42+
stream = Stream()
43+
return stream
44+
45+
assert uses_cuda_stream(func_with_generic_stream)
46+
47+
@pytest.mark.skipif(not HAS_CUDA, reason="CUDA not available")
48+
def test_stream_with_device_id(self):
49+
"""Test detection of Stream with device ID."""
50+
51+
def func_with_device_stream():
52+
from torch.cuda import Stream
53+
54+
stream = Stream(0)
55+
return stream
56+
57+
assert uses_cuda_stream(func_with_device_stream)
58+
59+
def test_no_stream_creation(self):
60+
"""Test functions without stream creation return False."""
61+
62+
def func_without_stream():
63+
import torch
64+
65+
x = torch.randn(100, 100)
66+
y = x @ x.T
67+
return y
68+
69+
assert not uses_cuda_stream(func_without_stream)
70+
71+
@pytest.mark.skipif(not HAS_CUDA, reason="CUDA not available")
72+
def test_lambda_function(self):
73+
"""Test detection in lambda functions."""
74+
75+
def func_lambda_with_stream():
76+
return torch.cuda.Stream()
77+
78+
def func_lambda_without(x):
79+
return x * 2
80+
81+
assert uses_cuda_stream(func_lambda_with_stream)
82+
assert not uses_cuda_stream(func_lambda_without)
83+
84+
@pytest.mark.skipif(not HAS_CUDA, reason="CUDA not available")
85+
def test_nested_function(self):
86+
"""Test detection in nested functions."""
87+
88+
def outer_function():
89+
def inner_with_stream():
90+
import torch
91+
92+
return torch.cuda.Stream()
93+
94+
return inner_with_stream
95+
96+
inner = outer_function()
97+
assert uses_cuda_stream(inner)
98+
99+
@pytest.mark.skipif(not HAS_CUDA, reason="CUDA not available")
100+
def test_class_method(self):
101+
"""Test detection in class methods."""
102+
103+
class StreamClass:
104+
def method_with_stream(self):
105+
import torch
106+
107+
self.stream = torch.cuda.Stream()
108+
return self.stream
109+
110+
def method_without_stream(self):
111+
return "no stream here"
112+
113+
obj = StreamClass()
114+
assert uses_cuda_stream(obj.method_with_stream)
115+
assert not uses_cuda_stream(obj.method_without_stream)
116+
117+
@pytest.mark.skipif(not HAS_CUDA, reason="CUDA not available")
118+
def test_various_formats(self):
119+
"""Test various formatting of stream creation."""
120+
121+
def func_spaces():
122+
stream = torch.cuda.Stream()
123+
return stream
124+
125+
def func_multiline():
126+
stream = torch.cuda.Stream(device=0)
127+
return stream
128+
129+
def func_chained():
130+
result = torch.cuda.Stream().query()
131+
return result
132+
133+
assert uses_cuda_stream(func_spaces)
134+
assert uses_cuda_stream(func_multiline)
135+
assert uses_cuda_stream(func_chained)
136+
137+
@pytest.mark.skipif(not HAS_CUDA, reason="CUDA not available")
138+
def test_case_sensitivity(self):
139+
"""Test case-insensitive detection."""
140+
141+
def func_lowercase():
142+
stream = torch.cuda.stream() # lowercase (if it existed)
143+
return stream
144+
145+
def func_uppercase():
146+
stream = torch.cuda.STREAM() # uppercase (if it existed)
147+
return stream
148+
149+
# These should still be detected due to case-insensitive regex
150+
assert uses_cuda_stream(func_lowercase)
151+
assert uses_cuda_stream(func_uppercase)
152+
153+
def test_opoverload_callables(self):
154+
"""Test that OpOverload objects don't raise exceptions."""
155+
import torch
156+
157+
# Test OpOverload (torch operators)
158+
assert not uses_cuda_stream(torch.add)
159+
assert not uses_cuda_stream(torch.ops.aten.add)
160+
161+
162+
if __name__ == "__main__":
163+
pytest.main([__file__])

0 commit comments

Comments
 (0)