Skip to content

Commit 810e3f9

Browse files
committed
WiP: Nomogram support on sparse
1 parent 6bf05ae commit 810e3f9

File tree

2 files changed

+63
-8
lines changed

2 files changed

+63
-8
lines changed

Orange/data/util.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,58 @@
22
Data-manipulation utilities.
33
"""
44
import numpy as np
5+
import scipy.sparse as sp
56
import bottleneck as bn
67

78

9+
def _nan_min_max(x, axis=0, func=None):
10+
if not sp.issparse(x):
11+
return func(x, axis=axis)
12+
else:
13+
r = []
14+
if axis == 0:
15+
x = x.T
16+
17+
# TODO check & transform to correct format
18+
19+
for row in x:
20+
values = row.data
21+
have_zeros = np.prod(row.shape) != values.size
22+
extreme = func(values)
23+
if have_zeros:
24+
extreme = func([0, extreme])
25+
r.append(extreme)
26+
return np.array(r)
27+
28+
29+
def nan_min(x, axis):
30+
return _nan_min_max(x, axis, np.nanmin)
31+
32+
33+
def nan_max(x, axis):
34+
return _nan_min_max(x, axis, np.nanmax)
35+
36+
37+
def nan_average(x):
38+
if not sp.issparse(x):
39+
return np.average(x)
40+
else:
41+
n_values = np.prod(x.shape) - np.sum(np.isnan(x.data))
42+
return np.nansum(x.data) / n_values
43+
44+
45+
def unique(x, return_counts=True):
46+
if not sp.issparse(x):
47+
return np.unique(x, return_counts=return_counts)
48+
else:
49+
n_zeros = np.prod(x.shape) - x.data.size
50+
r = np.unique(x.data, return_counts=return_counts)
51+
if return_counts:
52+
return np.insert(r[0], 0, 0), np.insert(r[1], 0, n_zeros)
53+
else:
54+
return np.insert(r, 0, 0)
55+
56+
857
def one_hot(values, dtype=float):
958
"""Return a one-hot transform of values
1059

Orange/widgets/visualize/ownomogram.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import time
22
from enum import IntEnum
3+
from collections import OrderedDict
34

45
import numpy as np
56

@@ -12,6 +13,7 @@
1213
from AnyQt.QtCore import Qt, QEvent, QRectF, QSize
1314

1415
from Orange.data import Table, Domain
16+
from Orange.data.util import nan_min, nan_max, nan_average, unique
1517
from Orange.classification import Model
1618
from Orange.classification.naive_bayes import NaiveBayesModel
1719
from Orange.classification.logistic_regression import \
@@ -865,11 +867,13 @@ def calculate_log_reg_coefficients(self):
865867
self.log_reg_coeffs = [coeffs[:, ranges[i]] for i in range(len(attrs))]
866868
self.log_reg_coeffs_orig = self.log_reg_coeffs.copy()
867869

868-
for i in range(len(self.log_reg_coeffs)):
870+
min_values = nan_min(self.data.X, axis=0)
871+
max_values = nan_max(self.data.X, axis=0)
872+
873+
for i, min_t, max_t in zip(range(len(self.log_reg_coeffs)),
874+
min_values, max_values):
869875
if self.log_reg_coeffs[i].shape[1] == 1:
870876
coef = self.log_reg_coeffs[i]
871-
min_t = np.nanmin(self.data.X, axis=0)[i]
872-
max_t = np.nanmax(self.data.X, axis=0)[i]
873877
self.log_reg_coeffs[i] = np.hstack((coef * min_t, coef * max_t))
874878
self.log_reg_cont_data_extremes.append(
875879
[sorted([min_t, max_t], reverse=(c < 0)) for c in coef])
@@ -1076,10 +1080,10 @@ def _init_feature_marker_values(self):
10761080
value, feature_val = 0, None
10771081
if len(self.log_reg_coeffs):
10781082
if attr.is_discrete:
1079-
ind, n = np.unique(self.data.X[:, i], return_counts=True)
1083+
ind, n = unique(self.data.X[:, i], return_counts=True)
10801084
feature_val = np.nan_to_num(ind[np.argmax(n)])
10811085
else:
1082-
feature_val = np.average(self.data.X[:, i])
1086+
feature_val = nan_average(self.data.X[:, i])
10831087
inst_in_dom = instances and attr in instances.domain
10841088
if inst_in_dom and not np.isnan(instances[0][attr]):
10851089
feature_val = instances[0][attr]
@@ -1104,13 +1108,15 @@ def send_report(self):
11041108

11051109
@staticmethod
11061110
def reconstruct_domain(original, preprocessed):
1107-
attrs = []
1111+
# abuse dict to make "in" comparisons faster
1112+
attrs = OrderedDict()
11081113
for attr in preprocessed.attributes:
11091114
cv = attr._compute_value.variable._compute_value
11101115
var = cv.variable if cv else original[attr.name]
1111-
if var in attrs:
1116+
if var in attrs: # the reason for OrderedDict
11121117
continue
1113-
attrs.append(var)
1118+
attrs[var] = None # we only need keys
1119+
attrs = list(attrs.keys())
11141120
return Domain(attrs, original.class_var, original.metas)
11151121

11161122
@staticmethod

0 commit comments

Comments
 (0)