|
1 | 1 | """Tests for SSPOC class.""" |
2 | 2 |
|
| 3 | +from unittest.mock import Mock |
| 4 | + |
3 | 5 | import numpy as np |
4 | 6 | import pytest |
5 | 7 | from pytest_lazyfixture import lazy_fixture |
@@ -31,8 +33,6 @@ def data_multiclass_classification(): |
31 | 33 | def test_not_fitted(data_binary_classification): |
32 | 34 | x, y, _ = data_binary_classification |
33 | 35 | model = SSPOC() |
34 | | - |
35 | | - # Shouldn't be able to call any of these methods before fitting |
36 | 36 | with pytest.raises(NotFittedError): |
37 | 37 | model.predict(x) |
38 | 38 | with pytest.raises(NotFittedError): |
@@ -271,3 +271,176 @@ def test_sspoc_selector_equivalence(data_multiclass_classification): |
271 | 271 | model = SSPOC().fit(x, y) |
272 | 272 |
|
273 | 273 | np.testing.assert_array_equal(model.get_selected_sensors(), model.selected_sensors) |
| 274 | + |
| 275 | + |
| 276 | +@pytest.fixture |
| 277 | +def sspoc_instance(): |
| 278 | + """Create a mock SSPOC instance with refit=False.""" |
| 279 | + sspoc = SSPOC() |
| 280 | + sspoc.refit_ = False |
| 281 | + sspoc.n_sensors = 3 |
| 282 | + sspoc.basis_matrix_inverse_ = np.array( |
| 283 | + [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]] |
| 284 | + ) |
| 285 | + sspoc.classifier = Mock() |
| 286 | + sspoc.classifier.predict.return_value = np.array([0, 1, 0]) |
| 287 | + sspoc.sensor_coef_ = np.array([1, 2, 3]) |
| 288 | + |
| 289 | + return sspoc |
| 290 | + |
| 291 | + |
| 292 | +def test_predict_with_refit_false(sspoc_instance): |
| 293 | + """Test predict method when refit is False.""" |
| 294 | + X_test = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) |
| 295 | + expected_transformed_input = np.dot(X_test, sspoc_instance.basis_matrix_inverse_.T) |
| 296 | + result = sspoc_instance.predict(X_test) |
| 297 | + sspoc_instance.classifier.predict.assert_called_once() |
| 298 | + actual_arg = sspoc_instance.classifier.predict.call_args[0][0] |
| 299 | + np.testing.assert_array_almost_equal(actual_arg, expected_transformed_input) |
| 300 | + assert np.array_equal(result, np.array([0, 1, 0])) |
| 301 | + |
| 302 | + |
| 303 | +def test_predict_with_refit_true(sspoc_instance): |
| 304 | + """Test predict method when refit is True.""" |
| 305 | + sspoc_instance.refit_ = True |
| 306 | + X_test = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) |
| 307 | + result = sspoc_instance.predict(X_test) |
| 308 | + sspoc_instance.classifier.predict.assert_called_once_with(X_test) |
| 309 | + assert np.array_equal(result, np.array([0, 1, 0])) |
| 310 | + |
| 311 | + |
| 312 | +def test_predict_with_zero_sensors(): |
| 313 | + """Test predict method when n_sensors is 0.""" |
| 314 | + sspoc = SSPOC() |
| 315 | + sspoc.n_sensors = 0 |
| 316 | + sspoc.sensor_coef_ = np.array([]) |
| 317 | + sspoc.dummy_ = Mock() |
| 318 | + sspoc.dummy_.predict.return_value = np.array([1, 1, 1]) |
| 319 | + |
| 320 | + X_test = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) |
| 321 | + with pytest.warns(UserWarning, match="SSPOC model has no selected sensors"): |
| 322 | + sspoc.predict(X_test) |
| 323 | + sspoc.dummy_.predict.assert_called_once() |
| 324 | + np.testing.assert_array_equal(sspoc.dummy_.predict.call_args[0][0], X_test[:, 0]) |
| 325 | + |
| 326 | + |
| 327 | +@pytest.fixture |
| 328 | +def sspoc_mock(): |
| 329 | + """Create a mock SSPOC instance that will actually trigger the warning.""" |
| 330 | + sspoc = SSPOC() |
| 331 | + sspoc.sensor_coef_ = np.array([0.9, 0.8, 0.0, 0.3, 0.2]) |
| 332 | + sspoc.classifier = Mock() |
| 333 | + |
| 334 | + def custom_update( |
| 335 | + n_sensors=None, |
| 336 | + threshold=None, |
| 337 | + xy=None, |
| 338 | + quiet=False, |
| 339 | + method=np.max, |
| 340 | + **method_kws, |
| 341 | + ): |
| 342 | + if n_sensors is not None: |
| 343 | + sorted_indices = np.argsort(-np.abs(sspoc.sensor_coef_)) |
| 344 | + print(f"Sorted indices: {sorted_indices}") |
| 345 | + print(f"n_sensors-1 index: {sorted_indices[n_sensors - 1]}") |
| 346 | + print( |
| 347 | + f"Value at index: {sspoc.sensor_coef_[sorted_indices[n_sensors - 1]]}" |
| 348 | + ) |
| 349 | + print( |
| 350 | + f"Is 0?{np.abs(sspoc.sensor_coef_[sorted_indices[n_sensors - 1]]) == 0}" |
| 351 | + ) |
| 352 | + original = sspoc.update_sensors |
| 353 | + return original(n_sensors, threshold, xy, quiet, method, **method_kws) |
| 354 | + |
| 355 | + return sspoc |
| 356 | + |
| 357 | + |
| 358 | +@pytest.fixture |
| 359 | +def sspoc_multiclass_mock(): |
| 360 | + """Create a mock SSPOC instance for multiclass case that will trigger warning.""" |
| 361 | + sspoc = SSPOC() |
| 362 | + sspoc.sensor_coef_ = np.array( |
| 363 | + [ |
| 364 | + [0.9, 0.8, 0.7], |
| 365 | + [0.6, 0.5, 0.4], |
| 366 | + [0.0, 0.0, 0.0], |
| 367 | + [0.2, 0.1, 0.3], |
| 368 | + [0.05, 0.04, 0.03], |
| 369 | + ] |
| 370 | + ) |
| 371 | + sspoc.classifier = Mock() |
| 372 | + return sspoc |
| 373 | + |
| 374 | + |
| 375 | +def test_warning_when_threshold_too_high(sspoc_mock): |
| 376 | + """Test warning when threshold is set too high and no sensors are selected.""" |
| 377 | + with pytest.warns(UserWarning, match="Threshold set too high.*no sensors selected"): |
| 378 | + sspoc_mock.update_sensors(threshold=1.0) |
| 379 | + assert len(sspoc_mock.sparse_sensors_) == 0 |
| 380 | + assert sspoc_mock.n_sensors == 0 |
| 381 | + |
| 382 | + |
| 383 | +def test_warning_when_no_sensors_selected_for_refit(sspoc_mock): |
| 384 | + """Test warning when trying to refit with no sensors selected.""" |
| 385 | + X = np.random.rand(10, 5) |
| 386 | + y = np.random.randint(0, 2, 10) |
| 387 | + |
| 388 | + with pytest.warns(UserWarning, match="No selected sensors; model was not refit"): |
| 389 | + sspoc_mock.update_sensors(threshold=1.0, xy=(X, y)) |
| 390 | + sspoc_mock.classifier.fit.assert_not_called() |
| 391 | + |
| 392 | + |
| 393 | +def test_warning_when_both_n_sensors_and_threshold_provided(sspoc_mock): |
| 394 | + """Test that warning is issued when both n_sensors and threshold are provided.""" |
| 395 | + with pytest.warns( |
| 396 | + UserWarning, |
| 397 | + match="Both n_sensors.*and threshold.*were passed so threshold will be ignored", |
| 398 | + ): |
| 399 | + sspoc_mock.update_sensors(n_sensors=2, threshold=0.4) |
| 400 | + |
| 401 | + |
| 402 | +def test_update_sensors_too_many_sensors_error(): |
| 403 | + n_available_sensors = 10 |
| 404 | + model = SSPOC() |
| 405 | + model.sensor_coef_ = np.random.rand(n_available_sensors) |
| 406 | + too_many_sensors = n_available_sensors + 5 |
| 407 | + |
| 408 | + expected_error = ( |
| 409 | + f"n_sensors\\({too_many_sensors}\\) cannot exceed number of " |
| 410 | + f"available sensors \\({n_available_sensors}\\)" |
| 411 | + ) |
| 412 | + |
| 413 | + with pytest.raises(ValueError, match=expected_error): |
| 414 | + model.update_sensors(n_sensors=too_many_sensors) |
| 415 | + |
| 416 | + |
| 417 | +def test_uninformative_sensors_warning(): |
| 418 | + n_available_sensors = 10 |
| 419 | + n_sensors_to_select = 6 |
| 420 | + model = SSPOC() |
| 421 | + sensor_coef = np.zeros(n_available_sensors) |
| 422 | + sensor_coef[:5] = np.random.rand(5) |
| 423 | + sensor_coef[:5] = -np.sort(-np.abs(sensor_coef[:5])) |
| 424 | + model.sensor_coef_ = sensor_coef |
| 425 | + with pytest.warns( |
| 426 | + UserWarning, |
| 427 | + match="Some uninformative sensors were selected. Consider decreasing n_sensors", |
| 428 | + ): |
| 429 | + model.update_sensors(n_sensors=n_sensors_to_select) |
| 430 | + |
| 431 | + |
| 432 | +def test_uninformative_sensors_multiclass_warning(): |
| 433 | + n_available_sensors = 10 |
| 434 | + n_classes = 3 |
| 435 | + n_sensors_to_select = 6 |
| 436 | + model = SSPOC() |
| 437 | + sensor_coef = np.zeros((n_available_sensors, n_classes)) |
| 438 | + sensor_coef[:5, :] = np.random.rand(5, n_classes) |
| 439 | + for i in range(5): |
| 440 | + sensor_coef[i, :] = np.abs(sensor_coef[i, :]) + 0.5 |
| 441 | + model.sensor_coef_ = sensor_coef |
| 442 | + with pytest.warns( |
| 443 | + UserWarning, |
| 444 | + match="Some uninformative sensors were selected. Consider decreasing n_sensors", |
| 445 | + ): |
| 446 | + model.update_sensors(n_sensors=n_sensors_to_select, method=np.mean) |
0 commit comments