Skip to content

Commit 47503bf

Browse files
authored
pgbart step arguments (#5458)
* pgbart step arguments * fix predictions with one chain
1 parent cb7995f commit 47503bf

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

pymc/bart/pgbart.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class PGBART(ArrayStepShared):
5050
Optional model for sampling step. Defaults to None (taken from context).
5151
"""
5252

53-
name = "bartsampler"
53+
name = "pgbart"
5454
default_blocked = False
5555
generates_stats = True
5656
stats_dtypes = [{"variable_inclusion": np.ndarray, "bart_trees": np.ndarray}]
@@ -59,7 +59,12 @@ def __init__(self, vars=None, num_particles=40, max_stages=100, batch="auto", mo
5959
_log.warning("BART is experimental. Use with caution.")
6060
model = modelcontext(model)
6161
initial_values = model.compute_initial_point()
62-
value_bart = inputvars(vars)[0]
62+
if vars is None:
63+
vars = model.value_vars
64+
else:
65+
vars = [model.rvs_to_values.get(var, var) for var in vars]
66+
vars = inputvars(vars)
67+
value_bart = vars[0]
6368
self.bart = model.values_to_rvs[value_bart].owner.op
6469

6570
self.X = self.bart.X
@@ -200,7 +205,7 @@ def astep(self, _):
200205
for index in new_particle.used_variates:
201206
variable_inclusion[index] += 1
202207

203-
stats = {"variable_inclusion": variable_inclusion, "bart_trees": self.all_trees}
208+
stats = {"variable_inclusion": variable_inclusion, "bart_trees": copy(self.all_trees)}
204209
return self.sum_trees, [stats]
205210

206211
def normalize(self, particles):

pymc/sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def sample(
403403
404404
``nuts``, ``hmc``, ``metropolis``, ``binary_metropolis``,
405405
``binary_gibbs_metropolis``, ``categorical_gibbs_metropolis``,
406-
``DEMetropolis``, ``DEMetropolisZ``, ``slice``
406+
``DEMetropolis``, ``DEMetropolisZ``, ``slice``, ``pgbart``
407407
408408
B. If you manually declare the ``step_method``\ s, within the ``step``
409409
kwarg, then you can address the ``step_method`` kwargs directly.

0 commit comments

Comments
 (0)