Skip to content

Commit 1771270

Browse files
committed
[ENH] added the win tie loss labels as params
1 parent 8cea406 commit 1771270

File tree

2 files changed

+50
-20
lines changed

2 files changed

+50
-20
lines changed

MCM/MCM.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def default(self, obj):
2525

2626
def get_analysis(df_results,
2727
output_dir='./',
28-
used_statistic='Accuracy',
28+
used_statistic='Score',
2929
save_as_json=True,
3030
plot_1v1_comparisons=False,
3131
order_WinTieLoss='higher',
@@ -50,7 +50,7 @@ def get_analysis(df_results,
5050
5151
df_results : pandas DataFrame, the csv file containing results
5252
output_dir : str, default = './', the output directory for the results
53-
used_statistic : str, default = 'Accuracy', one can imagine using error, time, memory etc. instead
53+
used_statistic : str, default = 'Score', one can imagine using error, time, memory etc. instead
5454
save_as_json : bool, default = True, whether or not to save the python analysis dict
5555
into a json file format
5656
plot_1v1_comparisons : bool, default = True, whether or not to plot the 1v1 scatter results
@@ -199,7 +199,10 @@ def get_heatmap(analysis=None,
199199
pixels_per_clf_width=3.5,
200200
show_symetry=True,
201201
colorbar_orientation='vertical',
202-
colorbar_value=None):
202+
colorbar_value=None,
203+
win_label='r>c',
204+
tie_label='r=c',
205+
loss_label='r<c'):
203206

204207
"""
205208
@@ -245,8 +248,8 @@ def get_heatmap(analysis=None,
245248
start_index = 0
246249

247250
string_to_add = ''
248-
string_to_add = string_to_add + analysis['used-mean'] + '\n'
249-
string_to_add = string_to_add + 'Win/Tie/Loss ' + analysis['order-WinTieLoss'] + '\n'
251+
string_to_add = string_to_add + capitalize_label(analysis['used-mean']) + '\n'
252+
string_to_add = string_to_add + win_label+'/'+tie_label+'/'+loss_label+' ' + '\n'
250253
if analysis['include-pvalue']:
251254
string_to_add = string_to_add + analysis['pvalue-test'].capitalize() + ' p-value'
252255

@@ -320,7 +323,10 @@ def get_heatmap(analysis=None,
320323
if show_symetry:
321324
pairwise_matrix[i,i] = 0.0
322325
if i > 0:
323-
dict_to_add[analysis['ordered-classifier-names'][i]] = '-'
326+
if i == 1 and win_label == 'r>c':
327+
dict_to_add[analysis['ordered-classifier-names'][i]] = 'r: row\nc: column'
328+
else:
329+
dict_to_add[analysis['ordered-classifier-names'][i]] = '-'
324330

325331
df_annotations = df_annotations.append(dict_to_add, ignore_index=True)
326332

@@ -352,7 +358,10 @@ def get_heatmap(analysis=None,
352358
plt.rcParams["figure.autolayout"] = True
353359
fig, ax = plt.subplots(1, 1, figsize=(figsize[0], figsize[1]))
354360

355-
min_value, max_value = get_limits(pairwise_matrix=pairwise_matrix)
361+
_can_be_negative = False
362+
if colorbar_value is None or colorbar_value == 'mean-difference':
363+
_can_be_negative = True
364+
min_value, max_value = get_limits(pairwise_matrix=pairwise_matrix, can_be_negative=_can_be_negative)
356365

357366
if colormap is None:
358367
_colormap = 'coolwarm'
@@ -363,6 +372,10 @@ def get_heatmap(analysis=None,
363372
_vmin = min_value + 0.2*min_value
364373
_vmax = max_value + 0.2*max_value
365374

375+
if colorbar_value is None:
376+
_colorbar_value = capitalize_label('mean-difference')
377+
else:
378+
_colorbar_value = capitalize_label(colorbar_value)
366379

367380
im = ax.imshow(pairwise_matrix,
368381
cmap=_colormap,
@@ -373,7 +386,7 @@ def get_heatmap(analysis=None,
373386
if colormap is not None:
374387
cbar = ax.figure.colorbar(im, ax=ax, orientation=colorbar_orientation)
375388
cbar.ax.tick_params(labelsize=font_size)
376-
cbar.set_label(label=analysis['used-mean'], size=font_size_colorbar_label)
389+
cbar.set_label(label=capitalize_label(_colorbar_value), size=font_size_colorbar_label)
377390

378391
xticks, yticks = get_ticks(analysis)
379392

@@ -398,7 +411,7 @@ def get_heatmap(analysis=None,
398411
kw.update(fontweight='normal')
399412

400413
if analysis['order-stats'] == 'average-statistic':
401-
ordering = 'average-'+analysis['used-statistics']
414+
ordering = 'Average-'+capitalize_label(analysis['used-statistics'])
402415
else:
403416
ordering = analysis['order-stats']
404417

@@ -434,7 +447,7 @@ def get_line_heatmap(proposed_methods,
434447
pixels_per_clf_hieght=7,
435448
pixels_per_clf_width=1.5,
436449
colorbar_orientation='horizontal',
437-
used_statistic='Accuracy',
450+
used_statistic='Score',
438451
order_WinTieLoss='higher',
439452
include_ProbaWinTieLoss=False,
440453
bayesian_rope=0.01,
@@ -445,7 +458,10 @@ def get_line_heatmap(proposed_methods,
445458
used_mean='mean-difference',
446459
order_stats='average-statistic',
447460
order_better='decreasing',
448-
dataset_column='dataset_name',):
461+
dataset_column='dataset_name',
462+
win_label='row>col',
463+
tie_label='row=col',
464+
loss_label='row<col'):
449465

450466
"""
451467
@@ -539,7 +555,10 @@ def get_line_heatmap(proposed_methods,
539555
pvalue_correction=pvalue_correction,
540556
pvalue_test=pvalue_test,
541557
pvalue_threshhold=pvalue_threshhold,
542-
dataset_column=dataset_column)
558+
dataset_column=dataset_column,
559+
win_label=win_label,
560+
tie_label=tie_label,
561+
loss_label=loss_label)
543562

