Skip to content

Commit 67efaa5

Browse files
committed
add same corner test cases for merge duplicated optimizer
1 parent c348562 commit 67efaa5

File tree

1 file changed

+62
-2
lines changed

1 file changed

+62
-2
lines changed

tests/test_optimizers.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def test_duplicated_duplicated_input(self):
306306

307307
graph = helper.make_graph(
308308
[node0, node1, node2, node3, node4],
309-
"transpose-merge-test",
309+
"test_duplicated_duplicated_input",
310310
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (5, 5))],
311311
[helper.make_tensor_value_info("OUT", TensorProto.FLOAT, (5, 5))],
312312
)
@@ -325,14 +325,74 @@ def test_duplicated_duplicated_attributes(self):
325325

326326
graph = helper.make_graph(
327327
[node0, node1, node2, node3, node4],
328-
"transpose-merge-test",
328+
"test_duplicated_duplicated_attributes",
329329
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (5, 5))],
330330
[helper.make_tensor_value_info("OUT", TensorProto.FLOAT, (5,))],
331331
)
332332

333333
model_proto = helper.make_model(graph, producer_name="onnx-tests")
334334
self.run_merge_duplicated_nodes_compare(["OUT"], {"X": np.random.randn(5, 5).astype(np.float32)}, model_proto,
335335
op_type="ReduceSum", remaining_op_num=2)
336+
337+
def test_duplicated_node_is_graph_output(self):
338+
node0 = helper.make_node('Add', inputs=["X", "X"], outputs=["value0"])
339+
node1 = helper.make_node('Add', inputs=["X", "X"], outputs=["value1"])
340+
node2 = helper.make_node('Add', inputs=["value1", "X"], outputs=["value2"])
341+
342+
graph = helper.make_graph(
343+
[node0, node1, node2],
344+
"test_duplicated_node_is_graph_output",
345+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (5, 5))],
346+
[helper.make_tensor_value_info("value1", TensorProto.FLOAT, (5, 5)),
347+
helper.make_tensor_value_info("value2", TensorProto.FLOAT, (5, 5))],
348+
)
349+
350+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
351+
self.run_merge_duplicated_nodes_compare(["value1", "value2"], {"X": np.random.randn(5, 5).astype(np.float32)}, model_proto,
352+
op_type="Add", remaining_op_num=2)
353+
354+
def test_duplicated_different_output_length(self):
355+
node0 = helper.make_node('Dropout', inputs=["X"], outputs=["value0"])
356+
node1 = helper.make_node('Dropout', inputs=["X"], outputs=["value1", "mask"])
357+
node2 = helper.make_node('Dropout', inputs=["value1"], outputs=["value2"])
358+
359+
graph = helper.make_graph(
360+
[node0, node1, node2],
361+
"test_duplicated_different_output_length",
362+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (5,))],
363+
[helper.make_tensor_value_info("value1", TensorProto.FLOAT, (5,)),
364+
helper.make_tensor_value_info("mask", TensorProto.BOOL, (5,)),
365+
helper.make_tensor_value_info("value2", TensorProto.FLOAT, (5,))],
366+
)
367+
368+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
369+
self.run_merge_duplicated_nodes_compare(["value1", "mask", "value2"],
370+
{"X": np.random.randn(5,).astype(np.float32)},
371+
model_proto,
372+
op_type="Dropout", remaining_op_num=2)
373+
374+
def test_duplicated_need_multiple_run(self):
375+
node00 = helper.make_node('Log', inputs=["X"], outputs=["value00"])
376+
node01 = helper.make_node('Log', inputs=["value00"], outputs=["value01"])
377+
node02 = helper.make_node('Log', inputs=["value01"], outputs=["value02"])
378+
379+
node10 = helper.make_node('Log', inputs=["X"], outputs=["value10"])
380+
node11 = helper.make_node('Log', inputs=["value10"], outputs=["value11"])
381+
node12 = helper.make_node('Log', inputs=["value11"], outputs=["value12"])
382+
383+
res = helper.make_node('Add', inputs=["value02", "value12"], outputs=["res"])
384+
385+
graph = helper.make_graph(
386+
[node00, node01, node02, node10, node11, node12, res],
387+
"test_duplicated_node_is_graph_output",
388+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (5,))],
389+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (5,))],
390+
)
391+
392+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
393+
self.run_merge_duplicated_nodes_compare(["res"], {"X": np.random.randn(5,).astype(np.float32)},
394+
model_proto,
395+
op_type="Log", remaining_op_num=3)
336396
# Merge Duplicated Nodes Optimizer Tests End
337397

338398

0 commit comments

Comments
 (0)