Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4016,8 +4016,28 @@ LogicalResult ConvertAtenOp<AtenTransposeIntOp>::matchAndRewrite(
transposedDims[dim0] = dim1;
transposedDims[dim1] = dim0;

Type resultType = getTypeConverter()->convertType(op.getType());
if (auto rankedSelf = dyn_cast<RankedTensorType>(selfType)) {
SmallVector<int64_t> transposedShape(rankedSelf.getRank(),
ShapedType::kDynamic);
if (rankedSelf.hasStaticShape()) {
auto staticShape =
llvm::to_vector(makeShapeTorchCompatible(rankedSelf.getShape()));
auto dim0Index = static_cast<size_t>(dim0);
auto dim1Index = static_cast<size_t>(dim1);
if (dim0Index < staticShape.size() && dim1Index < staticShape.size())
std::swap(staticShape[dim0Index], staticShape[dim1Index]);
for (size_t i = 0; i < staticShape.size(); ++i)
transposedShape[i] = staticShape[i];
}
auto rankedResult = RankedTensorType::get(
makeShapeLLVMCompatible(transposedShape), rankedSelf.getElementType());
if (auto converted = getTypeConverter()->convertType(rankedResult))
resultType = converted;
}

rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
op, resultType, adaptor.getSelf(),
rewriter.getDenseI32ArrayAttr(transposedDims));

return success();
Expand Down Expand Up @@ -9387,6 +9407,32 @@ class ConvertTorchToTosa
};
} // namespace

namespace {
class FoldStaticToDynamicTensorCast
: public OpConversionPattern<tensor::CastOp> {
public:
using OpConversionPattern<tensor::CastOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto sourceType = dyn_cast<RankedTensorType>(adaptor.getSource().getType());
auto resultType = dyn_cast<RankedTensorType>(op.getType());
if (!sourceType || !resultType)
return failure();
if (sourceType.getElementType() != resultType.getElementType())
return failure();
if (!sourceType.hasStaticShape())
return failure();
if (!resultType.hasStaticShape())
return failure();
if (sourceType == resultType)
return failure();
rewriter.replaceOp(op, adaptor.getSource());
return success();
}
};
} // namespace

