Skip to content

Commit 898b13d

Browse files
authored
Merge pull request #385 from Xilinx/tiagot.add_folder_slice
feat: add folder for onnx.Slice.
2 parents a4bf486 + acae8a3 commit 898b13d

File tree

4 files changed

+31
-1
lines changed

4 files changed

+31
-1
lines changed

src/Dialect/ONNX/ONNXOps.td.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8972,6 +8972,7 @@ def ONNXSliceOp:ONNX_Op<"Slice",
89728972
return sh;
89738973
}
89748974
}];
8975+
let hasFolder = 1;
89758976
}
89768977

89778978
def ONNXSoftmaxOp:ONNX_Op<"Softmax",

src/Dialect/ONNX/ONNXOps/Tensor/Slice.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,16 @@ LogicalResult ONNXSliceOp::inferShapes(
201201
ONNXSliceOpShapeHelper shapeHelper(getOperation(), {});
202202
return shapeHelper.computeShapeAndUpdateType(elementType);
203203
}
204+
205+
//===----------------------------------------------------------------------===//
206+
// Folder
207+
//===----------------------------------------------------------------------===//
208+
OpFoldResult ONNXSliceOp::fold(FoldAdaptor adaptor) {
209+
210+
auto inputTy = llvm::dyn_cast<RankedTensorType>(getData().getType());
211+
auto outputTy = llvm::dyn_cast<RankedTensorType>(getOutput().getType());
212+
if (inputTy && inputTy == outputTy && inputTy.hasStaticShape()) {
213+
return getData();
214+
}
215+
return nullptr;
216+
}

test/mlir/onnx/onnx_fold.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,20 @@ func.func @test_reduceMeanIsNoopWithEmptyAxes(%arg0: tensor<4x512x256x8xf32>) ->
4343
// CHECK-LABEL: @test_reduceMeanIsNoopWithEmptyAxes
4444
// CHECK-SAME: (%[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4x512x256x8xf32> {
4545
// CHECK: return %[[VAL_0]] : tensor<4x512x256x8xf32>
46+
// CHECK: }
47+
48+
// -----
49+
50+
func.func @test_slice(%arg0: tensor<16x1x2500x4xf32>) -> tensor<16x1x2500x4xf32> {
51+
%0 = onnx.Constant dense<0> : tensor<1xi64>
52+
%1 = onnx.Constant dense<4> : tensor<1xi64>
53+
%2 = onnx.Constant dense<3> : tensor<1xi64>
54+
%3 = onnx.Constant dense<1> : tensor<1xi64>
55+
%4 = "onnx.Slice"(%arg0, %0, %1, %2, %3) : (tensor<16x1x2500x4xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<16x1x2500x4xf32>
56+
return %4 : tensor<16x1x2500x4xf32>
57+
}
58+
59+
// CHECK-LABEL: @test_slice
60+
// CHECK-SAME: (%[[VAL_0:.*]]: tensor<16x1x2500x4xf32>) -> tensor<16x1x2500x4xf32> {
61+
// CHECK: return %[[VAL_0]] : tensor<16x1x2500x4xf32>
4662
// CHECK: }

utils/gen_onnx_mlir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@
454454
]
455455

456456
# Op with fold function
457-
OpsWithFolder = ["Constant", "Squeeze", "SqueezeV11", "ReduceMean"]
457+
OpsWithFolder = ["Constant", "Squeeze", "SqueezeV11", "ReduceMean", "Slice"]
458458

459459
# Op with ConstantLike trait
460460
OpsWithConstantLike = ["Constant"]

0 commit comments

Comments
 (0)