@@ -41,6 +41,8 @@ def parse_args(args):
41
41
parser .add_argument ("--m" , type = int )
42
42
parser .add_argument ("--k" , type = int )
43
43
parser .add_argument ("--n" , type = int )
44
+ parser .add_argument ("--per-tensor-scale-a" , type = float , default = None )
45
+ parser .add_argument ("--per-tensor-scale-b" , type = float , default = None )
44
46
return parser .parse_args (args )
45
47
46
48
@@ -54,12 +56,25 @@ def __init__(
54
56
super ().__init__ (tb_args , extra_args )
55
57
self .extra_args = parse_args (extra_args )
56
58
59
+ def _get_dtype (self ):
60
+ if self .extra_args .scaling_rowwise :
61
+ return torch .bfloat16
62
+ else :
63
+ return torch .float16
64
+
57
65
def get_input_iter (self ):
66
+ def _get_scale_per_tensor (x : torch .Tensor , custom_scale : float = None ) -> torch .Tensor :
67
+ # For tensor-wise scaling, kernel requires a float32 scale tensor
68
+ if custom_scale :
69
+ return torch .tensor (custom_scale , dtype = torch .float32 , device = x .device )
70
+ scale = torch .finfo (torch .float8_e4m3fn ).max / x .abs ().max ()
71
+ return scale .to (torch .float32 )
72
+
58
73
def args (m , n , k ):
59
- a = torch .randn (m , k , device = self .device ).to (torch . float16 )
74
+ a = torch .randn (m , k , device = self .device ).to (self . _get_dtype () )
60
75
b = (
61
76
torch .randn (k , n , device = self .device )
62
- .to (torch . float16 )
77
+ .to (self . _get_dtype () )
63
78
.T .contiguous ()
64
79
.T
65
80
)
@@ -69,8 +84,8 @@ def args(m, n, k):
69
84
scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
70
85
scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
71
86
else :
72
- scale_a = torch . tensor ( 1.0 , device = a . device )
73
- scale_b = torch . tensor ( 1.0 , device = a . device )
87
+ scale_a = _get_scale_per_tensor ( a , custom_scale = self . extra_args . per_tensor_scale_a )
88
+ scale_b = _get_scale_per_tensor ( b , custom_scale = self . extra_args . per_tensor_scale_b )
74
89
75
90
# Kernels expect dtype=float8_e4m3fn
76
91
a = a .to (torch .float8_e4m3fn )
@@ -108,16 +123,10 @@ def get_x_val(self, example_inputs) -> float:
108
123
_ , n = b .size ()
109
124
return (m , n , k )
110
125
111
- def _get_out_dtype (self ):
112
- if self .extra_args .scaling_rowwise :
113
- return torch .bfloat16
114
- else :
115
- return torch .float16
116
-
117
126
@register_benchmark (baseline = True )
118
127
def torch_fp8_gemm (self , a , b , scale_a , scale_b ):
119
128
return lambda : torch ._scaled_mm (
120
- a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self ._get_out_dtype ()
129
+ a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self ._get_dtype ()
121
130
)
122
131
123
132
@register_benchmark ()
@@ -129,7 +138,7 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
129
138
autotune_fallback_to_aten = False ,
130
139
):
131
140
f = lambda a , b : torch ._scaled_mm (
132
- a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self ._get_out_dtype ()
141
+ a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self ._get_dtype ()
133
142
)
134
143
compiled = torch .compile (f , dynamic = False )
135
144
compiled (a , b )
0 commit comments