Skip to content

Commit f23aa7f

Browse files
committed
vectorize Macenko
1 parent b98ff71 commit f23aa7f

File tree

6 files changed

+122
-17
lines changed

6 files changed

+122
-17
lines changed

torch_staintools/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .version import __version__

torch_staintools/constants/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ class _Config:
1414
# Whether to enable torch.compile (currently only the dictionary learning is affected)
1515
ENABLE_COMPILE: bool = True
1616

17+
STAIN_MAT_BATCHIFY: bool = True
18+
1719
CONFIG: _Config = _Config()
1820

1921

torch_staintools/functional/stain_extraction/macenko.py

Lines changed: 88 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from typing import Callable, Optional
22

33
import torch
4-
from .utils import percentile, cov
4+
from .utils import percentile, batch_masked_cov, batch_masked_perc, cov
55
from dataclasses import dataclass
66

7+
from ..compile import lazy_compile
8+
from ...constants import CONFIG
9+
10+
711
@dataclass(frozen=False)
812
class MckCfg:
913
"""Configration of Macenko Stain Estimation.
@@ -23,38 +27,77 @@ def __init__(self, cfg: MckCfg):
2327
super().__init__()
2428
self.cfg = cfg
2529

26-
@staticmethod
27-
def cov(x):
28-
"""Covariance matrix for eigen decomposition.
29-
https://en.wikipedia.org/wiki/Covariance_matrix
30-
"""
31-
E_x = x.mean(dim=1)
32-
x = x - E_x[:, None]
33-
return torch.mm(x, x.T) / (x.size(1) - 1)
3430

3531
@staticmethod
3632
def angular_helper(t_hat, ):
3733
# todo deal with multi-dimensional scenario
3834
raise NotImplementedError
3935

4036
@staticmethod
41-
def stain_matrix_helper(t_hat: torch.Tensor, perc: int, eig_vecs: torch.Tensor):
37+
def stain_matrix_helper(t_hat: torch.Tensor, mask_flatten: torch.Tensor,
38+
perc: int, eig_vecs: torch.Tensor):
4239
"""Helper function to compute the stain matrix.
4340
4441
Separate the projected OD vectors on singular vectors (SVD of OD in Macenko paper, which is also the
4542
eigen vector of the covariance matrix of the OD)
4643
4744
Args:
4845
t_hat: projection of OD on the plane of most significant singular vectors of OD.
49-
perc: perc --> min angular term, 100 - perc --> max angular term
46+
B x num_pixel. Not masked.
47+
mask_flatten: the flattened mask. B x num_pixel x 1.
48+
perc: perc --> min angular term, 100 - perc --> max angular term. integer [0, 100].
5049
eig_vecs: eigen vectors of the cov(OD), which may also be the singular vectors of OD.
50+
B x C x num_stains
5151
5252
Returns:
5353
sorted stain matrix in shape of B x num_stains x num_input_color_channel. For H&E cases, the first row
5454
in dimension of num_stains is H and the second is E (only num_stains=2 supported for now).
5555
"""
56-
phi = torch.atan2(t_hat[:, 1], t_hat[:, 0])
56+
# batchified. t_hat as B x num_pixel x num_stains
57+
# phi as B x num_pixels. Unmasked at this point.
58+
phi = torch.atan2(t_hat[..., 1], t_hat[..., 0])
59+
# phi -> num_pix
60+
# requires mask and phi has the same number of dimension.
61+
# therefore collapse the final dim
62+
min_phi = batch_masked_perc(phi, mask_flatten.squeeze(-1), perc, dim=1)
63+
max_phi = batch_masked_perc(phi, mask_flatten.squeeze(-1), 100 - perc, dim=1)
64+
65+
# B x 2 x 1
66+
rot_min = torch.stack([torch.cos(min_phi), torch.sin(min_phi)], dim=-1).unsqueeze(-1)
67+
rot_max = torch.stack([torch.cos(max_phi), torch.sin(max_phi)], dim=-1).unsqueeze(-1)
68+
# B x C x num_stain @ B x num_stain x 1
69+
# = B x C x 1
70+
v_min = torch.bmm(eig_vecs, rot_min)
71+
v_max = torch.bmm(eig_vecs, rot_max)
72+
73+
# a heuristic to make the vector corresponding to hematoxylin first and the
74+
# one corresponding to eosin second. (OD_red)
5775

76+
flag: torch.Tensor = v_min[:, 0: 1, :] > v_max[:, 0: 1, :]
77+
stain_mat = torch.where(flag,
78+
torch.cat((v_min, v_max), dim=-1),
79+
torch.cat((v_max, v_min), dim=-1))
80+
return stain_mat
81+
82+
83+
@staticmethod
84+
def stain_matrix_helper_original(t_hat: torch.Tensor, perc: int, eig_vecs: torch.Tensor):
85+
"""Helper function to compute the stain matrix.
86+
87+
Separate the projected OD vectors on singular vectors (SVD of OD in Macenko paper, which is also the
88+
eigen vector of the covariance matrix of the OD)
89+
90+
Args:
91+
t_hat: projection of OD on the plane of most significant singular vectors of OD.
92+
perc: perc --> min angular term, 100 - perc --> max angular term
93+
eig_vecs: eigen vectors of the cov(OD), which may also be the singular vectors of OD.
94+
95+
Returns:
96+
sorted stain matrix in shape of B x num_stains x num_input_color_channel. For H&E cases, the first row
97+
in dimension of num_stains is H and the second is E (only num_stains=2 supported for now).
98+
"""
99+
phi = torch.atan2(t_hat[..., 1], t_hat[..., 0])
100+
# phi -> num_pix
58101
min_phi = percentile(phi, perc, dim=0)
59102
max_phi = percentile(phi, 100 - perc, dim=0)
60103
v_min = torch.matmul(eig_vecs, torch.stack((torch.cos(min_phi), torch.sin(min_phi)))).unsqueeze(1)
@@ -97,13 +140,42 @@ def __call__(self, od: torch.Tensor,
97140
assert num_stains == 2, f"Num stains: {num_stains} not currently supported in Macenko. Only support: 2"
98141
# B x (HxWx1)
99142
tissue_mask_flatten = tissue_mask.flatten(start_dim=1, end_dim=-1).to(device)
143+
# add dim
144+
100145
# B x (H*W) x C
146+
#
101147
od_flatten = od.flatten(start_dim=2, end_dim=-1).permute(0, 2, 1)
102148
max_stains = od_flatten.shape[-1]
103149
assert num_stains <= max_stains, f"number of stains exceeds maximum stains allowed." \
104150
f" {num_stains} vs {max_stains}"
151+
if CONFIG.STAIN_MAT_BATCHIFY:
152+
return self.stain_mat_vectorize(od_flatten,
153+
tissue_mask_flatten, num_stains, perc)
154+
else:
155+
return self.stain_mat_loop(od_flatten, tissue_mask_flatten, num_stains, perc)
156+
157+
# the actual overhead seems to be the eigh. compilation's impact is minimal.
158+
# maybe don't need it at all?
159+
# @lazy_compile
160+
def stain_mat_vectorize(self, od_flatten: torch.Tensor,
161+
tissue_mask_flatten: torch.Tensor,
162+
num_stains: int, perc: int,):
163+
# add a singleton dim for batchification
164+
tissue_mask_flatten = tissue_mask_flatten[..., None]
165+
cov_mat = batch_masked_cov(od_flatten, tissue_mask_flatten)
166+
_, eig_vecs = torch.linalg.eigh(cov_mat)
167+
eig_vecs = eig_vecs[:, :, -num_stains:]
168+
# unmasked. handle masking later
169+
t_hat = torch.bmm(od_flatten, eig_vecs)
170+
stain_mat = MacenkoAlg.stain_matrix_helper(t_hat, tissue_mask_flatten,
171+
perc, eig_vecs)
172+
stain_mat = stain_mat.transpose(1, 2)
173+
return stain_mat
174+
175+
def stain_mat_loop(self, od_flatten: torch.Tensor, tissue_mask_flatten: torch.Tensor,
176+
num_stains: int, perc: int,
177+
):
105178
stain_mat_list = []
106-
# todo, batchify
107179
for od_single, mask_single in zip(od_flatten, tissue_mask_flatten):
108180
x = od_single[mask_single]
109181

@@ -114,7 +186,9 @@ def __call__(self, od: torch.Tensor,
114186
# HW * C x C x num_stains --> HW x num_stains
115187
t_hat = torch.matmul(x, eig_vecs)
116188
# HW
117-
stain_mat = MacenkoAlg.stain_matrix_helper(t_hat, perc, eig_vecs)
189+
# t_hat -> num_pixels x num_stain
190+
# eig_vecs -> C x num_stain
191+
stain_mat = MacenkoAlg.stain_matrix_helper_original(t_hat, perc, eig_vecs)
118192
stain_mat = stain_mat.T
119193
stain_mat_list.append(stain_mat)
120194
return torch.stack(stain_mat_list)

torch_staintools/functional/stain_extraction/utils.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,28 @@ def normalize_matrix_rows(a: torch.Tensor) -> torch.Tensor:
1212
return a / torch.linalg.norm(a, dim=1)[:, None]
1313

1414

15-
def cov(x):
15+
def cov(x: torch.Tensor) -> torch.Tensor:
1616
"""Covariance matrix for eigen decomposition.
1717
https://en.wikipedia.org/wiki/Covariance_matrix
1818
"""
19+
# x: C x num_pixel
1920
E_x = x.mean(dim=1)
2021
x = x - E_x[:, None]
2122
return torch.mm(x, x.T) / (x.size(1) - 1)
2223

24+
25+
@torch.no_grad()
26+
def batch_masked_cov(od_flatten: torch.Tensor, mask_flatten: torch.Tensor) -> torch.Tensor:
27+
# mask B x num_pixel x 1
28+
# clamp so avoid 0div in mean and cov computation
29+
size_masked = mask_flatten.sum(dim=1).clamp_min(2).unsqueeze(-1)
30+
# B x C
31+
mean = (od_flatten * mask_flatten).sum(dim=1, keepdim=True) / size_masked
32+
# B x num_pix x C
33+
x = (od_flatten - mean) * mask_flatten
34+
return torch.bmm(x.transpose(1, 2), x) / (size_masked - 1)
35+
36+
2337
def percentile(t: torch.Tensor, q: float, dim: int) -> torch.Tensor:
2438
"""Author: adapted from https://gist.github.com/spezold/42a451682422beb42bc43ad0c0967a30
2539
@@ -45,3 +59,17 @@ def percentile(t: torch.Tensor, q: float, dim: int) -> torch.Tensor:
4559
return t.kthvalue(k, dim=dim).values
4660

