Skip to content

Commit 905b152

Browse files
authored
Run CUDA synchronize after each _do_bench call, to surface error sooner (#544)
1 parent 3cce791 commit 905b152

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

tritonbench/utils/triton_op.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,13 +1031,8 @@ def run(
10311031
# before we hand them to the captured graph. Otherwise we can
10321032
# read partially initialized values (e.g. from torch.randint)
10331033
# and hit device-side asserts in the baseline kernels.
1034-
if (
1035-
self.use_cuda_graphs
1036-
and self.device
1037-
and self.device.startswith("cuda")
1038-
and torch.cuda.is_available()
1039-
):
1040-
torch.cuda.synchronize()
1034+
if self.use_cuda_graphs:
1035+
torch.accelerator.synchronize()
10411036
self.baseline_fn = None
10421037
self.baseline_metrics = None
10431038
self._op_flops = {}
@@ -1108,6 +1103,8 @@ def _reduce_benchmarks(acc, bm_name: str):
11081103
quantiles=quantiles,
11091104
baseline=baseline,
11101105
)
1106+
# Synchronize after each benchmark to make errors surface sooner
1107+
torch.accelerator.synchronize()
11111108
if baseline:
11121109
self.baseline_metrics = acc[bm_name]
11131110
if sleep:

0 commit comments

Comments
 (0)