Skip to content

Commit ab84fc2

Browse files
committed
Merge commit '210477d7' into matthias.update_to_torch_2.6.0
2 parents 67d293a + 210477d commit ab84fc2

File tree

11 files changed

+518
-342
lines changed

11 files changed

+518
-342
lines changed

include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,25 @@
1212

1313
#include "mlir/Dialect/Func/IR/FuncOps.h"
1414
#include "mlir/Pass/Pass.h"
15+
#include "mlir/Transforms/DialectConversion.h"
16+
1517
#include <memory>
1618

1719
namespace mlir {
1820
namespace torch {
21+
22+
/// Collect a set of legal/illegal ops for converting Torch operations to Tosa
23+
/// dialect.
24+
void populateTorchToTosaConversionLegalOps(ConversionTarget &target);
25+
26+
/// Collect a set of patterns to convert Torch operations to Tosa dialect +
27+
/// return the set of illegalOps
28+
std::set<StringRef>
29+
populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter,
30+
RewritePatternSet &patterns);
31+
1932
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
20-
}
33+
} // namespace torch
2134
} // namespace mlir
2235

2336
#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H

include/torch-mlir/Conversion/Utils/Utils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,15 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
9797
Value torchOptionalInt, Value builtinInt,
9898
Value defaultValue, Value dimSize);
9999

100+
// Helper function to unsqueeze the input tensor at given dim.
101+
// Returns the unsqueezed tensor or failure.
102+
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
103+
Value input, int64_t dim);
104+
105+
// Helper function to squeeze the input tensor at given dim.
106+
// Returns the squeezed tensor or failure.
107+
FailureOr<Value> squeezeTensor(PatternRewriter &rewriter, Operation *op,
108+
Value input, int64_t dim);
100109
} // namespace Torch
101110
} // namespace torch
102111
} // namespace mlir

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,18 +1093,35 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
10931093
rewriter.replaceOp(binder.op, nllLoss);
10941094
return success();
10951095
});
1096-
patterns.onOp("NonZero", 13,
1097-
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
1098-
Torch::ValueTensorType resultType;
1099-
Value operand;
1100-
if (binder.tensorOperand(operand) ||
1101-
binder.tensorResultType(resultType)) {
1102-
return failure();
1103-
}
1104-
rewriter.replaceOpWithNewOp<Torch::AtenNonzeroOp>(
1105-
binder.op, resultType, operand);
1106-
return success();
1107-
});
1096+
patterns.onOp(
1097+
"NonZero", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1098+
Torch::ValueTensorType resultType;
1099+
Value operand;
1100+
if (binder.tensorOperand(operand) ||
1101+
binder.tensorResultType(resultType)) {
1102+
return failure();
1103+
}
1104+
Value zero = rewriter.create<Torch::ConstantIntOp>(
1105+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
1106+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
1107+
Value one = rewriter.create<Torch::ConstantIntOp>(
1108+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
1109+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
1110+
auto rawSize = resultType.getSizes();
1111+
SmallVector<int64_t> torchResultSize(rawSize.rbegin(), rawSize.rend());
1112+
auto torchResultType = rewriter.getType<Torch::ValueTensorType>(
1113+
torchResultSize, resultType.getDtype());
1114+
auto nonZero = rewriter.create<Torch::AtenNonzeroOp>(
1115+
binder.getLoc(), torchResultType, operand);
1116+
// The output tensor has a shape of ((n, z)), where (n) is the
1117+
// number of dimensions in the input tensor and (z) is the
1118+
// number of non-zero elements2. This is different from
1119+
// PyTorch's default behavior, where the dimensions are
1120+
// reversed.
1121+
rewriter.replaceOpWithNewOp<Torch::AtenTransposeIntOp>(
1122+
binder.op, resultType, nonZero, zero, one);
1123+
return success();
1124+
});
11081125
patterns.onOp(
11091126
"MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
11101127
std::string autoPad;

lib/Conversion/TorchToLinalg/DataMovement.cpp

Lines changed: 13 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,69 +1642,18 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern<AtenSqueezeDimOp> {
16421642
ConversionPatternRewriter &rewriter) const override {
16431643
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
16441644
return failure();
1645-
Value input = adaptor.getSelf();
1646-
auto inputType = cast<RankedTensorType>(input.getType());
1647-
int64_t inputRank = inputType.getRank();
1648-
1649-
if (inputRank == 0) {
1650-
return rewriter.notifyMatchFailure(
1651-
op, "zero input rank should have been handled by the folder");
1652-
}
1653-
16541645
int64_t dim;
16551646
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
16561647
return rewriter.notifyMatchFailure(op, "dim must be constant");
1657-
dim = toPositiveDim(dim, inputRank);
1658-
if (!isValidDim(dim, inputRank))
1659-
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
1660-
1661-
// assert dynamic squeeze dim size == 1
1662-
if (inputType.isDynamicDim(dim)) {
1663-
Value cstDim = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), dim);
1664-
Value dimVal = rewriter.create<tensor::DimOp>(op.getLoc(), input, cstDim);
1665-
Value cstOne = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 1);
1666-
Value cmp = rewriter.create<arith::CmpIOp>(
1667-
op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne);
1668-
rewriter.create<cf::AssertOp>(
1669-
op.getLoc(), cmp,
1670-
rewriter.getStringAttr(
1671-
"Expected dynamic squeeze dim size to be statically 1"));
1672-
}
1673-
1674-
const TypeConverter *typeConverter = getTypeConverter();
1675-
auto resultType =
1676-
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
1677-
int64_t resultRank = resultType.getRank();
16781648

