Skip to content

Commit 091e88c

Browse files
authored
Merge pull request #6998 from janezd/create-class-re
Create Class: Add regular expressions
2 parents 083c53d + 9bd18c7 commit 091e88c

File tree

3 files changed

+290
-80
lines changed

3 files changed

+290
-80
lines changed

Orange/widgets/data/owcreateclass.py

Lines changed: 148 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Widget for creating classes from non-numeric attribute by substrings"""
22
import re
33
from itertools import count
4+
from typing import Optional, Sequence
45

56
import numpy as np
67

@@ -19,39 +20,71 @@
1920
from Orange.widgets.widget import Msg, Input, Output
2021

2122

22-
def map_by_substring(a, patterns, case_sensitive, match_beginning,
23-
map_values=None):
23+
def map_by_substring(
24+
a: np.ndarray,
25+
patterns: list[str],
26+
case_sensitive: bool, match_beginning: bool, regular_expressions: bool,
27+
map_values: Optional[Sequence[int]] = None) -> np.ndarray:
2428
"""
2529
Map values in a using a list of patterns. The patterns are considered in
2630
order of appearance.
2731
32+
Flags `match_beginning` and `regular_expressions` are incompatible.
33+
2834
Args:
2935
a (np.array): input array of `dtype` `str`
3036
patterns (list of str): list of strings
3137
case_sensitive (bool): case sensitive match
3238
match_beginning (bool): match only at the beginning of the string
33-
map_values (list of int): list of len(pattens);
34-
contains return values for each pattern
39+
map_values (list of int, optional):
40+
list of len(patterns); return values for each pattern
41+
regular_expressions (bool): use regular expressions
3542
3643
Returns:
3744
np.array of floats representing indices of matched patterns
3845
"""
46+
assert not (regular_expressions and match_beginning)
3947
if map_values is None:
4048
map_values = np.arange(len(patterns))
4149
else:
4250
map_values = np.array(map_values, dtype=int)
4351
res = np.full(len(a), np.nan)
44-
if not case_sensitive:
52+
if not case_sensitive and not regular_expressions:
4553
a = np.char.lower(a)
4654
patterns = (pattern.lower() for pattern in patterns)
4755
for val_idx, pattern in reversed(list(enumerate(patterns))):
48-
indices = np.char.find(a, pattern)
49-
matches = indices == 0 if match_beginning else indices != -1
56+
# Note that similar code repeats in update_counts. Any changes here
57+
# should be reflected there.
58+
if regular_expressions:
59+
re_pattern = re.compile(pattern,
60+
re.IGNORECASE if not case_sensitive else 0)
61+
matches = np.array([bool(re_pattern.search(s)) for s in a],
62+
dtype=bool)
63+
else:
64+
indices = np.char.find(a, pattern)
65+
matches = indices == 0 if match_beginning else indices != -1
5066
res[matches] = map_values[val_idx]
5167
return res
5268

5369

54-
class ValueFromStringSubstring(Transformation):
70+
class _EqHashMixin:
71+
def __eq__(self, other):
72+
return super().__eq__(other) \
73+
and self.patterns == other.patterns \
74+
and self.case_sensitive == other.case_sensitive \
75+
and self.match_beginning == other.match_beginning \
76+
and self.regular_expressions == other.regular_expressions \
77+
and np.all(self.map_values == other.map_values)
78+
79+
def __hash__(self):
80+
return hash((type(self), self.variable,
81+
tuple(self.patterns),
82+
self.case_sensitive, self.match_beginning,
83+
self.regular_expressions,
84+
None if self.map_values is None else tuple(self.map_values)
85+
))
86+
87+
class ValueFromStringSubstring(_EqHashMixin, Transformation):
5588
"""
5689
Transformation that computes a discrete variable from a string variable by
5790
pattern matching.
@@ -67,15 +100,28 @@ class ValueFromStringSubstring(Transformation):
67100
sensitive
68101
match_beginning (bool, optional): if set to `True`, the pattern must
69102
appear at the beginning of the string
103+
map_values (list of int, optional): return values for each pattern
104+
regular_expressions (bool, optional): if set to `True`, the patterns are
70105
"""
71-
def __init__(self, variable, patterns,
72-
case_sensitive=False, match_beginning=False, map_values=None):
106+
# regular_expressions was added later and at the end (instead of with other
107+
# flags) for compatibility with older existing pickles
108+
def __init__(
109+
self,
110+
variable: StringVariable,
111+
patterns: list[str],
112+
case_sensitive: bool = False,
113+
match_beginning: bool = False,
114+
map_values: Optional[Sequence[int]] = None,
115+
regular_expressions: bool = False):
73116
super().__init__(variable)
74117
self.patterns = patterns
75118
self.case_sensitive = case_sensitive
76119
self.match_beginning = match_beginning
120+
self.regular_expressions = regular_expressions
77121
self.map_values = map_values
78122

123+
InheritEq = True
124+
79125
def transform(self, c):
80126
"""
81127
Transform the given data.
@@ -90,26 +136,14 @@ def transform(self, c):
90136
c = c.astype(str)
91137
c[nans] = ""
92138
res = map_by_substring(
93-
c, self.patterns, self.case_sensitive, self.match_beginning,
139+
c, self.patterns,
140+
self.case_sensitive, self.match_beginning, self.regular_expressions,
94141
self.map_values)
95142
res[nans] = np.nan
96143
return res
97144

