Skip to content

Commit f9022b4

Browse files
authored
Merge pull request #25 from niharika2999/Add_tests
Added tests to up the coverage to 95%
2 parents e30fea6 + e0d9b4d commit f9022b4

File tree

10 files changed

+4118
-80
lines changed

10 files changed

+4118
-80
lines changed

pysensors/optimizers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._ccqr import CCQR
1+
from ._ccqr import CCQR, qr_reflector
22
from ._gqr import GQR
33
from ._qr import QR
44

tests/basis/test_basis.py

Lines changed: 158 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
from sklearn.exceptions import NotFittedError
88

9-
from pysensors.basis import SVD, Identity, RandomProjection
9+
from pysensors.basis import SVD, Custom, Identity, RandomProjection, _base
1010

1111

1212
@pytest.mark.parametrize("basis", [Identity(), SVD(), RandomProjection()])
@@ -75,7 +75,6 @@ def test_extra_basis_modes(basis, data_random):
7575
data = data_random
7676
n_basis_modes = data.shape[0] + 1
7777
b = basis(n_basis_modes=n_basis_modes)
78-
# Can't have more basis modes than the number of training examples
7978
with pytest.raises(ValueError):
8079
b.fit(data)
8180

@@ -90,3 +89,160 @@ def test_matrix_inverse_shape(basis, data_random):
9089
inverse = basis.matrix_inverse(n_basis_modes=n_basis_modes)
9190

