1- from torch_staintools .functional .stain_extraction .factory import build_from_name
21from torch import nn
32import torch
4- from torch_staintools .functional .optimization .dict_learning import get_concentrations
5- from torch_staintools .functional .stain_extraction .extractor import BaseExtractor
6- from torch_staintools .functional .utility .implementation import transpose_trailing , img_from_concentration
7- from torch_staintools .functional .tissue_mask import get_tissue_mask
8- from operator import mul
9- from functools import reduce
10- from typing import Optional , Sequence , Tuple
11- import multiprocessing as mp
12- import ctypes
13- import numpy as np
3+ from typing import Optional , Sequence , Tuple , Hashable , List
4+ from ..functional .utility .implementation import default_device
5+ # from operator import mul
6+ # from functools import reduce
7+ # import multiprocessing as mp
8+ # import ctypes
9+ # import numpy as np
10+ from ..functional .stain_extraction .factory import build_from_name
11+ from ..functional .optimization .dict_learning import get_concentrations
12+ from ..functional .stain_extraction .extractor import BaseExtractor
13+ from ..functional .utility .implementation import transpose_trailing , img_from_concentration
14+ from ..functional .tissue_mask import get_tissue_mask
15+ from ..cache .tensor_cache import TensorCache
16+ from ..loggers import GlobalLoggers
17+
18+ logger = GlobalLoggers .instance ().get_logger (__name__ )
1419
1520
1621class Augmentor (nn .Module ):
17- use_cache : bool
22+ device : torch .device
23+
24+ _tensor_cache : TensorCache
25+ CACHE_FIELD : str = '_tensor_cache'
26+
1827 target_stain_idx : Optional [Sequence [int ]]
1928 rng : torch .Generator
2029
@@ -29,14 +38,23 @@ class Augmentor(nn.Module):
2938 luminosity_threshold : float
3039 regularizer : float
3140
41+ @staticmethod
42+ def _init_cache (use_cache : bool , cache_size_limit : int , device : Optional [torch .device ] = None ,
43+ load_path : Optional [str ] = None ) -> Optional [TensorCache ]:
44+ if not use_cache :
45+ return None
46+ return TensorCache .build (size_limit = cache_size_limit , device = device , path = load_path )
47+
3248 def __init__ (self , get_stain_matrix : BaseExtractor , reconst_method : str = 'ista' ,
3349 rng : Optional [int | torch .Generator ] = None ,
3450 target_stain_idx : Optional [Sequence [int ]] = (0 , 1 ),
3551 sigma_alpha : float = 0.2 ,
3652 sigma_beta : float = 0.2 ,
3753 num_stains : int = 2 ,
38- luminosity_threshold : float = 0.8 ,
39- regularizer : float = 0.01 ):
54+ luminosity_threshold : Optional [float ] = 0.8 ,
55+ regularizer : float = 0.1 ,
56+ cache : Optional [TensorCache ] = None ,
57+ device : Optional [torch .device ] = None ):
4058 """Augment the stain concentration by alpha * concentration + beta
4159
4260 Args:
@@ -50,7 +68,7 @@ def __init__(self, get_stain_matrix: BaseExtractor, reconst_method: str = 'ista'
5068 luminosity_threshold: luminosity threshold to obtain tissue region and ignore brighter backgrounds.
5169 If None, all image pixels will be considered as tissue for stain matrix/concentration computation.
5270 regularizer: the regularizer to compute concentration used in ISTA or CD algorithm.
53-
71+ cache: the external cache object
5472 """
5573 super ().__init__ ()
5674 self .reconst_method = reconst_method
@@ -65,6 +83,25 @@ def __init__(self, get_stain_matrix: BaseExtractor, reconst_method: str = 'ista'
6583 self .luminosity_threshold = luminosity_threshold
6684 self .regularizer = regularizer
6785
86+ self ._tensor_cache = cache
87+ self .device = default_device (device )
88+
89+ def to (self , device : torch .device ):
90+ self .device = device
91+ if self .cache_initialized ():
92+ self .tensor_cache .to (device )
93+ return super ().to (device )
94+
95+ @property
96+ def cache_size_limit (self ) -> int :
97+ if self .cache_initialized ():
98+ return self .tensor_cache .size_limit
99+ return 0
100+
101+ def dump_cache (self , path : str ):
102+ assert self .cache_initialized ()
103+ self .tensor_cache .dump (path )
104+
68105 @staticmethod
69106 def _default_rng (rng : Optional [torch .Generator | int ]):
70107 if rng is None :
@@ -74,21 +111,21 @@ def _default_rng(rng: Optional[torch.Generator | int]):
74111 assert isinstance (rng , torch .Generator )
75112 return rng
76113
77- @staticmethod
78- def new_cache (shape ):
79- """
80- Args:
81- shape:
82-
83- Returns:
84-
85- """
86- # Todo map the key to the corresponding cached data -- cached in file or to memory?
87- #
88- shared_array_base = mp .Array (ctypes .c_float , reduce (mul , shape ))
89- shared_array = np .ctypeslib .as_array (shared_array_base .get_obj ())
90- shared_array = shared_array .reshape (* shape )
91- return shared_array
114+ # @staticmethod
115+ # def new_cache(shape):
116+ # """
117+ # Args:
118+ # shape:
119+ #
120+ # Returns:
121+ #
122+ # """
123+ # # Todo map the key to the corresponding cached data -- cached in file or to memory?
124+ # #
125+ # shared_array_base = mp.Array(ctypes.c_float, reduce(mul, shape))
126+ # shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
127+ # shared_array = shared_array.reshape(*shape)
128+ # return shared_array
92129
93130 @staticmethod
94131 def __concentration_selected (target_concentration : torch .Tensor ,
@@ -123,9 +160,13 @@ def __inplace_augment_helper(target_concentration: torch.Tensor, *,
123160 """
124161 alpha = alpha .to (target_concentration .device )
125162 beta = beta .to (target_concentration .device )
126- tissue_mask = tissue_mask .ravel ()
127- target_concentration [..., tissue_mask ] *= alpha
128- target_concentration += beta
163+
164+ tissue_mask_flattened = tissue_mask .flatten (start_dim = - 2 , end_dim = - 1 ).expand (target_concentration .shape )
165+ alpha_expanded = alpha .expand (target_concentration .shape )
166+ target_concentration [..., tissue_mask_flattened ] *= alpha_expanded [..., tissue_mask_flattened ]
167+
168+ beta_expanded = beta .expand (target_concentration .shape )
169+ target_concentration [..., tissue_mask_flattened ] += beta_expanded [..., tissue_mask_flattened ]
129170 return target_concentration
130171
131172 @staticmethod
@@ -142,7 +183,7 @@ def channel_rand(target_concentration_selected: torch.Tensor, rng: torch.Generat
142183
143184 Args:
144185 target_concentration_selected: concentrations to work on (e.g., the entire or a subset of concentration
145- matrix
186+ matrix)
146187 rng: torch.Generator object
147188 sigma_alpha: sample alpha values in range (1-sigma, 1+ sigma)
148189 sigma_beta: sample beta values in range (-sigma, sigma)
@@ -197,11 +238,80 @@ def augment(*,
197238 alpha = alpha , beta = beta )
198239 return target_concentration
199240
200- def forward (self , target : torch .Tensor , ** stain_mat_kwargs ):
241+ @staticmethod
242+ def _stain_mat_kwargs_helper (luminosity_threshold ,
243+ num_stains ,
244+ regularizer ,
245+ ** stain_mat_kwargs ):
246+ arg_dict = {
247+ 'luminosity_threshold' : luminosity_threshold ,
248+ 'num_stains' : num_stains ,
249+ 'regularizer' : regularizer ,
250+ }
251+ stain_mat_kwargs = {k : v for k , v in stain_mat_kwargs .items ()}
252+ stain_mat_kwargs .update (arg_dict )
253+ return stain_mat_kwargs
254+
255+ @staticmethod
256+ def stain_mat_from_cache (cache : TensorCache , * ,
257+ cache_keys : List [Hashable ],
258+ get_stain_matrix : BaseExtractor ,
259+ target ,
260+ luminosity_threshold ,
261+ num_stains ,
262+ regularizer ,
263+ ** stain_mat_kwargs ) -> torch .Tensor :
264+ cache_func_kwargs = Augmentor ._stain_mat_kwargs_helper (luminosity_threshold , num_stains , regularizer ,
265+ ** stain_mat_kwargs )
266+ stain_mat_list = cache .get_batch (cache_keys , get_stain_matrix , target , ** cache_func_kwargs )
267+ if isinstance (stain_mat_list , torch .Tensor ):
268+ return stain_mat_list
269+
270+ return torch .stack (stain_mat_list , dim = 0 )
271+
272+ def _tensor_cache_helper (self ) -> Optional [TensorCache ]:
273+ return getattr (self , Augmentor .CACHE_FIELD )
274+
275+ def cache_initialized (self ):
276+ return hasattr (self , Augmentor .CACHE_FIELD ) and self ._tensor_cache_helper () is not None
277+
278+ @property
279+ def tensor_cache (self ) -> Optional [TensorCache ]:
280+ return self ._tensor_cache_helper ()
281+
282+ def stain_matrix_helper (self ,
283+ * ,
284+ cache_keys : Optional [List [Hashable ]],
285+ get_stain_matrix : BaseExtractor ,
286+ target ,
287+ luminosity_threshold ,
288+ num_stains ,
289+ regularizer ,
290+ ** stain_mat_kwargs ) -> torch .Tensor :
291+ if not self .cache_initialized () or cache_keys is None :
292+ logger .debug (f'{ self .cache_initialized ()} + { cache_keys is None } - no cache' )
293+ return get_stain_matrix (target , luminosity_threshold = luminosity_threshold ,
294+ num_stains = num_stains ,
295+ regularizer = regularizer ,
296+ ** stain_mat_kwargs )
297+ # if use cache
298+ assert self .cache_initialized (), f"Attempt to fetch data from cache but cache is not initialized"
299+ assert cache_keys is not None , f"Attempt to fetch data from cache but key is not given"
300+ # move fetched stain matrix to the same device of the target
301+ logger .debug (f"{ cache_keys [0 :3 ]} . cache initialized" )
302+ return Augmentor .stain_mat_from_cache (cache = self .tensor_cache , cache_keys = cache_keys ,
303+ get_stain_matrix = get_stain_matrix ,
304+ target = target ,
305+ luminosity_threshold = luminosity_threshold , num_stains = num_stains ,
306+ regularizer = regularizer , ** stain_mat_kwargs ,
307+ ).to (target .device )
308+
309+ def forward (self , target : torch .Tensor , cache_keys : Optional [List [Hashable ]] = None , ** stain_mat_kwargs ):
201310 """
202311
203312 Args:
204313 target: input tensor to augment. Shape B x C x H x W and intensity range is [0, 1].
314+ cache_keys: a unique key point the input entry to the cached stain matrix. `None` means no cache.
205315 **stain_mat_kwargs: all extra keyword arguments other than regularizer/num_stains/luminosity_threshold set
206316 in __init__.
207317
@@ -210,10 +320,11 @@ def forward(self, target: torch.Tensor, **stain_mat_kwargs):
210320 """
211321 # stain_matrix_target -- B x num_stain x num_input_color_channel
212322 # todo cache
213- target_stain_matrix = self .get_stain_matrix (target , luminosity_threshold = self .luminosity_threshold ,
214- num_stains = self .num_stains ,
215- regularizer = self .regularizer ,
216- ** stain_mat_kwargs )
323+ target_stain_matrix = self .stain_matrix_helper (cache_keys = cache_keys , get_stain_matrix = self .get_stain_matrix ,
324+ target = target , luminosity_threshold = self .luminosity_threshold ,
325+ num_stains = self .num_stains ,
326+ regularizer = self .regularizer ,
327+ ** stain_mat_kwargs )
217328
218329 # B x num_stains x num_pixel_in_mask
219330 concentration = get_concentrations (target , target_stain_matrix , regularizer = self .regularizer ,
@@ -236,9 +347,40 @@ def build(cls,
236347 rng : Optional [int | torch .Generator ] = None ,
237348 target_stain_idx : Optional [Sequence [int ]] = (0 , 1 ),
238349 sigma_alpha : float = 0.2 ,
239- sigma_beta : float = 0.2 ):
350+ sigma_beta : float = 0.2 ,
351+ luminosity_threshold : Optional [float ] = 0.8 ,
352+ regularizer : float = 0.1 ,
353+ use_cache : bool = False ,
354+ cache_size_limit : int = - 1 ,
355+ device : Optional [torch .device ] = None ,
356+ load_path : Optional [str ] = None
357+ ):
358+ """Factory builder of the augmentor which manipulate the stain concentration by alpha * concentration + beta.
359+
360+ Args:
361+ method: algorithm name to extract stain - support 'vahadane' or 'macenko'
362+ reconst_method: algorithm to compute concentration. default ista
363+ rng: a optional seed (either an int or a torch.Generator) to determine the random number generation.
364+ target_stain_idx: what stains to augment: e.g., for HE cases, it can be either or both from [0, 1]
365+ sigma_alpha: alpha is uniformly randomly selected from (1-sigma_alpha, 1+sigma_alpha)
366+ sigma_beta: beta is uniformly randomly selected from (-sigma_beta, sigma_beta)
367+ luminosity_threshold: luminosity threshold to find tissue regions (smaller than but positive)
368+ a pixel is considered as being tissue if the intensity falls in the open interval of (0, threshold).
369+ regularizer: regularization term in ISTA algorithm
370+ use_cache: whether use cache to save the stain matrix to avoid recomputation
371+ cache_size_limit: size limit of the cache. negative means no limits.
372+ device: what device to hold the cache.
373+ load_path: If specified, then stain matrix cache will be loaded from the file path. See the `cache`
374+ module for more details.
375+
376+ Returns:
377+
378+ """
240379 method = method .lower ()
241380 extractor = build_from_name (method )
381+ cache = cls ._init_cache (use_cache , cache_size_limit = cache_size_limit , device = device ,
382+ load_path = load_path )
242383 return cls (extractor , reconst_method = reconst_method , rng = rng , target_stain_idx = target_stain_idx ,
243- sigma_alpha = sigma_alpha , sigma_beta = sigma_beta )
244-
384+ sigma_alpha = sigma_alpha , sigma_beta = sigma_beta ,
385+ luminosity_threshold = luminosity_threshold , regularizer = regularizer ,
386+ cache = cache , device = device )
0 commit comments