98-
def __eq__(self, other):
99-
return super().__eq__(other) \
100-
and self.patterns == other.patterns \
101-
and self.case_sensitive == other.case_sensitive \
102-
and self.match_beginning == other.match_beginning \
103-
and self.map_values == other.map_values
104145

105-
def __hash__(self):
106-
return hash((type(self), self.variable,
107-
tuple(self.patterns),
108-
self.case_sensitive, self.match_beginning,
109-
self.map_values))
110-
111-
112-
class ValueFromDiscreteSubstring(Lookup):
146+
class ValueFromDiscreteSubstring(_EqHashMixin, Lookup):
113147
"""
114148
Transformation that computes a discrete variable from discrete variable by
115149
pattern matching.
@@ -126,16 +160,29 @@ class ValueFromDiscreteSubstring(Lookup):
126160
sensitive
127161
match_beginning (bool, optional): if set to `True`, the pattern must
128162
appear at the beginning of the string
163+
map_values (list of int, optional): return values for each pattern
164+
regular_expressions (bool, optional): if set to `True`, the patterns are
165+
129166
"""
130-
def __init__(self, variable, patterns,
131-
case_sensitive=False, match_beginning=False,
132-
map_values=None):
167+
# regular_expressions was added later and at the end (instead of with other
168+
# flags) for compatibility with older existing pickles
169+
def __init__(
170+
self,
171+
variable: DiscreteVariable,
172+
patterns: list[str],
173+
case_sensitive: bool = False,
174+
match_beginning: bool = False,
175+
map_values: Optional[Sequence[int]] = None,
176+
regular_expressions: bool = False):
133177
super().__init__(variable, [])
134178
self.case_sensitive = case_sensitive
135179
self.match_beginning = match_beginning
136180
self.map_values = map_values
181+
self.regular_expressions = regular_expressions
137182
self.patterns = patterns # Finally triggers computation of the lookup
138183

184+
InheritEq = True
185+
139186
def __setattr__(self, key, value):
140187
"""__setattr__ is overloaded to recompute the lookup table when the
141188
patterns, the original attribute or the flags change."""
@@ -145,10 +192,10 @@ def __setattr__(self, key, value):
145192
"variable", "map_values"):
146193
self.lookup_table = map_by_substring(
147194
self.variable.values, self.patterns,
148-
self.case_sensitive, self.match_beginning, self.map_values)
149-
195+
self.case_sensitive, self.match_beginning,
196+
self.regular_expressions, self.map_values)
150197

