Skip to content

Commit 454fa9d

Browse files
authored
* [tosa] Support for AtenFlattenUsingIntsOp (#548)
1 parent 8bc028a commit 454fa9d

File tree

3 files changed

+82
-0
lines changed

3 files changed

+82
-0
lines changed

e2e_testing/torchscript/xfail_sets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,7 @@
8686
"BatchNorm1DModule_basic",
8787
"BatchNorm2DModule_basic",
8888
"BatchNorm3DModule_basic",
89+
"FlattenStaticModule_basic",
90+
"FlattenRank0Module_basic",
91+
"ElementwiseFlattenBroadcastModule_basic",
8992
}

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
1818
#include "mlir/Dialect/Traits.h"
1919
#include "mlir/IR/Matchers.h"
20+
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
2021
#include "mlir/Transforms/DialectConversion.h"
2122
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
2223
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
@@ -1928,6 +1929,64 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
19281929
return success();
19291930
}
19301931

1932+
template <>
1933+
LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
1934+
AtenFlattenUsingIntsOp op, OpAdaptor adaptor,
1935+
ConversionPatternRewriter &rewriter) const {
1936+
1937+
// Not a ranked tensor type
1938+
auto selfType = adaptor.self().getType().dyn_cast<RankedTensorType>();
1939+
if (!selfType || !selfType.hasStaticShape())
1940+
return op.emitError(
1941+
"Only ranked tensor types with static shapes are currently supported");
1942+
1943+
int64_t selfRank = selfType.getRank();
1944+
1945+
int64_t start_dim, end_dim;
1946+
1947+
if (!matchPattern(op.start_dim(), m_TorchConstantInt(&start_dim)))
1948+
return op.emitError("start_dim must be a Scalar constant");
1949+
start_dim = toPositiveDim(start_dim, selfRank);
1950+
1951+
if (!matchPattern(op.end_dim(), m_TorchConstantInt(&end_dim)))
1952+
return op.emitError("end_dim must be a Scalar constant");
1953+
end_dim = toPositiveDim(end_dim, selfRank);
1954+
1955+
if (selfRank > 0 && !isValidDim(start_dim, selfRank))
1956+
return op.emitError("start_dim is statically invalid");
1957+
if (selfRank > 0 && !isValidDim(end_dim, selfRank))
1958+
return op.emitError("end_dim is statically invalid");
1959+
if (end_dim < start_dim)
1960+
return op.emitError("end_dim must be larger than start_dim");
1961+
1962+
SmallVector<int64_t> newShape;
1963+
for (auto s : llvm::enumerate(selfType.getShape())) {
1964+
int64_t idx = s.index();
1965+
if (idx < start_dim || idx > end_dim) {
1966+
newShape.push_back(s.value());
1967+
} else {
1968+
if (idx == start_dim)
1969+
newShape.push_back(s.value());
1970+
else
1971+
newShape.back() *= s.value();
1972+
}
1973+
}
1974+
1975+
// Handle the Scalar case
1976+
if (newShape.size() == 0)
1977+
newShape.push_back(1);
1978+
1979+
auto newType = RankedTensorType::get(newShape, selfType.getElementType());
1980+
auto reshapeOp =
1981+
rewriter.create<tosa::ReshapeOp>(op->getLoc(), newType, adaptor.self(),
1982+
rewriter.getI64ArrayAttr(newShape));
1983+
1984+
rewriter.replaceOpWithNewOp<tensor::CastOp>(
1985+
op, getTypeConverter()->convertType(op.getType()), reshapeOp);
1986+
1987+
return success();
1988+
}
1989+
19311990
} // namespace
19321991

19331992
// -----------------------------------------------------------------------------
@@ -2085,6 +2144,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
20852144
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
20862145
INSERT_ATENOP_PATTERN(AtenReshapeOp);
20872146
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
2147+
INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp);
20882148
#undef INSERT_ATENOP_PATTERN
20892149

20902150
if (failed(applyPartialConversion(getOperation(), target,

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,3 +525,22 @@ func @forward(%arg0: !torch.vtensor<[10,4,3],f32> ) -> !torch.vtensor<[10,4,3],f
525525
%2 = torch.aten.batch_norm %arg0, %1, %0, %0, %1, %false, %float1.000000e-01, %float1.000000e-05, %true : !torch.vtensor<[10,4,3],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[10,4,3],f32>
526526
return %2 : !torch.vtensor<[10,4,3],f32>
527527
}
528+
529+
// -----
530+
531+
// CHECK-LABEL: func @forward(
532+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,3,8,9,3,4],f32>) -> !torch.vtensor<[10,3,?,4],f32> {
533+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,3,8,9,3,4],f32> -> tensor<10x3x8x9x3x4xf32>
534+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 4
535+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
536+
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [10, 3, 216, 4]} : (tensor<10x3x8x9x3x4xf32>) -> tensor<10x3x216x4xf32>
537+
// CHECK: %[[VAL_5:.*]] = tensor.cast %[[VAL_4]] : tensor<10x3x216x4xf32> to tensor<10x3x?x4xf32>
538+
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<10x3x?x4xf32> -> !torch.vtensor<[10,3,?,4],f32>
539+
// CHECK: return %[[VAL_6]] : !torch.vtensor<[10,3,?,4],f32>
540+
// CHECK: }
541+
func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor<[10,3,?,4],f32> {
542+
%int4 = torch.constant.int 4
543+
%int2 = torch.constant.int 2
544+
%0 = torch.aten.flatten.using_ints %arg0, %int2, %int4 : !torch.vtensor<[10,3,8,9,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[10,3,?,4],f32>
545+
return %0 : !torch.vtensor<[10,3,?,4],f32>
546+
}

0 commit comments

Comments
 (0)