|
1 | 1 | from collections.abc import Callable |
| 2 | +from unittest.mock import MagicMock |
2 | 3 |
|
3 | 4 | import numpy as np |
4 | 5 | import pytest |
@@ -320,6 +321,47 @@ def test_apply_with_extra_kwargs_on_array_raises(labelled_boolean_mask: Callable |
320 | 321 | _ = collection.apply(array, sample_frequency="test") |
321 | 322 |
|
322 | 323 |
|
| 324 | +def test_apply_unsupported_type(): |
| 325 | + pm = PixelMask([[True, False], [False, True]]) |
| 326 | + collection = PixelMaskCollection([pm]) |
| 327 | + |
| 328 | + # Passing a string should trigger the `case _:` TypeError |
| 329 | + with pytest.raises(TypeError) as excinfo: |
| 330 | + collection.apply("not a valid type") |
| 331 | + assert "Unsupported data type" in str(excinfo.value) |
| 332 | + |
| 333 | + |
| 334 | +def test_apply_to_eitdata_branch(): |
| 335 | + # Create a MagicMock that is recognized as an instance of EITData |
| 336 | + eit_data_mock = MagicMock(spec=EITData) |
| 337 | + eit_data_mock.pixel_impedance = np.array([[1.0, 2.0], [3.0, 4.0]]) |
| 338 | + |
| 339 | + # Ensure .apply() on a PixelMask returns an EITData instance (or mock) |
| 340 | + def mock_apply( |
| 341 | + _self: "PixelMaskCollection", |
| 342 | + _data: "EITData | np.ndarray | PixelMap", |
| 343 | + *, |
| 344 | + _label: str | None = None, |
| 345 | + **_kwargs: object, |
| 346 | + ) -> EITData: |
| 347 | + """Mock apply method for PixelMaskCollection.""" |
| 348 | + return MagicMock(spec=EITData) |
| 349 | + |
| 350 | + # Patch PixelMask.apply for this test |
| 351 | + PixelMask.apply = mock_apply |
| 352 | + |
| 353 | + pm1 = PixelMask([[True, False], [False, True]], label="mask1") |
| 354 | + pm2 = PixelMask([[True, True], [False, False]], label="mask2") |
| 355 | + |
| 356 | + collection = PixelMaskCollection([pm1, pm2]) |
| 357 | + |
| 358 | + result = collection.apply(eit_data_mock) |
| 359 | + |
| 360 | + assert isinstance(result, dict) |
| 361 | + assert set(result.keys()) == {"mask1", "mask2"} |
| 362 | + assert all(isinstance(v, EITData) for v in result.values()) |
| 363 | + |
| 364 | + |
323 | 365 | def test_empty_collection_behavior(): |
324 | 366 | # Allow emtpy collection initialization |
325 | 367 | _ = PixelMaskCollection() # No masks provided |
@@ -488,3 +530,22 @@ def test_combine_weighted(): |
488 | 530 | multiplied_mask = collection.combine(method="product", label="combined_product") |
489 | 531 | assert multiplied_mask.label == "combined_product" |
490 | 532 | assert np.array_equal(multiplied_mask.mask, np.array([[np.nan, 0.1], [0.06, 0.2]]), equal_nan=True) |
| 533 | + |
| 534 | + |
| 535 | +def test_combine_method_argument(): |
| 536 | + pm1 = PixelMask([[True, False], [False, True]]) |
| 537 | + pm2 = PixelMask([[True, True], [False, False]]) |
| 538 | + collection = PixelMaskCollection([pm1, pm2]) |
| 539 | + |
| 540 | + # Test sum method |
| 541 | + summed = collection.combine(method="sum") |
| 542 | + assert isinstance(summed, PixelMask) |
| 543 | + |
| 544 | + # Test product method |
| 545 | + product = collection.combine(method="product") |
| 546 | + assert isinstance(product, PixelMask) |
| 547 | + |
| 548 | + # Test unsupported method raises ValueError |
| 549 | + with pytest.raises(ValueError) as excinfo: |
| 550 | + collection.combine(method="invalid") |
| 551 | + assert "Unsupported method" in str(excinfo.value) |
0 commit comments