Skip to content

Commit cccfb61

Browse files
committed
Create class: Add case sensitivity, match at beginning
1 parent 86ac64f commit cccfb61

File tree

2 files changed

+199
-54
lines changed

2 files changed

+199
-54
lines changed

Orange/widgets/data/owcreateclass.py

Lines changed: 95 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,52 @@
1313
from Orange.widgets.widget import Msg
1414

1515

16-
def map_by_substring(a, patterns):
16+
def map_by_substring(a, patterns, case_sensitive, at_beginning):
1717
res = np.full(len(a), np.nan)
18+
if not case_sensitive:
19+
a = np.char.lower(a)
20+
patterns = (pattern.lower() for pattern in patterns)
1821
for val_idx, pattern in reversed(list(enumerate(patterns))):
19-
res[np.char.find(a, pattern) != -1] = val_idx
22+
indices = np.char.find(a, pattern)
23+
matches = indices == 0 if at_beginning else indices != -1
24+
res[matches] = val_idx
2025
return res
2126

2227

23-
class ClassFromStringSubstring(Transformation):
24-
def __init__(self, variable, patterns):
28+
class ValueFromStringSubstring(Transformation):
29+
def __init__(self, variable, patterns,
30+
case_sensitive=False, match_beginning=False):
2531
super().__init__(variable)
2632
self.patterns = patterns
33+
self.case_sensitive = case_sensitive
34+
self.match_beginning = match_beginning
2735

2836
def transform(self, c):
2937
nans = np.equal(c, None)
3038
c = c.astype(str)
3139
c[nans] = ""
32-
res = map_by_substring(c, self.patterns)
40+
res = map_by_substring(
41+
c, self.patterns, self.case_sensitive, self.match_beginning)
3342
res[nans] = np.nan
3443
return res
3544

3645

37-
class ClassFromDiscreteSubstring(Lookup):
38-
def __init__(self, variable, patterns):
39-
lookup_table = map_by_substring(variable.values, patterns)
40-
super().__init__(variable, lookup_table)
46+
class ValueFromDiscreteSubstring(Lookup):
47+
def __init__(self, variable, patterns,
48+
case_sensitive=False, match_beginning=False):
49+
super().__init__(variable, [])
50+
self.case_sensitive = case_sensitive
51+
self.match_beginning = match_beginning
52+
self.patterns = patterns # Finally triggers computation of the lookup
53+
54+
def __setattr__(self, key, value):
55+
super().__setattr__(key, value)
56+
if hasattr(self, "patterns") and \
57+
key in ("case_sensitive", "match_beginning", "patterns",
58+
"variable"):
59+
self.lookup_table = map_by_substring(
60+
self.variable.values, self.patterns,
61+
self.case_sensitive, self.match_beginning)
4162

4263

4364
class OWCreateClass(widget.OWWidget):
@@ -55,9 +76,11 @@ class OWCreateClass(widget.OWWidget):
5576
settingsHandler = DomainContextHandler()
5677
attribute = ContextSetting(None)
5778
rules = ContextSetting({})
79+
match_beginning = ContextSetting(False)
80+
case_sensitive = ContextSetting(False)
5881

59-
TRANSFORMERS = {StringVariable: ClassFromStringSubstring,
60-
DiscreteVariable: ClassFromDiscreteSubstring}
82+
TRANSFORMERS = {StringVariable: ValueFromStringSubstring,
83+
DiscreteVariable: ValueFromDiscreteSubstring}
6184

6285
class Warning(widget.OWWidget.Warning):
6386
no_nonnumeric_vars = Msg("Data contains only numeric variables.")
@@ -68,17 +91,19 @@ def __init__(self):
6891
self.line_edits = []
6992
self.remove_buttons = []
7093
self.counts = []
94+
self.match_counts = []
7195
self.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Maximum)
7296

73-
box = gui.hBox(self.controlArea)
97+
patternbox = gui.vBox(self.controlArea, box="Patterns")
98+
box = gui.hBox(patternbox)
7499
gui.widgetLabel(box, "Class from column: ", addSpace=12)
75100
gui.comboBox(
76101
box, self, "attribute", callback=self.update_rules,
77102
model=DomainModel(valid_types=(StringVariable, DiscreteVariable)),
78103
sizePolicy=(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed))
79104

