Skip to content

Commit b1bc716

Browse files
committed
replace accuracy nomenclature in plots with recall, and add precision where possible
1 parent ac147a6 commit b1bc716

File tree

4 files changed

+221
-175
lines changed

4 files changed

+221
-175
lines changed

src/accuracy

Lines changed: 103 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@ import matplotlib.pyplot as plt
2121
import csv
2222
from natsort import natsorted, index_natsorted
2323
import matplotlib.cm as cm
24+
from matplotlib.patches import Rectangle
2425
from datetime import datetime
2526
import socket
27+
import statistics
2628

2729
repodir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
2830

@@ -193,9 +195,10 @@ def main():
193195
error_ratios = [float(x) for x in FLAGS.error_ratios.split(',')]
194196

195197
train_accuracy, train_loss, train_time, train_step, \
196-
validation_precision, validation_recall, validation_accuracy, \
198+
validation_precision, validation_recall, \
199+
validation_precision_mean, validation_recall_mean, \
197200
validation_time, validation_step, \
198-
_, _, _, \
201+
_, _, _, _, \
199202
labels_touse, label_counts, _, _, batch_size, _ = \
200203
read_logs(FLAGS.logdir)
201204
training_set_size = {k: len(label_counts[k]["training"]) * \
@@ -231,12 +234,12 @@ def main():
231234
train_step[model].index(y)]) \
232235
for (x,y) in validation_intervals]
233236
ax4.plot([(x+y)/2 for (x,y) in validation_intervals], train_loss_ave, 'm', label='Loss mean')
234-
if validation_accuracy[model]:
235-
ax1.plot(validation_step[model], validation_accuracy[model], 'r', label='Validation')
236-
ax1.set_title(model+" "+str(round(validation_accuracy[model][-1],1))+'%')
237-
ax1.set_ylim(bottom=min(validation_accuracy[model]))
237+
if validation_recall_mean[model]:
238+
ax1.plot(validation_step[model], validation_recall_mean[model], 'r', label='Validation')
239+
ax1.set_title(model+" "+str(round(validation_recall_mean[model][-1],1))+'%')
240+
ax1.set_ylim(bottom=min(validation_recall_mean[model]))
238241
ax1.set_ylim(top=100)
239-
ax1.set_ylabel('Overall accuracy')
242+
ax1.set_ylabel('Overall recall')
240243
ax1.set_xlim(0,1+len(train_step[model]))
241244
if imodel==len(keys_to_plot)-1:
242245
handles1, labels1 = ax1.get_legend_handles_labels()
@@ -269,7 +272,7 @@ def main():
269272
ax3.set_xlabel('Epoch')
270273

271274
fig.tight_layout()
272-
plt.savefig(os.path.join(FLAGS.logdir,'train-loss.pdf'))
275+
plt.savefig(os.path.join(FLAGS.logdir,'train-validation-loss.pdf'))
273276
plt.close()
274277

275278
nrows, ncols = layout(len(keys_to_plot))
@@ -348,7 +351,7 @@ def main():
348351
bbox_extra_artists=(lgd,), bbox_inches='tight')
349352
plt.close()
350353

351-
if not len(validation_accuracy[keys_to_plot[0]]):
354+
if not len(validation_recall_mean[keys_to_plot[0]]):
352355
sys.exit()
353356

354357
if len(keys_to_plot)>1:
@@ -358,29 +361,29 @@ def main():
358361
for model in keys_to_plot:
359362
scaled_validation_time, units = choose_units(validation_time[model])
360363

361-
line, = ax.plot(validation_step[model], validation_accuracy[model])
364+
line, = ax.plot(validation_step[model], validation_recall_mean[model])
362365
line.set_label(model)
363366
ax.set_ylim(top=100)
364367
ax.set_xlabel('Step')
365-
ax.set_ylabel('Overall accuracy')
368+
ax.set_ylabel('Overall recall')
366369
#ax.legend(loc='lower right')
367370

