Skip to content

Commit cf45a2e

Browse files
committed
[TOSA] MultiheadAttention legalization
- Legalize Torch scaled_dot_product_attention into TOSA by adding the necessary patterns in TorchToTosa.cpp plus backend type-conversion hooks. - Introduce a detailed decomposition path for multi-head attention within DecomposeComplexOps.cpp, preparing inputs for TOSA lowering. - Expands the PT1 e2e suite with a dedicated multi-head attention MLIR/Python test and drop the corresponding xfails now that the path works. Signed-off-by: Cathal Corbett <[email protected]> Change-Id: I96c17aefd25b979f1cf6e897d91d5a29f0a2fa85
1 parent 244f4b6 commit cf45a2e

File tree

6 files changed

+397
-11
lines changed

6 files changed

+397
-11
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4016,8 +4016,28 @@ LogicalResult ConvertAtenOp<AtenTransposeIntOp>::matchAndRewrite(
40164016
transposedDims[dim0] = dim1;
40174017
transposedDims[dim1] = dim0;
40184018

4019+
Type resultType = getTypeConverter()->convertType(op.getType());
4020+
if (auto rankedSelf = dyn_cast<RankedTensorType>(selfType)) {
4021+
SmallVector<int64_t> transposedShape(rankedSelf.getRank(),
4022+
ShapedType::kDynamic);
4023+
if (rankedSelf.hasStaticShape()) {
4024+
auto staticShape =
4025+
llvm::to_vector(makeShapeTorchCompatible(rankedSelf.getShape()));
4026+
auto dim0Index = static_cast<size_t>(dim0);
4027+
auto dim1Index = static_cast<size_t>(dim1);
4028+
if (dim0Index < staticShape.size() && dim1Index < staticShape.size())
4029+
std::swap(staticShape[dim0Index], staticShape[dim1Index]);
4030+
for (size_t i = 0; i < staticShape.size(); ++i)
4031+
transposedShape[i] = staticShape[i];
4032+
}
4033+
auto rankedResult = RankedTensorType::get(
4034+
makeShapeLLVMCompatible(transposedShape), rankedSelf.getElementType());
4035+
if (auto converted = getTypeConverter()->convertType(rankedResult))
4036+
resultType = converted;
4037+
}
4038+
40194039
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
4020-
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
4040+
op, resultType, adaptor.getSelf(),
40214041
rewriter.getDenseI32ArrayAttr(transposedDims));
40224042

40234043
return success();
@@ -9387,6 +9407,32 @@ class ConvertTorchToTosa
93879407
};
93889408
} // namespace
93899409

