Skip to content

Commit a472027

Browse files
Merge pull request #39 from ajinkya-kulkarni/patch-1
Speed optimizations
2 parents 64622cd + da63f5a commit a472027

File tree

10 files changed

+64
-51
lines changed

10 files changed

+64
-51
lines changed

.github/workflows/build.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ on:
66
jobs:
77
build_wheels:
88
name: Build release
9-
runs-on: ubuntu-18.04
9+
runs-on: ubuntu-20.04
1010

1111
steps:
1212
- uses: actions/checkout@v3
@@ -27,7 +27,7 @@ jobs:
2727

2828
upload_pypi:
2929
needs: build_wheels
30-
runs-on: ubuntu-18.04
30+
runs-on: ubuntu-20.04
3131

3232
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')
3333

.github/workflows/tests_full.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ on:
1010

1111
jobs:
1212
build:
13-
runs-on: ubuntu-18.04
13+
runs-on: ubuntu-20.04
1414

1515
if: startsWith(github.ref, 'refs/tags/v') != true
1616

@@ -39,7 +39,7 @@ jobs:
3939
runs-on: ${{ matrix.os }}
4040
strategy:
4141
matrix:
42-
os: [ windows-2019, ubuntu-18.04, macos-11 ]
42+
os: [ windows-2019, ubuntu-20.04, macos-11 ]
4343
python-version: [ 3.7, 3.8, 3.9 ]
4444
tf-version: [2.7.0, 2.8.0, 2.9.0]
4545

@@ -71,14 +71,16 @@ jobs:
7171
runs-on: ${{ matrix.os }}
7272
strategy:
7373
matrix:
74-
os: [ windows-2019, ubuntu-18.04, macos-11 ]
74+
os: [ windows-2019, ubuntu-20.04, macos-11 ]
7575
python-version: [ 3.6, 3.7, 3.8, 3.9 ]
7676
pytorch-version: [1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0]
7777
exclude:
7878
- python-version: 3.6
7979
pytorch-version: 1.11.0
8080
- python-version: 3.6
8181
pytorch-version: 1.12.0
82+
- python-version: 3.6
83+
pytorch-version: 1.13.0
8284

8385
steps:
8486
- uses: actions/checkout@v1

.github/workflows/tests_quick.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ on:
1010

1111
jobs:
1212
build:
13-
runs-on: ubuntu-18.04
13+
runs-on: ubuntu-20.04
1414
steps:
1515
- uses: actions/checkout@v1
1616
- name: Set up Python 3.6
@@ -33,7 +33,7 @@ jobs:
3333

3434
test-tf:
3535
needs: build
36-
runs-on: ubuntu-18.04
36+
runs-on: ubuntu-20.04
3737

3838
steps:
3939
- uses: actions/checkout@v1

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
setup(
88
name='torchstain',
9-
version='1.2.0',
9+
version='1.3.0',
1010
description='Stain normalization tools for histological analysis and computational pathology',
1111
long_description=README,
1212
long_description_content_type='text/markdown',

tests/test_color_conv.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@
44
import cv2
55
import os
66

7-
def test_rgb_to_lab():
7+
def test_rgb_lab():
88
size = 1024
99
curr_file_path = os.path.dirname(os.path.realpath(__file__))
1010
img = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size))
11+
12+
# rgb2lab expects data to be float32 in range [0, 1]
13+
img = img / 255
1114

15+
# convert from RGB to LAB and back again to RGB
1216
reconstructed_img = lab2rgb(rgb2lab(img))
13-
val = np.mean(np.abs(reconstructed_img - img))
14-
print("MAE:", val)
15-
assert val < 0.1
17+
18+
# assess if the reconstructed image is similar to the original image
19+
np.testing.assert_almost_equal(np.mean(np.abs(reconstructed_img - img)), 0.0, decimal=4, verbose=True)

tests/test_tf.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import torchstain
44
import torchstain.tf
55
import tensorflow as tf
6-
import time
7-
from skimage.metrics import structural_similarity as ssim
86
import numpy as np
97

108
def test_cov():
@@ -44,11 +42,11 @@ def test_macenko_tf():
4442
result_tf, _, _ = tf_normalizer.normalize(I=t_to_transform, stains=True)
4543

4644
# convert to numpy and set dtype
47-
result_numpy = result_numpy.astype("float32")
48-
result_tf = result_tf.numpy().astype("float32")
45+
result_numpy = result_numpy.astype("float32") / 255.
46+
result_tf = result_tf.numpy().astype("float32") / 255.
4947

5048
# assess whether the normalized images are identical across backends
51-
np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_tf.flatten()), 1.0, decimal=4, verbose=True)
49+
np.testing.assert_almost_equal(result_numpy.flatten(), result_tf.flatten(), decimal=2, verbose=True)
5250

5351
def test_reinhard_tf():
5452
size = 1024
@@ -72,8 +70,8 @@ def test_reinhard_tf():
7270
result_tf = tf_normalizer.normalize(I=t_to_transform)
7371

7472
# convert to numpy and set dtype
75-
result_numpy = result_numpy.astype("float32")
76-
result_tf = result_tf.numpy().astype("float32")
73+
result_numpy = result_numpy.astype("float32") / 255.
74+
result_tf = result_tf.numpy().astype("float32") / 255.
7775

7876
# assess whether the normalized images are identical across backends
79-
np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_tf.flatten()), 1.0, decimal=4, verbose=True)
77+
np.testing.assert_almost_equal(result_numpy.flatten(), result_tf.flatten(), decimal=2, verbose=True)

tests/test_torch.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
import torchstain.torch
55
import torch
66
import torchvision
7-
import time
87
import numpy as np
98
from torchvision import transforms
10-
from skimage.metrics import structural_similarity as ssim
9+
1110

