|
1 | 1 | import argparse
|
2 | 2 | import logging
|
3 | 3 |
|
4 |
| -from typing import Any, List, Optional |
| 4 | +from typing import Any, Callable, List, Optional |
5 | 5 |
|
6 | 6 | import torch
|
| 7 | +import torch._inductor.config as inductor_config |
7 | 8 | import triton
|
8 | 9 |
|
9 | 10 | from tritonbench.utils.triton_op import (
|
@@ -90,6 +91,24 @@ def torch_fp8_gemm(self, a, b):
|
90 | 91 | a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=torch.float16
|
91 | 92 | )
|
92 | 93 |
|
| 94 | + @register_benchmark() |
| 95 | + def pt2_fp8_gemm(self, a, b) -> Callable: |
| 96 | + torch._dynamo.reset() |
| 97 | + with inductor_config.patch( |
| 98 | + max_autotune=True, |
| 99 | + max_autotune_gemm_backends="TRITON", |
| 100 | + autotune_fallback_to_aten=False, |
| 101 | + ): |
| 102 | + scale_a = torch.tensor(1.0, device=a.device) |
| 103 | + scale_b = torch.tensor(1.0, device=a.device) |
| 104 | + f = lambda a, b: torch._scaled_mm( |
| 105 | + a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=torch.float16 |
| 106 | + ) |
| 107 | + compiled = torch.compile(f, dynamic=False) |
| 108 | + compiled(a, b) |
| 109 | + |
| 110 | + return lambda: compiled(a, b) |
| 111 | + |
93 | 112 | @register_benchmark()
|
94 | 113 | def triton_fp8_gemm(self, a, b):
|
95 | 114 | return lambda: tutorial_matmul(a, b)
|
|
0 commit comments