Skip to content

Commit 7844755

Browse files
committed
use systematic resampling and fix copy error
1 parent 55e8f05 commit 7844755

File tree

2 files changed

+69
-12
lines changed

2 files changed

+69
-12
lines changed

pymc_bart/pgbart.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
import logging
1616

17-
from copy import copy
17+
from copy import copy, deepcopy
18+
from numba import jit
1819

1920
import aesara
2021
import numpy as np
@@ -173,11 +174,8 @@ def astep(self, _):
173174
# Normalize weights
174175
w_t, normalized_weights = self.normalize(particles[2:])
175176

176-
# Resample all but first two particles
177-
new_indices = np.random.choice(
178-
self.indices, size=self.len_indices, p=normalized_weights
179-
)
180-
particles[2:] = particles[new_indices]
177+
# Resample
178+
particles = self.resample(particles, normalized_weights)
181179

182180
# Set the new weight
183181
for p in particles[2:]:
@@ -196,12 +194,14 @@ def astep(self, _):
196194
self.sum_trees = self.sum_trees_noi + new_tree._predict()
197195
self.all_trees[tree_id] = new_tree.trim()
198196

197+
used_variates = new_tree.get_split_variables()
198+
199199
if self.tune:
200200
self.ssv = SampleSplittingVariable(self.alpha_vec)
201-
for index in new_particle.used_variates:
201+
for index in used_variates:
202202
self.alpha_vec[index] += 1
203203
else:
204-
for index in new_particle.used_variates:
204+
for index in used_variates:
205205
variable_inclusion[index] += 1
206206

207207
stats = {"variable_inclusion": variable_inclusion, "bart_trees": copy(self.all_trees)}
@@ -225,6 +225,21 @@ def normalize(self, particles):
225225

226226
return w_t, normalized_weights
227227

228+
def resample(self, particles, normalized_weights):
229+
"""Use systematic resample for all but first two particles"""
230+
seen = []
231+
new_particles = []
232+
for idx in new_indices:
233+
if idx in seen:
234+
new_particles.append(deepcopy(particles[idx]))
235+
else:
236+
seen.append(idx)
237+
new_particles.append(particles[idx])
238+
239+
particles[2:] = new_particles
240+
241+
return particles
242+
228243
def init_particles(self, tree_id: int) -> np.ndarray:
229244
"""Initialize particles."""
230245
p0 = self.all_particles[tree_id]
@@ -286,9 +301,9 @@ def __init__(self, tree):
286301
self.expansion_nodes = [0]
287302
self.log_weight = 0
288303
self.old_likelihood_logp = 0
289-
self.used_variates = []
290304
self.kf = 0.75
291305

306+
292307
def sample_tree(
293308
self,
294309
ssv,
@@ -326,7 +341,6 @@ def sample_tree(
326341
if index_selected_predictor is not None:
327342
new_indexes = self.tree.idx_leaf_nodes[-2:]
328343
self.expansion_nodes.extend(new_indexes)
329-
self.used_variates.append(index_selected_predictor)
330344
tree_grew = True
331345

332346
return tree_grew
@@ -524,7 +538,7 @@ def mean(a):
524538

525539
return mean
526540

527-
541+
@jit()
528542
def discrete_uniform_sampler(upper_value):
529543
"""Draw from the uniform distribution with bounds [0, upper_value).
530544
@@ -541,14 +555,15 @@ def __init__(self, scale, shape):
541555
self.scale = scale
542556
self.shape = shape
543557
self.update()
544-
558+
545559
def random(self):
546560
if self.idx == self.size:
547561
self.update()
548562
pop = self.cache[:, self.idx]
549563
self.idx += 1
550564
return pop
551565

566+
552567
def update(self):
553568
self.idx = 0
554569
self.cache = np.random.normal(loc=0.0, scale=self.scale, size=(self.shape, self.size))
@@ -571,12 +586,47 @@ def random(self):
571586
self.idx += 1
572587
return pop
573588

589+
574590
def update(self):
575591
self.idx = 0
576592
self.cache = np.random.uniform(
577593
self.lower_bound, self.upper_bound, size=(self.shape, self.size)
578594
)
579595

596+
@jit()
597+
def systematic(W):
598+
"""Systematic resampling.
599+
"""
600+
M = len(W)
601+
su = (np.random.rand(1) + np.arange(M)) / M
602+
return inverse_cdf(su, W) + 2
603+
604+
605+
@jit(nopython=True)
606+
def inverse_cdf(su, W):
607+
"""Inverse CDF algorithm for a finite distribution.
608+
Parameters
609+
----------
610+
su: (M,) ndarray
611+
M sorted uniform variates (i.e. M ordered points in [0,1]).
612+
W: (N,) ndarray
613+
a vector of N normalized weights (>=0 and sum to one)
614+
Returns
615+
-------
616+
A: (M,) ndarray
617+
a vector of M indices in range 0, ..., N-1
618+
"""
619+
j = 0
620+
s = W[0]
621+
M = su.shape[0]
622+
A = np.empty(M, dtype=np.int64)
623+
for n in range(M):
624+
while su[n] > s:
625+
j += 1
626+
s += W[j]
627+
A[n] = j
628+
return A
629+
580630

581631
def logp(point, out_vars, vars, shared):
582632
"""Compile Aesara function of the model and the input and output variables.

pymc_bart/tree.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,13 @@ def trim(self):
8383
del current_node.idx_data_points
8484
return a_tree
8585

86+
def get_split_variables(self):
87+
return [
88+
node.idx_split_variable
89+
for node in self.tree_structure.values()
90+
if isinstance(node, SplitNode)
91+
]
92+
8693
def _predict(self):
8794
output = self.output
8895
for node_index in self.idx_leaf_nodes:

0 commit comments

Comments
 (0)