Skip to content

Commit b87879e

Browse files
committed
changing to linealg norm from cupy in tests
1 parent e08958f commit b87879e

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

zenodo-tests/test_prep/test_stripe.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,17 @@
2020
(
2121
"i12_dataset4",
2222
11,
23-
78.5593,
23+
78.8627,
2424
),
2525
(
2626
"i12_dataset4",
2727
21,
28-
92.4986,
28+
92.8653,
2929
),
3030
(
3131
"i12_dataset4",
3232
31,
33-
99.3062,
33+
99.6979,
3434
),
3535
],
3636
ids=["size_11", "size_21", "size_31"],
@@ -49,7 +49,7 @@ def test_remove_stripe_based_sorting_i12_dataset4(
4949
)
5050

5151
residual_calc = data_normalised - output
52-
norm_res = np.linalg.norm(residual_calc.get().flatten())
52+
norm_res = cp.linalg.norm(residual_calc.flatten())
5353

5454
assert isclose(norm_res, norm_res_expected, abs_tol=10**-4)
5555

@@ -63,17 +63,17 @@ def test_remove_stripe_based_sorting_i12_dataset4(
6363
(
6464
"i12_dataset4",
6565
0.01,
66-
321.5298,
66+
322.5501,
6767
),
6868
(
6969
"i12_dataset4",
7070
0.03,
71-
172.1631,
71+
173.1643,
7272
),
7373
(
7474
"i12_dataset4",
7575
0.06,
76-
128.0331,
76+
128.8032,
7777
),
7878
],
7979
ids=["beta_001", "beta_003", "beta_006"],
@@ -90,7 +90,7 @@ def test_remove_stripe_ti_i12_dataset4(
9090
output = remove_stripe_ti(cp.copy(data_normalised), beta=beta_val)
9191

9292
residual_calc = data_normalised - output
93-
norm_res = np.linalg.norm(residual_calc.get().flatten())
93+
norm_res = cp.linalg.norm(residual_calc.flatten())
9494

9595
assert isclose(norm_res, norm_res_expected, abs_tol=10**-4)
9696

@@ -101,10 +101,10 @@ def test_remove_stripe_ti_i12_dataset4(
101101
@pytest.mark.parametrize(
102102
"dataset_fixture, snr_val, la_size_val, sm_size_val, norm_res_expected",
103103
[
104-
("i12_dataset4", 1.0, 31, 10, 104.8582),
105-
("i12_dataset4", 2.0, 41, 17, 99.4554),
106-
("i12_dataset4", 3.0, 61, 21, 102.8688),
107-
("i12_dataset4", 4.0, 71, 31, 106.1418),
104+
("i12_dataset4", 1.0, 31, 10, 105.3542),
105+
("i12_dataset4", 2.0, 41, 17, 99.9182),
106+
("i12_dataset4", 3.0, 61, 21, 103.3776),
107+
("i12_dataset4", 4.0, 71, 31, 106.6767),
108108
],
109109
ids=["snr_1", "snr_2", "snr_3", "snr_4"],
110110
)
@@ -126,7 +126,7 @@ def test_remove_all_stripe_i12_dataset4(
126126
)
127127

128128
residual_calc = data_normalised - output
129-
norm_res = np.linalg.norm(residual_calc.get().flatten())
129+
norm_res = cp.linalg.norm(residual_calc.flatten())
130130

131131
assert isclose(norm_res, norm_res_expected, abs_tol=10**-4)
132132
assert not np.isnan(output).any(), "Output contains NaN values"
@@ -137,9 +137,9 @@ def test_remove_all_stripe_i12_dataset4(
137137
@pytest.mark.parametrize(
138138
"dataset_fixture, nvalue_val, vvalue_val, norm_res_expected",
139139
[
140-
("i12_dataset4", 2, 4, 94.0424),
141-
("i12_dataset4", 4, 2, 86.2983),
142-
("i12_dataset4", 6, 5, 111.0662),
140+
("i12_dataset4", 2, 4, 94.0996),
141+
("i12_dataset4", 4, 2, 86.3459),
142+
("i12_dataset4", 6, 5, 111.1377),
143143
],
144144
ids=["case_1", "case_2", "case_3"],
145145
)
@@ -168,7 +168,7 @@ def test_raven_filter_i12_dataset4(
168168
)
169169

170170
residual_calc = data_normalised - output
171-
norm_res = np.linalg.norm(residual_calc.get().flatten())
171+
norm_res = cp.linalg.norm(residual_calc.flatten())
172172

173173
assert isclose(norm_res, norm_res_expected, abs_tol=10**-4)
174174

0 commit comments

Comments
 (0)