45
45
HAS_TMA = False
46
46
logger .warning (f"Failed to import TMA: { e } " )
47
47
48
+ HAS_CUDA_129 = (
49
+ torch .cuda .is_available () and torch .version .cuda and torch .version .cuda >= "12.9"
50
+ )
51
+
48
52
49
53
def parse_args (args ):
50
54
parser = argparse .ArgumentParser (description = "TritonBench fp8_gemm" )
@@ -63,6 +67,10 @@ def get_scaling_recipe_int(scaling_recipe: str) -> int:
63
67
return ScalingType .TensorWise
64
68
elif scaling_recipe == "RowWise" :
65
69
return ScalingType .RowWise
70
+ elif scaling_recipe == "BlockWise1x128" :
71
+ return ScalingType .BlockWise1x128
72
+ elif scaling_recipe == "BlockWise128x128" :
73
+ return ScalingType .BlockWise128x128
66
74
else :
67
75
raise ValueError (f"Invalid scaling recipe: { scaling_recipe } " )
68
76
@@ -79,7 +87,7 @@ def _get_scale_per_tensor(
79
87
# For tensor-wise scaling, kernel requires a float32 scale tensor
80
88
if custom_scale :
81
89
return torch .tensor (custom_scale , dtype = torch .float32 , device = x .device )
82
- scale = torch .finfo (torch .float8_e4m3fn ).max / x .abs ().max ()
90
+ scale = ( torch .finfo (torch .float8_e4m3fn ).max / x .abs ().max ()). reciprocal ()
83
91
x *= scale
84
92
return x , scale .to (torch .float32 )
85
93
@@ -90,22 +98,46 @@ def _get_scale_per_row(
90
98
scale = (
91
99
torch .finfo (torch .float8_e4m3fn ).max
92
100
/ x .abs ().max (dim = 0 , keepdim = True ).values
93
- )
101
+ ). reciprocal ()
94
102
else : # scale_a.shape should be [M, 1]
95
103
scale = (
96
104
torch .finfo (torch .float8_e4m3fn ).max
97
105
/ x .abs ().max (dim = 1 , keepdim = True ).values
98
- )
106
+ ). reciprocal ()
99
107
x = x .mul (scale )
100
108
return x , scale .to (
101
109
torch .float32
102
110
) # For row-wise scaling, kernel requires a float32 scale tensor
103
111
112
+ def _get_scale_per_block (
113
+ x : torch .Tensor , block_outer : int , block_inner : int
114
+ ) -> (torch .Tensor , torch .Tensor ):
115
+ x = x .unflatten (1 , (- 1 , block_inner )).unflatten (0 , (- 1 , block_outer ))
116
+ amax = x .abs ().amax (dim = [1 , 3 ], keepdim = True ).float ()
117
+ scale = (
118
+ torch .finfo (torch .float8_e4m3fn ).max / amax
119
+ ).reciprocal () # keeps scale small enough such that scaling doesn't cause inf values
120
+ x = (
121
+ x .mul (scale ).flatten (2 , 3 ).flatten (0 , 1 )
122
+ ) # scale input up to dynamic range of float8_e4m3fn
123
+ scale = scale .flatten (2 , 3 ).flatten (0 , 1 )
124
+
125
+ if block_outer == 1 and block_inner == 128 :
126
+ scale = (
127
+ scale .t ().contiguous ().t ()
128
+ ) # 1x128 blocks need scales to be outer-dim-major
129
+
130
+ return x , scale .to (torch .float32 )
131
+
104
132
match scaling_recipe_int :
105
133
case ScalingType .TensorWise :
106
134
return _get_scale_per_tensor (x , custom_scale = custom_scale )
107
135
case ScalingType .RowWise :
108
136
return _get_scale_per_row (x , transpose = transpose )
137
+ case ScalingType .BlockWise1x128 :
138
+ return _get_scale_per_block (x , 1 , 128 )
139
+ case ScalingType .BlockWise128x128 :
140
+ return _get_scale_per_block (x , 128 , 128 )
109
141
case _:
110
142
raise AssertionError (f"Unsupported scaling type { scaling_recipe_int } " )
111
143
@@ -131,6 +163,19 @@ def __init__(
131
163
self .scaling_recipe_a_int = get_scaling_recipe_int (scaling_recipe_a ).value
132
164
self .scaling_recipe_b_int = get_scaling_recipe_int (scaling_recipe_b ).value
133
165
166
+ blockwise_scaling_types = [
167
+ ScalingType .BlockWise1x128 ,
168
+ ScalingType .BlockWise128x128 ,
169
+ ]
170
+ self .contains_blockwise_scaling = (
171
+ self .scaling_recipe_a_int in blockwise_scaling_types
172
+ or self .scaling_recipe_b_int in blockwise_scaling_types
173
+ )
174
+
175
+ self .use_fast_accum = (
176
+ False if self .contains_blockwise_scaling else True
177
+ ) # BlockWise scaled_gemm does not support use_fast_accum=True
178
+
134
179
def _get_dtype (self ):
135
180
if (
136
181
self .scaling_recipe_a_int == ScalingType .TensorWise
@@ -193,12 +238,16 @@ def get_x_val(self, example_inputs) -> float:
193
238
194
239
@register_benchmark (baseline = True )
195
240
def torch_fp8_gemm (self , a , b , scale_a , scale_b ):
241
+ assert (
242
+ not self .contains_blockwise_scaling or HAS_CUDA_129
243
+ ), "BlockWise scaling variants for scaled_gemm require CUDA 12.9+"
244
+
196
245
return lambda : torch ._scaled_mm (
197
246
a ,
198
247
b .t (),
199
248
scale_a ,
200
249
scale_b .t (),
201
- use_fast_accum = True ,
250
+ use_fast_accum = self . use_fast_accum ,
202
251
out_dtype = self ._get_dtype (),
203
252
)
204
253
@@ -215,7 +264,7 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
215
264
b .t (),
216
265
scale_a ,
217
266
scale_b .t (),
218
- use_fast_accum = True ,
267
+ use_fast_accum = self . use_fast_accum ,
219
268
out_dtype = self ._get_dtype (),
220
269
)
221
270
compiled = torch .compile (f , dynamic = False )
0 commit comments