Skip to content

Commit 97d267c

Browse files
committed
Re-run model using Metropolis
1 parent 3900c2c commit 97d267c

File tree

2 files changed

+1796
-1428
lines changed

2 files changed

+1796
-1428
lines changed

examples/mixture_models/dependent_density_regression.ipynb

Lines changed: 1661 additions & 1401 deletions
Large diffs are not rendered by default.

examples/mixture_models/dependent_density_regression.myst.md

Lines changed: 135 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ kernelspec:
1010
name: python3
1111
---
1212

13+
+++ {"id": "FDW0_THqg8LC"}
14+
1315
(dependent_density_regression)=
1416
# Dependent density regression
1517
:::{post} 2017
@@ -23,6 +25,12 @@ In another [example](dp_mix.ipynb), we showed how to use Dirichlet processes to
2325
Just as Dirichlet process mixtures can be thought of as infinite mixture models that select the number of active components as part of inference, dependent density regression can be thought of as infinite [mixtures of experts](https://en.wikipedia.org/wiki/Committee_machine) that select the active experts as part of inference. Their flexibility and modularity make them powerful tools for performing nonparametric Bayesian Data analysis.
2426

2527
```{code-cell} ipython3
28+
---
29+
colab:
30+
base_uri: https://localhost:8080/
31+
id: wSEx-eTag8LD
32+
outputId: a962b5ff-d107-47f8-b413-5dc0480648bf
33+
---
2634
from io import StringIO
2735
2836
import arviz as az
@@ -40,17 +48,27 @@ print(f"Running on PyMC v{pm.__version__}")
4048
```
4149

4250
```{code-cell} ipython3
51+
:id: 0iVlIVjig8LE
52+
4353
%config InlineBackend.figure_format = 'retina'
4454
plt.rc("animation", writer="ffmpeg")
4555
blue, *_ = sns.color_palette()
4656
az.style.use("arviz-darkgrid")
47-
SEED = 972915 # from random.org; for reproducibility
57+
SEED = 1972917 # from random.org; for reproducibility
4858
np.random.seed(SEED)
4959
```
5060

61+
+++ {"id": "3VHUk32Mg8LE"}
62+
5163
We will use the LIDAR data set from Larry Wasserman's excellent book, [_All of Nonparametric Statistics_](http://www.stat.cmu.edu/~larry/all-of-nonpar/). We standardize the data set to improve the rate of convergence of our samples.
5264

5365
```{code-cell} ipython3
66+
---
67+
colab:
68+
base_uri: https://localhost:8080/
69+
id: cVuo7yrRg8LE
70+
outputId: bc357830-c080-453c-ff24-8154c328817b
71+
---
5472
DATA_URI = "http://www.stat.cmu.edu/~larry/all-of-nonpar/=data/lidar.dat"
5573
5674
@@ -65,12 +83,28 @@ df = pd.read_csv(StringIO(response.text), sep=r"\s{1,3}", engine="python").assig
6583
```
6684

6785
```{code-cell} ipython3
86+
---
87+
colab:
88+
base_uri: https://localhost:8080/
89+
height: 206
90+
id: i30x-q2Cg8LE
91+
outputId: 791768de-d65e-47f8-9aa2-ffecea186946
92+
---
6893
df.head()
6994
```
7095

96+
+++ {"id": "tylbzDhcg8LE"}
97+
7198
We plot the LIDAR data below.
7299

73100
```{code-cell} ipython3
101+
---
102+
colab:
103+
base_uri: https://localhost:8080/
104+
height: 628
105+
id: HuFM6Wq8g8LE
106+
outputId: 4240b043-428a-4923-9a48-3e1f24461842
107+
---
74108
fig, ax = plt.subplots(figsize=(8, 6))
75109
76110
ax.scatter(df.std_range, df.std_logratio, color=blue)
@@ -82,6 +116,8 @@ ax.set_yticklabels([])
82116
ax.set_ylabel("Standardized log ratio");
83117
```
84118

119+
+++ {"id": "2mYwxtSfg8LE"}
120+
85121
This data set has a two interesting properties that make it useful for illustrating dependent density regression.
86122

87123
1. The relationship between range and log ratio is nonlinear, but has locally linear components.
@@ -90,6 +126,8 @@ This data set has a two interesting properties that make it useful for illustrat
90126
The intuitive idea behind dependent density regression is to reduce the problem to many (related) density estimates, conditioned on fixed values of the predictors. The following animation illustrates this intuition.
91127

92128
```{code-cell} ipython3
129+
:id: di7x_3pvg8LE
130+
93131
fig, (scatter_ax, hist_ax) = plt.subplots(ncols=2, figsize=(16, 6))
94132
95133
scatter_ax.scatter(df.std_range, df.std_logratio, color=blue, zorder=2)
@@ -130,11 +168,20 @@ plt.close()
130168
```
131169

132170
```{code-cell} ipython3
171+
---
172+
colab:
173+
base_uri: https://localhost:8080/
174+
height: 641
175+
id: SyWtHa72g8LE
176+
outputId: c48bcfec-aa82-41ec-ce9a-32667117125e
177+
---
133178
from IPython.display import HTML
134179
135180
HTML(animation.to_html5_video())
136181
```
137182

183+
+++ {"id": "i3B2R7-vg8LE"}
184+
138185
As we slice the data with a window sliding along the x-axis in the left plot, the empirical distribution of the y-values of the points in the window varies in the right plot. An important aspect of this approach is that the density estimates that correspond to close values of the predictor are similar.
139186

140187
In the previous example, we saw that a Dirichlet process estimates a probability density as a mixture model with infinitely many components. In the case of normal component distributions,
@@ -155,37 +202,43 @@ where $\Phi$ is the cumulative distribution function of the standard normal dist
155202

156203
$$w_i\ |\ x = v_i\ |\ x \cdot \prod_{j = 1}^{i - 1} (1 - v_j\ |\ x).$$
157204

158-
For the LIDAR data set, we use independent normal priors $\alpha_i \sim N(0, 5^2)$ and $\beta_i \sim N(0, 5^2)$. We now express this this model for the conditional mixture weights using `PyMC3`.
205+
For the LIDAR data set, we use independent normal priors $\alpha_i \sim N(0, 5^2)$ and $\beta_i \sim N(0, 5^2)$. We now express this this model for the conditional mixture weights using `PyMC`.
159206

160207
```{code-cell} ipython3
208+
:id: 5EgbxpkUg8LE
209+
161210
def norm_cdf(z):
162211
return 0.5 * (1 + pt.erf(z / np.sqrt(2)))
163212
164213
165214
def stick_breaking(v):
166215
return v * pt.concatenate(
167-
[pt.ones_like(v[:, :1]), pt.extra_ops.cumprod(1 - v, axis=1)[:, :-1]], axis=1
216+
[pt.ones_like(v[:, :1]), pt.extra_ops.cumprod(1 - v[:, :-1], axis=1)], axis=1
168217
)
169218
```
170219

171220
```{code-cell} ipython3
221+
:id: qtZS8sing8LE
222+
172223
N = len(df)
173224
K = 20
174225
175-
std_range = df.std_range.values[:, np.newaxis]
226+
std_range = df.std_range.values
176227
std_logratio = df.std_logratio.values
177228
178-
with pm.Model(coords={"N": np.arange(N), "K": np.arange(K) + 1, "one": [1]}) as model:
179-
alpha = pm.Normal("alpha", 0.0, 5.0, dims="K")
180-
beta = pm.Normal("beta", 0.0, 5.0, dims=("one", "K"))
181-
x = pm.Data("x", std_range)
182-
v = norm_cdf(alpha + x @ beta)
229+
with pm.Model(coords={"N": np.arange(N), "K": np.arange(K) + 1}) as model:
230+
alpha = pm.Normal("alpha", 0, 5, dims="K")
231+
beta = pm.Normal("beta", 0, 5, dims="K")
232+
x = pm.Data("x", std_range, dims="N")
233+
v = norm_cdf(alpha + pt.outer(x, beta))
183234
w = pm.Deterministic("w", stick_breaking(v), dims=["N", "K"])
184235
```
185236

186-
We have defined `x` as a `pm.Data` container in order to use `PyMC3`'s posterior prediction capabilities later.
237+
+++ {"id": "TKt9RzIVg8LF"}
238+
239+
We have defined `x` as a `pm.Data` container in order to use `PyMC`'s posterior prediction capabilities later.
187240

188-
While the dependent density regression model theoretically has infinitely many components, we must truncate the model to finitely many components (in this case, twenty) in order to express it using `PyMC3`. After sampling from the model, we will verify that truncation did not unduly influence our results.
241+
While the dependent density regression model theoretically has infinitely many components, we must truncate the model to finitely many components (in this case, twenty) in order to express it using `PyMC`. After sampling from the model, we will verify that truncation did not unduly influence our results.
189242

190243
Since the LIDAR data seems to have several linear components, we use the linear models
191244

@@ -203,33 +256,63 @@ $$
203256
for the conditional component means.
204257

205258
```{code-cell} ipython3
259+
:id: qMLOhLHsg8LF
260+
206261
with model:
207-
gamma = pm.Normal("gamma", 0.0, 10.0, dims="K")
208-
delta = pm.Normal("delta", 0.0, 10.0, dims=("one", "K"))
209-
mu = pm.Deterministic("mu", gamma + x @ delta)
262+
gamma = pm.Normal("gamma", 0, 3, dims="K")
263+
delta = pm.Normal("delta", 0, 3, dims="K")
264+
mu = pm.Deterministic("mu", gamma + pt.outer(x, delta), dims=("N", "K"))
210265
```
211266

212-
Finally, we place the prior $\tau_i \sim \textrm{Gamma}(1, 1)$ on the component precisions.
267+
+++ {"id": "4dcBWBbvg8LF"}
268+
269+
Finally, we specify a `NormalMixture` likelihood function, using the weights we have modeled above.
213270

214271
```{code-cell} ipython3
272+
---
273+
colab:
274+
base_uri: https://localhost:8080/
275+
height: 487
276+
id: ag8Lwc9sg8LF
277+
outputId: 85b8d803-d144-4073-8e5d-7f3ffd35e48a
278+
---
215279
with model:
216-
tau = pm.Gamma("tau", 1.0, 1.0, dims="K")
217-
y = pm.Data("y", std_logratio)
218-
obs = pm.NormalMixture("obs", w, mu, tau=tau, observed=y)
280+
sigma = pm.HalfNormal("sigma", 3, dims="K")
281+
y = pm.Data("y", std_logratio, dims="N")
282+
obs = pm.NormalMixture("obs", w, mu, sigma=sigma, observed=y, dims="N")
219283
220284
pm.model_to_graphviz(model)
221285
```
222286

223-
We now sample from the dependent density regression model.
287+
+++ {"id": "gUPThEEEg8LF"}
288+
289+
We now sample from the dependent density regression model using a Metropolis sampler. The default NUTS sampler has a difficult time sampling from this model, and the traceplots show poor convergence.
224290

225291
```{code-cell} ipython3
292+
---
293+
colab:
294+
base_uri: https://localhost:8080/
295+
height: 70
296+
referenced_widgets: [e2c19d27c2d24df69b2570d2580009a1, 6d10b9e9b680495386f1803d8994c2fb]
297+
id: FSYdNHFUg8LF
298+
outputId: 829d4ee8-c971-4962-aa71-265f93eeb356
299+
---
226300
with model:
227-
trace = pm.sample(random_seed=SEED)
301+
trace = pm.sample(random_seed=SEED, step=pm.Metropolis(), draws=10_000, tune=10_000, cores=2)
228302
```
229303

304+
+++ {"id": "io6KXPdgg8LF"}
305+
230306
To verify that truncation did not unduly influence our results, we plot the largest posterior expected mixture weight for each component. (In this model, each point has a mixture weight for each component, so we plot the maximum mixture weight for each component across all data points in order to judge if the component exerts any influence on the posterior.)
231307

232308
```{code-cell} ipython3
309+
---
310+
colab:
311+
base_uri: https://localhost:8080/
312+
height: 628
313+
id: L_yuCm6Fg8LF
314+
outputId: dda7fd9e-b609-4a23-8dc1-d298353c7182
315+
---
233316
fig, ax = plt.subplots(figsize=(8, 6))
234317
235318
max_mixture_weights = trace.posterior["w"].mean(("chain", "draw")).max("N")
@@ -242,28 +325,45 @@ ax.set_xlabel("Mixture component")
242325
ax.set_ylabel("Largest posterior expected\nmixture weight");
243326
```
244327

328+
+++ {"id": "6Pq0WqBbg8LF"}
329+
245330
Since only three mixture components have appreciable posterior expected weight for any data point, we can be fairly certain that truncation did not unduly influence our results. (If most components had appreciable posterior expected weight, truncation may have influenced the results, and we would have increased the number of components and sampled again.)
246331

247332
Visually, it is reasonable that the LIDAR data has three linear components, so these posterior expected weights seem to have identified the structure of the data well. We now sample from the posterior predictive distribution to get a better understand the model's performance.
248333

249334
```{code-cell} ipython3
335+
---
336+
colab:
337+
base_uri: https://localhost:8080/
338+
height: 33
339+
referenced_widgets: [64628338cd314dcf998fdcdec5e64a2c, 8283e2190c4d45a6926da9d95273d376]
340+
id: -tAIHunXg8LF
341+
outputId: 733df6c3-aa98-44b6-bace-cc2075cee2a9
342+
---
250343
lidar_pp_x = np.linspace(std_range.min() - 0.05, std_range.max() + 0.05, 100)
251344
252345
with model:
253-
pm.set_data({"x": lidar_pp_x[:, np.newaxis], "y": np.zeros_like(lidar_pp_x)})
346+
pm.set_data(
347+
{"x": lidar_pp_x, "y": np.zeros_like(lidar_pp_x)}, coords={"N": np.arange(len(lidar_pp_x))}
348+
)
254349
255350
pm.sample_posterior_predictive(
256351
trace, predictions=True, extend_inferencedata=True, random_seed=SEED
257352
)
258353
```
259354

260-
Below we plot the posterior expected value and the 95% posterior credible interval.
355+
+++ {"id": "UecH3-jAg8LF"}
261356

262-
```{code-cell} ipython3
263-
trace.predictions["obs"].mean(("chain", "draw")).shape
264-
```
357+
Below we plot the posterior expected value and the 95% posterior credible interval.
265358

266359
```{code-cell} ipython3
360+
---
361+
colab:
362+
base_uri: https://localhost:8080/
363+
height: 628
364+
id: m2ZWtSuQg8LF
365+
outputId: ade722fc-744c-4b8a-8bf6-ec4fe55ce657
366+
---
267367
fig, ax = plt.subplots(figsize=(8, 6))
268368
269369
ax.scatter(df.std_range, df.std_logratio, color=blue, zorder=10, label=None)
@@ -291,19 +391,21 @@ ax.legend(loc=1)
291391
ax.set_title("LIDAR Data");
292392
```
293393

394+
+++ {"id": "0vFYLTpZg8LF"}
395+
294396
The model has fit the linear components of the data well, and also accommodated its heteroskedasticity. This flexibility, along with the ability to modularly specify the conditional mixture weights and conditional component densities, makes dependent density regression an extremely useful nonparametric Bayesian model.
295397

296398
To learn more about dependent density regression and related models, consult [_Bayesian Data Analysis_](http://www.stat.columbia.edu/~gelman/book/), [_Bayesian Nonparametric Data Analysis_](http://www.springer.com/us/book/9783319189673), or [_Bayesian Nonparametrics_](https://www.google.com/webhp?sourceid=chrome-instant&ion=1&espv=2&ie=UTF-8#q=bayesian+nonparametrics+book).
297399

298400
This example first appeared [here](http://austinrochford.com/posts/2017-01-18-ddp-pymc3.html).
299401

300-
+++
402+
+++ {"id": "CxDFNZDtg8LF"}
301403

302404
## Authors
303405
* authored by Austin Rochford in 2017
304406
* updated to PyMC v5 by Christopher Fonnesbeck in September2024
305407

306-
+++
408+
+++ {"id": "e41HT-6Og8LF"}
307409

308410
## References
309411

@@ -312,6 +414,12 @@ This example first appeared [here](http://austinrochford.com/posts/2017-01-18-dd
312414
:::
313415

314416
```{code-cell} ipython3
417+
---
418+
colab:
419+
base_uri: https://localhost:8080/
420+
id: NMqJeLTAg8LF
421+
outputId: 2a8b67c1-2922-4aff-82b2-392d66190951
422+
---
315423
%load_ext watermark
316424
%watermark -n -u -v -iv -w
317425
```

0 commit comments

Comments
 (0)