Skip to content

Commit 4b4aff7

Browse files
committed
Refactor MappingTransform out from oweditdomain
1 parent c1221b9 commit 4b4aff7

File tree

5 files changed

+205
-91
lines changed

5 files changed

+205
-91
lines changed

Orange/preprocess/tests/test_transformation.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from Orange.data import DiscreteVariable
77
from Orange.preprocess.transformation import \
8-
Transformation, _Indicator, Normalizer, Lookup, Indicator, Indicator1
8+
Transformation, _Indicator, Normalizer, Lookup, Indicator, Indicator1, \
9+
MappingTransform
910

1011

1112
class TestTransformEquality(unittest.TestCase):
@@ -84,6 +85,29 @@ def test_lookup(self):
8485
self.assertNotEqual(t1, t1a)
8586
self.assertNotEqual(hash(t1), hash(t1a))
8687

88+
def test_mapping(self):
89+
def test_equal(a, b):
90+
self.assertEqual(a, b)
91+
self.assertEqual(hash(a), hash(b))
92+
93+
t1 = MappingTransform(self.disc1, {"a": "1", "b": "2", "c":"3"})
94+
t1a = MappingTransform(self.disc1a, {"a": "1", "b": "2", "c":"3"})
95+
t2 = MappingTransform(self.disc2, {"a": "1", "b": "2", "c":"3"},
96+
unknown="")
97+
test_equal(t1, t1a)
98+
self.assertNotEqual(t1, t2)
99+
100+
t1 = MappingTransform(self.disc1, {"a": 1, "b": 2, "c": float("nan")},
101+
unknown=float("nan"))
102+
t1_ = MappingTransform(self.disc1, {"a": 1, "b": 2, "c": float("nan")},
103+
unknown=float("nan"))
104+
test_equal(t1, t1_)
105+
t1_ = MappingTransform(self.disc1, {"a": 1, "b": float("nan"), "c": 2},
106+
unknown=float("nan"))
107+
self.assertNotEqual(t1, t1_)
108+
with self.assertRaises(ValueError):
109+
MappingTransform(self.disc1, {float("nan"): 1})
110+
87111

88112
class TestIndicator(unittest.TestCase):
89113
def test_nan(self):

Orange/preprocess/transformation.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1+
from typing import TYPE_CHECKING, Mapping, Optional
2+
13
import numpy as np
24
import scipy.sparse as sp
35
from pandas import isna
46

5-
from Orange.data import Instance, Table, Domain
6-
from Orange.util import Reprable
7+
from Orange.data import Instance, Table, Domain, Variable
8+
from Orange.misc.collections import DictMissingConst
9+
from Orange.util import Reprable, nan_eq, nan_hash_stand, frompyfunc
10+
11+
if TYPE_CHECKING:
12+
from numpy.typing import DTypeLike
713

814

