Skip to content

Commit 152063c

Browse files
authored
Avoid Deepcopy on Tree and ParticleTree (#47)
* Implement a constructor of Tree with all the atributes(tree_structure, idx_leaf_nodes and output); Create the class method new_tree with the previous behaviour of the constructor; Change the implementation of copy on Tree, avoiding deepcopying all and just deepcopying the output; Improve method trim(), previously we copy all and then we delete the atributes that we do not need, now we just create the tree copying the atributes that we care; Insted of creating an empty list and append each particle we create an comprehensive list; We implement the method copy on ParticleTree, instead of deepcopying; * Changes Code Review;
1 parent cd5191f commit 152063c

File tree

2 files changed

+42
-29
lines changed

2 files changed

+42
-29
lines changed

pymc_bart/pgbart.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import logging
1616

17-
from copy import deepcopy
1817
from numba import njit
1918

2019
import numpy as np
@@ -107,7 +106,7 @@ def __init__(
107106
config.floatX
108107
)
109108
self.sum_trees_noi = self.sum_trees - (init_mean / self.m)
110-
self.a_tree = Tree(
109+
self.a_tree = Tree.new_tree(
111110
leaf_node_value=init_mean / self.m,
112111
idx_data_points=np.arange(self.num_observations, dtype="int32"),
113112
num_observations=self.num_observations,
@@ -136,9 +135,7 @@ def __init__(
136135

137136
shared = make_shared_replacements(initial_values, vars, model)
138137
self.likelihood_logp = logp(initial_values, [model.datalogp], vars, shared)
139-
self.all_particles = []
140-
for _ in range(self.m):
141-
self.all_particles.append(ParticleTree(self.a_tree))
138+
self.all_particles = list(ParticleTree(self.a_tree) for _ in range(self.m))
142139
self.all_trees = np.array([p.tree for p in self.all_particles])
143140
super().__init__(vars, shared)
144141

@@ -239,7 +236,7 @@ def resample(self, particles, normalized_weights):
239236
new_particles = []
240237
for idx in new_indices:
241238
if idx in seen:
242-
new_particles.append(deepcopy(particles[idx]))
239+
new_particles.append(particles[idx].copy())
243240
else:
244241
new_particles.append(particles[idx])
245242
seen.append(idx)
@@ -274,7 +271,7 @@ def systematic(self, normalized_weights):
274271
def init_particles(self, tree_id: int) -> np.ndarray:
275272
"""Initialize particles."""
276273
p0 = self.all_particles[tree_id]
277-
p1 = deepcopy(p0)
274+
p1 = p0.copy()
278275
p1.sample_leafs(
279276
self.sum_trees,
280277
self.m,
@@ -328,12 +325,22 @@ class ParticleTree:
328325
__slots__ = "tree", "expansion_nodes", "log_weight", "old_likelihood_logp", "kfactor"
329326

330327
def __init__(self, tree):
331-
self.tree = tree.copy() # keeps the tree that we care at the moment
328+
self.tree = tree.copy()
332329
self.expansion_nodes = [0]
333330
self.log_weight = 0
334331
self.old_likelihood_logp = 0
335332
self.kfactor = 0.75
336333

334+
def copy(self):
335+
p = ParticleTree(self.tree)
336+
p.expansion_nodes, p.log_weight, p.old_likelihood_logp, p.kfactor = (
337+
self.expansion_nodes.copy(),
338+
self.log_weight,
339+
self.old_likelihood_logp,
340+
self.kfactor,
341+
)
342+
return p
343+
337344
def sample_tree(
338345
self,
339346
ssv,
@@ -500,7 +507,6 @@ def grow_tree(
500507

501508

502509
def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X):
503-
504510
left_idx = X[idx_data_points, selected_predictor] <= split_value
505511
left_node_idx_data_points = idx_data_points[left_idx]
506512
right_node_idx_data_points = idx_data_points[~left_idx]
@@ -509,7 +515,6 @@ def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X)
509515

510516

511517
def get_split_value(available_splitting_values, idx_data_points, missing_data):
512-
513518
if missing_data:
514519
idx_data_points = idx_data_points[~np.isnan(available_splitting_values)]
515520
available_splitting_values = available_splitting_values[

pymc_bart/tree.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import math
1616

17-
from copy import deepcopy
1817
from functools import lru_cache
1918

2019
from pytensor import config
@@ -42,9 +41,9 @@ class Tree:
4241
4342
Parameters
4443
----------
45-
idx_data_points : array of integers
46-
num_observations : integer
47-
shape : int
44+
tree_structure : Dictionary of nodes
45+
idx_leaf_nodes : List with the index of the leaf nodes of the tree.
46+
output : Array of shape number of observations, shape
4847
"""
4948

5049
__slots__ = (
@@ -53,12 +52,20 @@ class Tree:
5352
"output",
5453
)
5554

56-
def __init__(self, leaf_node_value, idx_data_points, num_observations, shape):
57-
self.tree_structure = {
58-
0: Node.new_leaf_node(0, value=leaf_node_value, idx_data_points=idx_data_points)
59-
}
60-
self.idx_leaf_nodes = [0]
61-
self.output = np.zeros((num_observations, shape)).astype(config.floatX).squeeze()
55+
def __init__(self, tree_structure, idx_leaf_nodes, output):
56+
self.tree_structure = tree_structure
57+
self.idx_leaf_nodes = idx_leaf_nodes
58+
self.output = output
59+
60+
@classmethod
61+
def new_tree(cls, leaf_node_value, idx_data_points, num_observations, shape):
62+
return cls(
63+
tree_structure={
64+
0: Node.new_leaf_node(0, value=leaf_node_value, idx_data_points=idx_data_points)
65+
},
66+
idx_leaf_nodes=[0],
67+
output=np.zeros((num_observations, shape)).astype(config.floatX).squeeze(),
68+
)
6269

6370
def __getitem__(self, index):
6471
return self.get_node(index)
@@ -67,7 +74,11 @@ def __setitem__(self, index, node):
6774
self.set_node(index, node)
6875

6976
def copy(self):
70-
return deepcopy(self)
77+
tree = {
78+
k: Node(v.index, v.value, v.idx_data_points, v.idx_split_variable)
79+
for k, v in self.tree_structure.items()
80+
}
81+
return Tree(tree, self.idx_leaf_nodes.copy(), self.output.copy())
7182

7283
def get_node(self, index) -> "Node":
7384
return self.tree_structure[index]
@@ -82,14 +93,11 @@ def delete_leaf_node(self, index):
8293
del self.tree_structure[index]
8394

8495
def trim(self):
85-
a_tree = self.copy()
86-
del a_tree.output
87-
del a_tree.idx_leaf_nodes
88-
for k in a_tree.tree_structure.keys():
89-
current_node = a_tree[k]
90-
if current_node.is_leaf_node():
91-
del current_node.idx_data_points
92-
return a_tree
96+
tree = {
97+
k: Node(v.index, v.value, None, v.idx_split_variable)
98+
for k, v in self.tree_structure.items()
99+
}
100+
return Tree(tree, None, None)
93101

94102
def get_split_variables(self):
95103
return [

0 commit comments

Comments
 (0)