Skip to content

Commit f899e95

Browse files
authored
Merge pull request #4922 from janezd/transformations-eq
[FIX] Edit Domain (and perhaps other widgets) could cause missing data later in the workflow
2 parents 6193d32 + ba3e1c4 commit f899e95

File tree

12 files changed

+384
-33
lines changed

12 files changed

+384
-33
lines changed

Orange/preprocess/discretize.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ def fmt(val):
8383
dvar.to_sql = to_sql
8484
return dvar
8585

86+
def __eq__(self, other):
87+
return super().__eq__(other) and self.points == other.points
88+
89+
def __hash__(self):
90+
return hash((type(self), self.variable, tuple(self.points)))
91+
8692

8793
class BinSql:
8894
def __init__(self, var, points):

Orange/preprocess/impute.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ def transform(self, c):
3232
else:
3333
return np.where(np.isnan(c), self.value, c)
3434

35+
def __eq__(self, other):
36+
return super().__eq__(other) and self.value == other.value
37+
38+
def __hash__(self):
39+
return hash((type(self), self.variable, float(self.value)))
40+
3541

3642
class BaseImputeMethod(Reprable):
3743
name = ""
@@ -316,6 +322,12 @@ def transform(self, c):
316322
c[nanindices] = sample
317323
return c
318324

325+
def __eq__(self, other):
326+
return super().__eq__(other) and self.distribution == other.distribution
327+
328+
def __hash__(self):
329+
return hash((type(self), self.variable, self.distribution))
330+
319331

320332
class Random(BaseImputeMethod):
321333
name = "Random values"

Orange/preprocess/tests/test_discretize.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
from time import struct_time, mktime
66

77
import numpy as np
8+
9+
from Orange.data import ContinuousVariable
810
from Orange.preprocess.discretize import \
9-
_time_binnings, time_binnings, BinDefinition
11+
_time_binnings, time_binnings, BinDefinition, Discretizer
1012

1113

1214
# pylint: disable=redefined-builtin
@@ -17,12 +19,12 @@ def create(year=1970, month=1, day=1, hour=0, min=0, sec=0):
1719
class TestTimeBinning(unittest.TestCase):
1820
def setUp(self):
1921
self.dates = [mktime(x) for x in
20-
[(1975, 6, 9, 10, 0, 0, 0, 161, 0),
21-
(1975, 6, 9, 10, 50, 0, 0, 161, 0),
22-
(1975, 6, 9, 11, 40, 0, 0, 161, 0),
23-
(1975, 6, 9, 12, 30, 0, 0, 161, 0),
24-
(1975, 6, 9, 13, 20, 0, 0, 161, 0),
25-
(1975, 6, 9, 14, 10, 0, 0, 161, 0)]]
22+
[(1975, 6, 9, 10, 0, 0, 0, 161, 0),
23+
(1975, 6, 9, 10, 50, 0, 0, 161, 0),
24+
(1975, 6, 9, 11, 40, 0, 0, 161, 0),
25+
(1975, 6, 9, 12, 30, 0, 0, 161, 0),
26+
(1975, 6, 9, 13, 20, 0, 0, 161, 0),
27+
(1975, 6, 9, 14, 10, 0, 0, 161, 0)]]
2628

2729
def test_binning(self):
2830
def tr1(s):
@@ -752,5 +754,28 @@ def test_thresholds(self):
752754
self.assertEqual(bindef.nbins, 2)
753755

754756