915
class Transformation(Reprable):
@@ -224,6 +230,75 @@ def __hash__(self):
224230
# to avoid different hashes for the same array change to None
225231
# issue: https://bugs.python.org/issue43475#msg388508
226232
tuple(None if isna(x) else x for x in self.lookup_table),
227-
self.unknown,
233+
nan_hash_stand(self.unknown),
228234
)
229235
)
236+
237+
238+
class MappingTransform(Transformation):
239+
"""
240+
Map values via a dictionary lookup.
241+
242+
Parameters
243+
----------
244+
variable: Variable
245+
mapping: Mapping
246+
The mapping (for the non NA values).
247+
dtype: Optional[DTypeLike]
248+
The optional target dtype.
249+
unknown: Any
250+
The constant with whitch to replace unknown values in input.
251+
"""
252+
def __init__(
253+
self,
254+
variable: Variable,
255+
mapping: Mapping,
256+
dtype: Optional['DTypeLike'] = None,
257+
unknown=np.nan,
258+
) -> None:
259+
super().__init__(variable)
260+
if any(nan_eq(k, np.nan) for k in mapping.keys()): # ill-defined mapping
261+
raise ValueError("'nan' value in mapping.keys()")
262+
self.mapping = mapping
263+
self.dtype = dtype
264+
self.unknown = unknown
265+
self._mapper = self._make_dict_mapper(
266+
DictMissingConst(unknown, mapping), dtype
267+
)
268+
269+
@staticmethod
270+
def _make_dict_mapper(mapping, dtype):
271+
return frompyfunc(mapping.__getitem__, 1, 1, dtype)
272+
273+
def transform(self, c):
274+
return self._mapper(c)
275+
276+
def __reduce_ex__(self, protocol):
277+
return type(self), (self.variable, self.mapping, self.dtype, self.unknown)
278+
279+
def __eq__(self, other):
280+
return super().__eq__(other) \
281+
and nan_mapping_eq(self.mapping, other.mapping) \
282+
and self.dtype == other.dtype \
283+
and nan_eq(self.unknown, other.unknown)
284+
285+
def __hash__(self):
286+
return hash((type(self), self.variable, nan_mapping_hash(self.mapping),
287+
self.dtype, nan_hash_stand(self.unknown)))
288+
289+
290+
def nan_mapping_hash(a: Mapping) -> int:
291+
return hash(tuple((k, nan_hash_stand(v)) for k, v in a.items()))
292+
293+
294+
def nan_mapping_eq(a: Mapping, b: Mapping) -> bool:
295+
if len(a) != len(b):
296+
return False
297+
try:
298+
for k, va in a.items():
299+
vb = b[k]
300+
if not nan_eq(va, vb):
301+
return False
302+
except LookupError:
303+
return False
304+
return True

Orange/tests/test_util.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import scipy.sparse as sp
77

88
from Orange.util import export_globals, flatten, deprecated, try_, deepgetattr, \
9-
OrangeDeprecationWarning
9+
OrangeDeprecationWarning, nan_eq, nan_hash_stand
1010
from Orange.data import Table
1111
from Orange.data.util import vstack, hstack, array_equal
1212
from Orange.statistics.util import stats
@@ -72,6 +72,21 @@ class a:
7272
self.assertTrue(deepgetattr(a, 'l.__nx__.__x__', 42), 42)
7373
self.assertRaises(AttributeError, lambda: deepgetattr(a, 'l.__nx__.__x__'))
7474

75+
def test_nan_eq(self):
76+
self.assertTrue(nan_eq(float("nan"), float("nan")))
77+
self.assertTrue(nan_eq(1, 1.0))
78+
self.assertFalse(nan_eq(float("nan"), 1))
79+
self.assertFalse(nan_eq(1, float("nan")))
80+
self.assertFalse(nan_eq(float("inf"), float("nan")))
81+
self.assertFalse(nan_eq(float("nan"), float("inf")))
82+
self.assertFalse(nan_eq(1, 2))
83+
self.assertFalse(nan_eq(None, 2))
84+
self.assertFalse(nan_eq(2, None))
85+
86+
def test_nan_hash_stand(self):
87+
self.assertEqual(hash(nan_hash_stand(float("nan"))),
88+
hash(nan_hash_stand(float("nan"))))
89+
7590
def test_vstack(self):
7691
numpy = np.array([[1., 2.], [3., 4.]])
7792
csr = sp.csr_matrix(numpy)

Orange/util.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import os
44
import inspect
55
import datetime
6+
import math
67
from contextlib import contextmanager
8+
from typing import TYPE_CHECKING, Callable
79

810
import pkg_resources
911
from enum import Enum as _Enum
@@ -17,9 +19,14 @@
1719
# Exposed here for convenience. Prefer patching to try-finally blocks
1820
from unittest.mock import patch # pylint: disable=unused-import
1921

