Skip to content

Commit 1c9f28b

Browse files
committed
convert negative indices correctly
1 parent 189c863 commit 1c9f28b

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

src/Dialect/ONNX/Transforms/SimplifyShapeRelatedOps.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,14 +247,16 @@ class PassThroughGatherPattern : public OpRewritePattern<ONNXGatherOp> {
247247

248248
// Rewrite
249249
MultiDialectBuilder<OnnxBuilder> create(rewriter, loc);
250-
int64_t inputRank = getRank(input.getType());
250+
ShapedType inputType = llvm::dyn_cast<ShapedType>(input.getType());
251+
if (!inputType || !inputType.hasStaticShape())
252+
return failure();
251253

252254
// Compute integer indices.
253255
SmallVector<int64_t, 4> indicesI64;
254256
for (auto element : indicesAttr.getValues<IntegerAttr>()) {
255-
int64_t axis = element.getInt();
256-
axis = (axis < 0) ? (axis + inputRank) : axis;
257-
indicesI64.emplace_back(axis);
257+
int64_t index = element.getInt();
258+
index = (index < 0) ? (index + inputType.getShape()[axis]) : index;
259+
indicesI64.emplace_back(index);
258260
}
259261

260262
// Replace GatherOp by ConcatOp of specific dimensions.

test/mlir/onnx/onnx_simplify_shape_related_ops.mlir

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ func.func @test_pass_dims_through_concat(%arg0: tensor<?x256xi64>) -> (tensor<4x
103103

104104
// -----
105105

106-
func.func @test_pass_dims_through_cast_2(%arg0: tensor<?x?x200xf32>) -> tensor<2xi64> {
106+
func.func @test_pass_dims_through_gather(%arg0: tensor<?x?x200xf32>) -> tensor<2xi64> {
107107
%0 = onnx.Constant dense<[0, 1]> : tensor<2xi64>
108108
%1 = "onnx.Dim"(%arg0) {axis = 0 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
109109
%2 = "onnx.Dim"(%arg0) {axis = 1 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
@@ -113,7 +113,28 @@ func.func @test_pass_dims_through_cast_2(%arg0: tensor<?x?x200xf32>) -> tensor<2
113113
onnx.Return %5 : tensor<2xi64>
114114

115115
// mlir2FileCheck.py
116-
// CHECK-LABEL: func.func @test_pass_dims_through_cast_2
116+
// CHECK-LABEL: func.func @test_pass_dims_through_gather
117+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x200xf32>) -> tensor<2xi64> {
118+
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
119+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
120+
// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
121+
// CHECK: onnx.Return [[VAR_2_]] : tensor<2xi64>
122+
// CHECK: }
123+
}
124+
125+
// -----
126+
127+
func.func @test_pass_dims_through_gather_2(%arg0: tensor<?x?x200xf32>) -> tensor<2xi64> {
128+
%0 = onnx.Constant dense<[-3, -2]> : tensor<2xi64>
129+
%1 = "onnx.Dim"(%arg0) {axis = 0 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
130+
%2 = "onnx.Dim"(%arg0) {axis = 1 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
131+
%3 = onnx.Constant dense<200> : tensor<1xi64>
132+
%4 = "onnx.Concat"(%1, %2, %3) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64>
133+
%5 = "onnx.Gather"(%4, %0) {axis = 0 : si64} : (tensor<3xi64>, tensor<2xi64>) -> tensor<2xi64>
134+
onnx.Return %5 : tensor<2xi64>
135+
136+
// mlir2FileCheck.py
137+
// CHECK-LABEL: func.func @test_pass_dims_through_gather_2
117138
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x200xf32>) -> tensor<2xi64> {
118139
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>
119140
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x?x200xf32>) -> tensor<1xi64>

0 commit comments

Comments
 (0)