Skip to content

Commit ed88b44

Browse files
committed
Add shape inference for microsoft QuickGelu
1 parent 8283c06 commit ed88b44

File tree

2 files changed

+124
-1
lines changed

2 files changed

+124
-1
lines changed

src/Builder/FrontendDialectTransformer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1566,7 +1566,7 @@ class FrontendGenImpl {
15661566
} else if (opName == "QuantizeLinear") {
15671567
outElementType =
15681568
cast<ShapedType>(inputs.at(2).getType()).getElementType();
1569-
} else if (opName == "Gelu") {
1569+
} else if (opName == "Gelu" || opName == "QuickGelu") {
15701570
outElementType =
15711571
cast<ShapedType>(inputs.at(0).getType()).getElementType();
15721572
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// RUN: onnx-mlir --EmitONNXIR --useOnnxModelTypes=false --printIR %s | FileCheck %s
2+
// Semi hand-written model.
3+
// CHECK-LABEL: func.func @main_graph
4+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x1x1xf32> {onnx.name = "x"}, [[PARAM_1_:%.+]]: tensor<1x3x1x3xf32> {onnx.name = "add_in"}) -> (tensor<1x3x1x3xf32> {onnx.name = "add_out"}) {
5+
// CHECK: [[VAR_0_:%.+]] = "onnx.Custom"([[PARAM_0_]]) {domain_name = "com.microsoft", function_name = "QuickGelu", onnx_node_name = "onnx.Custom_0", output_element_type = f32, shape_infer_pattern = "SameAs"} : (tensor<1x3x1x1xf32>) -> tensor<1x3x1x1xf32>
6+
// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_0_]], [[PARAM_1_]]) {onnx_node_name = "onnx.Add_1"} : (tensor<1x3x1x1xf32>, tensor<1x3x1x3xf32>) -> tensor<1x3x1x3xf32>
7+
// CHECK: return [[VAR_1_]] : tensor<1x3x1x3xf32>
8+
// CHECK: }
9+
{
10+
"irVersion": "10",
11+
"producerName": "onnx-custom-op-example",
12+
"graph": {
13+
"node": [
14+
{
15+
"input": [
16+
"x"
17+
],
18+
"output": [
19+
"y"
20+
],
21+
"opType": "QuickGelu",
22+
"domain": "com.microsoft"
23+
},
24+
{
25+
"input": [
26+
"y",
27+
"add_in"
28+
],
29+
"output": [
30+
"add_out"
31+
],
32+
"opType": "Add"
33+
}
34+
],
35+
"name": "CustomQuantizeLinearGraph",
36+
"input": [
37+
{
38+
"name": "x",
39+
"type": {
40+
"tensorType": {
41+
"elemType": 1,
42+
"shape": {
43+
"dim": [
44+
{
45+
"dimValue": "1"
46+
},
47+
{
48+
"dimValue": "3"
49+
},
50+
{
51+
"dimValue": "1"
52+
},
53+
{
54+
"dimValue": "1"
55+
}
56+
]
57+
}
58+
}
59+
}
60+
},
61+
{
62+
"name": "add_in",
63+
"type": {
64+
"tensorType": {
65+
"elemType": 1,
66+
"shape": {
67+
"dim": [
68+
{
69+
"dimValue": "1"
70+
},
71+
{
72+
"dimValue": "3"
73+
},
74+
{
75+
"dimValue": "1"
76+
},
77+
{
78+
"dimValue": "3"
79+
}
80+
]
81+
}
82+
}
83+
}
84+
}
85+
],
86+
"output": [
87+
{
88+
"name": "add_out",
89+
"type": {
90+
"tensorType": {
91+
"elemType": 1,
92+
"shape": {
93+
"dim": [
94+
{
95+
"dimValue": "1"
96+
},
97+
{
98+
"dimValue": "3"
99+
},
100+
{
101+
"dimValue": "1"
102+
},
103+
{
104+
"dimValue": "3"
105+
}
106+
]
107+
}
108+
}
109+
}
110+
}
111+
]
112+
},
113+
"opsetImport": [
114+
{
115+
"domain": "",
116+
"version": "17"
117+
},
118+
{
119+
"domain": "com.microsoft",
120+
"version": "1"
121+
}
122+
]
123+
}

0 commit comments

Comments
 (0)