4761

62+
@torch.no_grad()
63+
def batch_masked_perc(phi: torch.Tensor, mask: torch.Tensor, q: int, dim: int) -> torch.Tensor:
64+
# fill nan. use nanquantile to ignore the nans (bg)
65+
# mask = mask.squeeze(-1)
66+
phi_filled = torch.where(mask.bool(), phi, torch.tensor(torch.inf, device=phi.device))
67+
# inf at the end. cut off.
68+
phi_sorted, _ = torch.sort(phi_filled, dim=dim)
69+
size_masked = mask.sum(dim=dim)
70+
q_float = q / 100.0
71+
target_indices = (q_float * (size_masked - 1)).long().clamp(min=0)
72+
73+
# not friendly to torch.compile
74+
# torch.nanquantile(phi_masked, q_float, dim=dim, interpolation='nearest') # B
75+
return phi_sorted.gather(dim, target_indices.unsqueeze(dim)).squeeze(dim)

torch_staintools/normalizer/factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def build(method: TYPE_SUPPORTED,
2929
dict_init: MODE_INIT = 'transpose',
3030
concentration_solver: METHOD_FACTORIZE = 'fista',
3131
num_stains: int = 2,
32-
luminosity_threshold: float = 0.8,
32+
luminosity_threshold: Optional[float] = 0.8,
3333
perc: int = 1,
3434
regularizer: float = PARAM.OPTIM_DEFAULT_SPARSE_LAMBDA, # 1e-2
3535
maxiter: int = PARAM.OPTIM_SPARSE_DEFAULT_MAX_ITER, # 50

torch_staintools/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.0.5'
1+
__version__ = '1.0.6a'

0 commit comments

Comments
 (0)