Skip to content

Commit 8693dbc

Browse files
committed
update unittest signatures; remove one invalid test case (concentration output is never masked)
1 parent 88a890c commit 8693dbc

File tree

13 files changed

+165
-122
lines changed

13 files changed

+165
-122
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ augmentor = AugmentorBuilder.build('vahadane',
153153
# if set to None means all pixels are treated as tissue
154154
luminosity_threshold=0.8,
155155
# herein we use 'ista' to compute the concentration
156-
concentration_method='ista',
156+
concentration_solver='ista',
157157
sigma_alpha=0.2,
158158
sigma_beta=0.2, target_stain_idx=(0, 1),
159159
# this allows to cache the stain matrix if it's too time-consuming to recompute.

demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def postprocess(image_tensor): return convert_image_dtype(image_tensor, torch.ui
217217
fig, axs = plt.subplots(2, num_repeat + 1, figsize=(15, 8), dpi=300)
218218
for i, ax_alg in enumerate(axs):
219219
alg = algorithms[i].lower()
220-
augmentor = AugmentorBuilder.build(alg, concentration_method='ista',
220+
augmentor = AugmentorBuilder.build(alg, concentration_solver='ista',
221221
sigma_alpha=0.5,
222222
sigma_beta=0.5,
223223
luminosity_threshold=0.8,

tests/images/test_functionals.py

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
"""borrow Staintool's test cases
22
"""
33
import unittest
4+
from typing import Optional, cast
5+
46
from tests.util import fix_seed, dummy_from_numpy, psnr
5-
from torch_staintools.functional.stain_extraction.macenko import MacenkoAlg
6-
from torch_staintools.functional.stain_extraction.vahadane import VahadaneAlg
7-
from torch_staintools.functional.concentration import get_concentrations
7+
from torch_staintools.constants import CONFIG
8+
from torch_staintools.functional.optimization.sparse_util import METHOD_FACTORIZE
9+
from torch_staintools.functional.stain_extraction.extractor import StainExtraction
10+
from torch_staintools.functional.stain_extraction.macenko import MacenkoAlg, DEFAULT_MACENKO_CONFIG
11+
from torch_staintools.functional.stain_extraction.vahadane import VahadaneAlg, DEFAULT_VAHADANE_CONFIG
12+
from torch_staintools.functional.concentration import ConcentrationSolver, ConcentCfg
813
from torch_staintools.functional.tissue_mask import get_tissue_mask, TissueMaskException
914
from torch_staintools.functional.utility.implementation import transpose_trailing, img_from_concentration
1015
from torchvision.transforms.functional import convert_image_dtype
@@ -24,6 +29,7 @@ class TestFunctional(unittest.TestCase):
2429
rand_img = torch.randint(0, 255, (1, 3, 256, 256))
2530

2631
THRESH_PSNR = 20
32+
POSITIVE_CONC_CFG = ConcentCfg()
2733

2834
@staticmethod
2935
def get_dummy_path():
@@ -40,65 +46,87 @@ def new_dummy_img_tensor_ubyte():
4046
return TestFunctional.DUMMY_IMG_TENSOR.clone()
4147

4248
@staticmethod
43-
def stain_extract(dummy_tensor, get_stain_mat, luminosity_threshold, num_stains, algorithm, regularizer):
49+
def stain_extract(dummy_tensor: torch.Tensor, get_stain_mat: StainExtraction,
50+
conc_solver: ConcentrationSolver,
51+
luminosity_threshold: float, num_stains: int, rng: Optional[torch.Generator]):
4452

4553
# lab_tensor = rgb_to_lab(convert_image_dtype(dummy_tensor))
4654

47-
stain_matrix = get_stain_mat(image=dummy_tensor, luminosity_threshold=luminosity_threshold,
48-
num_stains=num_stains, regularizer=regularizer)
55+
stain_matrix = get_stain_mat(image=dummy_tensor,
56+
luminosity_threshold=luminosity_threshold,
57+
num_stains=num_stains, rng=rng)
4958

50-
concentration = get_concentrations(dummy_tensor, stain_matrix, algorithm=algorithm,
51-
regularizer=regularizer)
59+
concentration = conc_solver(dummy_tensor, stain_matrix, rng=rng)
5260
c_transposed_src = transpose_trailing(concentration)
5361
reconstructed = img_from_concentration(c_transposed_src, stain_matrix, dummy_tensor.shape, (0, 1))
5462
return stain_matrix, concentration, c_transposed_src, reconstructed
5563

5664
@staticmethod
57-
def extract_eval_helper(tester, get_stain_mat, luminosity_threshold,
58-
num_stains, regularizer, dict_algorithm):
65+
def extract_eval_helper(tester, get_stain_mat: StainExtraction,
66+
conc_solver: ConcentrationSolver,
67+
luminosity_threshold: Optional[float],
68+
num_stains: int, rng: Optional[torch.Generator]):
5969
device = TestFunctional.device
6070
dummy_tensor_ubyte = TestFunctional.new_dummy_img_tensor_ubyte().to(device)
6171
# get_stain_mat = MacenkoExtractor()
6272
result_tuple = TestFunctional.stain_extract(dummy_tensor_ubyte, get_stain_mat,
73+
conc_solver=conc_solver,
6374
luminosity_threshold=luminosity_threshold,
6475
num_stains=num_stains,
65-
algorithm=dict_algorithm, regularizer=regularizer)
76+
rng=rng)
6677

6778
stain_matrix, concentration, c_transposed_src, reconstructed = result_tuple
6879
dummy_scaled = convert_image_dtype(dummy_tensor_ubyte, torch.float32)
6980
psnr_out = psnr(dummy_scaled, reconstructed).item()
70-
tester.assertTrue(psnr_out > TestFunctional.THRESH_PSNR)
81+
tester.assertTrue(psnr_out > TestFunctional.THRESH_PSNR,
82+
msg=f"{psnr_out} vs. {TestFunctional.THRESH_PSNR}. \n"
83+
f"{get_stain_mat.stain_algorithm.cfg} \n"
84+
f"nan: {torch.isnan(reconstructed).any()} \n"
85+
f"Dict pos: {CONFIG.DICT_POSITIVE_DICTIONARY}")
7186
# size
7287
batch_size, channel_size, height, width = dummy_tensor_ubyte.shape
7388
tester.assertTrue(stain_matrix.shape == (batch_size, num_stains, channel_size))
7489

7590
# transpose
7691
tester.assertTrue((c_transposed_src.permute(0, 2, 1) == concentration).all())
7792

78-
# manual tissue mask
79-
mask = get_tissue_mask(dummy_scaled, luminosity_threshold=luminosity_threshold)
80-
tissue_count = mask.sum()
81-
tester.assertTrue(concentration.shape[-1] == tissue_count)
82-
8393
def eval_wrapper(self, extractor):
8494

8595
# all pixel
86-
algorithms = ['ista', 'cd', 'ls']
87-
for alg in algorithms:
88-
TestFunctional.extract_eval_helper(self, extractor, luminosity_threshold=None,
89-
num_stains=2, regularizer=0.1, dict_algorithm=alg)
96+
algorithms = ['ista', 'cd', 'ls', 'fista']
97+
dict_constraint_flag = [True]
98+
for flag in dict_constraint_flag:
99+
CONFIG.DICT_POSITIVE_DICTIONARY = flag
100+
for alg in algorithms:
101+
cfg = TestFunctional.POSITIVE_CONC_CFG
102+
cfg.algorithm = cast(METHOD_FACTORIZE, alg)
103+
cfg.positive = True
104+
solver = ConcentrationSolver(cfg)
105+
TestFunctional.extract_eval_helper(self, extractor, luminosity_threshold=None,
106+
num_stains=2, conc_solver=solver, rng=None)
107+
solver.cfg.positive = False
108+
TestFunctional.extract_eval_helper(self, extractor, luminosity_threshold=None,
109+
num_stains=2, conc_solver=solver, rng=None)
90110

91111
def test_stains(self):
92-
macenko = MacenkoAlg()
93-
vahadane = VahadaneAlg()
112+
macenko = StainExtraction(MacenkoAlg(DEFAULT_MACENKO_CONFIG))
113+
vahadane = StainExtraction(VahadaneAlg(DEFAULT_VAHADANE_CONFIG))
94114
# not support num_stains other than 2
95-
with self.assertRaises(NotImplementedError):
96-
TestFunctional.extract_eval_helper(self, macenko, luminosity_threshold=None,
97-
num_stains=3, regularizer=0.1, dict_algorithm='ista')
115+
with self.assertRaises(AssertionError):
116+
TestFunctional.extract_eval_helper(self, macenko,
117+
conc_solver=ConcentrationSolver(TestFunctional.POSITIVE_CONC_CFG),
118+
luminosity_threshold=None,
119+
num_stains=3, rng=None)
98120

99121
self.eval_wrapper(macenko)
100122
self.eval_wrapper(vahadane)
101123

124+
# vahadane with rng and lr
125+
vahadane.stain_algorithm.cfg.lr = 0.5
126+
TestFunctional.extract_eval_helper(self, vahadane,
127+
conc_solver=ConcentrationSolver(TestFunctional.POSITIVE_CONC_CFG),
128+
luminosity_threshold=None,
129+
num_stains=3, rng=torch.Generator(1))
102130
def test_tissue_mask(self):
103131
device = TestFunctional.device
104132
dummy_scaled = convert_image_dtype(TestFunctional.new_dummy_img_tensor_ubyte(), torch.float32).to(device)

torch_staintools/augmentor/base.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
from functools import partial
2+
13
import torch
24
from typing import Optional, Sequence, Tuple, Hashable, List
3-
from ..functional.stain_extraction.factory import build_from_name
4-
from ..functional.optimization.sparse_util import METHOD_FACTORIZE
5-
from torch_staintools.functional.concentration import get_concentrations
5+
from torch_staintools.functional.concentration import ConcentrationSolver
66
from ..functional.stain_extraction.extractor import StainExtraction, StainAlg
77
from ..functional.utility.implementation import transpose_trailing, img_from_concentration
88
from ..functional.tissue_mask import get_tissue_mask, TissueMaskException
@@ -23,10 +23,10 @@ class Augmentor(CachedRNGModule):
2323

2424
# _tensor_cache: TensorCache
2525
# CACHE_FIELD: str = '_tensor_cache'
26-
26+
rng: Optional[torch.Generator]
2727
target_stain_idx: Optional[Sequence[int]]
2828

29-
concentration_method: METHOD_FACTORIZE
29+
concentration_solver: ConcentrationSolver
3030
get_stain_matrix: StainExtraction # can be any callable following the signature of BaseExtractor's __call__
3131
target_concentrations: torch.Tensor
3232

@@ -38,7 +38,7 @@ class Augmentor(CachedRNGModule):
3838
regularizer: float
3939

4040
def __init__(self, stain_alg: StainAlg,
41-
concentration_method: METHOD_FACTORIZE = 'ista',
41+
concentration_solver: ConcentrationSolver,
4242
rng: TYPE_RNG = None,
4343
target_stain_idx: Optional[Sequence[int]] = (0, 1),
4444
sigma_alpha: float = 0.2,
@@ -56,7 +56,7 @@ def __init__(self, stain_alg: StainAlg,
5656
Args:
5757
stain_alg: the Callable to obtain stain matrix - e.g., Vahadane's dict learning or
5858
Macenko's SVD
59-
concentration_method: How to get stain concentration from stain matrix
59+
concentration_solver: How to get stain concentration from stain matrix
6060
rng: the specified torch.Generator or int (as seed) for reproducing the results
6161
sigma_alpha: bound of alpha (mean 1). Sampled from (1-sigma, 1+sigma)
6262
sigma_beta: bound of beta (mean 0). Sampled from (-sigma, sigma)
@@ -67,7 +67,7 @@ def __init__(self, stain_alg: StainAlg,
6767
6868
"""
6969
super().__init__(cache, device, rng)
70-
self.concentration_method = concentration_method
70+
self.concentration_solver = concentration_solver
7171
self.get_stain_matrix = StainExtraction(stain_alg)
7272

7373
self.target_stain_idx = target_stain_idx
@@ -212,32 +212,27 @@ def augment(*,
212212
alpha=alpha, beta=beta)
213213
return target_concentration
214214

215-
def forward(self, target: torch.Tensor, cache_keys: Optional[List[Hashable]] = None, **stain_mat_kwargs):
215+
def forward(self, target: torch.Tensor, cache_keys: Optional[List[Hashable]] = None):
216216
"""
217217
218218
Args:
219219
target: input tensor to augment. Shape B x C x H x W and intensity range is [0, 1].
220220
cache_keys: unique keys point the input batch to the cached stain matrices. `None` means no cache.
221-
**stain_mat_kwargs: all extra keyword arguments other than regularizer/num_stains/luminosity_threshold set
222-
in __init__.
223221
224222
Returns:
225223
Augmented output.
226224
"""
227225
# stain_matrix_target -- B x num_stain x num_input_color_channel
228226
# todo cache
229-
get_stain_mat_partial = self.get_stain_matrix.get_partial(luminosity_threshold=self.luminosity_threshold,
230-
num_stains=self.num_stains,
231-
regularizer=self.regularizer,
232-
rng=self.rng,
233-
**stain_mat_kwargs)
227+
get_stain_partial = partial(self.get_stain_matrix,
228+
luminosity_threshold=self.luminosity_threshold,
229+
num_stains=self.num_stains, rng=self.rng)
234230

235-
target_stain_matrix = self.tensor_from_cache(cache_keys=cache_keys, func=get_stain_mat_partial,
231+
target_stain_matrix = self.tensor_from_cache(cache_keys=cache_keys, func=get_stain_partial,
236232
target=target)
237233

238234
# B x num_stains x num_pixel_in_mask
239-
concentration = get_concentrations(target, target_stain_matrix, regularizer=self.regularizer,
240-
algorithm=self.concentration_method, rng=self.rng)
235+
concentration = self.concentration_solver(target, target_stain_matrix, rng=self.rng)
241236
try:
242237
tissue_mask = get_tissue_mask(target, luminosity_threshold=self.luminosity_threshold, throw_error=True,
243238
true_when_empty=False)
@@ -257,13 +252,14 @@ def forward(self, target: torch.Tensor, cache_keys: Optional[List[Hashable]] = N
257252

258253
@classmethod
259254
def build(cls,
260-
method: str, *, concentration_method: METHOD_FACTORIZE = 'ista',
255+
stain_alg: StainAlg, *,
256+
concentration_solver: ConcentrationSolver,
261257
rng: TYPE_RNG = None,
262258
target_stain_idx: Optional[Sequence[int]] = (0, 1),
263259
sigma_alpha: float = 0.2,
264260
sigma_beta: float = 0.2,
261+
num_stains: int = 2,
265262
luminosity_threshold: Optional[float] = 0.8,
266-
regularizer: float = 0.1,
267263
use_cache: bool = False,
268264
cache_size_limit: int = -1,
269265
device: Optional[torch.device] = None,
@@ -272,18 +268,18 @@ def build(cls,
272268
"""Factory builder of the augmentor which manipulate the stain concentration by alpha * concentration + beta.
273269
274270
Args:
275-
method: algorithm name to extract stain - support 'vahadane' or 'macenko'
276-
concentration_method: method to obtain the concentration. Default 'ista' for fast sparse solution on GPU
271+
stain_alg: algorithm name to extract stain - support 'vahadane' or 'macenko'
272+
concentration_solver: method to obtain the concentration. Default 'ista' for fast sparse solution on GPU
277273
only applied for StainSeparation-based approaches (macenko and vahadane).
278274
support 'ista', 'cd', and 'ls'. 'ls' simply solves the least square problem for factorization of
279275
min||HExC - OD|| but is faster. 'ista'/cd enforce the sparse penalty but slower.
280276
rng: an optional seed (either an int or a torch.Generator) to determine the random number generation.
281277
target_stain_idx: what stains to augment: e.g., for HE cases, it can be either or both from [0, 1]
282278
sigma_alpha: alpha is uniformly randomly selected from (1-sigma_alpha, 1+sigma_alpha)
283279
sigma_beta: beta is uniformly randomly selected from (-sigma_beta, sigma_beta)
280+
num_stains: number of stains to separate.
284281
luminosity_threshold: luminosity threshold to find tissue regions (smaller than but positive)
285282
a pixel is considered as being tissue if the intensity falls in the open interval of (0, threshold).
286-
regularizer: regularization term in ISTA algorithm
287283
use_cache: whether to use cache to save the stain matrix to avoid re-computation
288284
cache_size_limit: size limit of the cache. negative means no limits.
289285
device: what device to hold the cache.
@@ -293,11 +289,9 @@ def build(cls,
293289
Returns:
294290
Augmentor.
295291
"""
296-
method = method.lower()
297-
extractor = build_from_name(method)
298292
cache = cls._init_cache(use_cache, cache_size_limit=cache_size_limit, device=device,
299293
load_path=load_path)
300-
return cls(extractor, concentration_method=concentration_method, rng=rng, target_stain_idx=target_stain_idx,
301-
sigma_alpha=sigma_alpha, sigma_beta=sigma_beta,
302-
luminosity_threshold=luminosity_threshold, regularizer=regularizer,
294+
return cls(stain_alg, concentration_solver=concentration_solver, rng=rng, target_stain_idx=target_stain_idx,
295+
sigma_alpha=sigma_alpha, sigma_beta=sigma_beta, num_stains=num_stains,
296+
luminosity_threshold=luminosity_threshold,
303297
cache=cache, device=device).to(device)

0 commit comments

Comments
 (0)