Skip to content

Commit 4117092

Browse files
Merge pull request #935 from RandySheriffH/rashuai/FixMatrixDiagV3
Rashuai/MatrixDiagV1&V2&V3&MatrixSetDiagV3
2 parents 4a3c491 + 0aab188 commit 4117092

File tree

2 files changed

+724
-4
lines changed

2 files changed

+724
-4
lines changed

tests/test_backend.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3182,10 +3182,10 @@ def test_matrix_diag_part_v3(self):
31823182

31833183
def func(X, K):
31843184
v2 = tf.raw_ops.MatrixDiagPartV2(input=X, k=K, padding_value=0.123, name=_TFOUTPUT)
3185-
v3 = tf.raw_ops.MatrixDiagPartV3(input=X, k=K, padding_value=0.123, align='RIGHT_LEFT', name=_TFOUTPUT1)
3185+
v3 = tf.raw_ops.MatrixDiagPartV3(input=X, k=K, padding_value=0.123, align='LEFT_RIGHT', name=_TFOUTPUT1)
31863186
return v2, v3
31873187

3188-
for x_shape in ([4, 5], [2, 3, 4, 5]):
3188+
for x_shape in ([4, 5], [2, 3, 4, 5], [5, 4], [7, 5]):
31893189
x_val = np.random.random(x_shape).astype(np.float32)
31903190
for raw_k in ([0], [1], [3], [-1], [-3], [1, 2], [-2, -1], [-1, 1]):
31913191
k_val = np.array(raw_k).astype(np.int32)
@@ -3235,6 +3235,116 @@ def func(x):
32353235
return tf.identity(y, name=_TFOUTPUT)
32363236
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
32373237

