1+ import math
2+ from typing import List , Callable
13from xml .sax .saxutils import escape
24
35import numpy as np
6+ import scipy .stats as ss
47from scipy .stats import linregress
58from sklearn .neighbors import NearestNeighbors
69from sklearn .metrics import r2_score
@@ -104,6 +107,8 @@ def update_lines(**settings):
104107 self .reg_line_settings .update (** settings )
105108 Updater .update_inf_lines (self .reg_line_items ,
106109 ** self .reg_line_settings )
110+ Updater .update_lines (self .ellipse_items ,
111+ ** self .reg_line_settings )
107112 self .master .update_reg_line_label_colors ()
108113
109114 def update_line_label (** settings ):
@@ -129,20 +134,27 @@ def reg_line_label_items(self):
129134 return [line .label for line in self .master .reg_line_items
130135 if hasattr (line , "label" )]
131136
137+ @property
138+ def ellipse_items (self ):
139+ return self .master .ellipse_items
140+
132141
133142class OWScatterPlotGraph (OWScatterPlotBase ):
134143 show_reg_line = Setting (False )
135144 orthonormal_regression = Setting (False )
145+ show_ellipse = Setting (False )
136146 jitter_continuous = Setting (False )
137147
138148 def __init__ (self , scatter_widget , parent ):
139149 super ().__init__ (scatter_widget , parent )
140150 self .parameter_setter = ParameterSetter (self )
141151 self .reg_line_items = []
152+ self .ellipse_items : List [pg .PlotCurveItem ] = []
142153
143154 def clear (self ):
144155 super ().clear ()
145156 self .reg_line_items .clear ()
157+ self .ellipse_items .clear ()
146158
147159 def update_coordinates (self ):
148160 super ().update_coordinates ()
@@ -153,6 +165,7 @@ def update_coordinates(self):
153165 def update_colors (self ):
154166 super ().update_colors ()
155167 self .update_regression_line ()
168+ self .update_ellipse ()
156169
157170 def jitter_coordinates (self , x , y ):
158171 def get_span (attr ):
@@ -255,17 +268,28 @@ def update_density(self):
255268 self .update_reg_line_label_colors ()
256269
257270 def update_regression_line (self ):
258- for line in self .reg_line_items :
259- self .plot_widget .removeItem (line )
260- self .reg_line_items .clear ()
261- if not (self .show_reg_line
262- and self .master .can_draw_regresssion_line ()):
271+ self ._update_curve (self .reg_line_items ,
272+ self .show_reg_line ,
273+ self ._add_line )
274+ self .update_reg_line_label_colors ()
275+
276+ def update_ellipse (self ):
277+ self ._update_curve (self .ellipse_items ,
278+ self .show_ellipse ,
279+ self ._add_ellipse )
280+
281+ def _update_curve (self , items : List , show : bool , add : Callable ):
282+ for item in items :
283+ self .plot_widget .removeItem (item )
284+ items .clear ()
285+ if not (show and self .master .can_draw_regression_line ()):
263286 return
264287 x , y = self .master .get_coordinates_data ()
265- if x is None :
288+ if x is None or len ( x ) < 2 :
266289 return
267- self ._add_line (x , y , QColor ("#505050" ))
268- if self .master .is_continuous_color () or self .palette is None :
290+ add (x , y , QColor ("#505050" ))
291+ if self .master .is_continuous_color () or self .palette is None \
292+ or len (self .palette ) == 0 :
269293 return
270294 c_data = self .master .get_color_data ()
271295 if c_data is None :
@@ -274,8 +298,40 @@ def update_regression_line(self):
274298 for val in range (c_data .max () + 1 ):
275299 mask = c_data == val
276300 if mask .sum () > 1 :
277- self ._add_line (x [mask ], y [mask ], self .palette [val ].darker (135 ))
278- self .update_reg_line_label_colors ()
301+ add (x [mask ], y [mask ], self .palette [val ].darker (135 ))
302+
303+ def _add_ellipse (self , x : np .ndarray , y : np .ndarray , color : QColor ) -> np .ndarray :
304+ # https://github.com/ChristianGoueguel/HotellingEllipse/blob/master/R/ellipseCoord.R
305+ points = np .vstack ([x , y ]).T
306+ mu = np .mean (points , axis = 0 )
307+ cov = np .cov (* (points - mu ).T )
308+ vals , vects = np .linalg .eig (cov )
309+ angle = math .atan2 (vects [1 , 0 ], vects [0 , 0 ])
310+ matrix = np .array ([[np .cos (angle ), - np .sin (angle )],
311+ [np .sin (angle ), np .cos (angle )]])
312+
313+ n = len (x )
314+ f = ss .f .ppf (0.95 , 2 , n - 2 )
315+ f = f * 2 * (n - 1 ) / (n - 2 )
316+ m = [np .pi * i / 100 for i in range (201 )]
317+ cx = np .cos (m ) * np .sqrt (vals [0 ] * f )
318+ cy = np .sin (m ) * np .sqrt (vals [1 ] * f )
319+
320+ pts = np .vstack ([cx , cy ])
321+ pts = matrix .dot (pts )
322+ cx = pts [0 ] + mu [0 ]
323+ cy = pts [1 ] + mu [1 ]
324+
325+ width = self .parameter_setter .reg_line_settings [Updater .WIDTH_LABEL ]
326+ alpha = self .parameter_setter .reg_line_settings [Updater .ALPHA_LABEL ]
327+ style = self .parameter_setter .reg_line_settings [Updater .STYLE_LABEL ]
328+ style = Updater .LINE_STYLES [style ]
329+ color .setAlpha (alpha )
330+
331+ pen = pg .mkPen (color = color , width = width , style = style )
332+ ellipse = pg .PlotCurveItem (cx , cy , pen = pen )
333+ self .plot_widget .addItem (ellipse )
334+ self .ellipse_items .append (ellipse )
279335
280336
281337class OWScatterPlot (OWDataProjectionWidget , VizRankMixin (ScatterPlotVizRank )):
@@ -353,6 +409,12 @@ def _add_controls(self):
353409 "If checked, fit line to group (minimize distance from points);\n "
354410 "otherwise fit y as a function of x (minimize vertical distances)" ,
355411 disabledBy = self .cb_reg_line )
412+ gui .checkBox (
413+ self ._plot_box , self ,
414+ value = "graph.show_ellipse" ,
415+ label = "Show confidence ellipse" ,
416+ tooltip = "Hotelling's T² confidence ellipse (α=95%)" ,
417+ callback = self .graph .update_ellipse )
356418
357419 def _add_controls_axis (self ):
358420 common_options = dict (
@@ -492,7 +554,7 @@ def _point_tooltip(self, point_id, skip_attrs=()):
492554 text = "<b>{}</b><br/><br/>{}" .format (text , others )
493555 return text
494556
495- def can_draw_regresssion_line (self ):
557+ def can_draw_regression_line (self ):
496558 return self .data is not None and \
497559 self .data .domain is not None and \
498560 self .attr_x is not None and self .attr_y is not None and \
@@ -552,7 +614,9 @@ def handleNewSignals(self):
552614 if self ._domain_invalidated :
553615 self .graph .update_axes ()
554616 self ._domain_invalidated = False
555- self .cb_reg_line .setEnabled (self .can_draw_regresssion_line ())
617+ can_plot = self .can_draw_regression_line ()
618+ self .cb_reg_line .setEnabled (can_plot )
619+ self .graph .controls .show_ellipse .setEnabled (can_plot )
556620
557621 @Inputs .features
558622 def set_shown_attributes (self , attributes ):
@@ -578,7 +642,9 @@ def set_attr_from_combo(self):
578642 self .vizrankAutoSelect .emit ([self .attr_x , self .attr_y ])
579643
580644 def attr_changed (self ):
581- self .cb_reg_line .setEnabled (self .can_draw_regresssion_line ())
645+ can_plot = self .can_draw_regression_line ()
646+ self .cb_reg_line .setEnabled (can_plot )
647+ self .graph .controls .show_ellipse .setEnabled (can_plot )
582648 self .setup_plot ()
583649 self .commit .deferred ()
584650
0 commit comments