22+
import numpy as np
23+
2024
# Backwards-compat
2125
from Orange.data.util import scale # pylint: disable=unused-import
2226

27+
if TYPE_CHECKING:
28+
from numpy.typing import DTypeLike
29+
2330

2431
log = logging.getLogger(__name__)
2532

@@ -531,6 +538,78 @@ def utc_from_timestamp(timestamp) -> datetime.datetime:
531538
datetime.timedelta(seconds=float(timestamp))
532539

533540

541+
def frompyfunc(func: Callable, nin: int, nout: int, dtype: 'DTypeLike'):
542+
"""
543+
Wrap an `func` callable into a ufunc-like function with `out`, `dtype`,
544+
`where`, ... parameters. The `dtype` is uses as the default.
545+
546+
Unlike numpy.frompyfunc this function always returns output array of
547+
the specified `dtype`. Note that the conversion is space efficient.
548+
"""
549+
func_ = np.frompyfunc(func, nin, nout)
550+
551+
@wraps(func)
552+
def funcv(*args, out=None, dtype=dtype, **kwargs):
553+
if not args:
554+
raise TypeError
555+
args = [np.asanyarray(a) for a in args]
556+
args = np.broadcast_arrays(*args)
557+
shape = args[0].shape
558+
have_out = out is not None
559+
if out is None and dtype is not None:
560+
out = np.empty(shape, dtype)
561+
562+
res = func_(*args, out, dtype=dtype, casting="unsafe", **kwargs)
563+
if res.shape == () and not have_out:
564+
return res.item()
565+
else:
566+
return res
567+
568+
return funcv
569+
570+
571+
_isnan = math.isnan
572+
573+
574+
def nan_eq(a, b) -> bool:
575+
"""
576+
Same as `a == b` except where both `a` and `b` are NaN values in which
577+
case `True` is returned.
578+
579+
.. seealso:: nan_hash_stand
580+
"""
581+
try:
582+
both_nan = _isnan(a) and _isnan(b)
583+
except TypeError:
584+
return a == b
585+
else:
586+
return both_nan or a == b
587+
588+
589+
def nan_hash_stand(value):
590+
"""
591+
If `value` is a NaN then return a singular global *standin* NaN instance,
592+
otherwise return `value` unchanged.
593+
594+
Use this where a hash of `value` is needed and `value` might be a NaN
595+
to account for distinct hashes of NaN instances.
596+
597+
E.g. the folowing `__eq__` and `__hash__` pairs would be ill-defined for
598+
`A(float("nan"))` instances if `nan_hash_stand` and `nan_eq` were not
599+
used.
600+
>>> class A:
601+
... def __init__(self, v): self.v = v
602+
... def __hash__(self): return hash(nan_hash_stand(self.v))
603+
... def __eq__(self, other): return nan_eq(self.v, other.v)
604+
"""
605+
try:
606+
if _isnan(value):
607+
return math.nan
608+
except TypeError:
609+
pass
610+
return value
611+
612+
534613
# For best result, keep this at the bottom
535614
__all__ = export_globals(globals(), __name__)
536615

Orange/widgets/data/oweditdomain.py

Lines changed: 7 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
A widget for manual editing of a domain's attributes.
66
77
"""
8-
import math
98
import warnings
109
from xml.sax.saxutils import escape
1110
from itertools import zip_longest, repeat, chain
@@ -39,8 +38,11 @@
3938

4039
import Orange.data
4140

42-
from Orange.preprocess.transformation import Transformation, Identity, Lookup
41+
from Orange.preprocess.transformation import (
42+
Transformation, Identity, Lookup, MappingTransform
43+
)
4344
from Orange.misc.collections import DictMissingConst
45+
from Orange.util import frompyfunc
4446
from Orange.widgets import widget, gui, settings
4547
from Orange.widgets.utils import itemmodels, ftry
4648
from Orange.widgets.utils.buttons import FixedSizeButton
@@ -2640,14 +2642,7 @@ def make_dict_mapper(
26402642
`make_dict_mapper` it is used as a the default return dtype,
26412643
otherwise the default dtype is `object`.
26422644
"""
2643-
_vmapper = np.frompyfunc(mapping.__getitem__, 1, 1)
2644-
2645-
def mapper(arr, out=None, dtype=dtype, **kwargs):
2646-
arr = np.asanyarray(arr)
2647-
if out is None and dtype is not None and arr.shape != ():
2648-
out = np.empty_like(arr, dtype)
2649-
return _vmapper(arr, out, dtype=dtype, casting="unsafe", **kwargs)
2650-
return mapper
2645+
return frompyfunc(mapping.__getitem__, 1, 1, dtype)
26512646

