Skip to content

Commit 4c5423f

Browse files
committed
add test for transpose opt with loop
1 parent c68e2c5 commit 4c5423f

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

tests/test_optimizers.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,21 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
5050
self.assertEqual(expected_val.shape, actual_val.shape)
5151

5252
return new_proto
53+
54+
@staticmethod
55+
def _make_onnx_const(np_val, output_name):
56+
node = helper.make_node(
57+
'Constant',
58+
inputs=[],
59+
outputs=[output_name],
60+
value=helper.make_tensor(
61+
name=output_name,
62+
data_type=utils.map_numpy_to_onnx_dtype(np_val.dtype),
63+
dims=np_val.shape,
64+
vals=np_val.flatten().astype(np_val.dtype),
65+
),
66+
)
67+
return node
5368
# Tranpose Optimizer Tests Start
5469

5570
def run_transpose_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
@@ -304,6 +319,55 @@ def test_transpose_with_squeeze4(self):
304319
self.run_transpose_compare(["Z"], {"X": np.random.randn(3, 1, 1, 5).astype(np.float32)},
305320
model_proto, remaining_transpose_num=0)
306321

322+
def test_transpose_with_loop(self):
323+
def _define_loop_graph(external_inputs):
324+
# external_inputs: external node which will be used by this graph
325+
# graph without loop carried
326+
# computation
327+
# for(...){a = external_inputs[i]; b = trans(a), c = squeeze(b)}, c is scan output
328+
node1 = helper.make_node("Gather", [external_inputs[0], "loop_iter_num"], ["Y0"])
329+
node2 = helper.make_node("Transpose", ["Y0"], ["Z0"], perm=[0, 2, 3, 1])
330+
# graph output
331+
node3 = helper.make_node("Squeeze", ["Z0"], ["scan_output"], axes=[0])
332+
node4 = helper.make_node("Identity", ["loop_condition"], ["loop_cond_output"])
333+
node5 = helper.make_node("Identity", ["loop_condition"], ["loop_carried_output"])
334+
335+
graph = helper.make_graph(
336+
[node1, node2, node3, node4, node5],
337+
"loop_subgraph",
338+
[helper.make_tensor_value_info("loop_iter_num", TensorProto.INT64, (1,)), # iteration_num
339+
helper.make_tensor_value_info("loop_condition", TensorProto.BOOL, ()), # condition
340+
helper.make_tensor_value_info("loop_carried", TensorProto.BOOL, ()) # loop_carried
341+
],
342+
[helper.make_tensor_value_info("loop_cond_output", TensorProto.BOOL, ()),
343+
helper.make_tensor_value_info("loop_carried_output", TensorProto.BOOL, ()),
344+
helper.make_tensor_value_info("scan_output", TensorProto.FLOAT, ["unknown"] * 3)
345+
],
346+
)
347+
return graph
348+
349+
def _make_loop(external_inputs, outputs):
350+
trip_cnt = self._make_onnx_const(np.array(10, dtype=np.int64), "trip_cnt")
351+
cond = self._make_onnx_const(np.array(True, dtype=np.bool), "cond")
352+
sub_graph = _define_loop_graph(external_inputs)
353+
loop_node = helper.make_node("Loop", ["trip_cnt", "cond", "cond"], outputs,
354+
name="loop", body=sub_graph)
355+
return trip_cnt, cond, loop_node
356+
357+
nodes = _make_loop(["array"], ["loop_carried", "scan_out"])
358+
res = helper.make_node("Transpose", ["scan_out"], ["Y"], perm=[0, 3, 1, 2], name="trans")
359+
360+
graph = helper.make_graph(
361+
[*nodes, res],
362+
"transpose_with_loop",
363+
[helper.make_tensor_value_info("array", TensorProto.FLOAT, ["unknow"] * 4)],
364+
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, ["unknow"] * 4)],
365+
)
366+
367+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
368+
self.run_transpose_compare(["Y"], {"array": np.random.randn(10, 3, 4, 5).astype(np.float32)},
369+
model_proto, remaining_transpose_num=0)
370+
307371
def test_trans_output_as_graph_outputs(self):
308372
"""
309373
If transpose's output is graph's output, don't optimize it.

0 commit comments

Comments
 (0)