757+
class TestDiscretizer(unittest.TestCase):
758+
def test_equality(self):
759+
v1 = ContinuousVariable("x")
760+
v2 = ContinuousVariable("x", number_of_decimals=42)
761+
v3 = ContinuousVariable("y")
762+
assert v1 == v2
763+
764+
t1 = Discretizer(v1, [0, 2, 1])
765+
t1a = Discretizer(v2, [0, 2, 1])
766+
t2 = Discretizer(v3, [0, 2, 1])
767+
self.assertEqual(t1, t1)
768+
self.assertEqual(t1, t1a)
769+
self.assertNotEqual(t1, t2)
770+
771+
self.assertEqual(hash(t1), hash(t1a))
772+
self.assertNotEqual(hash(t1), hash(t2))
773+
774+
t1 = Discretizer(v1, [0, 2, 1])
775+
t1a = Discretizer(v2, [1, 2, 0])
776+
self.assertNotEqual(t1, t1a)
777+
self.assertNotEqual(hash(t1), hash(t1a))
778+
779+
755780
if __name__ == '__main__':
756781
unittest.main()
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import unittest
2+
3+
from Orange.data import DiscreteVariable, ContinuousVariable
4+
from Orange.preprocess.impute import ReplaceUnknownsRandom, ReplaceUnknowns
5+
from Orange.statistics.distribution import Discrete
6+
7+
8+
class TestReplaceUnknowns(unittest.TestCase):
9+
def test_equality(self):
10+
v1 = ContinuousVariable("x")
11+
v2 = ContinuousVariable("x")
12+
v3 = ContinuousVariable("y")
13+
14+
t1 = ReplaceUnknowns(v1, 0)
15+
t1a = ReplaceUnknowns(v2, 0)
16+
t2 = ReplaceUnknowns(v3, 0)
17+
self.assertEqual(t1, t1)
18+
self.assertEqual(t1, t1a)
19+
self.assertNotEqual(t1, t2)
20+
21+
self.assertEqual(hash(t1), hash(t1a))
22+
self.assertNotEqual(hash(t1), hash(t2))
23+
24+
t1 = ReplaceUnknowns(v1, 0)
25+
t1a = ReplaceUnknowns(v1, 1)
26+
self.assertNotEqual(t1, t1a)
27+
self.assertNotEqual(hash(t1), hash(t1a))
28+
29+
30+
class TestReplaceUnknownsRandom(unittest.TestCase):
31+
def test_equality(self):
32+
v1 = DiscreteVariable("x", tuple("abc"))
33+
v2 = DiscreteVariable("x", tuple("abc"))
34+
v3 = DiscreteVariable("y", tuple("abc"))
35+
36+
d1 = Discrete([1, 2, 3], v1)
37+
d2 = Discrete([1, 2, 3], v2)
38+
d3 = Discrete([1, 2, 3], v3)
39+
40+
t1 = ReplaceUnknownsRandom(v1, d1)
41+
t1a = ReplaceUnknownsRandom(v2, d2)
42+
t2 = ReplaceUnknownsRandom(v3, d3)
43+
self.assertEqual(t1, t1)
44+
self.assertEqual(t1, t1a)
45+
self.assertNotEqual(t1, t2)
46+
47+
self.assertEqual(hash(t1), hash(t1a))
48+
self.assertNotEqual(hash(t1), hash(t2))
49+
50+
d1[1] += 1
51+
self.assertNotEqual(t1, t1a)
52+
self.assertNotEqual(hash(t1), hash(t1a))
53+
54+
55+
if __name__ == "__main__":
56+
unittest.main()
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import unittest
2+
3+
import numpy as np
4+
5+
from Orange.data import DiscreteVariable
6+
from Orange.preprocess.transformation import \
7+
Transformation, _Indicator, Normalizer, Lookup
8+
9+
10+
class TestTransformEquality(unittest.TestCase):
11+
def setUp(self):
12+
self.disc1 = DiscreteVariable("d1", values=tuple("abc"))
13+
self.disc1a = DiscreteVariable("d1", values=tuple("abc"))
14+
self.disc2 = DiscreteVariable("d2", values=tuple("abc"))
15+
assert self.disc1 == self.disc1a
16+
17+
def test_transformation(self):
18+
t1 = Transformation(self.disc1)
19+
t1a = Transformation(self.disc1a)
20+
t2 = Transformation(self.disc2)
21+
self.assertEqual(t1, t1)
22+
self.assertEqual(t1, t1a)
23+
self.assertNotEqual(t1, t2)
24+
25+
self.assertEqual(hash(t1), hash(t1a))
26+
self.assertNotEqual(hash(t1), hash(t2))
27+
28+
def test_indicator(self):
29+
t1 = _Indicator(self.disc1, 0)
30+
t1a = _Indicator(self.disc1a, 0)
31+
t2 = _Indicator(self.disc2, 0)
32+
self.assertEqual(t1, t1)
33+
self.assertEqual(t1, t1a)
34+
self.assertNotEqual(t1, t2)
35+
36+
self.assertEqual(hash(t1), hash(t1a))
37+
self.assertNotEqual(hash(t1), hash(t2))
38+
39+
t1 = _Indicator(self.disc1, 0)
40+
t1a = _Indicator(self.disc1a, 1)
41+
self.assertNotEqual(t1, t1a)
42+
self.assertNotEqual(hash(t1), hash(t1a))
43+
44+
def test_normalizer(self):
45+
t1 = Normalizer(self.disc1, 0, 1)
46+
t1a = Normalizer(self.disc1a, 0, 1)
47+
t2 = Normalizer(self.disc2, 0, 1)
48+
self.assertEqual(t1, t1)
49+
self.assertEqual(t1, t1a)
50+
self.assertNotEqual(t1, t2)
51+
52+
self.assertEqual(hash(t1), hash(t1a))
53+
self.assertNotEqual(hash(t1), hash(t2))
54+
55+
t1 = Normalizer(self.disc1, 0, 1)
56+
t1a = Normalizer(self.disc1a, 1, 1)
57+
self.assertNotEqual(t1, t1a)
58+
self.assertNotEqual(hash(t1), hash(t1a))
59+
60+
t1 = Normalizer(self.disc1, 0, 1)
61+
t1a = Normalizer(self.disc1a, 0, 2)
62+
self.assertNotEqual(t1, t1a)
63+
self.assertNotEqual(hash(t1), hash(t1a))
64+
65+
def test_lookup(self):
66+
t1 = Lookup(self.disc1, np.array([0, 2, 1]), 1)
67+
t1a = Lookup(self.disc1a, np.array([0, 2, 1]), 1)
68+
t2 = Lookup(self.disc2, np.array([0, 2, 1]), 1)
69+
self.assertEqual(t1, t1)
70+
self.assertEqual(t1, t1a)
71+
self.assertNotEqual(t1, t2)
72+
73+
self.assertEqual(hash(t1), hash(t1a))
74+
self.assertNotEqual(hash(t1), hash(t2))
75+
76+
t1 = Lookup(self.disc1, np.array([0, 2, 1]), 1)
77+
t1a = Lookup(self.disc1a, np.array([1, 2, 0]), 1)
78+
self.assertNotEqual(t1, t1a)
79+
self.assertNotEqual(hash(t1), hash(t1a))
80+
81+
t1 = Lookup(self.disc1, np.array([0, 2, 1]), 1)
82+
t1a = Lookup(self.disc1a, np.array([0, 2, 1]), 2)
83+
self.assertNotEqual(t1, t1a)
84+
self.assertNotEqual(hash(t1), hash(t1a))
85+
86+
87+
if __name__ == '__main__':
88+
unittest.main()