9291
assert inverse.shape == (n_basis_modes, n_features)
92+
93+
94+
@pytest.fixture
95+
def sample_basis():
96+
"""Create a sample basis matrix for testing."""
97+
return np.eye(5)
98+
99+
100+
def test_valid_n_basis_modes(sample_basis):
101+
"""Test initialization with valid n_basis_modes."""
102+
custom = Custom(U=sample_basis, n_basis_modes=3)
103+
assert custom._n_basis_modes == 3
104+
np.testing.assert_array_equal(custom.custom_basis_, sample_basis)
105+
106+
custom = Custom(U=sample_basis, n_basis_modes=1)
107+
assert custom._n_basis_modes == 1
108+
109+
110+
@pytest.mark.parametrize("value", [3.5, "3", [3], (3,), None])
111+
def test_n_basis_modes_not_integer(sample_basis, value):
112+
"""Test initialization with non-integer n_basis_modes."""
113+
with pytest.raises(ValueError) as excinfo:
114+
Custom(U=sample_basis, n_basis_modes=value)
115+
assert "n_basis_modes must be a positive integer" in str(excinfo.value)
116+
117+
118+
@pytest.mark.parametrize("value", [0, -1, -10])
119+
def test_n_basis_modes_not_positive(sample_basis, value):
120+
"""Test initialization with non-positive n_basis_modes."""
121+
with pytest.raises(ValueError) as excinfo:
122+
Custom(U=sample_basis, n_basis_modes=value)
123+
assert "n_basis_modes must be a positive integer" in str(excinfo.value)
124+
125+
126+
def test_with_keyword_arguments(sample_basis):
127+
"""Test initialization with additional keyword arguments."""
128+
custom = Custom(
129+
U=sample_basis, n_basis_modes=3, extra_param=True, another_param="value"
130+
)
131+
assert custom._n_basis_modes == 3
132+
133+
134+
@pytest.fixture
135+
def custom_instance(sample_basis):
136+
"""Create an initialized Custom instance for testing."""
137+
return Custom(U=sample_basis, n_basis_modes=3)
138+
139+
140+
def test_fit_method(custom_instance, sample_basis):
141+
"""Test that fit correctly sets basis_matrix_ and returns self."""
142+
result = custom_instance.fit()
143+
expected_basis_matrix = sample_basis[:, :3]
144+
np.testing.assert_array_equal(custom_instance.basis_matrix_, expected_basis_matrix)
145+
assert result is custom_instance
146+
147+
148+
def test_matrix_inverse_default(custom_instance):
149+
"""Test matrix_inverse with default n_basis_modes."""
150+
custom_instance.fit()
151+
result = custom_instance.matrix_inverse()
152+
expected_result = custom_instance.basis_matrix_.T
153+
np.testing.assert_array_equal(result, expected_result)
154+
155+
156+
@pytest.mark.parametrize("n_modes", [1, 2])
157+
def test_matrix_inverse_with_n_basis_modes(custom_instance, n_modes):
158+
"""Test matrix_inverse with specified n_basis_modes."""
159+
custom_instance.fit()
160+
result = custom_instance.matrix_inverse(n_basis_modes=n_modes)
161+
expected_result = custom_instance.basis_matrix_[:, :n_modes].T
162+
np.testing.assert_array_equal(result, expected_result)
163+
assert result.shape == (n_modes, 5)
164+
165+
166+
def test_n_basis_modes_getter(custom_instance):
167+
"""Test n_basis_modes property getter."""
168+
assert custom_instance.n_basis_modes == 3
169+
custom_instance._n_basis_modes = 4
170+
assert custom_instance.n_basis_modes == 4
171+
172+
173+
def test_n_basis_modes_setter(custom_instance):
174+
"""Test n_basis_modes property setter."""
175+
custom_instance.n_basis_modes = 2
176+
assert custom_instance._n_basis_modes == 2
177+
assert custom_instance.n_components == 2
178+
179+
custom_instance.n_basis_modes = 4
180+
assert custom_instance._n_basis_modes == 4
181+
assert custom_instance.n_components == 4
182+
183+
184+
def test_matrix_inverse_calls_validate_input(custom_instance, monkeypatch):
185+
"""Test that matrix_inverse calls _validate_input."""
186+
custom_instance.fit()
187+
validation_called = False
188+
test_value = None
189+
190+
def mock_validate_input(self, value):
191+
nonlocal validation_called, test_value
192+
validation_called = True
193+
test_value = value
194+
return 2
195+
196+
monkeypatch.setattr(Custom, "_validate_input", mock_validate_input)
197+
custom_instance.matrix_inverse(n_basis_modes=3)
198+
assert validation_called
199+
assert test_value == 3
200+
201+
202+
def test_invertible_basis_abstract_method():
203+
class TestBasis(_base.InvertibleBasis):
204+
pass
205+
206+
with pytest.raises(TypeError, match="Can't instantiate abstract class TestBasis"):
207+
TestBasis()
208+
209+
class ProperTestBasis(_base.InvertibleBasis):
210+
def matrix_inverse(self, n_basis_modes=None, **kwargs):
211+
return None
212+
213+
with pytest.raises(
214+
NotImplementedError, match="This method has not been implemented"
215+
):
216+
_base.InvertibleBasis.matrix_inverse(None)
217+
218+
219+
def test_validate_input_too_many_modes_error():
220+
basis = SVD()
221+
n_available_modes = 5
222+
basis.basis_matrix_ = np.random.rand(10, n_available_modes)
223+
basis.n_basis_modes = n_available_modes
224+
too_many_modes = 8
225+
226+
expected_error = (
227+
f"Requested number of modes {too_many_modes} exceeds"
228+
f" number available: {n_available_modes}"
229+
)
230+
231+
with pytest.raises(ValueError, match=expected_error):
232+
basis._validate_input(n_basis_modes=too_many_modes)
233+
234+
235+
def test_matrix_representation_copy():
236+
basis = SVD()
237+
n_features = 10
238+
n_basis_modes = 5
239+
basis.basis_matrix_ = np.random.rand(n_features, n_basis_modes)
240+
result_copy = basis.matrix_representation(n_basis_modes=3, copy=True)
241+
np.testing.assert_array_equal(result_copy, basis.basis_matrix_[:, :3])
242+
original_value = basis.basis_matrix_[0, 0]
243+
result_copy[0, 0] = 999
244+
assert basis.basis_matrix_[0, 0] == original_value
245+
result_view = basis.matrix_representation(n_basis_modes=3, copy=False)
246+
np.testing.assert_array_equal(result_view, basis.basis_matrix_[:, :3])
247+
result_view[0, 0] = 777
248+
assert basis.basis_matrix_[0, 0] == 777

tests/classification/test_sspoc.py

