Skip to content

Commit c96e922

Browse files
authored
[TOSA] MultiheadAttention legalization (#4382)
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
1 parent 67022a8 commit c96e922

File tree

6 files changed

+334
-22
lines changed

6 files changed

+334
-22
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4534,9 +4534,27 @@ LogicalResult ConvertAtenOp<AtenTransposeIntOp>::matchAndRewrite(
45344534
transposedDims[dim0] = dim1;
45354535
transposedDims[dim1] = dim0;
45364536

4537-
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
4538-
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
4537+
Type expectedResultType = getTypeConverter()->convertType(op.getType());
4538+
if (!expectedResultType)
4539+
return rewriter.notifyMatchFailure(
4540+
op, "failed to convert transpose result type");
4541+
4542+
auto elementType = cast<TensorType>(selfType).getElementType();
4543+
auto unrankedResultType = UnrankedTensorType::get(elementType);
4544+
auto transpose = tosa::CreateOpAndInfer<tosa::TransposeOp>(
4545+
rewriter, op->getLoc(), unrankedResultType, adaptor.getSelf(),
45394546
rewriter.getDenseI32ArrayAttr(transposedDims));
4547+
Value resultValue = transpose.getResult();
4548+
if (resultValue.getType() != expectedResultType) {
4549+
if (!tensor::CastOp::areCastCompatible(resultValue.getType(),
4550+
expectedResultType))
4551+
return rewriter.notifyMatchFailure(
4552+
op, "transpose result incompatible with expected type");
4553+
auto castOp = tensor::CastOp::create(rewriter, op->getLoc(),
4554+
expectedResultType, resultValue);
4555+
resultValue = castOp.getResult();
4556+
}
4557+
rewriter.replaceOp(op, resultValue);
45404558

45414559
return success();
45424560
}
@@ -8172,7 +8190,7 @@ LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
81728190
makeShapeLLVMCompatible(transposedInputShape), selfElemTy);
81738191
SmallVector<int64_t> startSlice(selfRank, 0);
81748192
SmallVector<int64_t> sizeSlice =
8175-
llvm::to_vector(makeShapeTorchCompatible(transposedInputShape));
8193+
makeShapeTorchCompatible(transposedInputShape);
81768194
if (offset < 0)
81778195
startSlice[targetDim1] = std::abs(offset);
81788196
diagonalTensor = tosa::SliceOp::create(
@@ -10243,6 +10261,21 @@ void populateTorchToTosaConversionLegalOps(ConversionTarget &target) {
1024310261
target.addLegalOp<ConstantDeviceOp>();
1024410262
target.addLegalOp<PrimListConstructOp>();
1024510263
target.addLegalOp<PrimTupleConstructOp>();
10264+
target.addDynamicallyLegalOp<tensor::CastOp>([](tensor::CastOp op) -> bool {
10265+
auto sourceType = dyn_cast<RankedTensorType>(op.getSource().getType());
10266+
auto resultType = dyn_cast<RankedTensorType>(op.getType());
10267+
if (!sourceType || !resultType)
10268+
return true;
10269+
if (sourceType.getElementType() != resultType.getElementType())
10270+
return true;
10271+
if (!sourceType.hasStaticShape())
10272+
return true;
10273+
if (!resultType.hasStaticShape())
10274+
return true;
10275+
if (sourceType == resultType)
10276+
return true;
10277+
return false;
10278+
});
1024610279
}
1024710280

1024810281
std::set<StringRef> populateTorchToTosaConversionPatternsAndIllegalOps(

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 213 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2295,17 +2295,216 @@ class DecomposeAtenTraceOp : public OpRewritePattern<AtenTraceOp> {
22952295
};
22962296
} // namespace
22972297

2298+
static Value getSoftmaxResult(Operation *op, Value self, Value dim,
2299+
Type resultType, Type accumulatorType,
2300+
PatternRewriter &rewriter);
2301+
2302+
namespace {
2303+
// Decompose scaled dot product attention into matmul/softmax pipeline when
2304+
// there is no masking, dropout, causal, or GQA behaviour.
2305+
class DecomposeAtenScaledDotProductAttentionOp
2306+
: public OpRewritePattern<AtenScaledDotProductAttentionOp> {
2307+
public:
2308+
using OpRewritePattern::OpRewritePattern;
2309+
LogicalResult matchAndRewrite(AtenScaledDotProductAttentionOp op,
2310+
PatternRewriter &rewriter) const override {
2311+
Location loc = op.getLoc();
2312+
2313+
if (!isa<Torch::NoneType>(op.getAttnMask().getType()))
2314+
return rewriter.notifyMatchFailure(
2315+
op, "attention mask decomposition not implemented");
2316+
2317+
double dropoutP;
2318+
if (!matchPattern(op.getDropoutP(), m_TorchConstantFloat(&dropoutP)) ||
2319+
dropoutP != 0.0)
2320+
return rewriter.notifyMatchFailure(
2321+
op, "expected dropout_p to be the constant 0.0");
2322+
2323+
bool isCausal;
2324+
if (!matchPattern(op.getIsCausal(), m_TorchConstantBool(&isCausal)) ||
2325+
isCausal)
2326+
return rewriter.notifyMatchFailure(op,
2327+
"causal attention not supported yet");
2328+
2329+
bool enableGqa;
2330+
if (!matchPattern(op.getEnableGqa(), m_TorchConstantBool(&enableGqa)) ||
2331+
enableGqa)
2332+
return rewriter.notifyMatchFailure(op,
2333+
"grouped-query attention unsupported");
2334+
2335+
Value query = op.getQuery();
2336+
Value key = op.getKey();
2337+
Value value = op.getValue();
2338+
2339+
auto queryValueTensorType = dyn_cast<ValueTensorType>(query.getType());
2340+
auto keyValueTensorType = dyn_cast<ValueTensorType>(key.getType());
2341+
auto valueValueTensorType = dyn_cast<ValueTensorType>(value.getType());
2342+
if (!queryValueTensorType || !keyValueTensorType || !valueValueTensorType)
2343+
return rewriter.notifyMatchFailure(op, "expected value tensor semantics");
2344+
if (!queryValueTensorType.hasSizes() || !keyValueTensorType.hasSizes() ||
2345+
!valueValueTensorType.hasSizes())
2346+
return rewriter.notifyMatchFailure(
2347+
op, "expected tensor inputs to have known shapes");
2348+
if (!queryValueTensorType.hasDtype() || !keyValueTensorType.hasDtype() ||
2349+
!valueValueTensorType.hasDtype())
2350+
return rewriter.notifyMatchFailure(
2351+
op, "expected tensor inputs to have dtypes");
2352+
Type queryDtype = queryValueTensorType.getDtype();
2353+
Type keyDtype = keyValueTensorType.getDtype();
2354+
Type valueDtype = valueValueTensorType.getDtype();
2355+
if (queryDtype != keyDtype || queryDtype != valueDtype)
2356+
return rewriter.notifyMatchFailure(
2357+
op, "expected query, key, and value to share dtype");
2358+
2359+
ArrayRef<int64_t> querySizes = queryValueTensorType.getSizes();
2360+
int64_t queryRank = querySizes.size();
2361+
if (queryRank < 3 || queryRank > 4)
2362+
return rewriter.notifyMatchFailure(
2363+
op, "expected query tensor rank to be 3 or 4");
2364+
ArrayRef<int64_t> keySizes = keyValueTensorType.getSizes();
2365+
ArrayRef<int64_t> valueSizes = valueValueTensorType.getSizes();
2366+
if (static_cast<int64_t>(keySizes.size()) != queryRank ||
2367+
static_cast<int64_t>(valueSizes.size()) != queryRank)
2368+
return rewriter.notifyMatchFailure(
2369+
op, "expected query, key, and value to share rank");
2370+
bool hasExplicitHeadDim = queryRank == 4;
2371+
Value oneInt =
2372+
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1));
2373+
Value zeroInt =
2374+
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0));
2375+
Value rank = AtenDimOp::create(rewriter, loc, query);
2376+
Value lastDim = AtenSubIntOp::create(rewriter, loc, rank, oneInt);
2377+
Value headDim = AtenSizeIntOp::create(rewriter, loc, query, lastDim);
2378+
Value seqDimIndex = AtenSubIntOp::create(rewriter, loc, lastDim, oneInt);
2379+
Value seqLen = AtenSizeIntOp::create(rewriter, loc, query, seqDimIndex);
2380+
Value keySeqLen = AtenSizeIntOp::create(rewriter, loc, key, seqDimIndex);
2381+
Value numHeadsSize =
2382+
hasExplicitHeadDim
2383+
? (Value)AtenSizeIntOp::create(rewriter, loc, query, oneInt)
2384+
: oneInt;
2385+
Value batchSize = AtenSizeIntOp::create(rewriter, loc, query, zeroInt);
2386+
auto listIntType =
2387+
Torch::ListType::get(Torch::IntType::get(rewriter.getContext()));
2388+
2389+
auto getDimValue = [&](int64_t staticDim, Value fallback) -> Value {
2390+
if (staticDim != Torch::kUnknownSize)
2391+
return ConstantIntOp::create(rewriter, loc,
2392+
rewriter.getI64IntegerAttr(staticDim));
2393+
return fallback;
2394+
};
2395+
2396+
Value scaleFloat;
2397+
if (isa<Torch::NoneType>(op.getScale().getType())) {
2398+
Value sqrtHeadDim = AtenSqrtIntOp::create(rewriter, loc, headDim);
2399+
Value oneFloat =
2400+
ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0));
2401+
scaleFloat = AtenDivFloatOp::create(rewriter, loc, oneFloat, sqrtHeadDim);
2402+
} else {
2403+
scaleFloat = op.getScale();
2404+
}
2405+
2406+
Value negTwo =
2407+
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-2));
2408+
Value negOne =
2409+
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-1));
2410+
2411+
SmallVector<int64_t> keyTransposedSizes(keySizes.begin(), keySizes.end());
2412+
std::swap(keyTransposedSizes[keyTransposedSizes.size() - 1],
2413+
keyTransposedSizes[keyTransposedSizes.size() - 2]);
2414+
ArrayRef<int64_t> keyTransposedRef(keyTransposedSizes);
2415+
std::optional<ArrayRef<int64_t>> keyTransposedOpt(keyTransposedRef);
2416+
Type keyTransposedType = keyValueTensorType.getWithSizesAndDtypeAndSparsity(
2417+
keyTransposedSizes, keyValueTensorType.getOptionalDtype(),
2418+
keyValueTensorType.getOptionalSparsity());
2419+
Value keyTransposed = AtenTransposeIntOp::create(
2420+
rewriter, loc, keyTransposedType, key, negTwo, negOne);
2421+
SmallVector<Value> keyDims;
2422+
auto getOrFallback = [&](ArrayRef<int64_t> staticDims, unsigned idx,
2423+
Value fallback) -> Value {
2424+
return getDimValue(idx < staticDims.size() ? staticDims[idx]
2425+
: Torch::kUnknownSize,
2426+
fallback);
2427+
};
2428+
keyDims.push_back(getOrFallback(keyTransposedSizes, 0, batchSize));
2429+
if (hasExplicitHeadDim) {
2430+
keyDims.push_back(getOrFallback(keyTransposedSizes, 1, numHeadsSize));
2431+
keyDims.push_back(getOrFallback(keyTransposedSizes, 2, headDim));
2432+
keyDims.push_back(getOrFallback(keyTransposedSizes, 3, keySeqLen));
2433+
} else {
2434+
keyDims.push_back(getOrFallback(keyTransposedSizes, 1, headDim));
2435+
keyDims.push_back(getOrFallback(keyTransposedSizes, 2, keySeqLen));
2436+
}
2437+
Value keyTransposeShapeList =
2438+
PrimListConstructOp::create(rewriter, loc, listIntType, keyDims);
2439+
keyTransposed = AtenViewOp::create(rewriter, loc, keyTransposedType,
2440+
keyTransposed, keyTransposeShapeList);
2441+
2442+
auto getStaticDim = [](ArrayRef<int64_t> sizes, int64_t index) {
2443+
if (index < 0)
2444+
index += sizes.size();
2445+
if (index < 0 || index >= static_cast<int64_t>(sizes.size()))
2446+
return Torch::kUnknownSize;
2447+
return sizes[index];
2448+
};
2449+
int64_t queryBatchStatic = getStaticDim(querySizes, 0);
2450+
int64_t querySeqStatic = getStaticDim(querySizes, -2);
2451+
int64_t keySeqStatic = getStaticDim(keySizes, -2);
2452+
int64_t queryHeadsStatic =
2453+
hasExplicitHeadDim ? getStaticDim(querySizes, 1) : 1;
2454+
SmallVector<int64_t, 4> scoresSizes;
2455+
if (hasExplicitHeadDim)
2456+
scoresSizes.assign(
2457+
{queryBatchStatic, queryHeadsStatic, querySeqStatic, keySeqStatic});
2458+
else
2459+
scoresSizes.assign({queryBatchStatic, querySeqStatic, keySeqStatic});
2460+
Type scoresType = ValueTensorType::get(
2461+
op->getContext(),
2462+
ArrayRef<int64_t>(scoresSizes.begin(), scoresSizes.end()),
2463+
queryValueTensorType.getOptionalDtype(),
2464+
queryValueTensorType.getOptionalSparsity());
2465+
Value scores =
2466+
AtenMatmulOp::create(rewriter, loc, scoresType, query, keyTransposed);
2467+
SmallVector<Value> scoresDims;
2468+
scoresDims.push_back(getDimValue(scoresSizes[0], batchSize));
2469+
unsigned seqIndex = 1;
2470+
if (hasExplicitHeadDim) {
2471+
scoresDims.push_back(getDimValue(scoresSizes[1], numHeadsSize));
2472+
seqIndex = 2;
2473+
}
2474+
scoresDims.push_back(getDimValue(scoresSizes[seqIndex], seqLen));
2475+
scoresDims.push_back(getDimValue(scoresSizes.back(), keySeqLen));
2476+
Value scoresShapeList =
2477+
PrimListConstructOp::create(rewriter, loc, listIntType, scoresDims);
2478+
scores =
2479+
AtenViewOp::create(rewriter, loc, scoresType, scores, scoresShapeList);
2480+
Value scaledScores =
2481+
AtenMulScalarOp::create(rewriter, loc, scoresType, scores, scaleFloat);
2482+
2483+
Value softmax = getSoftmaxResult(op.getOperation(), scaledScores, negOne,
2484+
scoresType, scoresType, rewriter);
2485+
if (!softmax)
2486+
return rewriter.notifyMatchFailure(op,
2487+
"failed to compute softmax scores");
2488+
2489+
Value output =
2490+
AtenMatmulOp::create(rewriter, loc, op.getType(), softmax, value);
2491+
2492+
rewriter.replaceOp(op, output);
2493+
return success();
2494+
}
2495+
};
2496+
} // namespace
2497+
22982498
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
22992499
// exp(x)/sum(exp(x)).
23002500
// To avoid overflow we use the following decomposition rule:
23012501
// x_max = max(input, dim, keepdim = True)
23022502
// unnorm = aten.exp(input - x_max)
23032503
// softmax = unnorm / sum(unnorm, dim, keepdim = True)
2304-
template <typename OpTy>
2305-
static Value getSoftmaxResult(OpTy op, Value self, Type resultType,
2306-
Type accumulatorType, PatternRewriter &rewriter) {
2307-
Location loc = op.getLoc();
2308-
Value dim = op.getDim();
2504+
static Value getSoftmaxResult(Operation *op, Value self, Value dim,
2505+
Type resultType, Type accumulatorType,
2506+
PatternRewriter &rewriter) {
2507+
Location loc = op->getLoc();
23092508
if (resultType != accumulatorType)
23102509
self = convertTensorToDtype(rewriter, loc, self, accumulatorType);
23112510
Value xMax =
@@ -2362,8 +2561,9 @@ class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
23622561

23632562
Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype);
23642563

2365-
Value result = getSoftmaxResult(op, self, resultTensorType,
2366-
accumulatorTensorType, rewriter);
2564+
Value result =
2565+
getSoftmaxResult(op.getOperation(), self, op.getDim(), resultTensorType,
2566+
accumulatorTensorType, rewriter);
23672567
if (!result)
23682568
return failure();
23692569
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
@@ -2411,8 +2611,9 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
24112611

24122612
Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype);
24132613

