From cf45a2ea285f7e8dc14b41b691e0ae6c2fc11414 Mon Sep 17 00:00:00 2001 From: Cathal Corbett Date: Tue, 18 Nov 2025 09:59:18 +0100 Subject: [PATCH] [TOSA] MultiheadAttention legalization - Legalize Torch scaled_dot_product_attention into TOSA by adding the necessary patterns in TorchToTosa.cpp plus backend type-conversion hooks. - Introduce a detailed decomposition path for multi-head attention within DecomposeComplexOps.cpp, preparing inputs for TOSA lowering. - Expands the PT1 e2e suite with a dedicated multi-head attention MLIR/Python test and drop the corresponding xfails now that the path works. Signed-off-by: Cathal Corbett Change-Id: I96c17aefd25b979f1cf6e897d91d5a29f0a2fa85 --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 65 +++++- .../Torch/Transforms/DecomposeComplexOps.cpp | 214 ++++++++++++++++++ .../Transforms/BackendTypeConversion.cpp | 18 ++ projects/pt1/e2e_testing/xfail_sets.py | 10 - .../scaled_dot_product_attention_lowering.py | 72 ++++++ .../TorchToTosa/multi_head_attention.mlir | 29 +++ 6 files changed, 397 insertions(+), 11 deletions(-) create mode 100644 projects/pt1/test/python/scaled_dot_product_attention_lowering.py create mode 100644 test/Conversion/TorchToTosa/multi_head_attention.mlir diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index c3dbc095a745..dda2281f0a1b 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4016,8 +4016,28 @@ LogicalResult ConvertAtenOp::matchAndRewrite( transposedDims[dim0] = dim1; transposedDims[dim1] = dim0; + Type resultType = getTypeConverter()->convertType(op.getType()); + if (auto rankedSelf = dyn_cast(selfType)) { + SmallVector transposedShape(rankedSelf.getRank(), + ShapedType::kDynamic); + if (rankedSelf.hasStaticShape()) { + auto staticShape = + llvm::to_vector(makeShapeTorchCompatible(rankedSelf.getShape())); + auto dim0Index = static_cast(dim0); + auto dim1Index = static_cast(dim1); + if (dim0Index < staticShape.size() && dim1Index < staticShape.size()) + std::swap(staticShape[dim0Index], staticShape[dim1Index]); + for (size_t i = 0; i < staticShape.size(); ++i) + transposedShape[i] = staticShape[i]; + } + auto rankedResult = RankedTensorType::get( + makeShapeLLVMCompatible(transposedShape), rankedSelf.getElementType()); + if (auto converted = getTypeConverter()->convertType(rankedResult)) + resultType = converted; + } + rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), + op, resultType, adaptor.getSelf(), rewriter.getDenseI32ArrayAttr(transposedDims)); return success(); @@ -9387,6 +9407,32 @@ class ConvertTorchToTosa }; } // namespace +namespace { +class FoldStaticToDynamicTensorCast + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto sourceType = dyn_cast(adaptor.getSource().getType()); + auto resultType = dyn_cast(op.getType()); + if (!sourceType || !resultType) + return failure(); + if (sourceType.getElementType() != resultType.getElementType()) + return failure(); + if (!sourceType.hasStaticShape()) + return failure(); + if (!resultType.hasStaticShape()) + return failure(); + if (sourceType == resultType) + return failure(); + rewriter.replaceOp(op, adaptor.getSource()); + return success(); + } +}; +} // namespace + void populateTorchToTosaConversionLegalOps(ConversionTarget &target) { // The following ops are never the primary reason why lowering fails. // The backend contract only allows functions to return tensors thus there @@ -9402,6 +9448,21 @@ void populateTorchToTosaConversionLegalOps(ConversionTarget &target) { target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); + target.addDynamicallyLegalOp([](tensor::CastOp op) -> bool { + auto sourceType = dyn_cast(op.getSource().getType()); + auto resultType = dyn_cast(op.getType()); + if (!sourceType || !resultType) + return true; + if (sourceType.getElementType() != resultType.getElementType()) + return true; + if (!sourceType.hasStaticShape()) + return true; + if (!resultType.hasStaticShape()) + return true; + if (sourceType == resultType) + return true; + return false; + }); } std::set populateTorchToTosaConversionPatternsAndIllegalOps( @@ -9723,6 +9784,8 @@ std::set populateTorchToTosaConversionPatternsAndIllegalOps( INSERT_CAST_ATENOP_PATTERN(AtenIntReprOp); #undef INSERT_CAST_ATENOP_PATTERN + patterns.add(typeConverter, context); + return illegalOps; } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 08b25c9b6f60..ce58291fd57c 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2295,6 +2295,218 @@ class DecomposeAtenTraceOp : public OpRewritePattern { }; } // namespace +namespace { +// Decompose scaled dot product attention into matmul/softmax pipeline when +// there is no masking, dropout, causal, or GQA behaviour. +class DecomposeAtenScaledDotProductAttentionOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenScaledDotProductAttentionOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + if (!isa(op.getAttnMask().getType())) + return rewriter.notifyMatchFailure( + op, "attention mask decomposition not implemented"); + + double dropoutP; + if (!matchPattern(op.getDropoutP(), m_TorchConstantFloat(&dropoutP)) || + dropoutP != 0.0) + return rewriter.notifyMatchFailure( + op, "expected dropout_p to be the constant 0.0"); + + bool isCausal; + if (!matchPattern(op.getIsCausal(), m_TorchConstantBool(&isCausal)) || + isCausal) + return rewriter.notifyMatchFailure(op, + "causal attention not supported yet"); + + bool enableGqa; + if (!matchPattern(op.getEnableGqa(), m_TorchConstantBool(&enableGqa)) || + enableGqa) + return rewriter.notifyMatchFailure(op, + "grouped-query attention unsupported"); + + Value query = op.getQuery(); + Value key = op.getKey(); + Value value = op.getValue(); + + auto queryTensorType = dyn_cast(query.getType()); + auto keyTensorType = dyn_cast(key.getType()); + auto valueTensorType = dyn_cast(value.getType()); + if (!queryTensorType || !keyTensorType || !valueTensorType) + return rewriter.notifyMatchFailure(op, "expected tensor inputs"); + if (!queryTensorType.hasSizes() || !keyTensorType.hasSizes() || + !valueTensorType.hasSizes()) + return rewriter.notifyMatchFailure( + op, "expected tensor inputs to have known shapes"); + auto queryValueTensorType = dyn_cast(queryTensorType); + auto keyValueTensorType = dyn_cast(keyTensorType); + auto valueValueTensorType = dyn_cast(valueTensorType); + if (!queryValueTensorType || !keyValueTensorType || !valueValueTensorType) + return rewriter.notifyMatchFailure(op, "expected value tensor semantics"); + + Value oneInt = + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value zeroInt = + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); + Value rank = AtenDimOp::create(rewriter, loc, query); + Value lastDim = AtenSubIntOp::create(rewriter, loc, rank, oneInt); + Value headDim = AtenSizeIntOp::create(rewriter, loc, query, lastDim); + Value seqDimIndex = AtenSubIntOp::create(rewriter, loc, lastDim, oneInt); + Value seqLen = AtenSizeIntOp::create(rewriter, loc, query, seqDimIndex); + Value keySeqLen = AtenSizeIntOp::create(rewriter, loc, key, seqDimIndex); + ArrayRef querySizes = queryValueTensorType.getSizes(); + bool hasExplicitHeadDim = querySizes.size() >= 4; + Value numHeadsSize = + hasExplicitHeadDim + ? (Value)AtenSizeIntOp::create(rewriter, loc, query, oneInt) + : oneInt; + Value batchSize = AtenSizeIntOp::create(rewriter, loc, query, zeroInt); + auto listIntType = + Torch::ListType::get(Torch::IntType::get(rewriter.getContext())); + + auto getDimValue = [&](int64_t staticDim, Value fallback) -> Value { + if (staticDim != Torch::kUnknownSize) + return ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(staticDim)); + return fallback; + }; + + Value scaleFloat; + if (isa(op.getScale().getType())) { + Value sqrtHeadDim = AtenSqrtIntOp::create(rewriter, loc, headDim); + Value oneFloat = + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0)); + scaleFloat = AtenDivFloatOp::create(rewriter, loc, oneFloat, sqrtHeadDim); + } else { + scaleFloat = op.getScale(); + } + + Value negTwo = + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-2)); + Value negOne = + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-1)); + + ArrayRef keySizes = keyValueTensorType.getSizes(); + SmallVector keyTransposedSizes(keySizes.begin(), keySizes.end()); + if (keyTransposedSizes.size() < 2) + return rewriter.notifyMatchFailure( + op, "expected key tensor rank >= 2 for transpose"); + std::swap(keyTransposedSizes[keyTransposedSizes.size() - 1], + keyTransposedSizes[keyTransposedSizes.size() - 2]); + ArrayRef keyTransposedRef(keyTransposedSizes); + std::optional> keyTransposedOpt(keyTransposedRef); + Type keyTransposedType = keyValueTensorType.getWithSizesAndDtypeAndSparsity( + keyTransposedSizes, keyValueTensorType.getOptionalDtype(), + keyValueTensorType.getOptionalSparsity()); + Value keyTransposed = AtenTransposeIntOp::create( + rewriter, loc, keyTransposedType, key, negTwo, negOne); + SmallVector keyDims; + auto getOrFallback = [&](ArrayRef staticDims, unsigned idx, + Value fallback) -> Value { + return getDimValue(idx < staticDims.size() ? staticDims[idx] + : Torch::kUnknownSize, + fallback); + }; + keyDims.push_back(getOrFallback(keyTransposedSizes, 0, batchSize)); + if (keyTransposedSizes.size() == 4) { + keyDims.push_back(getOrFallback(keyTransposedSizes, 1, numHeadsSize)); + keyDims.push_back(getOrFallback(keyTransposedSizes, 2, seqLen)); + keyDims.push_back(getOrFallback(keyTransposedSizes, 3, keySeqLen)); + } else { + keyDims.push_back(getOrFallback(keyTransposedSizes, 1, headDim)); + keyDims.push_back(getOrFallback(keyTransposedSizes, 2, keySeqLen)); + } + Value keyTransposeShapeList = + PrimListConstructOp::create(rewriter, loc, listIntType, keyDims); + keyTransposed = AtenViewOp::create(rewriter, loc, keyTransposedType, + keyTransposed, keyTransposeShapeList); + + auto getStaticDim = [](ArrayRef sizes, int64_t index) { + if (index < 0) + index += sizes.size(); + if (index < 0 || index >= static_cast(sizes.size())) + return Torch::kUnknownSize; + return sizes[index]; + }; + int64_t queryBatchStatic = getStaticDim(querySizes, 0); + int64_t querySeqStatic = getStaticDim(querySizes, -2); + int64_t keySeqStatic = getStaticDim(keySizes, -2); + int64_t queryHeadsStatic = + hasExplicitHeadDim ? getStaticDim(querySizes, 1) : 1; + SmallVector scoresSizes; + if (hasExplicitHeadDim) + scoresSizes.assign( + {queryBatchStatic, queryHeadsStatic, querySeqStatic, keySeqStatic}); + else + scoresSizes.assign({queryBatchStatic, querySeqStatic, keySeqStatic}); + Type scoresType = ValueTensorType::get( + op->getContext(), + ArrayRef(scoresSizes.begin(), scoresSizes.end()), + queryValueTensorType.getOptionalDtype(), + queryValueTensorType.getOptionalSparsity()); + Value scores = + AtenMatmulOp::create(rewriter, loc, scoresType, query, keyTransposed); + SmallVector scoresDims; + scoresDims.push_back(getDimValue(scoresSizes[0], batchSize)); + unsigned seqIndex = 1; + if (hasExplicitHeadDim) { + scoresDims.push_back(getDimValue(scoresSizes[1], numHeadsSize)); + seqIndex = 2; + } + scoresDims.push_back(getDimValue(scoresSizes[seqIndex], seqLen)); + scoresDims.push_back(getDimValue(scoresSizes.back(), keySeqLen)); + Value scoresShapeList = + PrimListConstructOp::create(rewriter, loc, listIntType, scoresDims); + scores = + AtenViewOp::create(rewriter, loc, scoresType, scores, scoresShapeList); + Value scaledScores = + AtenMulScalarOp::create(rewriter, loc, scoresType, scores, scaleFloat); + + SmallVector reducedSizes(scoresSizes.begin(), scoresSizes.end()); + reducedSizes.back() = 1; + ArrayRef reducedSizesRef(reducedSizes); + std::optional> reducedSizesOpt(reducedSizesRef); + Type reducedValueType = + ValueTensorType::get(op->getContext(), reducedSizesOpt, + queryValueTensorType.getOptionalDtype()); + Type reducedIndexType = ValueTensorType::get( + op->getContext(), reducedSizesOpt, + IntegerType::get(op->getContext(), 64, IntegerType::Signed)); + Value keepDimTrue = + ConstantBoolOp::create(rewriter, loc, rewriter.getBoolAttr(true)); + auto maxOp = + AtenMaxDimOp::create(rewriter, loc, reducedValueType, reducedIndexType, + scaledScores, negOne, keepDimTrue); + Value softmaxMax = TensorStaticInfoCastOp::create( + rewriter, loc, reducedValueType, maxOp.getValues()); + Value centered = + createTensorSub(rewriter, loc, scoresType, scaledScores, softmaxMax); + Value unNormalizedExp = + AtenExpOp::create(rewriter, loc, scoresType, centered); + SmallVector softmaxDims{negOne}; + Value dimList = + PrimListConstructOp::create(rewriter, loc, listIntType, softmaxDims); + Value noneValue = ConstantNoneOp::create(rewriter, loc); + Value softmaxDenominator = AtenSumDimIntListOp::create( + rewriter, loc, reducedValueType, unNormalizedExp, dimList, keepDimTrue, + noneValue); + softmaxDenominator = TensorStaticInfoCastOp::create( + rewriter, loc, reducedValueType, softmaxDenominator); + Value softmax = AtenDivTensorOp::create( + rewriter, loc, scoresType, unNormalizedExp, softmaxDenominator); + + Value output = + AtenMatmulOp::create(rewriter, loc, op.getType(), softmax, value); + + rewriter.replaceOp(op, output); + return success(); + } +}; +} // namespace + // Calculates the softmax function on the given `input` tensor. Softmax(x) = // exp(x)/sum(exp(x)). // To avoid overflow we use the following decomposition rule: @@ -13084,6 +13296,8 @@ class DecomposeComplexOpsPass legalOpsSet.clear(); legalOpsSet.insert(legalOps.begin(), legalOps.end()); + patterns.add(context); + addPatternIfTargetOpIsIllegal( patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index f7ffb14f0602..d7394531e965 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" using namespace mlir; @@ -40,6 +41,23 @@ static void setupValueTensorToBuiltinTensorConversion( return {}; return ToBuiltinTensorOp::create(builder, loc, type, inputs[0]); }); + typeConverter.addTargetMaterialization([](OpBuilder &builder, Type type, + ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return Value(); + auto fromType = dyn_cast(inputs[0].getType()); + auto toType = dyn_cast(type); + if (!fromType || !toType) + return Value(); + if (fromType == toType) + return inputs[0]; + if (fromType.getElementType() != toType.getElementType()) + return Value(); + if (!toType.hasStaticShape()) + return Value(); + return tensor::CastOp::create(builder, loc, toType, inputs[0]); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::ValueTensorType type, ValueRange inputs, Location loc) -> Value { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 4c8318570c7b..f4937f686113 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -50,11 +50,8 @@ "ScaledDotProductAttentionBoolMaskModule_basic", "ScaledDotProductAttentionDifferentCausalModule_basic", "ScaledDotProductAttentionDifferentDynamicCausalModule_basic", - "ScaledDotProductAttentionDifferentModule_basic", "ScaledDotProductAttentionMaskModule_basic", "ScaledDotProductAttentionSameCausalModule_basic", - "ScaledDotProductAttentionSameDynamicModule_basic", - "ScaledDotProductAttentionSameModule_basic", } LINALG_CRASHING_SET = { @@ -953,11 +950,8 @@ "ScaledDotProductAttentionBoolMaskModule_basic", "ScaledDotProductAttentionDifferentCausalModule_basic", "ScaledDotProductAttentionDifferentDynamicCausalModule_basic", - "ScaledDotProductAttentionDifferentModule_basic", "ScaledDotProductAttentionMaskModule_basic", "ScaledDotProductAttentionSameCausalModule_basic", - "ScaledDotProductAttentionSameDynamicModule_basic", - "ScaledDotProductAttentionSameModule_basic", "SubIntModule_basic", "TensorToIntZeroRank_basic", "UpSampleNearest2dDynamicFactor_basic", @@ -3978,11 +3972,8 @@ "ScaledDotProductAttentionBoolMaskModule_basic", "ScaledDotProductAttentionDifferentCausalModule_basic", "ScaledDotProductAttentionDifferentDynamicCausalModule_basic", - "ScaledDotProductAttentionDifferentModule_basic", "ScaledDotProductAttentionMaskModule_basic", "ScaledDotProductAttentionSameCausalModule_basic", - "ScaledDotProductAttentionSameDynamicModule_basic", - "ScaledDotProductAttentionSameModule_basic", "ScaledDotProductAttentionGQAModule_basic", # error: 'tosa.scatter' op requires dimensions K >= W "IndexPut1DFloatNonAccumulateModule_basic", @@ -4887,7 +4878,6 @@ # REMOVE WHEN ENABLE_GQA IS ADDED "ScaledDotProductAttentionBoolMaskModule_basic", "ScaledDotProductAttentionSameCausalModule_basic", - "ScaledDotProductAttentionSameDynamicModule_basic", "ScatterAddDynamicModule_basic", "ScatterReduceFloatMaxModule", "ScatterReduceFloatMaxModuleIncludeSelf", diff --git a/projects/pt1/test/python/scaled_dot_product_attention_lowering.py b/projects/pt1/test/python/scaled_dot_product_attention_lowering.py new file mode 100644 index 000000000000..0f9ced49d801 --- /dev/null +++ b/projects/pt1/test/python/scaled_dot_product_attention_lowering.py @@ -0,0 +1,72 @@ +# RUN: %PYTHON %s | FileCheck %s + +import torch + +from torch_mlir import ir +from torch_mlir.compiler_utils import run_pipeline_with_repro_report +from torch_mlir.dialects import torch as torch_d +from torch_mlir.extras.fx_decomp_util import get_decomposition_table +from torch_mlir.extras.fx_importer import FxImporter + + +class SdpaModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.aten.scaled_dot_product_attention.default( + x, + x, + x, + None, + 0.0, + False, + scale=None, + enable_gqa=False, + ) + + +def lower_sdpa() -> None: + module = SdpaModule().eval() + + example_input = torch.randn(2, 4, 8) + exported = torch.export.export(module, (example_input,)) + + decomposition_table = get_decomposition_table() + if decomposition_table: + exported = exported.run_decompositions(decomposition_table) + + context = ir.Context() + torch_d.register_dialect(context) + importer = FxImporter(context=context) + importer.import_frozen_program(exported) + mlir_module = importer.module + + pipeline = """ + builtin.module( + func.func(torch-match-quantized-custom-ops), + torchdynamo-export-to-torch-backend-pipeline{extra-library= backend-legal-ops=aten.as_strided}, + torch-adjust-calling-conventions + ) + """ + run_pipeline_with_repro_report( + mlir_module, + pipeline, + "Lowering TorchFX IR -> Torch Backend IR", + enable_ir_printing=False, + ) + + module_str = str(mlir_module.operation) + if "torch.aten.scaled_dot_product_attention.default" in module_str: + raise RuntimeError( + "scaled_dot_product_attention unexpectedly survived lowering" + ) + + +if __name__ == "__main__": + torch.manual_seed(0) + if hasattr(torch, "set_deterministic_debug_mode"): + torch.set_deterministic_debug_mode("error") + + lower_sdpa() + print("lowered scaled dot product attention") + print("SUCCESS") +# CHECK: lowered scaled dot product attention +# CHECK: SUCCESS diff --git a/test/Conversion/TorchToTosa/multi_head_attention.mlir b/test/Conversion/TorchToTosa/multi_head_attention.mlir new file mode 100644 index 000000000000..b06a86a58d2a --- /dev/null +++ b/test/Conversion/TorchToTosa/multi_head_attention.mlir @@ -0,0 +1,29 @@ +// RUN: torch-mlir-opt %s -torch-decompose-complex-ops -convert-torch-to-tosa -split-input-file | FileCheck %s + +// Checks that scaled dot product attention (single-head configuration) lowers +// through the decomposition pass into the expected TOSA matmul + softmax flow. +module { + // CHECK-LABEL: func.func @scaled_dot_product_attention( + // CHECK: %[[KEY_T:.*]] = tosa.transpose %{{.*}} {perms = array} : (tensor<1x4x8xf32>) -> tensor<1x8x4xf32> + // CHECK: %[[KEY_VIEW:.*]] = tosa.reshape %[[KEY_T]], %{{.*}} : (tensor<1x8x4xf32>, !tosa.shape<3>) -> tensor<1x8x4xf32> + // CHECK: %[[QK:.*]] = tosa.matmul %{{.*}}, %[[KEY_VIEW]], %{{.*}}, %{{.*}} : (tensor<1x4x8xf32>, tensor<1x8x4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4xf32> + // CHECK: %[[SCORES:.*]] = tosa.reshape %[[QK]], %{{.*}} : (tensor<1x4x4xf32>, !tosa.shape<3>) -> tensor<1x4x4xf32> + // CHECK: %[[SCALED:.*]] = tosa.mul %[[SCORES]], %{{.*}}, %{{.*}} : (tensor<1x4x4xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<1x4x4xf32> + // CHECK: %[[CENTERED:.*]] = tosa.sub %[[SCALED]], %{{.*}} : (tensor<1x4x4xf32>, tensor<1x4x1xf32>) -> tensor<1x4x4xf32> + // CHECK: %[[EXP:.*]] = tosa.exp %[[CENTERED]] : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32> + // CHECK: %[[DENOM:.*]] = tosa.reduce_sum %[[EXP]] {axis = 2 : i32} : (tensor<1x4x4xf32>) -> tensor<1x4x1xf32> + // CHECK: %[[SOFTMAX:.*]] = tosa.mul %[[EXP]], %{{.*}}, %{{.*}} : (tensor<1x4x4xf32>, tensor<1x4x1xf32>, tensor<1xi8>) -> tensor<1x4x4xf32> + // CHECK: %[[OUTPUT:.*]] = tosa.matmul %[[SOFTMAX]], %{{.*}}, %{{.*}}, %{{.*}} : (tensor<1x4x4xf32>, tensor<1x4x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8xf32> + // CHECK: return %{{.*}} : !torch.vtensor<[1,4,8],f32> + func.func @scaled_dot_product_attention( + %query: !torch.vtensor<[1,4,8],f32>, + %key: !torch.vtensor<[1,4,8],f32>, + %value: !torch.vtensor<[1,4,8],f32>) -> !torch.vtensor<[1,4,8],f32> { + %none = torch.constant.none + %zero = torch.constant.float 0.000000e+00 + %false = torch.constant.bool false + %result = torch.aten.scaled_dot_product_attention %query, %key, %value, %none, %zero, %false, %none, %false : + !torch.vtensor<[1,4,8],f32>, !torch.vtensor<[1,4,8],f32>, !torch.vtensor<[1,4,8],f32>, !torch.none, !torch.float, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,4,8],f32> + return %result : !torch.vtensor<[1,4,8],f32> + } +}