Skip to content

Commit 5681c1d

Browse files
committed
Adding tests to get coverage up to 95 %
1 parent c592534 commit 5681c1d

File tree

6 files changed

+554
-34
lines changed

6 files changed

+554
-34
lines changed

tests/basis/test_basis.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
from sklearn.exceptions import NotFittedError
88

9-
from pysensors.basis import SVD, Custom, Identity, RandomProjection
9+
from pysensors.basis import SVD, Custom, Identity, RandomProjection, _base
1010

1111

1212
@pytest.mark.parametrize("basis", [Identity(), SVD(), RandomProjection()])
@@ -197,3 +197,52 @@ def mock_validate_input(self, value):
197197
custom_instance.matrix_inverse(n_basis_modes=3)
198198
assert validation_called
199199
assert test_value == 3
200+
201+
202+
def test_invertible_basis_abstract_method():
203+
class TestBasis(_base.InvertibleBasis):
204+
pass
205+
206+
with pytest.raises(TypeError, match="Can't instantiate abstract class TestBasis"):
207+
TestBasis()
208+
209+
class ProperTestBasis(_base.InvertibleBasis):
210+
def matrix_inverse(self, n_basis_modes=None, **kwargs):
211+
return None
212+
213+
with pytest.raises(
214+
NotImplementedError, match="This method has not been implemented"
215+
):
216+
_base.InvertibleBasis.matrix_inverse(None)
217+
218+
219+
def test_validate_input_too_many_modes_error():
220+
basis = SVD()
221+
n_available_modes = 5
222+
basis.basis_matrix_ = np.random.rand(10, n_available_modes)
223+
basis.n_basis_modes = n_available_modes
224+
too_many_modes = 8
225+
226+
expected_error = (
227+
f"Requested number of modes {too_many_modes} exceeds"
228+
f" number available: {n_available_modes}"
229+
)
230+
231+
with pytest.raises(ValueError, match=expected_error):
232+
basis._validate_input(n_basis_modes=too_many_modes)
233+
234+
235+
def test_matrix_representation_copy():
236+
basis = SVD()
237+
n_features = 10
238+
n_basis_modes = 5
239+
basis.basis_matrix_ = np.random.rand(n_features, n_basis_modes)
240+
result_copy = basis.matrix_representation(n_basis_modes=3, copy=True)
241+
np.testing.assert_array_equal(result_copy, basis.basis_matrix_[:, :3])
242+
original_value = basis.basis_matrix_[0, 0]
243+
result_copy[0, 0] = 999
244+
assert basis.basis_matrix_[0, 0] == original_value
245+
result_view = basis.matrix_representation(n_basis_modes=3, copy=False)
246+
np.testing.assert_array_equal(result_view, basis.basis_matrix_[:, :3])
247+
result_view[0, 0] = 777
248+
assert basis.basis_matrix_[0, 0] == 777

tests/classification/test_sspoc.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,3 +395,50 @@ def test_warning_when_both_n_sensors_and_threshold_provided(sspoc_mock):
395395
match="Both n_sensors.*and threshold.*were passed so threshold will be ignored",
396396
):
397397
sspoc_mock.update_sensors(n_sensors=2, threshold=0.4)
398+
399+
400+
def test_update_sensors_too_many_sensors_error():
401+
n_available_sensors = 10
402+
model = SSPOC()
403+
model.sensor_coef_ = np.random.rand(n_available_sensors)
404+
too_many_sensors = n_available_sensors + 5
405+
406+
expected_error = (
407+
f"n_sensors\\({too_many_sensors}\\) cannot exceed number of "
408+
f"available sensors \\({n_available_sensors}\\)"
409+
)
410+
411+
with pytest.raises(ValueError, match=expected_error):
412+
model.update_sensors(n_sensors=too_many_sensors)
413+
414+
415+
def test_uninformative_sensors_warning():
416+
n_available_sensors = 10
417+
n_sensors_to_select = 6
418+
model = SSPOC()
419+
sensor_coef = np.zeros(n_available_sensors)
420+
sensor_coef[:5] = np.random.rand(5)
421+
sensor_coef[:5] = -np.sort(-np.abs(sensor_coef[:5]))
422+
model.sensor_coef_ = sensor_coef
423+
with pytest.warns(
424+
UserWarning,
425+
match="Some uninformative sensors were selected. Consider decreasing n_sensors",
426+
):
427+
model.update_sensors(n_sensors=n_sensors_to_select)
428+
429+
430+
def test_uninformative_sensors_multiclass_warning():
431+
n_available_sensors = 10
432+
n_classes = 3
433+
n_sensors_to_select = 6
434+
model = SSPOC()
435+
sensor_coef = np.zeros((n_available_sensors, n_classes))
436+
sensor_coef[:5, :] = np.random.rand(5, n_classes)
437+
for i in range(5):
438+
sensor_coef[i, :] = np.abs(sensor_coef[i, :]) + 0.5
439+
model.sensor_coef_ = sensor_coef
440+
with pytest.warns(
441+
UserWarning,
442+
match="Some uninformative sensors were selected. Consider decreasing n_sensors",
443+
):
444+
model.update_sensors(n_sensors=n_sensors_to_select, method=np.mean)