2414-
Value result = getSoftmaxResult(op, self, resultTensorType,
2415-
accumulatorTensorType, rewriter);
2614+
Value result =
2615+
getSoftmaxResult(op.getOperation(), self, op.getDim(), resultTensorType,
2616+
accumulatorTensorType, rewriter);
24162617
if (!result)
24172618
return op.emitError("failed to get softmax result");
24182619
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, resultTensorType,
@@ -13103,6 +13304,9 @@ class DecomposeComplexOpsPass
1310313304
legalOpsSet.clear();
1310413305
legalOpsSet.insert(legalOps.begin(), legalOps.end());
1310513306

13307+
addPatternIfTargetOpIsIllegal<DecomposeAtenScaledDotProductAttentionOp>(
13308+
patterns);
13309+
1310613310
addPatternIfTargetOpIsIllegal<DecomposeAten_WeightNormInterfaceOp>(
1310713311
patterns);
1310813312
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);

lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//===----------------------------------------------------------------------===//
99

1010
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
11+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1112
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
1213

1314
using namespace mlir;
@@ -40,6 +41,25 @@ static void setupValueTensorToBuiltinTensorConversion(
4041
return {};
4142
return ToBuiltinTensorOp::create(builder, loc, type, inputs[0]);
4243
});
44+
typeConverter.addTargetMaterialization([](OpBuilder &builder, Type type,
45+
ValueRange inputs,
46+
Location loc) -> Value {
47+
if (inputs.size() != 1)
48+
return Value();
49+
auto fromType = dyn_cast<RankedTensorType>(inputs[0].getType());
50+
auto toType = dyn_cast<RankedTensorType>(type);
51+
if (!fromType || !toType)
52+
return Value();
53+
if (fromType == toType)
54+
return inputs[0];
55+
if (fromType.getElementType() != toType.getElementType())
56+
return Value();
57+
if (!toType.hasStaticShape())
58+
return Value();
59+
if (!tensor::CastOp::areCastCompatible(inputs[0].getType(), toType))
60+
return Value();
61+
return tensor::CastOp::create(builder, loc, toType, inputs[0]);
62+
});
4363
auto sourceMaterialization = [](OpBuilder &builder,
4464
Torch::ValueTensorType type,
4565
ValueRange inputs, Location loc) -> Value {

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,8 @@
5252
"ScaledDotProductAttentionBoolMaskModule_basic",
5353
"ScaledDotProductAttentionDifferentCausalModule_basic",
5454
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
55-
"ScaledDotProductAttentionDifferentModule_basic",
5655
"ScaledDotProductAttentionMaskModule_basic",
5756
"ScaledDotProductAttentionSameCausalModule_basic",
58-
"ScaledDotProductAttentionSameDynamicModule_basic",
59-
"ScaledDotProductAttentionSameModule_basic",
6057
}
6158