1679-
// If the dim(th) dimension of operand tensor type is not statically unit,
1680-
// `aten.squeeze` will behave as an identity operation.
1681-
if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) {
1682-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
1683-
return success();
1649+
auto squeezeTensorInfo =
1650+
squeezeTensor(rewriter, op, adaptor.getSelf(), dim);
1651+
if (failed(squeezeTensorInfo)) {
1652+
return rewriter.notifyMatchFailure(op,
1653+
"cannot generate unsqueeze tensor");
16841654
}
16851655

1686-
SmallVector<ReassociationIndices> reassociationMap(resultRank);
1687-
bool alreadyCrossedSqueezedDim = false;
1688-
for (int i = 0; i != resultRank; i++) {
1689-
if (alreadyCrossedSqueezedDim) {
1690-
reassociationMap[i].push_back(i + 1);
1691-
} else {
1692-
reassociationMap[i].push_back(i);
1693-
if (dim != 0 && i != dim - 1)
1694-
continue;
1695-
1696-
alreadyCrossedSqueezedDim = true;
1697-
if (dim == 0)
1698-
reassociationMap[0].push_back(1);
1699-
if (i == dim - 1)
1700-
reassociationMap[i].push_back(dim);
1701-
}
1702-
}
1703-
// Note: In case the operand tensor type is of unit rank and is statically
1704-
// shaped with unit dimension, the `reassociationMap` will be empty and the
1705-
// input will be collapsed to a 0-D tensor.
1706-
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(op, resultType, input,
1707-
reassociationMap);
1656+
rewriter.replaceOp(op, squeezeTensorInfo.value());
17081657
return success();
17091658
}
17101659
};
@@ -1722,36 +1671,15 @@ class ConvertAtenUnsqueezeOp : public OpConversionPattern<AtenUnsqueezeOp> {
17221671
int64_t dim;
17231672
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
17241673
return rewriter.notifyMatchFailure(op, "dim must be constant");
1725-
auto inputRank =
1726-
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
1727-
dim = toPositiveDim(dim, inputRank + 1);
1728-
if (!isValidDim(dim, inputRank + 1))
1729-
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
17301674

1731-
SmallVector<ReassociationIndices> reassociationMap(inputRank);
1732-
// From the perspective of the reassociation map, the situation of
1733-
// unsqueezing before or after the last dimension is symmetrical.
1734-
// Normalize it to the "before" case.
1735-
// The 0 case is special here, since there is no last dimension to insert
1736-
// before -- we simply rely on the loop below iterating 0 times.
1737-
if (dim == inputRank && inputRank != 0)
1738-
dim = inputRank - 1;
1739-
bool alreadyCrossedExpandedDim = false;
1740-
for (int i = 0; i != inputRank; i++) {
1741-
if (alreadyCrossedExpandedDim) {
1742-
reassociationMap[i].push_back(i + 1);
1743-
} else {
1744-
reassociationMap[i].push_back(i);
1745-
if (i == dim) {
1746-
reassociationMap[i].push_back(i + 1);
1747-
alreadyCrossedExpandedDim = true;
1748-
}
1749-
}
1675+
auto unsqueezeTensorInfo =
1676+
unsqueezeTensor(rewriter, op, adaptor.getSelf(), dim);
1677+
if (failed(unsqueezeTensorInfo)) {
1678+
return rewriter.notifyMatchFailure(op,
1679+
"cannot generate unsqueeze tensor");
17501680
}
1751-
auto resultType = cast<RankedTensorType>(
1752-
getTypeConverter()->convertType(op->getResult(0).getType()));
1753-
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1754-
op, resultType, adaptor.getSelf(), reassociationMap);
1681+
1682+
rewriter.replaceOp(op, unsqueezeTensorInfo.value());
17551683
return success();
17561684
}
17571685
};

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,48 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
850850
return rewriter.notifyMatchFailure(op,
851851
"only support constant int dilations");
852852