9410+
namespace {
9411+
class FoldStaticToDynamicTensorCast
9412+
: public OpConversionPattern<tensor::CastOp> {
9413+
public:
9414+
using OpConversionPattern<tensor::CastOp>::OpConversionPattern;
9415+
LogicalResult
9416+
matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
9417+
ConversionPatternRewriter &rewriter) const override {
9418+
auto sourceType = dyn_cast<RankedTensorType>(adaptor.getSource().getType());
9419+
auto resultType = dyn_cast<RankedTensorType>(op.getType());
9420+
if (!sourceType || !resultType)
9421+
return failure();
9422+
if (sourceType.getElementType() != resultType.getElementType())
9423+
return failure();
9424+
if (!sourceType.hasStaticShape())
9425+
return failure();
9426+
if (!resultType.hasStaticShape())
9427+
return failure();
9428+
if (sourceType == resultType)
9429+
return failure();
9430+
rewriter.replaceOp(op, adaptor.getSource());
9431+
return success();
9432+
}
9433+
};
9434+
} // namespace
9435+
93909436
void populateTorchToTosaConversionLegalOps(ConversionTarget &target) {
93919437
// The following ops are never the primary reason why lowering fails.
93929438
// The backend contract only allows functions to return tensors thus there
@@ -9402,6 +9448,21 @@ void populateTorchToTosaConversionLegalOps(ConversionTarget &target) {
94029448
target.addLegalOp<ConstantDeviceOp>();
94039449
target.addLegalOp<PrimListConstructOp>();
94049450
target.addLegalOp<PrimTupleConstructOp>();
9451+
target.addDynamicallyLegalOp<tensor::CastOp>([](tensor::CastOp op) -> bool {
9452+
auto sourceType = dyn_cast<RankedTensorType>(op.getSource().getType());
9453+
auto resultType = dyn_cast<RankedTensorType>(op.getType());
9454+
if (!sourceType || !resultType)
9455+
return true;
9456+
if (sourceType.getElementType() != resultType.getElementType())
9457+
return true;
9458+
if (!sourceType.hasStaticShape())
9459+
return true;
9460+
if (!resultType.hasStaticShape())
9461+
return true;
9462+
if (sourceType == resultType)
9463+
return true;
9464+
return false;
9465+
});
94059466
}
94069467

94079468
std::set<StringRef> populateTorchToTosaConversionPatternsAndIllegalOps(
@@ -9723,6 +9784,8 @@ std::set<StringRef> populateTorchToTosaConversionPatternsAndIllegalOps(
97239784
INSERT_CAST_ATENOP_PATTERN(AtenIntReprOp);
97249785
#undef INSERT_CAST_ATENOP_PATTERN
97259786

9787+
patterns.add<FoldStaticToDynamicTensorCast>(typeConverter, context);
9788+
97269789
return illegalOps;
97279790
}
97289791

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2295,6 +2295,218 @@ class DecomposeAtenTraceOp : public OpRewritePattern<AtenTraceOp> {
22952295
};
22962296
} // namespace
22972297

2298+
namespace {
2299+
// Decompose scaled dot product attention into matmul/softmax pipeline when
2300+
// there is no masking, dropout, causal, or GQA behaviour.
2301+
class DecomposeAtenScaledDotProductAttentionOp
2302+
: public OpRewritePattern<AtenScaledDotProductAttentionOp> {
2303+
public:
2304+
using OpRewritePattern::OpRewritePattern;
2305+
LogicalResult matchAndRewrite(AtenScaledDotProductAttentionOp op,
2306+
PatternRewriter &rewriter) const override {
2307+
Location loc = op.getLoc();
2308+
2309+
if (!isa<Torch::NoneType>(op.getAttnMask().getType()))
2310+
return rewriter.notifyMatchFailure(
2311+
op, "attention mask decomposition not implemented");
2312+
2313+
double dropoutP;
2314+
if (!matchPattern(op.getDropoutP(), m_TorchConstantFloat(&dropoutP)) ||
2315+
dropoutP != 0.0)
2316+
return rewriter.notifyMatchFailure(
2317+
op, "expected dropout_p to be the constant 0.0");
2318+
2319+
bool isCausal;
2320+
if (!matchPattern(op.getIsCausal(), m_TorchConstantBool(&isCausal)) ||
2321+
isCausal)
2322+
return rewriter.notifyMatchFailure(op,
2323+
"causal attention not supported yet");
2324+
2325+
bool enableGqa;
2326+
if (!matchPattern(op.getEnableGqa(), m_TorchConstantBool(&enableGqa)) ||
2327+
enableGqa)
2328+
return rewriter.notifyMatchFailure(op,
2329+
"grouped-query attention unsupported");
2330+
2331+
Value query = op.getQuery();
2332+
Value key = op.getKey();
2333+
Value value = op.getValue();
2334+
2335+
auto queryTensorType = dyn_cast<BaseTensorType>(query.getType());
2336+
auto keyTensorType = dyn_cast<BaseTensorType>(key.getType());
2337+
auto valueTensorType = dyn_cast<BaseTensorType>(value.getType());
2338+
if (!queryTensorType || !keyTensorType || !valueTensorType)
2339+
return rewriter.notifyMatchFailure(op, "expected tensor inputs");
2340+
if (!queryTensorType.hasSizes() || !keyTensorType.hasSizes() ||
2341+
!valueTensorType.hasSizes())
2342+
return rewriter.notifyMatchFailure(
2343+
op, "expected tensor inputs to have known shapes");
2344+
auto queryValueTensorType = dyn_cast<ValueTensorType>(queryTensorType);
2345+
auto keyValueTensorType = dyn_cast<ValueTensorType>(keyTensorType);
2346+
auto valueValueTensorType = dyn_cast<ValueTensorType>(valueTensorType);
2347+
if (!queryValueTensorType || !keyValueTensorType || !valueValueTensorType)
2348+
return rewriter.notifyMatchFailure(op, "expected value tensor semantics");
2349+
2350+
Value oneInt =
2351+
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1));
2352+
Value zeroInt =
2353+
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0));
2354+
Value rank = AtenDimOp::create(rewriter, loc, query);
2355+
Value lastDim = AtenSubIntOp::create(rewriter, loc, rank, oneInt);
2356+
Value headDim = AtenSizeIntOp::create(rewriter, loc, query, lastDim);
2357+
Value seqDimIndex = AtenSubIntOp::create(rewriter, loc, lastDim, oneInt);
2358+
Value seqLen = AtenSizeIntOp::create(rewriter, loc, query, seqDimIndex);
2359+
Value keySeqLen = AtenSizeIntOp::create(rewriter, loc, key, seqDimIndex);
2360+
ArrayRef<int64_t> querySizes = queryValueTensorType.getSizes();
2361+
bool hasExplicitHeadDim = querySizes.size() >= 4;
2362+
Value numHeadsSize =
2363+
hasExplicitHeadDim
2364+
? (Value)AtenSizeIntOp::create(rewriter, loc, query, oneInt)
2365+
: oneInt;
2366+
Value batchSize = AtenSizeIntOp::create(rewriter, loc, query, zeroInt);
2367+
auto listIntType =
2368+
Torch::ListType::get(Torch::IntType::get(rewriter.getContext()));
2369+
2370+
auto getDimValue = [&](int64_t staticDim, Value fallback) -> Value {
2371+
if (staticDim != Torch::kUnknownSize)
2372+
return ConstantIntOp::create(rewriter, loc,
2373+
rewriter.getI64IntegerAttr(staticDim));
2374+
return fallback;
2375+
};
2376+
2377+
Value scaleFloat;
2378+
if (isa<Torch::NoneType>(op.getScale().getType())) {
2379+
Value sqrtHeadDim = AtenSqrtIntOp::create(rewriter, loc, headDim);
2380+
Value oneFloat =
2381+
ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0));
2382+
scaleFloat = AtenDivFloatOp::create(rewriter, loc, oneFloat, sqrtHeadDim);
2383+
} else {
2384+
scaleFloat = op.getScale();
2385+
}
2386+
2387+
Value negTwo =
2388+
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-2));
2389+
Value negOne =
2390+
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-1));
2391+
2392+
ArrayRef<int64_t> keySizes = keyValueTensorType.getSizes();
2393+
SmallVector<int64_t> keyTransposedSizes(keySizes.begin(), keySizes.end());
2394+
if (keyTransposedSizes.size() < 2)
2395+
return rewriter.notifyMatchFailure(
2396+
op, "expected key tensor rank >= 2 for transpose");
2397+
std::swap(keyTransposedSizes[keyTransposedSizes.size() - 1],
2398+
keyTransposedSizes[keyTransposedSizes.size() - 2]);
2399+
ArrayRef<int64_t> keyTransposedRef(keyTransposedSizes);
2400+
std::optional<ArrayRef<int64_t>> keyTransposedOpt(keyTransposedRef);
2401+
Type keyTransposedType = keyValueTensorType.getWithSizesAndDtypeAndSparsity(
2402+
keyTransposedSizes, keyValueTensorType.getOptionalDtype(),
2403+
keyValueTensorType.getOptionalSparsity());
2404+
Value keyTransposed = AtenTransposeIntOp::create(
2405+
rewriter, loc, keyTransposedType, key, negTwo, negOne);
2406+
SmallVector<Value> keyDims;
2407+
auto getOrFallback = [&](ArrayRef<int64_t> staticDims, unsigned idx,
2408+
Value fallback) -> Value {
2409+
return getDimValue(idx < staticDims.size() ? staticDims[idx]
2410+
: Torch::kUnknownSize,
2411+
fallback);
2412+
};
2413+
keyDims.push_back(getOrFallback(keyTransposedSizes, 0, batchSize));
2414+
if (keyTransposedSizes.size() == 4) {
2415+
keyDims.push_back(getOrFallback(keyTransposedSizes, 1, numHeadsSize));
2416+
keyDims.push_back(getOrFallback(keyTransposedSizes, 2, seqLen));
2417+
keyDims.push_back(getOrFallback(keyTransposedSizes, 3, keySeqLen));
2418+
} else {
2419+
keyDims.push_back(getOrFallback(keyTransposedSizes, 1, headDim));
2420+
keyDims.push_back(getOrFallback(keyTransposedSizes, 2, keySeqLen));
2421+
}
2422+
Value keyTransposeShapeList =
2423+
PrimListConstructOp::create(rewriter, loc, listIntType, keyDims);
2424+
keyTransposed = AtenViewOp::create(rewriter, loc, keyTransposedType,
2425+
keyTransposed, keyTransposeShapeList);
2426+
2427+
auto getStaticDim = [](ArrayRef<int64_t> sizes, int64_t index) {
2428+
if (index < 0)
2429+
index += sizes.size();
2430+
if (index < 0 || index >= static_cast<int64_t>(sizes.size()))
2431+
return Torch::kUnknownSize;
2432+
return sizes[index];
2433+
};
2434+
int64_t queryBatchStatic = getStaticDim(querySizes, 0);
2435+
int64_t querySeqStatic = getStaticDim(querySizes, -2);
2436+
int64_t keySeqStatic = getStaticDim(keySizes, -2);
2437+
int64_t queryHeadsStatic =
2438+
hasExplicitHeadDim ? getStaticDim(querySizes, 1) : 1;
2439+
SmallVector<int64_t, 4> scoresSizes;
2440+
if (hasExplicitHeadDim)
2441+
scoresSizes.assign(
2442+
{queryBatchStatic, queryHeadsStatic, querySeqStatic, keySeqStatic});
2443+
else
2444+
scoresSizes.assign({queryBatchStatic, querySeqStatic, keySeqStatic});
2445+
Type scoresType = ValueTensorType::get(
2446+
op->getContext(),
2447+
ArrayRef<int64_t>(scoresSizes.begin(), scoresSizes.end()),
2448+
queryValueTensorType.getOptionalDtype(),
2449+
queryValueTensorType.getOptionalSparsity());
2450+
Value scores =
2451+
AtenMatmulOp::create(rewriter, loc, scoresType, query, keyTransposed);
2452+
SmallVector<Value> scoresDims;
2453+
scoresDims.push_back(getDimValue(scoresSizes[0], batchSize));
2454+
unsigned seqIndex = 1;
2455+
if (hasExplicitHeadDim) {
2456+
scoresDims.push_back(getDimValue(scoresSizes[1], numHeadsSize));
2457+
seqIndex = 2;
2458+
}
2459+
scoresDims.push_back(getDimValue(scoresSizes[seqIndex], seqLen));
2460+
scoresDims.push_back(getDimValue(scoresSizes.back(), keySeqLen));
2461+
Value scoresShapeList =
2462+
PrimListConstructOp::create(rewriter, loc, listIntType, scoresDims);
2463+
scores =
2464+
AtenViewOp::create(rewriter, loc, scoresType, scores, scoresShapeList);
2465+
Value scaledScores =
2466+
AtenMulScalarOp::create(rewriter, loc, scoresType, scores, scaleFloat);
2467+
2468+
SmallVector<int64_t> reducedSizes(scoresSizes.begin(), scoresSizes.end());
2469+
reducedSizes.back() = 1;
2470+
ArrayRef<int64_t> reducedSizesRef(reducedSizes);
2471+
std::optional<ArrayRef<int64_t>> reducedSizesOpt(reducedSizesRef);
2472+
Type reducedValueType =
2473+
ValueTensorType::get(op->getContext(), reducedSizesOpt,
2474+
queryValueTensorType.getOptionalDtype());
2475+
Type reducedIndexType = ValueTensorType::get(
2476+
op->getContext(), reducedSizesOpt,
2477+
IntegerType::get(op->getContext(), 64, IntegerType::Signed));
2478+
Value keepDimTrue =
2479+
ConstantBoolOp::create(rewriter, loc, rewriter.getBoolAttr(true));
2480+
auto maxOp =
2481+
AtenMaxDimOp::create(rewriter, loc, reducedValueType, reducedIndexType,
2482+
scaledScores, negOne, keepDimTrue);
2483+
Value softmaxMax = TensorStaticInfoCastOp::create(
2484+
rewriter, loc, reducedValueType, maxOp.getValues());
2485+
Value centered =
2486+
createTensorSub(rewriter, loc, scoresType, scaledScores, softmaxMax);
2487+
Value unNormalizedExp =
2488+
AtenExpOp::create(rewriter, loc, scoresType, centered);
2489+
SmallVector<Value, 1> softmaxDims{negOne};
2490+
Value dimList =
2491+
PrimListConstructOp::create(rewriter, loc, listIntType, softmaxDims);
2492+
Value noneValue = ConstantNoneOp::create(rewriter, loc);
2493+
Value softmaxDenominator = AtenSumDimIntListOp::create(
2494+
rewriter, loc, reducedValueType, unNormalizedExp, dimList, keepDimTrue,
2495+
noneValue);
2496+
softmaxDenominator = TensorStaticInfoCastOp::create(
2497+
rewriter, loc, reducedValueType, softmaxDenominator);
2498+
Value softmax = AtenDivTensorOp::create(
2499+
rewriter, loc, scoresType, unNormalizedExp, softmaxDenominator);
2500+
2501+
Value output =
2502+
AtenMatmulOp::create(rewriter, loc, op.getType(), softmax, value);
2503+
2504+
rewriter.replaceOp(op, output);
2505+
return success();
2506+
}
2507+
};
2508+
} // namespace
2509+
22982510
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
22992511
// exp(x)/sum(exp(x)).
23002512
// To avoid overflow we use the following decomposition rule:
@@ -13084,6 +13296,8 @@ class DecomposeComplexOpsPass
1308413296
legalOpsSet.clear();
1308513297
legalOpsSet.insert(legalOps.begin(), legalOps.end());
1308613298

