Skip to content

Commit 2a254fe

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Add fp8_gemm benchmark for deepseek-style scaling
Summary: Add `fp8_gemm` benchmark for deepseek-style scaling in TritonBench. Differential Revision: D83689980
1 parent be516f4 commit 2a254fe

File tree

1 file changed

+45
-2
lines changed

1 file changed

+45
-2
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@
4747
HAS_TMA = False
4848
logger.warning(f"Failed to import TMA: {e}")
4949

50+
HAS_CUDA_129 = (
51+
torch.cuda.is_available() and torch.version.cuda and torch.version.cuda >= "12.9"
52+
)
53+
5054

5155
def parse_args(args):
5256
parser = argparse.ArgumentParser(description="TritonBench fp8_gemm")
@@ -65,6 +69,8 @@ def get_scaling_mode_int(scaling_mode: str) -> int:
6569
return ScalingMode.TENSOR
6670
elif scaling_mode == "row":
6771
return ScalingMode.ROW
72+
elif scaling_mode == "deepseek":
73+
return ScalingMode.DEEPSEEK
6874
else:
6975
raise ValueError(f"Invalid scaling mode: {scaling_mode}")
7076

@@ -113,11 +119,40 @@ def _get_scale_per_row(
113119
torch.float32
114120
) # For row-wise scaling, kernel requires a float32 scale tensor
115121

122+
def _get_scale_deepseek(
123+
x: torch.Tensor,
124+
block_outer: int,
125+
) -> tuple[torch.Tensor, torch.Tensor]:
126+
"""
127+
DeepSeek-style scaling on matmul A @ B uses a combination of block- and tile-wise scaling:
128+
- activation tensor A: 1x128 tile-wise scaling
129+
- weight tensor B: 128x128 block-wise scaling
130+
"""
131+
block_inner = 128
132+
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
133+
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
134+
scale = torch.finfo(torch.float8_e4m3fn).max / amax
135+
x = (
136+
x.mul(scale).flatten(2, 3).flatten(0, 1)
137+
) # scale input up to dynamic range of float8_e4m3fn
138+
scale = scale.flatten(2, 3).flatten(0, 1)
139+
return x, scale.to(torch.float32)
140+
116141
def args(m, n, k):
117142
a = torch.randn(m, k, device=self.device).to(self._get_dtype())
118143
b = torch.randn(n, k, device=self.device).to(self._get_dtype())
119144

120-
if self.scaling_mode_int == ScalingMode.ROW:
145+
if self.scaling_mode_int == ScalingMode.DEEPSEEK:
146+
activations_block_outer = 1
147+
weights_block_outer = 128
148+
149+
a, scale_a = _get_scale_deepseek(a, activations_block_outer)
150+
b, scale_b = _get_scale_deepseek(b, weights_block_outer)
151+
152+
scale_a = (
153+
scale_a.t().contiguous().t()
154+
) # 1x128 blocks need scales to be outer-dim-major
155+
elif self.scaling_mode_int == ScalingMode.ROW:
121156
scale_a = _get_scale_per_row(a)
122157
scale_b = _get_scale_per_row(b)
123158
else: # self.scaling_mode_int == ScalingMode.TENSOR
@@ -166,12 +201,20 @@ def get_x_val(self, example_inputs) -> float:
166201

167202
@register_benchmark(baseline=True)
168203
def torch_fp8_gemm(self, a, b, scale_a, scale_b):
204+
is_scaling_deepseek = self.scaling_mode_int == ScalingMode.DEEPSEEK
205+
206+
assert (
207+
not is_scaling_deepseek or HAS_CUDA_129
208+
), "Deepseek-style scaling (BlockWise128x128) for scaled_gemm requires CUDA 12.9+"
209+
210+
use_fast_accum = False if is_scaling_deepseek else True # blockwise scaled_gemm does not support use_fast_accum=True
211+
169212
return lambda: torch._scaled_mm(
170213
a,
171214
b.t(),
172215
scale_a,
173216
scale_b.t(),
174-
use_fast_accum=True,
217+
use_fast_accum=use_fast_accum,
175218
out_dtype=self._get_dtype(),
176219
)
177220

0 commit comments

Comments
 (0)