Skip to content

Commit fd70c60

Browse files
authored
Merge pull request #6114 from markotoplak/compute-value-warn
Extend warnings for missing __eq__ within Variables
2 parents b0485ba + bdaee7d commit fd70c60

File tree

5 files changed

+136
-3
lines changed

5 files changed

+136
-3
lines changed

Orange/data/tests/test_util.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from Orange.data import Domain, ContinuousVariable
55
from Orange.data.util import get_unique_names, get_unique_names_duplicates, \
6-
get_unique_names_domain, one_hot, sanitized_name
6+
get_unique_names_domain, one_hot, sanitized_name, redefines_eq_and_hash
77

88

99
class TestGetUniqueNames(unittest.TestCase):
@@ -309,5 +309,39 @@ def test_sanitized_name(self):
309309
self.assertEqual(sanitized_name("1 Foo Bar"), "_1_Foo_Bar")
310310

311311

312+
class TestRedefinesEqAndHash(unittest.TestCase):
313+
314+
class Valid:
315+
def __eq__(self, other):
316+
pass
317+
318+
def __hash__(self):
319+
pass
320+
321+
class Subclass(Valid):
322+
pass
323+
324+
class OnlyEq:
325+
def __eq__(self, other):
326+
pass
327+
328+
class OnlyHash:
329+
def __hash__(self):
330+
pass
331+
332+
def test_valid(self):
333+
self.assertTrue(redefines_eq_and_hash(self.Valid))
334+
self.assertTrue(redefines_eq_and_hash(self.Valid()))
335+
336+
def test_subclass(self):
337+
self.assertFalse(redefines_eq_and_hash(self.Subclass))
338+
339+
def test_only_eq(self):
340+
self.assertFalse(redefines_eq_and_hash(self.OnlyEq))
341+
342+
def test_only_hash(self):
343+
self.assertFalse(redefines_eq_and_hash(self.OnlyHash))
344+
345+
312346
if __name__ == "__main__":
313347
unittest.main()

Orange/data/tests/test_variable.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,24 @@ class Invalid:
258258
ContinuousVariable("x", compute_value=Invalid())
259259
self.assertNotEqual(warns, [])
260260

261+
with warnings.catch_warnings(record=True) as warns:
262+
263+
class MissingHash:
264+
def __eq__(self, other):
265+
return self is other
266+
267+
ContinuousVariable("x", compute_value=MissingHash())
268+
self.assertNotEqual(warns, [])
269+
270+
with warnings.catch_warnings(record=True) as warns:
271+
272+
class MissingEq:
273+
def __hash__(self):
274+
return super().__hash__(self)
275+
276+
ContinuousVariable("x", compute_value=MissingEq())
277+
self.assertNotEqual(warns, [])
278+
261279

262280
def variabletest(varcls):
263281
def decorate(cls):

Orange/data/util.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Data-manipulation utilities.
33
"""
44
import re
5+
import types
6+
import warnings
57
from collections import Counter
68
from itertools import chain, count
79
from typing import Callable, Union, List, Type
@@ -72,6 +74,12 @@ class SharedComputeValue:
7274
def __init__(self, compute_shared, variable=None):
7375
self.compute_shared = compute_shared
7476
self.variable = variable
77+
if compute_shared is not None \
78+
and not isinstance(compute_shared, (types.BuiltinFunctionType,
79+
types.FunctionType)) \
80+
and not redefines_eq_and_hash(compute_shared):
81+
warnings.warn(f"{type(compute_shared).__name__} should define"
82+
f"__eq__ and __hash__ to be used for compute_shared")
7583

7684
def __call__(self, data, shared_data=None):
7785
"""Fallback if common parts are not passed."""
@@ -85,6 +93,14 @@ def compute(self, data, shared_data):
8593
Subclasses need to implement this function."""
8694
raise NotImplementedError
8795

96+
def __eq__(self, other):
97+
return type(self) is type(other) \
98+
and self.compute_shared == other.compute_shared \
99+
and self.variable == other.variable
100+
101+
def __hash__(self):
102+
return hash((type(self), self.compute_shared, self.variable))
103+
88104

89105
def vstack(arrays):
90106
"""vstack that supports sparse and dense arrays
@@ -307,3 +323,20 @@ def sanitized_name(name: str) -> str:
307323
if sanitized[0].isdigit():
308324
sanitized = "_" + sanitized
309325
return sanitized
326+
327+
328+
def redefines_eq_and_hash(this):
329+
"""
330+
Check if the passed object (or class) redefines __eq__ and __hash__.
331+
332+
Args:
333+
this: class or object
334+
"""
335+
if not isinstance(this, type):
336+
this = type(this)
337+
338+
# if only __eq__ is defined, __hash__ is set to None
339+
if this.__hash__ is None:
340+
return False
341+
342+
return "__hash__" in this.__dict__ and "__eq__" in this.__dict__

Orange/data/variable.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import scipy.sparse as sp
1414

1515
from Orange.data import _variable
16+
from Orange.data.util import redefines_eq_and_hash
1617
from Orange.util import Registry, Reprable, OrangeDeprecationWarning
1718

1819

@@ -368,11 +369,11 @@ def __init__(self, name="", compute_value=None, *, sparse=False):
368369
warnings.warn("Variable must have a name", OrangeDeprecationWarning,
369370
stacklevel=3)
370371
self._name = name
372+
371373
if compute_value is not None \
372374
and not isinstance(compute_value, (types.BuiltinFunctionType,
373375
types.FunctionType)) \
374-
and (type(compute_value).__eq__ is object.__eq__
375-
or compute_value.__hash__ is object.__hash__):
376+
and not redefines_eq_and_hash(compute_value):
376377
warnings.warn(f"{type(compute_value).__name__} should define"
377378
f"__eq__ and __hash__ to be used for compute_value")
378379

Orange/tests/test_data_util.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
import warnings
23
from unittest.mock import Mock
34

45
import numpy as np
@@ -72,3 +73,49 @@ def test_single_call(self):
7273
#test with descendants of table
7374
DummyTable.from_table(c.domain, data)
7475
self.assertEqual(obj.compute_shared.call_count, 4)
76+
77+
def test_compute_shared_eq_warning(self):
78+
with warnings.catch_warnings(record=True) as warns:
79+
DummyPlus(compute_shared=lambda *_: 42)
80+
81+
class Valid:
82+
def __eq__(self, other):
83+
pass
84+
85+
def __hash__(self):
86+
pass
87+
88+
DummyPlus(compute_shared=Valid())
89+
self.assertEqual(warns, [])
90+
91+
class Invalid:
92+
pass
93+
94+
DummyPlus(compute_shared=Invalid())
95+
self.assertNotEqual(warns, [])
96+
97+
with warnings.catch_warnings(record=True) as warns:
98+
99+
class MissingHash:
100+
def __eq__(self, other):
101+
pass
102+
103+
DummyPlus(compute_shared=MissingHash())
104+
self.assertNotEqual(warns, [])
105+
106+
with warnings.catch_warnings(record=True) as warns:
107+
108+
class MissingEq:
109+
def __hash__(self):
110+
pass
111+
112+
DummyPlus(compute_shared=MissingEq())
113+
self.assertNotEqual(warns, [])
114+
115+
with warnings.catch_warnings(record=True) as warns:
116+
117+
class Subclass(Valid):
118+
pass
119+
120+
DummyPlus(compute_shared=Subclass())
121+
self.assertNotEqual(warns, [])

0 commit comments

Comments
 (0)