Skip to content

Commit 8609be4

Browse files
committed
Add predict() to simple tree and simple RF models
This avoid creating intermediate tables.
1 parent 3835959 commit 8609be4

File tree

3 files changed

+22
-6
lines changed

3 files changed

+22
-6
lines changed

Orange/classification/simple_random_forest.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,16 @@ def learn(self, learner, data):
7272
self.estimators_.append(tree(data))
7373

7474
def predict_storage(self, data):
75-
p = np.zeros((data.X.shape[0], self.cls_vals))
75+
return self.predict(data.X)
76+
77+
def predict(self, X):
78+
p = np.zeros((X.shape[0], self.cls_vals))
79+
X = np.ascontiguousarray(X) # so that it is a no-op for individual trees
7680
for tree in self.estimators_:
77-
p += tree(data, tree.Probs)
81+
# SimpleTrees do not have preprocessors and domain conversion
82+
# was already handled within this class so we can call tree.predict() directly
83+
# instead of going through tree.__call__
84+
_, pt = tree.predict(X)
85+
p += pt
7886
p /= len(self.estimators_)
7987
return p.argmax(axis=1), p

Orange/classification/simple_tree.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,10 @@ def __init__(self, learner, data):
158158
learner.seed)
159159

160160
def predict_storage(self, data):
161-
X = np.ascontiguousarray(data.X)
161+
return self.predict(data.X)
162+
163+
def predict(self, X):
164+
X = np.ascontiguousarray(X)
162165
if self.type == Classification:
163166
p = np.zeros((X.shape[0], self.cls_vals))
164167
_tree.predict_classification(

Orange/regression/simple_random_forest.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,14 @@ def __init__(self, learner, data):
6262
self.estimators_ = []
6363
self.learn(learner, data)
6464

65-
def predict_storage(self, data):
66-
p = np.zeros(data.X.shape[0])
65+
def predict(self, X):
66+
p = np.zeros(X.shape[0])
67+
X = np.ascontiguousarray(X) # so that it is a no-op for individual trees
6768
for tree in self.estimators_:
68-
p += tree(data)
69+
# SimpleTrees do not have preprocessors and domain conversion
70+
# was already handled within this class so we can call tree.predict() directly
71+
# instead of going through tree.__call__
72+
pt = tree.predict(X)
73+
p += pt
6974
p /= len(self.estimators_)
7075
return p

0 commit comments

Comments
 (0)