Skip to content

Commit de4f902

Browse files
Merge branch 'feature/onnx-to-tosa' into chaitany.fix_convtranpose_4x4_mismatch
2 parents 8f70a00 + 8798e80 commit de4f902

File tree

6 files changed

+116
-42
lines changed

6 files changed

+116
-42
lines changed

src/Conversion/ONNXToTOSA/Tensor/Gather.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"
1818
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
1919
#include "src/Dialect/ONNX/ONNXOps.hpp"
20+
#include "src/Support/TypeUtilities.hpp"
2021
#include "llvm/ADT/SmallVector.h"
2122

2223
using namespace mlir;
@@ -46,7 +47,8 @@ class ONNXGatherLoweringToTOSA : public OpConversionPattern<ONNXGatherOp> {
4647
if (!onnx_mlir::isRankedShapedType(inputType))
4748
return rewriter.notifyMatchFailure(op, "input is not a ranked tensor");
4849

49-
if (!hasStaticShape(result.getType()))
50+
if (!hasStaticShape(inputType) || !hasStaticShape(indices.getType()) ||
51+
!hasStaticShape(result.getType()))
5052
return rewriter.notifyMatchFailure(op, "dynamic shapes not supported");
5153

5254
auto resultTy = dyn_cast<TensorType>(op.getType());

src/Conversion/ONNXToTOSA/Tensor/PaddingOp.cpp

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
2323
#include "src/Dialect/ONNX/ONNXOps.hpp"
2424
#include "llvm/ADT/APFloat.h"
25+
#include "llvm/ADT/ArrayRef.h"
2526
#include "llvm/ADT/SmallVector.h"
26-
#include "llvm/Support/Casting.h"
2727

2828
using namespace mlir;
2929

@@ -61,11 +61,9 @@ class ONNXPadOpLoweringToTOSA : public OpConversionPattern<ONNXPadOp> {
6161
op, "only 'constant' mode is supported");
6262
}
6363

64-
if (!pads.getDefiningOp<mlir::tosa::ConstOp>() ||
65-
!(constValue.getDefiningOp<mlir::tosa::ConstOp>() ||
66-
constValue.getDefiningOp<ONNXNoneOp>())) {
64+
if (!pads.getDefiningOp<mlir::tosa::ConstOp>()) {
6765
return rewriter.notifyMatchFailure(
68-
op, "only tosa.const operands are supported");
66+
op, "only tosa.const 'padding' values are supported");
6967
}
7068
// creating the DenseElementsAttr using pads values.
7169
auto denseAttr = tosa::getValueFromTosaConst<ElementsAttr>(pads);
@@ -90,33 +88,7 @@ class ONNXPadOpLoweringToTOSA : public OpConversionPattern<ONNXPadOp> {
9088
mlir::Type resultType =
9189
getTypeConverter()->convertType(op.getResult().getType());
9290

93-
if (!isa<NoneType>(constValue.getType())) {
94-
auto valueAttr = tosa::getValueFromTosaConst<ElementsAttr>(constValue);
95-
TosaBuilder tosaBuilder(rewriter, loc);
96-
97-
Value constTosaTensor;
98-
if (isa<FloatType>(valueAttr.getElementType())) {
99-
auto valueIt = valueAttr.getValues<FloatAttr>().begin();
100-
const float valueFloat = cast<FloatAttr>(*valueIt).getValueAsDouble();
101-
constTosaTensor = tosaBuilder.getSplattedConst(
102-
valueFloat, valueAttr.getElementType(), 0);
103-
} else {
104-
assert(isTOSAInt(elementDtype) && "Already validated");
105-
auto valueIt = valueAttr.getValues<IntegerAttr>().begin();
106-
auto valueAsAPInt = cast<IntegerAttr>(*valueIt).getValue();
107-
auto asIntegerTy = cast<IntegerType>(valueAttr.getElementType());
108-
if (asIntegerTy.isUnsigned()) {
109-
constTosaTensor = tosaBuilder.getSplattedConst(
110-
valueAsAPInt.getZExtValue(), asIntegerTy, 0);
111-
} else {
112-
constTosaTensor = tosaBuilder.getSplattedConst(
113-
valueAsAPInt.getSExtValue(), asIntegerTy, 0);
114-
}
115-
}
116-
rewriter.replaceOpWithNewOp<mlir::tosa::PadOp>(
117-
op, resultType, data, padsList1, constTosaTensor);
118-
119-
} else {
91+
if (isa<NoneType>(constValue.getType())) {
12092
auto constType = RankedTensorType::get({}, elementDtype);
12193

12294
DenseElementsAttr constAttr;
@@ -134,8 +106,12 @@ class ONNXPadOpLoweringToTOSA : public OpConversionPattern<ONNXPadOp> {
134106
padsList1,
135107
rewriter.create<mlir::tosa::ConstOp>(
136108
op->getLoc(), constType, constAttr));
109+
} else {
110+
TosaBuilder tosaBuilder(rewriter, loc);
111+
Value reshapeToSplattedConst = tosaBuilder.reshape(constValue, {});
112+
rewriter.replaceOpWithNewOp<mlir::tosa::PadOp>(
113+
op, resultType, data, padsList1, reshapeToSplattedConst);
137114
}
138-
139115
return success();
140116
}
141117
};

