Skip to content

Commit 9cfa737

Browse files
committed
refactor case
1 parent 6096b4f commit 9cfa737

File tree

1 file changed

+5
-35
lines changed

1 file changed

+5
-35
lines changed

tests/test_backend.py

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3117,41 +3117,11 @@ def func(X, K):
31173117
v3 = tf.raw_ops.MatrixDiagPartV3(input=X, k=K, padding_value=0.123, align='RIGHT_LEFT', name=_TFOUTPUT1)
31183118
return v2, v3
31193119

3120-
x_val = np.random.random([4, 5]).astype(np.float32)
3121-
k_val = np.array([0]).astype(np.int32)
3122-
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
3123-
k_val = np.array([1]).astype(np.int32)
3124-
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
3125-
k_val = np.array([3]).astype(np.int32)
3126-
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
3127-
k_val = np.array([-1]).astype(np.int32)
3128-
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
3129-
k_val = np.array([-3]).astype(np.int32)
3130-
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
3131-
k_val = np.array([1, 2]).astype(np.int32)
3132-
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
3133-
k_val = np.array([-2, -1]).astype(np.int32)
3134-
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
3135-
k_val = np.array([-1, 1]).astype(np.int32)
3136-
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
3137-
3138-
x_val = np.random.random([2, 3, 4, 5]).astype(np.float32)
3139-
k_val = np.array([0]).astype(np.int32)
3140-
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
3141-
k_val = np.array([1]).astype(np.int32)
3142-
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
3143-
k_val = np.array([3]).astype(np.int32)
3144-
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
3145-
k_val = np.array([-1]).astype(np.int32)
3146-
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
3147-
k_val = np.array([-3]).astype(np.int32)
3148-
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
3149-
k_val = np.array([1, 2]).astype(np.int32)
3150-
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
3151-
k_val = np.array([-2, -1]).astype(np.int32)
3152-
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
3153-
k_val = np.array([-1, 1]).astype(np.int32)
3154-
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
3120+
for x_shape in ([4, 5], [2, 3, 4, 5]):
3121+
x_val = np.random.random(x_shape).astype(np.float32)
3122+
for raw_k in ([0], [1], [3], [-1], [-3], [1, 2], [-2, -1], [-1, 1]):
3123+
k_val = np.array(raw_k).astype(np.int32)
3124+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val, _INPUT1: k_val})
31553125

31563126

31573127
if __name__ == '__main__':

0 commit comments

Comments
 (0)