Skip to content

Commit 68c1b27

Browse files
committed
shape inference support LogicalOr
1 parent 0b4ce93 commit 68c1b27

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tf2onnx/shape_inference.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@
3636
"GreaterEqual",
3737
"Less",
3838
"LogicalAnd",
39+
"LogicalOr",
3940
"Mul",
4041
"RealDiv",
4142
"Sub"
4243
]
4344

45+
4446
def infer_shape_for_graph(g):
4547
no_shape_updated = True
4648
while no_shape_updated:
@@ -135,7 +137,7 @@ def infer_shape_for_node(g, node):
135137
axis += len(s1)
136138
new_shape = s1[:axis] + [val]
137139
if axis < len(s1) - 1:
138-
new_shape += s1[axis+1:]
140+
new_shape += s1[axis + 1:]
139141

140142
g.set_shape(node.output[0], new_shape)
141143
log.debug("set ConcatV2 node [%s] with new shape %s", node.output[0], new_shape)
@@ -148,7 +150,7 @@ def infer_shape_for_node(g, node):
148150
shape_indices = g.get_shape(node.input[1])
149151
axis = node.input[2].get_tensor_value()
150152

151-
shape = shape_params[:axis] + shape_indices + shape_indices[axis+1:]
153+
shape = shape_params[:axis] + shape_indices + shape_indices[axis + 1:]
152154
g.set_shape(node.output[0], shape)
153155
return True
154156

0 commit comments

Comments
 (0)