Skip to content

Commit e004da9

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Add fp8_gemm benchmark for deepseek-style scaling (#504)
Summary: Add `fp8_gemm` benchmark for deepseek-style scaling in TritonBench. Reviewed By: NikhilAPatel Differential Revision: D83689980
1 parent 7360f14 commit e004da9

File tree

1 file changed

+47
-2
lines changed

1 file changed

+47
-2
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@
4545
HAS_TMA = False
4646
logger.warning(f"Failed to import TMA: {e}")
4747

48+
HAS_CUDA_129 = (
49+
torch.cuda.is_available() and torch.version.cuda and torch.version.cuda >= "12.9"
50+
)
51+
4852

4953
def parse_args(args):
5054
parser = argparse.ArgumentParser(description="TritonBench fp8_gemm")
@@ -63,6 +67,8 @@ def get_scaling_mode_int(scaling_mode: str) -> int:
6367
return ScalingMode.TENSOR
6468
elif scaling_mode == "row":
6569
return ScalingMode.ROW
70+
elif scaling_mode == "deepseek":
71+
return ScalingMode.DEEPSEEK
6672
else:
6773
raise ValueError(f"Invalid scaling mode: {scaling_mode}")
6874

@@ -112,11 +118,40 @@ def _get_scale_per_row(
112118
torch.float32
113119
) # For row-wise scaling, kernel requires a float32 scale tensor
114120

121+
def _get_scale_deepseek(
122+
x: torch.Tensor,
123+
block_outer: int,
124+
) -> tuple[torch.Tensor, torch.Tensor]:
125+
"""
126+
DeepSeek-style scaling on matmul A @ B uses a combination of block- and tile-wise scaling:
127+
- activation tensor A: 1x128 tile-wise scaling
128+
- weight tensor B: 128x128 block-wise scaling
129+
"""
130+
block_inner = 128
131+
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
132+
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
133+
scale = torch.finfo(torch.float8_e4m3fn).max / amax
134+
x = (
135+
x.mul(scale).flatten(2, 3).flatten(0, 1)
136+
) # scale input up to dynamic range of float8_e4m3fn
137+
scale = scale.flatten(2, 3).flatten(0, 1)
138+
return x, scale.to(torch.float32)
139+
115140
def args(m, n, k):
116141
a = torch.randn(m, k, device=self.device).to(self._get_dtype())
117142
b = torch.randn(n, k, device=self.device).to(self._get_dtype())
118143

119-
if self.scaling_mode_int == ScalingMode.ROW:
144+
if self.scaling_mode_int == ScalingMode.DEEPSEEK:
145+
activations_block_outer = 1
146+
weights_block_outer = 128
147+
148+
a, scale_a = _get_scale_deepseek(a, activations_block_outer)
149+
b, scale_b = _get_scale_deepseek(b, weights_block_outer)
150+
151+
scale_a = (
152+
scale_a.t().contiguous().t()
153+
) # 1x128 blocks need scales to be outer-dim-major
154+
elif self.scaling_mode_int == ScalingMode.ROW:
120155
scale_a = _get_scale_per_row(a)
121156
scale_b = _get_scale_per_row(b)
122157
else: # self.scaling_mode_int == ScalingMode.TENSOR
@@ -165,12 +200,22 @@ def get_x_val(self, example_inputs) -> float:
165200

166201
@register_benchmark(baseline=True)
167202
def torch_fp8_gemm(self, a, b, scale_a, scale_b):
203+
is_scaling_deepseek = self.scaling_mode_int == ScalingMode.DEEPSEEK
204+
205+
assert (
206+
not is_scaling_deepseek or HAS_CUDA_129
207+
), "Deepseek-style scaling (BlockWise128x128) for scaled_gemm requires CUDA 12.9+"
208+
209+
use_fast_accum = (
210+
False if is_scaling_deepseek else True
211+
) # blockwise scaled_gemm does not support use_fast_accum=True
212+
168213
return lambda: torch._scaled_mm(
169214
a,
170215
b.t(),
171216
scale_a,
172217
scale_b.t(),
173-
use_fast_accum=True,
218+
use_fast_accum=use_fast_accum,
174219
out_dtype=self._get_dtype(),
175220
)
176221

0 commit comments

Comments
 (0)