80105
self.rules_box = rules_box = QGridLayout()
81-
self.controlArea.layout().addLayout(self.rules_box)
106+
patternbox.layout().addLayout(self.rules_box)
82107
self.add_button = gui.button(None, self, "+", flat=True,
83108
callback=self.add_row,
84109
minimumSize=QSize(12, 20))
@@ -92,6 +117,14 @@ def __init__(self):
92117
rules_box.addWidget(QLabel("#Instances"), 0, 3, 1, 2)
93118
self.update_rules()
94119

120+
optionsbox = gui.vBox(self.controlArea, box=True)
121+
gui.checkBox(
122+
optionsbox, self, "match_beginning", "Match only at the beginning",
123+
callback=self.options_changed)
124+
gui.checkBox(
125+
optionsbox, self, "case_sensitive", "Case sensitive",
126+
callback=self.options_changed)
127+
95128
box = gui.hBox(self.controlArea)
96129
gui.rubber(box)
97130
gui.button(box, self, "Apply", autoDefault=False, callback=self.apply)
@@ -127,6 +160,9 @@ def update_rules(self):
127160
self.rules_to_edits()
128161
self.update_counts()
129162

163+
def options_changed(self):
164+
self.update_counts()
165+
130166
def adjust_n_rule_rows(self):
131167
def _add_line():
132168
self.line_edits.append([])
@@ -191,61 +227,81 @@ def sync_edit(self, text):
191227
self.update_counts()
192228

193229
def update_counts(self):
194-
def _set_labels(labels, matching, total_matching):
195-
n_matched = int(np.sum(matching))
196-
n_before = int(np.sum(total_matching)) - n_matched
197-
labels[0].setText("{}".format(n_matched))
198-
if n_before:
199-
labels[1].setText("+ {}".format(n_before))
200-
201-
def _string_counts(data):
230+
def _matcher(strings, pattern):
231+
if not self.case_sensitive:
232+
pattern = pattern.lower()
233+
indices = np.char.find(strings, pattern)
234+
return indices == 0 if self.match_beginning else indices != -1
235+
236+
def _lower_if_needed(strings):
237+
return strings if self.case_sensitive else np.char.lower(strings)
238+
239+
def _string_counts():
240+
nonlocal data
202241
data = data.astype(str)
203242
data = data[~np.char.equal(data, "")]
243+
data = _lower_if_needed(data)
204244
remaining = np.array(data)
205-
for labels, (_, pattern) in zip(self.counts, self.active_rules):
206-
matching = np.char.find(remaining, pattern) != -1
207-
total_matching = np.char.find(data, pattern) != -1
208-
_set_labels(labels, matching, total_matching)
245+
for _, pattern in self.active_rules:
246+
matching = _matcher(remaining, pattern)
247+
total_matching = _matcher(data, pattern)
248+
yield matching, total_matching
209249
remaining = remaining[~matching]
210250
if len(remaining) == 0:
211251
break
212252

213-
def _discrete_counts(data):
253+
def _discrete_counts():
214254
attr_vals = np.array(attr.values)
255+
attr_vals = _lower_if_needed(attr_vals)
215256
bins = bincount(data, max_val=len(attr.values) - 1)[0]
216257
remaining = np.array(bins)
217-
for labels, (_, pattern) in zip(self.counts, self.active_rules):
218-
matching = np.char.find(attr_vals, pattern) != -1
219-
_set_labels(labels, remaining[matching], bins[matching])
258+
for _, pattern in self.active_rules:
259+
matching = _matcher(attr_vals, pattern)
260+
yield remaining[matching], bins[matching]
220261
remaining[matching] = 0
221262
if not np.any(remaining):
222263
break
223264

224-
for labels in self.counts:
225-
for label in labels:
226-
label.setText("")
265+
def _clear_labels():
266+
for lab_matched, lab_total in self.counts:
267+
lab_matched.setText("")
268+
lab_total.setText("")
269+
270+
def _set_labels():
271+
for (n_matched, n_total), (lab_matched, lab_total) in \
272+
zip(self.match_counts, self.counts):
273+
n_before = n_total - n_matched
274+
lab_matched.setText("{}".format(n_matched))
275+
if n_before:
276+
lab_total.setText("+ {}".format(n_before))
277+
278+
_clear_labels()
227279
attr = self.attribute
228280
if attr is None:
229281
return
282+
counters = {StringVariable: _string_counts,
283+
DiscreteVariable: _discrete_counts}
230284
data = self.data.get_column_view(attr)[0]
231-
if isinstance(attr, StringVariable):
232-
_string_counts(data)
233-
else:
234-
_discrete_counts(data)
285+
self.match_counts = [[int(np.sum(x)) for x in matches]
286+
for matches in counters[type(attr)]()]
287+
_set_labels()
235288

