Skip to content

Commit 9467104

Browse files
authored
reverse transpose (#57)
1 parent ff22efe commit 9467104

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pymc_bart/tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def new_tree(cls, leaf_node_value, idx_data_points, num_observations, shape):
6464
0: Node.new_leaf_node(0, value=leaf_node_value, idx_data_points=idx_data_points)
6565
},
6666
idx_leaf_nodes=[0],
67-
output=np.zeros((shape, num_observations)).astype(config.floatX).squeeze(),
67+
output=np.zeros((num_observations, shape)).astype(config.floatX).squeeze(),
6868
)
6969

7070
def __getitem__(self, index):
@@ -111,7 +111,7 @@ def _predict(self):
111111
for node_index in self.idx_leaf_nodes:
112112
leaf_node = self.get_node(node_index)
113113
output[leaf_node.idx_data_points] = leaf_node.value
114-
return output
114+
return output.T
115115

116116
def predict(self, x, excluded=None):
117117
"""

0 commit comments

Comments
 (0)