368371
ax = fig.add_subplot(2,3,2)
369372
for model in keys_to_plot:
370373
ax.plot([x*batch_size[model]/training_set_size[model] \
371-
for x in validation_step[model]], validation_accuracy[model])
374+
for x in validation_step[model]], validation_recall_mean[model])
372375
ax.set_ylim(top=100)
373376
ax.set_xlabel('Epoch')
374-
ax.set_ylabel('Overall accuracy')
377+
ax.set_ylabel('Overall recall')
375378

376379
ax = fig.add_subplot(2,3,3)
377380
for model in keys_to_plot:
378-
idx = min(len(scaled_validation_time), len(validation_accuracy[model]))
379-
line, = ax.plot(scaled_validation_time[:idx], validation_accuracy[model][:idx])
381+
idx = min(len(scaled_validation_time), len(validation_recall_mean[model]))
382+
line, = ax.plot(scaled_validation_time[:idx], validation_recall_mean[model][:idx])
380383
line.set_label(model)
381384
ax.set_ylim(top=100)
382385
ax.set_xlabel('Time ('+units+')')
383-
ax.set_ylabel('Overall accuracy')
386+
ax.set_ylabel('Overall recall')
384387

385388
ax = fig.add_subplot(2,3,4)
386389
for model in keys_to_plot:
@@ -409,45 +412,46 @@ def main():
409412

410413
fig.tight_layout()
411414

412-
plt.savefig(os.path.join(FLAGS.logdir,'train-overlay.pdf'))
415+
plt.savefig(os.path.join(FLAGS.logdir,'validation-overlay.pdf'))
413416
plt.close()
414417

415418
fig = plt.figure(figsize=(6.4, 4.8))
416419

417-
model_validation_accuracy = np.zeros(len(validation_accuracy[keys_to_plot[0]]))
420+
model_validation_recall = np.zeros(len(validation_recall_mean[keys_to_plot[0]]))
418421
nmodels = 0
419422
for model in keys_to_plot:
420423
if validation_step[model] != validation_step[keys_to_plot[0]]:
421424
print("WARNING: not all checkpoint steps are the same for "+model)
422425
continue
423-
model_validation_accuracy += validation_accuracy[model]
426+
model_validation_recall += validation_recall_mean[model]
424427
model_validation_step = validation_step[model]
425428
nmodels += 1
426429
ax = fig.add_subplot(1,1,1)
427-
line, = ax.plot(model_validation_step, model_validation_accuracy / nmodels)
430+
line, = ax.plot(model_validation_step, model_validation_recall / nmodels)
428431
line.set_label("model average")
429432
ax.set_ylim(top=100)
430433
ax.set_xlabel('Step')
431-
ax.set_ylabel('Overall accuracy')
434+
ax.set_ylabel('Overall recall')
432435
ax.legend(loc=(1.05, 0.0))
433436
ax.grid(True)
434437

435438
fig.tight_layout()
436439

437-
plt.savefig(os.path.join(FLAGS.logdir,'train-summed.pdf'))
440+
plt.savefig(os.path.join(FLAGS.logdir,'validation-average.pdf'))
438441
plt.close()
439442

440443
summed_confusion_matrix, confusion_matrices, labels = \
441444
parse_confusion_matrices(FLAGS.logdir, next(iter(keys_to_plot)).split('_')[0])
442445

443446
recall_matrices={}
444447
precision_matrices={}
445-
accuracies={}
448+
precisions_mean={}
449+
recalls_mean={}
446450
for model in keys_to_plot:
447-
recall_matrices[model], precision_matrices[model], accuracies[model] = \
451+
precision_matrices[model], recall_matrices[model], precisions_mean[model], recalls_mean[model] = \
448452
normalize_confusion_matrix(confusion_matrices[model])
449453

450-
recall_summed_matrix, precision_summed_matrix, accuracy_summed = \
454+
precision_summed_matrix, recall_summed_matrix, precision_summed, recall_summed = \
451455
normalize_confusion_matrix(summed_confusion_matrix)
452456