void populateTorchToTosaConversionLegalOps(ConversionTarget &target) {
// The following ops are never the primary reason why lowering fails.
// The backend contract only allows functions to return tensors thus there
Expand All @@ -9402,6 +9448,21 @@ void populateTorchToTosaConversionLegalOps(ConversionTarget &target) {
target.addLegalOp<ConstantDeviceOp>();
target.addLegalOp<PrimListConstructOp>();
target.addLegalOp<PrimTupleConstructOp>();
target.addDynamicallyLegalOp<tensor::CastOp>([](tensor::CastOp op) -> bool {
auto sourceType = dyn_cast<RankedTensorType>(op.getSource().getType());
auto resultType = dyn_cast<RankedTensorType>(op.getType());
if (!sourceType || !resultType)
return true;
if (sourceType.getElementType() != resultType.getElementType())
return true;
if (!sourceType.hasStaticShape())
return true;
if (!resultType.hasStaticShape())
return true;
if (sourceType == resultType)
return true;
return false;
});
}

std::set<StringRef> populateTorchToTosaConversionPatternsAndIllegalOps(
Expand Down Expand Up @@ -9723,6 +9784,8 @@ std::set<StringRef> populateTorchToTosaConversionPatternsAndIllegalOps(
INSERT_CAST_ATENOP_PATTERN(AtenIntReprOp);
#undef INSERT_CAST_ATENOP_PATTERN

patterns.add<FoldStaticToDynamicTensorCast>(typeConverter, context);

return illegalOps;
}

Expand Down
214 changes: 214 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2295,6 +2295,218 @@ class DecomposeAtenTraceOp : public OpRewritePattern<AtenTraceOp> {
};
} // namespace

namespace {
// Decompose scaled dot product attention into matmul/softmax pipeline when
// there is no masking, dropout, causal, or GQA behaviour.
class DecomposeAtenScaledDotProductAttentionOp
: public OpRewritePattern<AtenScaledDotProductAttentionOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenScaledDotProductAttentionOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();

if (!isa<Torch::NoneType>(op.getAttnMask().getType()))
return rewriter.notifyMatchFailure(
op, "attention mask decomposition not implemented");

double dropoutP;
if (!matchPattern(op.getDropoutP(), m_TorchConstantFloat(&dropoutP)) ||
dropoutP != 0.0)
return rewriter.notifyMatchFailure(
op, "expected dropout_p to be the constant 0.0");

bool isCausal;
if (!matchPattern(op.getIsCausal(), m_TorchConstantBool(&isCausal)) ||
isCausal)
return rewriter.notifyMatchFailure(op,
"causal attention not supported yet");

bool enableGqa;
if (!matchPattern(op.getEnableGqa(), m_TorchConstantBool(&enableGqa)) ||
enableGqa)
return rewriter.notifyMatchFailure(op,
"grouped-query attention unsupported");

Value query = op.getQuery();
Value key = op.getKey();
Value value = op.getValue();

auto queryTensorType = dyn_cast<BaseTensorType>(query.getType());
auto keyTensorType = dyn_cast<BaseTensorType>(key.getType());
auto valueTensorType = dyn_cast<BaseTensorType>(value.getType());
if (!queryTensorType || !keyTensorType || !valueTensorType)
return rewriter.notifyMatchFailure(op, "expected tensor inputs");
if (!queryTensorType.hasSizes() || !keyTensorType.hasSizes() ||
!valueTensorType.hasSizes())
return rewriter.notifyMatchFailure(
op, "expected tensor inputs to have known shapes");
auto queryValueTensorType = dyn_cast<ValueTensorType>(queryTensorType);
auto keyValueTensorType = dyn_cast<ValueTensorType>(keyTensorType);
auto valueValueTensorType = dyn_cast<ValueTensorType>(valueTensorType);
if (!queryValueTensorType || !keyValueTensorType || !valueValueTensorType)
return rewriter.notifyMatchFailure(op, "expected value tensor semantics");

Value oneInt =
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1));
Value zeroInt =
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0));
Value rank = AtenDimOp::create(rewriter, loc, query);
Value lastDim = AtenSubIntOp::create(rewriter, loc, rank, oneInt);
Value headDim = AtenSizeIntOp::create(rewriter, loc, query, lastDim);
Value seqDimIndex = AtenSubIntOp::create(rewriter, loc, lastDim, oneInt);
Value seqLen = AtenSizeIntOp::create(rewriter, loc, query, seqDimIndex);
Value keySeqLen = AtenSizeIntOp::create(rewriter, loc, key, seqDimIndex);
ArrayRef<int64_t> querySizes = queryValueTensorType.getSizes();
bool hasExplicitHeadDim = querySizes.size() >= 4;
Value numHeadsSize =
hasExplicitHeadDim
? (Value)AtenSizeIntOp::create(rewriter, loc, query, oneInt)
: oneInt;
Value batchSize = AtenSizeIntOp::create(rewriter, loc, query, zeroInt);
auto listIntType =
Torch::ListType::get(Torch::IntType::get(rewriter.getContext()));

auto getDimValue = [&](int64_t staticDim, Value fallback) -> Value {
if (staticDim != Torch::kUnknownSize)
return ConstantIntOp::create(rewriter, loc,
rewriter.getI64IntegerAttr(staticDim));
return fallback;
};

Value scaleFloat;
if (isa<Torch::NoneType>(op.getScale().getType())) {
Value sqrtHeadDim = AtenSqrtIntOp::create(rewriter, loc, headDim);
Value oneFloat =
ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0));
scaleFloat = AtenDivFloatOp::create(rewriter, loc, oneFloat, sqrtHeadDim);
} else {
scaleFloat = op.getScale();
}

Value negTwo =
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-2));
Value negOne =
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-1));

