1- from typing import Any , Dict , Optional
1+ from typing import Any , Dict , List , Optional
22
33import numpy as np
44import pytorch_lightning as pl
55import torch
66import torch .nn .functional as F
7+ from skimage .color import label2rgb
8+ from tqdm import tqdm
79
810try :
911 import wandb
1012except ImportError :
1113 raise ImportError ("wandb required. `pip install wandb`" )
1214
15+ from ...inference import PostProcessor
16+ from ...metrics .functional import iou_multiclass , panoptic_quality
17+ from ...utils import get_type_instances , remap_label
1318from ..functional import iou
1419
15- __all__ = ["WandbImageCallback" , "WandbClassBarCallback " , "WandbClassLineCallback " ]
20+ __all__ = ["WandbImageCallback" , "WandbClassLineCallback " , "WandbGetExamplesCallback " ]
1621
1722
1823class WandbImageCallback (pl .Callback ):
@@ -135,7 +140,7 @@ def compute(
135140 met = iou (pred , target ).mean (dim = 0 )
136141 return met .to ("cpu" ).numpy ()
137142
138- def on_train_batch_end (
143+ def train_batch_end (
139144 self ,
140145 trainer : pl .Trainer ,
141146 pl_module : pl .LightningModule ,
@@ -147,7 +152,7 @@ def on_train_batch_end(
147152 """Log the inputs and outputs of the model to wandb."""
148153 self .batch_end (trainer , outputs ["soft_masks" ], batch , batch_idx , phase = "train" )
149154
150- def on_validation_batch_end (
155+ def validation_batch_end (
151156 self ,
152157 trainer : pl .Trainer ,
153158 pl_module : pl .LightningModule ,
@@ -159,47 +164,17 @@ def on_validation_batch_end(
159164 """Log the inputs and outputs of the model to wandb."""
160165 self .batch_end (trainer , outputs ["soft_masks" ], batch , batch_idx , phase = "val" )
161166
162-
163- class WandbClassBarCallback (WandbIoUCallback ):
164- def __init__ (
165- self ,
166- type_classes : Dict [str , int ],
167- sem_classes : Optional [Dict [str , int ]],
168- freq : int = 100 ,
169- ) -> None :
170- """Create a wandb callback that logs per-class mIoU at batch ends."""
171- super ().__init__ (type_classes , sem_classes , freq )
172-
173- def get_bar (self , iou : np .ndarray , classes : Dict [int , str ], title : str ) -> Any :
174- """Return a wandb bar plot object of the current per class iou values."""
175- batch_data = [[lab , val ] for lab , val in zip (list (classes .values ()), iou )]
176- table = wandb .Table (data = batch_data , columns = ["label" , "value" ])
177- return wandb .plot .bar (table , "label" , "value" , title = title )
178-
179- def batch_end (
167+ def test_batch_end (
180168 self ,
181169 trainer : pl .Trainer ,
170+ pl_module : pl .LightningModule ,
182171 outputs : Dict [str , torch .Tensor ],
183172 batch : Dict [str , torch .Tensor ],
184173 batch_idx : int ,
185- phase : str ,
174+ dataloader_idx : int ,
186175 ) -> None :
187- """Log metrics at every 100th step to wandb."""
188- if batch_idx % self .freq == 0 :
189- log_dict = {}
190- if "type" in list (batch .keys ()):
191- iou = self .compute ("type" , outputs , batch )
192- log_dict [f"{ phase } /type_ious_bar" ] = self .get_bar (
193- list (iou ), self .type_classes , title = "Cell class mIoUs"
194- )
195-
196- if "sem" in list (batch .keys ()):
197- iou = self .compute ("sem" , outputs , batch )
198- log_dict [f"{ phase } /sem_ious_bar" ] = self .get_bar (
199- list (iou ), self .sem_classes , title = "Sem class mIoUs"
200- )
201-
202- trainer .logger .experiment .log (log_dict )
176+ """Log the inputs and outputs of the model to wandb."""
177+ self .batch_end (trainer , outputs ["soft_masks" ], batch , batch_idx , phase = "test" )
203178
204179
205180class WandbClassLineCallback (WandbIoUCallback ):
@@ -240,3 +215,245 @@ def batch_end(
240215 )
241216
242217 trainer .logger .experiment .log (log_dict )
218+
219+ def on_validation_batch_end (
220+ self ,
221+ trainer : pl .Trainer ,
222+ pl_module : pl .LightningModule ,
223+ outputs : Dict [str , torch .Tensor ],
224+ batch : Dict [str , torch .Tensor ],
225+ batch_idx : int ,
226+ dataloader_idx : int ,
227+ ) -> None :
228+ """Call the callback at val time."""
229+ self .validation_batch_end (
230+ trainer , pl_module , outputs , batch , batch_idx , dataloader_idx
231+ )
232+
233+ def on_train_batch_end (
234+ self ,
235+ trainer : pl .Trainer ,
236+ pl_module : pl .LightningModule ,
237+ outputs : Dict [str , torch .Tensor ],
238+ batch : Dict [str , torch .Tensor ],
239+ batch_idx : int ,
240+ dataloader_idx : int ,
241+ ) -> None :
242+ """Call the callback at val time."""
243+ self .train_batch_end (
244+ trainer , pl_module , outputs , batch , batch_idx , dataloader_idx
245+ )
246+
247+
248+ class WandbGetExamplesCallback (pl .Callback ):
249+ def __init__ (
250+ self ,
251+ type_classes : Dict [str , int ],
252+ sem_classes : Optional [Dict [str , int ]],
253+ instance_postproc : str ,
254+ inst_key : str ,
255+ aux_key : str ,
256+ inst_act : str = "softmax" ,
257+ aux_act : str = None ,
258+ ) -> None :
259+ """Create a wandb callback that logs img examples at test time."""
260+ super ().__init__ ()
261+ self .type_classes = type_classes
262+ self .sem_classes = sem_classes
263+ self .inst_key = inst_key
264+ self .aux_key = aux_key
265+
266+ self .inst_act = inst_act
267+ self .aux_act = aux_act
268+
269+ self .postprocessor = PostProcessor (
270+ instance_postproc = instance_postproc , inst_key = inst_key , aux_key = aux_key
271+ )
272+
273+ def post_proc (
274+ self , outputs : Dict [str , torch .Tensor ]
275+ ) -> List [Dict [str , np .ndarray ]]:
276+ """Apply post processing to the outputs."""
277+ B , _ , _ , _ = outputs [self .inst_key ].shape
278+
279+ inst = outputs [self .inst_key ].detach ()
280+ if self .inst_act == "softmax" :
281+ inst = F .softmax (inst , dim = 1 )
282+ if self .inst_act == "sigmoid" :
283+ inst = torch .sigmoid (inst )
284+
285+ aux = outputs [self .aux_key ].detach ()
286+ if self .aux_act == "tanh" :
287+ aux = torch .tanh (aux )
288+
289+ sem = None
290+ if "sem" in outputs .keys ():
291+ sem = outputs ["sem" ].detach ()
292+ sem = F .softmax (sem , dim = 1 ).cpu ().numpy ()
293+
294+ typ = None
295+ if "type" in outputs .keys ():
296+ typ = outputs ["type" ].detach ()
297+ typ = F .softmax (typ , dim = 1 ).cpu ().numpy ()
298+
299+ inst = inst .cpu ().numpy ()
300+ aux = aux .cpu ().numpy ()
301+ outmaps = []
302+ for i in range (B ):
303+ maps = {
304+ self .inst_key : inst [i ],
305+ self .aux_key : aux [i ],
306+ }
307+ if sem is not None :
308+ maps ["sem" ] = sem [i ]
309+ if typ is not None :
310+ maps ["type" ] = typ [i ]
311+
312+ out = self .postprocessor .post_proc_pipeline (maps )
313+ outmaps .append (out )
314+
315+ return outmaps
316+
317+ def count_pixels (self , img : np .ndarray , shape : int ):
318+ """Compute pixel proportions per class."""
319+ return [float (p ) / shape ** 2 for p in np .bincount (img .astype (int ).flatten ())]
320+
321+ def epoch_end (self , trainer , pl_module ) -> None :
322+ """Log metrics at the epoch end."""
323+ decs = [list (it .keys ()) for it in pl_module .heads .values ()]
324+ outheads = [item for sublist in decs for item in sublist ]
325+
326+ loader = trainer .datamodule .test_dataloader ()
327+ runid = trainer .logger .experiment .id
328+ test_res_at = wandb .Artifact ("test_pred_" + runid , "test_preds" )
329+
330+ # Create artifact
331+ runid = trainer .logger .experiment .id
332+ test_res_at = wandb .Artifact ("test_pred_" + runid , "test_preds" )
333+
334+ # Init wb table
335+ cols = ["id" , "inst_gt" , "inst_pred" , "bPQ" ]
336+
337+ if "type" in outheads :
338+ cols += [
339+ "cell_types" ,
340+ * [f"pq_{ c } " for c in self .type_classes .values () if c != "bg" ],
341+ ]
342+ if "sem" in outheads :
343+ cols += [
344+ "tissue_types" ,
345+ * [f"iou_{ c } " for c in self .sem_classes .values () if c != "bg" ],
346+ ]
347+
348+ model_res_table = wandb .Table (columns = cols )
349+
350+ #
351+ with tqdm (loader , unit = "batch" ) as loader :
352+ with torch .no_grad ():
353+ for batch_idx , batch in enumerate (loader ):
354+ soft_masks = pl_module .forward (batch ["image" ].to (pl_module .device ))
355+ imgs = list (batch ["image" ].detach ().cpu ().numpy ()) # [(C, H, W)*B]
356+ inst_targets = list (batch ["inst_map" ].detach ().cpu ().numpy ())
357+
358+ outmaps = self .post_proc (soft_masks )
359+
360+ type_targets = None
361+ if "type" in list (batch .keys ()):
362+ type_targets = list (
363+ batch ["type" ].detach ().cpu ().numpy ()
364+ ) # [(C, H, W)*B]
365+
366+ sem_targets = None
367+ if "sem" in list (batch .keys ()):
368+ sem_targets = list (
369+ batch ["sem" ].detach ().cpu ().numpy ()
370+ ) # [(C, H, W)*B]
371+
372+ # loop the images in batch
373+ for i , (pred , im , inst_target ) in enumerate (
374+ zip (outmaps , imgs , inst_targets )
375+ ):
376+ inst_targ = remap_label (inst_target )
377+ inst_pred = remap_label (pred ["inst" ])
378+
379+ wb_inst_gt = wandb .Image (label2rgb (inst_targ , bg_label = 0 ))
380+ wb_inst_pred = wandb .Image (label2rgb (inst_pred , bg_label = 0 ))
381+ pq_inst = panoptic_quality (inst_targ , inst_pred )["pq" ]
382+
383+ row = [
384+ f"test_batch_{ batch_idx } _{ i } " ,
385+ wb_inst_gt ,
386+ wb_inst_pred ,
387+ pq_inst ,
388+ ]
389+
390+ if type_targets is not None :
391+ per_class_pq = [
392+ panoptic_quality (
393+ remap_label (
394+ get_type_instances (
395+ inst_targ , type_targets [i ], j
396+ )
397+ ),
398+ remap_label (
399+ get_type_instances (inst_pred , pred ["type" ], j )
400+ ),
401+ )["pq" ]
402+ for j in self .type_classes .keys ()
403+ if j != 0
404+ ]
405+
406+ type_classes_set = wandb .Classes (
407+ [
408+ {"name" : name , "id" : id }
409+ for id , name in self .type_classes .items ()
410+ if id != 0
411+ ]
412+ )
413+ wb_type = wandb .Image (
414+ im .transpose (1 , 2 , 0 ),
415+ classes = type_classes_set ,
416+ masks = {
417+ "ground_truth" : {"mask_data" : type_targets [i ]},
418+ "pred" : {"mask_data" : pred ["type" ]},
419+ },
420+ )
421+
422+ row += [wb_type , * per_class_pq ]
423+
424+ if sem_targets is not None :
425+ per_class_iou = list (
426+ iou_multiclass (
427+ sem_targets [i ], pred ["sem" ], len (self .sem_classes )
428+ )
429+ )
430+
431+ sem_classes_set = wandb .Classes (
432+ [
433+ {"name" : name , "id" : id }
434+ for id , name in self .sem_classes .items ()
435+ if id != 0
436+ ]
437+ )
438+ wb_sem = wandb .Image (
439+ im .transpose (1 , 2 , 0 ),
440+ classes = sem_classes_set ,
441+ masks = {
442+ "ground_truth" : {"mask_data" : sem_targets [i ]},
443+ "pred" : {"mask_data" : pred ["sem" ]},
444+ },
445+ )
446+ row += [wb_sem , * per_class_iou [1 :]]
447+
448+ model_res_table .add_data (* row )
449+
450+ test_res_at .add (model_res_table , "model_batch_results" )
451+ trainer .logger .experiment .log_artifact (test_res_at )
452+
453+ def on_test_epoch_end (
454+ self ,
455+ trainer : pl .Trainer ,
456+ pl_module : pl .LightningModule ,
457+ ) -> None :
458+ """Call the callback at test time."""
459+ self .epoch_end (trainer , pl_module )
0 commit comments