151-
def unique_in_order_mapping(a):
198+
def unique_in_order_mapping(a: Sequence[str]) -> tuple[list[str], list[int]]:
152199
""" Return
153200
- unique elements of the input list (in the order of appearance)
154201
- indices of the input list onto the returned uniques
@@ -187,6 +234,7 @@ class Outputs:
187234
rules = ContextSetting({})
188235
match_beginning = ContextSetting(False)
189236
case_sensitive = ContextSetting(False)
237+
regular_expressions = ContextSetting(False)
190238

191239
TRANSFORMERS = {StringVariable: ValueFromStringSubstring,
192240
DiscreteVariable: ValueFromDiscreteSubstring}
@@ -202,6 +250,7 @@ class Warning(widget.OWWidget.Warning):
202250
class Error(widget.OWWidget.Error):
203251
class_name_duplicated = Msg("Class name duplicated.")
204252
class_name_empty = Msg("Class name should not be empty.")
253+
invalid_regular_expression = Msg("Invalid regular expression: {}")
205254

206255
def __init__(self):
207256
super().__init__()
@@ -252,9 +301,9 @@ def __init__(self):
252301
rules_box.addWidget(QLabel("Count"), 0, 3, 1, 2)
253302
self.update_rules()
254303

255-
widget = QWidget(patternbox)
256-
widget.setLayout(rules_box)
257-
patternbox.layout().addWidget(widget)
304+
widg = QWidget(patternbox)
305+
widg.setLayout(rules_box)
306+
patternbox.layout().addWidget(widg)
258307

259308
box = gui.hBox(patternbox)
260309
gui.rubber(box)
@@ -264,8 +313,12 @@ def __init__(self):
264313
QSizePolicy.Maximum))
265314

266315
optionsbox = gui.vBox(self.controlArea, "Options")
316+
gui.checkBox(
317+
optionsbox, self, "regular_expressions", "Use regular expressions",
318+
callback=self.options_changed)
267319
gui.checkBox(
268320
optionsbox, self, "match_beginning", "Match only at the beginning",
321+
stateWhenDisabled=False,
269322
callback=self.options_changed)
270323
gui.checkBox(
271324
optionsbox, self, "case_sensitive", "Case sensitive",
@@ -322,6 +375,7 @@ def update_rules(self):
322375
# TODO: Indicator that changes need to be applied
323376

324377
def options_changed(self):
378+
self.controls.match_beginning.setEnabled(not self.regular_expressions)
325379
self.update_counts()
326380

327381
def adjust_n_rule_rows(self):
@@ -344,8 +398,8 @@ def _add_line():
344398
self.rules_box.addWidget(button, n_lines, 0)
345399
self.counts.append([])
346400
for coli, kwargs in enumerate(
347-
(dict(),
348-
dict(styleSheet="color: gray"))):
401+
({},
402+
{"styleSheet": "color: gray"})):
349403
label = QLabel(alignment=Qt.AlignCenter, **kwargs)
350404
self.counts[-1].append(label)
351405
self.rules_box.addWidget(label, n_lines, 3 + coli)
@@ -401,23 +455,48 @@ def class_labels(self):
401455
if re.match("^C\\d+", label)),
402456
default=0)
403457
class_count = count(largest_c + 1)
404-
return [label_edit.text() or "C{}".format(next(class_count))
458+
return [label_edit.text() or f"C{next(class_count)}"
405459
for label_edit, _ in self.line_edits]
406460

461+
def invalid_patterns(self):
462+
if not self.regular_expressions:
463+
return None
464+
for _, pattern in self.active_rules:
465+
try:
466+
re.compile(pattern)
467+
except re.error:
468+
return pattern
469+
return None
470+
407471
def update_counts(self):
408472
"""Recompute and update the counts of matches."""
409-
def _matcher(strings, pattern):
410-
"""Return indices of strings into patterns; consider case
411-
sensitivity and matching at the beginning. The given strings are
412-
assumed to be in lower case if match is case insensitive. Patterns
413-
are fixed on the fly."""
414-
if not self.case_sensitive:
415-
pattern = pattern.lower()
416-
indices = np.char.find(strings, pattern.strip())
417-
return indices == 0 if self.match_beginning else indices != -1
418-
419-
def _lower_if_needed(strings):
420-
return strings if self.case_sensitive else np.char.lower(strings)
473+
if self.regular_expressions:
474+
def _matcher(strings, pattern):
475+
# Note that similar code repeats in map_by_substring.
476+
# Any changes here should be reflected there.
477+
re_pattern = re.compile(
478+
pattern,
479+
re.IGNORECASE if not self.case_sensitive else 0)
480+
return np.array([bool(re_pattern.search(s)) for s in strings],
481+
dtype=bool)
482+
483+
def _lower_if_needed(strings):
484+
return strings
485+
else:
486+
def _matcher(strings, pattern):
487+
"""Return indices of strings into patterns; consider case
488+
sensitivity and matching at the beginning. The given strings are
489+
assumed to be in lower case if match is case insensitive. Patterns
490+
are fixed on the fly."""
491+
# Note that similar code repeats in map_by_substring.
492+
# Any changes here should be reflected there.
493+
if not self.case_sensitive:
494+
pattern = pattern.lower()
495+
indices = np.char.find(strings, pattern.strip())
496+
return indices == 0 if self.match_beginning else indices != -1
497+
498+
def _lower_if_needed(strings):
499+
return strings if self.case_sensitive else np.char.lower(strings)
421500

422501
def _string_counts():
423502
"""
@@ -469,9 +548,9 @@ def _set_labels():
469548
for (n_matched, n_total), (lab_matched, lab_total), (lab, patt) in \
470549
zip(self.match_counts, self.counts, self.active_rules):
471550
n_before = n_total - n_matched
472-
lab_matched.setText("{}".format(n_matched))
551+
lab_matched.setText(f"{n_matched}")
473552
if n_before and (lab or patt):
474-
lab_total.setText("+ {}".format(n_before))
553+
lab_total.setText(f"+ {n_before}")
475554
if n_matched:
476555
tip = f"{n_before} o" \
477556
f"f {n_total} matching {pl(n_total, 'instance')} " \
@@ -496,6 +575,11 @@ def _set_placeholders():
496575
lab_edit.setPlaceholderText(label)
497576

