Skip to content

Commit 1861124

Browse files
committed
clean and update
1 parent f32bc75 commit 1861124

File tree

3 files changed

+9
-16
lines changed

3 files changed

+9
-16
lines changed

pymc_bart/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pymc_bart.utils import plot_dependence, plot_variable_importance
1919

2020
__all__ = ["BART", "PGBART"]
21-
__version__ = "0.1.0"
21+
__version__ = "0.2.0"
2222

2323

2424
pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART]

pymc_bart/pgbart.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,47 +80,43 @@ def __init__(
8080
else:
8181
self.X = self.bart.X
8282

83-
self.Y = self.bart.Y
8483
self.missing_data = np.any(np.isnan(self.X))
8584
self.m = self.bart.m
86-
self.alpha = self.bart.alpha
8785
shape = initial_values[value_bart.name].shape
8886
if len(shape) == 1:
8987
self.shape = 1
9088
else:
9189
self.shape = shape[0]
9290

93-
# self.alpha_vec = self.bart.split_prior
9491
if self.bart.split_prior:
9592
self.alpha_vec = self.bart.split_prior
9693
else:
9794
self.alpha_vec = np.ones(self.X.shape[1])
98-
self.init_mean = self.Y.mean()
95+
init_mean = self.bart.Y.mean()
9996
# if data is binary
100-
y_unique = np.unique(self.Y)
97+
y_unique = np.unique(self.bart.Y)
10198
if y_unique.size == 2 and np.all(y_unique == [0, 1]):
10299
mu_std = 3 / self.m**0.5
103-
# maybe we need to check for count data
104100
else:
105-
mu_std = self.Y.std() / self.m**0.5
101+
mu_std = self.bart.Y.std() / self.m**0.5
106102

107103
self.num_observations = self.X.shape[0]
108104
self.num_variates = self.X.shape[1]
109105
self.available_predictors = list(range(self.num_variates))
110106

111-
self.sum_trees = np.full((self.shape, self.Y.shape[0]), self.init_mean).astype(
107+
self.sum_trees = np.full((self.shape, self.bart.Y.shape[0]), init_mean).astype(
112108
config.floatX
113109
)
114-
self.sum_trees_noi = self.sum_trees - (self.init_mean / self.m)
110+
self.sum_trees_noi = self.sum_trees - (init_mean / self.m)
115111
self.a_tree = Tree(
116-
leaf_node_value=self.init_mean / self.m,
112+
leaf_node_value=init_mean / self.m,
117113
idx_data_points=np.arange(self.num_observations, dtype="int32"),
118114
num_observations=self.num_observations,
119115
shape=self.shape,
120116
)
121117
self.normal = NormalSampler(mu_std, self.shape)
122118
self.uniform = UniformSampler(0.33, 0.75, self.shape)
123-
self.prior_prob_leaf_node = compute_prior_probability(self.alpha)
119+
self.prior_prob_leaf_node = compute_prior_probability(self.bart.alpha)
124120
self.ssv = SampleSplittingVariable(self.alpha_vec)
125121

126122
self.tune = True
@@ -143,7 +139,7 @@ def __init__(
143139
self.likelihood_logp = logp(initial_values, [model.datalogp], vars, shared)
144140
self.all_particles = []
145141
for _ in range(self.m):
146-
self.a_tree.leaf_node_value = self.init_mean / self.m
142+
self.a_tree.leaf_node_value = init_mean / self.m
147143
p = ParticleTree(self.a_tree)
148144
self.all_particles.append(p)
149145
self.all_trees = np.array([p.tree for p in self.all_particles])

setup.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,6 @@ def get_version():
7575
long_description=LONG_DESCRIPTION,
7676
long_description_content_type="text/markdown",
7777
packages=find_packages(),
78-
# because of an upload-size limit by PyPI, we're temporarily removing docs from the tarball.
79-
# Also see MANIFEST.in
80-
# package_data={'docs': ['*']},
8178
include_package_data=True,
8279
classifiers=classifiers,
8380
python_requires=">=3.8",

0 commit comments

Comments
 (0)