Orange/preprocess/transformation.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,25 +48,21 @@ def transform(self, c):
4848
raise NotImplementedError(
4949
"ColumnTransformations must implement method 'transform'.")
5050

51-
52-
class Identity(Transformation):
53-
"""Return an untransformed value of `c`.
54-
"""
55-
def transform(self, c):
56-
return c
57-
5851
def __eq__(self, other):
5952
return type(other) is type(self) and self.variable == other.variable
6053

6154
def __hash__(self):
6255
return hash((type(self), self.variable))
6356

6457

65-
class Indicator(Transformation):
66-
"""
67-
Return an indicator value that equals 1 if the variable has the specified
68-
value and 0 otherwise.
58+
class Identity(Transformation):
59+
"""Return an untransformed value of `c`.
6960
"""
61+
def transform(self, c):
62+
return c
63+
64+
65+
class _Indicator(Transformation):
7066
def __init__(self, variable, value):
7167
"""
7268
:param variable: The variable whose transformed value is returned.
@@ -78,26 +74,27 @@ def __init__(self, variable, value):
7874
super().__init__(variable)
7975
self.value = value
8076

77+
def __eq__(self, other):
78+
return super().__eq__(other) and self.value == other.value
79+
80+
def __hash__(self):
81+
return hash((type(self), self.variable, self.value))
82+
83+
84+
class Indicator(_Indicator):
85+
"""
86+
Return an indicator value that equals 1 if the variable has the specified
87+
value and 0 otherwise.
88+
"""
8189
def transform(self, c):
8290
return c == self.value
8391

