Skip to content

Commit b00fba4

Browse files
committed
add warning for customized rng due to potential graph breaking behavior
1 parent 8693dbc commit b00fba4

File tree

9 files changed

+159
-31
lines changed

9 files changed

+159
-31
lines changed

README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ If this toolkit helps you in your publication, please feel free to cite with the
3636
* Stain Augmentation using Macenko and Vahadane as stain extraction.
3737
* Fast normalization/augmentation on GPU with stain matrices caching.
3838
* Simulate the workflow in [StainTools library](https://github.com/Peter554/StainTools) but use the Iterative Shrinkage Thresholding Algorithm (ISTA), or optionally, the coordinate descent (CD) to solve the dictionary learning for stain matrix computation in Vahadane or Macenko (stain concentration only) algorithm. The implementation of ISTA and CD are derived from Cédric Walker's [torchvahadane](https://github.com/cwlkr/torchvahadane)
39-
* Stain Concentration is solved via factorization of `Stain_Matrix x Concentration = Optical_Density`. For efficient sparse solution and more robust outcomes, ISTA can be applied. Alternatively, Least Square solver (LS) from `torch.linalg.lstsq` might be applied for faster non-sparse solution.
39+
* Stain Concentration is solved via factorization of `Stain_Matrix x Concentration = Optical_Density`. For efficient sparse solution and more robust outcomes, ISTA can be applied. Alternatively, the Least Square solver (LS) from `torch.linalg.lstsq` might be applied for faster non-sparse solution.
4040
* No SPAMS requirement (which is a dependency in StainTools).
4141

4242
<br />
@@ -90,7 +90,7 @@ timeit. Comparison between torch_stain_tools in CPU/GPU mode, as well as that of
9090
* Normalizers are wrapped as `torch.nn.Module`, working similarly to a standalone neural network. This means that for a workflow involving dataloader with multiprocessing, the normalizer
9191
(Note that CUDA has poor support in multiprocessing, and therefore it may not be the best practice to perform GPU-accelerated on-the-fly stain transformation in pytorch's dataset/dataloader)
9292

93-
* `concentration_method='ls'` (i.e., `torch.linalg.lstsq`) can be efficient for batches of many smaller input (e.g., `256x256`) in terms of width and height. However, it may fail on GPU for a single larger input image (width and height). This happens even if the
93+
* `concentration_method='ls'` (i.e., `torch.linalg.lstsq`) can be efficient for batches of many smaller input (e.g., `256x256`) in terms of width and height. However, it may fail on GPU for a single larger input image (width and height). This happens even if
9494
the total number of pixels of the image is fewer than the aforementioned batch of multiple smaller input. Therefore, `concentration_method='ls'` could be suitable to deal with huge amount of small images in batches on the fly.
9595

9696
```python
@@ -184,8 +184,7 @@ augmentor.dump_cache('./cache.pickle')
184184

185185
# fast batch operation
186186
tile_size = 512
187-
tiles: torch.Tensor = norm_tensor.unfold(2, tile_size, tile_size)
188-
.unfold(3, tile_size, tile_size).reshape(1, 3, -1, tile_size, tile_size).squeeze(0).permute(1, 0, 2, 3).contiguous()
187+
tiles: torch.Tensor = norm_tensor.unfold(2, tile_size, tile_size).unfold(3, tile_size, tile_size).reshape(1, 3, -1, tile_size, tile_size).squeeze(0).permute(1, 0, 2, 3).contiguous()
189188
print(tiles.shape)
190189
# use macenko normalization as example
191190
normalizer_macenko = NormalizerBuilder.build('macenko', use_cache=True,

demo.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""Demo prerequisite:
2-
tqdm
2+
tqdm (progress bar)
33
staintools (for comparison)
4+
cv2 (read and process images)
45
"""
56
import cv2
67
import torch
78
from torchvision.transforms import ToTensor
89
from torchvision.transforms.functional import convert_image_dtype
910
from torch_staintools.normalizer import NormalizerBuilder
1011
from torch_staintools.augmentor import AugmentorBuilder
12+
from torch_staintools.constants import CONFIG
1113
import matplotlib.pyplot as plt
1214
import numpy as np
1315
from tqdm import tqdm
@@ -41,6 +43,7 @@
4143

4244
# test with multiple smaller regions from the sample image
4345
tile_size = 1024
46+
# split the sample images into a batch of patches.
4447
tiles: torch.Tensor = norm_tensor.unfold(2, tile_size, tile_size)\
4548
.unfold(3, tile_size, tile_size).reshape(1, 3, -1, tile_size, tile_size).squeeze(0).permute(1, 0, 2, 3).contiguous()
4649

@@ -53,24 +56,36 @@
5356
plt.show()
5457

5558

59+
# helper function to convert tensor back to numpy arrays for visualization purposes.
5660
def postprocess(image_tensor): return convert_image_dtype(image_tensor, torch.uint8)\
5761
.squeeze().detach().cpu().permute(1, 2, 0).numpy()
5862

59-
63+
# We enable the torch.compile (note this is True by default)
64+
CONFIG.ENABLE_COMPILE = True
6065
# ######### Vahadane
6166
normalizer_vahadane = NormalizerBuilder.build('vahadane',
62-
concentration_solver='ista', use_cache=True,
63-
rng=1,
67+
# use fista (fast iterative shrinkage-thresholding algorithm)
68+
# for dictionary learning to
69+
# estimate the stain matrix (sparse constraints)
70+
# alternative: 'cd' (coordinate descent);
71+
# 'ista' (iterative shrinkage-thresholding algorithm)
72+
sparse_stain_solver='fista',
73+
concentration_solver='fista',
74+
# whether to cache the stain matrix.
75+
# must pair the input with an identifier (e.g. filename)
76+
# otherwise cache will be ignored.
77+
use_cache=True
6478
)
6579
normalizer_vahadane = normalizer_vahadane.to(device)
6680
normalizer_vahadane.fit(target_tensor)
6781
# the normalizer has no parameters so torch.no_grad() has no effect. Leave it here for future demo of models
6882
# that may enclose parameters.
6983
with torch.no_grad():
7084
for idx, tile_single in enumerate(tqdm(tiles, disable=False)):
71-
85+
tile_single: torch.Tensor
7286
tile_single = tile_single.unsqueeze(0)
7387
# BCHW - scaled to [0 1] torch.float32
88+
# cache key herein is the index of data points.
7489
test_out = normalizer_vahadane(tile_single, cache_keys=[idx])
7590
test_out = postprocess(test_out)
7691
plt.imshow(test_out)
@@ -80,9 +95,11 @@ def postprocess(image_tensor): return convert_image_dtype(image_tensor, torch.ui
8095
# %timeit normalizer_vahadane(norm_tensor, positive_dict=True)
8196

8297
# #################### Macenko
83-
84-
85-
normalizer_macenko = NormalizerBuilder.build('macenko', use_cache=True, concentration_solver='ls')
98+
# if using cusolver, 'ls' (least square) will fail on single large images.
99+
# try magma backend if 'ls' is still preferred as the concentration estimator (see below)
100+
# torch.backends.cuda.preferred_linalg_library('magma')
101+
normalizer_macenko = NormalizerBuilder.build('macenko', use_cache=True,
102+
concentration_solver='fista') # 'ls'
86103
normalizer_macenko = normalizer_macenko.to(device)
87104
normalizer_macenko.fit(target_tensor)
88105

@@ -117,9 +134,14 @@ def postprocess(image_tensor): return convert_image_dtype(image_tensor, torch.ui
117134
# Augmentation
118135

119136
augmentor = AugmentorBuilder.build('vahadane',
120-
rng=314159,
137+
sparse_stain_solver='fista',
138+
concentration_solver='fista',
139+
num_stains=2,
140+
rng=314159, # None if globally managing the seeds
121141
sigma_alpha=0.2,
122-
sigma_beta=0.2, target_stain_idx=(0, 1),
142+
sigma_beta=0.2,
143+
# for two stains (herein, H&E), augment both H and E.
144+
target_stain_idx=(0, 1),
123145
use_cache=True,
124146
)
125147
# move augmentor to the device
@@ -142,7 +164,8 @@ def postprocess(image_tensor): return convert_image_dtype(image_tensor, torch.ui
142164
tiles_np = tiles.permute(0, 2, 3, 1).detach().cpu().contiguous().numpy()
143165

144166
for idx, tile_single in enumerate(tqdm(tiles_np)):
145-
tile_single = (tile_single * 255).astype(np.uint8)
167+
tile_single: np.ndarray
168+
tile_single: np.ndarray = (tile_single * 255).astype(np.uint8)
146169
test_out = st_vahadane.transform(tile_single)
147170
plt.imshow(test_out)
148171
plt.title(f"Vahadane StainTools: {idx}")
@@ -153,10 +176,11 @@ def postprocess(image_tensor): return convert_image_dtype(image_tensor, torch.ui
153176
from staintools.stain_normalizer import StainNormalizer
154177
st_macenko = StainNormalizer(method='macenko')
155178
st_macenko.fit(target)
156-
tiles_np = tiles.permute(0, 2, 3, 1).detach().cpu().contiguous().numpy()
179+
tiles_np: np.ndarray = tiles.permute(0, 2, 3, 1).detach().cpu().contiguous().numpy()
157180
# timeit st_macenko.transform(norm)
158181
for idx, tile_single in enumerate(tqdm(tiles_np)):
159-
tile_single = (tile_single * 255).astype(np.uint8)
182+
tile_single: np.ndarray
183+
tile_single: np.ndarray = (tile_single * 255).astype(np.uint8)
160184
test_out = st_macenko.transform(tile_single)
161185
plt.imshow(test_out)
162186
plt.title(f"Vahadane StainTools: {idx}")
@@ -170,7 +194,8 @@ def postprocess(image_tensor): return convert_image_dtype(image_tensor, torch.ui
170194
tiles_np = tiles.permute(0, 2, 3, 1).detach().cpu().contiguous().numpy()
171195
# %timeit st_reinhard.transform(norm)
172196
for idx, tile_single in enumerate(tqdm(tiles_np)):
173-
tile_single = (tile_single * 255).astype(np.uint8)
197+
tile_single: np.ndarray
198+
tile_single: np.ndarray = (tile_single * 255).astype(np.uint8)
174199
test_out = st_reinhard.transform(tile_single)
175200
plt.imshow(test_out)
176201
plt.title(f"Reinhard ST: {idx}")
@@ -217,7 +242,9 @@ def postprocess(image_tensor): return convert_image_dtype(image_tensor, torch.ui
217242
fig, axs = plt.subplots(2, num_repeat + 1, figsize=(15, 8), dpi=300)
218243
for i, ax_alg in enumerate(axs):
219244
alg = algorithms[i].lower()
220-
augmentor = AugmentorBuilder.build(alg, concentration_solver='ista',
245+
# noinspection PyTypeChecker
246+
augmentor = AugmentorBuilder.build(alg,
247+
concentration_solver='ista',
221248
sigma_alpha=0.5,
222249
sigma_beta=0.5,
223250
luminosity_threshold=0.8,

tests/images/test_functionals.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,15 @@ def test_stains(self):
121121
self.eval_wrapper(macenko)
122122
self.eval_wrapper(vahadane)
123123

124-
# vahadane with rng and lr
125-
vahadane.stain_algorithm.cfg.lr = 0.5
126-
TestFunctional.extract_eval_helper(self, vahadane,
127-
conc_solver=ConcentrationSolver(TestFunctional.POSITIVE_CONC_CFG),
128-
luminosity_threshold=None,
129-
num_stains=3, rng=torch.Generator(1))
124+
# github remote end fails due to driver issues. Test it locally.
125+
# # vahadane with rng and lr
126+
# vahadane.stain_algorithm.cfg.lr = 0.5
127+
# TestFunctional.extract_eval_helper(self, vahadane,
128+
# conc_solver=ConcentrationSolver(TestFunctional.POSITIVE_CONC_CFG),
129+
# luminosity_threshold=None,
130+
# num_stains=3, rng=torch.Generator(1))
131+
132+
130133
def test_tissue_mask(self):
131134
device = TestFunctional.device
132135
dummy_scaled = convert_image_dtype(TestFunctional.new_dummy_img_tensor_ubyte(), torch.float32).to(device)

torch_staintools/base_module/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
from ..cache.tensor_cache import TensorCache
24
import torch
35
from typing import Optional, List, Hashable, Callable
@@ -123,6 +125,9 @@ def __init__(self, cache: Optional[TensorCache], device: Optional[torch.device],
123125
self._tensor_cache = cache
124126
self.device = default_device(device)
125127
self._rng = default_rng(rng, self.device)
128+
if self._rng is not None:
129+
warnings.warn(f"A custom RNG is passed and may cause graph break if torch.compile is used."
130+
f"Consider fixing random states globally instead.")
126131

127132
@property
128133
def rng(self):

torch_staintools/constants/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ class _Config:
1111
# Whether to Enforce Positive Code / Concentration
1212
DICT_POSITIVE_CODE: bool = True
1313

14+
# Whether to enable torch.compile (currently only the dictionary learning is affected)
15+
ENABLE_COMPILE: bool = True
1416

1517
CONFIG: _Config = _Config()
1618

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import torch
2+
import functools
3+
import warnings
4+
from typing import Callable, Any, Optional, Protocol, cast
5+
from torch_staintools.constants import CONFIG
6+
7+
_FIELD_COMPILED_ATTR = 'compiled_fn'
8+
9+
10+
class CompiledWrapper(Protocol):
11+
compiled_fn: Optional[Callable]
12+
13+
def reset_cache(self) -> None:
14+
...
15+
16+
def __call__(self, *args, **kwargs):
17+
...
18+
19+
20+
def lazy_compile(func: Callable) -> CompiledWrapper:
21+
"""Enable or disable torch.compile by torch_staintools.constants.CONFIG.ENABLE_COMPILE.
22+
23+
If True, function will be compiled and cached. Otherwise, it will be executed in eager mode.
24+
25+
Args:
26+
func: The function to compile.
27+
28+
Returns:
29+
CompiledWrapper: The compiled function or the original function.
30+
"""
31+
32+
@functools.wraps(func)
33+
def wrapper(*args, **kwargs) -> Any:
34+
enable_compile = getattr(CONFIG, "ENABLE_COMPILE", False)
35+
if not enable_compile:
36+
# if disabled execute it in eager mode
37+
return func(*args, **kwargs)
38+
39+
if not hasattr(wrapper, _FIELD_COMPILED_ATTR) or wrapper.compiled_fn is None:
40+
try:
41+
wrapper.compiled_fn = torch.compile(func)
42+
except Exception as e:
43+
warnings.warn(f"torch.compile failed for '{func.__name__}': {e}. "
44+
f"Falling back to eager execution.")
45+
wrapper.compiled_fn = func
46+
47+
return wrapper.compiled_fn(*args, **kwargs)
48+
49+
wrapper = cast(CompiledWrapper, wrapper)
50+
# init the attribute
51+
wrapper.compiled_fn = None
52+
# clear the compiled cache --> future use
53+
def reset_cache():
54+
wrapper.compiled_fn = None
55+
wrapper.reset_cache = reset_cache
56+
return wrapper
57+
58+
59+
def static_compile(func: Callable) -> Callable:
60+
"""Import-time wrapper.
61+
62+
CONFIG.ENABLE_COMPILE must be modified before importing any compiled functions.
63+
64+
Args:
65+
func: The function to compile.
66+
67+
Returns:
68+
The compiled function or the original function.
69+
"""
70+
71+
if getattr(CONFIG, "ENABLE_COMPILE", False):
72+
try:
73+
return torch.compile(func)
74+
except Exception as e:
75+
warnings.warn(f"torch.compile failed for '{func.__name__}': {e}. "
76+
f"Falling back to eager execution.")
77+
return func
78+
79+
return func

torch_staintools/functional/optimization/dict_learning.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@
66
import torch
77
import torch.nn.functional as F
88
from typing import Optional, cast, Tuple
9+
10+
from ..compile import lazy_compile
911
from ..eps import get_eps
1012
from torch_staintools.constants import CONFIG
1113

1214

13-
@torch.compile
15+
# @torch.compile
16+
# @static_compile
17+
@lazy_compile
1418
def update_dict_cd(dictionary: torch.Tensor, x: torch.Tensor, code: torch.Tensor,
1519
positive: bool = True,
1620
dead_thresh=1e-7,
@@ -86,11 +90,14 @@ def update_dict_cd(dictionary: torch.Tensor, x: torch.Tensor, code: torch.Tensor
8690
return dictionary, code
8791

8892

89-
@torch.compile
93+
# @torch.compile
94+
# @static_compile
95+
@lazy_compile
9096
def update_dict_ridge(x: torch.Tensor, code: torch.Tensor, lambd: float) -> Tuple[torch.Tensor, torch.Tensor]:
9197
"""Update an (unconstrained) dictionary with ridge regression
9298
93-
This is equivalent to a Newton step with the (L2-regularized) squared
99+
This is equivalent to a Newton step with the (L2-regularized) squared.
100+
May have severe numerical stability issues compared to update_dict_cd.
94101
error objective:
95102
f(V) = (1/2N) * ||Vz - x||_2^2 + (lambd/2) * ||V||_2^2
96103

torch_staintools/functional/optimization/solver.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import torch
55
import torch.nn.functional as F
66

7+
from torch_staintools.functional.compile import lazy_compile
8+
79

810
def coord_descent(x: torch.Tensor, z0: torch.Tensor, weight: torch.Tensor,
911
alpha: torch.Tensor,
@@ -137,7 +139,9 @@ def fista_step(
137139

138140

139141

140-
@torch.compile
142+
# @torch.compile
143+
# @static_compile
144+
@lazy_compile
141145
def ista_loop(z: torch.Tensor, hessian: torch.Tensor, b: torch.Tensor,
142146
alpha: torch.Tensor, lr: torch.Tensor,
143147
tol: float, maxiter: int, positive_code: bool):
@@ -153,7 +157,9 @@ def ista_loop(z: torch.Tensor, hessian: torch.Tensor, b: torch.Tensor,
153157
return z
154158

155159

156-
@torch.compile
160+
# @torch.compile
161+
# @static_compile
162+
@lazy_compile
157163
def fista_loop(
158164
z: torch.Tensor,
159165
hessian: torch.Tensor,

torch_staintools/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.0.5a'
1+
__version__ = '1.0.5'

0 commit comments

Comments
 (0)