Skip to content

Commit d10b022

Browse files
committed
Permutation Plot: Add widget
1 parent 39bd0ec commit d10b022

File tree

3 files changed

+475
-0
lines changed

3 files changed

+475
-0
lines changed
Lines changed: 40 additions & 0 deletions
Loading
Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
from typing import Optional, Tuple, Callable, List, Dict
2+
3+
import numpy as np
4+
from scipy.stats import spearmanr, linregress
5+
from AnyQt.QtCore import Qt
6+
import pyqtgraph as pg
7+
8+
from orangewidget.utils.visual_settings_dlg import VisualSettingsDialog, \
9+
KeyType, ValueType
10+
from Orange.base import Learner
11+
from Orange.data import Table
12+
from Orange.data.table import DomainTransformationError
13+
from Orange.evaluation import CrossValidation, R2, TestOnTrainingData, Results
14+
from Orange.util import dummy_callback
15+
from Orange.widgets import gui
16+
from Orange.widgets.settings import Setting
17+
from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin, TaskState
18+
from Orange.widgets.utils.widgetpreview import WidgetPreview
19+
from Orange.widgets.visualize.owscatterplotgraph import LegendItem
20+
from Orange.widgets.visualize.utils.customizableplot import \
21+
CommonParameterSetter, Updater
22+
from Orange.widgets.visualize.utils.plotutils import PlotWidget
23+
from Orange.widgets.widget import OWWidget, Input, Msg
24+
25+
N_FOLD = 7
26+
PermutationResults = Tuple[np.ndarray, List, float, float, List, float, float]
27+
28+
29+
def _correlation(y: np.ndarray, y_pred: np.ndarray) -> float:
30+
return spearmanr(y, y_pred)[0] * 100
31+
32+
33+
def _validate(data: Table, learner: Learner) -> Tuple[float, float]:
34+
# dummy call - Validation would silence the exceptions
35+
learner(data)
36+
37+
res: Results = TestOnTrainingData()(data, [learner])
38+
res_cv: Results = CrossValidation(k=N_FOLD)(data, [learner])
39+
# pylint: disable=unsubscriptable-object
40+
return R2(res)[0], R2(res_cv)[0]
41+
42+
43+
def permutation(
44+
data: Table,
45+
learner: Learner,
46+
n_perm: int = 100,
47+
progress_callback: Callable = dummy_callback
48+
) -> PermutationResults:
49+
r2, q2 = _validate(data, learner)
50+
r2_scores = [r2]
51+
q2_scores = [q2]
52+
correlations = [100.0]
53+
progress_callback(0, "Calculating...")
54+
np.random.seed(0)
55+
56+
data_perm = data.copy()
57+
for i in range(n_perm):
58+
progress_callback(i / n_perm)
59+
np.random.shuffle(data_perm.Y)
60+
r2, q2 = _validate(data_perm, learner)
61+
correlations.append(_correlation(data.Y, data_perm.Y))
62+
r2_scores.append(r2)
63+
q2_scores.append(q2)
64+
65+
correlations = np.abs(correlations)
66+
r2_res = linregress([correlations[0], np.mean(correlations[1:])],
67+
[r2_scores[0], np.mean(r2_scores[1:])])
68+
q2_res = linregress([correlations[0], np.mean(correlations[1:])],
69+
[q2_scores[0], np.mean(q2_scores[1:])])
70+
71+
return (correlations,
72+
r2_scores, r2_res.intercept, r2_res.slope,
73+
q2_scores, q2_res.intercept, q2_res.slope,
74+
data.domain.class_var.name)
75+
76+
77+
def run(
78+
data: Table,
79+
learner: Learner,
80+
n_perm: int,
81+
state: TaskState
82+
) -> PermutationResults:
83+
def callback(i: float, status: str = ""):
84+
state.set_progress_value(i * 100)
85+
if status:
86+
state.set_status(status)
87+
if state.is_interruption_requested():
88+
# pylint: disable=broad-exception-raised
89+
raise Exception
90+
91+
return permutation(data, learner, n_perm, callback)
92+
93+
94+
class ParameterSetter(CommonParameterSetter):
95+
GRID_LABEL, SHOW_GRID_LABEL = "Gridlines", "Show"
96+
DEFAULT_ALPHA_GRID, DEFAULT_SHOW_GRID = 80, True
97+
98+
def __init__(self, master):
99+
self.grid_settings: Dict = None
100+
self.master: PermutationPlot = master
101+
super().__init__()
102+
103+
def update_setters(self):
104+
self.grid_settings = {
105+
Updater.ALPHA_LABEL: self.DEFAULT_ALPHA_GRID,
106+
self.SHOW_GRID_LABEL: self.DEFAULT_SHOW_GRID,
107+
}
108+
109+
self.initial_settings = {
110+
self.LABELS_BOX: {
111+
self.FONT_FAMILY_LABEL: self.FONT_FAMILY_SETTING,
112+
self.TITLE_LABEL: self.FONT_SETTING,
113+
self.AXIS_TITLE_LABEL: self.FONT_SETTING,
114+
self.AXIS_TICKS_LABEL: self.FONT_SETTING,
115+
self.LEGEND_LABEL: self.FONT_SETTING,
116+
},
117+
self.PLOT_BOX: {
118+
self.GRID_LABEL: {
119+
self.SHOW_GRID_LABEL: (None, True),
120+
Updater.ALPHA_LABEL: (range(0, 255, 5),
121+
self.DEFAULT_ALPHA_GRID),
122+
},
123+
},
124+
}
125+
126+
def update_grid(**settings):
127+
self.grid_settings.update(**settings)
128+
self.master.showGrid(
129+
x=self.grid_settings[self.SHOW_GRID_LABEL],
130+
y=self.grid_settings[self.SHOW_GRID_LABEL],
131+
alpha=self.grid_settings[Updater.ALPHA_LABEL] / 255)
132+
133+
self._setters[self.PLOT_BOX] = {self.GRID_LABEL: update_grid}
134+
135+
@property
136+
def title_item(self):
137+
return self.master.getPlotItem().titleLabel
138+
139+
@property
140+
def axis_items(self):
141+
return [value["item"] for value in
142+
self.master.getPlotItem().axes.values()]
143+
144+
@property
145+
def legend_items(self):
146+
return self.master.legend.items
147+
148+
149+
class PermutationPlot(PlotWidget):
150+
def __init__(self):
151+
super().__init__(enableMenu=False)
152+
self.legend = self._create_legend()
153+
self.parameter_setter = ParameterSetter(self)
154+
self.setMouseEnabled(False, False)
155+
self.hideButtons()
156+
157+
self.showGrid(True, True)
158+
text = "Correlation between original Y and permuted Y (%)"
159+
self.setLabel(axis="bottom", text=text)
160+
self.setLabel(axis="left", text="R2, Q2")
161+
162+
def _create_legend(self) -> LegendItem:
163+
legend = LegendItem()
164+
legend.setParentItem(self.getViewBox())
165+
legend.anchor((1, 1), (1, 1), offset=(-5, -5))
166+
legend.hide()
167+
return legend
168+
169+
def set_data(
170+
self,
171+
corr: np.ndarray,
172+
r2_scores: List,
173+
r2_intercept: float,
174+
r2_slope: float,
175+
q2_scores: List,
176+
q2_intercept: float,
177+
q2_slope: float,
178+
name: str
179+
):
180+
self.clear()
181+
title = f"{name} Intercepts: " \
182+
f"R2=(0.0, {round(r2_intercept, 4)}), " \
183+
f"Q2=(0.0, {round(q2_intercept, 4)})"
184+
self.setTitle(title)
185+
186+
x = np.array([0, 100])
187+
pen = pg.mkPen("#000", width=2, style=Qt.DashLine)
188+
r2_line = pg.PlotCurveItem(x, r2_intercept + r2_slope * x, pen=pen)
189+
q2_line = pg.PlotCurveItem(x, q2_intercept + q2_slope * x, pen=pen)
190+
191+
point_pen = pg.mkPen("#333")
192+
r2_kwargs = {"pen": point_pen, "symbol": "o", "brush": "#6fa255"}
193+
q2_kwargs = {"pen": point_pen, "symbol": "s", "brush": "#3a78b6"}
194+
195+
kwargs = {"size": 12}
196+
kwargs.update(r2_kwargs)
197+
r2_points = pg.ScatterPlotItem(corr, r2_scores, **kwargs)
198+
kwargs.update(q2_kwargs)
199+
q2_points = pg.ScatterPlotItem(corr, q2_scores, **kwargs)
200+
201+
self.addItem(r2_line)
202+
self.addItem(q2_line)
203+
self.addItem(r2_points)
204+
self.addItem(q2_points)
205+
206+
self.legend.clear()
207+
self.legend.addItem(pg.ScatterPlotItem(**r2_kwargs), "R2")
208+
self.legend.addItem(pg.ScatterPlotItem(**q2_kwargs), "Q2")
209+
self.legend.show()
210+
211+
212+
class OWPermutationPlot(OWWidget, ConcurrentWidgetMixin):
213+
name = "Permutation Plot"
214+
description = "Permutation analysis plotting R2 and Q2"
215+
icon = "icons/PermutationPlot.svg"
216+
priority = 1100
217+
218+
n_permutations = Setting(100)
219+
visual_settings = Setting({}, schema_only=True)
220+
graph_name = "graph.plotItem"
221+
222+
class Inputs:
223+
data = Input("Data", Table)
224+
learner = Input("Lerner", Learner)
225+
226+
class Error(OWWidget.Error):
227+
domain_transform_err = Msg("{}")
228+
unknown_err = Msg("{}")
229+
not_enough_data = Msg(f"At least {N_FOLD} instances are needed.")
230+
incompatible_learner = Msg("{}")
231+
232+
def __init__(self):
233+
OWWidget.__init__(self)
234+
ConcurrentWidgetMixin.__init__(self)
235+
self._data: Optional[Table] = None
236+
self._learner: Optional[Learner] = None
237+
self.graph: PermutationPlot = None
238+
self.setup_gui()
239+
VisualSettingsDialog(
240+
self, self.graph.parameter_setter.initial_settings
241+
)
242+
243+
def setup_gui(self):
244+
self._add_plot()
245+
self._add_controls()
246+
247+
def _add_plot(self):
248+
box = gui.vBox(self.mainArea)
249+
self.graph = PermutationPlot()
250+
box.layout().addWidget(self.graph)
251+
252+
def _add_controls(self):
253+
box = gui.vBox(self.controlArea, "Settings")
254+
gui.spin(box, self, "n_permutations", label="Permutations:",
255+
minv=1, maxv=1000, callback=self._run)
256+
gui.rubber(self.controlArea)
257+
258+
@Inputs.data
259+
def set_data(self, data: Table):
260+
self.Error.not_enough_data.clear()
261+
self._data = data
262+
if self._data and len(self._data) < N_FOLD:
263+
self.Error.not_enough_data()
264+
self._data = None
265+
266+
@Inputs.learner
267+
def set_learner(self, learner: Learner):
268+
self._learner = learner
269+
270+
def handleNewSignals(self):
271+
self.Error.incompatible_learner.clear()
272+
self.Error.unknown_err.clear()
273+
self.Error.domain_transform_err.clear()
274+
self.clear()
275+
if self._data is None or self._learner is None:
276+
return
277+
278+
reason = self._learner.incompatibility_reason(self._data.domain)
279+
if reason:
280+
self.Error.incompatible_learner(reason)
281+
return
282+
283+
self._run()
284+
285+
def clear(self):
286+
self.cancel()
287+
self.graph.clear()
288+
self.graph.setTitle()
289+
290+
def _run(self):
291+
if self._data is None or self._learner is None:
292+
return
293+
self.start(run, self._data, self._learner, self.n_permutations)
294+
295+
def on_done(self, result: PermutationResults):
296+
self.graph.set_data(*result)
297+
298+
def on_exception(self, ex: Exception):
299+
if isinstance(ex, DomainTransformationError):
300+
self.Error.domain_transform_err(ex)
301+
else:
302+
self.Error.unknown_err(ex)
303+
304+
def on_partial_result(self, _):
305+
pass
306+
307+
def onDeleteWidget(self):
308+
self.shutdown()
309+
super().onDeleteWidget()
310+
311+
def send_report(self):
312+
if self._data is None or self._learner is None:
313+
return
314+
self.report_plot()
315+
316+
def set_visual_settings(self, key: KeyType, value: ValueType):
317+
self.graph.parameter_setter.set_parameter(key, value)
318+
# pylint: disable=unsupported-assignment-operation
319+
self.visual_settings[key] = value
320+
321+
322+
if __name__ == "__main__":
323+
from Orange.regression import LinearRegressionLearner
324+
325+
housing = Table("housing")
326+
pls = LinearRegressionLearner()
327+
# permutation(housing, pls)
328+
329+
WidgetPreview(OWPermutationPlot).run(
330+
set_data=housing, set_learner=pls)

0 commit comments

Comments
 (0)