13299+
patterns.add<DecomposeAtenScaledDotProductAttentionOp>(context);
13300+
1308713301
addPatternIfTargetOpIsIllegal<DecomposeAten_WeightNormInterfaceOp>(
1308813302
patterns);
1308913303
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);

lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//===----------------------------------------------------------------------===//
99

1010
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
11+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1112
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
1213

1314
using namespace mlir;
@@ -40,6 +41,23 @@ static void setupValueTensorToBuiltinTensorConversion(
4041
return {};
4142
return ToBuiltinTensorOp::create(builder, loc, type, inputs[0]);
4243
});
44+
typeConverter.addTargetMaterialization([](OpBuilder &builder, Type type,
45+
ValueRange inputs,
46+
Location loc) -> Value {
47+
if (inputs.size() != 1)
48+
return Value();
49+
auto fromType = dyn_cast<RankedTensorType>(inputs[0].getType());
50+
auto toType = dyn_cast<RankedTensorType>(type);
51+
if (!fromType || !toType)
52+
return Value();
53+
if (fromType == toType)
54+
return inputs[0];
55+
if (fromType.getElementType() != toType.getElementType())
56+
return Value();
57+
if (!toType.hasStaticShape())
58+
return Value();
59+
return tensor::CastOp::create(builder, loc, toType, inputs[0]);
60+
});
4361
auto sourceMaterialization = [](OpBuilder &builder,
4462
Torch::ValueTensorType type,
4563
ValueRange inputs, Location loc) -> Value {

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,8 @@
5050
"ScaledDotProductAttentionBoolMaskModule_basic",
5151
"ScaledDotProductAttentionDifferentCausalModule_basic",
5252
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
53-
"ScaledDotProductAttentionDifferentModule_basic",
5453
"ScaledDotProductAttentionMaskModule_basic",
5554
"ScaledDotProductAttentionSameCausalModule_basic",
56-
"ScaledDotProductAttentionSameDynamicModule_basic",
57-
"ScaledDotProductAttentionSameModule_basic",
5855
}
5956