498577
_clear_labels()
578+
if (invalid := self.invalid_patterns()) is not None:
579+
self.Error.invalid_regular_expression(invalid)
580+
return
581+
self.Error.invalid_regular_expression.clear()
582+
499583
attr = self.attribute
500584
if attr is None:
501585
return
@@ -510,6 +594,11 @@ def _set_placeholders():
510594
def apply(self):
511595
"""Output the transformed data."""
512596
self.Error.clear()
597+
if (invalid := self.invalid_patterns()) is not None:
598+
self.Error.invalid_regular_expression(invalid)
599+
self.Outputs.data.send(None)
600+
return
601+
513602
self.class_name = self.class_name.strip()
514603
if not self.attribute:
515604
self.Outputs.data.send(None)
@@ -541,19 +630,21 @@ def _create_variable(self):
541630
if valid)
542631
transformer = self.TRANSFORMERS[type(self.attribute)]
543632

544-
# join patters with the same names
633+
# join patterns with the same names
545634
names, map_values = unique_in_order_mapping(names)
546635
names = tuple(str(a) for a in names)
547636
map_values = tuple(map_values)
548637

549638
var_key = (self.attribute, self.class_name, names,
550-
patterns, self.case_sensitive, self.match_beginning, map_values)
639+
patterns, self.case_sensitive, self.match_beginning,
640+
self.regular_expressions, map_values)
551641
if var_key in self.cached_variables:
552642
return self.cached_variables[var_key]
553643

554644
compute_value = transformer(
555-
self.attribute, patterns, self.case_sensitive, self.match_beginning,
556-
map_values)
645+
self.attribute, patterns, self.case_sensitive,
646+
self.match_beginning and not self.regular_expressions,
647+
map_values, self.regular_expressions)
557648
new_var = DiscreteVariable(
558649
self.class_name, names, compute_value=compute_value)
559650
self.cached_variables[var_key] = new_var
@@ -597,10 +688,10 @@ def _count_part():
597688
for (n_matched, n_total), class_name, (lab, patt) in \
598689
zip(self.match_counts, names, self.active_rules):
599690
if lab or patt or n_total:
600-
output += "<li>{}; {}</li>".format(_cond_part(), _count_part())
691+
output += f"<li>{_cond_part()}; {_count_part()}</li>"
601692
if output:
602693
self.report_items("Output", [("Class name", self.class_name)])
603-
self.report_raw("<ol>{}</ol>".format(output))
694+
self.report_raw(f"<ol>{output}</ol>")
604695

605696

606697
if __name__ == "__main__": # pragma: no cover

0 commit comments

Comments
 (0)