Skip to content

Commit a52c85c

Browse files
committed
Fixed usage bug of multitarget normalizer
1 parent f268a78 commit a52c85c

File tree

3 files changed

+11
-9
lines changed

3 files changed

+11
-9
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .he_normalizer import HENormalizer
22
from .macenko import MacenkoNormalizer
3+
from .multitarget import MultiMacenkoNormalizer
34
from .reinhard import ReinhardNormalizer
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from torchstain.torch.normalizers.macenko import TorchMacenkoNormalizer
2-
from torchstain.torch.normalizers.multitarget import MultiMacenkoNormalizer
2+
from torchstain.torch.normalizers.multitarget import TorchMultiMacenkoNormalizer
33
from torchstain.torch.normalizers.reinhard import TorchReinhardNormalizer

torchstain/torch/normalizers/multitarget.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import torch
22
from torchstain.torch.utils import cov, percentile
3+
34
"""
45
Implementation of the multi-target normalizer from the paper: https://arxiv.org/pdf/2406.02077
56
"""
6-
class MultiMacenkoNormalizer:
7-
def __init__(self, norm_mode='avg-post'):
7+
class TorchMultiMacenkoNormalizer:
8+
def __init__(self, norm_mode="avg-post"):
89
self.norm_mode = norm_mode
910
self.HERef = torch.tensor([[0.5626, 0.2159],
1011
[0.7201, 0.8012],
1112
[0.4062, 0.5581]])
1213
self.maxCRef = torch.tensor([1.9705, 1.0308])
13-
self.updated_lstsq = hasattr(torch.linalg, 'lstsq')
14+
self.updated_lstsq = hasattr(torch.linalg, "lstsq")
1415

1516
def __convert_rgb2od(self, I, Io, beta):
1617
I = I.permute(1, 2, 0)
@@ -59,15 +60,15 @@ def __compute_matrices_single(self, I, Io, alpha, beta):
5960
return HE, C, maxC
6061

6162
def fit(self, Is, Io=240, alpha=1, beta=0.15):
62-
if self.norm_mode == 'avg-post':
63+
if self.norm_mode == "avg-post":
6364
HEs, _, maxCs = zip(*(
6465
self.__compute_matrices_single(I, Io, alpha, beta)
6566
for I in Is
6667
))
6768

6869
self.HERef = torch.stack(HEs).mean(dim=0)
6970
self.maxCRef = torch.stack(maxCs).mean(dim=0)
70-
elif self.norm_mode == 'concat':
71+
elif self.norm_mode == "concat":
7172
ODs, ODhats = zip(*(
7273
self.__convert_rgb2od(I, Io, beta)
7374
for I in Is
@@ -83,7 +84,7 @@ def fit(self, Is, Io=240, alpha=1, beta=0.15):
8384
maxCs = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)])
8485
self.HERef = HE
8586
self.maxCRef = maxCs
86-
elif self.norm_mode == 'avg-pre':
87+
elif self.norm_mode == "avg-pre":
8788
ODs, ODhats = zip(*(
8889
self.__convert_rgb2od(I, Io, beta)
8990
for I in Is
@@ -100,7 +101,7 @@ def fit(self, Is, Io=240, alpha=1, beta=0.15):
100101
maxCs = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)])
101102
self.HERef = HE
102103
self.maxCRef = maxCs
103-
elif self.norm_mode == 'fixed-single' or self.norm_mode == 'stochastic-single':
104+
elif self.norm_mode == "fixed-single" or self.norm_mode == "stochastic-single":
104105
# single img
105106
self.HERef, _, self.maxCRef = self.__compute_matrices_single(Is[0], Io, alpha, beta)
106107
else:
@@ -127,4 +128,4 @@ def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True):
127128
E[E > 255] = 255
128129
E = E.T.reshape(h, w, c).int()
129130

130-
return Inorm, H, E
131+
return Inorm, H, E

0 commit comments

Comments
 (0)