Skip to content

Commit 67e9629

Browse files
authored
Merge pull request #5975 from ales-erjavec/feature-constructor-opt
[ENH] Feature constructor optimization
2 parents 25ba55a + 92b2864 commit 67e9629

File tree

6 files changed

+141
-54
lines changed

6 files changed

+141
-54
lines changed

Orange/data/variable.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import re
22
import warnings
33
from collections.abc import Iterable
4+
from typing import Sequence
45

56
from datetime import datetime, timedelta, timezone
67
from numbers import Number, Real, Integral
@@ -168,6 +169,44 @@ def __new__(cls, variable, value=Unknown):
168169
self._value = value
169170
return self
170171

172+
@staticmethod
173+
def _as_values_primitive(variable, data) -> Sequence['Value']:
174+
assert variable.is_primitive()
175+
_Value = Value
176+
_float_new = float.__new__
177+
res = [Value(variable, np.nan)] * len(data)
178+
for i, v in enumerate(data):
179+
v = _float_new(_Value, v)
180+
v.variable = variable
181+
res[i] = v
182+
return res
183+
184+
@staticmethod
185+
def _as_values_non_primitive(variable, data) -> Sequence['Value']:
186+
assert not variable.is_primitive()
187+
_Value = Value
188+
_float_new = float.__new__
189+
data_arr = np.array(data, dtype=object)
190+
NA = data_arr == variable.Unknown
191+
fdata = np.full(len(data), np.finfo(float).min)
192+
fdata[NA] = np.nan
193+
res = [Value(variable, Variable.Unknown)] * len(data)
194+
for i, (v, fval) in enumerate(zip(data, fdata)):
195+
val = _float_new(_Value, fval)
196+
val.variable = variable
197+
val._value = v
198+
res[i] = val
199+
return res
200+
201+
@staticmethod
202+
def _as_values(variable, data):
203+
"""Equivalent but faster then `[Value(variable, v) for v in data]
204+
"""
205+
if variable.is_primitive():
206+
return Value._as_values_primitive(variable, data)
207+
else:
208+
return Value._as_values_non_primitive(variable, data)
209+
171210
def __init__(self, _, __=Unknown):
172211
# __new__ does the job, pylint: disable=super-init-not-called
173212
pass

Orange/tests/test_value.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,13 @@ def test_hash(self):
6464
self.assertTrue(val == v and hash(val) == hash(v))
6565
val = Value(DiscreteVariable("var", ["red", "green", "blue"]), 1)
6666
self.assertRaises(TypeError, hash, val)
67+
68+
def test_as_values(self):
69+
x = ContinuousVariable("x")
70+
values = Value._as_values(x, [0., 1., 2.]) # pylint: disable=protected-access
71+
self.assertIsInstance(values[0], Value)
72+
self.assertEqual(values[0], 0)
73+
s = StringVariable("s")
74+
values = Value._as_values(s, ["a", "b", ""]) # pylint: disable=protected-access
75+
self.assertIsInstance(values[0], Value)
76+
self.assertEqual(values[0], "a")

Orange/widgets/data/oweditdomain.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141
from Orange.preprocess.transformation import Transformation, Identity, Lookup
4242
from Orange.widgets import widget, gui, settings
43-
from Orange.widgets.utils import itemmodels
43+
from Orange.widgets.utils import itemmodels, ftry
4444
from Orange.widgets.utils.buttons import FixedSizeButton
4545
from Orange.widgets.utils.itemmodels import signal_blocking
4646
from Orange.widgets.utils.widgetpreview import WidgetPreview
@@ -50,8 +50,6 @@
5050
MArray = np.ma.MaskedArray
5151
DType = Union[np.dtype, type]
5252

53-
A = TypeVar("A") # pylint: disable=invalid-name
54-
B = TypeVar("B") # pylint: disable=invalid-name
5553
V = TypeVar("V", bound=Orange.data.Variable) # pylint: disable=invalid-name
5654
H = TypeVar("H", bound=Hashable) # pylint: disable=invalid-name
5755

@@ -2631,21 +2629,6 @@ def apply_transform_string(var, trs):
26312629
return variable
26322630

26332631

