1- """
2- ROC Analysis Widget
3- -------------------
4-
5- """
61import operator
72from functools import reduce , wraps
83from collections import namedtuple , deque , OrderedDict
116import sklearn .metrics as skl_metrics
127
138from AnyQt .QtWidgets import QListView , QLabel , QGridLayout , QFrame , QAction , \
14- QToolTip , QSizePolicy
9+ QToolTip
1510from AnyQt .QtGui import QColor , QPen , QBrush , QPainter , QPalette , QFont , \
1611 QCursor , QFontMetrics
1712from AnyQt .QtCore import Qt , QSize
2116from Orange .widgets import widget , gui , settings
2217from Orange .widgets .evaluate .contexthandlers import \
2318 EvaluationResultsContextHandler
24- from Orange .widgets .evaluate .utils import \
25- check_results_adequacy , results_for_preview
19+ from Orange .widgets .evaluate .utils import check_results_adequacy
2620from Orange .widgets .utils import colorpalettes
2721from Orange .widgets .utils .widgetpreview import WidgetPreview
2822from Orange .widgets .widget import Input
2923from Orange .widgets import report
3024
25+ from Orange .widgets .evaluate .utils import results_for_preview
26+ from Orange .evaluation .testing import Results
27+
3128
3229#: Points on a ROC curve
3330ROCPoints = namedtuple (
@@ -93,11 +90,11 @@ def roc_data_from_results(results, clf_index, target):
9390 :rval ROCData:
9491 A instance holding the computed curves.
9592 """
96- merged = roc_curve_for_fold (results , slice ( 0 , - 1 ) , clf_index , target )
93+ merged = roc_curve_for_fold (results , ... , clf_index , target )
9794 merged_curve = ROCCurve (ROCPoints (* merged ),
9895 ROCPoints (* roc_curve_convex_hull (merged )))
9996
100- folds = results .folds if results .folds is not None else [slice ( 0 , - 1 ) ]
97+ folds = results .folds if results .folds is not None else [... ]
10198 fold_curves = []
10299 for fold in folds :
103100 points = roc_curve_for_fold (results , fold , clf_index , target )
@@ -413,11 +410,13 @@ def __init__(self):
413410 axis .setTickFont (tickfont )
414411 axis .setPen (pen )
415412 axis .setLabel ("FP Rate (1-Specificity)" )
413+ axis .setGrid (16 )
416414
417415 axis = self .plot .getAxis ("left" )
418416 axis .setTickFont (tickfont )
419417 axis .setPen (pen )
420418 axis .setLabel ("TP Rate (Sensitivity)" )
419+ axis .setGrid (16 )
421420
422421 self .plot .showGrid (True , True , alpha = 0.1 )
423422 self .plot .setRange (xRange = (0.0 , 1.0 ), yRange = (0.0 , 1.0 ), padding = 0.05 )
@@ -621,6 +620,8 @@ def no_averaging():
621620 if self .roc_averaging == OWROCAnalysis .Merge :
622621 self ._update_perf_line ()
623622
623+ self ._update_axes_ticks ()
624+
624625 warning = ""
625626 if not all (c .is_valid for c in hull_curves ):
626627 if any (c .is_valid for c in hull_curves ):
@@ -629,6 +630,22 @@ def no_averaging():
629630 warning = "All ROC curves are undefined"
630631 self .warning (warning )
631632
633+ def _update_axes_ticks (self ):
634+ def enumticks (a ):
635+ a = np .unique (a )
636+ if len (a ) > 15 :
637+ return None
638+ return [[(x , f"{ x :.2f} " ) for x in a [::- 1 ]]]
639+
640+ data = self .curve_data (self .target_index , self .selected_classifiers [0 ])
641+ points = data .merged .points
642+
643+ axis = self .plot .getAxis ("bottom" )
644+ axis .setTicks (enumticks (points .fpr ))
645+
646+ axis = self .plot .getAxis ("left" )
647+ axis .setTicks (enumticks (points .tpr ))
648+
632649 def _on_mouse_moved (self , pos ):
633650 target = self .target_index
634651 selected = self .selected_classifiers
@@ -802,10 +819,19 @@ def roc_curve_for_fold(res, fold, clf_idx, target):
802819 return np .array ([]), np .array ([]), np .array ([])
803820
804821 fold_probs = res .probabilities [clf_idx ][fold ][:, target ]
805- return skl_metrics .roc_curve (
806- fold_actual , fold_probs , pos_label = target
822+ drop_intermediate = len (fold_actual ) > 20
823+ fpr , tpr , thresholds = skl_metrics .roc_curve (
824+ fold_actual , fold_probs , pos_label = target ,
825+ drop_intermediate = drop_intermediate
807826 )
808827
828+ # skl sets the first threshold to the highest threshold in the data + 1
829+ # since we deal with probabilities, we (carefully) set it to 1
830+ # Unrelated comparisons, thus pylint: disable=chained-comparison
831+ if len (thresholds ) > 1 and thresholds [1 ] <= 1 :
832+ thresholds [0 ] = 1
833+ return fpr , tpr , thresholds
834+
809835
810836def roc_curve_vertical_average (curves , samples = 10 ):
811837 if not curves :
@@ -969,5 +995,19 @@ def roc_iso_performance_slope(fp_cost, fn_cost, p):
969995 return (fp_cost * (1. - p )) / (fn_cost * p )
970996
971997
998+ def _create_results (): # pragma: no cover
999+ probs1 = [0.984 , 0.907 , 0.881 , 0.865 , 0.815 , 0.741 , 0.735 , 0.635 ,
1000+ 0.582 , 0.554 , 0.413 , 0.317 , 0.287 , 0.225 , 0.216 , 0.183 ]
1001+ probs = np .array ([[[1 - x , x ] for x in probs1 ]])
1002+ preds = (probs > 0.5 ).astype (float )
1003+ return Results (
1004+ data = Orange .data .Table ("heart_disease" )[:16 ],
1005+ row_indices = np .arange (16 ),
1006+ actual = np .array (list (map (int , "1100111001001000" ))),
1007+ probabilities = probs , predicted = preds
1008+ )
1009+
1010+
9721011if __name__ == "__main__" : # pragma: no cover
1012+ # WidgetPreview(OWROCAnalysis).run(_create_results())
9731013 WidgetPreview (OWROCAnalysis ).run (results_for_preview ())
0 commit comments