Skip to content

Commit 10676b6

Browse files
authored
Merge pull request #321 from lucienwang1009/infer_square_shape
Infer shape for Square op
2 parents 5c5e54a + 03ba8c2 commit 10676b6

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

tf2onnx/shape_inference.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,28 @@
1717
logging.basicConfig(level=logging.INFO)
1818
log = logging.getLogger("tf2onnx.shape_inference")
1919

20+
direct_ops = [
21+
"Cast",
22+
"Enter",
23+
"Exit",
24+
"Floor",
25+
"Identity",
26+
"LogicalNot",
27+
"ReverseSequence",
28+
"Sigmoid",
29+
"Square",
30+
"Tanh"
31+
]
32+
broadcast_ops = [
33+
"Add",
34+
"Greater",
35+
"GreaterEqual",
36+
"Less",
37+
"LogicalAnd",
38+
"Mul",
39+
"RealDiv",
40+
"Sub"
41+
]
2042

2143
def infer_shape_for_graph(g):
2244
no_shape_updated = True
@@ -58,10 +80,10 @@ def infer_shape_for_node(g, node):
5880
log.debug("node %s has inputs don't have shape specified, they are: %s", node.name, no_shape)
5981
return False
6082

61-
if node.type in ["Cast", "Enter", "Exit", "Floor", "LogicalNot", "ReverseSequence", "Sigmoid", "Tanh", "Identity"]:
83+
if node.type in direct_ops:
6284
return set_shape_from_input(g, node.input[0], node.output[0])
6385

64-
if node.type in ["Add", "Greater", "GreaterEqual", "Less", "LogicalAnd", "Mul", "RealDiv", "Sub"]:
86+
if node.type in broadcast_ops:
6587
return set_shape_from_inputs_broadcast(g, node.input, node.output[0])
6688

6789
if node.type == "Placeholder":

0 commit comments

Comments
 (0)