|
1 | 1 | from collections.abc import Callable |
| 2 | +from unittest.mock import MagicMock |
2 | 3 |
|
3 | 4 | import numpy as np |
4 | 5 | import pytest |
@@ -320,6 +321,40 @@ def test_apply_with_extra_kwargs_on_array_raises(labelled_boolean_mask: Callable |
320 | 321 | _ = collection.apply(array, sample_frequency="test") |
321 | 322 |
|
322 | 323 |
|
| 324 | +def test_apply_unsupported_type(): |
| 325 | + pm = PixelMask([[True, False], [False, True]]) |
| 326 | + collection = PixelMaskCollection([pm]) |
| 327 | + |
| 328 | + # Passing a string should trigger the `case _:` TypeError |
| 329 | + with pytest.raises(TypeError) as excinfo: |
| 330 | + collection.apply("not a valid type") |
| 331 | + assert "Unsupported data type" in str(excinfo.value) |
| 332 | + |
| 333 | + |
| 334 | +def test_apply_to_eitdata_branch(): |
| 335 | + # Create a MagicMock that is recognized as an instance of EITData |
| 336 | + eit_data_mock = MagicMock(spec=EITData) |
| 337 | + eit_data_mock.pixel_impedance = np.array([[1.0, 2.0], [3.0, 4.0]]) |
| 338 | + |
| 339 | + # Ensure .apply() on a PixelMask returns an EITData instance (or mock) |
| 340 | + def mock_apply(self, data, **kwargs): |
| 341 | + return MagicMock(spec=EITData) |
| 342 | + |
| 343 | + # Patch PixelMask.apply for this test |
| 344 | + PixelMask.apply = mock_apply |
| 345 | + |
| 346 | + pm1 = PixelMask([[True, False], [False, True]], label="mask1") |
| 347 | + pm2 = PixelMask([[True, True], [False, False]], label="mask2") |
| 348 | + |
| 349 | + collection = PixelMaskCollection([pm1, pm2]) |
| 350 | + |
| 351 | + result = collection.apply(eit_data_mock) |
| 352 | + |
| 353 | + assert isinstance(result, dict) |
| 354 | + assert set(result.keys()) == {"mask1", "mask2"} |
| 355 | + assert all(isinstance(v, EITData) for v in result.values()) |
| 356 | + |
| 357 | + |
323 | 358 | def test_empty_collection_behavior(): |
324 | 359 | # Allow emtpy collection initialization |
325 | 360 | _ = PixelMaskCollection() # No masks provided |
@@ -488,3 +523,22 @@ def test_combine_weighted(): |
488 | 523 | multiplied_mask = collection.combine(method="product", label="combined_product") |
489 | 524 | assert multiplied_mask.label == "combined_product" |
490 | 525 | assert np.array_equal(multiplied_mask.mask, np.array([[np.nan, 0.1], [0.06, 0.2]]), equal_nan=True) |
| 526 | + |
| 527 | + |
| 528 | +def test_combine_method_argument(): |
| 529 | + pm1 = PixelMask([[True, False], [False, True]]) |
| 530 | + pm2 = PixelMask([[True, True], [False, False]]) |
| 531 | + collection = PixelMaskCollection([pm1, pm2]) |
| 532 | + |
| 533 | + # Test sum method |
| 534 | + summed = collection.combine(method="sum") |
| 535 | + assert isinstance(summed, PixelMask) |
| 536 | + |
| 537 | + # Test product method |
| 538 | + product = collection.combine(method="product") |
| 539 | + assert isinstance(product, PixelMask) |
| 540 | + |
| 541 | + # Test unsupported method raises ValueError |
| 542 | + with pytest.raises(ValueError) as excinfo: |
| 543 | + collection.combine(method="invalid") |
| 544 | + assert "Unsupported method" in str(excinfo.value) |
0 commit comments