Skip to content

Commit 283c467

Browse files
authored
Merge pull request #6115 from ales-erjavec/fixes/owfeatureconstructor-array-result
[FIX] owfeatureconstructor: Cast FeatureFunc result to array
2 parents 86c2192 + ca96c48 commit 283c467

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

Orange/widgets/data/owfeatureconstructor.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from traceback import format_exception_only
2222
from collections import namedtuple, OrderedDict
2323
from itertools import chain, count, starmap
24-
from typing import List, Dict, Any, Mapping
24+
from typing import List, Dict, Any, Mapping, Optional
2525

2626
import numpy as np
2727

@@ -1021,6 +1021,7 @@ def bind_variable(descriptor, env, data, use_values):
10211021

10221022
values = {}
10231023
cast = None
1024+
dtype = object if isinstance(descriptor, StringDescriptor) else float
10241025

10251026
if isinstance(descriptor, DiscreteDescriptor):
10261027
if not descriptor.values:
@@ -1038,7 +1039,7 @@ def bind_variable(descriptor, env, data, use_values):
10381039
cast = DateTimeCast()
10391040

10401041
func = FeatureFunc(descriptor.expression, source_vars, values, cast,
1041-
use_values=use_values)
1042+
use_values=use_values, dtype=dtype)
10421043
return descriptor, func
10431044

10441045

@@ -1216,7 +1217,10 @@ class FeatureFunc:
12161217
A function for casting the expressions result to the appropriate
12171218
type (e.g. string representation of date/time variables to floats)
12181219
"""
1219-
def __init__(self, expression, args, extra_env=None, cast=None, use_values=False):
1220+
dtype: Optional['DType'] = None
1221+
1222+
def __init__(self, expression, args, extra_env=None, cast=None, use_values=False,
1223+
dtype=None):
12201224
self.expression = expression
12211225
self.args = args
12221226
self.extra_env = dict(extra_env or {})
@@ -1225,6 +1229,7 @@ def __init__(self, expression, args, extra_env=None, cast=None, use_values=False
12251229
self.cast = cast
12261230
self.mask_exceptions = True
12271231
self.use_values = use_values
1232+
self.dtype = dtype
12281233

12291234
def __call__(self, table, *_):
12301235
if isinstance(table, Table):
@@ -1252,7 +1257,7 @@ def __call_table(self, table):
12521257
y = list(starmap(f, args))
12531258
if self.cast is not None:
12541259
y = self.cast(y)
1255-
return y
1260+
return np.asarray(y, dtype=self.dtype)
12561261

12571262
def __call_instance(self, instance: Instance):
12581263
table = Table.from_numpy(
@@ -1281,7 +1286,8 @@ def extract_column(self, table: Table, var: Variable):
12811286

12821287
def __reduce__(self):
12831288
return type(self), (self.expression, self.args,
1284-
self.extra_env, self.cast, self.use_values)
1289+
self.extra_env, self.cast, self.use_values,
1290+
self.dtype)
12851291

12861292
def __repr__(self):
12871293
return "{0.__name__}{1!r}".format(*self.__reduce__())

Orange/widgets/data/tests/test_owfeatureconstructor.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from unittest.mock import patch, Mock
99

1010
import numpy as np
11+
from scipy import sparse as sp
1112

1213
from orangewidget.settings import Context
1314

@@ -148,6 +149,20 @@ def test_unicode_normalization():
148149
construct_variables(desc, data)))
149150
np.testing.assert_equal(data.X, data.metas)
150151

152+
def test_transform_sparse(self):
153+
domain = Domain([ContinuousVariable("A")])
154+
desc = [
155+
ContinuousDescriptor(name="X", expression="A", number_of_decimals=2)
156+
]
157+
X = sp.csc_matrix(np.arange(5).reshape(5, 1))
158+
data = Table.from_numpy(domain, X)
159+
data_ = data.transform(Domain(data.domain.attributes,
160+
[],
161+
construct_variables(desc, data)))
162+
np.testing.assert_equal(
163+
data.get_column_view(0)[0], data_.get_column_view(0)[0]
164+
)
165+
151166

152167
class TestTools(unittest.TestCase):
153168
def test_free_vars(self):
@@ -276,7 +291,7 @@ def test_reconstruct(self):
276291

277292
def test_repr(self):
278293
self.assertEqual(repr(FeatureFunc("a + 1", [("a", 2)])),
279-
"FeatureFunc('a + 1', [('a', 2)], {}, None, False)")
294+
"FeatureFunc('a + 1', [('a', 2)], {}, None, False, None)")
280295

281296
def test_call(self):
282297
iris = Table("iris")
@@ -291,7 +306,7 @@ def test_string_casting(self):
291306
f = FeatureFunc("name[0]",
292307
[("name", zoo.domain["name"])])
293308
r = f(zoo)
294-
self.assertEqual(r, [x[0] for x in zoo.metas[:, 0]])
309+
self.assertEqual(list(r), [x[0] for x in zoo.metas[:, 0]])
295310
self.assertEqual(f(zoo[0]), str(zoo[0, "name"])[0])
296311

297312
def test_missing_variable(self):
@@ -309,7 +324,7 @@ def test_time_str(self):
309324
data = Table.from_numpy(Domain([TimeVariable("T", have_date=True)]), [[0], [0]])
310325
f = FeatureFunc("str(T)", [("T", data.domain[0])])
311326
c = f(data)
312-
self.assertEqual(c, ["1970-01-01", "1970-01-01"])
327+
self.assertEqual(list(c), ["1970-01-01", "1970-01-01"])
313328

314329
def test_invalid_expression_variable(self):
315330
iris = Table("iris")

0 commit comments

Comments
 (0)