|
17 | 17 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
18 | 18 | #include "mlir/Dialect/Traits.h" |
19 | 19 | #include "mlir/IR/Matchers.h" |
| 20 | +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" |
20 | 21 | #include "mlir/Transforms/DialectConversion.h" |
21 | 22 | #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" |
22 | 23 | #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" |
@@ -1928,6 +1929,64 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite( |
1928 | 1929 | return success(); |
1929 | 1930 | } |
1930 | 1931 |
|
| 1932 | +template <> |
| 1933 | +LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite( |
| 1934 | + AtenFlattenUsingIntsOp op, OpAdaptor adaptor, |
| 1935 | + ConversionPatternRewriter &rewriter) const { |
| 1936 | + |
| 1937 | + // Not a ranked tensor type |
| 1938 | + auto selfType = adaptor.self().getType().dyn_cast<RankedTensorType>(); |
| 1939 | + if (!selfType || !selfType.hasStaticShape()) |
| 1940 | + return op.emitError( |
| 1941 | + "Only ranked tensor types with static shapes are currently supported"); |
| 1942 | + |
| 1943 | + int64_t selfRank = selfType.getRank(); |
| 1944 | + |
| 1945 | + int64_t start_dim, end_dim; |
| 1946 | + |
| 1947 | + if (!matchPattern(op.start_dim(), m_TorchConstantInt(&start_dim))) |
| 1948 | + return op.emitError("start_dim must be a Scalar constant"); |
| 1949 | + start_dim = toPositiveDim(start_dim, selfRank); |
| 1950 | + |
| 1951 | + if (!matchPattern(op.end_dim(), m_TorchConstantInt(&end_dim))) |
| 1952 | + return op.emitError("end_dim must be a Scalar constant"); |
| 1953 | + end_dim = toPositiveDim(end_dim, selfRank); |
| 1954 | + |
| 1955 | + if (selfRank > 0 && !isValidDim(start_dim, selfRank)) |
| 1956 | + return op.emitError("start_dim is statically invalid"); |
| 1957 | + if (selfRank > 0 && !isValidDim(end_dim, selfRank)) |
| 1958 | + return op.emitError("end_dim is statically invalid"); |
| 1959 | + if (end_dim < start_dim) |
| 1960 | + return op.emitError("end_dim must be larger than start_dim"); |
| 1961 | + |
| 1962 | + SmallVector<int64_t> newShape; |
| 1963 | + for (auto s : llvm::enumerate(selfType.getShape())) { |
| 1964 | + int64_t idx = s.index(); |
| 1965 | + if (idx < start_dim || idx > end_dim) { |
| 1966 | + newShape.push_back(s.value()); |
| 1967 | + } else { |
| 1968 | + if (idx == start_dim) |
| 1969 | + newShape.push_back(s.value()); |
| 1970 | + else |
| 1971 | + newShape.back() *= s.value(); |
| 1972 | + } |
| 1973 | + } |
| 1974 | + |
| 1975 | + // Handle the Scalar case |
| 1976 | + if (newShape.size() == 0) |
| 1977 | + newShape.push_back(1); |
| 1978 | + |
| 1979 | + auto newType = RankedTensorType::get(newShape, selfType.getElementType()); |
| 1980 | + auto reshapeOp = |
| 1981 | + rewriter.create<tosa::ReshapeOp>(op->getLoc(), newType, adaptor.self(), |
| 1982 | + rewriter.getI64ArrayAttr(newShape)); |
| 1983 | + |
| 1984 | + rewriter.replaceOpWithNewOp<tensor::CastOp>( |
| 1985 | + op, getTypeConverter()->convertType(op.getType()), reshapeOp); |
| 1986 | + |
| 1987 | + return success(); |
| 1988 | +} |
| 1989 | + |
1931 | 1990 | } // namespace |
1932 | 1991 |
|
1933 | 1992 | // ----------------------------------------------------------------------------- |
@@ -2085,6 +2144,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> { |
2085 | 2144 | INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); |
2086 | 2145 | INSERT_ATENOP_PATTERN(AtenReshapeOp); |
2087 | 2146 | INSERT_ATENOP_PATTERN(AtenBatchNormOp); |
| 2147 | + INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp); |
2088 | 2148 | #undef INSERT_ATENOP_PATTERN |
2089 | 2149 |
|
2090 | 2150 | if (failed(applyPartialConversion(getOperation(), target, |
|
0 commit comments