@@ -154,51 +154,101 @@ def _scaled_dot_product_int8_op_ref(
154
154
out = torch .clamp (torch .round (out / o_scale ) + o_zp , min = 0 , max = 255 )
155
155
return out .to (torch .uint8 )
156
156
157
+ def _scaled_dot_product_fp8_op_ref (
158
+ self ,
159
+ q ,
160
+ k ,
161
+ v ,
162
+ attn_mask = None ,
163
+ dropout_p = 0 ,
164
+ is_causal = False ,
165
+ q_scale = 1.0 ,
166
+ k_scale = 1.0 ,
167
+ v_scale = 1.0 ,
168
+ a_scale = 1.0 ,
169
+ o_scale = 1.0 ,
170
+ ):
171
+ q = q .to (torch .float ) * q_scale
172
+ k = k .to (torch .float ) * k_scale
173
+ v = v .to (torch .float ) * v_scale
174
+ scale_factor = 1 / math .sqrt (q .size (- 1 ))
175
+ attn = q @ k .transpose (- 2 , - 1 )
176
+
177
+ attn = attn * scale_factor
178
+ if attn_mask is not None :
179
+ attn = attn + attn_mask .to (torch .float )
180
+ attn_max = attn .max (dim = - 1 , keepdim = True ).values
181
+ attn = attn - attn_max
182
+ attn = torch .exp (attn )
183
+ attn_sum = torch .sum (attn , dim = - 1 , keepdim = True )
184
+ attn = attn / attn_sum
185
+ attn = torch .clamp (attn / a_scale , min = - 448 , max = 448 )
186
+ attn = attn .to (torch .float8_e4m3fn ).to (torch .float )
187
+ attn = attn * a_scale
188
+ out = attn @ v
189
+ out = torch .clamp (out / o_scale , min = - 448 , max = 448 )
190
+ return out .to (torch .float8_e4m3fn )
191
+
157
192
@pytest .mark .skipif (
158
193
not torch_version_at_least ("2.7.0" ),
159
- reason = "int8 sdpa requires torch 2.7 or later" ,
194
+ reason = "quantized sdpa requires torch 2.7 or later" ,
160
195
)
161
196
@pytest .mark .skipif (not IS_LINUX , reason = "only support on linux" )
162
197
@pytest .mark .skipif (
163
198
"CPU" not in torch ._C ._dispatch_dump ("torchao::qscaled_dot_product" ),
164
199
reason = "cpp kernels not built" ,
165
200
)
201
+ @parametrize ("input_dtype" , [torch .uint8 , torch .float8_e4m3fn ])
166
202
@parametrize ("batch_size" , [56 , 120 ])
167
203
@parametrize ("n_head" , [2 , 16 ])
168
204
@parametrize ("q_seq_len" , [18 , 89 ])
169
205
@parametrize ("kv_seq_len" , [100 , 253 ])
170
206
@parametrize ("head_dim" , [32 , 64 ])
171
207
@parametrize ("mask_dtype" , [None , torch .float32 , torch .bfloat16 ])
172
- def test_scaled_dot_product_int8_op (
173
- self , batch_size , n_head , q_seq_len , kv_seq_len , head_dim , mask_dtype
208
+ def test_quantized_scaled_dot_product_op (
209
+ self ,
210
+ input_dtype ,
211
+ batch_size ,
212
+ n_head ,
213
+ q_seq_len ,
214
+ kv_seq_len ,
215
+ head_dim ,
216
+ mask_dtype ,
174
217
):
175
218
torch .manual_seed (1234 )
176
219
device = "cpu"
177
- q_scale = float (1.7907238006591797 )
178
- q_zp = int (127 )
179
- k_scale = float (1.8039721250534058 )
180
- k_zp = int (125 )
181
- v_scale = float (1.839004635810852 )
182
- v_zp = int (127 )
183
- a_scale = float (0.003919653594493866 )
184
- a_zp = int (120 )
185
- o_scale = float (1.8191684484481812 )
186
- o_zp = int (128 )
220
+ if input_dtype == torch .uint8 :
221
+ q_scale = float (1.7907238006591797 )
222
+ k_scale = float (1.8039721250534058 )
223
+ v_scale = float (1.839004635810852 )
224
+ a_scale = float (0.003919653594493866 )
225
+ o_scale = float (1.8191684484481812 )
226
+ q_zp = int (127 )
227
+ k_zp = int (125 )
228
+ v_zp = int (127 )
229
+ a_zp = int (120 )
230
+ o_zp = int (128 )
231
+ atol , rtol = 1.0 , 5e-6
232
+ else :
233
+ q_scale = float (5.96875 )
234
+ k_scale = float (5.78125 )
235
+ v_scale = float (0.98046875 )
236
+ a_scale = float (4.84375 )
237
+ o_scale = float (3.171875 )
238
+ atol , rtol = 0.125 , 5e-6
187
239
q_shape = [batch_size , q_seq_len , n_head , head_dim ]
188
240
kv_shape = [batch_size , kv_seq_len , n_head , head_dim ]
189
241
mask_shape = [batch_size , 1 , 1 , kv_seq_len ]
190
- q = torch .randn (q_shape , dtype = torch .float , device = device ).transpose (1 , 2 ) * 100
191
- k = (
192
- torch .randn (kv_shape , dtype = torch .float , device = device ).transpose (1 , 2 )
193
- * 100
194
- )
195
- v = (
196
- torch .randn (kv_shape , dtype = torch .float , device = device ).transpose (1 , 2 )
197
- * 100
198
- )
199
- q = q .to (torch .uint8 )
200
- k = k .to (torch .uint8 )
201
- v = v .to (torch .uint8 )
242
+ q = torch .randn (q_shape , dtype = torch .float , device = device ).transpose (1 , 2 )
243
+ k = torch .randn (kv_shape , dtype = torch .float , device = device ).transpose (1 , 2 )
244
+ v = torch .randn (kv_shape , dtype = torch .float , device = device ).transpose (1 , 2 )
245
+ if input_dtype == torch .uint8 :
246
+ q *= 100
247
+ k *= 100
248
+ v *= 100
249
+ q = q .to (input_dtype )
250
+ k = k .to (input_dtype )
251
+ v = v .to (input_dtype )
202
252
attn_mask = (
203
253
torch .randn (mask_shape , dtype = mask_dtype , device = device )
204
254
if mask_dtype is not None
@@ -211,44 +261,71 @@ def test_scaled_dot_product_int8_op(
211
261
attn_mask .clone () if mask_dtype is not None else None ,
212
262
)
213
263
214
- math_ref = self ._scaled_dot_product_int8_op_ref (
215
- q2 ,
216
- k2 ,
217
- v2 ,
218
- attn_mask = attn_mask ,
219
- dropout_p = 0.0 ,
220
- is_causal = False ,
221
- q_scale = q_scale ,
222
- q_zp = q_zp ,
223
- k_scale = k_scale ,
224
- k_zp = k_zp ,
225
- v_scale = v_scale ,
226
- v_zp = v_zp ,
227
- a_scale = a_scale ,
228
- a_zp = a_zp ,
229
- o_scale = o_scale ,
230
- o_zp = o_zp ,
231
- )
232
- actual = torch .ops .torchao .qscaled_dot_product (
233
- q ,
234
- k ,
235
- v ,
236
- attn_mask = attn_mask_2 ,
237
- dropout_p = 0.0 ,
238
- is_causal = False ,
239
- q_scale = q_scale ,
240
- q_zp = q_zp ,
241
- k_scale = k_scale ,
242
- k_zp = k_zp ,
243
- v_scale = v_scale ,
244
- v_zp = v_zp ,
245
- a_scale = a_scale ,
246
- a_zp = a_zp ,
247
- o_scale = o_scale ,
248
- o_zp = o_zp ,
249
- )
250
-
251
- self .assertEqual (actual , math_ref , atol = 1.0 , rtol = 5e-6 )
264
+ if input_dtype == torch .uint8 :
265
+ math_ref = self ._scaled_dot_product_int8_op_ref (
266
+ q2 ,
267
+ k2 ,
268
+ v2 ,
269
+ attn_mask = attn_mask ,
270
+ dropout_p = 0.0 ,
271
+ is_causal = False ,
272
+ q_scale = q_scale ,
273
+ q_zp = q_zp ,
274
+ k_scale = k_scale ,
275
+ k_zp = k_zp ,
276
+ v_scale = v_scale ,
277
+ v_zp = v_zp ,
278
+ a_scale = a_scale ,
279
+ a_zp = a_zp ,
280
+ o_scale = o_scale ,
281
+ o_zp = o_zp ,
282
+ )
283
+ actual = torch .ops .torchao .qscaled_dot_product (
284
+ q ,
285
+ k ,
286
+ v ,
287
+ attn_mask = attn_mask_2 ,
288
+ dropout_p = 0.0 ,
289
+ is_causal = False ,
290
+ q_scale = q_scale ,
291
+ q_zp = q_zp ,
292
+ k_scale = k_scale ,
293
+ k_zp = k_zp ,
294
+ v_scale = v_scale ,
295
+ v_zp = v_zp ,
296
+ a_scale = a_scale ,
297
+ a_zp = a_zp ,
298
+ o_scale = o_scale ,
299
+ o_zp = o_zp ,
300
+ )
301
+ else :
302
+ math_ref = self ._scaled_dot_product_fp8_op_ref (
303
+ q2 ,
304
+ k2 ,
305
+ v2 ,
306
+ attn_mask = attn_mask ,
307
+ dropout_p = 0.0 ,
308
+ is_causal = False ,
309
+ q_scale = q_scale ,
310
+ k_scale = k_scale ,
311
+ v_scale = v_scale ,
312
+ a_scale = a_scale ,
313
+ o_scale = o_scale ,
314
+ )
315
+ actual = torch .ops .torchao .qscaled_dot_product (
316
+ q ,
317
+ k ,
318
+ v ,
319
+ attn_mask = attn_mask_2 ,
320
+ dropout_p = 0.0 ,
321
+ is_causal = False ,
322
+ q_scale = q_scale ,
323
+ k_scale = k_scale ,
324
+ v_scale = v_scale ,
325
+ a_scale = a_scale ,
326
+ o_scale = o_scale ,
327
+ )
328
+ self .assertEqual (actual .float (), math_ref .float (), atol = atol , rtol = rtol )
252
329
253
330
254
331
instantiate_parametrized_tests (TestOps )
0 commit comments