11from typing import Callable , Optional
22
33import torch
4- from .utils import percentile , cov
4+ from .utils import percentile , batch_masked_cov , batch_masked_perc , cov
55from dataclasses import dataclass
66
7+ from ..compile import lazy_compile
8+ from ...constants import CONFIG
9+
10+
711@dataclass (frozen = False )
812class MckCfg :
913 """Configration of Macenko Stain Estimation.
@@ -23,38 +27,77 @@ def __init__(self, cfg: MckCfg):
2327 super ().__init__ ()
2428 self .cfg = cfg
2529
26- @staticmethod
27- def cov (x ):
28- """Covariance matrix for eigen decomposition.
29- https://en.wikipedia.org/wiki/Covariance_matrix
30- """
31- E_x = x .mean (dim = 1 )
32- x = x - E_x [:, None ]
33- return torch .mm (x , x .T ) / (x .size (1 ) - 1 )
3430
3531 @staticmethod
3632 def angular_helper (t_hat , ):
3733 # todo deal with multi-dimensional scenario
3834 raise NotImplementedError
3935
4036 @staticmethod
41- def stain_matrix_helper (t_hat : torch .Tensor , perc : int , eig_vecs : torch .Tensor ):
37+ def stain_matrix_helper (t_hat : torch .Tensor , mask_flatten : torch .Tensor ,
38+ perc : int , eig_vecs : torch .Tensor ):
4239 """Helper function to compute the stain matrix.
4340
4441 Separate the projected OD vectors on singular vectors (SVD of OD in Macenko paper, which is also the
4542 eigen vector of the covariance matrix of the OD)
4643
4744 Args:
4845 t_hat: projection of OD on the plane of most significant singular vectors of OD.
49- perc: perc --> min angular term, 100 - perc --> max angular term
46+ B x num_pixel. Not masked.
47+ mask_flatten: the flattened mask. B x num_pixel x 1.
48+ perc: perc --> min angular term, 100 - perc --> max angular term. integer [0, 100].
5049 eig_vecs: eigen vectors of the cov(OD), which may also be the singular vectors of OD.
50+ B x C x num_stains
5151
5252 Returns:
5353 sorted stain matrix in shape of B x num_stains x num_input_color_channel. For H&E cases, the first row
5454 in dimension of num_stains is H and the second is E (only num_stains=2 supported for now).
5555 """
56- phi = torch .atan2 (t_hat [:, 1 ], t_hat [:, 0 ])
56+ # batchified. t_hat as B x num_pixel x num_stains
57+ # phi as B x num_pixels. Unmasked at this point.
58+ phi = torch .atan2 (t_hat [..., 1 ], t_hat [..., 0 ])
59+ # phi -> num_pix
60+ # requires mask and phi has the same number of dimension.
61+ # therefore collapse the final dim
62+ min_phi = batch_masked_perc (phi , mask_flatten .squeeze (- 1 ), perc , dim = 1 )
63+ max_phi = batch_masked_perc (phi , mask_flatten .squeeze (- 1 ), 100 - perc , dim = 1 )
64+
65+ # B x 2 x 1
66+ rot_min = torch .stack ([torch .cos (min_phi ), torch .sin (min_phi )], dim = - 1 ).unsqueeze (- 1 )
67+ rot_max = torch .stack ([torch .cos (max_phi ), torch .sin (max_phi )], dim = - 1 ).unsqueeze (- 1 )
68+ # B x C x num_stain @ B x num_stain x 1
69+ # = B x C x 1
70+ v_min = torch .bmm (eig_vecs , rot_min )
71+ v_max = torch .bmm (eig_vecs , rot_max )
72+
73+ # a heuristic to make the vector corresponding to hematoxylin first and the
74+ # one corresponding to eosin second. (OD_red)
5775
76+ flag : torch .Tensor = v_min [:, 0 : 1 , :] > v_max [:, 0 : 1 , :]
77+ stain_mat = torch .where (flag ,
78+ torch .cat ((v_min , v_max ), dim = - 1 ),
79+ torch .cat ((v_max , v_min ), dim = - 1 ))
80+ return stain_mat
81+
82+
83+ @staticmethod
84+ def stain_matrix_helper_original (t_hat : torch .Tensor , perc : int , eig_vecs : torch .Tensor ):
85+ """Helper function to compute the stain matrix.
86+
87+ Separate the projected OD vectors on singular vectors (SVD of OD in Macenko paper, which is also the
88+ eigen vector of the covariance matrix of the OD)
89+
90+ Args:
91+ t_hat: projection of OD on the plane of most significant singular vectors of OD.
92+ perc: perc --> min angular term, 100 - perc --> max angular term
93+ eig_vecs: eigen vectors of the cov(OD), which may also be the singular vectors of OD.
94+
95+ Returns:
96+ sorted stain matrix in shape of B x num_stains x num_input_color_channel. For H&E cases, the first row
97+ in dimension of num_stains is H and the second is E (only num_stains=2 supported for now).
98+ """
99+ phi = torch .atan2 (t_hat [..., 1 ], t_hat [..., 0 ])
100+ # phi -> num_pix
58101 min_phi = percentile (phi , perc , dim = 0 )
59102 max_phi = percentile (phi , 100 - perc , dim = 0 )
60103 v_min = torch .matmul (eig_vecs , torch .stack ((torch .cos (min_phi ), torch .sin (min_phi )))).unsqueeze (1 )
@@ -97,13 +140,42 @@ def __call__(self, od: torch.Tensor,
97140 assert num_stains == 2 , f"Num stains: { num_stains } not currently supported in Macenko. Only support: 2"
98141 # B x (HxWx1)
99142 tissue_mask_flatten = tissue_mask .flatten (start_dim = 1 , end_dim = - 1 ).to (device )
143+ # add dim
144+
100145 # B x (H*W) x C
146+ #
101147 od_flatten = od .flatten (start_dim = 2 , end_dim = - 1 ).permute (0 , 2 , 1 )
102148 max_stains = od_flatten .shape [- 1 ]
103149 assert num_stains <= max_stains , f"number of stains exceeds maximum stains allowed." \
104150 f" { num_stains } vs { max_stains } "
151+ if CONFIG .STAIN_MAT_BATCHIFY :
152+ return self .stain_mat_vectorize (od_flatten ,
153+ tissue_mask_flatten , num_stains , perc )
154+ else :
155+ return self .stain_mat_loop (od_flatten , tissue_mask_flatten , num_stains , perc )
156+
157+ # the actual overhead seems to be the eigh. compilation's impact is minimal.
158+ # maybe don't need it at all?
159+ # @lazy_compile
160+ def stain_mat_vectorize (self , od_flatten : torch .Tensor ,
161+ tissue_mask_flatten : torch .Tensor ,
162+ num_stains : int , perc : int ,):
163+ # add a singleton dim for batchification
164+ tissue_mask_flatten = tissue_mask_flatten [..., None ]
165+ cov_mat = batch_masked_cov (od_flatten , tissue_mask_flatten )
166+ _ , eig_vecs = torch .linalg .eigh (cov_mat )
167+ eig_vecs = eig_vecs [:, :, - num_stains :]
168+ # unmasked. handle masking later
169+ t_hat = torch .bmm (od_flatten , eig_vecs )
170+ stain_mat = MacenkoAlg .stain_matrix_helper (t_hat , tissue_mask_flatten ,
171+ perc , eig_vecs )
172+ stain_mat = stain_mat .transpose (1 , 2 )
173+ return stain_mat
174+
175+ def stain_mat_loop (self , od_flatten : torch .Tensor , tissue_mask_flatten : torch .Tensor ,
176+ num_stains : int , perc : int ,
177+ ):
105178 stain_mat_list = []
106- # todo, batchify
107179 for od_single , mask_single in zip (od_flatten , tissue_mask_flatten ):
108180 x = od_single [mask_single ]
109181
@@ -114,7 +186,9 @@ def __call__(self, od: torch.Tensor,
114186 # HW * C x C x num_stains --> HW x num_stains
115187 t_hat = torch .matmul (x , eig_vecs )
116188 # HW
117- stain_mat = MacenkoAlg .stain_matrix_helper (t_hat , perc , eig_vecs )
189+ # t_hat -> num_pixels x num_stain
190+ # eig_vecs -> C x num_stain
191+ stain_mat = MacenkoAlg .stain_matrix_helper_original (t_hat , perc , eig_vecs )
118192 stain_mat = stain_mat .T
119193 stain_mat_list .append (stain_mat )
120194 return torch .stack (stain_mat_list )
0 commit comments