|
| 1 | +from collections.abc import Sequence |
1 | 2 | from typing import Optional
|
2 | 3 |
|
3 | 4 | import sympy
|
4 | 5 | import torch
|
5 | 6 | from torch._inductor.ir import ChoiceCaller, FixedLayout, TensorBox, get_fill_order
|
6 |
| -from torch._inductor.kernel.flex_attention import construct_strides, maybe_realize |
7 | 7 | from torch._inductor.lowering import register_lowering
|
8 | 8 | from torch._inductor.select_algorithm import (
|
9 | 9 | ExternKernelChoice,
|
10 | 10 | autotune_select_algorithm,
|
| 11 | + realize_inputs, |
11 | 12 | )
|
| 13 | +from torch.utils._pytree import tree_map |
12 | 14 |
|
13 | 15 | from .codegen.cpp_int8_sdpa_template import CppInt8SdpaTemplate
|
14 | 16 |
|
| 17 | + |
| 18 | +# Copied directly from https://github.com/pytorch/pytorch/commit/e221a1c853b425b8d70b36d545ccb32ddc8176bd |
| 19 | +def maybe_realize(args): |
| 20 | + """Accepts a list of optional IRNodes and returns a list of realized IRNodes""" |
| 21 | + return tree_map( |
| 22 | + lambda x: ( |
| 23 | + realize_inputs(x) |
| 24 | + if x is not None and not isinstance(x, sympy.Symbol) |
| 25 | + else x |
| 26 | + ), |
| 27 | + args, |
| 28 | + ) |
| 29 | + |
| 30 | + |
| 31 | +# Copied directly from https://github.com/pytorch/pytorch/commit/e221a1c853b425b8d70b36d545ccb32ddc8176bd |
| 32 | +def construct_strides( |
| 33 | + sizes: Sequence[int], |
| 34 | + fill_order: Sequence[int], |
| 35 | +) -> Sequence[int]: |
| 36 | + """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" |
| 37 | + # Initialize strides |
| 38 | + assert len(sizes) == len(fill_order), ( |
| 39 | + "Length of sizes must match the length of the fill order" |
| 40 | + ) |
| 41 | + strides = [0] * len(sizes) |
| 42 | + |
| 43 | + # Start with stride 1 for the innermost dimension |
| 44 | + current_stride = 1 |
| 45 | + |
| 46 | + # Iterate through the fill order populating strides |
| 47 | + for dim in fill_order: |
| 48 | + strides[dim] = current_stride |
| 49 | + current_stride *= sizes[dim] |
| 50 | + |
| 51 | + return strides |
| 52 | + |
| 53 | + |
15 | 54 | op_int8_sdpa = ExternKernelChoice(
|
16 | 55 | torch.ops.torchao.qscaled_dot_product.default,
|
17 | 56 | "torchao::qscaled_dot_product",
|
|
0 commit comments