@@ -25,7 +25,7 @@ def default(self, obj):
2525
2626def get_analysis (df_results ,
2727 output_dir = './' ,
28- used_statistic = 'Accuracy ' ,
28+ used_statistic = 'Score ' ,
2929 save_as_json = True ,
3030 plot_1v1_comparisons = False ,
3131 order_WinTieLoss = 'higher' ,
@@ -50,7 +50,7 @@ def get_analysis(df_results,
5050
5151 df_results : pandas DataFrame, the csv file containing results
5252 output_dir : str, default = './', the output directory for the results
53- used_statistic : str, default = 'Accuracy ', one can imagine using error, time, memory etc. instead
53+ used_statistic : str, default = 'Score ', one can imagine using error, time, memory etc. instead
5454 save_as_json : bool, default = True, whether or not to save the python analysis dict
5555 into a json file format
5656 plot_1v1_comparisons : bool, default = True, whether or not to plot the 1v1 scatter results
@@ -199,7 +199,10 @@ def get_heatmap(analysis=None,
199199 pixels_per_clf_width = 3.5 ,
200200 show_symetry = True ,
201201 colorbar_orientation = 'vertical' ,
202- colorbar_value = None ):
202+ colorbar_value = None ,
203+ win_label = 'r>c' ,
204+ tie_label = 'r=c' ,
205+ loss_label = 'r<c' ):
203206
204207 """
205208
@@ -245,8 +248,8 @@ def get_heatmap(analysis=None,
245248 start_index = 0
246249
247250 string_to_add = ''
248- string_to_add = string_to_add + analysis ['used-mean' ] + '\n '
249- string_to_add = string_to_add + 'Win/Tie/Loss ' + analysis [ 'order-WinTieLoss' ] + '\n '
251+ string_to_add = string_to_add + capitalize_label ( analysis ['used-mean' ]) + '\n '
252+ string_to_add = string_to_add + win_label + '/' + tie_label + '/' + loss_label + ' ' + '\n '
250253 if analysis ['include-pvalue' ]:
251254 string_to_add = string_to_add + analysis ['pvalue-test' ].capitalize () + ' p-value'
252255
@@ -320,7 +323,10 @@ def get_heatmap(analysis=None,
320323 if show_symetry :
321324 pairwise_matrix [i ,i ] = 0.0
322325 if i > 0 :
323- dict_to_add [analysis ['ordered-classifier-names' ][i ]] = '-'
326+ if i == 1 and win_label == 'r>c' :
327+ dict_to_add [analysis ['ordered-classifier-names' ][i ]] = 'r: row\n c: column'
328+ else :
329+ dict_to_add [analysis ['ordered-classifier-names' ][i ]] = '-'
324330
325331 df_annotations = df_annotations .append (dict_to_add , ignore_index = True )
326332
@@ -352,7 +358,10 @@ def get_heatmap(analysis=None,
352358 plt .rcParams ["figure.autolayout" ] = True
353359 fig , ax = plt .subplots (1 , 1 , figsize = (figsize [0 ], figsize [1 ]))
354360
355- min_value , max_value = get_limits (pairwise_matrix = pairwise_matrix )
361+ _can_be_negative = False
362+ if colorbar_value is None or colorbar_value == 'mean-difference' :
363+ _can_be_negative = True
364+ min_value , max_value = get_limits (pairwise_matrix = pairwise_matrix , can_be_negative = _can_be_negative )
356365
357366 if colormap is None :
358367 _colormap = 'coolwarm'
@@ -363,6 +372,10 @@ def get_heatmap(analysis=None,
363372 _vmin = min_value + 0.2 * min_value
364373 _vmax = max_value + 0.2 * max_value
365374
375+ if colorbar_value is None :
376+ _colorbar_value = capitalize_label ('mean-difference' )
377+ else :
378+ _colorbar_value = capitalize_label (colorbar_value )
366379
367380 im = ax .imshow (pairwise_matrix ,
368381 cmap = _colormap ,
@@ -373,7 +386,7 @@ def get_heatmap(analysis=None,
373386 if colormap is not None :
374387 cbar = ax .figure .colorbar (im , ax = ax , orientation = colorbar_orientation )
375388 cbar .ax .tick_params (labelsize = font_size )
376- cbar .set_label (label = analysis [ 'used-mean' ] , size = font_size_colorbar_label )
389+ cbar .set_label (label = capitalize_label ( _colorbar_value ) , size = font_size_colorbar_label )
377390
378391 xticks , yticks = get_ticks (analysis )
379392
@@ -398,7 +411,7 @@ def get_heatmap(analysis=None,
398411 kw .update (fontweight = 'normal' )
399412
400413 if analysis ['order-stats' ] == 'average-statistic' :
401- ordering = 'average -' + analysis ['used-statistics' ]
414+ ordering = 'Average -' + capitalize_label ( analysis ['used-statistics' ])
402415 else :
403416 ordering = analysis ['order-stats' ]
404417
@@ -434,7 +447,7 @@ def get_line_heatmap(proposed_methods,
434447 pixels_per_clf_hieght = 7 ,
435448 pixels_per_clf_width = 1.5 ,
436449 colorbar_orientation = 'horizontal' ,
437- used_statistic = 'Accuracy ' ,
450+ used_statistic = 'Score ' ,
438451 order_WinTieLoss = 'higher' ,
439452 include_ProbaWinTieLoss = False ,
440453 bayesian_rope = 0.01 ,
@@ -445,7 +458,10 @@ def get_line_heatmap(proposed_methods,
445458 used_mean = 'mean-difference' ,
446459 order_stats = 'average-statistic' ,
447460 order_better = 'decreasing' ,
448- dataset_column = 'dataset_name' ,):
461+ dataset_column = 'dataset_name' ,
462+ win_label = 'row>col' ,
463+ tie_label = 'row=col' ,
464+ loss_label = 'row<col' ):
449465
450466 """
451467
@@ -539,7 +555,10 @@ def get_line_heatmap(proposed_methods,
539555 pvalue_correction = pvalue_correction ,
540556 pvalue_test = pvalue_test ,
541557 pvalue_threshhold = pvalue_threshhold ,
542- dataset_column = dataset_column )
558+ dataset_column = dataset_column ,
559+ win_label = win_label ,
560+ tie_label = tie_label ,
561+ loss_label = loss_label )
543562
544563def _get_line_heatmap (proposed_method ,
545564 excluded_methods = None ,
@@ -554,7 +573,7 @@ def _get_line_heatmap(proposed_method,
554573 pixels_per_clf_hieght = 7 ,
555574 pixels_per_clf_width = 2.5 ,
556575 colorbar_orientation = 'horizontal' ,
557- used_statistic = 'Accuracy ' ,
576+ used_statistic = 'Score ' ,
558577 order_WinTieLoss = 'higher' ,
559578 include_ProbaWinTieLoss = False ,
560579 bayesian_rope = 0.01 ,
@@ -565,7 +584,10 @@ def _get_line_heatmap(proposed_method,
565584 used_mean = 'mean-difference' ,
566585 order_stats = 'average-statistic' ,
567586 order_better = 'decreasing' ,
568- dataset_column = 'dataset_name' ,):
587+ dataset_column = 'dataset_name' ,
588+ win_label = 'row>column' ,
589+ tie_label = 'row=column' ,
590+ loss_label = 'row<column' ):
569591
570592 """
571593
@@ -747,9 +769,9 @@ def _get_line_heatmap(proposed_method,
747769 _vmax = max_value + 0.8 * max_value
748770
749771 if colorbar_value is None :
750- _colorbar_value = 'mean-difference'
772+ _colorbar_value = capitalize_label ( 'mean-difference' )
751773 else :
752- _colorbar_value = colorbar_value
774+ _colorbar_value = capitalize_label ( colorbar_value )
753775
754776 im = ax .imshow (pairwise_line ,
755777 cmap = _colormap ,
@@ -786,15 +808,15 @@ def _get_line_heatmap(proposed_method,
786808 kw .update (fontweight = 'normal' )
787809
788810 if analysis ['order-stats' ] == 'average-statistic' :
789- ordering = 'average -' + analysis ['used-statistics' ]
811+ ordering = 'Average -' + capitalize_label ( analysis ['used-statistics' ])
790812 else :
791813 ordering = analysis ['order-stats' ]
792814
793815 im .axes .text (- 0.7 ,- 0.7 , ordering , fontsize = font_size , ** {"horizontalalignment" :"center" , "verticalalignment" :"center" })
794816
795817 string_to_add = ''
796- string_to_add = string_to_add + analysis ['used-mean' ] + '\n '
797- string_to_add = string_to_add + 'Win/Tie/Loss ' + analysis [ 'order-WinTieLoss' ] + '\n '
818+ string_to_add = string_to_add + capitalize_label ( analysis ['used-mean' ]) + '\n '
819+ string_to_add = string_to_add + win_label + '/' + tie_label + '/' + loss_label + ' ' + '\n '
798820 if analysis ['include-pvalue' ]:
799821 string_to_add = string_to_add + analysis ['pvalue-test' ].capitalize () + ' p-value < ' + str (analysis ['pvalue-threshold' ])
800822
0 commit comments