Skip to content

Commit 4d13dab

Browse files
committed
Aggregate Columns: Add widget
1 parent 3617a5b commit 4d13dab

File tree

2 files changed

+319
-0
lines changed

2 files changed

+319
-0
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
from typing import List
2+
3+
import numpy as np
4+
5+
from AnyQt.QtWidgets import QSizePolicy
6+
from AnyQt.QtCore import Qt
7+
from Orange.data import Variable, Table, ContinuousVariable, TimeVariable
8+
from Orange.data.util import get_unique_names
9+
from Orange.widgets import gui, widget
10+
from Orange.widgets.settings import (
11+
ContextSetting, Setting, DomainContextHandler
12+
)
13+
from Orange.widgets.utils.widgetpreview import WidgetPreview
14+
from Orange.widgets.utils.state_summary import format_summary_details
15+
from Orange.widgets.widget import Input, Output
16+
from Orange.widgets.utils.itemmodels import DomainModel
17+
18+
19+
class OWAggregateColumns(widget.OWWidget):
20+
name = "Aggregate Columns"
21+
description = "Compute a sum, max, min ... of selected columns."
22+
icon = "icons/AggregateColumns.svg"
23+
priority = 100
24+
keywords = ["aggregate", "sum", "product", "max", "min", "mean",
25+
"median", "variance"]
26+
27+
class Inputs:
28+
data = Input("Data", Table, default=True)
29+
30+
class Outputs:
31+
data = Output("Data", Table)
32+
33+
want_main_area = False
34+
35+
settingsHandler = DomainContextHandler()
36+
variables: List[Variable] = ContextSetting([])
37+
operation = Setting("Sum")
38+
var_name = Setting("agg")
39+
auto_apply = Setting(True)
40+
41+
Operations = {"Sum": np.nansum, "Product": np.nanprod,
42+
"Min": np.nanmin, "Max": np.nanmax,
43+
"Mean": np.nanmean, "Variance": np.nanvar,
44+
"Median": np.nanmedian}
45+
TimePreserving = ("Min", "Max", "Mean", "Median")
46+
47+
def __init__(self):
48+
super().__init__()
49+
self.data = None
50+
51+
box = gui.vBox(self.controlArea, box=True)
52+
53+
self.variable_model = DomainModel(
54+
order=DomainModel.MIXED, valid_types=(ContinuousVariable, ))
55+
var_list = gui.listView(
56+
box, self, "variables", model=self.variable_model,
57+
callback=self.commit)
58+
var_list.setSelectionMode(var_list.ExtendedSelection)
59+
60+
combo = gui.comboBox(
61+
box, self, "operation",
62+
label="Operator: ", orientation=Qt.Horizontal,
63+
items=list(self.Operations), sendSelectedValue=True,
64+
callback=self.commit
65+
)
66+
combo.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)
67+
68+
gui.lineEdit(
69+
box, self, "var_name",
70+
label="Variable name: ", orientation=Qt.Horizontal,
71+
callback=self.commit
72+
)
73+
74+
gui.auto_apply(self.controlArea, self)
75+
76+
@Inputs.data
77+
def set_data(self, data: Table = None):
78+
self.closeContext()
79+
self.data = data
80+
if self.data:
81+
self.variable_model.set_domain(data.domain)
82+
self.info.set_input_summary(len(self.data),
83+
format_summary_details(self.data))
84+
self.variables.clear()
85+
self.openContext(data)
86+
else:
87+
self.variable_model.set_domain(None)
88+
self.variables.clear()
89+
self.info.set_input_summary(self.info.NoInput)
90+
self.unconditional_commit()
91+
92+
def commit(self):
93+
augmented = self._compute_data()
94+
self.Outputs.data.send(augmented)
95+
if augmented is None:
96+
self.info.set_output_summary(self.info.NoOutput)
97+
else:
98+
self.info.set_output_summary(
99+
len(augmented), format_summary_details(augmented))
100+
101+
def _compute_data(self):
102+
if not self.data or not self.variables:
103+
return self.data
104+
105+
new_col = self._compute_column()
106+
new_var = self._new_var()
107+
return self.data.add_column(new_var, new_col)
108+
109+
def _compute_column(self):
110+
arr = np.empty((len(self.data), len(self.variables)))
111+
for i, var in enumerate(self.variables):
112+
arr[:, i] = self.data.get_column_view(var)[0].astype(float)
113+
func = self.Operations[self.operation]
114+
return func(arr, axis=1)
115+
116+
def _new_var_name(self):
117+
return get_unique_names(self.data.domain, self.var_name)
118+
119+
def _new_var(self):
120+
name = self._new_var_name()
121+
if self.operation in self.TimePreserving \
122+
and all(isinstance(var, TimeVariable) for var in self.variables):
123+
return TimeVariable(name)
124+
return ContinuousVariable(name)
125+
126+
def send_report(self):
127+
# fp for self.variables, pylint: disable=unsubscriptable-object
128+
if not self.data or not self.variables:
129+
return
130+
var_list = ", ".join(f"'{var.name}'"
131+
for var in self.variables[:31][:-1])
132+
if len(self.variables) > 30:
133+
var_list += f" and {len(self.variables) - 30} others"
134+
else:
135+
var_list += f" and '{self.variables[-1].name}'"
136+
self.report_items((
137+
("Output:",
138+
f"'{self._new_var_name()}' as {self.operation.lower()} of {var_list}"
139+
),
140+
))
141+
142+
143+
if __name__ == "__main__": # pragma: no cover
144+
brown = Table("brown-selected")
145+
WidgetPreview(OWAggregateColumns).run(set_data=brown)
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Test methods with long descriptive names can omit docstrings
2+
# pylint: disable=missing-docstring, abstract-method, protected-access
3+
import unittest
4+
from itertools import chain
5+
6+
from unittest.mock import Mock
7+
8+
import numpy as np
9+
10+
from Orange.data import (
11+
Table, Domain,
12+
ContinuousVariable, DiscreteVariable, StringVariable, TimeVariable
13+
)
14+
from Orange.widgets.data.owaggregatecolumns import OWAggregateColumns
15+
from Orange.widgets.tests.base import WidgetTest
16+
from Orange.widgets.utils.state_summary import format_summary_details
17+
18+
19+
class TestOWAggregateColumn(WidgetTest):
20+
def setUp(self):
21+
#: OWAggregateColumns
22+
self.widget = self.create_widget(OWAggregateColumns)
23+
c1, c2, c3 = map(ContinuousVariable, "c1 c2 c3".split())
24+
t1, t2 = map(TimeVariable, "t1 t2".split())
25+
d1, d2, d3 = (DiscreteVariable(n, values=("a", "b", "c"))
26+
for n in "d1 d2 d3".split())
27+
s1 = StringVariable("s1")
28+
domain1 = Domain([c1, c2, d1, d2, t1], [d3], [s1, c3, t2])
29+
self.data1 = Table.from_list(domain1,
30+
[[0, 1, 0, 1, 2, 0, "foo", 0, 3],
31+
[3, 1, 0, 1, 42, 0, "bar", 0, 4]])
32+
33+
domain2 = Domain([ContinuousVariable("c4")])
34+
self.data2 = Table.from_list(domain2, [[4], [5]])
35+
36+
def test_no_input(self):
37+
widget = self.widget
38+
domain = self.data1.domain
39+
input_sum = widget.info.set_input_summary = Mock()
40+
output_sum = widget.info.set_output_summary = Mock()
41+
42+
self.send_signal(widget.Inputs.data, self.data1)
43+
self.assertEqual(widget.variables, [])
44+
widget.commit()
45+
output = self.get_output(self.widget.Outputs.data)
46+
self.assertIs(output, self.data1)
47+
input_sum.assert_called_with(len(self.data1),
48+
format_summary_details(self.data1))
49+
output_sum.assert_called_with(len(output),
50+
format_summary_details(output))
51+
52+
widget.variables = [domain[n] for n in "c1 c2 t2".split()]
53+
widget.commit()
54+
output = self.get_output(self.widget.Outputs.data)
55+
self.assertIsNotNone(output)
56+
output_sum.assert_called_with(len(output),
57+
format_summary_details(output))
58+
59+
self.send_signal(widget.Inputs.data, None)
60+
widget.commit()
61+
self.assertIsNone(self.get_output(self.widget.Outputs.data))
62+
input_sum.assert_called_with(widget.info.NoInput)
63+
output_sum.assert_called_with(widget.info.NoOutput)
64+
65+
def test_compute_data(self):
66+
domain = self.data1.domain
67+
self.send_signal(self.widget.Inputs.data, self.data1)
68+
self.widget.variables = [domain[n] for n in "c1 c2 t2".split()]
69+
70+
self.widget.operation = "Sum"
71+
output = self.widget._compute_data()
72+
self.assertEqual(output.domain.attributes[:-1], domain.attributes)
73+
np.testing.assert_equal(output.X[:, -1], [4, 8])
74+
75+
self.widget.operation = "Max"
76+
output = self.widget._compute_data()
77+
self.assertEqual(output.domain.attributes[:-1], domain.attributes)
78+
np.testing.assert_equal(output.X[:, -1], [3, 4])
79+
80+
def test_var_name(self):
81+
domain = self.data1.domain
82+
self.send_signal(self.widget.Inputs.data, self.data1)
83+
self.widget.variables = self.widget.variable_model[:]
84+
85+
self.widget.var_name = "test"
86+
output = self.widget._compute_data()
87+
self.assertEqual(output.domain.attributes[-1].name, "test")
88+
89+
self.widget.var_name = "d1"
90+
output = self.widget._compute_data()
91+
self.assertNotIn(
92+
output.domain.attributes[-1].name,
93+
[var.name for var in chain(domain.variables, domain.metas)])
94+
95+
def test_var_types(self):
96+
domain = self.data1.domain
97+
self.send_signal(self.widget.Inputs.data, self.data1)
98+
99+
self.widget.variables = [domain[n] for n in "t1 c2 t2".split()]
100+
for self.widget.operation in self.widget.Operations:
101+
self.assertIsInstance(self.widget._new_var(), ContinuousVariable)
102+
103+
self.widget.variables = [domain[n] for n in "t1 t2".split()]
104+
for self.widget.operation in self.widget.Operations:
105+
self.assertIsInstance(
106+
self.widget._new_var(),
107+
TimeVariable
108+
if self.widget.operation in ("Min", "Max", "Mean", "Median")
109+
else ContinuousVariable)
110+
111+
def test_operations(self):
112+
domain = self.data1.domain
113+
self.send_signal(self.widget.Inputs.data, self.data1)
114+
self.widget.variables = [domain[n] for n in "c1 c2 t2".split()]
115+
116+
m1, m2 = 4 / 3, 8 / 3
117+
for self.widget.operation, expected in {
118+
"Sum": [4, 8], "Product": [0, 12],
119+
"Min": [0, 1], "Max": [3, 4],
120+
"Mean": [m1, m2],
121+
"Variance": [(m1 ** 2 + (m1 - 1) ** 2 + (m1 - 3) ** 2) / 3,
122+
((m2 - 3) ** 2 + (m2 - 1) ** 2 + (m2 - 4) ** 2) / 3],
123+
"Median": [1, 3]}.items():
124+
np.testing.assert_equal(
125+
self.widget._compute_column(), expected,
126+
err_msg=f"error in '{self.widget.operation}'")
127+
128+
def test_operations_with_nan(self):
129+
domain = self.data1.domain
130+
self.send_signal(self.widget.Inputs.data, self.data1)
131+
self.data1.X[1, 0] = np.nan
132+
self.widget.variables = [domain[n] for n in "c1 c2 t2".split()]
133+
134+
m1, m2 = 4 / 3, 5 / 2
135+
for self.widget.operation, expected in {
136+
"Sum": [4, 5], "Product": [0, 4],
137+
"Min": [0, 1], "Max": [3, 4],
138+
"Mean": [m1, m2],
139+
"Variance": [(m1 ** 2 + (m1 - 1) ** 2 + (m1 - 3) ** 2) / 3,
140+
((m2 - 1) ** 2 + (m2 - 4) ** 2) / 2],
141+
"Median": [1, 2.5]}.items():
142+
np.testing.assert_equal(
143+
self.widget._compute_column(), expected,
144+
err_msg=f"error in '{self.widget.operation}'")
145+
146+
def test_contexts(self):
147+
domain = self.data1.domain
148+
self.send_signal(self.widget.Inputs.data, self.data1)
149+
self.widget.variables = [domain[n] for n in "c1 c2 t2".split()]
150+
saved = self.widget.variables[:]
151+
152+
self.send_signal(self.widget.Inputs.data, self.data2)
153+
self.assertEqual(self.widget.variables, [])
154+
155+
self.send_signal(self.widget.Inputs.data, self.data1)
156+
self.assertEqual(self.widget.variables, saved)
157+
158+
def test_report(self):
159+
self.widget.send_report()
160+
161+
domain = self.data1.domain
162+
self.send_signal(self.widget.Inputs.data, self.data1)
163+
self.widget.variables = [domain[n] for n in "c1 c2 t2".split()]
164+
self.widget.send_report()
165+
166+
domain3 = Domain([ContinuousVariable(f"c{i:02}") for i in range(100)])
167+
data3 = Table.from_numpy(domain3, np.zeros((2, 100)))
168+
self.send_signal(self.widget.Inputs.data, data3)
169+
self.widget.variables[:] = self.widget.variable_model[:]
170+
self.widget.send_report()
171+
172+
173+
if __name__ == "__main__":
174+
unittest.main()

0 commit comments

Comments
 (0)