Skip to content

Commit 3c980df

Browse files
authored
Merge pull request #454 from Xilinx/jrickert.mixed_dialects
Adjust IndexExprBuilderForAnalysis::getConst to not only handle onnx.…
2 parents ac97f19 + c58cf70 commit 3c980df

File tree

5 files changed

+47
-3
lines changed

5 files changed

+47
-3
lines changed

src/Dialect/ONNX/DialectBuilder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,7 @@ Value OnnxBuilder::getOrCastToI8(Value val, bool simpleCast) {
887887

888888
// Return null if none is found.
889889
ElementsAttr IndexExprBuilderForAnalysis::getConst(Value value) {
890-
return getElementAttributeFromONNXValue(value);
890+
return getElementAttributeFromConstLikeValue(value);
891891
}
892892

893893
// Return null if the value at index i is not a constant.

src/Dialect/ONNX/ONNXOps/OpHelper.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,29 @@ ElementsAttr getElementAttributeFromONNXValue(Value value) {
316316
return nullptr;
317317
}
318318

319+
ElementsAttr getElementAttributeFromConstLikeValue(Value value) {
320+
auto *definingOp = value.getDefiningOp();
321+
if (!isConstLikeOperation(definingOp)) {
322+
return nullptr;
323+
}
324+
SmallVector<OpFoldResult, 1> foldResults;
325+
[[maybe_unused]] const LogicalResult folded = definingOp->fold(foldResults);
326+
assert(succeeded(folded) && "ConstantLike op failed to fold");
327+
assert(foldResults.size() == 1 &&
328+
"ConstantLike op fold produced more results than expected");
329+
auto foldAttr = dyn_cast<Attribute>(foldResults[0]);
330+
assert(foldAttr && "ConstantLike op fold did not return an Attribute");
331+
return dyn_cast<ElementsAttr>(foldAttr);
332+
}
333+
334+
bool isConstLikeValue(Value value) {
335+
return isConstLikeOperation(value.getDefiningOp());
336+
}
337+
338+
bool isConstLikeOperation(mlir::Operation *op) {
339+
return op && op->hasTrait<OpTrait::ConstantLike>();
340+
}
341+
319342
// compare two ElementsAttr, except for their internal buffer size
320343
bool compareValueFromElementAttribute(
321344
ElementsAttr &attr1, ElementsAttr &attr2) {

src/Dialect/ONNX/ONNXOps/OpHelper.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,14 @@ void ArrayAttrIntVals(mlir::ArrayAttr a, mlir::SmallVectorImpl<int64_t> &i);
174174

175175
mlir::ElementsAttr getElementAttributeFromONNXValue(mlir::Value value);
176176

177+
// Get a ElementsAttr from a value that is defined by a ConstantLike op that
178+
// folds to an ElementsAttr
179+
mlir::ElementsAttr getElementAttributeFromConstLikeValue(mlir::Value value);
180+
181+
[[nodiscard]] bool isConstLikeValue(mlir::Value value);
182+
183+
[[nodiscard]] bool isConstLikeOperation(mlir::Operation *op);
184+
177185
bool compareValueFromElementAttribute(
178186
mlir::ElementsAttr &attr1, mlir::ElementsAttr &attr2);
179187

src/Dialect/ONNX/ONNXOps/Tensor/Slice.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ LogicalResult ONNXSliceOp::inferShapes(
160160

161161
// Cannot infer shape if axes is not a constant. It can be a constant after
162162
// several rounds of shape-inference and constant propagation.
163-
if (!isNoneValue(axes) && !getONNXConstantOp(axes))
163+
if (!isNoneValue(axes) && !isConstLikeValue(axes))
164164
return success();
165165

166166
const auto startsType =

test/mlir/onnx/onnx_shape_inference.mlir

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4454,4 +4454,17 @@ func.func @test_slice_negative_steps(%arg0: tensor<100x200xf32>) -> tensor<*xf32
44544454
}
44554455
// CHECK-LABEL: func.func @test_slice_negative_steps
44564456
// CHECK: "onnx.Slice"
4457-
// CHECK-SAME: (tensor<100x200xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<16x16xf32>
4457+
// CHECK-SAME: (tensor<100x200xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<16x16xf32>
4458+
4459+
// -----
4460+
func.func @test_slice_negative_steps_mixed_dialects(%arg0: tensor<100x200xf32>) -> tensor<*xf32> {
4461+
%axes = "tosa.const"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64>
4462+
%starts = "tosa.const"() {value = dense<[-10, -20]> : tensor<2xi64> } : () -> tensor<2xi64>
4463+
%ends = "tosa.const"() {value = dense<[10, 20]> : tensor<2xi64> } : () -> tensor<2xi64>
4464+
%steps = "tosa.const"() {value = dense<[-5, -10]> : tensor<2xi64> } : () -> tensor<2xi64>
4465+
%1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<100x200xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32>
4466+
return %1 : tensor<*xf32>
4467+
}
4468+
// CHECK-LABEL: func.func @test_slice_negative_steps_mixed_dialects
4469+
// CHECK: "onnx.Slice"
4470+
// CHECK-SAME: (tensor<100x200xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<16x16xf32>

0 commit comments

Comments
 (0)