Lines changed: 175 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Tests for SSPOC class."""
22

3+
from unittest.mock import Mock
4+
35
import numpy as np
46
import pytest
57
from pytest_lazyfixture import lazy_fixture
@@ -31,8 +33,6 @@ def data_multiclass_classification():
3133
def test_not_fitted(data_binary_classification):
3234
x, y, _ = data_binary_classification
3335
model = SSPOC()
34-
35-
# Shouldn't be able to call any of these methods before fitting
3636
with pytest.raises(NotFittedError):
3737
model.predict(x)
3838
with pytest.raises(NotFittedError):
@@ -271,3 +271,176 @@ def test_sspoc_selector_equivalence(data_multiclass_classification):
271271
model = SSPOC().fit(x, y)
272272

273273
np.testing.assert_array_equal(model.get_selected_sensors(), model.selected_sensors)
274+
275+
276+
@pytest.fixture
277+
def sspoc_instance():
278+
"""Create a mock SSPOC instance with refit=False."""
279+
sspoc = SSPOC()
280+
sspoc.refit_ = False
281+
sspoc.n_sensors = 3
282+
sspoc.basis_matrix_inverse_ = np.array(
283+
[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
284+
)
285+
sspoc.classifier = Mock()
286+
sspoc.classifier.predict.return_value = np.array([0, 1, 0])
287+
sspoc.sensor_coef_ = np.array([1, 2, 3])
288+
289+
return sspoc
290+
291+
292+
def test_predict_with_refit_false(sspoc_instance):
293+
"""Test predict method when refit is False."""
294+
X_test = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
295+
expected_transformed_input = np.dot(X_test, sspoc_instance.basis_matrix_inverse_.T)
296+
result = sspoc_instance.predict(X_test)
297+
sspoc_instance.classifier.predict.assert_called_once()
298+
actual_arg = sspoc_instance.classifier.predict.call_args[0][0]
299+
np.testing.assert_array_almost_equal(actual_arg, expected_transformed_input)
300+
assert np.array_equal(result, np.array([0, 1, 0]))
301+
302+
303+
def test_predict_with_refit_true(sspoc_instance):
304+
"""Test predict method when refit is True."""
305+
sspoc_instance.refit_ = True
306+
X_test = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
307+
result = sspoc_instance.predict(X_test)
308+
sspoc_instance.classifier.predict.assert_called_once_with(X_test)
309+
assert np.array_equal(result, np.array([0, 1, 0]))
310+
311+
312+
def test_predict_with_zero_sensors():
313+
"""Test predict method when n_sensors is 0."""
314+
sspoc = SSPOC()
315+
sspoc.n_sensors = 0
316+
sspoc.sensor_coef_ = np.array([])
317+
sspoc.dummy_ = Mock()
318+
sspoc.dummy_.predict.return_value = np.array([1, 1, 1])
319+
320+
X_test = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
321+
with pytest.warns(UserWarning, match="SSPOC model has no selected sensors"):
322+
sspoc.predict(X_test)
323+
sspoc.dummy_.predict.assert_called_once()
324+
np.testing.assert_array_equal(sspoc.dummy_.predict.call_args[0][0], X_test[:, 0])
325+
326+
327+
@pytest.fixture
328+
def sspoc_mock():
329+
"""Create a mock SSPOC instance that will actually trigger the warning."""
330+
sspoc = SSPOC()
331+
sspoc.sensor_coef_ = np.array([0.9, 0.8, 0.0, 0.3, 0.2])
332+
sspoc.classifier = Mock()
333+
334+
def custom_update(
335+
n_sensors=None,
336+
threshold=None,
337+
xy=None,
338+
quiet=False,
339+
method=np.max,
340+
**method_kws,
341+
):
342+
if n_sensors is not None:
343+
sorted_indices = np.argsort(-np.abs(sspoc.sensor_coef_))
344+
print(f"Sorted indices: {sorted_indices}")
345+
print(f"n_sensors-1 index: {sorted_indices[n_sensors - 1]}")
346+
print(
347+
f"Value at index: {sspoc.sensor_coef_[sorted_indices[n_sensors - 1]]}"
348+
)
349+
print(
350+
f"Is 0?{np.abs(sspoc.sensor_coef_[sorted_indices[n_sensors - 1]]) == 0}"
351+
)
352+
original = sspoc.update_sensors
353+
return original(n_sensors, threshold, xy, quiet, method, **method_kws)
354+
355+
return sspoc
356+
357+
358+
@pytest.fixture
359+
def sspoc_multiclass_mock():
360+
"""Create a mock SSPOC instance for multiclass case that will trigger warning."""
361+
sspoc = SSPOC()
362+
sspoc.sensor_coef_ = np.array(
363+
[
364+
[0.9, 0.8, 0.7],
365+
[0.6, 0.5, 0.4],
366+
[0.0, 0.0, 0.0],
367+
[0.2, 0.1, 0.3],
368+
[0.05, 0.04, 0.03],
369+
]
370+
)
371+
sspoc.classifier = Mock()
372+
return sspoc
373+
374+
375+
def test_warning_when_threshold_too_high(sspoc_mock):
376+
"""Test warning when threshold is set too high and no sensors are selected."""
377+
with pytest.warns(UserWarning, match="Threshold set too high.*no sensors selected"):
378+
sspoc_mock.update_sensors(threshold=1.0)
379+
assert len(sspoc_mock.sparse_sensors_) == 0
380+
assert sspoc_mock.n_sensors == 0
381+
382+
383+
def test_warning_when_no_sensors_selected_for_refit(sspoc_mock):
384+
"""Test warning when trying to refit with no sensors selected."""
385+
X = np.random.rand(10, 5)
386+
y = np.random.randint(0, 2, 10)
387+
388+
with pytest.warns(UserWarning, match="No selected sensors; model was not refit"):
389+
sspoc_mock.update_sensors(threshold=1.0, xy=(X, y))
390+
sspoc_mock.classifier.fit.assert_not_called()
391+
392+
393+
def test_warning_when_both_n_sensors_and_threshold_provided(sspoc_mock):
394+
"""Test that warning is issued when both n_sensors and threshold are provided."""
395+
with pytest.warns(
396+
UserWarning,
397+
match="Both n_sensors.*and threshold.*were passed so threshold will be ignored",
398+
):
399+
sspoc_mock.update_sensors(n_sensors=2, threshold=0.4)
400+
401+
402+
def test_update_sensors_too_many_sensors_error():
403+
n_available_sensors = 10
404+
model = SSPOC()
405+
model.sensor_coef_ = np.random.rand(n_available_sensors)
406+
too_many_sensors = n_available_sensors + 5
407+
408+
expected_error = (
409+
f"n_sensors\\({too_many_sensors}\\) cannot exceed number of "
410+
f"available sensors \\({n_available_sensors}\\)"
411+
)
412+
413+
with pytest.raises(ValueError, match=expected_error):
414+
model.update_sensors(n_sensors=too_many_sensors)
415+
416+
417+
def test_uninformative_sensors_warning():
418+
n_available_sensors = 10
419+
n_sensors_to_select = 6
420+
model = SSPOC()
421+
sensor_coef = np.zeros(n_available_sensors)
422+
sensor_coef[:5] = np.random.rand(5)
423+
sensor_coef[:5] = -np.sort(-np.abs(sensor_coef[:5]))
424+
model.sensor_coef_ = sensor_coef
425+
with pytest.warns(
426+
UserWarning,
427+
match="Some uninformative sensors were selected. Consider decreasing n_sensors",
428+
):
429+
model.update_sensors(n_sensors=n_sensors_to_select)
430+
431+
432+
def test_uninformative_sensors_multiclass_warning():
433+
n_available_sensors = 10
434+
n_classes = 3
435+
n_sensors_to_select = 6
436+
model = SSPOC()
437+
sensor_coef = np.zeros((n_available_sensors, n_classes))
438+
sensor_coef[:5, :] = np.random.rand(5, n_classes)
439+
for i in range(5):
440+
sensor_coef[i, :] = np.abs(sensor_coef[i, :]) + 0.5
441+
model.sensor_coef_ = sensor_coef
442+
with pytest.warns(
443+
UserWarning,
444+
match="Some uninformative sensors were selected. Consider decreasing n_sensors",
445+
):
446+
model.update_sensors(n_sensors=n_sensors_to_select, method=np.mean)

0 commit comments

Comments
 (0)