Skip to content

Commit 8f92810

Browse files
committed
refactor normalizer
1 parent 44478a2 commit 8f92810

File tree

4 files changed

+37
-61
lines changed

4 files changed

+37
-61
lines changed

torch_staintools/augmentor/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def forward(self, target: torch.Tensor, cache_keys: Optional[List[Hashable]] = N
234234
rng=self.rng,
235235
**stain_mat_kwargs)
236236

237-
target_stain_matrix = self.tensor_from_cache(cache_keys=cache_keys, func_partial=get_stain_mat_partial,
237+
target_stain_matrix = self.tensor_from_cache(cache_keys=cache_keys, func=get_stain_mat_partial,
238238
target=target)
239239

240240
# B x num_stains x num_pixel_in_mask

torch_staintools/base_module/base.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class CachedRNGModule(torch.nn.Module):
3535
device: torch.device
3636
_tensor_cache: TensorCache
3737
CACHE_FIELD: str = '_tensor_cache'
38-
rng: Optional[torch.Generator]
38+
_rng: Optional[torch.Generator]
3939

4040
def _tensor_cache_helper(self) -> Optional[TensorCache]:
4141
return getattr(self, CachedRNGModule.CACHE_FIELD)
@@ -84,27 +84,13 @@ def dump_cache(self, path: str):
8484
assert self.cache_initialized()
8585
self.tensor_cache.dump(path)
8686

87-
# @staticmethod
88-
# def _stain_mat_kwargs_helper(luminosity_threshold,
89-
# num_stains,
90-
# regularizer,
91-
# **stain_mat_kwargs):
92-
# arg_dict = {
93-
# 'luminosity_threshold': luminosity_threshold,
94-
# 'num_stains': num_stains,
95-
# 'regularizer': regularizer,
96-
# }
97-
# stain_mat_kwargs = {k: v for k, v in stain_mat_kwargs.items()}
98-
# stain_mat_kwargs.update(arg_dict)
99-
# return stain_mat_kwargs
100-
10187
@staticmethod
10288
def tensor_from_cache_helper(cache: TensorCache, *,
10389
cache_keys: List[Hashable],
104-
func_partial: Callable,
90+
func: Callable,
10591
target) -> torch.Tensor:
10692

107-
stain_mat_list = cache.get_batch(cache_keys, func_partial, target)
93+
stain_mat_list = cache.get_batch(cache_keys, func, target)
10894
if isinstance(stain_mat_list, torch.Tensor):
10995
return stain_mat_list
11096

@@ -113,21 +99,21 @@ def tensor_from_cache_helper(cache: TensorCache, *,
11399
def tensor_from_cache(self,
114100
*,
115101
cache_keys: Optional[List[Hashable]],
116-
func_partial: Callable,
102+
func: Callable,
117103
target) -> torch.Tensor:
118104
if cache_keys is not None and not self.cache_initialized():
119105
logger.warning(f"Cache keys are given but the cache is not initialized: {cache_keys[:3]} etc..")
120106

121107
if not self.cache_initialized() or cache_keys is None:
122108
logger.debug(f'{self.cache_initialized()} + {cache_keys is None} - no cache')
123-
return func_partial(target)
109+
return func(target)
124110
# if using cache
125111
assert self.cache_initialized(), f"Attempt to fetch data from cache but cache is not initialized"
126112
assert cache_keys is not None, f"Attempt to fetch data from cache but key is not given"
127113
# move fetched stain matrix to the same device of the target
128114
logger.debug(f"{cache_keys[0:3]}. cache initialized")
129115
return CachedRNGModule.tensor_from_cache_helper(cache=self.tensor_cache, cache_keys=cache_keys,
130-
func_partial=func_partial,
116+
func=func,
131117
target=target).to(target.device)
132118

133119
def __init__(self, cache: Optional[TensorCache], device: Optional[torch.device],
@@ -136,7 +122,15 @@ def __init__(self, cache: Optional[TensorCache], device: Optional[torch.device],
136122

137123
self._tensor_cache = cache
138124
self.device = default_device(device)
139-
self.rng = default_rng(rng, self.device)
125+
self._rng = default_rng(rng, self.device)
126+
127+
@property
128+
def rng(self):
129+
return self._rng
130+
131+
@rng.setter
132+
def rng(self, rng: torch.Generator):
133+
self._rng = default_rng(rng, self.device)
140134

141135
@classmethod
142136
@abstractmethod

torch_staintools/functional/stain_extraction/macenko.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable
1+
from typing import Callable, Optional
22

33
import torch
44
from .utils import percentile, cov
@@ -11,8 +11,11 @@ class MckCfg:
1111
Attributes:
1212
perc: Percentile number to find the minimum angular term. min angular as 1 percentile
1313
max angular as 100 - perc percentile.
14+
rng: torch.Generator for any random initializations incurred (e.g., if `init` is set to be unif)
15+
1416
"""
1517
perc: int
18+
rng: Optional[torch.Generator]
1619

1720
class MacenkoAlg(Callable):
1821
cfg: MckCfg

torch_staintools/normalizer/separation.py

Lines changed: 17 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
"""
44

55
import torch
6-
from torch_staintools.functional.stain_extraction.extractor import StainExtraction
6+
from torch_staintools.functional.stain_extraction.extractor import StainExtraction, StainAlg
77
from ..functional.optimization.sparse_util import METHOD_FACTORIZE
88
from ..functional.optimization.concentration import get_concentrations
9-
from torch_staintools.functional.stain_extraction.factory import build_from_name
109
from torch_staintools.functional.stain_extraction.utils import percentile
1110
from torch_staintools.functional.utility.implementation import transpose_trailing, img_from_concentration
1211
from .base import Normalizer
@@ -33,11 +32,10 @@ class StainSeparation(Normalizer):
3332
rng: torch.Generator
3433
concentration_method: METHOD_FACTORIZE
3534

36-
def __init__(self, get_stain_matrix: StainExtraction,
35+
def __init__(self, stain_alg: StainAlg,
3736
concentration_method: METHOD_FACTORIZE = 'fista',
3837
num_stains: int = 2,
3938
luminosity_threshold: float = 0.8,
40-
regularizer: float = 0.1,
4139
rng: Optional[int | torch.Generator] = None,
4240
cache: Optional[TensorCache] = None,
4341
device: Optional[torch.device] = None):
@@ -48,8 +46,7 @@ def __init__(self, get_stain_matrix: StainExtraction,
4846
regardless of batch size. Therefore, 'ls' is better for multiple small inputs in terms of H and W.
4947
5048
Args:
51-
get_stain_matrix: the Callable to obtain stain matrix - e.g., Vahadane's dict learning or
52-
macenko's SVD
49+
stain_alg: the Callable to obtain stain matrix - e.g., Vahadane or Macenko
5350
concentration_method: How to get stain concentration from stain matrix and OD through factorization.
5451
support 'ista', 'cd', and 'ls'. 'ls' simply solves the least square problem for factorization of
5552
min||HExC - OD|| but is faster. 'ista'/cd enforce the sparse penalty but slower.
@@ -58,17 +55,15 @@ def __init__(self, get_stain_matrix: StainExtraction,
5855
In general cases it is recommended to set num_stains as 2.
5956
luminosity_threshold: luminosity threshold to ignore the background. None means all regions are considered
6057
as tissue.
61-
regularizer: Regularizer term in dict learning. Note that similar to staintools, for image
62-
reconstruction step, we also use dictionary learning to get the target stain concentration.
6358
"""
6459
super().__init__(cache=cache, device=device, rng=rng)
6560
self.concentration_method = concentration_method
66-
self.get_stain_matrix = get_stain_matrix
61+
self.get_stain_matrix = StainExtraction(stain_alg)
6762
self.num_stains = num_stains
6863
self.luminosity_threshold = luminosity_threshold
69-
self.regularizer = regularizer
7064

71-
def fit(self, target, concentration_method: Optional[METHOD_FACTORIZE] = None, **stainmat_kwargs):
65+
66+
def fit(self, target, concentration_method: Optional[METHOD_FACTORIZE] = None):
7267
"""Fit to a target image.
7368
7469
Note that the stain matrices are registered into buffers so that it's move to specified device
@@ -78,18 +73,13 @@ def fit(self, target, concentration_method: Optional[METHOD_FACTORIZE] = None, *
7873
target: BCHW. Assume it's cast to torch.float32 and scaled to [0, 1]
7974
concentration_method: method to obtain concentration. Use the `self.concentration_method` if not specified
8075
in the signature.
81-
**stainmat_kwargs: Extra keyword argument of stain seperator, besides the num_stains/luminosity_threshold
82-
that are set in the __init__
8376
8477
Returns:
8578
8679
"""
8780
assert target.shape[0] == 1
8881
stain_matrix_target = self.get_stain_matrix(target, num_stains=self.num_stains,
89-
regularizer=self.regularizer,
90-
luminosity_threshold=self.luminosity_threshold,
91-
rng=self.rng,
92-
**stainmat_kwargs)
82+
luminosity_threshold=self.luminosity_threshold)
9383

9484
self.register_buffer('stain_matrix_target', stain_matrix_target)
9585
target_conc = get_concentrations(target, self.stain_matrix_target, regularizer=self.regularizer,
@@ -119,7 +109,7 @@ def repeat_stain_mat(stain_mat: torch.Tensor, image: torch.Tensor) -> torch.Tens
119109

120110
def transform(self, image: torch.Tensor,
121111
cache_keys: Optional[List[Hashable]] = None,
122-
**stain_mat_kwargs) -> torch.Tensor:
112+
**kwargs) -> torch.Tensor:
123113
"""Transformation operation.
124114
125115
Stain matrix is extracted from source image use specified stain seperator (dict learning or svd)
@@ -131,22 +121,14 @@ def transform(self, image: torch.Tensor,
131121
image: Image input must be BxCxHxW cast to torch.float32 and rescaled to [0, 1]
132122
Check torchvision.transforms.convert_image_dtype.
133123
cache_keys: unique keys point the input batch to the cached stain matrices. `None` means no cache.
134-
**stain_mat_kwargs: Extra keyword argument of stain seperator besides the num_stains
135-
and luminosity_threshold that was already set in __init__.
136-
For instance, in Macenko, an angular percentile argument "perc" may be selected to separate
137-
the angles of OD vector projected on SVD and the x-positive axis.
138124
139125
Returns:
140126
torch.Tensor: normalized output in BxCxHxW shape and float32 dtype. Note that some pixel value may exceed
141127
[0, 1] and therefore a clipping operation is applied.
142128
"""
143129
# one source matrix - multiple target
144-
get_stain_mat_partial = self.get_stain_matrix.get_partial(luminosity_threshold=self.luminosity_threshold,
145-
num_stains=self.num_stains,
146-
regularizer=self.regularizer,
147-
rng=self.rng,
148-
**stain_mat_kwargs)
149-
stain_matrix_source = self.tensor_from_cache(cache_keys=cache_keys, func_partial=get_stain_mat_partial,
130+
get_stain_matrix = self.get_stain_matrix
131+
stain_matrix_source = self.tensor_from_cache(cache_keys=cache_keys, func=get_stain_matrix,
150132
target=image)
151133

152134
# stain_matrix_source -- B x 2 x 3 wherein B is 1. Note that the input batch size is independent of how many
@@ -186,14 +168,14 @@ def forward(self, x: torch.Tensor,
186168
torch.Tensor: normalized output in BxCxHxW shape and float32 dtype. Note that some pixel value may exceed
187169
[0, 1] and therefore a clipping operation is applied.
188170
"""
189-
return self.transform(x, cache_keys, **stain_mat_kwargs)
171+
return self.transform(x, cache_keys)
190172

191173
@classmethod
192-
def build(cls, method: str,
174+
def build(cls,
175+
stain_alg: StainAlg,
193176
concentration_method: METHOD_FACTORIZE = 'fista',
194177
num_stains: int = 2,
195178
luminosity_threshold: float = 0.8,
196-
regularizer: float = 0.1,
197179
rng: Optional[int | torch.Generator] = None,
198180
use_cache: bool = False,
199181
cache_size_limit: int = -1,
@@ -203,16 +185,15 @@ def build(cls, method: str,
203185
"""Builder.
204186
205187
Args:
206-
method: method of stain extractor name: vadahane or macenko
188+
stain_alg: stain algorithm to use.
207189
concentration_method: method to obtain the concentration. default ista for computational efficiency on GPU.
208190
support 'ista', 'cd', and 'ls'. 'ls' simply solves the least square problem for factorization of
209191
min||HExC - OD|| but is faster. 'ista'/cd enforce the sparse penalty but slower.
210192
num_stains: number of stains to separate. Currently, Macenko only supports 2. In general cases it is
211193
recommended to set num_stains as 2.
212194
luminosity_threshold: luminosity threshold to ignore the background. None means all regions are considered
213195
as tissue.
214-
regularizer: regularizer term in ista for stain separation and concentration computation.
215-
rng: seed or torch.Generator for any random initialization might incur.
196+
rng: Optional. Seed for reproducibility.
216197
use_cache: whether to use cache to save the stain matrix of input image to normalize
217198
cache_size_limit: size limit of the cache. negative means no limits.
218199
device: what device to hold the cache and the normalizer. If none the device is set to cpu.
@@ -222,10 +203,8 @@ def build(cls, method: str,
222203
Returns:
223204
StainSeparation normalizer.
224205
"""
225-
method = method.lower()
226-
extractor = build_from_name(method)
227206
cache = cls._init_cache(use_cache, cache_size_limit=cache_size_limit, device=device,
228207
load_path=load_path)
229-
return cls(extractor, concentration_method=concentration_method, num_stains=num_stains,
230-
luminosity_threshold=luminosity_threshold, regularizer=regularizer, rng=rng,
208+
return cls(stain_alg, concentration_method=concentration_method, num_stains=num_stains,
209+
luminosity_threshold=luminosity_threshold, rng=rng,
231210
cache=cache, device=device).to(device)

0 commit comments

Comments
 (0)