8492

85-
class Indicator1(Transformation):
93+
class Indicator1(_Indicator):
8694
"""
8795
Return an indicator value that equals 1 if the variable has the specified
8896
value and -1 otherwise.
8997
"""
90-
def __init__(self, variable, value):
91-
"""
92-
:param variable: The variable whose transformed value is returned.
93-
:type variable: int or str or :obj:`~Orange.data.Variable`
94-
95-
:param value: The value to which the indicator refers
96-
:type value: int or float
97-
"""
98-
super().__init__(variable)
99-
self.value = value
100-
10198
def transform(self, c):
10299
return (c == self.value) * 2 - 1
103100

@@ -129,6 +126,13 @@ def transform(self, c):
129126
else:
130127
return (c - self.offset) * self.factor
131128

129+
def __eq__(self, other):
130+
return super().__eq__(other) \
131+
and self.offset == other.offset and self.factor == other.factor
132+
133+
def __hash__(self):
134+
return hash((type(self), self.variable, self.offset, self.factor))
135+
132136

133137
class Lookup(Transformation):
134138
"""
@@ -139,7 +143,7 @@ def __init__(self, variable, lookup_table, unknown=np.nan):
139143
:param variable: The variable whose transformed value is returned.
140144
:type variable: int or str or :obj:`~Orange.data.DiscreteVariable`
141145
:param lookup_table: transformations for each value of `self.variable`
142-
:type lookup_table: np.array or list or tuple
146+
:type lookup_table: np.array
143147
:param unknown: The value to be used as unknown value.
144148
:type unknown: float or int
145149
"""
@@ -156,3 +160,13 @@ def transform(self, column):
156160
column[mask] = 0
157161
values = self.lookup_table[column]
158162
return np.where(mask, self.unknown, values)
163+
164+
def __eq__(self, other):
165+
return super().__eq__(other) \
166+
and np.allclose(self.lookup_table, other.lookup_table,
167+
equal_nan=True) \
168+
and np.allclose(self.unknown, other.unknown, equal_nan=True)
169+
170+
def __hash__(self):
171+
return hash((type(self), self.variable,
172+
tuple(self.lookup_table), self.unknown))

Orange/widgets/data/owcontinuize.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,12 @@ def transform(self, c):
189189
t *= self.weight
190190
return t
191191

192+
def __eq__(self, other):
193+
return super().__eq__(other) and self.weight == other.weight
194+
195+
def __hash__(self):
196+
return hash((type(self), self.variable, self.value, self.weight))
197+
192198

193199
def make_indicator_var(source, value_ind, weight=None):
194200
if weight is None:

Orange/widgets/data/owcreateclass.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,17 @@ def transform(self, c):
8686
res[nans] = np.nan
8787
return res
8888

89+
def __eq__(self, other):
90+
return super().__eq__(other) \
91+
and self.patterns == other.patterns \
92+
and self.case_sensitive == other.case_sensitive \
93+
and self.match_beginning == other.match_beginning
94+
95+
def __hash__(self):
96+
return hash((type(self), self.variable,
97+
tuple(self.patterns),
98+
self.case_sensitive, self.match_beginning))
99+
89100

90101
class ValueFromDiscreteSubstring(Lookup):
91102
"""

0 commit comments

Comments
 (0)