236289
def apply(self):
237290
if not self.attribute or not self.active_rules:
238291
self.send("Data", None)
239292
return
240293
domain = self.data.domain
294+
# Transposition + stripping
241295
names, patterns = \
242296
zip(*((name.strip(), pattern)
243297
for name, pattern in self.active_rules if name.strip()))
244298
transformer = self.TRANSFORMERS[type(self.attribute)]
299+
compute_value = transformer(
300+
self.attribute, patterns, self.case_sensitive, self.match_beginning)
245301
new_class = DiscreteVariable(
246-
"class", names, compute_value=transformer(self.attribute, patterns))
247-
new_domain = Domain(domain.attributes, new_class,
248-
domain.metas + domain.class_vars)
302+
"class", names, compute_value=compute_value)
303+
new_domain = Domain(
304+
domain.attributes, new_class, domain.metas + domain.class_vars)
249305
new_data = Table(new_domain, self.data)
250306
self.send("Data", new_data)
251307

Orange/widgets/data/tests/test_owcreateclass.py

Lines changed: 104 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from Orange.data import Table, StringVariable, DiscreteVariable
99
from Orange.widgets.data.owcreateclass import (
1010
OWCreateClass,
11-
map_by_substring, ClassFromStringSubstring, ClassFromDiscreteSubstring)
11+
map_by_substring, ValueFromStringSubstring, ValueFromDiscreteSubstring)
1212
from Orange.widgets.tests.base import WidgetTest
1313

1414

@@ -19,34 +19,108 @@ def setUpClass(cls):
1919
cls.arr = np.array(["abcd", "aa", "bcd", "rabc", "x"])
2020

2121
def test_map_by_substring(self):
22-
np.testing.assert_equal(map_by_substring(self.arr, self.patterns),
23-
[0, 1, 2, 0, 3])
24-
np.testing.assert_equal(map_by_substring(self.arr, ["", ""]), 0)
25-
self.assertTrue(np.all(np.isnan(map_by_substring(self.arr, []))))
26-
27-
def test_class_from_string_substring(self):
28-
trans = ClassFromStringSubstring(StringVariable(), self.patterns)
22+
np.testing.assert_equal(
23+
map_by_substring(self.arr,
24+
["abc", "a", "bc", ""],
25+
case_sensitive=True, at_beginning=False),
26+
[0, 1, 2, 0, 3])
27+
np.testing.assert_equal(
28+
map_by_substring(self.arr,
29+
["abc", "a", "Bc", ""],
30+
case_sensitive=True, at_beginning=False),
31+
[0, 1, 3, 0, 3])
32+
np.testing.assert_equal(
33+
map_by_substring(self.arr,
34+
["abc", "a", "Bc", ""],
35+
case_sensitive=False, at_beginning=False),
36+
[0, 1, 2, 0, 3])
37+
np.testing.assert_equal(
38+
map_by_substring(self.arr,
39+
["abc", "a", "bc", ""],
40+
case_sensitive=False, at_beginning=True),
41+
[0, 1, 2, 3, 3])
42+
np.testing.assert_equal(
43+
map_by_substring(self.arr, ["", ""], False, False),
44+
0)
45+
self.assertTrue(np.all(np.isnan(
46+
map_by_substring(self.arr, [], False, False))))
47+
48+
def test_value_from_string_substring(self):
49+
trans = ValueFromStringSubstring(StringVariable(), self.patterns)
2950
arr2 = np.hstack((self.arr.astype(object), [None]))
3051

3152
with patch('Orange.widgets.data.owcreateclass.map_by_substring') as mbs:
3253
trans.transform(self.arr)
33-
arg1, arg2 = mbs.call_args[0]
34-
np.testing.assert_equal(arg1, self.arr)
35-
self.assertEqual(arg2, self.patterns)
54+
a, patterns, case_sensitive, match_beginning = mbs.call_args[0]
55+
np.testing.assert_equal(a, self.arr)
56+
self.assertEqual(patterns, self.patterns)
57+
self.assertFalse(case_sensitive)
58+
self.assertFalse(match_beginning)
3659

