Skip to content

Commit ecef8a7

Browse files
committed
add numba, small tweaks
1 parent 7844755 commit ecef8a7

File tree

2 files changed

+70
-75
lines changed

2 files changed

+70
-75
lines changed

pymc_bart/pgbart.py

Lines changed: 69 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
import logging
1616

17-
from copy import copy, deepcopy
18-
from numba import jit
17+
from copy import deepcopy
18+
from numba import njit
1919

2020
import aesara
2121
import numpy as np
@@ -57,7 +57,7 @@ class PGBART(ArrayStepShared):
5757
def __init__(
5858
self,
5959
vars=None,
60-
num_particles=40,
60+
num_particles=20,
6161
batch="auto",
6262
model=None,
6363
):
@@ -105,8 +105,6 @@ def __init__(
105105
idx_data_points=np.arange(self.num_observations, dtype="int32"),
106106
shape=self.shape,
107107
)
108-
self.mean = fast_mean()
109-
110108
self.normal = NormalSampler(mu_std, self.shape)
111109
self.uniform = UniformSampler(0.33, 0.75, self.shape)
112110
self.prior_prob_leaf_node = compute_prior_probability(self.alpha)
@@ -159,7 +157,6 @@ def astep(self, _):
159157
self.X,
160158
self.missing_data,
161159
self.sum_trees,
162-
self.mean,
163160
self.m,
164161
self.normal,
165162
self.shape,
@@ -204,7 +201,7 @@ def astep(self, _):
204201
for index in used_variates:
205202
variable_inclusion[index] += 1
206203

207-
stats = {"variable_inclusion": variable_inclusion, "bart_trees": copy(self.all_trees)}
204+
stats = {"variable_inclusion": variable_inclusion, "bart_trees": self.all_trees}
208205
return self.sum_trees, [stats]
209206

210207
def normalize(self, particles):
@@ -226,15 +223,20 @@ def normalize(self, particles):
226223
return w_t, normalized_weights
227224

228225
def resample(self, particles, normalized_weights):
229-
"""Use systematic resample for all but first two particles"""
226+
"""
227+
Use systematic resample for all but first two particles
228+
229+
Ensure particles are copied only if needed.
230+
"""
231+
new_indices = systematic(normalized_weights)
230232
seen = []
231233
new_particles = []
232234
for idx in new_indices:
233235
if idx in seen:
234236
new_particles.append(deepcopy(particles[idx]))
235237
else:
236-
seen.append(idx)
237238
new_particles.append(particles[idx])
239+
seen.append(idx)
238240

239241
particles[2:] = new_particles
240242

@@ -243,15 +245,13 @@ def resample(self, particles, normalized_weights):
243245
def init_particles(self, tree_id: int) -> np.ndarray:
244246
"""Initialize particles."""
245247
p0 = self.all_particles[tree_id]
246-
p1 = copy(p0)
248+
p1 = deepcopy(p0)
247249
p1.sample_leafs(
248250
self.sum_trees,
249-
self.mean,
250251
self.m,
251252
self.normal,
252253
self.shape,
253254
)
254-
255255
# The old tree and the one with new leafs do not grow so we update the weights only once
256256
self.update_weight(p0, old=True)
257257
self.update_weight(p1, old=True)
@@ -303,7 +303,6 @@ def __init__(self, tree):
303303
self.old_likelihood_logp = 0
304304
self.kf = 0.75
305305

