Skip to content

Commit ecbca86

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 8b2effe + fa173c7 commit ecbca86

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+36900
-12624
lines changed

examples/case_studies/BART_quantile_regression.ipynb

Lines changed: 600 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
---
2+
jupytext:
3+
text_representation:
4+
extension: .md
5+
format_name: myst
6+
format_version: 0.13
7+
kernelspec:
8+
display_name: Python 3 (ipykernel)
9+
language: python
10+
name: python3
11+
---
12+
13+
(BART_quantile)=
14+
# Quantile Regression with BART
15+
:::{post} Jan 25, 2023
16+
:tags: BART, non-parametric, quantile, regression
17+
:category: intermediate, explanation
18+
:author: Osvaldo Martin
19+
:::
20+
21+
```{code-cell} ipython3
22+
from pathlib import Path
23+
24+
import arviz as az
25+
import matplotlib.pyplot as plt
26+
import numpy as np
27+
import pandas as pd
28+
import pymc as pm
29+
import pymc_bart as pmb
30+
31+
from scipy import stats
32+
33+
print(f"Running on PyMC v{pm.__version__}")
34+
```
35+
36+
```{code-cell} ipython3
37+
%config InlineBackend.figure_format = "retina"
38+
RANDOM_SEED = 5781
39+
np.random.seed(RANDOM_SEED)
40+
az.style.use("arviz-darkgrid")
41+
```
42+
43+
Usually when doing regression we model the conditional mean of some distribution. Common cases are a Normal distribution for continuous unbounded responses, a Poisson distribution for count data, etc.
44+
45+
Quantile regression, instead estimates a conditional quantile of the response variable. If the quantile is 0.5, then we will be estimating the median (instead of the mean), this could be useful as a way of performing robust regression, in a similar fashion as using a Student-t distribution instead of a Normal. But for some problems we actually care of the behavior of the response away from the mean (or median). For example, in medical research, pathologies or potential health risks occur at high or low quantiles, for instance, overweight and underweight. In some other fields like ecology, quantile regression is justified due to the existence of complex interactions between variables, where the effect of one variable on another is different for different ranges of the variable.
46+
47+
+++
48+
49+
## Asymmetric Laplace distribution
50+
51+
At first it could be weird to think which distribution we should use as the likelihood for quantile regression or how to write a Bayesian model for quantile regression. But it turns out the answer is quite simple, we just need to use the asymmetric Laplace distribution. This distribution has one parameter controlling the mean, another for the scale and a third one for the asymmetry. There are at least two alternative parametrization regarding this asymmetric parameter. In terms of $\kappa$ a parameter that goes from 0 to $\infty$ and in terms of $q$ a number between 0 and 1. This later parametrization is more intuitive for quantile regression as we can directly interpret it as the quantile of interest.
52+
53+
On the next cell we compute the pdf of 3 distribution from the Asymmetric Laplace family
54+
55+
```{code-cell} ipython3
56+
x = np.linspace(-6, 6, 2000)
57+
for q, m in zip([0.2, 0.5, 0.8], [0, 0, -1]):
58+
κ = (q / (1 - q)) ** 0.5
59+
plt.plot(x, stats.laplace_asymmetric(κ, m, 1).pdf(x), label=f"q={q:}, μ={m}, σ=1")
60+
plt.yticks([])
61+
plt.legend();
62+
```
63+
64+
We are going to use a simple dataset to model the Body Mass Index for Dutch kids and young men as a function of their age.
65+
66+
```{code-cell} ipython3
67+
try:
68+
bmi = pd.read_csv(Path("..", "data", "bmi.csv"))
69+
except FileNotFoundError:
70+
bmi = pd.read_csv(pm.get_data("bmi.csv"))
71+
72+
bmi.plot(x="age", y="bmi", kind="scatter");
73+
```
74+
75+
As we can see from the previous figure the relationship between BMI and age is far from linear, and hence we are going to use BART.
76+
77+
We are going to model 3 quantiles, 0.1, 0.5 and 0.9. We can compute this quantity by fitted 3 separated model, being the sole different the value of `q` of the Asymmetric Laplace distribution. Or we can stack the observed values as in `y_stack` and fit a single model.
78+
79+
```{code-cell} ipython3
80+
y = bmi.bmi.values
81+
X = bmi.age.values[:, None]
82+
83+
84+
y_stack = np.stack([bmi.bmi.values] * 3)
85+
quantiles = np.array([[0.1, 0.5, 0.9]]).T
86+
quantiles
87+
```
88+
89+
```{code-cell} ipython3
90+
with pm.Model() as model:
91+
μ = pmb.BART("μ", X, y, shape=(3, 7294))
92+
σ = pm.HalfNormal("σ", 5)
93+
obs = pm.AsymmetricLaplace("obs", mu=μ, b=σ, q=quantiles, observed=y_stack)
94+
95+
idata = pm.sample(compute_convergence_checks=False)
96+
```
97+
98+
We can see the result of the 3 fitted curves in the next figure. One feature that stand-out is that the gap or distance between the median (orange) line and the two other lines is not the same. Also the shapes of the curve while following a similar pattern, are not exactly the same.
99+
100+
```{code-cell} ipython3
101+
plt.plot(bmi.age, bmi.bmi, ".", color="0.5")
102+
for idx, q in enumerate(quantiles[:, 0]):
103+
plt.plot(
104+
bmi.age,
105+
idata.posterior["μ"].mean(("chain", "draw")).sel(μ_dim_0=idx),
106+
label=f"q={q:}",
107+
lw=3,
108+
)
109+
110+
plt.legend();
111+
```
112+
113+
To better understand these remarks let's compute a BART regression with a Normal likelihood and then compute the same 3 quantiles from that fit.
114+
115+
```{code-cell} ipython3
116+
y = bmi.bmi.values
117+
x = bmi.age.values[:, None]
118+
with pm.Model() as model:
119+
μ = pmb.BART("μ", x, y)
120+
σ = pm.HalfNormal("σ", 5)
121+
obs = pm.Normal("obs", mu=μ, sigma=σ, observed=y)
122+
123+
idata_g = pm.sample(compute_convergence_checks=False)
124+
idata_g.extend(pm.sample_posterior_predictive(idata_g))
125+
```
126+
127+
```{code-cell} ipython3
128+
idata_g_mean_quantiles = idata_g.posterior_predictive["obs"].quantile(
129+
quantiles[:, 0], ("chain", "draw")
130+
)
131+
```
132+
133+
```{code-cell} ipython3
134+
plt.plot(bmi.age, bmi.bmi, ".", color="0.5")
135+
for q in quantiles[:, 0]:
136+
plt.plot(bmi.age.values, idata_g_mean_quantiles.sel(quantile=q), label=f"q={q:}")
137+
138+
plt.legend()
139+
plt.xlabel("Age")
140+
plt.ylabel("BMI");
141+
```
142+
143+
We can see that when we use a Normal likelihood, and from that fit we compute the quantiles, the quantiles q=0.1 and q=0.9 are symetrical with respect to q=0.5, also the shape of the curves is essentially the same just shifted up or down. Additionally the Asymmetric Laplace family allows the model to account for the increased variability in BMI as the age increases, while for the Gaussian family that variability always stays the same.
144+
145+
+++
146+
147+
## Authors
148+
* Authored by Osvaldo Martin in Jan, 2023
149+
150+
+++
151+
152+
## References
153+
154+
:::{bibliography}
155+
:filter: docname in docnames
156+
157+
martin2021bayesian
158+
quiroga2022bart
159+
:::
160+
161+
+++
162+
163+
## Watermark
164+
165+
```{code-cell} ipython3
166+
%load_ext watermark
167+
%watermark -n -u -v -iv -w -p pytensor,xarray
168+
```
169+
170+
:::{include} ../page_footer.md
171+
:::

0 commit comments

Comments
 (0)