Skip to content

Commit c696fde

Browse files
authored
Merge pull request #20 from CielAl/dev
1. fix a problem to draw random number for augmentation. 2. Add least…
2 parents d2426c2 + ff371b9 commit c696fde

File tree

16 files changed

+315
-97
lines changed

16 files changed

+315
-97
lines changed

README.md

Lines changed: 87 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,34 +16,89 @@
1616
## Documentation
1717
Detail documentation regarding the code base can be found in the [GitPages](https://cielal.github.io/torch-staintools/).
1818

19+
## Citation
20+
If this toolkit helps you in your publication, please feel free to cite with the following bibtex entry:
21+
```bibtex
22+
@software{zhou_2024_10453807,
23+
author = {Zhou, Yufei},
24+
title = {CielAl/torch-staintools: V1.0.3 Release},
25+
month = jan,
26+
year = 2024,
27+
publisher = {Zenodo},
28+
version = {v1.0.3},
29+
doi = {10.5281/zenodo.10453807},
30+
url = {https://doi.org/10.5281/zenodo.10453807}
31+
}
32+
```
33+
1934
## Description
2035
* Stain Normalization (Reinhard, Macenko, and Vahadane) for pytorch. Input tensors (fit and transform) must be in shape of `NxCxHxW`, with value scaled to [0, 1] in format of torch.float32.
2136
* Stain Augmentation using Macenko and Vahadane as stain extraction.
2237
* Fast normalization/augmentation on GPU with stain matrices caching.
23-
* Simulate the workflow in [StainTools library](https://github.com/Peter554/StainTools) but use the Iterative Shrinkage Thresholding Algorithm (ISTA), or optionally, the coordinate descent (CD) to solve the dictionary learning for stain matrix/concentration computation in Vahadane or Macenko (stain concentration only) algorithm. The implementation of ISTA and CD are derived from Cédric Walker's [torchvahadane](https://github.com/cwlkr/torchvahadane)
38+
* Simulate the workflow in [StainTools library](https://github.com/Peter554/StainTools) but use the Iterative Shrinkage Thresholding Algorithm (ISTA), or optionally, the coordinate descent (CD) to solve the dictionary learning for stain matrix computation in Vahadane or Macenko (stain concentration only) algorithm. The implementation of ISTA and CD are derived from Cédric Walker's [torchvahadane](https://github.com/cwlkr/torchvahadane)
39+
* Stain Concentration is solved via factorization of `Stain_Matrix x Concentration = Optical_Density`. For efficient sparse solution and more robust outcomes, ISTA can be applied. Alternatively, Least Square solver (LS) from `torch.linalg.lstsq` might be applied for faster non-sparse solution.
2440
* No SPAMS requirement (which is a dependency in StainTools).
2541

2642
<br />
2743

28-
#### Sample Output of Torch StainTools
44+
#### Sample Output of Torch-StainTools Normalization
2945
![Screenshot](https://raw.githubusercontent.com/CielAl/torch-staintools/main/showcases/sample_out.png)
3046

3147
#### Sample Output of StainTools
3248
![Screenshot](https://raw.githubusercontent.com/CielAl/torch-staintools/main/showcases/sample_out_staintools.png)
3349

34-
## Use case
50+
#### Sample Output of Torch-StainTools Augmentation (Repeat 3 times)
51+
![Screenshot](https://raw.githubusercontent.com/CielAl/torch-staintools/main/showcases/sample_out_augmentation.png)
52+
53+
#### Sample Output of StainTools Augmentation (Repeat 3 times)
54+
![Screenshot](https://raw.githubusercontent.com/CielAl/torch-staintools/main/showcases/sample_out_augmentation_staintools.png)
55+
56+
## Benchmark (No Stain Matrices Caching)
57+
* Use the sample images under ./test_images (size `2500x2500x3`). Mean was computed from 7 runs (1 loop per run) using
58+
timeit. Comparison between torch_stain_tools in CPU/GPU mode, as well as that of the StainTools Implementation.
59+
* For consistency, use ISTA to compute the concentration.
60+
61+
### Transformation
62+
63+
| Method | CPU[s] | GPU[s] | StainTool[s] |
64+
|:---------|:-------|:-------|:-------------|
65+
| Vahadane | 119 | 7.5 | 20.9 |
66+
| Macenko | 5.57 | 0.479 | 20.7 |
67+
| Reinhard | 0.840 | 0.024 | 0.414 |
68+
69+
### Fitting
70+
| Method | CPU[s] | GPU[s] | StainTool[s] |
71+
|:---------|:-------|:-------|:-------------|
72+
| Vahadane | 132 | 8.40 | 19.1 |
73+
| Macenko | 6.99 | 0.064 | 20.0 |
74+
| Reinhard | 0.422 | 0.011 | 0.076 |
75+
76+
### Batchified Concentration Computation
77+
* Split the sample images under ./test_images (size `2500x2500x3`) into 81 non-overlapping `256x256x3` tiles as a batch.
78+
* For the StainTools baseline, a for-loop is implemented to get the individual concentration of each of the numpy array of the 81 tiles.
79+
*
80+
| Method | CPU[s] | GPU[s] |
81+
|:-------------------------------------|:-------|:----------|
82+
| ISTA (`concentration_method='ista'`) | 3.12 | 1.24 |
83+
| CD (`concentration_method='cd'`) | 29.3s | 4.87 |
84+
| LS (`concentration_method='ls'`) | 0.221 | **0.097** |
85+
| StainTools (SPAMS) | 16.6 | N/A |
86+
87+
88+
## Use Cases and Tips
3589
* For details, follow the example in demo.py
3690
* Normalizers are wrapped as `torch.nn.Module`, working similarly to a standalone neural network. This means that for a workflow involving dataloader with multiprocessing, the normalizer
3791
(Note that CUDA has poor support in multiprocessing, and therefore it may not be the best practice to perform GPU-accelerated on-the-fly stain transformation in pytorch's dataset/dataloader)
3892

93+
* `concentration_method='ls'` (i.e., `torch.linalg.lstsq`) can be efficient for batches of many smaller input (e.g., `256x256`) in terms of width and height. However, it may fail on GPU for a single larger input image (width and height). This happens even if the
94+
the total number of pixels of the image is fewer than the aforementioned batch of multiple smaller input. Therefore, `concentration_method='ls'` could be suitable to deal with huge amount of small images in batches on the fly.
3995

4096
```python
4197
import cv2
4298
import torch
4399
from torchvision.transforms import ToTensor
44-
from torchvision.transforms.functional import convert_image_dtype
45-
from torch_staintools.normalizer.factory import NormalizerBuilder
46-
from torch_staintools.augmentor.factory import AugmentorBuilder
100+
from torch_staintools.normalizer import NormalizerBuilder
101+
from torch_staintools.augmentor import AugmentorBuilder
47102
import os
48103
seed = 0
49104
torch.manual_seed(seed)
@@ -71,7 +126,15 @@ norm_tensor = ToTensor()(norm).unsqueeze(0).to(device)
71126

72127
# ######## Normalization
73128
# create the normalizer - using vahadane. Alternatively can use 'macenko' or 'reinhard'.
74-
normalizer_vahadane = NormalizerBuilder.build('vahadane')
129+
# note this is equivalent to:
130+
# from torch_staintools.normalizer.separation import StainSeparation
131+
# normalizer_vahadane = StainSeparation.build('vahadane', **arguments)
132+
133+
# we use the 'ista' (ISTA algorithm) to get the sparse solution of the factorization: STAIN_MATRIX * Concentration = OD
134+
# alternatively, 'cd' (coordinate descent) and 'ls' (least square from torch.linalg) is available.
135+
# Note that 'ls' does not can be much faster on batches of smaller input, but may fail on GPU for individual large input
136+
# in terms of width and height, regardless of the batch size
137+
normalizer_vahadane = NormalizerBuilder.build('vahadane', concentration_method='ista')
75138
# move the normalizer to the device (CPU or GPU)
76139
normalizer_vahadane = normalizer_vahadane.to(device)
77140
# fit. For macenko and vahadane this step will compute the stain matrix and concentration
@@ -89,7 +152,8 @@ augmentor = AugmentorBuilder.build('vahadane',
89152
# the luminosity threshold to find the tissue region to augment
90153
# if set to None means all pixels are treated as tissue
91154
luminosity_threshold=0.8,
92-
155+
# herein we use 'ista' to compute the concentration
156+
concentration_method='ista',
93157
sigma_alpha=0.2,
94158
sigma_beta=0.2, target_stain_idx=(0, 1),
95159
# this allows to cache the stain matrix if it's too time-consuming to recompute.
@@ -117,6 +181,21 @@ for _ in range(num_augment):
117181

118182
# dump the cache of stain matrices for future usage
119183
augmentor.dump_cache('./cache.pickle')
184+
185+
# fast batch operation
186+
tile_size = 512
187+
tiles: torch.Tensor = norm_tensor.unfold(2, tile_size, tile_size)\
188+
.unfold(3, tile_size, tile_size).reshape(1, 3, -1, tile_size, tile_size).squeeze(0).permute(1, 0, 2, 3).contiguous()
189+
print(tiles.shape)
190+
# use macenko normalization as example
191+
normalizer_macenko = NormalizerBuilder.build('macenko', use_cache=True,
192+
# use least square solver, along with cache, to perform
193+
# normalization on-the-fly
194+
concentration_method='ls')
195+
normalizer_macenko = normalizer_macenko.to(device)
196+
normalizer_macenko.fit(target_tensor)
197+
normalizer_macenko(tiles)
198+
120199
```
121200
## Stain Matrix Caching
122201
As elaborated in the below in the running time benchmark of fitting, computation of stain matrix could be time-consuming.
@@ -133,40 +212,6 @@ augmentor(input_batch, cache_keys=list_of_keys_corresponding_to_input_batch)
133212
The next time `Normalizer` or `Augmentor` process the images, the corresponding stain matrices will be queried and fetched from cache if they are stored already, rather than recomputing from scratch.
134213

135214

136-
## Benchmark
137-
* Use the sample images under ./test_images (size `2500x2500x3`). Mean was computed from 7 runs (1 loop per run) using
138-
timeit. Comparison between torch_stain_tools in CPU/GPU mode, as well as that of the StainTools Implementation.
139-
140-
### Transformation
141-
142-
| Method | CPU[s] | GPU[s] | StainTool[s] |
143-
|:---------|:-------|:-------|:-------------|
144-
| Vahadane | 119 | 7.5 | 20.9 |
145-
| Macenko | 5.57 | 0.479 | 20.7 |
146-
| Reinhard | 0.840 | 0.024 | 0.414 |
147-
148-
### Fitting
149-
| Method | CPU[s] | GPU[s] | StainTool[s] |
150-
|:---------|:-------|:-------|:-------------|
151-
| Vahadane | 132 | 8.40 | 19.1 |
152-
| Macenko | 6.99 | 0.064 | 20.0 |
153-
| Reinhard | 0.422 | 0.011 | 0.076 |
154-
155-
## Citation
156-
If this toolkit helps you in your publication, please feel free to cite with the following bibtex entry:
157-
```bibtex
158-
@software{zhou_2024_10453807,
159-
author = {Zhou, Yufei},
160-
title = {CielAl/torch-staintools: V1.0.3 Release},
161-
month = jan,
162-
year = 2024,
163-
publisher = {Zenodo},
164-
version = {v1.0.3},
165-
doi = {10.5281/zenodo.10453807},
166-
url = {https://doi.org/10.5281/zenodo.10453807}
167-
}
168-
```
169-
170215
## Acknowledgments
171216
* Some codes are derived from [torchvahadane](https://github.com/cwlkr/torchvahadane), [torchstain](https://github.com/EIDOSLAB/torchstain), and [StainTools](https://github.com/Peter554/StainTools)
172217
* Sample images in the demo and ReadMe.md are selected from [The Cancer Genome Atlas Program(TCGA)](https://www.cancer.gov/ccg/research/genome-sequencing/tcga) dataset.

demo.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
import torch
33
from torchvision.transforms import ToTensor
44
from torchvision.transforms.functional import convert_image_dtype
5-
from torch_staintools.normalizer.factory import NormalizerBuilder
6-
from torch_staintools.augmentor.factory import AugmentorBuilder
5+
from torch_staintools.normalizer import NormalizerBuilder
6+
from torch_staintools.augmentor import AugmentorBuilder
77
import matplotlib.pyplot as plt
88
import numpy as np
99
from tqdm import tqdm
10+
import random
1011
import os
1112
seed = 0
1213
torch.manual_seed(seed)
@@ -52,7 +53,7 @@ def postprocess(image_tensor): return convert_image_dtype(image_tensor, torch.ui
5253

5354

5455
# ######### Vahadane
55-
normalizer_vahadane = NormalizerBuilder.build('vahadane', reconst_method='ista', use_cache=True,
56+
normalizer_vahadane = NormalizerBuilder.build('vahadane', concentration_method='ista', use_cache=True,
5657
rng=1,
5758
)
5859
normalizer_vahadane = normalizer_vahadane.to(device)
@@ -75,7 +76,7 @@ def postprocess(image_tensor): return convert_image_dtype(image_tensor, torch.ui
7576
# #################### Macenko
7677

7778

78-
normalizer_macenko = NormalizerBuilder.build('macenko')
79+
normalizer_macenko = NormalizerBuilder.build('macenko', use_cache=True, concentration_method='ls')
7980
normalizer_macenko = normalizer_macenko.to(device)
8081
normalizer_macenko.fit(target_tensor)
8182

@@ -202,3 +203,52 @@ def postprocess(image_tensor): return convert_image_dtype(image_tensor, torch.ui
202203
ax.axis('off')
203204
plt.savefig(os.path.join('.', 'showcases', 'sample_out_staintools.png'), bbox_inches='tight')
204205
plt.show()
206+
207+
algorithms = ['Vahadane', 'Macenko']
208+
num_repeat = 3
209+
210+
# # sample aug output
211+
fig, axs = plt.subplots(2, num_repeat + 1, figsize=(15, 8), dpi=300)
212+
for i, ax_alg in enumerate(axs):
213+
alg = algorithms[i].lower()
214+
augmentor = AugmentorBuilder.build(alg, concentration_method='ista',
215+
sigma_alpha=0.5,
216+
sigma_beta=0.5,
217+
luminosity_threshold=0.8,
218+
rng=314159, use_cache=True).to(device)
219+
ax_alg[0].imshow(norm)
220+
ax_alg[0].set_title("Augmentation Original")
221+
ax_alg[0].axis('off')
222+
for j in range(1, len(ax_alg)):
223+
aug_out = augmentor(norm_tensor, cache_keys=[0])
224+
ax_alg[j].imshow(postprocess(aug_out))
225+
ax_alg[j].set_title(f"{alg} :{j}")
226+
ax_alg[j].axis('off')
227+
plt.savefig(os.path.join('.', 'showcases', 'sample_out_augmentation.png'), bbox_inches='tight')
228+
plt.show()
229+
230+
231+
# #### sample aug output
232+
np.random.seed(314159)
233+
random.seed(314159)
234+
from staintools import StainAugmentor
235+
from staintools.preprocessing.luminosity_standardizer import LuminosityStandardizer
236+
algorithms = ['Vahadane', 'Macenko']
237+
fig, axs = plt.subplots(2, num_repeat + 1, figsize=(15, 8), dpi=300)
238+
for i, ax_alg in enumerate(axs):
239+
alg = algorithms[i].lower()
240+
augmentor = StainAugmentor(method=alg, sigma1=0.5, sigma2=0.5, augment_background=False)
241+
standardized_norm = LuminosityStandardizer.standardize(norm)
242+
augmentor.fit(standardized_norm)
243+
ax_alg[0].imshow(standardized_norm)
244+
ax_alg[0].set_title("Augmentation Original")
245+
ax_alg[0].axis('off')
246+
for j in range(1, len(ax_alg)):
247+
aug_out = augmentor.pop().astype(np.uint8)
248+
ax_alg[j].imshow(aug_out)
249+
ax_alg[j].set_title(f"{alg} - StainTools: {j}")
250+
ax_alg[j].axis('off')
251+
plt.savefig(os.path.join('.', 'showcases', 'sample_out_augmentation_staintools.png'), bbox_inches='tight')
252+
plt.show()
253+
254+
12.3 MB
Loading
12.4 MB
Loading

tests/images/test_functionals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def extract_eval_helper(tester, get_stain_mat, luminosity_threshold,
8383
def eval_wrapper(self, extractor):
8484

8585
# all pixel
86-
algorithms = ['ista', 'cd']
86+
algorithms = ['ista', 'cd', 'ls']
8787
for alg in algorithms:
8888
TestFunctional.extract_eval_helper(self, extractor, luminosity_threshold=None,
8989
num_stains=2, regularizer=0.1, dict_algorithm=alg)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .factory import *

0 commit comments

Comments
 (0)