45
45
HAS_TMA = False
46
46
logger .warning (f"Failed to import TMA: { e } " )
47
47
48
+ HAS_CUDA_129 = (
49
+ torch .cuda .is_available () and torch .version .cuda and torch .version .cuda >= "12.9"
50
+ )
51
+
48
52
49
53
def parse_args (args ):
50
54
parser = argparse .ArgumentParser (description = "TritonBench fp8_gemm" )
@@ -63,6 +67,8 @@ def get_scaling_mode_int(scaling_mode: str) -> int:
63
67
return ScalingMode .TENSOR
64
68
elif scaling_mode == "row" :
65
69
return ScalingMode .ROW
70
+ elif scaling_mode == "deepseek" :
71
+ return ScalingMode .DEEPSEEK
66
72
else :
67
73
raise ValueError (f"Invalid scaling mode: { scaling_mode } " )
68
74
@@ -112,11 +118,40 @@ def _get_scale_per_row(
112
118
torch .float32
113
119
) # For row-wise scaling, kernel requires a float32 scale tensor
114
120
121
+ def _get_scale_deepseek (
122
+ x : torch .Tensor ,
123
+ block_outer : int ,
124
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
125
+ """
126
+ DeepSeek-style scaling on matmul A @ B uses a combination of block- and tile-wise scaling:
127
+ - activation tensor A: 1x128 tile-wise scaling
128
+ - weight tensor B: 128x128 block-wise scaling
129
+ """
130
+ block_inner = 128
131
+ x = x .unflatten (1 , (- 1 , block_inner )).unflatten (0 , (- 1 , block_outer ))
132
+ amax = x .abs ().amax (dim = [1 , 3 ], keepdim = True ).float ()
133
+ scale = torch .finfo (torch .float8_e4m3fn ).max / amax
134
+ x = (
135
+ x .mul (scale ).flatten (2 , 3 ).flatten (0 , 1 )
136
+ ) # scale input up to dynamic range of float8_e4m3fn
137
+ scale = scale .flatten (2 , 3 ).flatten (0 , 1 )
138
+ return x , scale .to (torch .float32 )
139
+
115
140
def args (m , n , k ):
116
141
a = torch .randn (m , k , device = self .device ).to (self ._get_dtype ())
117
142
b = torch .randn (n , k , device = self .device ).to (self ._get_dtype ())
118
143
119
- if self .scaling_mode_int == ScalingMode .ROW :
144
+ if self .scaling_mode_int == ScalingMode .DEEPSEEK :
145
+ activations_block_outer = 1
146
+ weights_block_outer = 128
147
+
148
+ a , scale_a = _get_scale_deepseek (a , activations_block_outer )
149
+ b , scale_b = _get_scale_deepseek (b , weights_block_outer )
150
+
151
+ scale_a = (
152
+ scale_a .t ().contiguous ().t ()
153
+ ) # 1x128 blocks need scales to be outer-dim-major
154
+ elif self .scaling_mode_int == ScalingMode .ROW :
120
155
scale_a = _get_scale_per_row (a )
121
156
scale_b = _get_scale_per_row (b )
122
157
else : # self.scaling_mode_int == ScalingMode.TENSOR
@@ -165,12 +200,22 @@ def get_x_val(self, example_inputs) -> float:
165
200
166
201
@register_benchmark (baseline = True )
167
202
def torch_fp8_gemm (self , a , b , scale_a , scale_b ):
203
+ is_scaling_deepseek = self .scaling_mode_int == ScalingMode .DEEPSEEK
204
+
205
+ assert (
206
+ not is_scaling_deepseek or HAS_CUDA_129
207
+ ), "Deepseek-style scaling (BlockWise128x128) for scaled_gemm requires CUDA 12.9+"
208
+
209
+ use_fast_accum = (
210
+ False if is_scaling_deepseek else True
211
+ ) # blockwise scaled_gemm does not support use_fast_accum=True
212
+
168
213
return lambda : torch ._scaled_mm (
169
214
a ,
170
215
b .t (),
171
216
scale_a ,
172
217
scale_b .t (),
173
- use_fast_accum = True ,
218
+ use_fast_accum = use_fast_accum ,
174
219
out_dtype = self ._get_dtype (),
175
220
)
176
221
0 commit comments