Skip to content

Commit f4b13be

Browse files
authored
Merge pull request #774 from RandySheriffH/rashuai/xlnet
Rashuai/xlnet
2 parents 0aa0e89 + f30205e commit f4b13be

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

tests/test_backend.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,6 +1828,16 @@ def test_strided_slice_dynamic_7(self):
18281828
_ = tf.identity(x_, name=_TFOUTPUT)
18291829
self._run_test_case([_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
18301830

1831+
@check_opset_min_version(10, "Slice")
1832+
def test_new_axis_mask(self):
1833+
x_val = np.arange(5*10*10*10*10*20*30).astype("float32").reshape((5, 10, 10, 10, 10, 20, 30))
1834+
y_val = np.array(9, dtype=np.int32)
1835+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1836+
y = tf.placeholder(tf.int32, y_val.shape, name=_TFINPUT1)
1837+
x_ = x[tf.newaxis, 0:y, y::2, tf.newaxis, :, tf.newaxis, :y, tf.newaxis, ..., 9]
1838+
_ = tf.identity(x_, name=_TFOUTPUT)
1839+
self._run_test_case([_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
1840+
18311841
@skip_caffe2_backend("fails with schema error")
18321842
@check_opset_min_version(7, "batchnorm")
18331843
def test_batchnorm(self):

tf2onnx/onnx_opset/tensor.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -738,19 +738,19 @@ def version_10(cls, ctx, node, **kwargs):
738738
# @int shrink_axis_mask, @int new_axis_mask)
739739
# T output = Slice(T input, Tind starts, Tind ends, Tind axes, Tind steps)
740740
# "ends" are exclusive, "axes" and "steps" are optional, their default val are [0, ...] and 1
741+
input_x = node.inputs[0]
741742
begin = node.inputs[1]
742743
end = node.inputs[2]
743744
strides = node.inputs[3]
745+
new_axis_mask = node.get_attr("new_axis_mask")
746+
new_axis_mask = new_axis_mask.i if new_axis_mask is not None else 0
747+
744748
if begin.is_const() and end.is_const() and strides.is_const() \
745-
and all(val == 1 for val in strides.get_tensor_value()):
749+
and all(val == 1 for val in strides.get_tensor_value()) \
750+
and new_axis_mask == 0:
746751
cls.version_1(ctx, node, **kwargs)
747752
return
748753

749-
not_supported_attr = ["new_axis_mask"]
750-
for attr_name in not_supported_attr:
751-
attr = node.get_attr(attr_name)
752-
if attr is not None and attr.i != 0:
753-
raise ValueError("StridedSlice: attribute " + attr_name + " not supported")
754754
onnx_dtype = ctx.get_dtype(node.input[1])
755755
np_dtype = utils.ONNX_TO_NUMPY_DTYPE[onnx_dtype]
756756

@@ -769,6 +769,15 @@ def version_10(cls, ctx, node, **kwargs):
769769
ellipsis_mask = ellipsis_mask.i if ellipsis_mask is not None else 0
770770
shrink_axis_mask = node.get_attr("shrink_axis_mask")
771771
shrink_axis_mask = shrink_axis_mask.i if shrink_axis_mask is not None else 0
772+
if new_axis_mask != 0:
773+
unqueeze_at = []
774+
for bit in range(32):
775+
if (new_axis_mask >> bit) & 1 == 1:
776+
unqueeze_at.append(bit)
777+
begin_mask |= 1 << bit
778+
end_mask |= 1 << bit
779+
input_x = ctx.make_node("Unsqueeze", [input_x.output[0]], {"axes": unqueeze_at})
780+
772781
param_shape = ctx.get_shape(node.input[1]) or \
773782
ctx.get_shape(node.input[2]) or \
774783
ctx.get_shape(node.input[3])
@@ -789,7 +798,7 @@ def version_10(cls, ctx, node, **kwargs):
789798
ellipsis_gap = 0
790799
for idx in range(param_rank):
791800
if (ellipsis_mask >> idx) & 1:
792-
input_shape = ctx.get_shape(node.input[0])
801+
input_shape = ctx.get_shape(input_x.output[0])
793802
utils.make_sure(
794803
input_shape is not None,
795804
"StridedSlice op {} requires the shape of input".format(node.name)
@@ -886,7 +895,7 @@ def version_10(cls, ctx, node, **kwargs):
886895
axes_output = axes_const.output[0]
887896

888897
inputs_map = {
889-
"data": node.input[0],
898+
"data": input_x.output[0],
890899
"starts": begin.output[0],
891900
"ends": end_output,
892901
"steps": strides_output,

0 commit comments

Comments
 (0)