45
45
HAS_TMA = False
46
46
logger .warning (f"Failed to import TMA: { e } " )
47
47
48
+ HAS_CUDA_129 = (
49
+ torch .cuda .is_available () and torch .version .cuda and torch .version .cuda >= "12.9"
50
+ )
51
+
48
52
49
53
def parse_args (args ):
50
54
parser = argparse .ArgumentParser (description = "TritonBench fp8_gemm" )
@@ -63,6 +67,8 @@ def get_scaling_mode_int(scaling_mode: str) -> int:
63
67
return ScalingMode .TENSOR
64
68
elif scaling_mode == "row" :
65
69
return ScalingMode .ROW
70
+ elif scaling_mode == "deepseek" :
71
+ return ScalingMode .DEEPSEEK
66
72
else :
67
73
raise ValueError (f"Invalid scaling mode: { scaling_mode } " )
68
74
@@ -111,11 +117,40 @@ def _get_scale_per_row(
111
117
torch .float32
112
118
) # For row-wise scaling, kernel requires a float32 scale tensor
113
119
120
+ def _get_scale_deepseek (
121
+ x : torch .Tensor ,
122
+ block_outer : int ,
123
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
124
+ """
125
+ DeepSeek-style scaling on matmul A @ B uses a combination of block- and tile-wise scaling:
126
+ - activation tensor A: 1x128 tile-wise scaling
127
+ - weight tensor B: 128x128 block-wise scaling
128
+ """
129
+ block_inner = 128
130
+ x = x .unflatten (1 , (- 1 , block_inner )).unflatten (0 , (- 1 , block_outer ))
131
+ amax = x .abs ().amax (dim = [1 , 3 ], keepdim = True ).float ()
132
+ scale = torch .finfo (torch .float8_e4m3fn ).max / amax
133
+ x = (
134
+ x .mul (scale ).flatten (2 , 3 ).flatten (0 , 1 )
135
+ ) # scale input up to dynamic range of float8_e4m3fn
136
+ scale = scale .flatten (2 , 3 ).flatten (0 , 1 )
137
+ return x , scale .to (torch .float32 )
138
+
114
139
def args (m , n , k ):
115
140
a = torch .randn (m , k , device = self .device ).to (self ._get_dtype ())
116
141
b = torch .randn (n , k , device = self .device ).to (self ._get_dtype ())
117
142
118
- if self .scaling_mode_int == ScalingMode .ROW :
143
+ if self .scaling_mode_int == ScalingMode .DEEPSEEK :
144
+ activations_block_outer = 1
145
+ weights_block_outer = 128
146
+
147
+ a , scale_a = _get_scale_deepseek (a , activations_block_outer )
148
+ b , scale_b = _get_scale_deepseek (b , weights_block_outer )
149
+
150
+ scale_a = (
151
+ scale_a .t ().contiguous ().t ()
152
+ ) # 1x128 blocks need scales to be outer-dim-major
153
+ elif self .scaling_mode_int == ScalingMode .ROW :
119
154
scale_a = _get_scale_per_row (a )
120
155
scale_b = _get_scale_per_row (b )
121
156
else : # self.scaling_mode_int == ScalingMode.TENSOR
@@ -164,12 +199,22 @@ def get_x_val(self, example_inputs) -> float:
164
199
165
200
@register_benchmark (baseline = True )
166
201
def torch_fp8_gemm (self , a , b , scale_a , scale_b ):
202
+ is_scaling_deepseek = self .scaling_mode_int == ScalingMode .DEEPSEEK
203
+
204
+ assert (
205
+ not is_scaling_deepseek or HAS_CUDA_129
206
+ ), "Deepseek-style scaling (BlockWise128x128) for scaled_gemm requires CUDA 12.9+"
207
+
208
+ use_fast_accum = (
209
+ False if is_scaling_deepseek else True
210
+ ) # blockwise scaled_gemm does not support use_fast_accum=True
211
+
167
212
return lambda : torch ._scaled_mm (
168
213
a ,
169
214
b .t (),
170
215
scale_a ,
171
216
scale_b .t (),
172
- use_fast_accum = True ,
217
+ use_fast_accum = use_fast_accum ,
173
218
out_dtype = self ._get_dtype (),
174
219
)
175
220
0 commit comments