Skip to content

Commit 9e92724

Browse files
authored
[Testing] Add a div0 check in the benchmarking function (#6868)
At Meta we try and reuse the Triton benchmarking infrastructure when comparing our Triton kernels to native comparisons. We have found a [rare case where comparing to a CK baseline is registering as "0ms"](https://github.com/pytorch-labs/tritonbench/blob/a13002697ff55096f495cd132d35cdc414ce36bf/tritonbench/operators/fp8_gemm_rowwise/operator.py#L204). This crashes our work-stream, so this adds as simple division by 0 check to prevent this issue. The default of 1000 is chosen arbitrarily.
1 parent 74d7e4d commit 9e92724

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

python/triton/testing.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,11 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod
9595
end_event.record()
9696
torch.cuda.synchronize()
9797
estimate_ms = start_event.elapsed_time(end_event) / 5
98-
n_repeat = max(1, int(rep / estimate_ms))
98+
# Rewrite to avoid possible division by 0 issues with fast benchmarks
99+
if estimate_ms == 0:
100+
n_repeat = 1000
101+
else:
102+
n_repeat = max(1, int(rep / estimate_ms))
99103
# step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize
100104
# host overhead
101105
g = torch.cuda.CUDAGraph()

0 commit comments

Comments
 (0)