55from  pathos .multiprocessing  import  ThreadPool  as  Pool 
66from  tqdm  import  tqdm 
77
8+ from  cellseg_models_pytorch .inference  import  BaseInferer 
9+ 
810from  ..metrics  import  (
911    accuracy_multiclass ,
1012    aggregated_jaccard_index ,
4345
4446class  BenchMarker :
4547    def  __init__ (
46-         self , pred_dir : str , true_dir : str , classes : Dict [str , int ] =  None 
48+         self ,
49+         true_dir : str ,
50+         pred_dir : str  =  None ,
51+         inferer : BaseInferer  =  None ,
52+         type_classes : Dict [str , int ] =  None ,
53+         sem_classes : Dict [str , int ] =  None ,
4754    ) ->  None :
4855        """Run benchmarking, given prediction and ground truth mask folders. 
4956
57+         NOTE: Can also take in an Inferer object. 
58+ 
5059        Parameters 
5160        ---------- 
52-             pred_dir : str 
53-                 Path to the prediction .mat files. The pred files have to have matching 
54-                 names to the gt filenames. 
5561            true_dir : str 
5662                Path to the ground truth .mat files. The gt files have to have matching 
5763                names to the pred filenames. 
58-             classes : Dict[str, int], optional 
59-                 Class dict. E.g. {"bg": 0, "epithelial": 1, "immmune": 2} 
64+             pred_dir : str, optional 
65+                 Path to the prediction .mat files. The pred files have to have matching 
66+                 names to the gt filenames. If None, the inferer object storing the 
67+                 predictions will be used instead. 
68+             inferer : BaseInferer, optional 
69+                 Infere object storing predictions of a model. If None, the `pred_dir` 
70+                 will be used to load the predictions instead. 
71+             type_classes : Dict[str, int], optional 
72+                 Cell type class dict. E.g. {"bg": 0, "epithelial": 1, "immmune": 2} 
73+             sem_classes : Dict[str, int], optional 
74+                 Tissue type class dict. E.g. {"bg": 0, "epithel": 1, "stroma": 2} 
6075        """ 
61-         self .pred_dir  =  Path (pred_dir )
76+         if  pred_dir  is  None  and  inferer  is  None :
77+             raise  ValueError (
78+                 "Both `inferer` and `pred_dir` cannot be set to None at the same time." 
79+             )
80+ 
6281        self .true_dir  =  Path (true_dir )
63-         self .classes  =  classes 
82+         self .type_classes  =  type_classes 
83+         self .sem_classes  =  sem_classes 
84+ 
85+         if  pred_dir  is  not   None :
86+             self .pred_dir  =  Path (pred_dir )
87+         else :
88+             self .pred_dir  =  None 
89+ 
90+         self .inferer  =  inferer 
91+ 
92+         if  inferer  is  not   None  and  pred_dir  is  None :
93+             try :
94+                 self .inferer .out_masks 
95+                 self .inferer .soft_masks 
96+             except  AttributeError :
97+                 raise  AttributeError (
98+                     "Did not find `out_masks` or `soft_masks` attributes. " 
99+                     "To get these, run inference with `inferer.infer()`. " 
100+                     "Remember to set `save_intermediate` to True for the inferer.`" 
101+                 )
64102
65103    @staticmethod  
66104    def  compute_inst_metrics (
@@ -100,16 +138,16 @@ def compute_inst_metrics(
100138                f"An illegal metric was given. Got: { metrics }  , allowed: { allowed }  " 
101139            )
102140
103-         # Skip empty GTs 
104-         if  len (np .unique (true )) >  1 :
141+         # Do not run metrics computation if there are no instances in neither of masks 
142+         res  =  {}
143+         if  len (np .unique (true )) >  1  or  len (np .unique (pred )) >  1 :
105144            true  =  remap_label (true )
106145            pred  =  remap_label (pred )
107146
108147            met  =  {}
109148            for  m  in  metrics :
110149                met [m ] =  INST_METRIC_LOOKUP [m ]
111150
112-             res  =  {}
113151            for  k , m  in  met .items ():
114152                score  =  m (true , pred )
115153
@@ -121,8 +159,19 @@ def compute_inst_metrics(
121159
122160            res ["name" ] =  name 
123161            res ["type" ] =  type 
162+         else :
163+             res ["name" ] =  name 
164+             res ["type" ] =  type 
124165
125-             return  res 
166+             for  m  in  metrics :
167+                 if  m  ==  "pq" :
168+                     res ["pq" ] =  - 1.0 
169+                     res ["sq" ] =  - 1.0 
170+                     res ["dq" ] =  - 1.0 
171+                 else :
172+                     res [m ] =  - 1.0 
173+ 
174+         return  res 
126175
127176    @staticmethod  
128177    def  compute_sem_metrics (
@@ -158,6 +207,9 @@ def compute_sem_metrics(
158207                A dictionary where metric names are mapped to metric values. 
159208                e.g. {"iou": 0.5, "f1score": 0.55, "name": "sample1"} 
160209        """ 
210+         if  not  isinstance (metrics , tuple ) and  not  isinstance (metrics , list ):
211+             raise  ValueError ("`metrics` must be either a list or tuple of values." )
212+ 
161213        allowed  =  list (SEM_METRIC_LOOKUP .keys ())
162214        if  not  all ([m  in  allowed  for  m  in  metrics ]):
163215            raise  ValueError (
@@ -227,20 +279,6 @@ def run_metrics(
227279
228280        return  metrics 
229281
230-     def  _read_files (self ) ->  List [Tuple [np .ndarray , np .ndarray , str ]]:
231-         """Read in the files from the input folders.""" 
232-         preds  =  sorted (self .pred_dir .glob ("*" ))
233-         trues  =  sorted (self .true_dir .glob ("*" ))
234- 
235-         masks  =  []
236-         for  truef , predf  in  zip (trues , preds ):
237-             true  =  FileHandler .read_mat (truef , return_all = True )
238-             pred  =  FileHandler .read_mat (predf , return_all = True )
239-             name  =  truef .name 
240-             masks .append ((true , pred , name ))
241- 
242-         return  masks 
243- 
244282    def  run_inst_benchmark (
245283        self , how : str  =  "binary" , metrics : Tuple [str , ...] =  ("pq" ,)
246284    ) ->  None :
@@ -268,17 +306,32 @@ def run_inst_benchmark(
268306        if  how  not  in   allowed :
269307            raise  ValueError (f"Illegal arg `how`. Got: { how }  , Allowed: { allowed }  " )
270308
271-         masks  =  self ._read_files ()
309+         trues  =  sorted (self .true_dir .glob ("*" ))
310+ 
311+         preds  =  None 
312+         if  self .pred_dir  is  not   None :
313+             preds  =  sorted (self .pred_dir .glob ("*" ))
314+ 
315+         ik  =  "inst"  if  self .pred_dir  is  None  else  "inst_map" 
316+         tk  =  "type"  if  self .pred_dir  is  None  else  "type_map" 
272317
273318        res  =  []
274-         if  how  ==  "multi"  and  self .classes  is  not   None :
275-             for  c , i  in  list (self .classes .items ())[1 :]:
319+         if  how  ==  "multi"  and  self .type_classes  is  not   None :
320+             for  c , i  in  list (self .type_classes .items ())[1 :]:
276321                args  =  []
277-                 for  true , pred , name  in  masks :
322+                 for  j , true_fn  in  enumerate (trues ):
323+                     name  =  true_fn .name 
324+                     true  =  FileHandler .read_mat (true_fn , return_all = True )
325+ 
326+                     if  preds  is  None :
327+                         pred  =  self .inferer .out_masks [name [:- 4 ]]
328+                     else :
329+                         pred  =  FileHandler .read_mat (preds [j ], return_all = True )
330+ 
278331                    true_inst  =  true ["inst_map" ]
279-                     pred_inst  =  pred ["inst_map" ]
280332                    true_type  =  true ["type_map" ]
281-                     pred_type  =  pred ["type_map" ]
333+                     pred_inst  =  pred [ik ]
334+                     pred_type  =  pred [tk ]
282335
283336                    pred_type  =  get_type_instances (pred_inst , pred_type , i )
284337                    true_type  =  get_type_instances (true_inst , true_type , i )
@@ -287,9 +340,17 @@ def run_inst_benchmark(
287340                res .extend ([metric  for  metric  in  met  if  metric ])
288341        else :
289342            args  =  []
290-             for  true , pred , name  in  masks :
343+             for  i , true_fn  in  enumerate (trues ):
344+                 name  =  true_fn .name 
345+                 true  =  FileHandler .read_mat (true_fn , return_all = True )
346+ 
347+                 if  preds  is  None :
348+                     pred  =  self .inferer .out_masks [name [:- 4 ]]
349+                 else :
350+                     pred  =  FileHandler .read_mat (preds [i ], return_all = True )
351+ 
291352                true  =  true ["inst_map" ]
292-                 pred  =  pred ["inst_map" ]
353+                 pred  =  pred [ik ]
293354                args .append ((true , pred , name , metrics ))
294355            met  =  self .run_metrics (args , "inst" , "binary instance seg" )
295356            res .extend ([metric  for  metric  in  met  if  metric ])
@@ -310,14 +371,40 @@ def run_sem_benchmark(self, metrics: Tuple[str, ...] = ("iou",)) -> Dict[str, An
310371            Dict[str, Any]: 
311372                Dictionary mapping the metrics to values + metadata. 
312373        """ 
313-         masks  =  self ._read_files ()
374+         trues  =  sorted (self .true_dir .glob ("*" ))
375+ 
376+         preds  =  None 
377+         if  self .pred_dir  is  not   None :
378+             preds  =  sorted (self .pred_dir .glob ("*" ))
379+ 
380+         sk  =  "sem"  if  self .pred_dir  is  None  else  "sem_map" 
314381
315382        args  =  []
316-         for  true , pred , name  in  masks :
383+         for  i , true_fn  in  enumerate (trues ):
384+             name  =  true_fn .name 
385+             true  =  FileHandler .read_mat (true_fn , return_all = True )
386+ 
387+             if  preds  is  None :
388+                 pred  =  self .inferer .out_masks [name [:- 4 ]]
389+             else :
390+                 pred  =  FileHandler .read_mat (preds [i ], return_all = True )
317391            true  =  true ["sem_map" ]
318-             pred  =  pred ["sem_map" ]
319-             args .append ((true , pred , name , len (self .classes ), metrics ))
392+             pred  =  pred [sk ]
393+             args .append ((true , pred , name , len (self .sem_classes ), metrics ))
394+ 
320395        met  =  self .run_metrics (args , "sem" , "semantic seg" )
321-         res  =  [metric  for  metric  in  met  if  metric ]
396+         ires  =  [metric  for  metric  in  met  if  metric ]
397+ 
398+         # re-format 
399+         res  =  []
400+         for  r  in  ires :
401+             for  k , val  in  self .sem_classes .items ():
402+                 cc  =  {
403+                     "name" : r ["name" ],
404+                     "type" : k ,
405+                 }
406+                 for  m  in  metrics :
407+                     cc [m ] =  r [m ][val ]
408+                 res .append (cc )
322409
323410        return  res 
0 commit comments