Skip to content

Commit 61a93c1

Browse files
committed
fix test failed among Python version
1 parent 34c1b2e commit 61a93c1

File tree

1 file changed

+42
-4
lines changed

1 file changed

+42
-4
lines changed

hbbrain/tests/test_matrix_transformation.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,52 @@ def test_split_matrix_non_sort_max():
4949

5050
def test_split_matrix_sorted_min():
5151
split_max = split_matrix(sim_matrix, asimil_type='min', is_sort=True)
52-
expected_res = np.array([[1. , 4. , 0.6], [0. , 3. , 0.6], [3. , 4. , 0.4], [0. , 2. , 0.4], [2. , 3. , 0.3], [1. , 3. , 0.2], [0. , 1. , 0.2], [2. , 4. , 0.1], [1. , 2. , 0.1], [0. , 4. , 0.1]])
53-
np.testing.assert_array_equal(split_max, expected_res)
52+
expected_res = np.array([[1 , 4 , 0.6],
53+
[0 , 3 , 0.6],
54+
[3 , 4 , 0.4],
55+
[0 , 2 , 0.4],
56+
[2 , 3 , 0.3],
57+
[1 , 3 , 0.2],
58+
[0 , 1 , 0.2],
59+
[2 , 4 , 0.1],
60+
[1 , 2 , 0.1],
61+
[0 , 4 , 0.1]])
62+
expected_res_1 = np.array([[0 , 3 , 0.6],
63+
[1 , 4 , 0.6],
64+
[0 , 2 , 0.4],
65+
[3 , 4 , 0.4],
66+
[2 , 3 , 0.3],
67+
[0 , 1 , 0.2],
68+
[1 , 3 , 0.2],
69+
[1 , 2 , 0.1],
70+
[2 , 4 , 0.1],
71+
[0 , 4 , 0.1]])
72+
assert np.array_equal(split_max, expected_res) or np.array_equal(split_max, expected_res_1)
5473

5574

5675
def test_split_matrix_sorted_max():
5776
split_max = split_matrix(sim_matrix, asimil_type='max', is_sort=True)
58-
expected_res = np.array([[1. , 4. , 0.8], [3. , 4. , 0.7], [2. , 3. , 0.6], [0. , 3. , 0.6], [0. , 2. , 0.5], [2. , 4. , 0.4], [1. , 3. , 0.4], [1. , 2. , 0.3], [0. , 1. , 0.3], [0. , 4. , 0.2]])
59-
np.testing.assert_array_equal(split_max, expected_res)
77+
expected_res = np.array([[1 , 4 , 0.8],
78+
[3 , 4 , 0.7],
79+
[2 , 3 , 0.6],
80+
[0 , 3 , 0.6],
81+
[0 , 2 , 0.5],
82+
[2 , 4 , 0.4],
83+
[1 , 3 , 0.4],
84+
[1 , 2 , 0.3],
85+
[0 , 1 , 0.3],
86+
[0 , 4 , 0.2]])
87+
expected_res_1 = np.array([[1 , 4 , 0.8],
88+
[3 , 4 , 0.7],
89+
[0 , 3 , 0.6],
90+
[2 , 3 , 0.6],
91+
[0 , 2 , 0.5],
92+
[1 , 3 , 0.4],
93+
[2 , 4 , 0.4],
94+
[0 , 1 , 0.3],
95+
[1 , 2 , 0.3],
96+
[0 , 4 , 0.2]])
97+
assert np.array_equal(split_max, expected_res) or np.array_equal(split_max, expected_res_1)
6098

6199

62100
def test_hashing():

0 commit comments

Comments
 (0)