Skip to content

Commit 8817d27

Browse files
Check shapes and dtypes in optimizer tests (#1452)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 30ec084 commit 8817d27

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

tests/test_optimizers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
3030

3131
origin_model_path = self.save_onnx_model(origin_proto, onnx_feed_dict, postfix="_origin")
3232

33-
new_proto = GraphUtil.optimize_model_proto(origin_proto, catch_errors=False)
33+
new_proto, new_graph = GraphUtil.optimize_model_proto(origin_proto, catch_errors=False, return_graph=True)
3434

3535
self.assertTrue(new_proto, msg="model proto after optimizer should not be None")
3636

@@ -52,6 +52,8 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
5252
self.assertEqual(expected_val.dtype, actual_val.dtype)
5353
self.assertEqual(expected_val.shape, actual_val.shape)
5454

55+
self.assert_shapes_correct(new_graph, allow_missing=False, run_checker=True)
56+
5557
return new_proto
5658

5759
@staticmethod

tf2onnx/graph.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1565,11 +1565,11 @@ def optimize_graph(graph, catch_errors=True):
15651565
return optimizer.optimize_graph(graph, catch_errors)
15661566

15671567
@staticmethod
1568-
def optimize_model_proto(onnx_model_proto, catch_errors=True):
1568+
def optimize_model_proto(onnx_model_proto, catch_errors=True, return_graph=False):
15691569
"""Optimize the model proto, for example: eliminating all useless Transpose pairs.
15701570
15711571
Returns:
1572-
model proto after optimization, if optimizer run successfully
1572+
model proto (and possibly graph) after optimization, if optimizer run successfully
15731573
or onnx_model_proto, if exceptions happens
15741574
"""
15751575
try:
@@ -1582,13 +1582,17 @@ def optimize_model_proto(onnx_model_proto, catch_errors=True):
15821582
if onnx_model_proto.metadata_props:
15831583
metadata_props = {p.key: p.value for p in onnx_model_proto.metadata_props}
15841584
helper.set_model_props(model_proto, metadata_props)
1585+
if return_graph:
1586+
return model_proto, graph
15851587
return model_proto
15861588
except Exception as e:
15871589
if not catch_errors:
15881590
raise e
15891591
# sometimes, onnx shape inference will fail for some reason,
15901592
# return onnx_model_proto for this case
15911593
logger.warning("Failed to optimize model proto", exc_info=1)
1594+
if return_graph:
1595+
return onnx_model_proto, None
15921596
return onnx_model_proto
15931597

15941598
@staticmethod

0 commit comments

Comments
 (0)