tests/reconstruction/test_sspor.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,3 +674,31 @@ def custom_score_function(y_true, y_pred, **kwargs):
674674
assert kwargs["beta"] == 2
675675
assert "sample_weight" in kwargs
676676
assert np.array_equal(kwargs["sample_weight"], np.ones(10))
677+
678+
679+
def test_reconstruction_error_warning():
680+
basis_mode_dim = 6
681+
n_sensors = 8
682+
model = SSPOR(n_sensors=n_sensors)
683+
model.basis_matrix_ = np.random.rand(basis_mode_dim, basis_mode_dim)
684+
model.ranked_sensors_ = np.arange(basis_mode_dim)
685+
sensor_range = np.arange(1, n_sensors + 1)
686+
x_test = np.random.rand(5, basis_mode_dim)
687+
with pytest.warns(
688+
UserWarning,
689+
match=f"Performance may be poor when using more than {basis_mode_dim} sensors",
690+
):
691+
model.reconstruction_error(x_test, sensor_range=sensor_range)
692+
693+
694+
def test_validate_n_sensors_warning():
695+
n_sensors = 10
696+
n_samples = 5
697+
model = SSPOR(optimizer=CCQR(), n_sensors=n_sensors)
698+
model.basis_matrix_ = np.random.rand(15, n_samples)
699+
with pytest.warns(
700+
UserWarning,
701+
match="Number of sensors exceeds number of samples, "
702+
"which may cause CCQR to select sensors in constrained regions.",
703+
):
704+
model._validate_n_sensors()

tests/utils/test_base.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,39 @@
1-
import pytest
21
import numpy as np
2+
import pytest
3+
34
from pysensors.utils import validate_input # Adjust import path as needed
45

6+
57
def test_validate_input_value_errors():
68
"""Test that validate_input raises appropriate ValueErrors."""
7-
not_arrays = [
8-
[1, 2, 3],
9-
(1, 2, 3),
10-
{1, 2, 3},
11-
{"a": 1, "b": 2},
12-
"123",
13-
123,
14-
None
15-
]
9+
not_arrays = [[1, 2, 3], (1, 2, 3), {1, 2, 3}, {"a": 1, "b": 2}, "123", 123, None]
1610
for not_array in not_arrays:
1711
with pytest.raises(ValueError, match="x must be a numpy array"):
1812
validate_input(not_array)
1913
x_1d = np.array([1, 2, 3, 4])
20-
wrong_sensors_1d = [
21-
np.array([0, 1]),
22-
np.array([0, 1, 2, 3, 4]),
23-
np.array([])
24-
]
14+
wrong_sensors_1d = [np.array([0, 1]), np.array([0, 1, 2, 3, 4]), np.array([])]
2515
for sensors in wrong_sensors_1d:
2616
with pytest.raises(ValueError, match="x has the wrong number of features"):
2717
validate_input(x_1d, sensors)
28-
x_2d = np.array([
29-
[1, 2, 3],
30-
[4, 5, 6]
31-
])
32-
wrong_sensors_2d = [
33-
np.array([0, 1]),
34-
np.array([0, 1, 2, 3]),
35-
np.array([])
36-
]
18+
x_2d = np.array([[1, 2, 3], [4, 5, 6]])
19+
wrong_sensors_2d = [np.array([0, 1]), np.array([0, 1, 2, 3]), np.array([])]
3720
for sensors in wrong_sensors_2d:
3821
with pytest.raises(ValueError, match="x has the wrong number of features"):
3922
validate_input(x_2d, sensors)
4023

24+
4125
def test_validate_input_valid_cases():
4226
"""Test that validate_input works correctly with valid inputs."""
4327
x_1d = np.array([1, 2, 3, 4])
4428
sensors_1d = np.array([0, 1, 2, 3])
4529
result = validate_input(x_1d, sensors_1d)
4630
assert np.array_equal(result, x_1d)
47-
x_2d = np.array([
48-
[1, 2, 3],
49-
[4, 5, 6]
50-
])
31+
x_2d = np.array([[1, 2, 3], [4, 5, 6]])
5132
sensors_2d = np.array([0, 1, 2])
5233
result = validate_input(x_2d, sensors_2d)
5334
assert np.array_equal(result, x_2d)
5435
result = validate_input(x_1d, None)
5536
assert np.array_equal(result, x_1d)
56-
37+
5738
result = validate_input(x_2d, None)
58-
assert np.array_equal(result, x_2d)
39+
assert np.array_equal(result, x_2d)

0 commit comments

Comments
 (0)