Skip to content

Commit acae8a3

Browse files
committed
feat: add folder for onnx.Slice.
1 parent b742d71 commit acae8a3

File tree

4 files changed

+31
-1
lines changed

4 files changed

+31
-1
lines changed

src/Dialect/ONNX/ONNXOps.td.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8946,6 +8946,7 @@ def ONNXSliceOp:ONNX_Op<"Slice",
89468946
return sh;
89478947
}
89488948
}];
8949+
let hasFolder = 1;
89498950
}
89508951

89518952
def ONNXSoftmaxOp:ONNX_Op<"Softmax",

src/Dialect/ONNX/ONNXOps/Tensor/Slice.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,16 @@ LogicalResult ONNXSliceOp::inferShapes(
201201
ONNXSliceOpShapeHelper shapeHelper(getOperation(), {});
202202
return shapeHelper.computeShapeAndUpdateType(elementType);
203203
}
204+
205+
//===----------------------------------------------------------------------===//
206+
// Folder
207+
//===----------------------------------------------------------------------===//
208+
OpFoldResult ONNXSliceOp::fold(FoldAdaptor adaptor) {
209+
210+
auto inputTy = llvm::dyn_cast<RankedTensorType>(getData().getType());
211+
auto outputTy = llvm::dyn_cast<RankedTensorType>(getOutput().getType());
212+
if (inputTy && inputTy == outputTy && inputTy.hasStaticShape()) {
213+
return getData();
214+
}
215+
return nullptr;
216+
}

test/mlir/onnx/onnx_fold.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,20 @@ func.func @test_reduceMeanIsNoopWithEmptyAxes(%arg0: tensor<4x512x256x8xf32>) ->
4343
// CHECK-LABEL: @test_reduceMeanIsNoopWithEmptyAxes
4444
// CHECK-SAME: (%[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4x512x256x8xf32> {
4545
// CHECK: return %[[VAL_0]] : tensor<4x512x256x8xf32>
46+
// CHECK: }
47+
48+
// -----
49+
50+
func.func @test_slice(%arg0: tensor<16x1x2500x4xf32>) -> tensor<16x1x2500x4xf32> {
51+
%0 = onnx.Constant dense<0> : tensor<1xi64>
52+
%1 = onnx.Constant dense<4> : tensor<1xi64>
53+
%2 = onnx.Constant dense<3> : tensor<1xi64>
54+
%3 = onnx.Constant dense<1> : tensor<1xi64>
55+
%4 = "onnx.Slice"(%arg0, %0, %1, %2, %3) : (tensor<16x1x2500x4xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<16x1x2500x4xf32>
56+
return %4 : tensor<16x1x2500x4xf32>
57+
}
58+
59+
// CHECK-LABEL: @test_slice
60+
// CHECK-SAME: (%[[VAL_0:.*]]: tensor<16x1x2500x4xf32>) -> tensor<16x1x2500x4xf32> {
61+
// CHECK: return %[[VAL_0]] : tensor<16x1x2500x4xf32>
4662
// CHECK: }

utils/gen_onnx_mlir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@
447447
]
448448

449449
# Op with fold function
450-
OpsWithFolder = ["Constant", "Squeeze", "SqueezeV11", "ReduceMean"]
450+
OpsWithFolder = ["Constant", "Squeeze", "SqueezeV11", "ReduceMean", "Slice"]
451451

452452
# Op with ConstantLike trait
453453
OpsWithConstantLike = ["Constant"]

0 commit comments

Comments
 (0)