@@ -423,7 +423,10 @@ class GatherV2:
423
423
@classmethod
424
424
def version_1 (cls , ctx , node , ** kwargs ):
425
425
# for GatherV2 axis come as input
426
+ err_msg = "Opset 12 required for batch_dims attribute of GatherV2"
427
+ utils .make_sure (node .get_attr_value ("batch_dims" , 0 ) == 0 , err_msg )
426
428
node .type = "Gather"
429
+ utils .make_sure (node .inputs [2 ].is_const (), "Axis of GatherV2 node must be constant" )
427
430
axis = node .inputs [2 ].get_tensor_value ()
428
431
ctx .remove_input (node , node .input [2 ], 2 )
429
432
node .set_attr ("axis" , axis )
@@ -433,6 +436,42 @@ def version_11(cls, ctx, node, **kwargs):
433
436
# no change
434
437
cls .version_1 (ctx , node , ** kwargs )
435
438
439
+ @classmethod
440
+ def version_12 (cls , ctx , node , ** kwargs ):
441
+ batch_dims = node .get_attr_value ("batch_dims" , 0 )
442
+ if batch_dims == 0 :
443
+ cls .version_1 (ctx , node , ** kwargs )
444
+ return
445
+ # If batch_dims is not zero, use GatherND to simulate Gather with batch dims.
446
+ data_inp , indices_inp , axis_inp = node .input
447
+ utils .make_sure (node .inputs [2 ].is_const (), "Axis of GatherV2 node must be constant" )
448
+ axis = node .inputs [2 ].get_tensor_value ()
449
+ ctx .remove_input (node , axis_inp , 2 )
450
+ if ctx .get_dtype (indices_inp ) != TensorProto .INT64 :
451
+ indices_inp = ctx .make_node ("Cast" , [indices_inp ], attr = {'to' : TensorProto .INT64 }).output [0 ]
452
+ unperm = None
453
+ # GatherND doesn't take an axis so we have to transpose stuff around
454
+ if axis != batch_dims :
455
+ data_rank = ctx .get_rank (data_inp )
456
+ indices_rank = ctx .get_rank (indices_inp )
457
+ result_rank = data_rank + indices_rank - 1 - batch_dims
458
+ shift_amt = axis - batch_dims
459
+ err_msg = "Cannot convert GatherV2 with batch dims since inputs have unknown ranks."
460
+ utils .make_sure (data_rank is not None and indices_rank is not None , err_msg )
461
+ perm = list (range (data_rank ))
462
+ perm = perm [:batch_dims ] + perm [axis :axis + 1 ] + perm [batch_dims :axis ] + perm [axis + 1 :]
463
+ data_inp = ctx .make_node ("Transpose" , [data_inp ], attr = {'perm' : perm }).output [0 ]
464
+ ctx .replace_input (node , node .input [0 ], data_inp , 0 )
465
+ unperm = list (range (result_rank ))
466
+ j = indices_rank + shift_amt
467
+ unperm = unperm [:batch_dims ] + unperm [indices_rank :j ] + unperm [batch_dims :indices_rank ] + unperm [j :]
468
+ node .type = "GatherND"
469
+ unsqueeze_node = GraphBuilder (ctx ).make_unsqueeze ({'data' : indices_inp , 'axes' : [- 1 ]})
470
+ ctx .replace_input (node , node .input [1 ], unsqueeze_node , 1 )
471
+ if unperm is not None :
472
+ ctx .update_node_shape_dtype (node , override = True )
473
+ ctx .insert_new_node_on_output ("Transpose" , node .output [0 ], perm = unperm )
474
+
436
475
437
476
def _make_gathernd_inner_loop (ctx , params , index , dtype ):
438
477
"""create the inner loop for GatherNd."""
0 commit comments