3760
trans.transform(arr2)
38-
arg1, arg2 = mbs.call_args[0]
39-
np.testing.assert_equal(arg1,
61+
a, patterns, *_ = mbs.call_args[0]
62+
np.testing.assert_equal(a,
4063
np.hstack((self.arr.astype(str), "")))
4164

4265
np.testing.assert_equal(trans.transform(arr2),
4366
[0, 1, 2, 0, 3, np.nan])
4467

45-
def test_class_from_discrete_substring(self):
46-
trans = ClassFromDiscreteSubstring(
68+
def test_value_string_substring_flags(self):
69+
trans = ValueFromStringSubstring(StringVariable(), self.patterns)
70+
with patch('Orange.widgets.data.owcreateclass.map_by_substring') as mbs:
71+
trans.case_sensitive = True
72+
trans.transform(self.arr)
73+
case_sensitive, match_beginning = mbs.call_args[0][-2:]
74+
self.assertTrue(case_sensitive)
75+
self.assertFalse(match_beginning)
76+
77+
trans.case_sensitive = False
78+
trans.match_beginning = True
79+
trans.transform(self.arr)
80+
case_sensitive, match_beginning = mbs.call_args[0][-2:]
81+
self.assertFalse(case_sensitive)
82+
self.assertTrue(match_beginning)
83+
84+
def test_value_from_discrete_substring(self):
85+
trans = ValueFromDiscreteSubstring(
4786
DiscreteVariable(values=self.arr), self.patterns)
4887
np.testing.assert_equal(trans.lookup_table, [0, 1, 2, 0, 3])
4988

89+
def test_value_from_discrete_substring_flags(self):
90+
trans = ValueFromDiscreteSubstring(
91+
DiscreteVariable(values=self.arr), self.patterns)
92+
with patch('Orange.widgets.data.owcreateclass.map_by_substring') as mbs:
93+
trans.case_sensitive = True
94+
a, patterns, case_sensitive, match_beginning = mbs.call_args[0]
95+
np.testing.assert_equal(a, self.arr)
96+
self.assertEqual(patterns, self.patterns)
97+
self.assertTrue(case_sensitive)
98+
self.assertFalse(match_beginning)
99+
100+
trans.case_sensitive = False
101+
trans.match_beginning = True
102+
a, patterns, case_sensitive, match_beginning = mbs.call_args[0]
103+
np.testing.assert_equal(a, self.arr)
104+
self.assertEqual(patterns, self.patterns)
105+
self.assertFalse(case_sensitive)
106+
self.assertTrue(match_beginning)
107+
108+
arr2 = self.arr[::-1]
109+
trans.variable = DiscreteVariable(values=arr2)
110+
a, patterns, case_sensitive, match_beginning = mbs.call_args[0]
111+
np.testing.assert_equal(a, arr2)
112+
self.assertEqual(patterns, self.patterns)
113+
self.assertFalse(case_sensitive)
114+
self.assertTrue(match_beginning)
115+
116+
patt2 = self.patterns[::-1]
117+
trans.patterns = patt2
118+
a, patterns, case_sensitive, match_beginning = mbs.call_args[0]
119+
np.testing.assert_equal(a, arr2)
120+
self.assertEqual(patterns, patt2)
121+
self.assertFalse(case_sensitive)
122+
self.assertTrue(match_beginning)
123+
50124

51125
class TestOWCreateClass(WidgetTest):
52126
def setUp(self):
@@ -249,6 +323,21 @@ def test_add_remove_lines(self):
249323
outdata = self.get_output("Data")
250324
self.assertIsNone(outdata)
251325

326+
def test_options(self):
327+
def _transformer_flags():
328+
widget.apply()
329+
outdata = self.get_output("Data")
330+
transformer = outdata.domain.class_var.compute_value
331+
return transformer.case_sensitive, transformer.match_beginning
332+
333+
widget = self.widget
334+
self.send_signal("Data", self.heart)
335+
self.assertEqual(_transformer_flags(), (False, False))
336+
widget.controls.case_sensitive.click()
337+
self.assertEqual(_transformer_flags(), (True, False))
338+
widget.controls.case_sensitive.click()
339+
widget.controls.match_beginning.click()
340+
self.assertEqual(_transformer_flags(), (False, True))
252341

253342
if __name__ == "__main__":
254343
unittest.main()

0 commit comments

Comments
 (0)