@@ -56,14 +56,22 @@ def __init__(
56
56
57
57
def get_input_iter (self ):
58
58
def args (m , n , k ):
59
- a = torch .randn (m , k , device = self .device ).to (torch .float8_e4m3fn )
60
- b = (
61
- torch .randn (k , n , device = self .device )
62
- .to (torch .float8_e4m3fn )
63
- .T .contiguous ()
64
- .T
65
- )
66
- return (a , b )
59
+ a = torch .randn (m , k , device = self .device ).to (torch .float16 )
60
+ b = torch .randn (k , n , device = self .device ).to (torch .float16 ).T .contiguous ().T
61
+
62
+ if self .extra_args .scaling_rowwise :
63
+ M , N = a .shape [0 ], b .shape [1 ]
64
+ scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
65
+ scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
66
+ else :
67
+ scale_a = torch .tensor (1.0 , device = a .device )
68
+ scale_b = torch .tensor (1.0 , device = a .device )
69
+
70
+ # Kernels expect dtype=float8_e4m3fn
71
+ a = a .to (torch .float8_e4m3fn )
72
+ b = b .to (torch .float8_e4m3fn )
73
+
74
+ return (a , b , scale_a , scale_b )
67
75
68
76
if (
69
77
hasattr (self , "external_shapes" ) and self .external_shapes
@@ -90,62 +98,54 @@ def args(m, n, k):
90
98
yield args (m , n , k )
91
99
92
100
def get_x_val (self , example_inputs ) -> float :
93
- a , b = example_inputs
101
+ a , b , _ , _ = example_inputs
94
102
m , k = a .size ()
95
103
_ , n = b .size ()
96
104
return (m , n , k )
97
105
98
- @register_benchmark (baseline = True )
99
- def torch_fp8_gemm (self , a , b ):
106
+ def _get_out_dtype (self ):
100
107
if self .extra_args .scaling_rowwise :
101
- M , N = a .shape [0 ], b .shape [1 ]
102
- scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
103
- scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
104
- out_dtype = torch .bfloat16
108
+ return torch .bfloat16
105
109
else :
106
- scale_a = torch .tensor (1.0 , device = a .device )
107
- scale_b = torch .tensor (1.0 , device = a .device )
108
- out_dtype = torch .float16
110
+ return torch .float16
109
111
112
+ @register_benchmark (baseline = True )
113
+ def torch_fp8_gemm (self , a , b , scale_a , scale_b ):
110
114
return lambda : torch ._scaled_mm (
111
- a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = out_dtype
115
+ a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self . _get_out_dtype ()
112
116
)
113
117
114
118
@register_benchmark ()
115
- def pt2_fp8_gemm (self , a , b ) -> Callable :
119
+ def pt2_fp8_gemm (self , a , b , scale_a , scale_b ) -> Callable :
116
120
torch ._dynamo .reset ()
117
121
with inductor_config .patch (
118
122
max_autotune = True ,
119
123
max_autotune_gemm_backends = "TRITON" ,
120
124
autotune_fallback_to_aten = False ,
121
125
):
122
- if self .extra_args .scaling_rowwise :
123
- M , N = a .shape [0 ], b .shape [1 ]
124
- scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
125
- scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
126
- out_dtype = torch .bfloat16
127
- else :
128
- scale_a = torch .tensor (1.0 , device = a .device )
129
- scale_b = torch .tensor (1.0 , device = b .device )
130
- out_dtype = torch .float16
131
126
f = lambda a , b : torch ._scaled_mm (
132
- a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = out_dtype
127
+ a ,
128
+ b ,
129
+ scale_a ,
130
+ scale_b ,
131
+ use_fast_accum = True ,
132
+ out_dtype = self ._get_out_dtype (),
133
133
)
134
134
compiled = torch .compile (f , dynamic = False )
135
135
compiled (a , b )
136
136
137
137
return lambda : compiled (a , b )
138
138
139
139
@register_benchmark ()
140
- def triton_fp8_gemm (self , a , b ):
140
+ def triton_fp8_gemm (self , a , b , scale_a , scale_b ):
141
141
return lambda : tutorial_matmul (a , b )
142
142
143
143
@register_benchmark (enabled = HAS_TMA )
144
- def triton_persistent_fp8_gemm (self , a , b ):
144
+ def triton_persistent_fp8_gemm (self , a , b , scale_a , scale_b ):
145
145
return lambda : matmul_persistent (a , b )
146
146
147
147
@register_benchmark (enabled = HAS_TMA )
148
- def triton_tma_persistent_fp8_gemm (self , a , b ):
148
+ def triton_tma_persistent_fp8_gemm (self , a , b , scale_a , scale_b ):
149
149
b = b .T .contiguous ()
150
150
c , desc_a , desc_b , desc_c = allocate_matmul_tma (a , b )
151
151
return lambda : matmul_tma_persistent (a , b , c , desc_a , desc_b , desc_c )
@@ -155,7 +155,7 @@ def gbps(self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> fl
155
155
def nbytes (t ):
156
156
return t .numel () * t .element_size ()
157
157
158
- a , b = example_inputs
158
+ a , b , _ , _ = example_inputs
159
159
c = fn ()
160
160
c = c [0 ] if isinstance (c , tuple ) else c
161
161
@@ -168,7 +168,7 @@ def nbytes(t):
168
168
def flops (
169
169
self , fn_name : str , example_inputs : Any , metrics : BenchmarkOperatorMetrics
170
170
) -> float :
171
- a , b = example_inputs
171
+ a , b , _ , _ = example_inputs
172
172
m , k = a .size ()
173
173
_ , n = b .size ()
174
174
flops = 2 * m * n * k
0 commit comments