6057
LINALG_CRASHING_SET = {
@@ -953,11 +950,8 @@
953950
"ScaledDotProductAttentionBoolMaskModule_basic",
954951
"ScaledDotProductAttentionDifferentCausalModule_basic",
955952
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
956-
"ScaledDotProductAttentionDifferentModule_basic",
957953
"ScaledDotProductAttentionMaskModule_basic",
958954
"ScaledDotProductAttentionSameCausalModule_basic",
959-
"ScaledDotProductAttentionSameDynamicModule_basic",
960-
"ScaledDotProductAttentionSameModule_basic",
961955
"SubIntModule_basic",
962956
"TensorToIntZeroRank_basic",
963957
"UpSampleNearest2dDynamicFactor_basic",
@@ -3978,11 +3972,8 @@
39783972
"ScaledDotProductAttentionBoolMaskModule_basic",
39793973
"ScaledDotProductAttentionDifferentCausalModule_basic",
39803974
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
3981-
"ScaledDotProductAttentionDifferentModule_basic",
39823975
"ScaledDotProductAttentionMaskModule_basic",
39833976
"ScaledDotProductAttentionSameCausalModule_basic",
3984-
"ScaledDotProductAttentionSameDynamicModule_basic",
3985-
"ScaledDotProductAttentionSameModule_basic",
39863977
"ScaledDotProductAttentionGQAModule_basic",
39873978
# error: 'tosa.scatter' op requires dimensions K >= W
39883979
"IndexPut1DFloatNonAccumulateModule_basic",
@@ -4887,7 +4878,6 @@
48874878
# REMOVE WHEN ENABLE_GQA IS ADDED
48884879
"ScaledDotProductAttentionBoolMaskModule_basic",
48894880
"ScaledDotProductAttentionSameCausalModule_basic",
4890-
"ScaledDotProductAttentionSameDynamicModule_basic",
48914881
"ScatterAddDynamicModule_basic",
48924882
"ScatterReduceFloatMaxModule",
48934883
"ScatterReduceFloatMaxModuleIncludeSelf",

0 commit comments

Comments
 (0)