Skip to content

Commit dba60d5

Browse files
authored
Merge pull request #390 from nbcsm/fix_qv
enhance tf shape inference for LessEqual and Max
2 parents 7b78248 + ba32f8c commit dba60d5

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

tests/run_pretrained_models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def run_tensorflow(self, sess, inputs):
153153

154154
def to_onnx(self, tf_graph, opset=None, shape_override=None, input_names=None):
155155
"""Convert graph to tensorflow."""
156-
return process_tf_graph(tf_graph, continue_on_error=True, verbose=True, opset=opset,
156+
return process_tf_graph(tf_graph, continue_on_error=False, verbose=True, opset=opset,
157157
target=Test.target, shape_override=shape_override,
158158
input_names=input_names, output_names=self.output_names)
159159

@@ -186,7 +186,6 @@ def run_onnxruntime(self, name, model_proto, inputs):
186186
"""Run test against msrt-next backend."""
187187
import onnxruntime as rt
188188
model_path = utils.save_onnx_model(TEMP_DIR, name, inputs, model_proto, include_test_data=True)
189-
utils.save_onnx_model(TEMP_DIR, name, inputs, model_proto, include_test_data=False, as_text=True)
190189
print("\t\t" + model_path)
191190
m = rt.InferenceSession(model_path)
192191
results = m.run(self.output_names, inputs)

tf2onnx/shape_inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"Greater",
3636
"GreaterEqual",
3737
"Less",
38+
"LessEqual",
3839
"LogicalAnd",
3940
"LogicalOr",
4041
"Mul",
@@ -154,7 +155,7 @@ def infer_shape_for_node(g, node):
154155
g.set_shape(node.output[0], shape)
155156
return True
156157

157-
if node.type in ["All", "Any", "Min"]:
158+
if node.type in ["All", "Any", "Max", "Min"]:
158159
axis_node = node.inputs[1]
159160
axis = axis_node.get_tensor_value()
160161
if not isinstance(axis, list):

0 commit comments

Comments
 (0)