Skip to content

Commit c2e904d

Browse files
committed
owfeatureconstructor: Make cast functions picklable
1 parent 75400c8 commit c2e904d

File tree

2 files changed

+76
-19
lines changed

2 files changed

+76
-19
lines changed

Orange/widgets/data/owfeatureconstructor.py

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from traceback import format_exception_only
2222
from collections import namedtuple, OrderedDict
2323
from itertools import chain, count, starmap
24-
from typing import List, Dict, Any
24+
from typing import List, Dict, Any, Mapping
2525

2626
import numpy as np
2727

@@ -31,9 +31,12 @@
3131
QPushButton, QMenu, QListView, QFrame, QLabel, QMessageBox)
3232
from AnyQt.QtGui import QKeySequence
3333
from AnyQt.QtCore import Qt, pyqtSignal as Signal, pyqtProperty as Property
34+
3435
from orangewidget.utils.combobox import ComboBoxSearch
3536

3637
import Orange
38+
from Orange.preprocess.transformation import MappingTransform
39+
from Orange.util import frompyfunc
3740
from Orange.data import Variable, Table, Value, Instance
3841
from Orange.data.util import get_unique_names
3942
from Orange.widgets import gui
@@ -63,7 +66,6 @@
6366

6467
StringDescriptor = namedtuple("StringDescriptor", ["name", "expression"])
6568

66-
#warningIcon = gui.createAttributePixmap('!', QColor((202, 0, 32)))
6769

6870
def make_variable(descriptor, compute_value):
6971
if isinstance(descriptor, ContinuousDescriptor):
@@ -1019,7 +1021,6 @@ def bind_variable(descriptor, env, data, use_values):
10191021

10201022
values = {}
10211023
cast = None
1022-
nan = float("nan")
10231024

10241025
if isinstance(descriptor, DiscreteDescriptor):
10251026
if not descriptor.values:
@@ -1028,29 +1029,62 @@ def bind_variable(descriptor, env, data, use_values):
10281029
values = sorted({str(x) for x in str_func(data)})
10291030
values = {name: i for i, name in enumerate(values)}
10301031
descriptor = descriptor._replace(values=values)
1031-
1032-
def cast(x): # pylint: disable=function-redefined
1033-
return values.get(x, nan)
1034-
1032+
cast = MappingTransformCast(values)
10351033
else:
10361034
values = [sanitized_name(v) for v in descriptor.values]
10371035
values = {name: i for i, name in enumerate(values)}
10381036

10391037
if isinstance(descriptor, DateTimeDescriptor):
1040-
parse = Orange.data.TimeVariable("_").parse
1041-
1042-
def cast(e): # pylint: disable=function-redefined
1043-
if isinstance(e, (int, float)):
1044-
return e
1045-
if e == "" or e is None:
1046-
return np.nan
1047-
return parse(e)
1038+
cast = DateTimeCast()
10481039

10491040
func = FeatureFunc(descriptor.expression, source_vars, values, cast,
10501041
use_values=use_values)
10511042
return descriptor, func
10521043

10531044

1045+
_parse_datetime = Orange.data.TimeVariable("_").parse
1046+
_cast_datetime_num_types = (int, float)
1047+
1048+
1049+
def cast_datetime(e):
1050+
if isinstance(e, _cast_datetime_num_types):
1051+
return e
1052+
if e == "" or e is None:
1053+
return np.nan
1054+
return _parse_datetime(e)
1055+
1056+
1057+
_cast_datetime = frompyfunc(cast_datetime, 1, 1, dtype=float)
1058+
1059+
1060+
class DateTimeCast:
1061+
def __call__(self, values):
1062+
return _cast_datetime(values)
1063+
1064+
def __eq__(self, other):
1065+
return isinstance(other, DateTimeCast)
1066+
1067+
def __hash__(self):
1068+
return hash(cast_datetime)
1069+
1070+
1071+
class MappingTransformCast:
1072+
def __init__(self, mapping: Mapping):
1073+
self.t = MappingTransform(None, mapping)
1074+
1075+
def __reduce_ex__(self, protocol):
1076+
return type(self), (self.t.mapping, )
1077+
1078+
def __call__(self, values):
1079+
return self.t.transform(values)
1080+
1081+
def __eq__(self, other):
1082+
return isinstance(other, MappingTransformCast) and self.t == other.t
1083+
1084+
def __hash__(self):
1085+
return hash(self.t)
1086+
1087+
10541088
def make_lambda(expression, args, env=None):
10551089
# type: (ast.Expression, List[str], Dict[str, Any]) -> types.FunctionType
10561090
"""
@@ -1217,8 +1251,7 @@ def __call_table(self, table):
12171251
else:
12181252
y = list(starmap(f, args))
12191253
if self.cast is not None:
1220-
cast = self.cast
1221-
y = [cast(y_) for y_ in y]
1254+
y = self.cast(y)
12221255
return y
12231256

12241257
def __call_instance(self, instance: Instance):
@@ -1248,7 +1281,7 @@ def extract_column(self, table: Table, var: Variable):
12481281

12491282
def __reduce__(self):
12501283
return type(self), (self.expression, self.args,
1251-
self.extra_env, self.cast)
1284+
self.extra_env, self.cast, self.use_values)
12521285

12531286
def __repr__(self):
12541287
return "{0.__name__}{1!r}".format(*self.__reduce__())

Orange/widgets/data/tests/test_owfeatureconstructor.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def test_reconstruct(self):
276276

277277
def test_repr(self):
278278
self.assertEqual(repr(FeatureFunc("a + 1", [("a", 2)])),
279-
"FeatureFunc('a + 1', [('a', 2)], {}, None)")
279+
"FeatureFunc('a + 1', [('a', 2)], {}, None, False)")
280280

281281
def test_call(self):
282282
iris = Table("iris")
@@ -524,6 +524,30 @@ def test_report(self):
524524
args = w.report_items.call_args[0][1]
525525
self.assertEqual(list(args), list("abcdefg"))
526526

527+
def test_output_domain_picklable(self):
528+
w = self.widget
529+
self.send_signal(w.Inputs.data, Table("iris")[::5])
530+
features = [
531+
ContinuousDescriptor("X1", "max(0, sepal_width - 5)", 2),
532+
DiscreteDescriptor("D1", "HIGH if sepal_width > 5 else LOW",
533+
("HIGH", "LOW"), False),
534+
DiscreteDescriptor("D2", "'HIGH' if sepal_length > 5 else 'LOW'",
535+
(), False),
536+
DateTimeDescriptor("T1", "0"),
537+
DateTimeDescriptor("T2", "'1900-01-01'"),
538+
]
539+
for f in features:
540+
w.addFeature(f)
541+
w.apply()
542+
out = self.get_output(w.Outputs.data)
543+
domain_a = out.domain
544+
domain_b= pickle.loads(pickle.dumps(domain_a))
545+
for name in ["X1", "D1", "D2", "T1", "T2"]:
546+
a = domain_a[name]
547+
b = domain_b[name]
548+
self.assertEqual(a, b)
549+
self.assertEqual(hash(a), hash(b))
550+
527551

528552
class TestFeatureEditor(unittest.TestCase):
529553
def test_has_functions(self):

0 commit comments

Comments
 (0)