2634-
def ftry(
2635-
func: Callable[..., A],
2636-
error: Union[Type[BaseException], Tuple[Type[BaseException]]],
2637-
default: B
2638-
) -> Callable[..., Union[A, B]]:
2639-
"""
2640-
Wrap a `func` such that if `errors` occur `default` is returned instead."""
2641-
def wrapper(*args, **kwargs):
2642-
try:
2643-
return func(*args, **kwargs)
2644-
except error:
2645-
return default
2646-
return wrapper
2647-
2648-
26492632
class DictMissingConst(dict):
26502633
"""
26512634
`dict` with a constant for `__missing__()` value.

Orange/widgets/data/owfeatureconstructor.py

Lines changed: 66 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from traceback import format_exception_only
2020
from collections import namedtuple, OrderedDict
21-
from itertools import chain, count
21+
from itertools import chain, count, starmap
2222
from typing import List, Dict, Any
2323

2424
import numpy as np
@@ -32,15 +32,19 @@
3232
from orangewidget.utils.combobox import ComboBoxSearch
3333

3434
import Orange
35+
from Orange.data import Variable, Table, Value, Instance
3536
from Orange.data.util import get_unique_names
3637
from Orange.widgets import gui
3738
from Orange.widgets.settings import ContextSetting, DomainContextHandler
38-
from Orange.widgets.utils import itemmodels, vartype
39+
from Orange.widgets.utils import (
40+
itemmodels, vartype, ftry, unique_everseen as unique
41+
)
3942
from Orange.widgets.utils.sql import check_sql_input
4043
from Orange.widgets import report
4144
from Orange.widgets.utils.widgetpreview import WidgetPreview
4245
from Orange.widgets.widget import OWWidget, Msg, Input, Output
4346

47+
4448
FeatureDescriptor = \
4549
namedtuple("FeatureDescriptor", ["name", "expression"])
4650

@@ -729,11 +733,14 @@ def duplicateFeature(self):
729733

730734
@staticmethod
731735
def check_attrs_values(attr, data):
732-
for i in range(len(data)):
733-
for var in attr:
734-
if not math.isnan(data[i, var]) \
735-
and int(data[i, var]) >= len(var.values):
736-
return var.name
736+
for var in attr:
737+
col, _ = data.get_column_view(var)
738+
mask = ~np.isnan(col)
739+
grater_or_equal = np.greater_equal(
740+
col, len(var.values), out=mask, where=mask
741+
)
742+
if grater_or_equal.any():
743+
return var.name
737744
return None
738745

739746
def _validate_descriptors(self, desc):
@@ -1162,25 +1169,59 @@ def __init__(self, expression, args, extra_env=None, cast=None, use_values=False
11621169
self.mask_exceptions = True
11631170
self.use_values = use_values
11641171

1165-
def __call__(self, instance, *_):
1166-
if isinstance(instance, Orange.data.Table):
1167-
return [self(inst) for inst in instance]
1172+
def __call__(self, table, *_):
1173+
if isinstance(table, Table):
1174+
return self.__call_table(table)
11681175
else:
1169-
try:
1170-
args = [str(instance[var]) if var.is_string
1171-
else var.values[int(instance[var])] if var.is_discrete and not self.use_values
1172-
else instance[var]
1173-
for _, var in self.args]
1174-
y = self.func(*args)
1175-
# user's expression can contain arbitrary errors
1176-
# this also covers missing attributes
1177-
except: # pylint: disable=bare-except
1178-
if not self.mask_exceptions:
1179-
raise
1180-
return np.nan
1181-
if self.cast:
1182-
y = self.cast(y)
1183-
return y
1176+
return self.__call_instance(table)
1177+
1178+
def __call_table(self, table):
1179+
try:
1180+
cols = [self.extract_column(table, var) for _, var in self.args]
1181+
except ValueError:
1182+
if self.mask_exceptions:
1183+
return np.full(len(table), np.nan)
1184+
else:
1185+
raise
1186+
1187+
if not cols:
1188+
args = [()] * len(table)
1189+
else:
1190+
args = zip(*cols)
1191+
f = self.func
1192+
if self.mask_exceptions:
1193+
y = list(starmap(ftry(f, Exception, np.nan), args))
1194+
else:
1195+
y = list(starmap(f, args))
1196+
if self.cast is not None:
1197+
cast = self.cast
1198+
y = [cast(y_) for y_ in y]
1199+
return y
1200+
1201+
def __call_instance(self, instance: Instance):
1202+
table = Table.from_numpy(
1203+
instance.domain,
1204+
np.array([instance.x]),
1205+
np.array([instance.y]),
1206+
np.array([instance.metas]),
1207+
)
1208+
return self.__call_table(table)[0]
1209+
1210+
def extract_column(self, table: Table, var: Variable):
1211+
data, _ = table.get_column_view(var)
1212+
if var.is_string:
1213+
return list(map(var.str_val, data))
1214+
elif var.is_discrete and not self.use_values:
1215+
values = np.array([*var.values, None], dtype=object)
1216+
idx = data.astype(int)
1217+
idx[~np.isfinite(data)] = len(values) - 1
1218+
return values[idx].tolist()
1219+
elif var.is_time: # time always needs Values due to str(val) formatting
1220+
return Value._as_values(var, data.tolist()) # pylint: disable=protected-access
1221+
elif not self.use_values:
1222+
return data.tolist()
1223+
else:
1224+
return Value._as_values(var, data.tolist()) # pylint: disable=protected-access
11841225

11851226
def __reduce__(self):
11861227
return type(self), (self.expression, self.args,
@@ -1190,15 +1231,5 @@ def __repr__(self):
11901231
return "{0.__name__}{1!r}".format(*self.__reduce__())
11911232

11921233

1193-
def unique(seq):
1194-
seen = set()
1195-
unique_el = []
1196-
for el in seq:
1197-
if el not in seen:
1198-
unique_el.append(el)
1199-
seen.add(el)
1200-
return unique_el
1201-
1202-
12031234
if __name__ == "__main__": # pragma: no cover
12041235
WidgetPreview(OWFeatureConstructor).run(Orange.data.Table("iris"))

Orange/widgets/data/tests/test_owfeatureconstructor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,12 @@ def test_missing_variable(self):
305305
self.assertTrue(np.all(np.isnan(r)))
306306
self.assertTrue(np.isnan(f(data2[0])))
307307

308+
def test_time_str(self):
309+
data = Table.from_numpy(Domain([TimeVariable("T", have_date=True)]), [[0], [0]])
310+
f = FeatureFunc("str(T)", [("T", data.domain[0])])
311+
c = f(data)
312+
self.assertEqual(c, ["1970-01-01", "1970-01-01"])
313+
308314
def test_invalid_expression_variable(self):
309315
iris = Table("iris")
310316
f = FeatureFunc("1 / petal_length",

Orange/widgets/utils/__init__.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44
from collections import deque
55
from typing import (
6-
TypeVar, Callable, Any, Iterable, Optional, Hashable, Type, Union
6+
TypeVar, Callable, Any, Iterable, Optional, Hashable, Type, Union, Tuple
77
)
88
from xml.sax.saxutils import escape
99

@@ -92,6 +92,8 @@ def qname(type_: type) -> str:
9292

9393
_T1 = TypeVar("_T1") # pylint: disable=invalid-name
9494
_E = TypeVar("_E", bound=enum.Enum) # pylint: disable=invalid-name
95+
_A = TypeVar("_A") # pylint: disable=invalid-name
96+
_B = TypeVar("_B") # pylint: disable=invalid-name
9597

9698

9799
def apply_all(seq, op):
@@ -101,6 +103,22 @@ def apply_all(seq, op):
101103
deque(map(op, seq), maxlen=0)
102104

103105

106+
def ftry(
107+
func: Callable[..., _A],
108+
error: Union[Type[BaseException], Tuple[Type[BaseException]]],
109+
default: _B
110+
) -> Callable[..., Union[_A, _B]]:
111+
"""
112+
Wrap a `func` such that if `errors` occur `default` is returned instead.
113+
"""
114+
def wrapper(*args, **kwargs):
115+
try:
116+
return func(*args, **kwargs)
117+
except error:
118+
return default
119+
return wrapper
120+
121+
104122
def unique_everseen(iterable, key=None):
105123
# type: (Iterable[_T1], Optional[Callable[[_T1], Hashable]]) -> Iterable[_T1]
106124
"""

0 commit comments

Comments
 (0)