Skip to content

Commit 16f7445

Browse files
committed
refactor dict learning; add const
1 parent 0cc1b20 commit 16f7445

File tree

15 files changed

+516
-350
lines changed

15 files changed

+516
-350
lines changed

demo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def postprocess(image_tensor): return convert_image_dtype(image_tensor, torch.ui
7777
plt.title(f"Vahadane: {idx}")
7878
plt.show()
7979

80-
# %timeit normalizer_vahadane(norm_tensor, constrained=True, verbose=False)
80+
# %timeit normalizer_vahadane(norm_tensor, positive_dict=True)
8181

8282
# #################### Macenko
8383

@@ -95,7 +95,7 @@ def postprocess(image_tensor): return convert_image_dtype(image_tensor, torch.ui
9595
plt.imshow(test_out)
9696
plt.title(f"Macenko: {idx}")
9797
plt.show()
98-
# # %timeit normalizer_macenko(norm_tensor, algorithm='ista', constrained=True, verbose=False)
98+
# # %timeit normalizer_macenko(norm_tensor, algorithm='ista', positive_dict=True,)
9999

100100
# ###################### Reinhard
101101

tests/images/test_functionals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from tests.util import fix_seed, dummy_from_numpy, psnr
55
from torch_staintools.functional.stain_extraction.macenko import MacenkoExtractor
66
from torch_staintools.functional.stain_extraction.vahadane import VahadaneExtractor
7-
from torch_staintools.functional.optimization.dict_learning import get_concentrations
7+
from torch_staintools.functional.optimization.concentration import get_concentrations
88
from torch_staintools.functional.tissue_mask import get_tissue_mask, TissueMaskException
99
from torch_staintools.functional.utility.implementation import transpose_trailing, img_from_concentration
1010
from torchvision.transforms.functional import convert_image_dtype

torch_staintools/augmentor/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import torch
22
from typing import Optional, Sequence, Tuple, Hashable, List
33
from ..functional.stain_extraction.factory import build_from_name
4-
from ..functional.optimization.dict_learning import get_concentrations, METHOD_FACTORIZE
4+
from ..functional.optimization.sparse_util import METHOD_FACTORIZE
5+
from ..functional.optimization.concentration import get_concentrations
56
from ..functional.stain_extraction.extractor import BaseExtractor
67
from ..functional.utility.implementation import transpose_trailing, img_from_concentration
78
from ..functional.tissue_mask import get_tissue_mask, TissueMaskException

torch_staintools/augmentor/factory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from typing import Literal, Callable, Optional, Sequence
22
import torch
33
from .base import Augmentor
4-
from ..functional.optimization.dict_learning import METHOD_FACTORIZE
4+
from ..functional.optimization.sparse_util import METHOD_FACTORIZE
5+
56
AUG_TYPE_VAHADANE = Literal['vahadane']
67
AUG_TYPE_MACENKO = Literal['macenko']
78

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .config import CONST
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from dataclasses import dataclass
2+
3+
__all__ = ['CONST']
4+
5+
@dataclass
6+
class _Config:
7+
# L2 penalty for ridge initialization of codes Z
8+
INIT_RIDGE_L2: float = 1e-4
9+
# L2 weight decay term of dictionary W in the lasso loss of dictionary learning
10+
DICT_WEIGHT_DECAY: float = 10e-10
11+
# Whether the code is persisted in the iterative procedure of dictionary learning
12+
DICT_PERSIST_CODE: bool = True
13+
# Whether to Enforce Positive Dictionary / Stain Matrix
14+
DICT_POSITIVE_DICTIONARY: bool = True
15+
# Whether to Enforce Positive Code / Concentration
16+
DICT_POSITIVE_CODE: bool = True
17+
18+
19+
CONST = _Config()
20+
21+
22+
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import torch
2+
3+
from torch_staintools.functional.conversion.od import rgb2od
4+
from torch_staintools.functional.optimization.solver import coord_descent, ista, fista
5+
from torch_staintools.functional.optimization.sparse_util import initialize_code, METHOD_FACTORIZE, _batch_supported
6+
from torch_staintools.functional.utility import transpose_trailing
7+
8+
9+
def get_concentrations_single(od_flatten, stain_matrix, regularizer=0.01,
10+
method: METHOD_FACTORIZE = 'fista',
11+
rng: torch.Generator = None,
12+
positive: bool = False,
13+
):
14+
"""Helper function to estimate concentration matrix given an image and stain matrix with shape: 2 x (H*W)
15+
16+
For solvers without batch support. Inputs are individual data points from a batch
17+
18+
Args:
19+
od_flatten: Flattened optical density vectors in shape of (H*W) x C (H and W dimensions flattened).
20+
stain_matrix: the computed stain matrices in shape of num_stain x input channel
21+
regularizer: regularization term if ISTA algorithm is used
22+
method: which method to compute the concentration: coordinate descent ('cd') or iterative-shrinkage soft
23+
thresholding algorithm ('ista')
24+
rng: torch.Generator for random initializations
25+
positive: enforce positive concentration
26+
Returns:
27+
computed concentration: num_stains x num_pixel_in_tissue_mask
28+
"""
29+
z0 = initialize_code(od_flatten, stain_matrix.T, 'zero', rng=rng)
30+
match method:
31+
case 'cd':
32+
return coord_descent(od_flatten, z0, stain_matrix.T, alpha=regularizer, positive_code=positive).T
33+
case 'ista':
34+
return ista(od_flatten, z0, stain_matrix.T, alpha=regularizer, positive_code=positive).T
35+
case 'fista':
36+
return fista(od_flatten, z0, stain_matrix.T, alpha=regularizer, positive_code=positive).T
37+
case 'ls':
38+
return torch.linalg.lstsq(stain_matrix.T, od_flatten.T)[0].T
39+
40+
raise NotImplementedError(f"{method} is not a valid optimizer")
41+
42+
43+
def get_concentration_one_by_one(od_flatten, stain_matrix, regularizer, algorithm, rng):
44+
result = list()
45+
for od_single, stain_mat_single in zip(od_flatten, stain_matrix):
46+
result.append(get_concentrations_single(od_single, stain_mat_single, regularizer, algorithm, rng=rng))
47+
# get_concentrations_helper(od_flatten, stain_matrix, regularizer, method)
48+
return torch.stack(result)
49+
50+
51+
def _ls_batch(od_flatten, stain_matrix):
52+
"""Use least square to solve the factorization for concentration.
53+
54+
Warnings:
55+
May fail on GPU for individual large input in cuSolver backend (e.g., 1000 x 1000), regardless of batch size.
56+
Better for multiple small inputs in terms of H and W.
57+
Magma backend may work: torch.backends.cuda.preferred_linalg_library('magma')
58+
59+
Args:
60+
od_flatten: B * (HW) x num_input_channel
61+
stain_matrix: B x num_stains x num_input_channel
62+
63+
Returns:
64+
concentration B x num_stains x (HW)
65+
"""
66+
return torch.linalg.lstsq(transpose_trailing(stain_matrix), transpose_trailing(od_flatten))[0]
67+
68+
69+
def get_concentration_batch(od_flatten, stain_matrix, regularizer, algorithm, rng):
70+
assert algorithm in _batch_supported
71+
if not _batch_supported[algorithm]:
72+
return get_concentration_one_by_one(od_flatten, stain_matrix, regularizer, algorithm, rng)
73+
match algorithm:
74+
case 'ls':
75+
return _ls_batch(od_flatten, stain_matrix)
76+
case _:
77+
...
78+
79+
raise NotImplementedError('Currently only least-square (ls) is implemented as batch concentration solver')
80+
81+
82+
def get_concentrations(image, stain_matrix, regularizer=0.01,
83+
algorithm: METHOD_FACTORIZE = 'fista',
84+
rng: torch.Generator = None):
85+
"""Estimate concentration matrix given an image and stain matrix.
86+
87+
Warnings:
88+
algorithm = 'ls' May fail on GPU for individual large input (e.g., 1000 x 1000), regardless of batch size.
89+
Better for multiple small inputs in terms of H and W.
90+
Args:
91+
image: batched image(s) in shape of BxCxHxW
92+
stain_matrix: B x num_stain x input channel
93+
regularizer: regularization term if ISTA algorithm is used
94+
algorithm: which method to compute the concentration: Solve min||HExC - OD||p
95+
support 'ista', 'cd', and 'ls'. 'ls' simply solves the least square problem for factorization of
96+
min||HExC - OD||F (Frobenius norm) but is faster. 'ista'/cd enforce the sparse penalty (L1 norm) but slower.
97+
rng: torch.Generator for random initializations
98+
Returns:
99+
concentration matrix: B x num_stains x num_pixel_in_tissue_mask
100+
"""
101+
device = image.device
102+
stain_matrix = stain_matrix.to(device)
103+
# BCHW
104+
od = rgb2od(image).to(device)
105+
# B (H*W) C
106+
od_flatten = od.flatten(start_dim=2, end_dim=-1).permute(0, 2, 1)
107+
return get_concentration_batch(od_flatten, stain_matrix, regularizer, algorithm, rng)

0 commit comments

Comments
 (0)