Skip to content

Commit b39e28e

Browse files
authored
Add Inductor benchmark for scaled Blackwell persistent + TMA
Differential Revision: D82699362 Pull Request resolved: #453
1 parent 6938569 commit b39e28e

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,27 @@ def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
189189
self.extra_args.scaling_rowwise,
190190
)
191191

192+
@register_benchmark(enabled=True)
193+
def blackwell_pt2_fp8_gemm(self, a, b, scale_a, scale_b):
194+
torch._dynamo.reset()
195+
with inductor_config.patch(
196+
max_autotune=True,
197+
max_autotune_gemm_backends="TRITON",
198+
autotune_fallback_to_aten=False,
199+
):
200+
f = lambda a, b: torch._scaled_mm(
201+
a,
202+
b,
203+
scale_a,
204+
scale_b,
205+
use_fast_accum=True,
206+
out_dtype=self._get_dtype(),
207+
)
208+
compiled = torch.compile(f, dynamic=False)
209+
compiled(a, b)
210+
211+
return lambda: compiled(a, b)
212+
192213
@register_benchmark()
193214
def triton_fp8_gemm(self, a, b, scale_a, scale_b):
194215
return lambda: tutorial_matmul(a, b)

0 commit comments

Comments
 (0)