26522647

26532648
as_string = np.frompyfunc(str, 1, 1)
@@ -2855,82 +2850,8 @@ def transform(self, c):
28552850
return np.nan
28562851

28572852

2858-
class LookupMappingTransform(Transformation):
2859-
"""
2860-
Map values via a dictionary lookup.
2861-
"""
2862-
def __init__(
2863-
self,
2864-
variable: Orange.data.Variable,
2865-
mapping: Mapping,
2866-
dtype: Optional[DType] = None,
2867-
unknown=np.nan,
2868-
) -> None:
2869-
super().__init__(variable)
2870-
self.mapping = mapping
2871-
self.dtype = dtype
2872-
self.unknown = unknown
2873-
self._mapper = make_dict_mapper(
2874-
DictMissingConst(unknown, mapping), dtype
2875-
)
2876-
2877-
def transform(self, c):
2878-
return self._mapper(c)
2879-
2880-
def __reduce_ex__(self, protocol):
2881-
return type(self), (self.variable, self.mapping, self.dtype, self.unknown)
2882-
2883-
def __eq__(self, other):
2884-
return self.variable == other.variable \
2885-
and self.mapping == other.mapping \
2886-
and self.dtype == other.dtype \
2887-
and nan_eq(self.unknown, other.unknown)
2888-
2889-
def __hash__(self):
2890-
return hash((type(self), self.variable, frozenset(self.mapping.items()),
2891-
self.dtype, nan_hash_stand(self.unknown)))
2892-
2893-
2894-
_isnan = math.isnan
2895-
2896-
2897-
def nan_eq(a, b) -> bool:
2898-
"""
2899-
Same as `a == b` except where both `a` and `b` are NaN values in which
2900-
case `True` is returned.
2901-
2902-
.. seealso:: nan_hash_stand
2903-
"""
2904-
try:
2905-
both_nan = _isnan(a) and _isnan(b)
2906-
except TypeError:
2907-
return a == b
2908-
else:
2909-
return both_nan or a == b
2910-
2911-
2912-
def nan_hash_stand(value):
2913-
"""
2914-
If `value` is a NaN then return a singular global *standin* NaN instance,
2915-
otherwise return `value` unchanged.
2916-
2917-
Use this where a hash of `value` is needed and `value` might be a NaN
2918-
to account for distinct hashes of NaN instances.
2919-
2920-
E.g. the folowing `__eq__` and `__hash__` pairs would be ill-defined for
2921-
`A(float("nan"))` instances if `nan_hash_stand` and `nan_eq` were not
2922-
used.
2923-
>>> class A:
2924-
... def __init__(self, v): self.v = v
2925-
... def __hash__(self): return hash(nan_hash_stand(self.v))
2926-
... def __eq__(self, other): return nan_eq(self.v, other.v)
2927-
"""
2928-
try:
2929-
if _isnan(value):
2930-
return math.nan
2931-
except TypeError:
2932-
pass
2933-
return value
2853+
# Alias for back compatibility (unpickling transforms)
2854+
LookupMappingTransform = MappingTransform
29342855

29352856

29362857
@singledispatch

0 commit comments

Comments
 (0)