Skip to content

Commit 4bce4f9

Browse files
authored
OnnxToTosa conversion: add option to only convert slice to tosa when all steps are 1
1 parent 9314dbe commit 4bce4f9

File tree

4 files changed

+68
-10
lines changed

4 files changed

+68
-10
lines changed

src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace onnx_mlir {
2626

2727
void populateONNXToTOSAConversionPattern(ConversionTarget &target,
2828
RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx,
29-
int64_t groupedConvThreshold) {
29+
int64_t groupedConvThreshold, bool convertSliceOnlyWhenStepOne) {
3030
// Math
3131
populateLoweringONNXElementwiseOpToTOSAPattern(
3232
target, patterns, typeConverter, ctx);
@@ -56,7 +56,7 @@ void populateONNXToTOSAConversionPattern(ConversionTarget &target,
5656
populateLoweringONNXFlattenOpToTOSAPattern(
5757
target, patterns, typeConverter, ctx);
5858
populateLoweringONNXSliceOpToTOSAPattern(
59-
target, patterns, typeConverter, ctx);
59+
target, patterns, typeConverter, ctx, convertSliceOnlyWhenStepOne);
6060
populateLoweringONNXSplitOpToTOSAPattern(
6161
target, patterns, typeConverter, ctx);
6262
populateLoweringONNXSqueezeOpToTOSAPattern(
@@ -114,6 +114,10 @@ struct FrontendToTosaLoweringPass
114114
"into a concatenation of tosa.conv2d operations"),
115115
llvm::cl::ZeroOrMore,
116116
llvm::cl::init(std::numeric_limits<int64_t>::max())};
117+
Option<bool> convertSliceOnlyWhenStepOne{*this,
118+
"convert-slice-only-when-step-one",
119+
llvm::cl::desc("If enabled, convert onnx.slice only if all steps are 1"),
120+
llvm::cl::ZeroOrMore, llvm::cl::init(false)};
117121
};
118122

