Skip to content

Commit 91d0d78

Browse files
Merge pull request #955 from RandySheriffH/rashuai/tf22
Enable tf22 ci
2 parents 4117092 + bf05899 commit 91d0d78

File tree

4 files changed

+52
-13
lines changed

4 files changed

+52
-13
lines changed

ci_build/azure_pipelines/unit_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ stages:
66
- template: 'templates/job_generator.yml'
77
parameters:
88
python_versions: ['3.7']
9-
tf_versions: ['1.14.0','1.15.2','2.1.0']
9+
tf_versions: ['1.14.0','1.15.2','2.1.0','2.2.0']
1010
onnx_opsets: ['']
1111
job:
1212
steps:

tests/test_backend.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,11 +1384,13 @@ def func():
13841384
self._run_test_case(func, [_OUTPUT], {}, check_value=False, check_shape=True)
13851385

13861386
@skip_caffe2_backend()
1387+
@check_opset_after_tf_version("2.2", 9, "RandomUniform")
13871388
def test_randomuniform_dyn_shape(self):
13881389
# test for dynamic shape coming from a shape op
13891390
x_val = np.array([0, 1, 2, 3, 5], dtype=np.int64)
13901391
def func(x):
1391-
return random_uniform(x[3:], name=_TFOUTPUT, dtype=tf.float32)
1392+
ret = random_uniform(x[3:], dtype=tf.float32)
1393+
return tf.identity(ret, name=_TFOUTPUT)
13921394
# since results are random, compare the shapes only
13931395
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, check_value=False, check_shape=True)
13941396

@@ -3338,7 +3340,7 @@ def test_matrix_set_diag_v3(self):
33383340
[7, 7, 7, 7]]]).astype(np.int64)
33393341
diag_val = np.array([[1, 2, 3],
33403342
[4, 5, 6]]).astype(np.int64)
3341-
k_val = np.array([0])
3343+
k_val = np.array([0]).astype(np.int32)
33423344

33433345
def func(base_matrix, diag, k):
33443346
return tf.raw_ops.MatrixSetDiagV3(input=base_matrix, diagonal=diag, k=k, align='RIGHT_LEFT', name=_TFOUTPUT)

tf2onnx/onnx_opset/tensor.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1835,8 +1835,16 @@ class MatrixDiagPartV2V3:
18351835
@classmethod
18361836
def version_11(cls, ctx, node, **kwargs):
18371837
# assemble MatrixDiagPart V2&V3 by looping k diagonals with proper pads
1838+
const_zero = ctx.make_const(utils.make_name(node.name) + 'const_zero', np.array([0]).astype(np.int64))
1839+
const_one = ctx.make_const(utils.make_name(node.name) + 'const_one', np.array([1]).astype(np.int64))
1840+
const_two = ctx.make_const(utils.make_name(node.name) + 'const_two', np.array([2]).astype(np.int64))
1841+
const_neg_one = ctx.make_const(utils.make_name(node.name) + 'const_neg_one', np.array([-1]).astype(np.int64))
1842+
const_neg_two = ctx.make_const(utils.make_name(node.name) + 'const_neg_two', np.array([-2]).astype(np.int64))
1843+
def normalize():
1844+
raw_k = ctx.make_node('Cast', [node.input[1]], attr={'to': TensorProto.INT64}).output[0]
1845+
return ctx.make_node('Reshape', [raw_k, const_neg_one.output[0]]).output[0]
18381846
input_tensor = node.input[0]
1839-
k = ctx.make_node('Cast', [node.input[1]], attr={'to': TensorProto.INT64}).output[0]
1847+
k = normalize()
18401848
padding = node.input[2]
18411849
align = 'LEFT_LEFT'
18421850
if node.op.op_type == 'MatrixDiagPartV3':
@@ -1850,12 +1858,7 @@ def version_11(cls, ctx, node, **kwargs):
18501858
for out in ctx.find_output_consumers(node.output[0]):
18511859
if out.op.op_type == 'Identity':
18521860
ctx.set_shape(out.output[0], raw_output_shape)
1853-
# define constants
1854-
const_zero = ctx.make_const(utils.make_name(node.name) + 'const_zero', np.array([0]).astype(np.int64))
1855-
const_one = ctx.make_const(utils.make_name(node.name) + 'const_one', np.array([1]).astype(np.int64))
1856-
const_two = ctx.make_const(utils.make_name(node.name) + 'const_two', np.array([2]).astype(np.int64))
1857-
const_neg_one = ctx.make_const(utils.make_name(node.name) + 'const_neg_one', np.array([-1]).astype(np.int64))
1858-
const_neg_two = ctx.make_const(utils.make_name(node.name) + 'const_neg_two', np.array([-2]).astype(np.int64))
1861+
18591862
# prepare new_shape of input
18601863
input_shape = ctx.make_node('Shape', [input_tensor])
18611864
shape_input_shape = ctx.make_node('Shape', [input_shape.output[0]])
@@ -2075,7 +2078,7 @@ def mkconsts(values, dtype=np.int64):
20752078
xalign, yalign = align.split('_')
20762079

