Skip to content

Commit 40bdd20

Browse files
committed
(1) fix tissue mask for augmentor. (2) add cache for stain matrix in augmentation for acceleration of re-augmentation of the previous inputs.
1 parent 7f867b8 commit 40bdd20

File tree

10 files changed

+478
-58
lines changed

10 files changed

+478
-58
lines changed

README.md

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,11 @@ target_tensor = ToTensor()(target).unsqueeze(0).to(device)
5353
norm_tensor = ToTensor()(norm).unsqueeze(0).to(device)
5454

5555
# ######## Normalization
56-
# fit
56+
# create the normalizer - using vahadane. Alternatively can use 'macenko' or 'reinhard'.
5757
normalizer_vahadane = NormalizerBuilder.build('vahadane')
58+
# move the normalizer to the device (CPU or GPU)
5859
normalizer_vahadane = normalizer_vahadane.to(device)
60+
# fit. For macenko and vahadane this step will compute the stain matrix and concentration
5961
normalizer_vahadane.fit(target_tensor)
6062
# transform
6163
# BCHW - scaled to [0, 1] torch.float32
@@ -65,15 +67,37 @@ output = normalizer_vahadane(norm_tensor)
6567
# augment by: alpha * concentration + beta, while alpha is uniformly randomly sampled from (1 - sigma_alpha, 1 + sigma_alpha),
6668
# and beta is uniformly randomly sampled from (-sigma_beta, sigma_beta).
6769
augmentor = AugmentorBuilder.build('vahadane',
70+
# fix the random number generator seed for reproducibility.
6871
rng=314159,
72+
# the luminosity threshold to find the tissue region to augment
73+
# if set to None means all pixels are treated as tissue
74+
luminosity_threshold=0.8,
75+
6976
sigma_alpha=0.2,
70-
sigma_beta=0.2, target_stain_idx=(0, 1)
77+
sigma_beta=0.2, target_stain_idx=(0, 1),
78+
# this allows to cache the stain matrix if it's too time-consuming to recompute.
79+
# e.g., if using Vahadane algorithm
80+
use_cache=True,
81+
# size limit of cache. -1 means no limit (stain matrix is often small in size, e.g., 2 x 3)
82+
cache_size_limit=-1,
83+
# if specified, the augmentor will load the cached stain matrices from file system.
84+
load_path=None,
7185
)
7286

