Skip to content

Commit 216264a

Browse files
authored
Merge pull request #6258 from markotoplak/simpletree-predict
[FIX] Implement predict() in simple tree and simple RF models
2 parents d352fe2 + 28c90fb commit 216264a

File tree

3 files changed

+18
-8
lines changed

3 files changed

+18
-8
lines changed

Orange/classification/simple_random_forest.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,14 @@ def learn(self, learner, data):
7171
tree.seed = learner.seed + i
7272
self.estimators_.append(tree(data))
7373

74-
def predict_storage(self, data):
75-
p = np.zeros((data.X.shape[0], self.cls_vals))
74+
def predict(self, X):
75+
p = np.zeros((X.shape[0], self.cls_vals))
76+
X = np.ascontiguousarray(X) # so that it is a no-op for individual trees
7677
for tree in self.estimators_:
77-
p += tree(data, tree.Probs)
78+
# SimpleTrees do not have preprocessors and domain conversion
79+
# was already handled within this class so we can call tree.predict() directly
80+
# instead of going through tree.__call__
81+
_, pt = tree.predict(X)
82+
p += pt
7883
p /= len(self.estimators_)
7984
return p.argmax(axis=1), p

Orange/classification/simple_tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ def __init__(self, learner, data):
157157
learner.bootstrap,
158158
learner.seed)
159159

160-
def predict_storage(self, data):
161-
X = np.ascontiguousarray(data.X)
160+
def predict(self, X):
161+
X = np.ascontiguousarray(X)
162162
if self.type == Classification:
163163
p = np.zeros((X.shape[0], self.cls_vals))
164164
_tree.predict_classification(

Orange/regression/simple_random_forest.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,14 @@ def __init__(self, learner, data):
6262
self.estimators_ = []
6363
self.learn(learner, data)
6464

65-
def predict_storage(self, data):
66-
p = np.zeros(data.X.shape[0])
65+
def predict(self, X):
66+
p = np.zeros(X.shape[0])
67+
X = np.ascontiguousarray(X) # so that it is a no-op for individual trees
6768
for tree in self.estimators_:
68-
p += tree(data)
69+
# SimpleTrees do not have preprocessors and domain conversion
70+
# was already handled within this class so we can call tree.predict() directly
71+
# instead of going through tree.__call__
72+
pt = tree.predict(X)
73+
p += pt
6974
p /= len(self.estimators_)
7075
return p

0 commit comments

Comments
 (0)