20772080
# consts
2078-
const_zero_float, const_neg_one_float = mkconsts([[0], [-1]], np.float32)
2081+
const_zero_float, const_neg_one_float = mkconsts([0, -1], np.float32)
20792082
const_zero, const_one, const_neg_one, const_neg_two, const_pad_vals, const_t = \
20802083
mkconsts([[0], [1], [-1], [-2], pads, [-1, 1]])
20812084
const_zero_scalar, const_one_scalar, const_neg_one_scalar = mkconsts([0, 1, -1])
@@ -2098,8 +2101,12 @@ def mkconsts(values, dtype=np.int64):
20982101
m2_shape = ctx.make_node('Concat', [partial_shape, const_neg_one], attr={'axis': 0}).output[0]
20992102
gather_shape = ctx.make_node('Concat', [partial_shape, const_one], attr={'axis': 0}).output[0]
21002103

2104+
def normalize():
2105+
raw_input1 = ctx.make_node('Cast', [node.input[1]], attr={'to': TensorProto.INT64}).output[0]
2106+
return ctx.make_node('Reshape', [raw_input1, const_neg_one])
2107+
21012108
# get k0, k1 values. diags to be extracted
2102-
input1 = ctx.make_node('Cast', [node.input[1]], attr={'to': TensorProto.INT64})
2109+
input1 = normalize()
21032110
k0 = ctx.make_node('ReduceMin', [input1.output[0]]).output[0]
21042111
k1 = ctx.make_node('ReduceMax', [input1.output[0]]).output[0]
21052112
k0_scalar = ctx.make_node('Squeeze', [k0]).output[0]

tf2onnx/rewriter/eye_rewriter.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,38 @@ def rewrite_eye(g, ops):
7979
OpTypePattern("Const", name="fill_value"),
8080
]),
8181
])
82+
pattern5 = \
83+
OpTypePattern("MatrixDiagV3", name="output_eye_matrix", inputs=[
84+
OpTypePattern("Fill", inputs=[
85+
OpTypePattern("ConcatV2", inputs=[
86+
"*",
87+
OpTypePattern("ExpandDims", inputs=[
88+
OpTypePattern("Minimum|Cast", name="min_or_cast"),
89+
"*"
90+
]),
91+
"*",
92+
]),
93+
OpTypePattern("Const", name="fill_value"),
94+
]),
95+
"*", "*", "*", "*",
96+
])
97+
pattern6 = \
98+
OpTypePattern("MatrixSetDiagV3", name="output_eye_matrix", inputs=[
99+
OpTypePattern("Fill"),
100+
OpTypePattern("Fill", inputs=[
101+
OpTypePattern("ConcatV2", inputs=[
102+
"*",
103+
OpTypePattern("ExpandDims", inputs=[
104+
OpTypePattern("Minimum|Cast", name="min_or_cast"),
105+
"*"
106+
]),
107+
"*",
108+
]),
109+
OpTypePattern("Const", name="fill_value"),
110+
]), "*"
111+
])
82112

83-
for pattern in [pattern1, pattern2, pattern3, pattern4]:
113+
for pattern in [pattern1, pattern2, pattern3, pattern4, pattern5, pattern6]:
84114
matcher = GraphMatcher(pattern, allow_reorder=True)
85115
match_results = list(matcher.match_ops(ops))
86116
for match_result in match_results:

0 commit comments

Comments
 (0)