7387
num_augment = 5
88+
# multiple copies of different random augmentation of the same tile may be generated
7489
for _ in range(num_augment):
7590
# B x C x H x W
76-
aug_out = augmentor(norm_tensor)
91+
# use a list of Hashable key (e.g., str) to map the batch input to its corresponding stain matrix in cache.
92+
# this key should be unique, e.g., using the filename of the input tile.
93+
# leave it as None if no caching is intended, even if use_cache is enabled.
94+
# note since the inputs are all batchified, the cache_key are in form of a list, with each element in the
95+
# list corresponding to a data point in the batch.
96+
aug_out = augmentor(norm_tensor, cache_keys=['some unique key'])
97+
# do anything to the augmentation output
98+
99+
# dump the cache of stain matrices for future usage
100+
augmentor.dump_cache('./cache.pickle')
77101
```
78102

79103
## Installation

demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def postprocess(image_tensor): return convert_image_dtype(image_tensor, torch.ui
115115
for idx, tile_single in enumerate(tqdm(tiles)):
116116
tile_single = tile_single.unsqueeze(0).contiguous()
117117
# BCHW - scaled to [0 1] torch.float32
118-
test_out_tensor = augmentor(tile_single, regularizer=0.01, )
118+
test_out_tensor = augmentor(tile_single)
119119
test_out = postprocess(test_out_tensor)
120120
plt.imshow(test_out)
121121
plt.title(f"Augmented: {idx}")

torch_staintools/augmentor/base.py

Lines changed: 184 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,29 @@
1-
from torch_staintools.functional.stain_extraction.factory import build_from_name
21
from torch import nn
32
import torch
4-
from torch_staintools.functional.optimization.dict_learning import get_concentrations
5-
from torch_staintools.functional.stain_extraction.extractor import BaseExtractor
6-
from torch_staintools.functional.utility.implementation import transpose_trailing, img_from_concentration
7-
from torch_staintools.functional.tissue_mask import get_tissue_mask
8-
from operator import mul
9-
from functools import reduce
10-
from typing import Optional, Sequence, Tuple
11-
import multiprocessing as mp
12-
import ctypes
13-
import numpy as np
3+
from typing import Optional, Sequence, Tuple, Hashable, List
4+
from ..functional.utility.implementation import default_device
5+
# from operator import mul
6+
# from functools import reduce
7+
# import multiprocessing as mp
8+
# import ctypes
9+
# import numpy as np
10+
from ..functional.stain_extraction.factory import build_from_name
11+
from ..functional.optimization.dict_learning import get_concentrations
12+
from ..functional.stain_extraction.extractor import BaseExtractor
13+
from ..functional.utility.implementation import transpose_trailing, img_from_concentration
14+
from ..functional.tissue_mask import get_tissue_mask
15+
from ..cache.tensor_cache import TensorCache
16+
from ..loggers import GlobalLoggers
17+
18+
logger = GlobalLoggers.instance().get_logger(__name__)
1419

1520

1621
class Augmentor(nn.Module):
17-
use_cache: bool
22+
device: torch.device
23+
24+
_tensor_cache: TensorCache
25+
CACHE_FIELD: str = '_tensor_cache'
26+
1827
target_stain_idx: Optional[Sequence[int]]
1928
rng: torch.Generator
2029

@@ -29,14 +38,23 @@ class Augmentor(nn.Module):
2938
luminosity_threshold: float
3039
regularizer: float
3140

41+
@staticmethod
42+
def _init_cache(use_cache: bool, cache_size_limit: int, device: Optional[torch.device] = None,
43+
load_path: Optional[str] = None) -> Optional[TensorCache]:
44+
if not use_cache:
45+
return None
46+
return TensorCache.build(size_limit=cache_size_limit, device=device, path=load_path)
47+
3248
def __init__(self, get_stain_matrix: BaseExtractor, reconst_method: str = 'ista',
3349
rng: Optional[int | torch.Generator] = None,
3450
target_stain_idx: Optional[Sequence[int]] = (0, 1),
3551
sigma_alpha: float = 0.2,
3652
sigma_beta: float = 0.2,
3753
num_stains: int = 2,
38-
luminosity_threshold: float = 0.8,
39-
regularizer: float = 0.01):
54+
luminosity_threshold: Optional[float] = 0.8,
55+
regularizer: float = 0.1,
56+
cache: Optional[TensorCache] = None,
57+
device: Optional[torch.device] = None):
4058
"""Augment the stain concentration by alpha * concentration + beta
4159
4260
Args:
@@ -50,7 +68,7 @@ def __init__(self, get_stain_matrix: BaseExtractor, reconst_method: str = 'ista'
5068
luminosity_threshold: luminosity threshold to obtain tissue region and ignore brighter backgrounds.
5169
If None, all image pixels will be considered as tissue for stain matrix/concentration computation.
5270
regularizer: the regularizer to compute concentration used in ISTA or CD algorithm.
53-
71+
cache: the external cache object
5472
"""
5573
super().__init__()
5674
self.reconst_method = reconst_method
@@ -65,6 +83,25 @@ def __init__(self, get_stain_matrix: BaseExtractor, reconst_method: str = 'ista'
6583
self.luminosity_threshold = luminosity_threshold
6684
self.regularizer = regularizer
6785

86+
self._tensor_cache = cache
87+
self.device = default_device(device)
88+
89+
def to(self, device: torch.device):
90+
self.device = device
91+
if self.cache_initialized():
92+
self.tensor_cache.to(device)
93+
return super().to(device)
94+
95+
@property
96+
def cache_size_limit(self) -> int:
97+
if self.cache_initialized():
98+
return self.tensor_cache.size_limit
99+
return 0
100+
101+
def dump_cache(self, path: str):
102+
assert self.cache_initialized()
103+
self.tensor_cache.dump(path)
104+
68105
@staticmethod
69106
def _default_rng(rng: Optional[torch.Generator | int]):
70107
if rng is None:
@@ -74,21 +111,21 @@ def _default_rng(rng: Optional[torch.Generator | int]):
74111
assert isinstance(rng, torch.Generator)
75112
return rng
76113

77-
@staticmethod
78-
def new_cache(shape):
79-
"""
80-
Args:
81-
shape:
82-
83-
Returns:
84-
85-
"""
86-
# Todo map the key to the corresponding cached data -- cached in file or to memory?
87-
#
88-
shared_array_base = mp.Array(ctypes.c_float, reduce(mul, shape))
89-
shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
90-
shared_array = shared_array.reshape(*shape)
91-
return shared_array
114+
# @staticmethod
115+
# def new_cache(shape):
116+
# """
117+
# Args:
118+
# shape:
119+
#
120+
# Returns:
121+
#
122+
# """
123+
# # Todo map the key to the corresponding cached data -- cached in file or to memory?
124+
# #
125+
# shared_array_base = mp.Array(ctypes.c_float, reduce(mul, shape))
126+
# shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
127+
# shared_array = shared_array.reshape(*shape)
128+
# return shared_array
92129

93130
@staticmethod
94131
def __concentration_selected(target_concentration: torch.Tensor,
@@ -123,9 +160,13 @@ def __inplace_augment_helper(target_concentration: torch.Tensor, *,
123160
"""
124161
alpha = alpha.to(target_concentration.device)
125162
beta = beta.to(target_concentration.device)
126-
tissue_mask = tissue_mask.ravel()
127-
target_concentration[..., tissue_mask] *= alpha
128-
target_concentration += beta
163+
164+
tissue_mask_flattened = tissue_mask.flatten(start_dim=-2, end_dim=-1).expand(target_concentration.shape)
165+
alpha_expanded = alpha.expand(target_concentration.shape)
166+
target_concentration[..., tissue_mask_flattened] *= alpha_expanded[..., tissue_mask_flattened]
167+
168+
beta_expanded = beta.expand(target_concentration.shape)
169+
target_concentration[..., tissue_mask_flattened] += beta_expanded[..., tissue_mask_flattened]
129170
return target_concentration
130171

131172
@staticmethod
@@ -142,7 +183,7 @@ def channel_rand(target_concentration_selected: torch.Tensor, rng: torch.Generat
142183
143184
Args:
144185
target_concentration_selected: concentrations to work on (e.g., the entire or a subset of concentration
145-
matrix
186+
matrix)
146187
rng: torch.Generator object
147188
sigma_alpha: sample alpha values in range (1-sigma, 1+ sigma)
148189
sigma_beta: sample beta values in range (-sigma, sigma)
@@ -197,11 +238,80 @@ def augment(*,
197238
alpha=alpha, beta=beta)
198239
return target_concentration
199240

200-
def forward(self, target: torch.Tensor, **stain_mat_kwargs):
241+
@staticmethod
242+
def _stain_mat_kwargs_helper(luminosity_threshold,
243+
num_stains,
244+
regularizer,
245+
**stain_mat_kwargs):
246+
arg_dict = {
247+
'luminosity_threshold': luminosity_threshold,
248+
'num_stains': num_stains,
249+
'regularizer': regularizer,
250+
}
251+
stain_mat_kwargs = {k: v for k, v in stain_mat_kwargs.items()}
252+
stain_mat_kwargs.update(arg_dict)
253+
return stain_mat_kwargs
254+
255+
@staticmethod
256+
def stain_mat_from_cache(cache: TensorCache, *,
257+
cache_keys: List[Hashable],
258+
get_stain_matrix: BaseExtractor,
259+
target,
260+
luminosity_threshold,
261+
num_stains,
262+
regularizer,
263+
**stain_mat_kwargs) -> torch.Tensor:
264+
cache_func_kwargs = Augmentor._stain_mat_kwargs_helper(luminosity_threshold, num_stains, regularizer,
265+
**stain_mat_kwargs)
266+
stain_mat_list = cache.get_batch(cache_keys, get_stain_matrix, target, **cache_func_kwargs)
267+
if isinstance(stain_mat_list, torch.Tensor):
268+
return stain_mat_list
269+
270+
return torch.stack(stain_mat_list, dim=0)
271+
272+
def _tensor_cache_helper(self) -> Optional[TensorCache]:
273+
return getattr(self, Augmentor.CACHE_FIELD)
274+
275+
def cache_initialized(self):
276+
return hasattr(self, Augmentor.CACHE_FIELD) and self._tensor_cache_helper() is not None
277+
278+
@property
279+
def tensor_cache(self) -> Optional[TensorCache]:
280+
return self._tensor_cache_helper()
281+
282+
def stain_matrix_helper(self,
283+
*,
284+
cache_keys: Optional[List[Hashable]],
285+
get_stain_matrix: BaseExtractor,
286+
target,
287+
luminosity_threshold,
288+
num_stains,
289+
regularizer,
290+
**stain_mat_kwargs) -> torch.Tensor:
291+
if not self.cache_initialized() or cache_keys is None:
292+
logger.debug(f'{self.cache_initialized()} + {cache_keys is None} - no cache')
293+
return get_stain_matrix(target, luminosity_threshold=luminosity_threshold,
294+
num_stains=num_stains,
295+
regularizer=regularizer,
296+
**stain_mat_kwargs)
297+
# if use cache
298+
assert self.cache_initialized(), f"Attempt to fetch data from cache but cache is not initialized"
299+
assert cache_keys is not None, f"Attempt to fetch data from cache but key is not given"
300+
# move fetched stain matrix to the same device of the target
301+
logger.debug(f"{cache_keys[0:3]}. cache initialized")
302+
return Augmentor.stain_mat_from_cache(cache=self.tensor_cache, cache_keys=cache_keys,
303+
get_stain_matrix=get_stain_matrix,
304+
target=target,
305+
luminosity_threshold=luminosity_threshold, num_stains=num_stains,
306+
regularizer=regularizer, **stain_mat_kwargs,
307+
).to(target.device)
308+
309+
def forward(self, target: torch.Tensor, cache_keys: Optional[List[Hashable]] = None, **stain_mat_kwargs):
201310
"""
202311
203312
Args:
204313
target: input tensor to augment. Shape B x C x H x W and intensity range is [0, 1].
314+
cache_keys: a unique key point the input entry to the cached stain matrix. `None` means no cache.
205315
**stain_mat_kwargs: all extra keyword arguments other than regularizer/num_stains/luminosity_threshold set
206316
in __init__.
207317
@@ -210,10 +320,11 @@ def forward(self, target: torch.Tensor, **stain_mat_kwargs):
210320
"""
211321
# stain_matrix_target -- B x num_stain x num_input_color_channel
212322
# todo cache
213-
target_stain_matrix = self.get_stain_matrix(target, luminosity_threshold=self.luminosity_threshold,
214-
num_stains=self.num_stains,
215-
regularizer=self.regularizer,
216-
**stain_mat_kwargs)
323+
target_stain_matrix = self.stain_matrix_helper(cache_keys=cache_keys, get_stain_matrix=self.get_stain_matrix,
324+
target=target, luminosity_threshold=self.luminosity_threshold,
325+
num_stains=self.num_stains,
326+
regularizer=self.regularizer,
327+
**stain_mat_kwargs)
217328

218329
# B x num_stains x num_pixel_in_mask
219330
concentration = get_concentrations(target, target_stain_matrix, regularizer=self.regularizer,
@@ -236,9 +347,40 @@ def build(cls,
236347
rng: Optional[int | torch.Generator] = None,
237348
target_stain_idx: Optional[Sequence[int]] = (0, 1),
238349
sigma_alpha: float = 0.2,
239-
sigma_beta: float = 0.2):
350+
sigma_beta: float = 0.2,
351+
luminosity_threshold: Optional[float] = 0.8,
352+
regularizer: float = 0.1,
353+
use_cache: bool = False,
354+
cache_size_limit: int = -1,
355+
device: Optional[torch.device] = None,
356+
load_path: Optional[str] = None
357+
):
358+
"""Factory builder of the augmentor which manipulate the stain concentration by alpha * concentration + beta.
359+
360+
Args:
361+
method: algorithm name to extract stain - support 'vahadane' or 'macenko'
362+
reconst_method: algorithm to compute concentration. default ista
363+
rng: a optional seed (either an int or a torch.Generator) to determine the random number generation.
364+
target_stain_idx: what stains to augment: e.g., for HE cases, it can be either or both from [0, 1]
365+
sigma_alpha: alpha is uniformly randomly selected from (1-sigma_alpha, 1+sigma_alpha)
366+
sigma_beta: beta is uniformly randomly selected from (-sigma_beta, sigma_beta)
367+
luminosity_threshold: luminosity threshold to find tissue regions (smaller than but positive)
368+
a pixel is considered as being tissue if the intensity falls in the open interval of (0, threshold).
369+
regularizer: regularization term in ISTA algorithm
370+
use_cache: whether use cache to save the stain matrix to avoid recomputation
371+
cache_size_limit: size limit of the cache. negative means no limits.
372+
device: what device to hold the cache.
373+
load_path: If specified, then stain matrix cache will be loaded from the file path. See the `cache`
374+
module for more details.
375+
376+
Returns:
377+
378+
"""
240379
method = method.lower()
241380
extractor = build_from_name(method)
381+
cache = cls._init_cache(use_cache, cache_size_limit=cache_size_limit, device=device,
382+
load_path=load_path)
242383
return cls(extractor, reconst_method=reconst_method, rng=rng, target_stain_idx=target_stain_idx,
243-
sigma_alpha=sigma_alpha, sigma_beta=sigma_beta)
244-
384+
sigma_alpha=sigma_alpha, sigma_beta=sigma_beta,
385+
luminosity_threshold=luminosity_threshold, regularizer=regularizer,
386+
cache=cache, device=device)

0 commit comments

Comments
 (0)