Skip to content

Commit 81d2871

Browse files
authored
Merge pull request #466 from Xilinx/jrickert.backport
Do not assert if an node with subgraph does not implement HasOnnxSubgraphOpInterface
2 parents 974795f + 6247d07 commit 81d2871

File tree

2 files changed

+87
-3
lines changed

2 files changed

+87
-3
lines changed

src/Builder/FrontendDialectTransformer.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -911,9 +911,18 @@ class FrontendGenImpl {
911911
for (const auto &attr : node.attribute()) {
912912
if (attr.type() == onnx::AttributeProto_AttributeType_GRAPH) {
913913
OperationName opName = op->getName();
914-
assert(opName.hasInterface<HasOnnxSubgraphOpInterface>() &&
915-
"Op contains subgraph attributes but does not "
916-
"implement HasOnnxSubgraphOpInterface interface.");
914+
if (!opName.hasInterface<HasOnnxSubgraphOpInterface>()) {
915+
llvm::errs() << "\nWarning: Node " << op
916+
<< " contains subgraph attributes but does not "
917+
"implement HasOnnxSubgraphOpInterface interface. The "
918+
"subgraph will be dropped.\n";
919+
if constexpr (!std::is_same_v<T, ONNXCustomOp>) {
920+
assert(false && "Not-custom ops must implement "
921+
"HasOnnxSubgraphOpInterface if they have "
922+
"subgraph attributes.");
923+
}
924+
continue;
925+
}
917926
auto opWithSubgraph =
918927
mlir::cast<HasOnnxSubgraphOpInterface>(op.getOperation());
919928
auto regionIdx = opWithSubgraph.getSubgraphRegionIdx(attr.name());
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// RUN: onnx-mlir --EmitONNXIR --useOnnxModelTypes=false --printIR %s > output.txt 2>&1; cat output.txt | FileCheck %s
2+
// CHECK: Warning: Node %0 = "onnx.Custom"(%arg0) {domain_name = "my.custom", function_name = "Super", onnx_node_name = "mySuperOp"} : (tensor<1x64x112x112xbf16>) -> tensor<*xbf16> contains subgraph attributes but does not implement HasOnnxSubgraphOpInterface interface. The subgraph will be dropped.
3+
// CHECK-LABEL: func.func @main_graph
4+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x64x112x112xbf16> {onnx.name = "myInput"}) -> (tensor<*xbf16> {onnx.name = "output0"}) {
5+
// CHECK: [[VAR_0_:%.+]] = "onnx.Custom"([[PARAM_0_]]) {domain_name = "my.custom", function_name = "Super", onnx_node_name = "mySuperOp"} : (tensor<1x64x112x112xbf16>) -> tensor<*xbf16>
6+
// CHECK: return [[VAR_0_]] : tensor<*xbf16>
7+
// CHECK: }
8+
{
9+
"irVersion": "8",
10+
"producerName": "handwritten",
11+
"producerVersion": "1.0",
12+
"graph": {
13+
"node": [
14+
{
15+
"input": [
16+
"myInput"
17+
],
18+
"output": [
19+
"output0"
20+
],
21+
"name": "mySuperOp",
22+
"opType": "Super",
23+
"domain": "my.custom",
24+
"attribute": [
25+
{
26+
"name": "body",
27+
"type": "GRAPH"
28+
}
29+
]
30+
}
31+
],
32+
"name": "main_graph",
33+
"input": [
34+
{
35+
"name": "myInput",
36+
"type": {
37+
"tensorType": {
38+
"elemType": 16,
39+
"shape": {
40+
"dim": [
41+
{
42+
"dimValue": "1"
43+
},
44+
{
45+
"dimValue": "64"
46+
},
47+
{
48+
"dimValue": "112"
49+
},
50+
{
51+
"dimValue": "112"
52+
}
53+
]
54+
}
55+
}
56+
}
57+
}
58+
],
59+
"output": [
60+
{
61+
"name": "output0",
62+
"type": {
63+
"tensorType": {
64+
"elemType": 16
65+
}
66+
}
67+
}
68+
]
69+
},
70+
"opsetImport": [
71+
{
72+
"version": "17"
73+
}
74+
]
75+
}

0 commit comments

Comments
 (0)