Skip to content

Commit 91e3662

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Move scaling logic to input generation
Summary: Move scaling logic for FP8 benchmarks to `get_input_iter()`. This diff aligns our fp8_gemm benchmarking suite with real-world practices: input tensors are of high precision types (`bfloat16`, `float16`), scales are computed on the high-precision input tensors, and input tensors are then casted to a lower precision (`float8_e4m3fn`). This diff also circumvents performing unsupported operations, like `torch.max` and `torch.abs`, on low-precision data types. Note: this diff is a copy of a diff I made a few days ago but couldn't land due to diff train sync issues. Differential Revision: D80716975
1 parent 16eddff commit 91e3662

File tree

1 file changed

+35
-35
lines changed

1 file changed

+35
-35
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,22 @@ def __init__(
5656

5757
def get_input_iter(self):
5858
def args(m, n, k):
59-
a = torch.randn(m, k, device=self.device).to(torch.float8_e4m3fn)
60-
b = (
61-
torch.randn(k, n, device=self.device)
62-
.to(torch.float8_e4m3fn)
63-
.T.contiguous()
64-
.T
65-
)
66-
return (a, b)
59+
a = torch.randn(m, k, device=self.device).to(torch.float16)
60+
b = torch.randn(k, n, device=self.device).to(torch.float16).T.contiguous().T
61+
62+
if self.extra_args.scaling_rowwise:
63+
M, N = a.shape[0], b.shape[1]
64+
scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device)
65+
scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
66+
else:
67+
scale_a = torch.tensor(1.0, device=a.device)
68+
scale_b = torch.tensor(1.0, device=a.device)
69+
70+
# Kernels expect dtype=float8_e4m3fn
71+
a = a.to(torch.float8_e4m3fn)
72+
b = b.to(torch.float8_e4m3fn)
73+
74+
return (a, b, scale_a, scale_b)
6775

6876
if (
6977
hasattr(self, "external_shapes") and self.external_shapes
@@ -90,62 +98,54 @@ def args(m, n, k):
9098
yield args(m, n, k)
9199

92100
def get_x_val(self, example_inputs) -> float:
93-
a, b = example_inputs
101+
a, b, _, _ = example_inputs
94102
m, k = a.size()
95103
_, n = b.size()
96104
return (m, n, k)
97105

98-
@register_benchmark(baseline=True)
99-
def torch_fp8_gemm(self, a, b):
106+
def _get_out_dtype(self):
100107
if self.extra_args.scaling_rowwise:
101-
M, N = a.shape[0], b.shape[1]
102-
scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device)
103-
scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
104-
out_dtype = torch.bfloat16
108+
return torch.bfloat16
105109
else:
106-
scale_a = torch.tensor(1.0, device=a.device)
107-
scale_b = torch.tensor(1.0, device=a.device)
108-
out_dtype = torch.float16
110+
return torch.float16
109111

112+
@register_benchmark(baseline=True)
113+
def torch_fp8_gemm(self, a, b, scale_a, scale_b):
110114
return lambda: torch._scaled_mm(
111-
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=out_dtype
115+
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_out_dtype()
112116
)
113117

114118
@register_benchmark()
115-
def pt2_fp8_gemm(self, a, b) -> Callable:
119+
def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
116120
torch._dynamo.reset()
117121
with inductor_config.patch(
118122
max_autotune=True,
119123
max_autotune_gemm_backends="TRITON",
120124
autotune_fallback_to_aten=False,
121125
):
122-
if self.extra_args.scaling_rowwise:
123-
M, N = a.shape[0], b.shape[1]
124-
scale_a = torch.ones((M, 1), dtype=torch.float32, device=a.device)
125-
scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device)
126-
out_dtype = torch.bfloat16
127-
else:
128-
scale_a = torch.tensor(1.0, device=a.device)
129-
scale_b = torch.tensor(1.0, device=b.device)
130-
out_dtype = torch.float16
131126
f = lambda a, b: torch._scaled_mm(
132-
a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=out_dtype
127+
a,
128+
b,
129+
scale_a,
130+
scale_b,
131+
use_fast_accum=True,
132+
out_dtype=self._get_out_dtype(),
133133
)
134134
compiled = torch.compile(f, dynamic=False)
135135
compiled(a, b)
136136

137137
return lambda: compiled(a, b)
138138

139139
@register_benchmark()
140-
def triton_fp8_gemm(self, a, b):
140+
def triton_fp8_gemm(self, a, b, scale_a, scale_b):
141141
return lambda: tutorial_matmul(a, b)
142142

143143
@register_benchmark(enabled=HAS_TMA)
144-
def triton_persistent_fp8_gemm(self, a, b):
144+
def triton_persistent_fp8_gemm(self, a, b, scale_a, scale_b):
145145
return lambda: matmul_persistent(a, b)
146146

147147
@register_benchmark(enabled=HAS_TMA)
148-
def triton_tma_persistent_fp8_gemm(self, a, b):
148+
def triton_tma_persistent_fp8_gemm(self, a, b, scale_a, scale_b):
149149
b = b.T.contiguous()
150150
c, desc_a, desc_b, desc_c = allocate_matmul_tma(a, b)
151151
return lambda: matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c)
@@ -155,7 +155,7 @@ def gbps(self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> fl
155155
def nbytes(t):
156156
return t.numel() * t.element_size()
157157

158-
a, b = example_inputs
158+
a, b, _, _ = example_inputs
159159
c = fn()
160160
c = c[0] if isinstance(c, tuple) else c
161161

@@ -168,7 +168,7 @@ def nbytes(t):
168168
def flops(
169169
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
170170
) -> float:
171-
a, b = example_inputs
171+
a, b, _, _ = example_inputs
172172
m, k = a.size()
173173
_, n = b.size()
174174
flops = 2 * m * n * k

0 commit comments

Comments
 (0)