45
45
HAS_TMA = False
46
46
logger .warning (f"Failed to import TMA: { e } " )
47
47
48
+ HAS_CUDA_129 = (
49
+ torch .cuda .is_available () and torch .version .cuda and torch .version .cuda >= "12.9"
50
+ )
51
+
48
52
49
53
def parse_args (args ):
50
54
parser = argparse .ArgumentParser (description = "TritonBench fp8_gemm" )
@@ -63,6 +67,10 @@ def get_scaling_recipe_int(scaling_recipe: str) -> int:
63
67
return ScalingType .TensorWise
64
68
elif scaling_recipe == "RowWise" :
65
69
return ScalingType .RowWise
70
+ elif scaling_recipe == "BlockWise1x128" :
71
+ return ScalingType .BlockWise1x128
72
+ elif scaling_recipe == "BlockWise128x128" :
73
+ return ScalingType .BlockWise128x128
66
74
else :
67
75
raise ValueError (f"Invalid scaling recipe: { scaling_recipe } " )
68
76
@@ -97,11 +105,25 @@ def _get_scale_per_row(x: torch.Tensor, transpose: bool = False) -> torch.Tensor
97
105
torch .float32
98
106
) # For row-wise scaling, kernel requires a float32 scale tensor
99
107
108
+ def _get_scale_per_block (x : torch .Tensor , block_outer : int , block_inner : int ):
109
+ x = x .unflatten (1 , (- 1 , block_inner )).unflatten (0 , (- 1 , block_outer ))
110
+ amax = x .abs ().amax (dim = [1 , 3 ], keepdim = True ).float ()
111
+ scale = torch .finfo (torch .float8_e4m3fn ).max / amax
112
+ x = (
113
+ x .mul (scale ).flatten (2 , 3 ).flatten (0 , 1 )
114
+ ) # scale input up to dynamic range of float8_e4m3fn
115
+ scale = scale .flatten (2 , 3 ).flatten (0 , 1 )
116
+ return x , scale .to (torch .float32 )
117
+
100
118
match scaling_recipe_int :
101
119
case ScalingType .TensorWise :
102
120
return _get_scale_per_tensor (x , custom_scale = custom_scale )
103
121
case ScalingType .RowWise :
104
122
return _get_scale_per_row (x , transpose = transpose )
123
+ case ScalingType .BlockWise1x128 :
124
+ return _get_scale_per_block (x , 1 , 128 )
125
+ case ScalingType .BlockWise128x128 :
126
+ return _get_scale_per_block (x , 128 , 128 )
105
127
case _:
106
128
raise AssertionError (f"Unsupported scaling type { scaling_recipe_int } " )
107
129
@@ -127,6 +149,19 @@ def __init__(
127
149
self .scaling_recipe_a_int = get_scaling_recipe_int (scaling_recipe_a ).value
128
150
self .scaling_recipe_b_int = get_scaling_recipe_int (scaling_recipe_b ).value
129
151
152
+ blockwise_scaling_types = [
153
+ ScalingType .BlockWise1x128 ,
154
+ ScalingType .BlockWise128x128 ,
155
+ ]
156
+ self .contains_blockwise_scaling = (
157
+ self .scaling_recipe_a_int in blockwise_scaling_types
158
+ or self .scaling_recipe_b_int in blockwise_scaling_types
159
+ )
160
+
161
+ self .use_fast_accum = (
162
+ False if self .contains_blockwise_scaling else True
163
+ ) # BlockWise scaled_gemm does not support use_fast_accum=True
164
+
130
165
def _get_dtype (self ):
131
166
if (
132
167
self .scaling_recipe_a_int == ScalingType .TensorWise
@@ -189,12 +224,16 @@ def get_x_val(self, example_inputs) -> float:
189
224
190
225
@register_benchmark (baseline = True )
191
226
def torch_fp8_gemm (self , a , b , scale_a , scale_b ):
227
+ assert (
228
+ not self .contains_blockwise_scaling or HAS_CUDA_129
229
+ ), "BlockWise scaling variants for scaled_gemm require CUDA 12.9+"
230
+
192
231
return lambda : torch ._scaled_mm (
193
232
a ,
194
233
b .t (),
195
234
scale_a ,
196
235
scale_b .t (),
197
- use_fast_accum = True ,
236
+ use_fast_accum = self . use_fast_accum ,
198
237
out_dtype = self ._get_dtype (),
199
238
)
200
239
@@ -211,7 +250,7 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
211
250
b .t (),
212
251
scale_a ,
213
252
scale_b .t (),
214
- use_fast_accum = True ,
253
+ use_fast_accum = self . use_fast_accum ,
215
254
out_dtype = self ._get_dtype (),
216
255
)
217
256
compiled = torch .compile (f , dynamic = False )
0 commit comments