@@ -50,6 +50,21 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
50
50
self .assertEqual (expected_val .shape , actual_val .shape )
51
51
52
52
return new_proto
53
+
54
+ @staticmethod
55
+ def _make_onnx_const (np_val , output_name ):
56
+ node = helper .make_node (
57
+ 'Constant' ,
58
+ inputs = [],
59
+ outputs = [output_name ],
60
+ value = helper .make_tensor (
61
+ name = output_name ,
62
+ data_type = utils .map_numpy_to_onnx_dtype (np_val .dtype ),
63
+ dims = np_val .shape ,
64
+ vals = np_val .flatten ().astype (np_val .dtype ),
65
+ ),
66
+ )
67
+ return node
53
68
# Tranpose Optimizer Tests Start
54
69
55
70
def run_transpose_compare (self , output_names_with_port , onnx_feed_dict , origin_proto ,
@@ -304,6 +319,55 @@ def test_transpose_with_squeeze4(self):
304
319
self .run_transpose_compare (["Z" ], {"X" : np .random .randn (3 , 1 , 1 , 5 ).astype (np .float32 )},
305
320
model_proto , remaining_transpose_num = 0 )
306
321
322
+ def test_transpose_with_loop (self ):
323
+ def _define_loop_graph (external_inputs ):
324
+ # external_inputs: external node which will be used by this graph
325
+ # graph without loop carried
326
+ # computation
327
+ # for(...){a = external_inputs[i]; b = trans(a), c = squeeze(b)}, c is scan output
328
+ node1 = helper .make_node ("Gather" , [external_inputs [0 ], "loop_iter_num" ], ["Y0" ])
329
+ node2 = helper .make_node ("Transpose" , ["Y0" ], ["Z0" ], perm = [0 , 2 , 3 , 1 ])
330
+ # graph output
331
+ node3 = helper .make_node ("Squeeze" , ["Z0" ], ["scan_output" ], axes = [0 ])
332
+ node4 = helper .make_node ("Identity" , ["loop_condition" ], ["loop_cond_output" ])
333
+ node5 = helper .make_node ("Identity" , ["loop_condition" ], ["loop_carried_output" ])
334
+
335
+ graph = helper .make_graph (
336
+ [node1 , node2 , node3 , node4 , node5 ],
337
+ "loop_subgraph" ,
338
+ [helper .make_tensor_value_info ("loop_iter_num" , TensorProto .INT64 , (1 ,)), # iteration_num
339
+ helper .make_tensor_value_info ("loop_condition" , TensorProto .BOOL , ()), # condition
340
+ helper .make_tensor_value_info ("loop_carried" , TensorProto .BOOL , ()) # loop_carried
341
+ ],
342
+ [helper .make_tensor_value_info ("loop_cond_output" , TensorProto .BOOL , ()),
343
+ helper .make_tensor_value_info ("loop_carried_output" , TensorProto .BOOL , ()),
344
+ helper .make_tensor_value_info ("scan_output" , TensorProto .FLOAT , ["unknown" ] * 3 )
345
+ ],
346
+ )
347
+ return graph
348
+
349
+ def _make_loop (external_inputs , outputs ):
350
+ trip_cnt = self ._make_onnx_const (np .array (10 , dtype = np .int64 ), "trip_cnt" )
351
+ cond = self ._make_onnx_const (np .array (True , dtype = np .bool ), "cond" )
352
+ sub_graph = _define_loop_graph (external_inputs )
353
+ loop_node = helper .make_node ("Loop" , ["trip_cnt" , "cond" , "cond" ], outputs ,
354
+ name = "loop" , body = sub_graph )
355
+ return trip_cnt , cond , loop_node
356
+
357
+ nodes = _make_loop (["array" ], ["loop_carried" , "scan_out" ])
358
+ res = helper .make_node ("Transpose" , ["scan_out" ], ["Y" ], perm = [0 , 3 , 1 , 2 ], name = "trans" )
359
+
360
+ graph = helper .make_graph (
361
+ [* nodes , res ],
362
+ "transpose_with_loop" ,
363
+ [helper .make_tensor_value_info ("array" , TensorProto .FLOAT , ["unknow" ] * 4 )],
364
+ [helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , ["unknow" ] * 4 )],
365
+ )
366
+
367
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
368
+ self .run_transpose_compare (["Y" ], {"array" : np .random .randn (10 , 3 , 4 , 5 ).astype (np .float32 )},
369
+ model_proto , remaining_transpose_num = 0 )
370
+
307
371
def test_trans_output_as_graph_outputs (self ):
308
372
"""
309
373
If transpose's output is graph's output, don't optimize it.
0 commit comments