33"""
44
55import torch
6- from torch_staintools .functional .stain_extraction .extractor import StainExtraction
6+ from torch_staintools .functional .stain_extraction .extractor import StainExtraction , StainAlg
77from ..functional .optimization .sparse_util import METHOD_FACTORIZE
88from ..functional .optimization .concentration import get_concentrations
9- from torch_staintools .functional .stain_extraction .factory import build_from_name
109from torch_staintools .functional .stain_extraction .utils import percentile
1110from torch_staintools .functional .utility .implementation import transpose_trailing , img_from_concentration
1211from .base import Normalizer
@@ -33,11 +32,10 @@ class StainSeparation(Normalizer):
3332 rng : torch .Generator
3433 concentration_method : METHOD_FACTORIZE
3534
36- def __init__ (self , get_stain_matrix : StainExtraction ,
35+ def __init__ (self , stain_alg : StainAlg ,
3736 concentration_method : METHOD_FACTORIZE = 'fista' ,
3837 num_stains : int = 2 ,
3938 luminosity_threshold : float = 0.8 ,
40- regularizer : float = 0.1 ,
4139 rng : Optional [int | torch .Generator ] = None ,
4240 cache : Optional [TensorCache ] = None ,
4341 device : Optional [torch .device ] = None ):
@@ -48,8 +46,7 @@ def __init__(self, get_stain_matrix: StainExtraction,
4846 regardless of batch size. Therefore, 'ls' is better for multiple small inputs in terms of H and W.
4947
5048 Args:
51- get_stain_matrix: the Callable to obtain stain matrix - e.g., Vahadane's dict learning or
52- macenko's SVD
49+ stain_alg: the Callable to obtain stain matrix - e.g., Vahadane or Macenko
5350 concentration_method: How to get stain concentration from stain matrix and OD through factorization.
5451 support 'ista', 'cd', and 'ls'. 'ls' simply solves the least square problem for factorization of
5552 min||HExC - OD|| but is faster. 'ista'/cd enforce the sparse penalty but slower.
@@ -58,17 +55,15 @@ def __init__(self, get_stain_matrix: StainExtraction,
5855 In general cases it is recommended to set num_stains as 2.
5956 luminosity_threshold: luminosity threshold to ignore the background. None means all regions are considered
6057 as tissue.
61- regularizer: Regularizer term in dict learning. Note that similar to staintools, for image
62- reconstruction step, we also use dictionary learning to get the target stain concentration.
6358 """
6459 super ().__init__ (cache = cache , device = device , rng = rng )
6560 self .concentration_method = concentration_method
66- self .get_stain_matrix = get_stain_matrix
61+ self .get_stain_matrix = StainExtraction ( stain_alg )
6762 self .num_stains = num_stains
6863 self .luminosity_threshold = luminosity_threshold
69- self .regularizer = regularizer
7064
71- def fit (self , target , concentration_method : Optional [METHOD_FACTORIZE ] = None , ** stainmat_kwargs ):
65+
66+ def fit (self , target , concentration_method : Optional [METHOD_FACTORIZE ] = None ):
7267 """Fit to a target image.
7368
7469 Note that the stain matrices are registered into buffers so that it's move to specified device
@@ -78,18 +73,13 @@ def fit(self, target, concentration_method: Optional[METHOD_FACTORIZE] = None, *
7873 target: BCHW. Assume it's cast to torch.float32 and scaled to [0, 1]
7974 concentration_method: method to obtain concentration. Use the `self.concentration_method` if not specified
8075 in the signature.
81- **stainmat_kwargs: Extra keyword argument of stain seperator, besides the num_stains/luminosity_threshold
82- that are set in the __init__
8376
8477 Returns:
8578
8679 """
8780 assert target .shape [0 ] == 1
8881 stain_matrix_target = self .get_stain_matrix (target , num_stains = self .num_stains ,
89- regularizer = self .regularizer ,
90- luminosity_threshold = self .luminosity_threshold ,
91- rng = self .rng ,
92- ** stainmat_kwargs )
82+ luminosity_threshold = self .luminosity_threshold )
9383
9484 self .register_buffer ('stain_matrix_target' , stain_matrix_target )
9585 target_conc = get_concentrations (target , self .stain_matrix_target , regularizer = self .regularizer ,
@@ -119,7 +109,7 @@ def repeat_stain_mat(stain_mat: torch.Tensor, image: torch.Tensor) -> torch.Tens
119109
120110 def transform (self , image : torch .Tensor ,
121111 cache_keys : Optional [List [Hashable ]] = None ,
122- ** stain_mat_kwargs ) -> torch .Tensor :
112+ ** kwargs ) -> torch .Tensor :
123113 """Transformation operation.
124114
125115 Stain matrix is extracted from source image use specified stain seperator (dict learning or svd)
@@ -131,22 +121,14 @@ def transform(self, image: torch.Tensor,
131121 image: Image input must be BxCxHxW cast to torch.float32 and rescaled to [0, 1]
132122 Check torchvision.transforms.convert_image_dtype.
133123 cache_keys: unique keys point the input batch to the cached stain matrices. `None` means no cache.
134- **stain_mat_kwargs: Extra keyword argument of stain seperator besides the num_stains
135- and luminosity_threshold that was already set in __init__.
136- For instance, in Macenko, an angular percentile argument "perc" may be selected to separate
137- the angles of OD vector projected on SVD and the x-positive axis.
138124
139125 Returns:
140126 torch.Tensor: normalized output in BxCxHxW shape and float32 dtype. Note that some pixel value may exceed
141127 [0, 1] and therefore a clipping operation is applied.
142128 """
143129 # one source matrix - multiple target
144- get_stain_mat_partial = self .get_stain_matrix .get_partial (luminosity_threshold = self .luminosity_threshold ,
145- num_stains = self .num_stains ,
146- regularizer = self .regularizer ,
147- rng = self .rng ,
148- ** stain_mat_kwargs )
149- stain_matrix_source = self .tensor_from_cache (cache_keys = cache_keys , func_partial = get_stain_mat_partial ,
130+ get_stain_matrix = self .get_stain_matrix
131+ stain_matrix_source = self .tensor_from_cache (cache_keys = cache_keys , func = get_stain_matrix ,
150132 target = image )
151133
152134 # stain_matrix_source -- B x 2 x 3 wherein B is 1. Note that the input batch size is independent of how many
@@ -186,14 +168,14 @@ def forward(self, x: torch.Tensor,
186168 torch.Tensor: normalized output in BxCxHxW shape and float32 dtype. Note that some pixel value may exceed
187169 [0, 1] and therefore a clipping operation is applied.
188170 """
189- return self .transform (x , cache_keys , ** stain_mat_kwargs )
171+ return self .transform (x , cache_keys )
190172
191173 @classmethod
192- def build (cls , method : str ,
174+ def build (cls ,
175+ stain_alg : StainAlg ,
193176 concentration_method : METHOD_FACTORIZE = 'fista' ,
194177 num_stains : int = 2 ,
195178 luminosity_threshold : float = 0.8 ,
196- regularizer : float = 0.1 ,
197179 rng : Optional [int | torch .Generator ] = None ,
198180 use_cache : bool = False ,
199181 cache_size_limit : int = - 1 ,
@@ -203,16 +185,15 @@ def build(cls, method: str,
203185 """Builder.
204186
205187 Args:
206- method: method of stain extractor name: vadahane or macenko
188+ stain_alg: stain algorithm to use.
207189 concentration_method: method to obtain the concentration. default ista for computational efficiency on GPU.
208190 support 'ista', 'cd', and 'ls'. 'ls' simply solves the least square problem for factorization of
209191 min||HExC - OD|| but is faster. 'ista'/cd enforce the sparse penalty but slower.
210192 num_stains: number of stains to separate. Currently, Macenko only supports 2. In general cases it is
211193 recommended to set num_stains as 2.
212194 luminosity_threshold: luminosity threshold to ignore the background. None means all regions are considered
213195 as tissue.
214- regularizer: regularizer term in ista for stain separation and concentration computation.
215- rng: seed or torch.Generator for any random initialization might incur.
196+ rng: Optional. Seed for reproducibility.
216197 use_cache: whether to use cache to save the stain matrix of input image to normalize
217198 cache_size_limit: size limit of the cache. negative means no limits.
218199 device: what device to hold the cache and the normalizer. If none the device is set to cpu.
@@ -222,10 +203,8 @@ def build(cls, method: str,
222203 Returns:
223204 StainSeparation normalizer.
224205 """
225- method = method .lower ()
226- extractor = build_from_name (method )
227206 cache = cls ._init_cache (use_cache , cache_size_limit = cache_size_limit , device = device ,
228207 load_path = load_path )
229- return cls (extractor , concentration_method = concentration_method , num_stains = num_stains ,
230- luminosity_threshold = luminosity_threshold , regularizer = regularizer , rng = rng ,
208+ return cls (stain_alg , concentration_method = concentration_method , num_stains = num_stains ,
209+ luminosity_threshold = luminosity_threshold , rng = rng ,
231210 cache = cache , device = device ).to (device )
0 commit comments