Skip to content

Commit e01dd2d

Browse files
authored
[FA] Fix tutorial blackwell perf
Differential Revision: D81343267 Pull Request resolved: #377
1 parent dec6bad commit e01dd2d

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tritonbench/kernels/blackwell_triton_fused_attention.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,9 @@ def grid(META):
402402
ctx.grid = grid
403403
warp_specialize = baseVariant == "ws"
404404
if is_blackwell() and warp_specialize:
405-
if HEAD_DIM_K == 128 and q.dtype == torch.float16:
405+
if HEAD_DIM_K == 128 and (
406+
q.dtype == torch.float16 or q.dtype == torch.bfloat16
407+
):
406408
extra_kern_args["maxnreg"] = 168
407409
else:
408410
extra_kern_args["maxnreg"] = 80

0 commit comments

Comments
 (0)