@@ -41,6 +41,8 @@ def parse_args(args):
41
41
parser .add_argument ("--m" , type = int )
42
42
parser .add_argument ("--k" , type = int )
43
43
parser .add_argument ("--n" , type = int )
44
+ parser .add_argument ("--per-tensor-scale-a" , type = float , default = None )
45
+ parser .add_argument ("--per-tensor-scale-b" , type = float , default = None )
44
46
return parser .parse_args (args )
45
47
46
48
@@ -54,18 +56,53 @@ def __init__(
54
56
super ().__init__ (tb_args , extra_args )
55
57
self .extra_args = parse_args (extra_args )
56
58
59
+ def _get_dtype (self ):
60
+ if self .extra_args .scaling_rowwise :
61
+ return torch .bfloat16
62
+ else :
63
+ return torch .float16
64
+
57
65
def get_input_iter (self ):
66
+ def _get_scale_per_tensor (
67
+ x : torch .Tensor , custom_scale : float = None
68
+ ) -> torch .Tensor :
69
+ # For tensor-wise scaling, kernel requires a float32 scale tensor
70
+ if custom_scale :
71
+ return torch .tensor (custom_scale , dtype = torch .float32 , device = x .device )
72
+ scale = torch .finfo (torch .float8_e4m3fn ).max / x .abs ().max ()
73
+ return scale .to (torch .float32 )
74
+
75
+ def _get_scale_per_row (
76
+ x : torch .Tensor , transpose : bool = False
77
+ ) -> torch .Tensor :
78
+ if transpose : # scale_b.shape should be [1, N]
79
+ scale = (
80
+ torch .finfo (torch .float8_e4m3fn ).max
81
+ / x .abs ().max (dim = 0 , keepdim = True ).values
82
+ )
83
+ else : # scale_a.shape should be [M, 1]
84
+ scale = (
85
+ torch .finfo (torch .float8_e4m3fn ).max
86
+ / x .abs ().max (dim = 1 , keepdim = True ).values
87
+ )
88
+ return scale .to (
89
+ torch .float32
90
+ ) # For row-wise scaling, kernel requires a float32 scale tensor
91
+
58
92
def args (m , n , k ):
59
93
a = torch .randn (m , k , device = self .device ).to (torch .float16 )
60
94
b = torch .randn (k , n , device = self .device ).to (torch .float16 ).T .contiguous ().T
61
95
62
96
if self .extra_args .scaling_rowwise :
63
- M , N = a .shape [0 ], b .shape [1 ]
64
- scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
65
- scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
97
+ scale_a = _get_scale_per_row (a )
98
+ scale_b = _get_scale_per_row (b , transpose = True )
66
99
else :
67
- scale_a = torch .tensor (1.0 , device = a .device )
68
- scale_b = torch .tensor (1.0 , device = a .device )
100
+ scale_a = _get_scale_per_tensor (
101
+ a , custom_scale = self .extra_args .per_tensor_scale_a
102
+ )
103
+ scale_b = _get_scale_per_tensor (
104
+ b , custom_scale = self .extra_args .per_tensor_scale_b
105
+ )
69
106
70
107
# Kernels expect dtype=float8_e4m3fn
71
108
a = a .to (torch .float8_e4m3fn )
@@ -103,16 +140,10 @@ def get_x_val(self, example_inputs) -> float:
103
140
_ , n = b .size ()
104
141
return (m , n , k )
105
142
106
- def _get_out_dtype (self ):
107
- if self .extra_args .scaling_rowwise :
108
- return torch .bfloat16
109
- else :
110
- return torch .float16
111
-
112
143
@register_benchmark (baseline = True )
113
144
def torch_fp8_gemm (self , a , b , scale_a , scale_b ):
114
145
return lambda : torch ._scaled_mm (
115
- a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self ._get_out_dtype ()
146
+ a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self ._get_dtype ()
116
147
)
117
148
118
149
@register_benchmark ()
@@ -129,7 +160,7 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
129
160
scale_a ,
130
161
scale_b ,
131
162
use_fast_accum = True ,
132
- out_dtype = self ._get_out_dtype (),
163
+ out_dtype = self ._get_dtype (),
133
164
)
134
165
compiled = torch .compile (f , dynamic = False )
135
166
compiled (a , b )
0 commit comments