30
30
aten = torch .ops .aten
31
31
quantize_dtypes = [torch .uint8 , torch .float8_e4m3fn ]
32
32
33
+
33
34
def _is_valid_qsdpa_pattern ():
34
35
def fn (match ):
35
36
assert all (k in match .kwargs for k in ("query" , "key" , "value" ))
@@ -117,7 +118,9 @@ def qsdpa(match: Match, *args, **kwargs):
117
118
return qsdpa
118
119
119
120
120
- def _generate_dequant_pattern (input_pattern , qtype , is_reduced_type , scale : str , zp : str = None ):
121
+ def _generate_dequant_pattern (
122
+ input_pattern , qtype , is_reduced_type , scale : str , zp : str = None
123
+ ):
121
124
if qtype == torch .uint8 :
122
125
assert zp is not None , "Zero point must be provided for uint8 dequantization"
123
126
return CallFunction (
@@ -146,7 +149,7 @@ def _generate_dequant_pattern(input_pattern, qtype, is_reduced_type, scale: str,
146
149
)
147
150
148
151
149
- def _generate_quant_pattern (input_pattern , qtype , scale : str , zp : str = None ):
152
+ def _generate_quant_pattern (input_pattern , qtype , scale : str , zp : str = None ):
150
153
if qtype == torch .uint8 :
151
154
assert zp is not None , "Zero point must be provided for uint8 quantization"
152
155
return CallFunction (
@@ -168,7 +171,11 @@ def _generate_quant_pattern(input_pattern, qtype, scale: str, zp: str=None):
168
171
169
172
170
173
def _get_qsdpa_qkv_pattern (
171
- qtype , is_batch_size_1 : bool , is_reduced_type : bool , has_convert : bool , input_name : str
174
+ qtype ,
175
+ is_batch_size_1 : bool ,
176
+ is_reduced_type : bool ,
177
+ has_convert : bool ,
178
+ input_name : str ,
172
179
):
173
180
assert input_name in ["query" , "key" , "value" ]
174
181
qsdpa_qkv_pattern_before_dequant = CallFunction (
@@ -221,7 +228,12 @@ def _get_qsdpa_qkv_pattern(
221
228
222
229
223
230
def _get_qsdpa_score_pattern (
224
- qtype , has_mask : bool , is_batch_size_1 : bool , is_reduced_type : bool , has_convert : bool , is_inv_scale : bool
231
+ qtype ,
232
+ has_mask : bool ,
233
+ is_batch_size_1 : bool ,
234
+ is_reduced_type : bool ,
235
+ has_convert : bool ,
236
+ is_inv_scale : bool ,
225
237
):
226
238
qsdpa_q_pattern = _get_qsdpa_qkv_pattern (
227
239
qtype , is_batch_size_1 , is_reduced_type , has_convert , "query"
@@ -276,7 +288,12 @@ def _get_qsdpa_score_pattern(
276
288
277
289
278
290
def _get_qsdpa_exp_pattern (
279
- qtype , has_mask : bool , is_batch_size_1 : bool , is_reduced_type : bool , has_convert : bool , is_inv_scale : bool
291
+ qtype ,
292
+ has_mask : bool ,
293
+ is_batch_size_1 : bool ,
294
+ is_reduced_type : bool ,
295
+ has_convert : bool ,
296
+ is_inv_scale : bool ,
280
297
):
281
298
qsdpa_score_pattern = _get_qsdpa_score_pattern (
282
299
qtype , has_mask , is_batch_size_1 , is_reduced_type , has_convert , is_inv_scale
@@ -298,15 +315,15 @@ def _get_qsdpa_exp_pattern(
298
315
_users = 2 ,
299
316
)
300
317
elif is_inv_scale :
301
- return CallFunction (
302
- aten .exp .default ,
303
- CallFunction (
304
- aten .div .Tensor ,
305
- qsdpa_exp_basic_pattern ,
306
- KeywordArg ("inv_scale" ),
307
- ),
308
- _users = 2 ,
309
- )
318
+ return CallFunction (
319
+ aten .exp .default ,
320
+ CallFunction (
321
+ aten .div .Tensor ,
322
+ qsdpa_exp_basic_pattern ,
323
+ KeywordArg ("inv_scale" ),
324
+ ),
325
+ _users = 2 ,
326
+ )
310
327
else :
311
328
return CallFunction (
312
329
aten .exp .default ,
@@ -320,7 +337,12 @@ def _get_qsdpa_exp_pattern(
320
337
321
338
322
339
def _get_qsdpa_attn_pattern (
323
- qtype , has_mask : bool , is_batch_size_1 : bool , is_reduced_type : bool , has_convert : bool , is_inv_scale : bool
340
+ qtype ,
341
+ has_mask : bool ,
342
+ is_batch_size_1 : bool ,
343
+ is_reduced_type : bool ,
344
+ has_convert : bool ,
345
+ is_inv_scale : bool ,
324
346
):
325
347
qsdpa_exp_pattern = _get_qsdpa_exp_pattern (
326
348
qtype , has_mask , is_batch_size_1 , is_reduced_type , has_convert , is_inv_scale
@@ -396,7 +418,12 @@ def _get_qsdpa_attn_pattern(
396
418
# has_convert: convert type if dequant out dtype is assigned
397
419
# is_inv_scale: if the scale in SDPA is inversed, in which case it is multiplied instead of divided
398
420
def _get_qsdpa_final_pattern (
399
- qtype , has_mask : bool , is_batch_size_1 : bool , is_reduced_type : bool , has_convert : bool , is_inv_scale : bool
421
+ qtype ,
422
+ has_mask : bool ,
423
+ is_batch_size_1 : bool ,
424
+ is_reduced_type : bool ,
425
+ has_convert : bool ,
426
+ is_inv_scale : bool ,
400
427
):
401
428
qsdpa_v_pattern = _get_qsdpa_qkv_pattern (
402
429
qtype , is_batch_size_1 , is_reduced_type , has_convert , "value"
@@ -429,8 +456,20 @@ def _get_qsdpa_final_pattern(
429
456
430
457
431
458
def _register_qsdpa_lowerings (custom_pass_dict ):
432
- for qtype , has_mask , is_batch_size_1 , is_reduced_type , has_convert , is_inv_scale in itertools .product (
433
- quantize_dtypes , [True , False ], [True , False ], [True , False ], [True , False ], [True , False ]
459
+ for (
460
+ qtype ,
461
+ has_mask ,
462
+ is_batch_size_1 ,
463
+ is_reduced_type ,
464
+ has_convert ,
465
+ is_inv_scale ,
466
+ ) in itertools .product (
467
+ quantize_dtypes ,
468
+ [True , False ],
469
+ [True , False ],
470
+ [True , False ],
471
+ [True , False ],
472
+ [True , False ],
434
473
):
435
474
_register_qsdpa_pattern (
436
475
_get_qsdpa_final_pattern (
0 commit comments