ArrayRef<int64_t> keySizes = keyValueTensorType.getSizes();
SmallVector<int64_t> keyTransposedSizes(keySizes.begin(), keySizes.end());
if (keyTransposedSizes.size() < 2)
return rewriter.notifyMatchFailure(
op, "expected key tensor rank >= 2 for transpose");
std::swap(keyTransposedSizes[keyTransposedSizes.size() - 1],
keyTransposedSizes[keyTransposedSizes.size() - 2]);
ArrayRef<int64_t> keyTransposedRef(keyTransposedSizes);
std::optional<ArrayRef<int64_t>> keyTransposedOpt(keyTransposedRef);
Type keyTransposedType = keyValueTensorType.getWithSizesAndDtypeAndSparsity(
keyTransposedSizes, keyValueTensorType.getOptionalDtype(),
keyValueTensorType.getOptionalSparsity());
Value keyTransposed = AtenTransposeIntOp::create(
rewriter, loc, keyTransposedType, key, negTwo, negOne);
SmallVector<Value> keyDims;
auto getOrFallback = [&](ArrayRef<int64_t> staticDims, unsigned idx,
Value fallback) -> Value {
return getDimValue(idx < staticDims.size() ? staticDims[idx]
: Torch::kUnknownSize,
fallback);
};
keyDims.push_back(getOrFallback(keyTransposedSizes, 0, batchSize));
if (keyTransposedSizes.size() == 4) {
keyDims.push_back(getOrFallback(keyTransposedSizes, 1, numHeadsSize));
keyDims.push_back(getOrFallback(keyTransposedSizes, 2, seqLen));
keyDims.push_back(getOrFallback(keyTransposedSizes, 3, keySeqLen));
} else {
keyDims.push_back(getOrFallback(keyTransposedSizes, 1, headDim));
keyDims.push_back(getOrFallback(keyTransposedSizes, 2, keySeqLen));
}
Value keyTransposeShapeList =
PrimListConstructOp::create(rewriter, loc, listIntType, keyDims);
keyTransposed = AtenViewOp::create(rewriter, loc, keyTransposedType,
keyTransposed, keyTransposeShapeList);

auto getStaticDim = [](ArrayRef<int64_t> sizes, int64_t index) {
if (index < 0)
index += sizes.size();
if (index < 0 || index >= static_cast<int64_t>(sizes.size()))
return Torch::kUnknownSize;
return sizes[index];
};
int64_t queryBatchStatic = getStaticDim(querySizes, 0);
int64_t querySeqStatic = getStaticDim(querySizes, -2);
int64_t keySeqStatic = getStaticDim(keySizes, -2);
int64_t queryHeadsStatic =
hasExplicitHeadDim ? getStaticDim(querySizes, 1) : 1;
SmallVector<int64_t, 4> scoresSizes;
if (hasExplicitHeadDim)
scoresSizes.assign(
{queryBatchStatic, queryHeadsStatic, querySeqStatic, keySeqStatic});
else
scoresSizes.assign({queryBatchStatic, querySeqStatic, keySeqStatic});
Type scoresType = ValueTensorType::get(
op->getContext(),
ArrayRef<int64_t>(scoresSizes.begin(), scoresSizes.end()),
queryValueTensorType.getOptionalDtype(),
queryValueTensorType.getOptionalSparsity());
Value scores =
AtenMatmulOp::create(rewriter, loc, scoresType, query, keyTransposed);
SmallVector<Value> scoresDims;
scoresDims.push_back(getDimValue(scoresSizes[0], batchSize));
unsigned seqIndex = 1;
if (hasExplicitHeadDim) {
scoresDims.push_back(getDimValue(scoresSizes[1], numHeadsSize));
seqIndex = 2;
}
scoresDims.push_back(getDimValue(scoresSizes[seqIndex], seqLen));
scoresDims.push_back(getDimValue(scoresSizes.back(), keySeqLen));
Value scoresShapeList =
PrimListConstructOp::create(rewriter, loc, listIntType, scoresDims);
scores =
AtenViewOp::create(rewriter, loc, scoresType, scores, scoresShapeList);
Value scaledScores =
AtenMulScalarOp::create(rewriter, loc, scoresType, scores, scaleFloat);

