Skip to content

Commit 8d4a5d8

Browse files
authored
fix bc breakage flex path (#2652)
* fix bc breakage flex path * lint * Update int8_sdpa_lowering.py * Driss feedback * Update int8_sdpa_lowering.py
1 parent 7c5c0b5 commit 8d4a5d8

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

torchao/prototype/inductor/int8_sdpa_lowering.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,56 @@
1+
from collections.abc import Sequence
12
from typing import Optional
23

34
import sympy
45
import torch
56
from torch._inductor.ir import ChoiceCaller, FixedLayout, TensorBox, get_fill_order
6-
from torch._inductor.kernel.flex_attention import construct_strides, maybe_realize
77
from torch._inductor.lowering import register_lowering
88
from torch._inductor.select_algorithm import (
99
ExternKernelChoice,
1010
autotune_select_algorithm,
11+
realize_inputs,
1112
)
13+
from torch.utils._pytree import tree_map
1214

1315
from .codegen.cpp_int8_sdpa_template import CppInt8SdpaTemplate
1416

17+
18+
# Copied directly from https://github.com/pytorch/pytorch/commit/e221a1c853b425b8d70b36d545ccb32ddc8176bd
19+
def maybe_realize(args):
20+
"""Accepts a list of optional IRNodes and returns a list of realized IRNodes"""
21+
return tree_map(
22+
lambda x: (
23+
realize_inputs(x)
24+
if x is not None and not isinstance(x, sympy.Symbol)
25+
else x
26+
),
27+
args,
28+
)
29+
30+
31+
# Copied directly from https://github.com/pytorch/pytorch/commit/e221a1c853b425b8d70b36d545ccb32ddc8176bd
32+
def construct_strides(
33+
sizes: Sequence[int],
34+
fill_order: Sequence[int],
35+
) -> Sequence[int]:
36+
"""From a list of sizes and a fill order, construct the strides of the permuted tensor."""
37+
# Initialize strides
38+
assert len(sizes) == len(fill_order), (
39+
"Length of sizes must match the length of the fill order"
40+
)
41+
strides = [0] * len(sizes)
42+
43+
# Start with stride 1 for the innermost dimension
44+
current_stride = 1
45+
46+
# Iterate through the fill order populating strides
47+
for dim in fill_order:
48+
strides[dim] = current_stride
49+
current_stride *= sizes[dim]
50+
51+
return strides
52+
53+
1554
op_int8_sdpa = ExternKernelChoice(
1655
torch.ops.torchao.qscaled_dot_product.default,
1756
"torchao::qscaled_dot_product",

0 commit comments

Comments
 (0)