1211
def setup_function(fn):
1312
print("torch version:", torch.__version__, "torchvision version:", torchvision.__version__)
@@ -52,11 +51,11 @@ def test_macenko_torch():
5251
result_torch, _, _ = torch_normalizer.normalize(I=t_to_transform, stains=True)
5352

5453
# convert to numpy and set dtype
55-
result_numpy = result_numpy.astype("float32")
56-
result_torch = result_torch.numpy().astype("float32")
54+
result_numpy = result_numpy.astype("float32") / 255.
55+
result_torch = result_torch.numpy().astype("float32") / 255.
5756

5857
# assess whether the normalized images are identical across backends
59-
np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_torch.flatten()), 1.0, decimal=4, verbose=True)
58+
np.testing.assert_almost_equal(result_numpy.flatten(), result_torch.flatten(), decimal=2, verbose=True)
6059

6160
def test_reinhard_torch():
6261
size = 1024
@@ -83,8 +82,9 @@ def test_reinhard_torch():
8382
result_torch = torch_normalizer.normalize(I=t_to_transform)
8483

8584
# convert to numpy and set dtype
86-
result_numpy = result_numpy.astype("float32")
87-
result_torch = result_torch.numpy().astype("float32")
85+
result_numpy = result_numpy.astype("float32") / 255.
86+
result_torch = result_torch.numpy().astype("float32") / 255.
8887

88+
8989
# assess whether the normalized images are identical across backends
90-
np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_torch.flatten()), 1.0, decimal=4, verbose=True)
90+
np.testing.assert_almost_equal(result_numpy.flatten(), result_torch.flatten(), decimal=2, verbose=True)

torchstain/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = '1.2.0'
1+
__version__ = '1.3.0'
22

33
from torchstain.base import normalizers

torchstain/numpy/utils/lab2rgb.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,39 @@
77
Implementation is based on:
88
https://github.com/scikit-image/scikit-image/blob/00177e14097237ef20ed3141ed454bc81b308f82/skimage/color/colorconv.py#L704
99
"""
10-
def lab2rgb(lab):
11-
lab = lab.astype("float32")
10+
def lab2rgb(lab: np.ndarray) -> np.ndarray:
11+
"""
12+
Convert an array of LAB values to RGB values.
13+
14+
Args:
15+
lab (np.ndarray): An array of shape (..., 3) containing LAB values.
16+
17+
Returns:
18+
np.ndarray: An array of shape (..., 3) containing RGB values.
19+
"""
1220
# first rescale back from OpenCV format
1321
lab[..., 0] /= 2.55
14-
lab[..., 1] -= 128
15-
lab[..., 2] -= 128
22+
lab[..., 1:] -= 128
1623

1724
# convert LAB -> XYZ color domain
18-
L, a, b = lab[..., 0], lab[..., 1], lab[..., 2]
19-
y = (L + 16.) / 116.
20-
x = (a / 500.) + y
21-
z = y - (b / 200.)
25+
y = (lab[..., 0] + 16.) / 116.
26+
x = (lab[..., 1] / 500.) + y
27+
z = y - (lab[..., 2] / 200.)
2228

23-
out = np.stack([x, y, z], axis=-1)
29+
xyz = np.stack([x, y, z], axis=-1)
2430

25-
mask = out > 0.2068966
26-
out[mask] = np.power(out[mask], 3.)
27-
out[~mask] = (out[~mask] - 16.0 / 116.) / 7.787
31+
mask = xyz > 0.2068966
32+
xyz[mask] = np.power(xyz[mask], 3.)
33+
xyz[~mask] = (xyz[~mask] - 16.0 / 116.) / 7.787
2834

2935
# rescale to the reference white (illuminant)
30-
out *= np.array((0.95047, 1., 1.08883), dtype=out.dtype)
31-
36+
xyz *= np.array((0.95047, 1., 1.08883), dtype=xyz.dtype)
37+
3238
# convert XYZ -> RGB color domain
33-
arr = out.copy()
34-
arr = np.dot(arr, _xyz2rgb.T)
35-
mask = arr > 0.0031308
36-
arr[mask] = 1.055 * np.power(arr[mask], 1 / 2.4) - 0.055
37-
arr[~mask] *= 12.92
38-
return np.clip(arr, 0, 1)
39+
rgb = np.matmul(xyz, _xyz2rgb.T)
40+
41+
mask = rgb > 0.0031308
42+
rgb[mask] = 1.055 * np.power(rgb[mask], 1 / 2.4) - 0.055
43+
rgb[~mask] *= 12.92
44+
45+
return np.clip(rgb, 0, 1)

torchstain/torch/normalizers/macenko.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ def __init__(self):
1414
[0.7201, 0.8012],
1515
[0.4062, 0.5581]])
1616
self.maxCRef = torch.tensor([1.9705, 1.0308])
17-
self.deprecated_torch = torch.__version__ < (1,9,0)
17+
18+
# Avoid using deprecated torch.lstsq (since 1.9.0)
19+
self.updated_lstsq = hasattr(torch.linalg, 'lstsq')
1820

1921
def __convert_rgb2od(self, I, Io, beta):
2022
I = I.permute(1, 2, 0)
@@ -50,7 +52,7 @@ def __find_concentration(self, OD, HE):
5052
Y = OD.T
5153

5254
# determine concentrations of the individual stains
53-
if self.deprecated_torch:
55+
if not self.updated_lstsq:
5456
return torch.lstsq(Y, HE)[0][:2]
5557

5658
return torch.linalg.lstsq(HE, Y)[0]

0 commit comments

Comments
 (0)