src/Dialect/ONNX/Transforms/SimplifyShapeRelatedOps.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ Now, it's straightforward to update the output shape of Reshape from
6565
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
6666
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
6767
#include "src/Pass/Passes.hpp"
68-
#include "src/Support/TypeUtilities.hpp"
6968

7069
#define DEBUG_TYPE "simplify_shape_related_ops"
7170

@@ -247,14 +246,16 @@ class PassThroughGatherPattern : public OpRewritePattern<ONNXGatherOp> {
247246

248247
// Rewrite
249248
MultiDialectBuilder<OnnxBuilder> create(rewriter, loc);
250-
int64_t inputRank = getRank(input.getType());
249+
ShapedType inputType = llvm::dyn_cast<ShapedType>(input.getType());
250+
if (!inputType || !inputType.hasStaticShape())
251+
return failure();
251252

252253
// Compute integer indices.
253254
SmallVector<int64_t, 4> indicesI64;
254255
for (auto element : indicesAttr.getValues<IntegerAttr>()) {
255-
int64_t axis = element.getInt();
256-
axis = (axis < 0) ? (axis + inputRank) : axis;
257-
indicesI64.emplace_back(axis);
256+
int64_t index = element.getInt();
257+
index = (index < 0) ? (index + inputType.getShape()[axis]) : index;
258+
indicesI64.emplace_back(index);
258259
}
259260

260261
// Replace GatherOp by ConcatOp of specific dimensions.

test/mlir/conversion/onnx_to_tosa/Tensor/Gather.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,21 @@ func.func @test_gather_dynamic_shape_indices_i32(%arg0 : tensor<?x4xf32>, %indic
175175
// CHECK-LABEL: test_gather_dynamic_shape_indices_i32
176176
// CHECK: onnx.Gather
177177
}
178+
179+
// -----
180+
181+
func.func @test_gather_dynamic_input_static_output(%arg0 : tensor<?x2xf32>, %indices: tensor<?xi64>) -> tensor<1x2xf32> {
182+
%0 = "onnx.Gather"(%arg0, %indices) {axis = 0 : si64, onnx_node_name = "/Gather_16"} : (tensor<?x2xf32>, tensor<?xi64>) -> tensor<1x2xf32>
183+
"func.return"(%0) : (tensor<1x2xf32>) -> ()
184+
// CHECK-LABEL: test_gather_dynamic_input_static_output
185+
// CHECK: onnx.Gather
186+
}
187+
188+
// -----
189+
190+
func.func @test_gather_dynamic_indices(%arg0 : tensor<1x2xf32>, %indices: tensor<?xi64>) -> tensor<1x2xf32> {
191+
%0 = "onnx.Gather"(%arg0, %indices) {axis = 0 : si64, onnx_node_name = "/Gather_16"} : (tensor<1x2xf32>, tensor<?xi64>) -> tensor<1x2xf32>
192+
"func.return"(%0) : (tensor<1x2xf32>) -> ()
193+
// CHECK-LABEL: test_gather_dynamic_indices
194+
// CHECK: onnx.Gather
195+
}

test/mlir/conversion/onnx_to_tosa/Tensor/Padding.mlir

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa --cse %s -split-input-file | FileCheck %s
1+
// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa --canonicalize --cse %s -split-input-file | FileCheck %s
22

33
func.func @test_pad_f32(%arg0: tensor<20x16x44x32xf32>) -> tensor<24x22x52x42xf32> {
44
%noval = "onnx.NoValue"() {value} : () -> none
@@ -160,3 +160,59 @@ func.func @test_pad_f16_constant_none(%arg0: tensor<256x1x1x5x1xf16>) -> tensor<
160160
// CHECK: %[[VAR2:.*]] = tosa.pad %arg0, %[[VAR0]], %[[VAR1]] : (tensor<256x1x1x5x1xf16>, !tosa.shape<10>, tensor<f16>) -> tensor<256x1x1x5x2xf16>
161161
// CHECK: return %[[VAR2]] : tensor<256x1x1x5x2xf16>
162162
}
163+
164+
// -----
165+
166+
func.func @test_pad_f32_non_constant_padval(%arg0: tensor<20x16x44x32xf32>, %arg1: tensor<f32>) -> tensor<24x22x52x42xf32> {
167+
%noval = "onnx.NoValue"() {value} : () -> none
168+
%0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64>
169+
%2 = "onnx.Pad"(%arg0, %0, %arg1, %noval) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, tensor<f32>, none) -> tensor<24x22x52x42xf32>
170+
return %2 : tensor<24x22x52x42xf32>
171+
// CHECK-LABEL: func.func @test_pad_f32_non_constant_padval
172+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<20x16x44x32xf32>, [[PARAM_1_:%.+]]: tensor<f32>) -> tensor<24x22x52x42xf32> {
173+
// CHECK: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[0, 4, 1, 5, 2, 6, 3, 7]> : tensor<8xindex>} : () -> !tosa.shape<8>
174+
// CHECK: [[VAR_1_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[PARAM_1_]] : (tensor<20x16x44x32xf32>, !tosa.shape<8>, tensor<f32>) -> tensor<24x22x52x42xf32>
175+
// CHECK: return [[VAR_1_]] : tensor<24x22x52x42xf32>
176+
}
177+
178+
// -----
179+
180+
func.func @test_pad_f32_non_constant_1Dpadval(%arg0: tensor<20x16x44x32xf32>, %arg1: tensor<1xf32>) -> tensor<24x22x52x42xf32> {
181+
%noval = "onnx.NoValue"() {value} : () -> none
182+
%0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64>
183+
%2 = "onnx.Pad"(%arg0, %0, %arg1, %noval) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, tensor<1xf32>, none) -> tensor<24x22x52x42xf32>
184+
return %2 : tensor<24x22x52x42xf32>
185+
// CHECK-LABEL: func.func @test_pad_f32_non_constant_1Dpadval
186+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<20x16x44x32xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<24x22x52x42xf32> {
187+
// CHECK-DAG: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[0, 4, 1, 5, 2, 6, 3, 7]> : tensor<8xindex>} : () -> !tosa.shape<8>
188+
// CHECK-DAG: [[VAL_1_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array<i64>} : (tensor<1xf32>) -> tensor<f32>
189+
// CHECK: [[VAR_2_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[VAL_1_]] : (tensor<20x16x44x32xf32>, !tosa.shape<8>, tensor<f32>) -> tensor<24x22x52x42xf32>
190+
// CHECK: return [[VAR_2_]] : tensor<24x22x52x42xf32>
191+
}
192+
193+
// -----
194+
195+
func.func @test_pad_i64_non_constant_padval(%arg0: tensor<20x16x44x32xi64>, %arg1: tensor<i64>) -> tensor<24x22x52x42xi64> {
196+
%noval = "onnx.NoValue"() {value} : () -> none
197+
%0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64>
198+
%2 = "onnx.Pad"(%arg0, %0, %arg1, %noval) {mode = "constant"} : (tensor<20x16x44x32xi64>, tensor<8xi64>, tensor<i64>, none) -> tensor<24x22x52x42xi64>
199+
return %2 : tensor<24x22x52x42xi64>
200+
// CHECK-LABEL: func.func @test_pad_i64_non_constant_padval
201+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<20x16x44x32xi64>, [[PARAM_1_:%.+]]: tensor<i64>) -> tensor<24x22x52x42xi64> {
202+
// CHECK: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[0, 4, 1, 5, 2, 6, 3, 7]> : tensor<8xindex>} : () -> !tosa.shape<8>
203+
// CHECK: [[VAR_1_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[PARAM_1_]] : (tensor<20x16x44x32xi64>, !tosa.shape<8>, tensor<i64>) -> tensor<24x22x52x42xi64>
204+
// CHECK: return [[VAR_1_]] : tensor<24x22x52x42xi64>
205+
}
206+
207+
// -----
208+
func.func @test_pad_f16_non_constant_padval(%arg0: tensor<20x16x44x32xf16>, %arg1: tensor<f16>) -> tensor<24x22x52x42xf16> {
209+
%noval = "onnx.NoValue"() {value} : () -> none
210+
%0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64>
211+
%2 = "onnx.Pad"(%arg0, %0, %arg1, %noval) {mode = "constant"} : (tensor<20x16x44x32xf16>, tensor<8xi64>, tensor<f16>, none) -> tensor<24x22x52x42xf16>
212+
return %2 : tensor<24x22x52x42xf16>
213+
// CHECK-LABEL: func.func @test_pad_f16_non_constant_padval
214+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<20x16x44x32xf16>, [[PARAM_1_:%.+]]: tensor<f16>) -> tensor<24x22x52x42xf16> {
215+
// CHECK: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[0, 4, 1, 5, 2, 6, 3, 7]> : tensor<8xindex>} : () -> !tosa.shape<8>
216+
// CHECK: [[VAR_1_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[PARAM_1_]] : (tensor<20x16x44x32xf16>, !tosa.shape<8>, tensor<f16>) -> tensor<24x22x52x42xf16>
217+
// CHECK: return [[VAR_1_]] : tensor<24x22x52x42xf16>
218+
}