453457
from mpl_toolkits.axes_grid1 import make_axes_locatable
@@ -468,41 +472,90 @@ def main():
468472
ax.invert_yaxis()
469473
ax.set_xlabel('Classification')
470474
ax.set_ylabel('Annotation')
471-
ax.set_title(str(round(accuracy_summed,1))+"%")
475+
ax.set_title("P="+str(round(precision_summed,1))+"% "+
476+
"R="+str(round(recall_summed,1))+"%")
472477

473478
ax = plt.subplot(1,3,2)
479+
precisions_all = []
480+
recalls_all = []
474481
for model in keys_to_plot:
475482
ax.set_prop_cycle(None)
476483
for (ilabel,label) in enumerate(labels):
477-
line, = ax.plot(100*recall_matrices[model][ilabel][ilabel],
478-
100*precision_matrices[model][ilabel][ilabel],
479-
'o', markeredgecolor='k')
484+
precisions_all.append(100*precision_matrices[model][ilabel][ilabel])
485+
recalls_all.append(100*recall_matrices[model][ilabel][ilabel])
486+
line, = ax.plot(recalls_all[-1], precisions_all[-1], 'o', markeredgecolor='k')
480487
if model==keys_to_plot[0]:
481488
line.set_label(label)
482-
489+
490+
if len(recalls_all)>1:
491+
ax.autoscale_view()
492+
ax.set_autoscale_on(False)
493+
miny = ax.get_ylim()[0]
494+
minx = ax.get_xlim()[0]
495+
x = statistics.mean(recalls_all)
496+
w = statistics.stdev(recalls_all)
497+
avebox = Rectangle((x-w,miny),2*w,100)
498+
ax.plot([x,x],[miny,100],'w-')
499+
pc = PatchCollection([avebox], facecolor='lightgray', alpha=0.5)
500+
ax.add_collection(pc)
501+
y = statistics.mean(precisions_all)
502+
h = statistics.stdev(precisions_all)
503+
avebox = Rectangle((minx,y-h),100,2*h)
504+
ax.plot([minx,100],[y,y],'w-')
505+
pc = PatchCollection([avebox], facecolor='lightgray', alpha=0.5)
506+
ax.add_collection(pc)
507+
else:
508+
x,y,w,h = recalls_all[0], precisions_all[0], 0, 0
509+
483510
ax.set_xlim(right=100)
484511
ax.set_ylim(top=100)
485512
ax.set_xlabel('Recall (%)')
486513
ax.set_ylabel('Precision (%)')
514+
ax.set_title("P="+str(round(y,1))+"+/-"+str(round(h,1))+"% "+
515+
"R="+str(round(x,1))+"+/-"+str(round(w,1))+"%")
487516
ax.legend(loc=(1.05, 0.0))
488517

