Skip to content

Commit 871bd29

Browse files
committed
support MatrixDiagV1&V2
1 parent d371be9 commit 871bd29

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

tests/test_backend.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3251,10 +3251,13 @@ def test_matrix_diag_v3_multi_dim(self):
32513251

32523252
def func(diag, k, row, col):
32533253
return tf.raw_ops.MatrixDiagV3(diagonal=diag, k=k, num_rows=row, num_cols=col,
3254-
padding_value=0.123, align='RIGHT_RIGHT', name=_TFOUTPUT)
3254+
padding_value=0.123, align='RIGHT_RIGHT', name=_TFOUTPUT), \
3255+
tf.raw_ops.MatrixDiagV2(diagonal=diag, k=k, num_rows=row, num_cols=col,
3256+
padding_value=0.123, name=_TFOUTPUT1)
32553257

3256-
self._run_test_case(func, [_OUTPUT], {_INPUT: diag_val, _INPUT1: k_val,
3257-
_INPUT2: row_val, _INPUT3: col_val})
3258+
self._run_test_case(func, [_OUTPUT, _OUTPUT1],
3259+
{_INPUT: diag_val, _INPUT1: k_val,
3260+
_INPUT2: row_val, _INPUT3: col_val})
32583261

32593262
@check_opset_min_version(12)
32603263
@check_tf_min_version("2.2")
@@ -3302,10 +3305,12 @@ def test_matrix_diag_v3_2single_dim_row_col(self):
33023305

33033306
def func(diag, k, row, col):
33043307
return tf.raw_ops.MatrixDiagV3(diagonal=diag, k=k, num_rows=row, num_cols=col,
3305-
padding_value=7, align='LEFT_RIGHT', name=_TFOUTPUT)
3308+
padding_value=7, align='LEFT_RIGHT', name=_TFOUTPUT), \
3309+
tf.raw_ops.MatrixDiag(diagonal=diag, name=_TFOUTPUT1)
33063310

3307-
self._run_test_case(func, [_OUTPUT], {_INPUT: diag_val, _INPUT1: k_val,
3308-
_INPUT2: row_val, _INPUT3: col_val})
3311+
self._run_test_case(func, [_OUTPUT, _OUTPUT1],
3312+
{_INPUT: diag_val, _INPUT1: k_val,
3313+
_INPUT2: row_val, _INPUT3: col_val})
33093314

33103315
@check_opset_min_version(12)
33113316
@check_tf_min_version("2.2")

tf2onnx/onnx_opset/tensor.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2207,11 +2207,13 @@ def compute_out_shape(k0_k1_same=False):
22072207
ctx.set_shape(consumer.output[0], shapes)
22082208

22092209

2210-
@tf_op("MatrixDiagV3")
2211-
class MatrixDiagV3:
2210+
@tf_op(["MatrixDiag", "MatrixDiagV2", "MatrixDiagV3"])
2211+
class MatrixDiag:
22122212
@classmethod
22132213
def version_12(cls, ctx, node, **kwargs):
22142214
# Assemble MatrixDiagV3 by ReverseSequence
2215+
argc = len(node.input)
2216+
22152217
def mkconsts(values):
22162218
return [ctx.make_const(utils.make_name('const'), \
22172219
np.array(value).astype(np.int64)).output[0] for value in values]
@@ -2230,6 +2232,9 @@ def normalize(name):
22302232
reshaped = mknode("Reshape", [casted, minus_one])
22312233
return reshaped
22322234

2235+
def cast(name):
2236+
return mknode("Cast", [name], attr={"to": ctx.get_dtype(node.input[0])})
2237+
22332238
def processdiag():
22342239
# unsqueeze diag if necessary
22352240
diag = node.input[0]
@@ -2241,7 +2246,7 @@ def processdiag():
22412246

22422247
diag_shape = mknode("Shape", [diag])
22432248
diag_depth = mknode("Slice", [diag_shape, minus_two, minus_one])
2244-
k = normalize(node.input[1])
2249+
k = normalize(node.input[1]) if argc > 1 else zeo
22452250
k_min, k_max = mknode("ReduceMin", [k]), mknode("ReduceMax", [k])
22462251
k_max_nxt = mknode("Add", [k_max, one])
22472252
k_depth = mknode("Sub", [k_max_nxt, k_min])
@@ -2272,8 +2277,10 @@ def squeeze(name):
22722277

22732278
# gather inputs
22742279
diag, k, k_min, k_max, k_max_nxt = processdiag()
2275-
row, col, pad, align = normalize(node.input[2]), normalize(node.input[3]), \
2276-
node.input[4], node.get_attr_str("align")
2280+
row, col, pad, align = normalize(node.input[2]) if argc > 2 else minus_one, \
2281+
normalize(node.input[3]) if argc > 3 else minus_one, \
2282+
node.input[4] if argc > 4 else cast(zeo), \
2283+
node.get_attr_str("align") if "align" in node.attr else "LEFT_LEFT"
22772284

22782285
diag_shape = mknode("Shape", [diag])
22792286
diag_rank = mknode("Shape", [diag_shape])
@@ -2580,12 +2587,12 @@ def normalize():
25802587
# make matrix of bool
25812588
ctx.set_dtype(ones_diag.output[0], TensorProto.INT64)
25822589
ones_matrix = ctx.make_node("MatrixDiagV3", [ones_diag.output[0], k, row, col, zeo], attr)
2583-
MatrixDiagV3.version_12(ctx, ones_matrix)
2590+
MatrixDiag.version_12(ctx, ones_matrix)
25842591
ones_bool = mknode("Equal", [ones_matrix.output[0], one])
25852592

25862593
# make matrix out of diag
25872594
diag_matrix = ctx.make_node("MatrixDiagV3", [diag, k, row, col, cast(zeo)], attr)
2588-
MatrixDiagV3.version_12(ctx, diag_matrix)
2595+
MatrixDiag.version_12(ctx, diag_matrix)
25892596

25902597
shapes = node.output_shapes
25912598
dtypes = node.output_dtypes

0 commit comments

Comments
 (0)