Skip to content

Commit e9fb5ef

Browse files
VesnaTmarkotoplak
authored andcommitted
Merge pull request biolab#5759 from markotoplak/fix-transformation-upickling
[FIX] Unpickling pre-3.28.0 Transformation
1 parent 874679a commit e9fb5ef

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

Orange/preprocess/transformation.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,28 @@ def __init__(self, variable):
1616
:type variable: int or str or :obj:`~Orange.data.Variable`
1717
"""
1818
self.variable = variable
19+
self._create_cached_target_domain()
1920

21+
def _create_cached_target_domain(self):
22+
""" If the same domain is used everytime this allows better caching of
23+
domain transformations in from_table"""
2024
if self.variable is not None:
2125
if self.variable.is_primitive():
22-
self.need_domain = Domain([self.variable])
26+
self._target_domain = Domain([self.variable])
2327
else:
24-
self.need_domain = Domain([], metas=[self.variable])
28+
self._target_domain = Domain([], metas=[self.variable])
29+
30+
def __getstate__(self):
31+
# Do not pickle the cached domain; rather recreate it after unpickling
32+
state = self.__dict__.copy()
33+
state.pop("_target_domain")
34+
return state
35+
36+
def __setstate__(self, state):
37+
# Ensure that cached target domain is created after unpickling.
38+
# This solves the problem of unpickling old pickled models.
39+
self.__dict__.update(state)
40+
self._create_cached_target_domain()
2541

2642
def __call__(self, data):
2743
"""
@@ -31,7 +47,7 @@ def __call__(self, data):
3147
inst = isinstance(data, Instance)
3248
if inst:
3349
data = Table.from_list(data.domain, [data])
34-
data = data.transform(self.need_domain)
50+
data = data.transform(self._target_domain)
3551
if self.variable.is_primitive():
3652
col = data.X
3753
else:

Orange/tests/test_transformation.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pickle
12
import unittest
23

34
import numpy as np
@@ -43,6 +44,17 @@ def test_transform_fails(self):
4344
trans = Transformation(self.data.domain[2])
4445
self.assertRaises(NotImplementedError, trans, self.data)
4546

47+
def test_pickling_target_domain(self):
48+
data = self.data
49+
trans = self.TransformationMock(data.domain[2])
50+
self.assertIn("_target_domain", trans.__dict__)
51+
# _target_domain should not be pickled
52+
state = trans.__getstate__()
53+
self.assertNotIn("_target_domain", state)
54+
# _target_domain should be recreated when unpickled
55+
unpickled = pickle.loads(pickle.dumps(trans))
56+
self.assertIn("_target_domain", unpickled.__dict__)
57+
4658

4759
class IdentityTest(unittest.TestCase):
4860
def test_identity(self):

0 commit comments

Comments
 (0)