489518
ax = fig.add_subplot(1,3,3)
490-
accuracies_ordered = [accuracies[k] for k in keys_to_plot]
519+
precisions_mean_ordered = [precisions_mean[k] for k in keys_to_plot]
520+
recalls_mean_ordered = [recalls_mean[k] for k in keys_to_plot]
491521
print('models=', keys_to_plot)
492-
print('accuracies=', accuracies_ordered)
493-
494-
x = 1
495-
y = statistics.mean(accuracies_ordered)
496-
h = statistics.stdev(accuracies_ordered) if len(accuracies)>1 else 0
497-
avebox = Rectangle((y-h,-0.25),2*h,len(accuracies)-1+0.5)
498-
ax.plot([y,y],[-0.25,len(accuracies)-1+0.25],'w-')
499-
pc = PatchCollection([avebox], facecolor='lightgray')
500-
ax.add_collection(pc)
501-
502-
ax.plot(accuracies_ordered, keys_to_plot, 'k.')
503-
ax.set_xlabel('Overall accuracy (%)')
504-
ax.set_ylabel('Model')
505-
ax.set_title(str(round(y,1))+"+/-"+str(round(h,1))+"%")
522+
print('precisions=', precisions_mean_ordered)
523+
print('recalls=', recalls_mean_ordered)
524+
525+
for model in keys_to_plot:
526+
ax.set_prop_cycle(None)
527+
line, = ax.plot(recalls_mean[model],
528+
precisions_mean[model],
529+
'o', markeredgecolor='k')
530+
line.set_label(model)
531+
532+
if len(recalls_mean_ordered)>1:
533+
ax.autoscale_view()
534+
ax.set_autoscale_on(False)
535+
miny = ax.get_ylim()[0]
536+
minx = ax.get_xlim()[0]
537+
x = statistics.mean(recalls_mean_ordered)
538+
w = statistics.stdev(recalls_mean_ordered)
539+
avebox = Rectangle((x-w,miny),2*w,100)
540+
ax.plot([x,x],[miny,100],'w-')
541+
pc = PatchCollection([avebox], facecolor='lightgray', alpha=0.5)
542+
ax.add_collection(pc)
543+
y = statistics.mean(precisions_mean_ordered)
544+
h = statistics.stdev(precisions_mean_ordered)
545+
avebox = Rectangle((minx,y-h),100,2*h)
546+
ax.plot([minx,100],[y,y],'w-')
547+
pc = PatchCollection([avebox], facecolor='lightgray', alpha=0.5)
548+
ax.add_collection(pc)
549+
else:
550+
x,y,w,h - recalls_mean_ordered[0], precisions_mean_ordered[0], 0, 0
551+
552+
ax.set_xlim(right=100)
553+
ax.set_ylim(top=100)
554+
ax.set_xlabel('Recall (%)')
555+
ax.set_ylabel('Precision (%)')
556+
ax.set_title("P="+str(round(y,1))+"+/-"+str(round(h,1))+"% "+
557+
"R="+str(round(x,1))+"+/-"+str(round(w,1))+"%")
558+
ax.legend(loc=(1.05, 0.0))
506559

507560
fig.tight_layout()
508561
plt.savefig(os.path.join(FLAGS.logdir,'accuracy.pdf'))
@@ -512,7 +565,7 @@ def main():
512565
plot_confusion_matrices(confusion_matrices,
513566
precision_matrices,
514567
recall_matrices,
515-
labels, accuracies, keys_to_plot,
568+
labels, precisions_mean, recalls_mean, keys_to_plot,
516569
numbers=len(labels)<10)
517570
plt.savefig(os.path.join(FLAGS.logdir,'confusion-matrices.pdf'))
518571
plt.close()
@@ -523,15 +576,15 @@ def main():
523576
pool = Pool(nprocs)
524577
results = []
525578

526-
for key_to_plot in accuracies:
579+
for model in recalls_mean:
527580
for ckpt in [int(x.split('-')[1][:-4]) for x in \
528581
filter(lambda x: 'validation' in x and x.endswith('.npz'), \
529-
os.listdir(os.path.join(FLAGS.logdir,key_to_plot)))]:
582+
os.listdir(os.path.join(FLAGS.logdir,model)))]:
530583

531584
if FLAGS.parallelize!=0:
532-
results.append(pool.apply_async(doit, (FLAGS.logdir,key_to_plot,ckpt,labels,FLAGS.nprobabilities,error_ratios)))
585+
results.append(pool.apply_async(doit, (FLAGS.logdir,model,ckpt,labels,FLAGS.nprobabilities,error_ratios)))
533586
else:
534-
doit(FLAGS.logdir, key_to_plot, ckpt, labels, FLAGS.nprobabilities, error_ratios)
587+
doit(FLAGS.logdir, model, ckpt, labels, FLAGS.nprobabilities, error_ratios)
535588

536589
if FLAGS.parallelize!=0:
537590
for result in results:

0 commit comments

Comments
 (0)