Skip to content

Commit fd94232

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Add fp8_gemm benchmark for deepseek-style scaling (#504)
Summary: Add `fp8_gemm` benchmark for deepseek-style scaling in TritonBench. Reviewed By: NikhilAPatel Differential Revision: D83689980
1 parent 301525f commit fd94232

File tree

1 file changed

+54
-5
lines changed

1 file changed

+54
-5
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@
4545
HAS_TMA = False
4646
logger.warning(f"Failed to import TMA: {e}")
4747

48+
HAS_CUDA_129 = (
49+
torch.cuda.is_available() and torch.version.cuda and torch.version.cuda >= "12.9"
50+
)
51+
4852

4953
def parse_args(args):
5054
parser = argparse.ArgumentParser(description="TritonBench fp8_gemm")
@@ -63,6 +67,10 @@ def get_scaling_recipe_int(scaling_recipe: str) -> int:
6367
return ScalingType.TensorWise
6468
elif scaling_recipe == "RowWise":
6569
return ScalingType.RowWise
70+
elif scaling_recipe == "BlockWise1x128":
71+
return ScalingType.BlockWise1x128
72+
elif scaling_recipe == "BlockWise128x128":
73+
return ScalingType.BlockWise128x128
6674
else:
6775
raise ValueError(f"Invalid scaling recipe: {scaling_recipe}")
6876

@@ -79,7 +87,7 @@ def _get_scale_per_tensor(
7987
# For tensor-wise scaling, kernel requires a float32 scale tensor
8088
if custom_scale:
8189
return torch.tensor(custom_scale, dtype=torch.float32, device=x.device)
82-
scale = torch.finfo(torch.float8_e4m3fn).max / x.abs().max()
90+
scale = (torch.finfo(torch.float8_e4m3fn).max / x.abs().max()).reciprocal()
8391
x *= scale
8492
return x, scale.to(torch.float32)
8593

@@ -90,22 +98,46 @@ def _get_scale_per_row(
9098
scale = (
9199
torch.finfo(torch.float8_e4m3fn).max
92100
/ x.abs().max(dim=0, keepdim=True).values
93-
)
101+
).reciprocal()
94102
else: # scale_a.shape should be [M, 1]
95103
scale = (
96104
torch.finfo(torch.float8_e4m3fn).max
97105
/ x.abs().max(dim=1, keepdim=True).values
98-
)
106+
).reciprocal()
99107
x = x.mul(scale)
100108
return x, scale.to(
101109
torch.float32
102110
) # For row-wise scaling, kernel requires a float32 scale tensor
103111

112+
def _get_scale_per_block(
113+
x: torch.Tensor, block_outer: int, block_inner: int
114+
) -> (torch.Tensor, torch.Tensor):
115+
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
116+
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
117+
scale = (
118+
torch.finfo(torch.float8_e4m3fn).max / amax
119+
).reciprocal() # keeps scale small enough such that scaling doesn't cause inf values
120+
x = (
121+
x.mul(scale).flatten(2, 3).flatten(0, 1)
122+
) # scale input up to dynamic range of float8_e4m3fn
123+
scale = scale.flatten(2, 3).flatten(0, 1)
124+
125+
if block_outer == 1 and block_inner == 128:
126+
scale = (
127+
scale.t().contiguous().t()
128+
) # 1x128 blocks need scales to be outer-dim-major
129+
130+
return x, scale.to(torch.float32)
131+
104132
match scaling_recipe_int:
105133
case ScalingType.TensorWise:
106134
return _get_scale_per_tensor(x, custom_scale=custom_scale)
107135
case ScalingType.RowWise:
108136
return _get_scale_per_row(x, transpose=transpose)
137+
case ScalingType.BlockWise1x128:
138+
return _get_scale_per_block(x, 1, 128)
139+
case ScalingType.BlockWise128x128:
140+
return _get_scale_per_block(x, 128, 128)
109141
case _:
110142
raise AssertionError(f"Unsupported scaling type {scaling_recipe_int}")
111143

@@ -131,6 +163,19 @@ def __init__(
131163
self.scaling_recipe_a_int = get_scaling_recipe_int(scaling_recipe_a).value
132164
self.scaling_recipe_b_int = get_scaling_recipe_int(scaling_recipe_b).value
133165

166+
blockwise_scaling_types = [
167+
ScalingType.BlockWise1x128,
168+
ScalingType.BlockWise128x128,
169+
]
170+
self.contains_blockwise_scaling = (
171+
self.scaling_recipe_a_int in blockwise_scaling_types
172+
or self.scaling_recipe_b_int in blockwise_scaling_types
173+
)
174+
175+
self.use_fast_accum = (
176+
False if self.contains_blockwise_scaling else True
177+
) # BlockWise scaled_gemm does not support use_fast_accum=True
178+
134179
def _get_dtype(self):
135180
if (
136181
self.scaling_recipe_a_int == ScalingType.TensorWise
@@ -193,12 +238,16 @@ def get_x_val(self, example_inputs) -> float:
193238

194239
@register_benchmark(baseline=True)
195240
def torch_fp8_gemm(self, a, b, scale_a, scale_b):
241+
assert (
242+
not self.contains_blockwise_scaling or HAS_CUDA_129
243+
), "BlockWise scaling variants for scaled_gemm require CUDA 12.9+"
244+
196245
return lambda: torch._scaled_mm(
197246
a,
198247
b.t(),
199248
scale_a,
200249
scale_b.t(),
201-
use_fast_accum=True,
250+
use_fast_accum=self.use_fast_accum,
202251
out_dtype=self._get_dtype(),
203252
)
204253

@@ -215,7 +264,7 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
215264
b.t(),
216265
scale_a,
217266
scale_b.t(),
218-
use_fast_accum=True,
267+
use_fast_accum=self.use_fast_accum,
219268
out_dtype=self._get_dtype(),
220269
)
221270
compiled = torch.compile(f, dynamic=False)

0 commit comments

Comments
 (0)