SmallVector<int64_t> reducedSizes(scoresSizes.begin(), scoresSizes.end());
reducedSizes.back() = 1;
ArrayRef<int64_t> reducedSizesRef(reducedSizes);
std::optional<ArrayRef<int64_t>> reducedSizesOpt(reducedSizesRef);
Type reducedValueType =
ValueTensorType::get(op->getContext(), reducedSizesOpt,
queryValueTensorType.getOptionalDtype());
Type reducedIndexType = ValueTensorType::get(
op->getContext(), reducedSizesOpt,
IntegerType::get(op->getContext(), 64, IntegerType::Signed));
Value keepDimTrue =
ConstantBoolOp::create(rewriter, loc, rewriter.getBoolAttr(true));
auto maxOp =
AtenMaxDimOp::create(rewriter, loc, reducedValueType, reducedIndexType,
scaledScores, negOne, keepDimTrue);
Value softmaxMax = TensorStaticInfoCastOp::create(
rewriter, loc, reducedValueType, maxOp.getValues());
Value centered =
createTensorSub(rewriter, loc, scoresType, scaledScores, softmaxMax);
Value unNormalizedExp =
AtenExpOp::create(rewriter, loc, scoresType, centered);
SmallVector<Value, 1> softmaxDims{negOne};
Value dimList =
PrimListConstructOp::create(rewriter, loc, listIntType, softmaxDims);
Value noneValue = ConstantNoneOp::create(rewriter, loc);
Value softmaxDenominator = AtenSumDimIntListOp::create(
rewriter, loc, reducedValueType, unNormalizedExp, dimList, keepDimTrue,
noneValue);
softmaxDenominator = TensorStaticInfoCastOp::create(
rewriter, loc, reducedValueType, softmaxDenominator);
Value softmax = AtenDivTensorOp::create(
rewriter, loc, scoresType, unNormalizedExp, softmaxDenominator);

Value output =
AtenMatmulOp::create(rewriter, loc, op.getType(), softmax, value);

rewriter.replaceOp(op, output);
return success();
}
};
} // namespace

// Calculates the softmax function on the given `input` tensor. Softmax(x) =
// exp(x)/sum(exp(x)).
// To avoid overflow we use the following decomposition rule:
Expand Down Expand Up @@ -13084,6 +13296,8 @@ class DecomposeComplexOpsPass
legalOpsSet.clear();
legalOpsSet.insert(legalOps.begin(), legalOps.end());

patterns.add<DecomposeAtenScaledDotProductAttentionOp>(context);

addPatternIfTargetOpIsIllegal<DecomposeAten_WeightNormInterfaceOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);
Expand Down
18 changes: 18 additions & 0 deletions lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"

using namespace mlir;
Expand Down Expand Up @@ -40,6 +41,23 @@ static void setupValueTensorToBuiltinTensorConversion(
return {};
return ToBuiltinTensorOp::create(builder, loc, type, inputs[0]);
});
typeConverter.addTargetMaterialization([](OpBuilder &builder, Type type,
ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1)
return Value();
auto fromType = dyn_cast<RankedTensorType>(inputs[0].getType());
auto toType = dyn_cast<RankedTensorType>(type);
if (!fromType || !toType)
return Value();
if (fromType == toType)
return inputs[0];
if (fromType.getElementType() != toType.getElementType())
return Value();
if (!toType.hasStaticShape())
return Value();
return tensor::CastOp::create(builder, loc, toType, inputs[0]);
});
auto sourceMaterialization = [](OpBuilder &builder,
Torch::ValueTensorType type,
ValueRange inputs, Location loc) -> Value {
Expand Down
10 changes: 0 additions & 10 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,8 @@
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
}

LINALG_CRASHING_SET = {
Expand Down Expand Up @@ -953,11 +950,8 @@
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
"SubIntModule_basic",
"TensorToIntZeroRank_basic",
"UpSampleNearest2dDynamicFactor_basic",
Expand Down Expand Up @@ -3978,11 +3972,8 @@
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
"ScaledDotProductAttentionGQAModule_basic",
# error: 'tosa.scatter' op requires dimensions K >= W
"IndexPut1DFloatNonAccumulateModule_basic",
Expand Down Expand Up @@ -4887,7 +4878,6 @@
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScatterAddDynamicModule_basic",
"ScatterReduceFloatMaxModule",
"ScatterReduceFloatMaxModuleIncludeSelf",
Expand Down
Loading
Loading