Skip to content

Commit 6063633

Browse files
authored
Merge branch 'feature/onnx-to-tosa' into koshrai2.backport.pad.lowering.changes
2 parents e859e0c + 86d29ea commit 6063633

File tree

4 files changed

+50
-8
lines changed

4 files changed

+50
-8
lines changed

src/Conversion/ONNXToTOSA/Tensor/Gather.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"
1818
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
1919
#include "src/Dialect/ONNX/ONNXOps.hpp"
20+
#include "src/Support/TypeUtilities.hpp"
2021
#include "llvm/ADT/SmallVector.h"
2122

2223
using namespace mlir;
@@ -46,7 +47,8 @@ class ONNXGatherLoweringToTOSA : public OpConversionPattern<ONNXGatherOp> {
4647
if (!onnx_mlir::isRankedShapedType(inputType))
4748
return rewriter.notifyMatchFailure(op, "input is not a ranked tensor");
4849

49-
if (!hasStaticShape(result.getType()))
50+
if (!hasStaticShape(inputType) || !hasStaticShape(indices.getType()) ||
51+
!hasStaticShape(result.getType()))
5052
return rewriter.notifyMatchFailure(op, "dynamic shapes not supported");
5153

5254
auto resultTy = dyn_cast<TensorType>(op.getType());

src/Dialect/ONNX/Transforms/SimplifyShapeRelatedOps.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ Now, it's straightforward to update the output shape of Reshape from
6565
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
6666
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
6767
#include "src/Pass/Passes.hpp"
68-
#include "src/Support/TypeUtilities.hpp"
6968

7069
#define DEBUG_TYPE "simplify_shape_related_ops"
7170

@@ -247,14 +246,16 @@ class PassThroughGatherPattern : public OpRewritePattern<ONNXGatherOp> {
247246

248247
// Rewrite
249248
MultiDialectBuilder<OnnxBuilder> create(rewriter, loc);
250-
int64_t inputRank = getRank(input.getType());
249+
ShapedType inputType = llvm::dyn_cast<ShapedType>(input.getType());
250+
if (!inputType || !inputType.hasStaticShape())
251+
return failure();
251252

252253
// Compute integer indices.
253254
SmallVector<int64_t, 4> indicesI64;
254255
for (auto element : indicesAttr.getValues<IntegerAttr>()) {
255-
int64_t axis = element.getInt();
256-
axis = (axis < 0) ? (axis + inputRank) : axis;
257-
indicesI64.emplace_back(axis);
256+
int64_t index = element.getInt();
257+
index = (index < 0) ? (index + inputType.getShape()[axis]) : index;
258+
indicesI64.emplace_back(index);
258259
}
259260

260261
// Replace GatherOp by ConcatOp of specific dimensions.

test/mlir/conversion/onnx_to_tosa/Tensor/Gather.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,21 @@ func.func @test_gather_dynamic_shape_indices_i32(%arg0 : tensor<?x4xf32>, %indic
175175
// CHECK-LABEL: test_gather_dynamic_shape_indices_i32
176176
// CHECK: onnx.Gather
177177
}
178+
179+
// -----
180+
181+
func.func @test_gather_dynamic_input_static_output(%arg0 : tensor<?x2xf32>, %indices: tensor<?xi64>) -> tensor<1x2xf32> {
182+
%0 = "onnx.Gather"(%arg0, %indices) {axis = 0 : si64, onnx_node_name = "/Gather_16"} : (tensor<?x2xf32>, tensor<?xi64>) -> tensor<1x2xf32>
183+
"func.return"(%0) : (tensor<1x2xf32>) -> ()
184+
// CHECK-LABEL: test_gather_dynamic_input_static_output
185+
// CHECK: onnx.Gather
186+
}
187+
188+
// -----
189+
190+
func.func @test_gather_dynamic_indices(%arg0 : tensor<1x2xf32>, %indices: tensor<?xi64>) -> tensor<1x2xf32> {
191+
%0 = "onnx.Gather"(%arg0, %indices) {axis = 0 : si64, onnx_node_name = "/Gather_16"} : (tensor<1x2xf32>, tensor<?xi64>) -> tensor<1x2xf32>
192+
"func.return"(%0) : (tensor<1x2xf32>) -> ()
193+
// CHECK-LABEL: test_gather_dynamic_indices
194+
// CHECK: onnx.Gather
195+
}

test/mlir/onnx/onnx_simplify_shape_related_ops.mlir

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ func.func @test_pass_dims_through_concat(%arg0: tensor<?x256xi64>) -> (tensor<4x
103103

104104
// -----
105105

106-
func.func @test_pass_dims_through_cast_2(%arg0: tensor<?x?x200xf32>) -> tensor<2xi64> {
106+
func.func @test_pass_dims_through_gather(%arg0: tensor<?x?x200xf32>) -> tensor<2xi64> {
107107
%0 = onnx.Constant dense<[0, 1]> : tensor<2xi64>
108108
%1 = "onnx.Dim"(%arg0) {axis = 0 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
109109
%2 = "onnx.Dim"(%arg0) {axis = 1 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
@@ -113,7 +113,28 @@ func.func @test_pass_dims_through_cast_2(%arg0: tensor<?x?x200xf32>) -> tensor<2
113113
onnx.Return %5 : tensor<2xi64>
114114

115115
// mlir2FileCheck.py
116-
// CHECK-LABEL: func.func @test_pass_dims_through_cast_2
116+
// CHECK-LABEL: func.func @test_pass_dims_through_gather
117+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x200xf32>) -> tensor<2xi64> {
118+
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
119+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
120+
// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
121+
// CHECK: onnx.Return [[VAR_2_]] : tensor<2xi64>
122+
// CHECK: }
123+
}
124+
125+
// -----
126+
127+
func.func @test_pass_dims_through_gather_2(%arg0: tensor<?x?x200xf32>) -> tensor<2xi64> {
128+
%0 = onnx.Constant dense<[-3, -2]> : tensor<2xi64>
129+
%1 = "onnx.Dim"(%arg0) {axis = 0 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
130+
%2 = "onnx.Dim"(%arg0) {axis = 1 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
131+
%3 = onnx.Constant dense<200> : tensor<1xi64>
132+
%4 = "onnx.Concat"(%1, %2, %3) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64>
133+
%5 = "onnx.Gather"(%4, %0) {axis = 0 : si64} : (tensor<3xi64>, tensor<2xi64>) -> tensor<2xi64>
134+
onnx.Return %5 : tensor<2xi64>
135+
136+
// mlir2FileCheck.py
137+
// CHECK-LABEL: func.func @test_pass_dims_through_gather_2
117138
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x200xf32>) -> tensor<2xi64> {
118139
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
119140
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>

0 commit comments

Comments
 (0)