306-
307306
def sample_tree(
308307
self,
309308
ssv,
@@ -312,7 +311,6 @@ def sample_tree(
312311
X,
313312
missing_data,
314313
sum_trees,
315-
mean,
316314
m,
317315
normal,
318316
shape,
@@ -332,7 +330,6 @@ def sample_tree(
332330
X,
333331
missing_data,
334332
sum_trees,
335-
mean,
336333
m,
337334
normal,
338335
self.kf,
@@ -345,15 +342,14 @@ def sample_tree(
345342

346343
return tree_grew
347344

348-
def sample_leafs(self, sum_trees, mean, m, normal, shape):
345+
def sample_leafs(self, sum_trees, m, normal, shape):
349346

350347
for idx in self.tree.idx_leaf_nodes:
351348
if idx > 0:
352349
leaf = self.tree[idx]
353350
idx_data_points = leaf.idx_data_points
354351
node_value = draw_leaf_value(
355352
sum_trees[:, idx_data_points],
356-
mean,
357353
m,
358354
normal,
359355
self.kf,
@@ -414,7 +410,6 @@ def grow_tree(
414410
X,
415411
missing_data,
416412
sum_trees,
417-
mean,
418413
m,
419414
normal,
420415
kf,
@@ -443,7 +438,6 @@ def grow_tree(
443438
idx_data_point = new_idx_data_points[idx]
444439
node_value = draw_leaf_value(
445440
sum_trees[:, idx_data_point],
446-
mean,
447441
m,
448442
normal,
449443
kf,
@@ -496,7 +490,7 @@ def get_split_value(available_splitting_values, idx_data_points, missing_data):
496490
return split_value
497491

498492

499-
def draw_leaf_value(Y_mu_pred, mean, m, normal, kf, shape):
493+
def draw_leaf_value(Y_mu_pred, m, normal, kf, shape):
500494
"""Draw Gaussian distributed leaf values."""
501495
if Y_mu_pred.size == 0:
502496
return np.zeros(shape)
@@ -505,40 +499,31 @@ def draw_leaf_value(Y_mu_pred, mean, m, normal, kf, shape):
505499
if Y_mu_pred.size == 1:
506500
mu_mean = np.full(shape, Y_mu_pred.item() / m)
507501
else:
508-
mu_mean = mean(Y_mu_pred) / m
502+
mu_mean = fast_mean(Y_mu_pred) / m
509503

510504
draw = norm + mu_mean
511505
return draw
512506

513507

514-
def fast_mean():
515-
"""If available use Numba to speed up the computation of the mean."""
516-
try:
517-
from numba import jit
518-
except ImportError:
519-
from functools import partial
508+
@njit
509+
def fast_mean(a):
510+
"""Use Numba to speed up the computation of the mean."""
511+
512+
if a.ndim == 1:
513+
count = a.shape[0]
514+
suma = 0
515+
for i in range(count):
516+
suma += a[i]
517+
return suma / count
518+
elif a.ndim == 2:
519+
res = np.zeros(a.shape[0])
520+
count = a.shape[1]
521+
for j in range(a.shape[0]):
522+
for i in range(count):
523+
res[j] += a[j, i]
524+
return res / count
520525

521-
return partial(np.mean, axis=1)
522526

523-
@jit
524-
def mean(a):
525-
if a.ndim == 1:
526-
count = a.shape[0]
527-
suma = 0
528-
for i in range(count):
529-
suma += a[i]
530-
return suma / count
531-
elif a.ndim == 2:
532-
res = np.zeros(a.shape[0])
533-
count = a.shape[1]
534-
for j in range(a.shape[0]):
535-
for i in range(count):
536-
res[j] += a[j, i]
537-
return res / count
538-
539-
return mean
540-
541-
@jit()
542527
def discrete_uniform_sampler(upper_value):
543528
"""Draw from the uniform distribution with bounds [0, upper_value).
544529
@@ -555,15 +540,14 @@ def __init__(self, scale, shape):
555540
self.scale = scale
556541
self.shape = shape
557542
self.update()
558-
543+
559544
def random(self):
560545
if self.idx == self.size:
561546
self.update()
562547
pop = self.cache[:, self.idx]
563548
self.idx += 1
564549
return pop
565550

566-
567551
def update(self):
568552
self.idx = 0
569553
self.cache = np.random.normal(loc=0.0, scale=self.scale, size=(self.shape, self.size))
@@ -586,44 +570,54 @@ def random(self):
586570
self.idx += 1
587571
return pop
588572

589-
590573
def update(self):
591574
self.idx = 0
592575
self.cache = np.random.uniform(
593576
self.lower_bound, self.upper_bound, size=(self.shape, self.size)
594577
)
595578

596-
@jit()
597-
def systematic(W):
598-
"""Systematic resampling.
579+
580+
def systematic(normalized_weights):
599581
"""
600-
M = len(W)
601-
su = (np.random.rand(1) + np.arange(M)) / M
602-
return inverse_cdf(su, W) + 2
603-
604-
605-
@jit(nopython=True)
606-
def inverse_cdf(su, W):
607-
"""Inverse CDF algorithm for a finite distribution.
608-
Parameters
609-
----------
610-
su: (M,) ndarray
611-
M sorted uniform variates (i.e. M ordered points in [0,1]).
612-
W: (N,) ndarray
613-
a vector of N normalized weights (>=0 and sum to one)
614-
Returns
615-
-------
616-
A: (M,) ndarray
617-
a vector of M indices in range 0, ..., N-1
582+
Systematic resampling.
583+
584+
Return indices in the range 2, ..., len(normalized_weights)+2
585+
586+
Note: adapted from https://github.com/nchopin/particles
587+
"""
588+
lnw = len(normalized_weights)
589+
single_uniform = (np.random.rand(1) + np.arange(lnw)) / lnw
590+
return inverse_cdf(single_uniform, normalized_weights) + 2
591+
592+
593+
@njit
594+
def inverse_cdf(single_uniform, normalized_weights):
595+
"""
596+
Inverse CDF algorithm for a finite distribution.
597+
598+
Parameters
599+
----------
600+
single_uniform: ndarray
601+
ordered points in [0,1]
602+
603+
normalized_weights: ndarray
604+
normalized weights
605+
606+
Returns
607+
-------
608+
A: ndarray
609+
a vector of indices in range 2, ..., len(normalized_weights)+2
610+
611+
Note: adapted from https://github.com/nchopin/particles
618612
"""
619613
j = 0
620-
s = W[0]
621-
M = su.shape[0]
614+
s = normalized_weights[0]
615+
M = single_uniform.shape[0]
622616
A = np.empty(M, dtype=np.int64)
623617
for n in range(M):
624-
while su[n] > s:
618+
while single_uniform[n] > s:
625619
j += 1
626-
s += W[j]
620+
s += normalized_weights[j]
627621
A[n] = j
628622
return A
629623

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
pymc>=4.1.7
22
arviz>=0.12.1
3+
numba>=0.55.1

0 commit comments

Comments
 (0)