Skip to content

Commit 4d0ea77

Browse files
committed
util.Reprable: more generic equals (handle array_like types)
1 parent 5970d21 commit 4d0ea77

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

Orange/tests/test_util.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,19 @@ def test_try_(self):
3333
self.assertFalse(try_(lambda: np.whatever()))
3434
self.assertEqual(try_(len, default=SOMETHING), SOMETHING)
3535

36+
def test_reprable(self):
37+
from Orange.data import ContinuousVariable
38+
from Orange.preprocess.impute import ReplaceUnknownsRandom
39+
from Orange.statistics.distribution import Continuous
40+
41+
var = ContinuousVariable('x')
42+
transform = ReplaceUnknownsRandom(var, Continuous(1, var))
43+
44+
self.assertEqual(repr(transform).replace('\n ', ' '),
45+
"ReplaceUnknownsRandom("
46+
"variable=ContinuousVariable(name='x', number_of_decimals=3), "
47+
"distribution=Continuous([[ 0.], [ 0.]]))")
48+
3649
def test_deepgetattr(self):
3750
class a:
3851
l = []

Orange/util.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from collections import OrderedDict
99
import warnings
1010

11+
import numpy as np
12+
1113
# Exposed here for convenience. Prefer patching to try-finally blocks
1214
from unittest.mock import patch # pylint: disable=unused-import
1315

@@ -234,7 +236,7 @@ def __repr__(self):
234236
param.kind not in (param.VAR_POSITIONAL,
235237
param.VAR_KEYWORD)):
236238
value = getattr(self, param.name)
237-
if value != param.default:
239+
if not self.__equal(value, param.default):
238240
names_values.append((param.name, value))
239241

240242
module = self._reprable_module
@@ -246,6 +248,17 @@ def __repr__(self):
246248
', '.join('{}={!r}'.format(*pair)
247249
for pair in names_values))
248250

251+
@staticmethod
252+
def __equal(obj1, obj2):
253+
try:
254+
# If the objects are broadcastable (works for array_like as
255+
# for arbitrary objects), compare them for equality (possibly
256+
# element-wise)
257+
return np.broadcast(obj1, obj2) and np.all(obj1 == obj2)
258+
except ValueError:
259+
# Broadcasting failed
260+
return False
261+
249262

250263
# For best result, keep this at the bottom
251264
__all__ = export_globals(globals(), __name__)

0 commit comments

Comments
 (0)