Skip to content

Commit 7c0bf15

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Add fp8_gemm benchmark for deepseek-style scaling (#504)
Summary: Add `fp8_gemm` benchmark for deepseek-style scaling in TritonBench. Reviewed By: NikhilAPatel Differential Revision: D83689980
1 parent 5d628e6 commit 7c0bf15

File tree

1 file changed

+47
-2
lines changed

1 file changed

+47
-2
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@
4545
HAS_TMA = False
4646
logger.warning(f"Failed to import TMA: {e}")
4747

48+
HAS_CUDA_129 = (
49+
torch.cuda.is_available() and torch.version.cuda and torch.version.cuda >= "12.9"
50+
)
51+
4852

4953
def parse_args(args):
5054
parser = argparse.ArgumentParser(description="TritonBench fp8_gemm")
@@ -63,6 +67,8 @@ def get_scaling_mode_int(scaling_mode: str) -> int:
6367
return ScalingMode.TENSOR
6468
elif scaling_mode == "row":
6569
return ScalingMode.ROW
70+
elif scaling_mode == "deepseek":
71+
return ScalingMode.DEEPSEEK
6672
else:
6773
raise ValueError(f"Invalid scaling mode: {scaling_mode}")
6874

@@ -111,11 +117,40 @@ def _get_scale_per_row(
111117
torch.float32
112118
) # For row-wise scaling, kernel requires a float32 scale tensor
113119

120+
def _get_scale_deepseek(
121+
x: torch.Tensor,
122+
block_outer: int,
123+
) -> tuple[torch.Tensor, torch.Tensor]:
124+
"""
125+
DeepSeek-style scaling on matmul A @ B uses a combination of block- and tile-wise scaling:
126+
- activation tensor A: 1x128 tile-wise scaling
127+
- weight tensor B: 128x128 block-wise scaling
128+
"""
129+
block_inner = 128
130+
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
131+
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
132+
scale = torch.finfo(torch.float8_e4m3fn).max / amax
133+
x = (
134+
x.mul(scale).flatten(2, 3).flatten(0, 1)
135+
) # scale input up to dynamic range of float8_e4m3fn
136+
scale = scale.flatten(2, 3).flatten(0, 1)
137+
return x, scale.to(torch.float32)
138+
114139
def args(m, n, k):
115140
a = torch.randn(m, k, device=self.device).to(self._get_dtype())
116141
b = torch.randn(n, k, device=self.device).to(self._get_dtype())
117142

118-
if self.scaling_mode_int == ScalingMode.ROW:
143+
if self.scaling_mode_int == ScalingMode.DEEPSEEK:
144+
activations_block_outer = 1
145+
weights_block_outer = 128
146+
147+
a, scale_a = _get_scale_deepseek(a, activations_block_outer)
148+
b, scale_b = _get_scale_deepseek(b, weights_block_outer)
149+
150+
scale_a = (
151+
scale_a.t().contiguous().t()
152+
) # 1x128 blocks need scales to be outer-dim-major
153+
elif self.scaling_mode_int == ScalingMode.ROW:
119154
scale_a = _get_scale_per_row(a)
120155
scale_b = _get_scale_per_row(b)
121156
else: # self.scaling_mode_int == ScalingMode.TENSOR
@@ -164,12 +199,22 @@ def get_x_val(self, example_inputs) -> float:
164199

165200
@register_benchmark(baseline=True)
166201
def torch_fp8_gemm(self, a, b, scale_a, scale_b):
202+
is_scaling_deepseek = self.scaling_mode_int == ScalingMode.DEEPSEEK
203+
204+
assert (
205+
not is_scaling_deepseek or HAS_CUDA_129
206+
), "Deepseek-style scaling (BlockWise128x128) for scaled_gemm requires CUDA 12.9+"
207+
208+
use_fast_accum = (
209+
False if is_scaling_deepseek else True
210+
) # blockwise scaled_gemm does not support use_fast_accum=True
211+
167212
return lambda: torch._scaled_mm(
168213
a,
169214
b.t(),
170215
scale_a,
171216
scale_b.t(),
172-
use_fast_accum=True,
217+
use_fast_accum=use_fast_accum,
173218
out_dtype=self._get_dtype(),
174219
)
175220

0 commit comments

Comments
 (0)