@@ -21,8 +21,10 @@ import matplotlib.pyplot as plt
2121import csv
2222from natsort import natsorted , index_natsorted
2323import matplotlib .cm as cm
24+ from matplotlib .patches import Rectangle
2425from datetime import datetime
2526import socket
27+ import statistics
2628
2729repodir = os .path .dirname (os .path .dirname (os .path .realpath (__file__ )))
2830
@@ -193,9 +195,10 @@ def main():
193195 error_ratios = [float (x ) for x in FLAGS .error_ratios .split (',' )]
194196
195197 train_accuracy , train_loss , train_time , train_step , \
196- validation_precision , validation_recall , validation_accuracy , \
198+ validation_precision , validation_recall , \
199+ validation_precision_mean , validation_recall_mean , \
197200 validation_time , validation_step , \
198- _ , _ , _ , \
201+ _ , _ , _ , _ , \
199202 labels_touse , label_counts , _ , _ , batch_size , _ = \
200203 read_logs (FLAGS .logdir )
201204 training_set_size = {k : len (label_counts [k ]["training" ]) * \
@@ -231,12 +234,12 @@ def main():
231234 train_step [model ].index (y )]) \
232235 for (x ,y ) in validation_intervals ]
233236 ax4 .plot ([(x + y )/ 2 for (x ,y ) in validation_intervals ], train_loss_ave , 'm' , label = 'Loss mean' )
234- if validation_accuracy [model ]:
235- ax1 .plot (validation_step [model ], validation_accuracy [model ], 'r' , label = 'Validation' )
236- ax1 .set_title (model + " " + str (round (validation_accuracy [model ][- 1 ],1 ))+ '%' )
237- ax1 .set_ylim (bottom = min (validation_accuracy [model ]))
237+ if validation_recall_mean [model ]:
238+ ax1 .plot (validation_step [model ], validation_recall_mean [model ], 'r' , label = 'Validation' )
239+ ax1 .set_title (model + " " + str (round (validation_recall_mean [model ][- 1 ],1 ))+ '%' )
240+ ax1 .set_ylim (bottom = min (validation_recall_mean [model ]))
238241 ax1 .set_ylim (top = 100 )
239- ax1 .set_ylabel ('Overall accuracy ' )
242+ ax1 .set_ylabel ('Overall recall ' )
240243 ax1 .set_xlim (0 ,1 + len (train_step [model ]))
241244 if imodel == len (keys_to_plot )- 1 :
242245 handles1 , labels1 = ax1 .get_legend_handles_labels ()
@@ -269,7 +272,7 @@ def main():
269272 ax3 .set_xlabel ('Epoch' )
270273
271274 fig .tight_layout ()
272- plt .savefig (os .path .join (FLAGS .logdir ,'train-loss.pdf' ))
275+ plt .savefig (os .path .join (FLAGS .logdir ,'train-validation- loss.pdf' ))
273276 plt .close ()
274277
275278 nrows , ncols = layout (len (keys_to_plot ))
@@ -348,7 +351,7 @@ def main():
348351 bbox_extra_artists = (lgd ,), bbox_inches = 'tight' )
349352 plt .close ()
350353
351- if not len (validation_accuracy [keys_to_plot [0 ]]):
354+ if not len (validation_recall_mean [keys_to_plot [0 ]]):
352355 sys .exit ()
353356
354357 if len (keys_to_plot )> 1 :
@@ -358,29 +361,29 @@ def main():
358361 for model in keys_to_plot :
359362 scaled_validation_time , units = choose_units (validation_time [model ])
360363
361- line , = ax .plot (validation_step [model ], validation_accuracy [model ])
364+ line , = ax .plot (validation_step [model ], validation_recall_mean [model ])
362365 line .set_label (model )
363366 ax .set_ylim (top = 100 )
364367 ax .set_xlabel ('Step' )
365- ax .set_ylabel ('Overall accuracy ' )
368+ ax .set_ylabel ('Overall recall ' )
366369 #ax.legend(loc='lower right')
367370
368371 ax = fig .add_subplot (2 ,3 ,2 )
369372 for model in keys_to_plot :
370373 ax .plot ([x * batch_size [model ]/ training_set_size [model ] \
371- for x in validation_step [model ]], validation_accuracy [model ])
374+ for x in validation_step [model ]], validation_recall_mean [model ])
372375 ax .set_ylim (top = 100 )
373376 ax .set_xlabel ('Epoch' )
374- ax .set_ylabel ('Overall accuracy ' )
377+ ax .set_ylabel ('Overall recall ' )
375378
376379 ax = fig .add_subplot (2 ,3 ,3 )
377380 for model in keys_to_plot :
378- idx = min (len (scaled_validation_time ), len (validation_accuracy [model ]))
379- line , = ax .plot (scaled_validation_time [:idx ], validation_accuracy [model ][:idx ])
381+ idx = min (len (scaled_validation_time ), len (validation_recall_mean [model ]))
382+ line , = ax .plot (scaled_validation_time [:idx ], validation_recall_mean [model ][:idx ])
380383 line .set_label (model )
381384 ax .set_ylim (top = 100 )
382385 ax .set_xlabel ('Time (' + units + ')' )
383- ax .set_ylabel ('Overall accuracy ' )
386+ ax .set_ylabel ('Overall recall ' )
384387
385388 ax = fig .add_subplot (2 ,3 ,4 )
386389 for model in keys_to_plot :
@@ -409,45 +412,46 @@ def main():
409412
410413 fig .tight_layout ()
411414
412- plt .savefig (os .path .join (FLAGS .logdir ,'train -overlay.pdf' ))
415+ plt .savefig (os .path .join (FLAGS .logdir ,'validation -overlay.pdf' ))
413416 plt .close ()
414417
415418 fig = plt .figure (figsize = (6.4 , 4.8 ))
416419
417- model_validation_accuracy = np .zeros (len (validation_accuracy [keys_to_plot [0 ]]))
420+ model_validation_recall = np .zeros (len (validation_recall_mean [keys_to_plot [0 ]]))
418421 nmodels = 0
419422 for model in keys_to_plot :
420423 if validation_step [model ] != validation_step [keys_to_plot [0 ]]:
421424 print ("WARNING: not all checkpoint steps are the same for " + model )
422425 continue
423- model_validation_accuracy += validation_accuracy [model ]
426+ model_validation_recall += validation_recall_mean [model ]
424427 model_validation_step = validation_step [model ]
425428 nmodels += 1
426429 ax = fig .add_subplot (1 ,1 ,1 )
427- line , = ax .plot (model_validation_step , model_validation_accuracy / nmodels )
430+ line , = ax .plot (model_validation_step , model_validation_recall / nmodels )
428431 line .set_label ("model average" )
429432 ax .set_ylim (top = 100 )
430433 ax .set_xlabel ('Step' )
431- ax .set_ylabel ('Overall accuracy ' )
434+ ax .set_ylabel ('Overall recall ' )
432435 ax .legend (loc = (1.05 , 0.0 ))
433436 ax .grid (True )
434437
435438 fig .tight_layout ()
436439
437- plt .savefig (os .path .join (FLAGS .logdir ,'train-summed .pdf' ))
440+ plt .savefig (os .path .join (FLAGS .logdir ,'validation-average .pdf' ))
438441 plt .close ()
439442
440443 summed_confusion_matrix , confusion_matrices , labels = \
441444 parse_confusion_matrices (FLAGS .logdir , next (iter (keys_to_plot )).split ('_' )[0 ])
442445
443446 recall_matrices = {}
444447 precision_matrices = {}
445- accuracies = {}
448+ precisions_mean = {}
449+ recalls_mean = {}
446450 for model in keys_to_plot :
447- recall_matrices [model ], precision_matrices [model ], accuracies [model ] = \
451+ precision_matrices [ model ], recall_matrices [model ], precisions_mean [model ], recalls_mean [model ] = \
448452 normalize_confusion_matrix (confusion_matrices [model ])
449453
450- recall_summed_matrix , precision_summed_matrix , accuracy_summed = \
454+ precision_summed_matrix , recall_summed_matrix , precision_summed , recall_summed = \
451455 normalize_confusion_matrix (summed_confusion_matrix )
452456
453457 from mpl_toolkits .axes_grid1 import make_axes_locatable
@@ -468,41 +472,90 @@ def main():
468472 ax .invert_yaxis ()
469473 ax .set_xlabel ('Classification' )
470474 ax .set_ylabel ('Annotation' )
471- ax .set_title (str (round (accuracy_summed ,1 ))+ "%" )
475+ ax .set_title ("P=" + str (round (precision_summed ,1 ))+ "% " +
476+ "R=" + str (round (recall_summed ,1 ))+ "%" )
472477
473478 ax = plt .subplot (1 ,3 ,2 )
479+ precisions_all = []
480+ recalls_all = []
474481 for model in keys_to_plot :
475482 ax .set_prop_cycle (None )
476483 for (ilabel ,label ) in enumerate (labels ):
477- line , = ax . plot (100 * recall_matrices [model ][ilabel ][ilabel ],
478- 100 * precision_matrices [model ][ilabel ][ilabel ],
479- 'o' , markeredgecolor = 'k' )
484+ precisions_all . append (100 * precision_matrices [model ][ilabel ][ilabel ])
485+ recalls_all . append ( 100 * recall_matrices [model ][ilabel ][ilabel ])
486+ line , = ax . plot ( recalls_all [ - 1 ], precisions_all [ - 1 ], 'o' , markeredgecolor = 'k' )
480487 if model == keys_to_plot [0 ]:
481488 line .set_label (label )
482-
489+
490+ if len (recalls_all )> 1 :
491+ ax .autoscale_view ()
492+ ax .set_autoscale_on (False )
493+ miny = ax .get_ylim ()[0 ]
494+ minx = ax .get_xlim ()[0 ]
495+ x = statistics .mean (recalls_all )
496+ w = statistics .stdev (recalls_all )
497+ avebox = Rectangle ((x - w ,miny ),2 * w ,100 )
498+ ax .plot ([x ,x ],[miny ,100 ],'w-' )
499+ pc = PatchCollection ([avebox ], facecolor = 'lightgray' , alpha = 0.5 )
500+ ax .add_collection (pc )
501+ y = statistics .mean (precisions_all )
502+ h = statistics .stdev (precisions_all )
503+ avebox = Rectangle ((minx ,y - h ),100 ,2 * h )
504+ ax .plot ([minx ,100 ],[y ,y ],'w-' )
505+ pc = PatchCollection ([avebox ], facecolor = 'lightgray' , alpha = 0.5 )
506+ ax .add_collection (pc )
507+ else :
508+ x ,y ,w ,h = recalls_all [0 ], precisions_all [0 ], 0 , 0
509+
483510 ax .set_xlim (right = 100 )
484511 ax .set_ylim (top = 100 )
485512 ax .set_xlabel ('Recall (%)' )
486513 ax .set_ylabel ('Precision (%)' )
514+ ax .set_title ("P=" + str (round (y ,1 ))+ "+/-" + str (round (h ,1 ))+ "% " +
515+ "R=" + str (round (x ,1 ))+ "+/-" + str (round (w ,1 ))+ "%" )
487516 ax .legend (loc = (1.05 , 0.0 ))
488517
489518 ax = fig .add_subplot (1 ,3 ,3 )
490- accuracies_ordered = [accuracies [k ] for k in keys_to_plot ]
519+ precisions_mean_ordered = [precisions_mean [k ] for k in keys_to_plot ]
520+ recalls_mean_ordered = [recalls_mean [k ] for k in keys_to_plot ]
491521 print ('models=' , keys_to_plot )
492- print ('accuracies=' , accuracies_ordered )
493-
494- x = 1
495- y = statistics .mean (accuracies_ordered )
496- h = statistics .stdev (accuracies_ordered ) if len (accuracies )> 1 else 0
497- avebox = Rectangle ((y - h ,- 0.25 ),2 * h ,len (accuracies )- 1 + 0.5 )
498- ax .plot ([y ,y ],[- 0.25 ,len (accuracies )- 1 + 0.25 ],'w-' )
499- pc = PatchCollection ([avebox ], facecolor = 'lightgray' )
500- ax .add_collection (pc )
501-
502- ax .plot (accuracies_ordered , keys_to_plot , 'k.' )
503- ax .set_xlabel ('Overall accuracy (%)' )
504- ax .set_ylabel ('Model' )
505- ax .set_title (str (round (y ,1 ))+ "+/-" + str (round (h ,1 ))+ "%" )
522+ print ('precisions=' , precisions_mean_ordered )
523+ print ('recalls=' , recalls_mean_ordered )
524+
525+ for model in keys_to_plot :
526+ ax .set_prop_cycle (None )
527+ line , = ax .plot (recalls_mean [model ],
528+ precisions_mean [model ],
529+ 'o' , markeredgecolor = 'k' )
530+ line .set_label (model )
531+
532+ if len (recalls_mean_ordered )> 1 :
533+ ax .autoscale_view ()
534+ ax .set_autoscale_on (False )
535+ miny = ax .get_ylim ()[0 ]
536+ minx = ax .get_xlim ()[0 ]
537+ x = statistics .mean (recalls_mean_ordered )
538+ w = statistics .stdev (recalls_mean_ordered )
539+ avebox = Rectangle ((x - w ,miny ),2 * w ,100 )
540+ ax .plot ([x ,x ],[miny ,100 ],'w-' )
541+ pc = PatchCollection ([avebox ], facecolor = 'lightgray' , alpha = 0.5 )
542+ ax .add_collection (pc )
543+ y = statistics .mean (precisions_mean_ordered )
544+ h = statistics .stdev (precisions_mean_ordered )
545+ avebox = Rectangle ((minx ,y - h ),100 ,2 * h )
546+ ax .plot ([minx ,100 ],[y ,y ],'w-' )
547+ pc = PatchCollection ([avebox ], facecolor = 'lightgray' , alpha = 0.5 )
548+ ax .add_collection (pc )
549+ else :
550+ x ,y ,w ,h - recalls_mean_ordered [0 ], precisions_mean_ordered [0 ], 0 , 0
551+
552+ ax .set_xlim (right = 100 )
553+ ax .set_ylim (top = 100 )
554+ ax .set_xlabel ('Recall (%)' )
555+ ax .set_ylabel ('Precision (%)' )
556+ ax .set_title ("P=" + str (round (y ,1 ))+ "+/-" + str (round (h ,1 ))+ "% " +
557+ "R=" + str (round (x ,1 ))+ "+/-" + str (round (w ,1 ))+ "%" )
558+ ax .legend (loc = (1.05 , 0.0 ))
506559
507560 fig .tight_layout ()
508561 plt .savefig (os .path .join (FLAGS .logdir ,'accuracy.pdf' ))
@@ -512,7 +565,7 @@ def main():
512565 plot_confusion_matrices (confusion_matrices ,
513566 precision_matrices ,
514567 recall_matrices ,
515- labels , accuracies , keys_to_plot ,
568+ labels , precisions_mean , recalls_mean , keys_to_plot ,
516569 numbers = len (labels )< 10 )
517570 plt .savefig (os .path .join (FLAGS .logdir ,'confusion-matrices.pdf' ))
518571 plt .close ()
@@ -523,15 +576,15 @@ def main():
523576 pool = Pool (nprocs )
524577 results = []
525578
526- for key_to_plot in accuracies :
579+ for model in recalls_mean :
527580 for ckpt in [int (x .split ('-' )[1 ][:- 4 ]) for x in \
528581 filter (lambda x : 'validation' in x and x .endswith ('.npz' ), \
529- os .listdir (os .path .join (FLAGS .logdir ,key_to_plot )))]:
582+ os .listdir (os .path .join (FLAGS .logdir ,model )))]:
530583
531584 if FLAGS .parallelize != 0 :
532- results .append (pool .apply_async (doit , (FLAGS .logdir ,key_to_plot ,ckpt ,labels ,FLAGS .nprobabilities ,error_ratios )))
585+ results .append (pool .apply_async (doit , (FLAGS .logdir ,model ,ckpt ,labels ,FLAGS .nprobabilities ,error_ratios )))
533586 else :
534- doit (FLAGS .logdir , key_to_plot , ckpt , labels , FLAGS .nprobabilities , error_ratios )
587+ doit (FLAGS .logdir , model , ckpt , labels , FLAGS .nprobabilities , error_ratios )
535588
536589 if FLAGS .parallelize != 0 :
537590 for result in results :
0 commit comments