Skip to content

Commit a167af5

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Add fp8_gemm benchmark for deepseek-style scaling (#504)
Summary: Add `fp8_gemm` benchmark for deepseek-style scaling in TritonBench. Reviewed By: NikhilAPatel Differential Revision: D83689980
1 parent a0fb3ca commit a167af5

File tree

1 file changed

+41
-2
lines changed

1 file changed

+41
-2
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@
4545
HAS_TMA = False
4646
logger.warning(f"Failed to import TMA: {e}")
4747

48+
HAS_CUDA_129 = (
49+
torch.cuda.is_available() and torch.version.cuda and torch.version.cuda >= "12.9"
50+
)
51+
4852

4953
def parse_args(args):
5054
parser = argparse.ArgumentParser(description="TritonBench fp8_gemm")
@@ -63,6 +67,10 @@ def get_scaling_recipe_int(scaling_recipe: str) -> int:
6367
return ScalingType.TensorWise
6468
elif scaling_recipe == "RowWise":
6569
return ScalingType.RowWise
70+
elif scaling_recipe == "BlockWise1x128":
71+
return ScalingType.BlockWise1x128
72+
elif scaling_recipe == "BlockWise128x128":
73+
return ScalingType.BlockWise128x128
6674
else:
6775
raise ValueError(f"Invalid scaling recipe: {scaling_recipe}")
6876

@@ -97,11 +105,25 @@ def _get_scale_per_row(x: torch.Tensor, transpose: bool = False) -> torch.Tensor
97105
torch.float32
98106
) # For row-wise scaling, kernel requires a float32 scale tensor
99107

108+
def _get_scale_per_block(x: torch.Tensor, block_outer: int, block_inner: int):
109+
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
110+
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
111+
scale = torch.finfo(torch.float8_e4m3fn).max / amax
112+
x = (
113+
x.mul(scale).flatten(2, 3).flatten(0, 1)
114+
) # scale input up to dynamic range of float8_e4m3fn
115+
scale = scale.flatten(2, 3).flatten(0, 1)
116+
return x, scale.to(torch.float32)
117+
100118
match scaling_recipe_int:
101119
case ScalingType.TensorWise:
102120
return _get_scale_per_tensor(x, custom_scale=custom_scale)
103121
case ScalingType.RowWise:
104122
return _get_scale_per_row(x, transpose=transpose)
123+
case ScalingType.BlockWise1x128:
124+
return _get_scale_per_block(x, 1, 128)
125+
case ScalingType.BlockWise128x128:
126+
return _get_scale_per_block(x, 128, 128)
105127
case _:
106128
raise AssertionError(f"Unsupported scaling type {scaling_recipe_int}")
107129

@@ -127,6 +149,19 @@ def __init__(
127149
self.scaling_recipe_a_int = get_scaling_recipe_int(scaling_recipe_a).value
128150
self.scaling_recipe_b_int = get_scaling_recipe_int(scaling_recipe_b).value
129151

152+
blockwise_scaling_types = [
153+
ScalingType.BlockWise1x128,
154+
ScalingType.BlockWise128x128,
155+
]
156+
self.contains_blockwise_scaling = (
157+
self.scaling_recipe_a_int in blockwise_scaling_types
158+
or self.scaling_recipe_b_int in blockwise_scaling_types
159+
)
160+
161+
self.use_fast_accum = (
162+
False if self.contains_blockwise_scaling else True
163+
) # BlockWise scaled_gemm does not support use_fast_accum=True
164+
130165
def _get_dtype(self):
131166
if (
132167
self.scaling_recipe_a_int == ScalingType.TensorWise
@@ -189,12 +224,16 @@ def get_x_val(self, example_inputs) -> float:
189224

190225
@register_benchmark(baseline=True)
191226
def torch_fp8_gemm(self, a, b, scale_a, scale_b):
227+
assert (
228+
not self.contains_blockwise_scaling or HAS_CUDA_129
229+
), "BlockWise scaling variants for scaled_gemm require CUDA 12.9+"
230+
192231
return lambda: torch._scaled_mm(
193232
a,
194233
b.t(),
195234
scale_a,
196235
scale_b.t(),
197-
use_fast_accum=True,
236+
use_fast_accum=self.use_fast_accum,
198237
out_dtype=self._get_dtype(),
199238
)
200239

@@ -211,7 +250,7 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
211250
b.t(),
212251
scale_a,
213252
scale_b.t(),
214-
use_fast_accum=True,
253+
use_fast_accum=self.use_fast_accum,
215254
out_dtype=self._get_dtype(),
216255
)
217256
compiled = torch.compile(f, dynamic=False)

0 commit comments

Comments
 (0)