Skip to content

Commit 532fd9a

Browse files
committed
updated
1 parent a67c5d8 commit 532fd9a

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

src/prediction_utils/traverse_treePredict.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,15 @@ def traverse_tree_predict(tree, X):
2121
X = np.dot(np.subtract(X, tree["rotDetails"]["muX"]), tree["rotDetails"]["R"])
2222

2323
if ('featureExpansion' in tree.keys()):
24+
if len(tree["decisionProjection"].shape) < 2:
25+
decisionProjection = np.expand_dims(tree["decisionProjection"], axis=1)
26+
else:
27+
decisionProjection = tree["decisionProjection"]
28+
# Check if the function exists
2429
if inspect.isfunction(tree["featureExpansion"]):
25-
bLessChild = np.dot(tree["featureExpansion"](X[:, tree["iIn"]]), tree["decisionProjection"]) <= tree["paritionPoint"]
30+
bLessChild = np.dot(tree["featureExpansion"](X[:, tree["iIn"]]), decisionProjection) <= tree["paritionPoint"]
2631
else:
27-
bLessChild = np.dot((X[:, tree["iIn"]]), tree["decisionProjection"]) <= tree["paritionPoint"]
32+
bLessChild = np.dot(X[:, tree["iIn"]], decisionProjection) <= tree["paritionPoint"]
2833
else:
2934
if len(tree["decisionProjection"].shape) < 2:
3035
decisionProjection = np.expand_dims(tree["decisionProjection"], axis=1)

0 commit comments

Comments
 (0)