@@ -62,7 +62,7 @@ def __init__(
6262 self ,
6363 summary_writer : Optional [SummaryWriter ] = None ,
6464 log_dir : str = "./runs" ,
65- tag_name = "val_acc " ,
65+ tag_name = "val " ,
6666 interval : int = 1 ,
6767 batch_transform : Callable = lambda x : x ,
6868 output_transform : Callable = lambda x : x ,
@@ -88,7 +88,8 @@ def __init__(
8888 self .class_y_pred : List [Any ] = []
8989
9090 def attach (self , engine : Engine ) -> None :
91- engine .add_event_handler (Events .ITERATION_COMPLETED (every = self .interval ), self , "iteration" )
91+ if self .interval == 1 :
92+ engine .add_event_handler (Events .ITERATION_COMPLETED (every = self .interval ), self , "iteration" )
9293 engine .add_event_handler (Events .EPOCH_COMPLETED (every = self .interval ), self , "epoch" )
9394
9495 def __call__ (self , engine : Engine , action ) -> None :
@@ -130,39 +131,37 @@ def write_images(self, batch_data, output_data, epoch):
130131 image = batch_data [bidx ]["image" ].detach ().cpu ().numpy ()
131132 y = output_data [bidx ]["label" ].detach ().cpu ().numpy ()
132133
133- tag_prefix = f"b{ bidx } - " if self .batch_limit != 1 else ""
134- img_tensor = make_grid (torch .from_numpy (image [:3 ] * 128 + 128 ), normalize = True )
135- self .writer .add_image (tag = f"{ tag_prefix } Image" , img_tensor = img_tensor , global_step = epoch )
136-
137134 if self .class_names :
138135 sig_np = image [:3 ] * 128 + 128
139136 sig_np [0 , :, :] = np .where (image [3 ] > 0 , 1 , sig_np [0 , :, :])
140- sig_tensor = make_grid (torch .from_numpy (sig_np ), normalize = True )
141- self .writer .add_image (tag = f"{ tag_prefix } Signal" , img_tensor = sig_tensor , global_step = epoch )
142137 if np .count_nonzero (image [3 ]) == 0 :
143- self .logger .info (" +++++++++ BUG (Signal is ZERO)" )
138+ self .logger .info (f" { self . tag_name } => +++++++++ BUG (Signal is ZERO)" )
144139
145140 y_pred = output_data [bidx ]["pred" ].detach ().cpu ().numpy ()
146141
147142 y_c = np .argmax (y )
148143 y_pred_c = np .argmax (y_pred )
149144
150- tag_prefix = f"b{ bidx } - " if self .batch_limit != 1 else " "
151- label_pred_tag = f"{ tag_prefix } Label vs Pred:"
145+ tag_prefix = f"{ self . tag_name } - b{ bidx } - " if self .batch_limit != 1 else f" { self . tag_name } - "
146+ label_pred_tag = f"{ tag_prefix } Image/Signal/ Label/ Pred:"
152147
153- y_img = Image .new ("RGB" , ( 200 , 100 ) )
148+ y_img = Image .new ("RGB" , image . shape [ - 2 :] )
154149 draw = ImageDraw .Draw (y_img )
155150 draw .text ((10 , 50 ), self .class_names .get (f"{ y_c } " , f"{ y_c } " ))
156151
157- y_pred_img = Image .new ("RGB" , ( 200 , 100 ) , "green" if y_c == y_pred_c else "red" )
152+ y_pred_img = Image .new ("RGB" , image . shape [ - 2 :] , "green" if y_c == y_pred_c else "red" )
158153 draw = ImageDraw .Draw (y_pred_img )
159154 draw .text ((10 , 50 ), self .class_names .get (f"{ y_pred_c } " , f"{ y_pred_c } " ))
160155
161- label_pred = [np .moveaxis (np .array (y_img ), - 1 , 0 ), np .moveaxis (np .array (y_pred_img ), - 1 , 0 )]
162156 img_tensor = make_grid (
163- tensor = torch .from_numpy (np .array (label_pred )),
164- nrow = 3 ,
165- normalize = False ,
157+ tensor = [
158+ torch .from_numpy (sig_np ),
159+ torch .from_numpy (np .stack ((np .where (image [3 ] > 0 , 255 , 0 ),) * 3 )),
160+ torch .from_numpy (np .moveaxis (np .array (y_img ), - 1 , 0 )),
161+ torch .from_numpy (np .moveaxis (np .array (y_pred_img ), - 1 , 0 )),
162+ ],
163+ nrow = 4 ,
164+ normalize = True ,
166165 pad_value = 10 ,
167166 )
168167 self .writer .add_image (tag = label_pred_tag , img_tensor = img_tensor , global_step = epoch )
@@ -171,35 +170,60 @@ def write_images(self, batch_data, output_data, epoch):
171170 if self .batch_limit == 1 and bidx < (len (batch_data ) - 1 ) and np .sum (y ) == 0 :
172171 continue
173172
173+ tag_prefix = f"{ self .tag_name } - b{ bidx } - " if self .batch_limit != 1 else ""
174+ img_np = image [:3 ] * 128 + 128
175+ if image .shape [0 ] > 3 :
176+ img_np [0 , :, :] = np .where (image [3 ] > 0 , 1 , img_np [0 , :, :])
177+ img_tensor = make_grid (torch .from_numpy (img_np ), normalize = True )
178+ self .writer .add_image (tag = f"{ tag_prefix } Image" , img_tensor = img_tensor , global_step = epoch )
179+
174180 y_pred = output_data [bidx ]["pred" ].detach ().cpu ().numpy ()
175181
176182 for region in range (y_pred .shape [0 ]):
177183 if region == 0 and y_pred .shape [0 ] > 1 : # one-hot; background
178184 continue
179185
186+ cl = np .count_nonzero (y [region ])
187+ cp = np .count_nonzero (y_pred [region ])
180188 self .logger .info (
181- "{} - {} - Image: {};"
189+ "{} => {} - {} - Image: {};"
182190 " Label: {} (nz: {});"
183191 " Pred: {} (nz: {});"
184- " Sig: (pos-nz: {}, neg-nz: {})" .format (
192+ " Diff: {:.2f}%; "
193+ "{}" .format (
194+ self .tag_name ,
185195 bidx ,
186196 region ,
187197 image .shape ,
188198 y .shape ,
189- np . count_nonzero ( y [ region ]) ,
199+ cl ,
190200 y_pred .shape ,
191- np .count_nonzero (y_pred [region ]),
192- np .count_nonzero (image [3 ]) if image .shape [0 ] == 5 else 0 ,
193- np .count_nonzero (image [4 ]) if image .shape [0 ] == 5 else 0 ,
201+ cp ,
202+ 100 * (cp - cl ) / (cl + 1 ),
203+ " Sig: (pos-nz: {}, neg-nz: {})" .format (
204+ np .count_nonzero (image [3 ]) if image .shape [0 ] == 5 else 0 ,
205+ np .count_nonzero (image [4 ]) if image .shape [0 ] == 5 else 0 ,
206+ )
207+ if image .shape [0 ] == 5
208+ else "" ,
194209 )
195210 )
196211
197- tag_prefix = f"b{ bidx } :l{ region } - " if self .batch_limit != 1 else f"l{ region } - "
212+ tag_prefix = (
213+ f"{ self .tag_name } - b{ bidx } :l{ region } - "
214+ if self .batch_limit != 1
215+ else f"{ self .tag_name } - l{ region } - "
216+ )
198217
199218 label_pred = [y [region ][None ], y_pred [region ][None ]]
200219 label_pred_tag = f"{ tag_prefix } Label vs Pred:"
201220 if image .shape [0 ] == 5 :
202- label_pred = [y [region ][None ], y_pred [region ][None ], image [3 ][None ], image [4 ][None ]]
221+ label_pred = [
222+ y [region ][None ],
223+ y_pred [region ][None ],
224+ image [3 ][None ] > 0 ,
225+ image [4 ][None ] > 0 ,
226+ ]
203227 label_pred_tag = f"{ tag_prefix } Label vs Pred vs Pos vs Neg"
204228
205229 img_tensor = make_grid (
@@ -222,12 +246,12 @@ def write_region_metrics(self, epoch):
222246 for n , m in v .items ():
223247 ltext .append (f"{ n } => { m :.4f} " )
224248 cname = self .class_names .get (k , k )
225- self .writer .add_scalar (f"cr_ { k } _{ n } " , m , epoch )
249+ self .writer .add_scalar (f"{ self . tag_name } _cr_ { k } _{ n } " , m , epoch )
226250
227- self .logger .info (f"Epoch[{ epoch } ] Metrics -- Class: { cname } ; { '; ' .join (ltext )} " )
251+ self .logger .info (f"{ self . tag_name } => Epoch[{ epoch } ] Metrics -- Class: { cname } ; { '; ' .join (ltext )} " )
228252 else :
229- self .logger .info (f"Epoch[{ epoch } ] Metrics -- { k } => { v :.4f} " )
230- self .writer .add_scalar (f"cr_ { k } " , v , epoch )
253+ self .logger .info (f"{ self . tag_name } => Epoch[{ epoch } ] Metrics -- { k } => { v :.4f} " )
254+ self .writer .add_scalar (f"{ self . tag_name } _cr_ { k } " , v , epoch )
231255
232256 self .class_y = []
233257 self .class_y_pred = []
@@ -237,13 +261,15 @@ def write_region_metrics(self, epoch):
237261 metric_sum = 0
238262 for region in self .metric_data :
239263 metric = self .metric_data [region ].mean ()
240- self .logger .info (f"Epoch[{ epoch } ] Metrics -- Region: { region :0>2d} , { self .tag_name } : { metric :.4f} " )
264+ self .logger .info (
265+ f"{ self .tag_name } => Epoch[{ epoch } ] Metrics (Dice) -- Region: { region :0>2d} : { metric :.4f} "
266+ )
241267
242- self .writer .add_scalar (f"dice_ { region :0>2d} " , metric , epoch )
268+ self .writer .add_scalar (f"{ self . tag_name } _dice_ { region :0>2d} " , metric , epoch )
243269 metric_sum += metric
244270
245271 metric_avg = metric_sum / len (self .metric_data )
246- self .writer .add_scalar ("dice_regions_avg " , metric_avg , epoch )
272+ self .writer .add_scalar (f" { self . tag_name } _dice_regions_avg " , metric_avg , epoch )
247273
248274 self .writer .flush ()
249275 self .metric_data = {}
0 commit comments