47
47
HAS_TMA = False
48
48
logger .warning (f"Failed to import TMA: { e } " )
49
49
50
+ HAS_CUDA_129 = (
51
+ torch .cuda .is_available () and torch .version .cuda and torch .version .cuda >= "12.9"
52
+ )
53
+
50
54
51
55
def parse_args (args ):
52
56
parser = argparse .ArgumentParser (description = "TritonBench fp8_gemm" )
@@ -65,6 +69,8 @@ def get_scaling_mode_int(scaling_mode: str) -> int:
65
69
return ScalingMode .TENSOR
66
70
elif scaling_mode == "row" :
67
71
return ScalingMode .ROW
72
+ elif scaling_mode == "deepseek" :
73
+ return ScalingMode .DEEPSEEK
68
74
else :
69
75
raise ValueError (f"Invalid scaling mode: { scaling_mode } " )
70
76
@@ -113,11 +119,40 @@ def _get_scale_per_row(
113
119
torch .float32
114
120
) # For row-wise scaling, kernel requires a float32 scale tensor
115
121
122
+ def _get_scale_deepseek (
123
+ x : torch .Tensor ,
124
+ block_outer : int ,
125
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
126
+ """
127
+ DeepSeek-style scaling on matmul A @ B uses a combination of block- and tile-wise scaling:
128
+ - activation tensor A: 1x128 tile-wise scaling
129
+ - weight tensor B: 128x128 block-wise scaling
130
+ """
131
+ block_inner = 128
132
+ x = x .unflatten (1 , (- 1 , block_inner )).unflatten (0 , (- 1 , block_outer ))
133
+ amax = x .abs ().amax (dim = [1 , 3 ], keepdim = True ).float ()
134
+ scale = torch .finfo (torch .float8_e4m3fn ).max / amax
135
+ x = (
136
+ x .mul (scale ).flatten (2 , 3 ).flatten (0 , 1 )
137
+ ) # scale input up to dynamic range of float8_e4m3fn
138
+ scale = scale .flatten (2 , 3 ).flatten (0 , 1 )
139
+ return x , scale .to (torch .float32 )
140
+
116
141
def args (m , n , k ):
117
142
a = torch .randn (m , k , device = self .device ).to (self ._get_dtype ())
118
143
b = torch .randn (n , k , device = self .device ).to (self ._get_dtype ())
119
144
120
- if self .scaling_mode_int == ScalingMode .ROW :
145
+ if self .scaling_mode_int == ScalingMode .DEEPSEEK :
146
+ activations_block_outer = 1
147
+ weights_block_outer = 128
148
+
149
+ a , scale_a = _get_scale_deepseek (a , activations_block_outer )
150
+ b , scale_b = _get_scale_deepseek (b , weights_block_outer )
151
+
152
+ scale_a = (
153
+ scale_a .t ().contiguous ().t ()
154
+ ) # 1x128 blocks need scales to be outer-dim-major
155
+ elif self .scaling_mode_int == ScalingMode .ROW :
121
156
scale_a = _get_scale_per_row (a )
122
157
scale_b = _get_scale_per_row (b )
123
158
else : # self.scaling_mode_int == ScalingMode.TENSOR
@@ -166,12 +201,20 @@ def get_x_val(self, example_inputs) -> float:
166
201
167
202
@register_benchmark (baseline = True )
168
203
def torch_fp8_gemm (self , a , b , scale_a , scale_b ):
204
+ is_scaling_deepseek = self .scaling_mode_int == ScalingMode .DEEPSEEK
205
+
206
+ assert (
207
+ not is_scaling_deepseek or HAS_CUDA_129
208
+ ), "Deepseek-style scaling (BlockWise128x128) for scaled_gemm requires CUDA 12.9+"
209
+
210
+ use_fast_accum = False if is_scaling_deepseek else True # blockwise scaled_gemm does not support use_fast_accum=True
211
+
169
212
return lambda : torch ._scaled_mm (
170
213
a ,
171
214
b .t (),
172
215
scale_a ,
173
216
scale_b .t (),
174
- use_fast_accum = True ,
217
+ use_fast_accum = use_fast_accum ,
175
218
out_dtype = self ._get_dtype (),
176
219
)
177
220
0 commit comments