@@ -1291,3 +1291,157 @@ def version_10(cls, ctx, node, **kwargs):
1291
1291
target_dtype = TensorProto .INT64
1292
1292
if seq_len_dtype != target_dtype :
1293
1293
ctx .insert_new_node_on_input (node , "Cast" , node .input [1 ], to = target_dtype )
1294
+
1295
+
1296
+ @tf_op ("ReverseV2" )
1297
+ class ReverseV2 :
1298
+ @classmethod
1299
+ def version_10 (cls , ctx , node , ** kwargs ):
1300
+ # T output = ReverseV2(T input, int32|int64 seq_lengths, @int seq_dim, @int batch_dim)
1301
+ # Implement tensorflow ReverseV2 op using multiple ReverseSequence (for each axis)
1302
+ # and Transpose ops. We sort the axis vector (if non-empty) at the start. Each axis can
1303
+ # be reversed only once (in tf) and so we can compute the transpose for each axis
1304
+ # (other than 0), feed the tensor to a ReverseSequence node and finally transpose again
1305
+ # to get back the original shape.
1306
+
1307
+ axes_node = node .inputs [1 ]
1308
+ axes = axes_node .get_tensor_value (as_list = False )
1309
+ # Current support is for when axis is a 1D tensor.
1310
+ utils .make_sure (len (axes .shape ) == 1 \
1311
+ , "Currently no support for reverseV2 tensor axis" )
1312
+
1313
+ axes = axes .tolist ()
1314
+ len_axes = len (axes )
1315
+
1316
+ # Store input and output parameters of the ReverseV2 node.
1317
+ rv2_in_names = [node .input [0 ]]
1318
+
1319
+ input_shape = ctx .get_shape (node .input [0 ])
1320
+ # Make sure input shape is not None
1321
+ utils .make_sure (input_shape is not None , "shape of {} is None" .format (node .input [0 ]))
1322
+
1323
+ input_rank = len (input_shape )
1324
+
1325
+ rv2_node_name = node .name
1326
+ # ReverseV2 has a single output.
1327
+ rv2_output_dtypes = node .output_dtypes
1328
+ rv2_output_shapes = node .output_shapes
1329
+
1330
+ const_name_root = rv2_node_name + '_Const'
1331
+
1332
+ # Remove ReverseV2 node from graph.
1333
+ ctx .remove_node (rv2_node_name )
1334
+
1335
+ # Variable to store input names for the next node.
1336
+ inputs = rv2_in_names
1337
+
1338
+ new_node = None
1339
+
1340
+ # Empty axis vector.
1341
+ if len_axes == 0 :
1342
+ # Replace ReverseV2 with an identity block.
1343
+ new_node = ctx .make_node (
1344
+ "Identity" ,
1345
+ inputs = inputs ,
1346
+ outputs = node .output ,
1347
+ shapes = rv2_output_shapes ,
1348
+ dtypes = rv2_output_dtypes ,
1349
+ op_name_scope = rv2_node_name ,
1350
+ )
1351
+
1352
+ else :
1353
+ # For negative indices use the positive counterpart.
1354
+ for i , ax in enumerate (axes ):
1355
+ if ax < 0 :
1356
+ axes [i ] += input_rank
1357
+
1358
+ axes = sorted (axes )
1359
+
1360
+ orig_perm = list (range (input_rank ))
1361
+ curr_perm = []
1362
+
1363
+ # Add ReverseSequence nodes for each element of axis.
1364
+ for i in range (len_axes ):
1365
+
1366
+ axis = axes [i ]
1367
+
1368
+ curr_perm = orig_perm .copy ()
1369
+ # Permutation indices relative to original tensor.
1370
+ curr_perm [axis ], curr_perm [0 ] = curr_perm [0 ], curr_perm [axis ]
1371
+
1372
+ # Add a Transpose node if the axis != 0 (finish first due to sort).
1373
+ if axis != 0 :
1374
+ # Permutation indices for the transpose node relative to IN tensor shape.
1375
+ new_node = ctx .make_node (
1376
+ "Transpose" ,
1377
+ inputs = inputs ,
1378
+ op_name_scope = rv2_node_name ,
1379
+ dtypes = rv2_output_dtypes ,
1380
+ attr = {"perm" : curr_perm }
1381
+ )
1382
+
1383
+ inputs = [new_node .output [0 ]]
1384
+
1385
+ # Add a Constant node (seq_len) for ReverseSequence.
1386
+
1387
+ # Index 1 for the shape should not return 0
1388
+ # since the input must have rank >= 2.
1389
+ rs_batch_size = ctx .get_shape (inputs [- 1 ])[1 ]
1390
+
1391
+ # Make sure rs_batch_size and input_shape[axis] are not -1 each
1392
+ utils .make_sure (input_shape [axis ] is not - 1 \
1393
+ , "shape of axis {} is unknown" .format (axis ))
1394
+ utils .make_sure (rs_batch_size is not - 1 \
1395
+ , "ReverseSequence batch size for axis {} is unknown" .format (axis ))
1396
+
1397
+ seq_list = [input_shape [axis ]] * rs_batch_size
1398
+ seq_array = np .asarray (seq_list , dtype = np .int64 ) # dtype should be int64
1399
+
1400
+ const_seq_name = utils .make_name (const_name_root )
1401
+ new_node = ctx .make_const (name = const_seq_name , np_val = seq_array )
1402
+ inputs .append (new_node .output [0 ])
1403
+
1404
+ # Add a ReverseSequence node.
1405
+
1406
+ # If processing for the final axis and the tensor shape permutation is
1407
+ # original then the output is fed to the output of the ReverseV2 node.
1408
+ #
1409
+ # Else a new output is created which is fed to a Transpose node.
1410
+ rs_out_name = node .output if \
1411
+ ((i == len_axes - 1 ) and (curr_perm == orig_perm )) \
1412
+ else None
1413
+
1414
+ rs_out_shapes = None if rs_out_name is None else rv2_output_shapes
1415
+
1416
+ new_node = ctx .make_node (
1417
+ "ReverseSequence" ,
1418
+ inputs = inputs ,
1419
+ op_name_scope = rv2_node_name ,
1420
+ outputs = rs_out_name ,
1421
+ shapes = rs_out_shapes ,
1422
+ dtypes = rv2_output_dtypes ,
1423
+ attr = {"batch_axis" : 1 , "time_axis" : 0 }
1424
+ )
1425
+
1426
+ inputs = [new_node .output [0 ]]
1427
+
1428
+ # Additional transpose block is required if the current
1429
+ # permutation list is not the original one.
1430
+ if curr_perm != orig_perm :
1431
+
1432
+ # Compute the required permutation list.
1433
+ if len_axes != 1 :
1434
+ for i , ax in enumerate (axes [::- 1 ][1 :]):
1435
+ curr_perm [0 ], curr_perm [ax ] = \
1436
+ curr_perm [ax ], curr_perm [0 ]
1437
+
1438
+ # Add a Transpose node to restore shape.
1439
+ new_node = ctx .make_node (
1440
+ "Transpose" ,
1441
+ inputs = inputs ,
1442
+ op_name_scope = rv2_node_name ,
1443
+ outputs = node .output ,
1444
+ shapes = rv2_output_shapes ,
1445
+ dtypes = rv2_output_dtypes ,
1446
+ attr = {"perm" : curr_perm }
1447
+ )
0 commit comments