Skip to content

Commit 9af455c

Browse files
authored
small fix, unnecesary transpose (#52)
1 parent 152063c commit 9af455c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pymc_bart/tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def new_tree(cls, leaf_node_value, idx_data_points, num_observations, shape):
6464
0: Node.new_leaf_node(0, value=leaf_node_value, idx_data_points=idx_data_points)
6565
},
6666
idx_leaf_nodes=[0],
67-
output=np.zeros((num_observations, shape)).astype(config.floatX).squeeze(),
67+
output=np.zeros((shape, num_observations)).astype(config.floatX).squeeze(),
6868
)
6969

7070
def __getitem__(self, index):
@@ -109,7 +109,7 @@ def _predict(self):
109109
for node_index in self.idx_leaf_nodes:
110110
leaf_node = self.get_node(node_index)
111111
output[leaf_node.idx_data_points] = leaf_node.value
112-
return output.T
112+
return output
113113

114114
def predict(self, x, excluded=None):
115115
"""

0 commit comments

Comments
 (0)