Skip to content

Commit 8e43e71

Browse files
authored
[AUTOTUNER] Don't cache benchmarking stream (intel#3993)
This caching seems to be responsible for some CUDA OOMs we encountered in Meta-internal builds. I haven't got a reduced repro, but this change does seem to fix things. My hypothesis is that the cached stream is causing the memory allocated for the graph to be retained.
1 parent e9480e1 commit 8e43e71

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

python/triton/runtime/autotuner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def _post_hook(args, exception):
9292
self.num_reps = rep
9393
import torch
9494
self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available()
95-
self.benchmarkig_stream = torch.cuda.Stream() if self.use_cuda_graph else None
9695

9796
def _bench(self, *args, config, **meta):
9897
from ..compiler.errors import CompileTimeAssertionFailure
@@ -128,7 +127,7 @@ def kernel_call():
128127
try:
129128
if self.use_cuda_graph:
130129
import torch
131-
with torch.cuda.stream(self.benchmarkig_stream):
130+
with torch.cuda.stream(torch.cuda.Stream()):
132131
bench_res = do_bench_cudagraph(kernel_call, rep=self.num_reps, return_mode="median")
133132
return bench_res
134133
return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))

0 commit comments

Comments
 (0)