544563
def _get_line_heatmap(proposed_method,
545564
excluded_methods=None,
@@ -554,7 +573,7 @@ def _get_line_heatmap(proposed_method,
554573
pixels_per_clf_hieght=7,
555574
pixels_per_clf_width=2.5,
556575
colorbar_orientation='horizontal',
557-
used_statistic='Accuracy',
576+
used_statistic='Score',
558577
order_WinTieLoss='higher',
559578
include_ProbaWinTieLoss=False,
560579
bayesian_rope=0.01,
@@ -565,7 +584,10 @@ def _get_line_heatmap(proposed_method,
565584
used_mean='mean-difference',
566585
order_stats='average-statistic',
567586
order_better='decreasing',
568-
dataset_column='dataset_name',):
587+
dataset_column='dataset_name',
588+
win_label='row>column',
589+
tie_label='row=column',
590+
loss_label='row<column'):
569591

570592
"""
571593
@@ -747,9 +769,9 @@ def _get_line_heatmap(proposed_method,
747769
_vmax = max_value + 0.8*max_value
748770

749771
if colorbar_value is None:
750-
_colorbar_value = 'mean-difference'
772+
_colorbar_value = capitalize_label('mean-difference')
751773
else:
752-
_colorbar_value = colorbar_value
774+
_colorbar_value = capitalize_label(colorbar_value)
753775

754776
im = ax.imshow(pairwise_line,
755777
cmap=_colormap,
@@ -786,15 +808,15 @@ def _get_line_heatmap(proposed_method,
786808
kw.update(fontweight='normal')
787809

788810
if analysis['order-stats'] == 'average-statistic':
789-
ordering = 'average-'+analysis['used-statistics']
811+
ordering = 'Average-'+capitalize_label(analysis['used-statistics'])
790812
else:
791813
ordering = analysis['order-stats']
792814

793815
im.axes.text(-0.7,-0.7, ordering, fontsize=font_size, **{"horizontalalignment":"center", "verticalalignment":"center"})
794816

795817
string_to_add = ''
796-
string_to_add = string_to_add + analysis['used-mean'] + '\n'
797-
string_to_add = string_to_add + 'Win/Tie/Loss ' + analysis['order-WinTieLoss'] + '\n'
818+
string_to_add = string_to_add + capitalize_label(analysis['used-mean']) + '\n'
819+
string_to_add = string_to_add + win_label+'/'+tie_label+'/'+loss_label+' ' + '\n'
798820
if analysis['include-pvalue']:
799821
string_to_add = string_to_add + analysis['pvalue-test'].capitalize() + ' p-value < ' + str(analysis['pvalue-threshold'])
800822

MCM/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,4 +483,12 @@ def plot_1v1(x,
483483

484484
if not os.path.exists(output_directory+'1v1_plots/'):
485485
os.mkdir(output_directory+'1v1_plots/')
486-
plt.savefig(output_directory + '1v1_plots/'+name_x+'_vs_'+name_y+'.pdf')
486+
plt.savefig(output_directory + '1v1_plots/'+name_x+'_vs_'+name_y+'.pdf')
487+
488+
def capitalize_label(s):
489+
490+
if len(s.split('-')) == 1:
491+
return s.capitalize()
492+
493+
else:
494+
return '-'.join(ss.capitalize() for ss in s.split('-'))

0 commit comments

Comments
 (0)