test/mlir/onnx/onnx_simplify_shape_related_ops.mlir

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ func.func @test_pass_dims_through_concat(%arg0: tensor<?x256xi64>) -> (tensor<4x
103103

104104
// -----
105105

106-
func.func @test_pass_dims_through_cast_2(%arg0: tensor<?x?x200xf32>) -> tensor<2xi64> {
106+
func.func @test_pass_dims_through_gather(%arg0: tensor<?x?x200xf32>) -> tensor<2xi64> {
107107
%0 = onnx.Constant dense<[0, 1]> : tensor<2xi64>
108108
%1 = "onnx.Dim"(%arg0) {axis = 0 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
109109
%2 = "onnx.Dim"(%arg0) {axis = 1 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
@@ -113,7 +113,28 @@ func.func @test_pass_dims_through_cast_2(%arg0: tensor<?x?x200xf32>) -> tensor<2
113113
onnx.Return %5 : tensor<2xi64>
114114

115115
// mlir2FileCheck.py
116-
// CHECK-LABEL: func.func @test_pass_dims_through_cast_2
116+
// CHECK-LABEL: func.func @test_pass_dims_through_gather
117+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x200xf32>) -> tensor<2xi64> {
118+
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
119+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
120+
// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
121+
// CHECK: onnx.Return [[VAR_2_]] : tensor<2xi64>
122+
// CHECK: }
123+
}
124+
125+
// -----
126+
127+
func.func @test_pass_dims_through_gather_2(%arg0: tensor<?x?x200xf32>) -> tensor<2xi64> {
128+
%0 = onnx.Constant dense<[-3, -2]> : tensor<2xi64>
129+
%1 = "onnx.Dim"(%arg0) {axis = 0 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
130+
%2 = "onnx.Dim"(%arg0) {axis = 1 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
131+
%3 = onnx.Constant dense<200> : tensor<1xi64>
132+
%4 = "onnx.Concat"(%1, %2, %3) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64>
133+
%5 = "onnx.Gather"(%4, %0) {axis = 0 : si64} : (tensor<3xi64>, tensor<2xi64>) -> tensor<2xi64>
134+
onnx.Return %5 : tensor<2xi64>
135+
136+
// mlir2FileCheck.py
137+
// CHECK-LABEL: func.func @test_pass_dims_through_gather_2
117138
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x200xf32>) -> tensor<2xi64> {
118139
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
119140
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>

0 commit comments

Comments
 (0)