|
1 | 1 | from abc import ABC, abstractmethod |
2 | 2 | from pathlib import Path |
3 | | -from typing import Union |
4 | | - |
| 3 | +from typing import Union, Optional |
| 4 | +import numpy as np |
5 | 5 | from auxiliary.io import read_image, write_image |
6 | 6 |
|
7 | 7 |
|
8 | 8 | class Defacer(ABC): |
9 | | - """ |
10 | | - Base class for defacing medical images using brain masks. |
11 | | -
|
12 | | - Subclasses should implement the `deface` method to generate a defaced image |
13 | | - based on the provided input image and mask. |
14 | | - """ |
15 | | - |
16 | | - @abstractmethod |
17 | | - def deface( |
| 9 | + def __init__( |
18 | 10 | self, |
19 | | - input_image_path: Union[str, Path], |
20 | | - mask_image_path: Union[str, Path], |
21 | | - ) -> None: |
| 11 | + masking_value: Optional[Union[int, float]] = None, |
| 12 | + ): |
22 | 13 | """ |
23 | | - Generate a defacing mask provided an input image. |
| 14 | + Base class for defacing medical images using brain masks. |
24 | 15 |
|
25 | | - Args: |
26 | | - input_image_path (str or Path): Path to the input image (NIfTI format). |
27 | | - mask_image_path (str or Path): Path to the output mask image (NIfTI format). |
| 16 | + Subclasses should implement the `deface` method to generate a defaced image |
| 17 | + based on the provided input image and mask. |
28 | 18 | """ |
29 | | - pass |
| 19 | + # Here, masking value functions as a global value across all images and modalities |
| 20 | + # If no value is passed, the minimum of a given input image is chosen |
| 21 | + # TODO: Consider extending this to modality-specific masking values in the future, this should |
| 22 | + # probably be implemented as a property of the the specific modality |
| 23 | + self.masking_value = masking_value |
| 24 | + |
| 25 | + @abstractmethod |
| 26 | + def deface( |
| 27 | + self, |
| 28 | + input_image_path: Union[str, Path], |
| 29 | + mask_image_path: Union[str, Path], |
| 30 | + ) -> None: |
| 31 | + """ |
| 32 | + Generate a defacing mask provided an input image. |
| 33 | +
|
| 34 | + Args: |
| 35 | + input_image_path (str or Path): Path to the input image (NIfTI format). |
| 36 | + mask_image_path (str or Path): Path to the output mask image (NIfTI format). |
| 37 | + """ |
| 38 | + pass |
30 | 39 |
|
31 | 40 | def apply_mask( |
32 | 41 | self, |
@@ -63,8 +72,17 @@ def apply_mask( |
63 | 72 | if input_data.shape != mask_data.shape: |
64 | 73 | raise ValueError("Input image and mask must have the same dimensions.") |
65 | 74 |
|
66 | | - # Apply mask (element-wise multiplication) |
67 | | - masked_data = input_data * mask_data |
| 75 | + # check whether a global masking value was passed, otherwise choose minimum |
| 76 | + if self.masking_value is None: |
| 77 | + current_masking_value = np.min(input_data) |
| 78 | + else: |
| 79 | + current_masking_value = ( |
| 80 | + np.array(self.masking_value).astype(input_data.dtype).item() |
| 81 | + ) |
| 82 | + # Apply mask (element-wise either input or masking value) |
| 83 | + masked_data = np.where( |
| 84 | + mask_data.astype(bool), input_data, current_masking_value |
| 85 | + ) |
68 | 86 |
|
69 | 87 | # Save the defaced image |
70 | 88 | write_image( |
|
0 commit comments