6259
LINALG_CRASHING_SET = {
@@ -959,11 +956,8 @@
959956
"ScaledDotProductAttentionBoolMaskModule_basic",
960957
"ScaledDotProductAttentionDifferentCausalModule_basic",
961958
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
962-
"ScaledDotProductAttentionDifferentModule_basic",
963959
"ScaledDotProductAttentionMaskModule_basic",
964960
"ScaledDotProductAttentionSameCausalModule_basic",
965-
"ScaledDotProductAttentionSameDynamicModule_basic",
966-
"ScaledDotProductAttentionSameModule_basic",
967961
"SubIntModule_basic",
968962
"TensorToIntZeroRank_basic",
969963
"UpSampleNearest2dDynamicFactor_basic",
@@ -3992,11 +3986,8 @@
39923986
"ScaledDotProductAttentionBoolMaskModule_basic",
39933987
"ScaledDotProductAttentionDifferentCausalModule_basic",
39943988
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
3995-
"ScaledDotProductAttentionDifferentModule_basic",
39963989
"ScaledDotProductAttentionMaskModule_basic",
39973990
"ScaledDotProductAttentionSameCausalModule_basic",
3998-
"ScaledDotProductAttentionSameDynamicModule_basic",
3999-
"ScaledDotProductAttentionSameModule_basic",
40003991
"ScaledDotProductAttentionGQAModule_basic",
40013992
# error: 'tosa.scatter' op requires dimensions K >= W
40023993
"IndexPut1DFloatNonAccumulateModule_basic",
@@ -4905,7 +4896,6 @@
49054896
# REMOVE WHEN ENABLE_GQA IS ADDED
49064897
"ScaledDotProductAttentionBoolMaskModule_basic",
49074898
"ScaledDotProductAttentionSameCausalModule_basic",
4908-
"ScaledDotProductAttentionSameDynamicModule_basic",
49094899
"ScatterAddDynamicModule_basic",
49104900
"ScatterReduceFloatMaxModule",
49114901
"ScatterReduceFloatMaxModuleIncludeSelf",

0 commit comments

Comments
 (0)