Skip to content

Commit f5e4a43

Browse files
committed
fix format
1 parent c5d1d59 commit f5e4a43

File tree

2 files changed

+58
-18
lines changed

2 files changed

+58
-18
lines changed

torchao/prototype/inductor/fx_passes/qsdpa_fusion.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
aten = torch.ops.aten
3131
quantize_dtypes = [torch.uint8, torch.float8_e4m3fn]
3232

33+
3334
def _is_valid_qsdpa_pattern():
3435
def fn(match):
3536
assert all(k in match.kwargs for k in ("query", "key", "value"))
@@ -117,7 +118,9 @@ def qsdpa(match: Match, *args, **kwargs):
117118
return qsdpa
118119

119120

120-
def _generate_dequant_pattern(input_pattern, qtype, is_reduced_type, scale: str, zp: str=None):
121+
def _generate_dequant_pattern(
122+
input_pattern, qtype, is_reduced_type, scale: str, zp: str = None
123+
):
121124
if qtype == torch.uint8:
122125
assert zp is not None, "Zero point must be provided for uint8 dequantization"
123126
return CallFunction(
@@ -146,7 +149,7 @@ def _generate_dequant_pattern(input_pattern, qtype, is_reduced_type, scale: str,
146149
)
147150

148151

149-
def _generate_quant_pattern(input_pattern, qtype, scale: str, zp: str=None):
152+
def _generate_quant_pattern(input_pattern, qtype, scale: str, zp: str = None):
150153
if qtype == torch.uint8:
151154
assert zp is not None, "Zero point must be provided for uint8 quantization"
152155
return CallFunction(
@@ -168,7 +171,11 @@ def _generate_quant_pattern(input_pattern, qtype, scale: str, zp: str=None):
168171

169172

170173
def _get_qsdpa_qkv_pattern(
171-
qtype, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool, input_name: str
174+
qtype,
175+
is_batch_size_1: bool,
176+
is_reduced_type: bool,
177+
has_convert: bool,
178+
input_name: str,
172179
):
173180
assert input_name in ["query", "key", "value"]
174181
qsdpa_qkv_pattern_before_dequant = CallFunction(
@@ -221,7 +228,12 @@ def _get_qsdpa_qkv_pattern(
221228

222229

223230
def _get_qsdpa_score_pattern(
224-
qtype, has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool, is_inv_scale: bool
231+
qtype,
232+
has_mask: bool,
233+
is_batch_size_1: bool,
234+
is_reduced_type: bool,
235+
has_convert: bool,
236+
is_inv_scale: bool,
225237
):
226238
qsdpa_q_pattern = _get_qsdpa_qkv_pattern(
227239
qtype, is_batch_size_1, is_reduced_type, has_convert, "query"
@@ -276,7 +288,12 @@ def _get_qsdpa_score_pattern(
276288

277289

278290
def _get_qsdpa_exp_pattern(
279-
qtype, has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool, is_inv_scale: bool
291+
qtype,
292+
has_mask: bool,
293+
is_batch_size_1: bool,
294+
is_reduced_type: bool,
295+
has_convert: bool,
296+
is_inv_scale: bool,
280297
):
281298
qsdpa_score_pattern = _get_qsdpa_score_pattern(
282299
qtype, has_mask, is_batch_size_1, is_reduced_type, has_convert, is_inv_scale
@@ -298,15 +315,15 @@ def _get_qsdpa_exp_pattern(
298315
_users=2,
299316
)
300317
elif is_inv_scale:
301-
return CallFunction(
302-
aten.exp.default,
303-
CallFunction(
304-
aten.div.Tensor,
305-
qsdpa_exp_basic_pattern,
306-
KeywordArg("inv_scale"),
307-
),
308-
_users=2,
309-
)
318+
return CallFunction(
319+
aten.exp.default,
320+
CallFunction(
321+
aten.div.Tensor,
322+
qsdpa_exp_basic_pattern,
323+
KeywordArg("inv_scale"),
324+
),
325+
_users=2,
326+
)
310327
else:
311328
return CallFunction(
312329
aten.exp.default,
@@ -320,7 +337,12 @@ def _get_qsdpa_exp_pattern(
320337

321338

322339
def _get_qsdpa_attn_pattern(
323-
qtype, has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool, is_inv_scale: bool
340+
qtype,
341+
has_mask: bool,
342+
is_batch_size_1: bool,
343+
is_reduced_type: bool,
344+
has_convert: bool,
345+
is_inv_scale: bool,
324346
):
325347
qsdpa_exp_pattern = _get_qsdpa_exp_pattern(
326348
qtype, has_mask, is_batch_size_1, is_reduced_type, has_convert, is_inv_scale
@@ -396,7 +418,12 @@ def _get_qsdpa_attn_pattern(
396418
# has_convert: convert type if dequant out dtype is assigned
397419
# is_inv_scale: if the scale in SDPA is inversed, in which case it is multiplied instead of divided
398420
def _get_qsdpa_final_pattern(
399-
qtype, has_mask: bool, is_batch_size_1: bool, is_reduced_type: bool, has_convert: bool, is_inv_scale: bool
421+
qtype,
422+
has_mask: bool,
423+
is_batch_size_1: bool,
424+
is_reduced_type: bool,
425+
has_convert: bool,
426+
is_inv_scale: bool,
400427
):
401428
qsdpa_v_pattern = _get_qsdpa_qkv_pattern(
402429
qtype, is_batch_size_1, is_reduced_type, has_convert, "value"
@@ -429,8 +456,20 @@ def _get_qsdpa_final_pattern(
429456

430457

431458
def _register_qsdpa_lowerings(custom_pass_dict):
432-
for qtype, has_mask, is_batch_size_1, is_reduced_type, has_convert, is_inv_scale in itertools.product(
433-
quantize_dtypes, [True, False], [True, False], [True, False], [True, False], [True, False]
459+
for (
460+
qtype,
461+
has_mask,
462+
is_batch_size_1,
463+
is_reduced_type,
464+
has_convert,
465+
is_inv_scale,
466+
) in itertools.product(
467+
quantize_dtypes,
468+
[True, False],
469+
[True, False],
470+
[True, False],
471+
[True, False],
472+
[True, False],
434473
):
435474
_register_qsdpa_pattern(
436475
_get_qsdpa_final_pattern(

torchao/prototype/inductor/qsdpa_lowering.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
quantize_dtypes = [torch.uint8, torch.float8_e4m3fn]
2323

24+
2425
def register_qsdpa():
2526
@register_lowering(
2627
torch.ops.torchao.qscaled_dot_product.default, type_promotion_kind=None

0 commit comments

Comments
 (0)