Skip to content

Commit c23ef70

Browse files
Add support for GatherV2 batch_dims attr (#1329)
* Add support for GatherV2 batch_dims attr Signed-off-by: Tom Wildenhain <[email protected]> * Fix tests Signed-off-by: Tom Wildenhain <[email protected]>
1 parent fea121d commit c23ef70

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

tests/test_backend.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,26 @@ def func(x):
871871
return tf.identity(x_, name=_TFOUTPUT)
872872
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
873873

874+
@check_tf_min_version("1.14")
875+
@check_opset_min_version(12, "GatherND with batch_dims")
876+
def test_gather_batch_dims_no_trans(self):
877+
x_val = np.arange(2 * 2 * 3 * 5 * 4, dtype=np.float32).reshape((2, 2, 3, 5, 4))
878+
idx_val = np.array([[[1, 0, 2, 0], [1, 1, 1, 0]], [[0, 0, 0, 0], [2, 1, 1, 0]]], dtype=np.int32)
879+
def func(x, idx):
880+
x_ = tf.gather(x, idx, batch_dims=2, axis=2)
881+
return tf.identity(x_, name=_TFOUTPUT)
882+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: idx_val})
883+
884+
@check_tf_min_version("1.14")
885+
@check_opset_min_version(12, "GatherND with batch_dims")
886+
def test_gather_batch_dims(self):
887+
x_val = np.arange(2 * 2 * 3 * 5 * 4, dtype=np.float32).reshape((2, 2, 3, 5, 4))
888+
idx_val = np.array([[[1, 0, 2, 0], [1, 1, 1, 0]], [[0, 0, 0, 0], [2, 1, 1, 0]]], dtype=np.int32)
889+
def func(x, idx):
890+
x_ = tf.gather(x, idx, batch_dims=2, axis=3)
891+
return tf.identity(x_, name=_TFOUTPUT)
892+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: idx_val})
893+
874894
@check_opset_min_version(10, "Slice")
875895
def test_roll_axis_scalar(self):
876896
x_val = np.arange(4 * 3 * 5 * 2, dtype=np.float32).reshape((4, 3, 5, 2))

tf2onnx/onnx_opset/tensor.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,10 @@ class GatherV2:
423423
@classmethod
424424
def version_1(cls, ctx, node, **kwargs):
425425
# for GatherV2 axis come as input
426+
err_msg = "Opset 12 required for batch_dims attribute of GatherV2"
427+
utils.make_sure(node.get_attr_value("batch_dims", 0) == 0, err_msg)
426428
node.type = "Gather"
429+
utils.make_sure(node.inputs[2].is_const(), "Axis of GatherV2 node must be constant")
427430
axis = node.inputs[2].get_tensor_value()
428431
ctx.remove_input(node, node.input[2], 2)
429432
node.set_attr("axis", axis)
@@ -433,6 +436,42 @@ def version_11(cls, ctx, node, **kwargs):
433436
# no change
434437
cls.version_1(ctx, node, **kwargs)
435438

439+
@classmethod
440+
def version_12(cls, ctx, node, **kwargs):
441+
batch_dims = node.get_attr_value("batch_dims", 0)
442+
if batch_dims == 0:
443+
cls.version_1(ctx, node, **kwargs)
444+
return
445+
# If batch_dims is not zero, use GatherND to simulate Gather with batch dims.
446+
data_inp, indices_inp, axis_inp = node.input
447+
utils.make_sure(node.inputs[2].is_const(), "Axis of GatherV2 node must be constant")
448+
axis = node.inputs[2].get_tensor_value()
449+
ctx.remove_input(node, axis_inp, 2)
450+
if ctx.get_dtype(indices_inp) != TensorProto.INT64:
451+
indices_inp = ctx.make_node("Cast", [indices_inp], attr={'to': TensorProto.INT64}).output[0]
452+
unperm = None
453+
# GatherND doesn't take an axis so we have to transpose stuff around
454+
if axis != batch_dims:
455+
data_rank = ctx.get_rank(data_inp)
456+
indices_rank = ctx.get_rank(indices_inp)
457+
result_rank = data_rank + indices_rank - 1 - batch_dims
458+
shift_amt = axis - batch_dims
459+
err_msg = "Cannot convert GatherV2 with batch dims since inputs have unknown ranks."
460+
utils.make_sure(data_rank is not None and indices_rank is not None, err_msg)
461+
perm = list(range(data_rank))
462+
perm = perm[:batch_dims] + perm[axis:axis+1] + perm[batch_dims:axis] + perm[axis+1:]
463+
data_inp = ctx.make_node("Transpose", [data_inp], attr={'perm': perm}).output[0]
464+
ctx.replace_input(node, node.input[0], data_inp, 0)
465+
unperm = list(range(result_rank))
466+
j = indices_rank+shift_amt
467+
unperm = unperm[:batch_dims] + unperm[indices_rank:j] + unperm[batch_dims:indices_rank] + unperm[j:]
468+
node.type = "GatherND"
469+
unsqueeze_node = GraphBuilder(ctx).make_unsqueeze({'data': indices_inp, 'axes': [-1]})
470+
ctx.replace_input(node, node.input[1], unsqueeze_node, 1)
471+
if unperm is not None:
472+
ctx.update_node_shape_dtype(node, override=True)
473+
ctx.insert_new_node_on_output("Transpose", node.output[0], perm=unperm)
474+
436475

437476
def _make_gathernd_inner_loop(ctx, params, index, dtype):
438477
"""create the inner loop for GatherNd."""

0 commit comments

Comments
 (0)