Skip to content

Commit ba32f8c

Browse files
committed
enhance tf shape inference for LessEqual and Max
1 parent 043475e commit ba32f8c

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tf2onnx/shape_inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"Greater",
3636
"GreaterEqual",
3737
"Less",
38+
"LessEqual",
3839
"LogicalAnd",
3940
"LogicalOr",
4041
"Mul",
@@ -154,7 +155,7 @@ def infer_shape_for_node(g, node):
154155
g.set_shape(node.output[0], shape)
155156
return True
156157

157-
if node.type in ["All", "Any", "Min"]:
158+
if node.type in ["All", "Any", "Max", "Min"]:
158159
axis_node = node.inputs[1]
159160
axis = axis_node.get_tensor_value()
160161
if not isinstance(axis, list):

0 commit comments

Comments
 (0)