11import torch
22from torchstain .torch .utils import cov , percentile
3+
34"""
45Implementation of the multi-target normalizer from the paper: https://arxiv.org/pdf/2406.02077
56"""
6- class MultiMacenkoNormalizer :
7- def __init__ (self , norm_mode = ' avg-post' ):
7+ class TorchMultiMacenkoNormalizer :
8+ def __init__ (self , norm_mode = " avg-post" ):
89 self .norm_mode = norm_mode
910 self .HERef = torch .tensor ([[0.5626 , 0.2159 ],
1011 [0.7201 , 0.8012 ],
1112 [0.4062 , 0.5581 ]])
1213 self .maxCRef = torch .tensor ([1.9705 , 1.0308 ])
13- self .updated_lstsq = hasattr (torch .linalg , ' lstsq' )
14+ self .updated_lstsq = hasattr (torch .linalg , " lstsq" )
1415
1516 def __convert_rgb2od (self , I , Io , beta ):
1617 I = I .permute (1 , 2 , 0 )
@@ -59,15 +60,15 @@ def __compute_matrices_single(self, I, Io, alpha, beta):
5960 return HE , C , maxC
6061
6162 def fit (self , Is , Io = 240 , alpha = 1 , beta = 0.15 ):
62- if self .norm_mode == ' avg-post' :
63+ if self .norm_mode == " avg-post" :
6364 HEs , _ , maxCs = zip (* (
6465 self .__compute_matrices_single (I , Io , alpha , beta )
6566 for I in Is
6667 ))
6768
6869 self .HERef = torch .stack (HEs ).mean (dim = 0 )
6970 self .maxCRef = torch .stack (maxCs ).mean (dim = 0 )
70- elif self .norm_mode == ' concat' :
71+ elif self .norm_mode == " concat" :
7172 ODs , ODhats = zip (* (
7273 self .__convert_rgb2od (I , Io , beta )
7374 for I in Is
@@ -83,7 +84,7 @@ def fit(self, Is, Io=240, alpha=1, beta=0.15):
8384 maxCs = torch .stack ([percentile (C [0 , :], 99 ), percentile (C [1 , :], 99 )])
8485 self .HERef = HE
8586 self .maxCRef = maxCs
86- elif self .norm_mode == ' avg-pre' :
87+ elif self .norm_mode == " avg-pre" :
8788 ODs , ODhats = zip (* (
8889 self .__convert_rgb2od (I , Io , beta )
8990 for I in Is
@@ -100,7 +101,7 @@ def fit(self, Is, Io=240, alpha=1, beta=0.15):
100101 maxCs = torch .stack ([percentile (C [0 , :], 99 ), percentile (C [1 , :], 99 )])
101102 self .HERef = HE
102103 self .maxCRef = maxCs
103- elif self .norm_mode == ' fixed-single' or self .norm_mode == ' stochastic-single' :
104+ elif self .norm_mode == " fixed-single" or self .norm_mode == " stochastic-single" :
104105 # single img
105106 self .HERef , _ , self .maxCRef = self .__compute_matrices_single (Is [0 ], Io , alpha , beta )
106107 else :
@@ -127,4 +128,4 @@ def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True):
127128 E [E > 255 ] = 255
128129 E = E .T .reshape (h , w , c ).int ()
129130
130- return Inorm , H , E
131+ return Inorm , H , E
0 commit comments