3238+
@check_opset_min_version(12)
3239+
@check_tf_min_version("2.2")
3240+
def test_matrix_diag_v3_multi_dim(self):
3241+
raw_diag = [[[1.0, 2.0, 3.0],
3242+
[4.0, 5.0, 6.0],
3243+
[7.0, 8.0, 9.0]],
3244+
[[10.0, 11.0, 12.0],
3245+
[13.0, 14.0, 15.0],
3246+
[16.0, 17.0, 18.0]]]
3247+
diag_val = np.array(raw_diag).astype(np.float32)
3248+
k_val = np.array([-1, 1]).astype(np.int32)
3249+
row_val = np.array(-1).astype(np.int32)
3250+
col_val = np.array(-1).astype(np.int32)
3251+
3252+
def func(diag, k, row, col):
3253+
return tf.raw_ops.MatrixDiagV3(diagonal=diag, k=k, num_rows=row, num_cols=col,
3254+
padding_value=0.123, align='RIGHT_RIGHT', name=_TFOUTPUT), \
3255+
tf.raw_ops.MatrixDiagV2(diagonal=diag, k=k, num_rows=row, num_cols=col,
3256+
padding_value=0.123, name=_TFOUTPUT1)
3257+
3258+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: diag_val, _INPUT1: k_val,
3259+
_INPUT2: row_val, _INPUT3: col_val})
3260+
3261+
@check_opset_min_version(12)
3262+
@check_tf_min_version("2.2")
3263+
def test_matrix_diag_v3_multi_dim_min_row(self):
3264+
raw_diag = [[[1.0, 2.0, 3.0],
3265+
[4.0, 5.0, 6.0]],
3266+
[[7.0, 8.0, 9.0],
3267+
[10.0, 11.0, 12.0]]]
3268+
diag_val = np.array(raw_diag).astype(np.float32)
3269+
k_val = np.array([2, 3]).astype(np.int32)
3270+
row_val = np.array(-1).astype(np.int32)
3271+
col_val = np.array(6).astype(np.int32)
3272+
3273+
def func(diag, k, row, col):
3274+
return tf.raw_ops.MatrixDiagV3(diagonal=diag, k=k, num_rows=row, num_cols=col,
3275+
padding_value=0.456, align='LEFT_LEFT', name=_TFOUTPUT)
3276+
3277+
self._run_test_case(func, [_OUTPUT], {_INPUT: diag_val, _INPUT1: k_val,
3278+
_INPUT2: row_val, _INPUT3: col_val})
3279+
3280+
@check_opset_min_version(12)
3281+
@check_tf_min_version("2.2")
3282+
def test_matrix_diag_v3_single_dim_min_col(self):
3283+
raw_diag = [1.0, 2.0, 3.0]
3284+
diag_val = np.array(raw_diag).astype(np.float32)
3285+
k_val = np.array(-1).astype(np.int32)
3286+
row_val = np.array(5).astype(np.int32)
3287+
col_val = np.array(-1).astype(np.int32)
3288+
3289+
def func(diag, k, row, col):
3290+
return tf.raw_ops.MatrixDiagV3(diagonal=diag, k=k, num_rows=row, num_cols=col,
3291+
padding_value=0.789, align='LEFT_RIGHT', name=_TFOUTPUT)
3292+
3293+
self._run_test_case(func, [_OUTPUT], {_INPUT: diag_val, _INPUT1: k_val,
3294+
_INPUT2: row_val, _INPUT3: col_val})
3295+
3296+
@check_opset_min_version(12)
3297+
@check_tf_min_version("2.2")
3298+
def test_matrix_diag_v3_2single_dim_row_col(self):
3299+
raw_diag = [[1, 2, 3], [4, 5, 6]]
3300+
diag_val = np.array(raw_diag).astype(np.int64)
3301+
k_val = np.array(0).astype(np.int32)
3302+
row_val = np.array(3).astype(np.int32)
3303+
col_val = np.array(4).astype(np.int32)
3304+
3305+
def func(diag, k, row, col):
3306+
return tf.raw_ops.MatrixDiagV3(diagonal=diag, k=k, num_rows=row, num_cols=col,
3307+
padding_value=7, align='LEFT_RIGHT', name=_TFOUTPUT), \
3308+
tf.raw_ops.MatrixDiag(diagonal=diag, name=_TFOUTPUT1)
3309+
3310+
self._run_test_case(func, [_OUTPUT, _OUTPUT1],
3311+
{_INPUT: diag_val, _INPUT1: k_val,
3312+
_INPUT2: row_val, _INPUT3: col_val})
3313+
3314+
@check_opset_min_version(12)
3315+
@check_tf_min_version("2.2")
3316+
def test_matrix_diag_v3_1single_dim_row_col(self):
3317+
raw_diag = [1, 2, 3, 4, 5]
3318+
diag_val = np.array(raw_diag).astype(np.int64)
3319+
k_val = np.array(0).astype(np.int32)
3320+
row_val = np.array(5).astype(np.int32)
3321+
col_val = np.array(10).astype(np.int32)
3322+
3323+
def func(diag, k, row, col):
3324+
return tf.raw_ops.MatrixDiagV3(diagonal=diag, k=k, num_rows=row, num_cols=col,
3325+
padding_value=7, align='LEFT_RIGHT', name=_TFOUTPUT)
3326+
3327+
self._run_test_case(func, [_OUTPUT], {_INPUT: diag_val, _INPUT1: k_val,
3328+
_INPUT2: row_val, _INPUT3: col_val})
3329+
3330+
@check_opset_min_version(12)
3331+
@check_tf_min_version("2.2")
3332+
def test_matrix_set_diag_v3(self):
3333+
input_val = np.array([[[7, 7, 7, 7],
3334+
[7, 7, 7, 7],
3335+
[7, 7, 7, 7]],
3336+
[[7, 7, 7, 7],
3337+
[7, 7, 7, 7],
3338+
[7, 7, 7, 7]]]).astype(np.int64)
3339+
diag_val = np.array([[1, 2, 3],
3340+
[4, 5, 6]]).astype(np.int64)
3341+
k_val = np.array([0])
3342+
3343+
def func(base_matrix, diag, k):
3344+
return tf.raw_ops.MatrixSetDiagV3(input=base_matrix, diagonal=diag, k=k, align='RIGHT_LEFT', name=_TFOUTPUT)
3345+
3346+
self._run_test_case(func, [_OUTPUT], {_INPUT: input_val, _INPUT1: diag_val, _INPUT2: k_val})
3347+
32383348

32393349
if __name__ == '__main__':
32403350
unittest_main()

0 commit comments

Comments
 (0)