853+
// Checks for valid group size
854+
int64_t numGroups;
855+
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups)))
856+
return rewriter.notifyMatchFailure(op,
857+
"only constant group size supported.");
858+
Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups());
859+
860+
// Adding support for 1d group convolution by converting the 1d-conv to
861+
// 2d-conv.
862+
// TODO: Replace this logic with the appropriate linalg op for 1-d group
863+
// convolution once that support is added.
864+
bool is1DGroupConv = (numSpatialDims == 1 && numGroups != 1);
865+
if (is1DGroupConv) {
866+
// Unsqueezing the last dim of input and weight. Also extending the
867+
// dilation, stride, padding, and output padding lists.
868+
auto unsqueezeInputInfo =
869+
unsqueezeTensor(rewriter, op, input, /*dim=*/-1);
870+
if (failed(unsqueezeInputInfo)) {
871+
return rewriter.notifyMatchFailure(op,
872+
"cannot generate unsqueeze tensor");
873+
}
874+
input = unsqueezeInputInfo.value();
875+
876+
auto unsqueezeWeightInfo =
877+
unsqueezeTensor(rewriter, op, weight, /*dim=*/-1);
878+
if (failed(unsqueezeWeightInfo)) {
879+
return rewriter.notifyMatchFailure(op,
880+
"cannot generate unsqueeze tensor");
881+
}
882+
weight = unsqueezeWeightInfo.value();
883+
884+
Value cstZero = rewriter.create<arith::ConstantOp>(
885+
loc, rewriter.getI64IntegerAttr(0));
886+
paddingIntValues.push_back(cstZero);
887+
outputPaddingIntValues.push_back(cstZero);
888+
strideInts.push_back(1);
889+
dilationInts.push_back(1);
890+
891+
inRank++;
892+
numSpatialDims++;
893+
}
894+
853895
Value inBatch = getDimOp(rewriter, loc, input, 0);
854896
Value inChannels = getDimOp(rewriter, loc, input, 1);
855897
SmallVector<Value> inDims;
@@ -861,13 +903,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
861903
for (size_t i = 2; i < inRank; i++)
862904
weightDims.push_back(getDimOp(rewriter, loc, weight, i));
863905

864-
// Checks for valid group size
865-
int64_t numGroups;
866-
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups)))
867-
return rewriter.notifyMatchFailure(op,
868-
"only constant group size supported.");
869-
Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups());
870-
871906
auto validate = [&](Value toValidate, std::string err) {
872907
Value c0 =
873908
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
@@ -1286,13 +1321,24 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
12861321
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
12871322
resultElementType);
12881323
}
1324+
1325+
if (is1DGroupConv) {
1326+
// Squeezing the last dim of the result of conv.
1327+
auto squeezeOutputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1);
1328+
if (failed(squeezeOutputInfo)) {
1329+
return rewriter.notifyMatchFailure(op,
1330+
"cannot generate squeeze tensor");
1331+
}
1332+
conv = squeezeOutputInfo.value();
1333+
}
1334+
12891335
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
12901336
return success();
12911337
}
12921338

12931339
if (numSpatialDims != 2)
12941340
return rewriter.notifyMatchFailure(
1295-
op, "unimplemented: only 2D grouped convolution supported");
1341+
op, "unimplemented: only 1D and 2D grouped convolution supported");
12961342

12971343
// Grouped case, use the grouped conv linalg op
12981344
auto expandGroups = [&](Value tensor, size_t dim) {
@@ -1377,6 +1423,16 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
13771423
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
13781424
resultElementType);
13791425
}
1426+
1427+
if (is1DGroupConv) {
1428+
// Squeezing the last dim of the result of conv.
1429+
auto squeezeOutputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1);
1430+
if (failed(squeezeOutputInfo)) {
1431+
return rewriter.notifyMatchFailure(op,
1432+
"cannot generate squeeze tensor");
1433+
}
1434+
conv = squeezeOutputInfo.value();
1435+
}
13801436
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
13811437
return success();
13821438
}

0 commit comments

Comments
 (0)