119123
void FrontendToTosaLoweringPass::runOnOperation() {
@@ -144,8 +148,8 @@ void FrontendToTosaLoweringPass::runOnOperation() {
144148
mlir::arith::ArithDialect, mlir::shape::ShapeDialect>();
145149

146150
// Define patterns
147-
populateONNXToTOSAConversionPattern(
148-
target, patterns, typeConverter, context, groupedConvThreshold);
151+
populateONNXToTOSAConversionPattern(target, patterns, typeConverter, context,
152+
groupedConvThreshold, convertSliceOnlyWhenStepOne);
149153

150154
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
151155
signalPassFailure();

src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ void populateLoweringONNXPadOpToTOSAPattern(mlir::ConversionTarget &,
163163
void populateLoweringONNXFlattenOpToTOSAPattern(mlir::ConversionTarget &,
164164
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
165165
void populateLoweringONNXSliceOpToTOSAPattern(mlir::ConversionTarget &,
166-
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
166+
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *,
167+
bool);
167168
void populateLoweringONNXSplitOpToTOSAPattern(mlir::ConversionTarget &,
168169
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *);
169170
void populateLoweringONNXSqueezeOpToTOSAPattern(mlir::ConversionTarget &,

src/Conversion/ONNXToTOSA/Tensor/Slice.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
16-
#include "mlir/Interfaces/InferTypeOpInterface.h"
1716
#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp"
1817
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"
1918
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
@@ -30,6 +29,11 @@ class ONNXSliceLoweringToTOSA : public OpConversionPattern<ONNXSliceOp> {
3029
public:
3130
using OpConversionPattern::OpConversionPattern;
3231
using OpAdaptor = typename ONNXSliceOp::Adaptor;
32+
33+
ONNXSliceLoweringToTOSA(MLIRContext *ctx, bool convertSliceOnlyWhenStepOne)
34+
: OpConversionPattern(ctx),
35+
convertSliceOnlyWhenStepOne(convertSliceOnlyWhenStepOne) {}
36+
3337
LogicalResult matchAndRewrite(ONNXSliceOp op, OpAdaptor adaptor,
3438
ConversionPatternRewriter &rewriter) const override {
3539

@@ -74,6 +78,17 @@ class ONNXSliceLoweringToTOSA : public OpConversionPattern<ONNXSliceOp> {
7478
llvm::SmallVector<int64_t, 4> steps;
7579
IndexExpr::getLiteral(shapeHelper.steps, steps);
7680

81+
size_t nbSlicedDims = 0;
82+
for (auto [in, out] : llvm::zip(inShape, outShape)) {
83+
if (out != in)
84+
nbSlicedDims++;
85+
}
86+
// TODO: remove the check nbSlicedDims == 1 when possible
87+
if (convertSliceOnlyWhenStepOne && nbSlicedDims == 1 &&
88+
llvm::any_of(steps, [](int64_t step) { return step > 1; })) {
89+
return rewriter.notifyMatchFailure(op, "step > 1 are not supported.");
90+
}
91+
7792
if (llvm::any_of(steps, [](int64_t step) { return step < 0; })) {
7893
return rewriter.notifyMatchFailure(op, "negative step not supported.");
7994
}
@@ -161,14 +176,17 @@ class ONNXSliceLoweringToTOSA : public OpConversionPattern<ONNXSliceOp> {
161176
rewriter.replaceOp(op, val);
162177
return success();
163178
}
179+
180+
private:
181+
bool convertSliceOnlyWhenStepOne;
164182
};
165183

166184
} // namespace
167185

168-
void populateLoweringONNXSliceOpToTOSAPattern(ConversionTarget &target,
169-
RewritePatternSet &patterns, TypeConverter &typeConverter,
170-
MLIRContext *ctx) {
171-
patterns.insert<ONNXSliceLoweringToTOSA>(ctx);
186+
void populateLoweringONNXSliceOpToTOSAPattern(ConversionTarget & /*target*/,
187+
RewritePatternSet &patterns, TypeConverter & /*typeConverter*/,
188+
MLIRContext *ctx, bool convertSliceOnlyWhenStepOne) {
189+
patterns.insert<ONNXSliceLoweringToTOSA>(ctx, convertSliceOnlyWhenStepOne);
172190
}
173191

174192
} // namespace onnx_mlir

test/mlir/conversion/onnx_to_tosa/Tensor/Slice.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s
2+
// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa="convert-slice-only-when-step-one=true" -cse %s -split-input-file | FileCheck %s --check-prefix=ONLY-STEP1
23

34

45
func.func @test_slice_constant_default_steps(%arg0 : tensor<2x4xf32>) -> tensor<1x3xf32> {
@@ -74,6 +75,9 @@ func.func @slice_just_steps(%arg0: tensor<100x200xf32>) -> tensor<20x20xf32> {
7475
// CHECK: %2 = tosa.reshape %1 {new_shape = array<i64: 20, 20>} : (tensor<20x1x20x1xf32>) -> tensor<20x20xf32>
7576
// CHECK: return %2 : tensor<20x20xf32>
7677

78+
// ONLY-STEP1-LABEL: func @slice_just_steps
79+
// ONLY-STEP1: tosa.slice
80+
7781
// -----
7882

7983
func.func @slice_steps_and_edges(%arg0: tensor<100x200xf32>) -> tensor<16x17xf32> {
@@ -91,6 +95,9 @@ func.func @slice_steps_and_edges(%arg0: tensor<100x200xf32>) -> tensor<16x17xf32
9195
// CHECK: %3 = tosa.reshape %2 {new_shape = array<i64: 16, 17>} : (tensor<16x1x17x1xf32>) -> tensor<16x17xf32>
9296
// CHECK: return %3 : tensor<16x17xf32>
9397

98+
// ONLY-STEP1-LABEL: func @slice_steps_and_edges
99+
// ONLY-STEP1: tosa.slice
100+
94101
// -----
95102

96103
func.func @slice_steps_and_edges_with_padding(%arg0: tensor<99x195xf32>) -> tensor<19x19xf32> {
@@ -111,6 +118,9 @@ func.func @slice_steps_and_edges_with_padding(%arg0: tensor<99x195xf32>) -> tens
111118
// CHECK: %6 = tosa.reshape %5 {new_shape = array<i64: 19, 19>} : (tensor<19x1x19x1xf32>) -> tensor<19x19xf32>
112119
// CHECK: return %6 : tensor<19x19xf32>
113120

121+
// ONLY-STEP1-LABEL: func @slice_steps_and_edges_with_padding
122+
// ONLY-STEP1: tosa.slice
123+
114124
// -----
115125

116126
func.func @slice_just_steps_with_padding(%arg0: tensor<99x195xf32>) -> tensor<20x20xf32> {
@@ -130,6 +140,9 @@ func.func @slice_just_steps_with_padding(%arg0: tensor<99x195xf32>) -> tensor<20
130140
// CHECK: %5 = tosa.reshape %4 {new_shape = array<i64: 20, 20>} : (tensor<20x1x20x1xf32>) -> tensor<20x20xf32>
131141
// CHECK: return %5 : tensor<20x20xf32>
132142

143+
// ONLY-STEP1-LABEL: func @slice_just_steps_with_padding
144+
// ONLY-STEP1: tosa.slice
145+
133146
// -----
134147

135148
func.func @slice_negative_steps(%arg0: tensor<100x200xf32>) -> tensor<20x20xf32> {
@@ -195,3 +208,25 @@ func.func @slice_4d(%arg0: tensor<1x56x56x92xf32>) -> tensor<1x28x28x92xf32> {
195208
// CHECK: %5 = tosa.slice %4 {size = array<i64: 1, 28, 1, 28, 1, 92>, start = array<i64: 0, 0, 0, 0, 0, 0>} : (tensor<1x28x2x28x2x92xf32>) -> tensor<1x28x1x28x1x92xf32>
196209
// CHECK: %6 = tosa.reshape %5 {new_shape = array<i64: 1, 28, 28, 92>} : (tensor<1x28x1x28x1x92xf32>) -> tensor<1x28x28x92xf32>
197210
// CHECK: return %6 : tensor<1x28x28x92xf32>
211+
212+
// ONLY-STEP1-LABEL: func @slice_4d
213+
// ONLY-STEP1: tosa.slice
214+
215+
// -----
216+
217+
func.func @slice_4d_step2_1sliced_dim(%arg0: tensor<1x3x640x640xbf16>) -> tensor<1x3x320x640xbf16> {
218+
%starts = "tosa.const"() <{value = dense<0> : tensor<1xi64>}> : () -> tensor<1xi64>
219+
%ends = "tosa.const"() <{value = dense<9223372036854775807> : tensor<1xi64>}> : () -> tensor<1xi64>
220+
%axes_steps = "tosa.const"() <{value = dense<2> : tensor<1xi64>}> : () -> tensor<1xi64>
221+
%3 = "onnx.Slice"(%arg0, %starts, %ends, %axes_steps, %axes_steps) : (tensor<1x3x640x640xbf16>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1x3x320x640xbf16>
222+
return %3 : tensor<1x3x320x640xbf16>
223+
}
224+
// CHECK-LABEL: func @slice_4d_step2_1sliced_dim
225+
// CHECK-SAME: (%[[ARG_0:.*]]: tensor<1x3x640x640xbf16>) -> tensor<1x3x320x640xbf16>
226+
// CHECK: %[[VAL_0:.*]] = tosa.reshape %[[ARG_0]] {new_shape = array<i64: 1, 3, 320, 2, 640>} : (tensor<1x3x640x640xbf16>) -> tensor<1x3x320x2x640xbf16>
227+
// CHECK: %[[VAL_1:.*]] = tosa.slice %[[VAL_0]] {size = array<i64: 1, 3, 320, 1, 640>, start = array<i64: 0, 0, 0, 0, 0>} : (tensor<1x3x320x2x640xbf16>) -> tensor<1x3x320x1x640xbf16>
228+
// CHECK: %[[VAL_2:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 1, 3, 320, 640>} : (tensor<1x3x320x1x640xbf16>) -> tensor<1x3x320x640xbf16>
229+
// CHECK: return %[[VAL_2]] : tensor<1x3x320x640xbf16>
230+
231+
// ONLY-STEP1-LABEL: func @slice_4d_step2_1sliced_dim
232+
// ONLY-STEP1-NOT: tosa.slice

0 commit comments

Comments
 (0)