11"""borrow Staintool's test cases
22"""
33import unittest
4+ from typing import Optional , cast
5+
46from tests .util import fix_seed , dummy_from_numpy , psnr
5- from torch_staintools .functional .stain_extraction .macenko import MacenkoAlg
6- from torch_staintools .functional .stain_extraction .vahadane import VahadaneAlg
7- from torch_staintools .functional .concentration import get_concentrations
7+ from torch_staintools .constants import CONFIG
8+ from torch_staintools .functional .optimization .sparse_util import METHOD_FACTORIZE
9+ from torch_staintools .functional .stain_extraction .extractor import StainExtraction
10+ from torch_staintools .functional .stain_extraction .macenko import MacenkoAlg , DEFAULT_MACENKO_CONFIG
11+ from torch_staintools .functional .stain_extraction .vahadane import VahadaneAlg , DEFAULT_VAHADANE_CONFIG
12+ from torch_staintools .functional .concentration import ConcentrationSolver , ConcentCfg
813from torch_staintools .functional .tissue_mask import get_tissue_mask , TissueMaskException
914from torch_staintools .functional .utility .implementation import transpose_trailing , img_from_concentration
1015from torchvision .transforms .functional import convert_image_dtype
@@ -24,6 +29,7 @@ class TestFunctional(unittest.TestCase):
2429 rand_img = torch .randint (0 , 255 , (1 , 3 , 256 , 256 ))
2530
2631 THRESH_PSNR = 20
32+ POSITIVE_CONC_CFG = ConcentCfg ()
2733
2834 @staticmethod
2935 def get_dummy_path ():
@@ -40,65 +46,87 @@ def new_dummy_img_tensor_ubyte():
4046 return TestFunctional .DUMMY_IMG_TENSOR .clone ()
4147
4248 @staticmethod
43- def stain_extract (dummy_tensor , get_stain_mat , luminosity_threshold , num_stains , algorithm , regularizer ):
49+ def stain_extract (dummy_tensor : torch .Tensor , get_stain_mat : StainExtraction ,
50+ conc_solver : ConcentrationSolver ,
51+ luminosity_threshold : float , num_stains : int , rng : Optional [torch .Generator ]):
4452
4553 # lab_tensor = rgb_to_lab(convert_image_dtype(dummy_tensor))
4654
47- stain_matrix = get_stain_mat (image = dummy_tensor , luminosity_threshold = luminosity_threshold ,
48- num_stains = num_stains , regularizer = regularizer )
55+ stain_matrix = get_stain_mat (image = dummy_tensor ,
56+ luminosity_threshold = luminosity_threshold ,
57+ num_stains = num_stains , rng = rng )
4958
50- concentration = get_concentrations (dummy_tensor , stain_matrix , algorithm = algorithm ,
51- regularizer = regularizer )
59+ concentration = conc_solver (dummy_tensor , stain_matrix , rng = rng )
5260 c_transposed_src = transpose_trailing (concentration )
5361 reconstructed = img_from_concentration (c_transposed_src , stain_matrix , dummy_tensor .shape , (0 , 1 ))
5462 return stain_matrix , concentration , c_transposed_src , reconstructed
5563
5664 @staticmethod
57- def extract_eval_helper (tester , get_stain_mat , luminosity_threshold ,
58- num_stains , regularizer , dict_algorithm ):
65+ def extract_eval_helper (tester , get_stain_mat : StainExtraction ,
66+ conc_solver : ConcentrationSolver ,
67+ luminosity_threshold : Optional [float ],
68+ num_stains : int , rng : Optional [torch .Generator ]):
5969 device = TestFunctional .device
6070 dummy_tensor_ubyte = TestFunctional .new_dummy_img_tensor_ubyte ().to (device )
6171 # get_stain_mat = MacenkoExtractor()
6272 result_tuple = TestFunctional .stain_extract (dummy_tensor_ubyte , get_stain_mat ,
73+ conc_solver = conc_solver ,
6374 luminosity_threshold = luminosity_threshold ,
6475 num_stains = num_stains ,
65- algorithm = dict_algorithm , regularizer = regularizer )
76+ rng = rng )
6677
6778 stain_matrix , concentration , c_transposed_src , reconstructed = result_tuple
6879 dummy_scaled = convert_image_dtype (dummy_tensor_ubyte , torch .float32 )
6980 psnr_out = psnr (dummy_scaled , reconstructed ).item ()
70- tester .assertTrue (psnr_out > TestFunctional .THRESH_PSNR )
81+ tester .assertTrue (psnr_out > TestFunctional .THRESH_PSNR ,
82+ msg = f"{ psnr_out } vs. { TestFunctional .THRESH_PSNR } . \n "
83+ f"{ get_stain_mat .stain_algorithm .cfg } \n "
84+ f"nan: { torch .isnan (reconstructed ).any ()} \n "
85+ f"Dict pos: { CONFIG .DICT_POSITIVE_DICTIONARY } " )
7186 # size
7287 batch_size , channel_size , height , width = dummy_tensor_ubyte .shape
7388 tester .assertTrue (stain_matrix .shape == (batch_size , num_stains , channel_size ))
7489
7590 # transpose
7691 tester .assertTrue ((c_transposed_src .permute (0 , 2 , 1 ) == concentration ).all ())
7792
78- # manual tissue mask
79- mask = get_tissue_mask (dummy_scaled , luminosity_threshold = luminosity_threshold )
80- tissue_count = mask .sum ()
81- tester .assertTrue (concentration .shape [- 1 ] == tissue_count )
82-
8393 def eval_wrapper (self , extractor ):
8494
8595 # all pixel
86- algorithms = ['ista' , 'cd' , 'ls' ]
87- for alg in algorithms :
88- TestFunctional .extract_eval_helper (self , extractor , luminosity_threshold = None ,
89- num_stains = 2 , regularizer = 0.1 , dict_algorithm = alg )
96+ algorithms = ['ista' , 'cd' , 'ls' , 'fista' ]
97+ dict_constraint_flag = [True ]
98+ for flag in dict_constraint_flag :
99+ CONFIG .DICT_POSITIVE_DICTIONARY = flag
100+ for alg in algorithms :
101+ cfg = TestFunctional .POSITIVE_CONC_CFG
102+ cfg .algorithm = cast (METHOD_FACTORIZE , alg )
103+ cfg .positive = True
104+ solver = ConcentrationSolver (cfg )
105+ TestFunctional .extract_eval_helper (self , extractor , luminosity_threshold = None ,
106+ num_stains = 2 , conc_solver = solver , rng = None )
107+ solver .cfg .positive = False
108+ TestFunctional .extract_eval_helper (self , extractor , luminosity_threshold = None ,
109+ num_stains = 2 , conc_solver = solver , rng = None )
90110
91111 def test_stains (self ):
92- macenko = MacenkoAlg ()
93- vahadane = VahadaneAlg ()
112+ macenko = StainExtraction ( MacenkoAlg (DEFAULT_MACENKO_CONFIG ) )
113+ vahadane = StainExtraction ( VahadaneAlg (DEFAULT_VAHADANE_CONFIG ) )
94114 # not support num_stains other than 2
95- with self .assertRaises (NotImplementedError ):
96- TestFunctional .extract_eval_helper (self , macenko , luminosity_threshold = None ,
97- num_stains = 3 , regularizer = 0.1 , dict_algorithm = 'ista' )
115+ with self .assertRaises (AssertionError ):
116+ TestFunctional .extract_eval_helper (self , macenko ,
117+ conc_solver = ConcentrationSolver (TestFunctional .POSITIVE_CONC_CFG ),
118+ luminosity_threshold = None ,
119+ num_stains = 3 , rng = None )
98120
99121 self .eval_wrapper (macenko )
100122 self .eval_wrapper (vahadane )
101123
124+ # vahadane with rng and lr
125+ vahadane .stain_algorithm .cfg .lr = 0.5
126+ TestFunctional .extract_eval_helper (self , vahadane ,
127+ conc_solver = ConcentrationSolver (TestFunctional .POSITIVE_CONC_CFG ),
128+ luminosity_threshold = None ,
129+ num_stains = 3 , rng = torch .Generator (1 ))
102130 def test_tissue_mask (self ):
103131 device = TestFunctional .device
104132 dummy_scaled = convert_image_dtype (TestFunctional .new_dummy_img_tensor_ubyte (), torch .float32 ).to (device )
0 commit comments