Skip to content

Commit 5039fda

Browse files
committed
fixing merging issues
1 parent 020f679 commit 5039fda

File tree

2 files changed

+222
-3
lines changed

2 files changed

+222
-3
lines changed

causalpy/pymc_models.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020
import pandas as pd
2121
import pymc as pm
22+
import pytensor.tensor as pt
2223
import xarray as xr
2324
from arviz import r2_score
2425

@@ -290,6 +291,224 @@ def build_model(self, X, y, coords):
290291
pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind")
291292

292293

294+
class InstrumentalVariableRegression(PyMCModel):
295+
"""Custom PyMC model for instrumental linear regression
296+
297+
Example
298+
--------
299+
>>> import causalpy as cp
300+
>>> import numpy as np
301+
>>> from causalpy.pymc_models import InstrumentalVariableRegression
302+
>>> N = 10
303+
>>> e1 = np.random.normal(0, 3, N)
304+
>>> e2 = np.random.normal(0, 1, N)
305+
>>> Z = np.random.uniform(0, 1, N)
306+
>>> ## Ensure the endogeneity of the the treatment variable
307+
>>> X = -1 + 4 * Z + e2 + 2 * e1
308+
>>> y = 2 + 3 * X + 3 * e1
309+
>>> t = X.reshape(10, 1)
310+
>>> y = y.reshape(10, 1)
311+
>>> Z = np.asarray([[1, Z[i]] for i in range(0, 10)])
312+
>>> X = np.asarray([[1, X[i]] for i in range(0, 10)])
313+
>>> COORDS = {"instruments": ["Intercept", "Z"], "covariates": ["Intercept", "X"]}
314+
>>> sample_kwargs = {
315+
... "tune": 5,
316+
... "draws": 10,
317+
... "chains": 2,
318+
... "cores": 2,
319+
... "target_accept": 0.95,
320+
... "progressbar": False,
321+
... }
322+
>>> iv_reg = InstrumentalVariableRegression(sample_kwargs=sample_kwargs)
323+
>>> iv_reg.fit(
324+
... X,
325+
... Z,
326+
... y,
327+
... t,
328+
... COORDS,
329+
... {
330+
... "mus": [[-2, 4], [0.5, 3]],
331+
... "sigmas": [1, 1],
332+
... "eta": 2,
333+
... "lkj_sd": 1,
334+
... },
335+
... None,
336+
... )
337+
Inference data...
338+
"""
339+
340+
def build_model(self, X, Z, y, t, coords, priors):
341+
"""Specify model with treatment regression and focal regression data and priors
342+
343+
:param X: A pandas dataframe used to predict our outcome y
344+
:param Z: A pandas dataframe used to predict our treatment variable t
345+
:param y: An array of values representing our focal outcome y
346+
:param t: An array of values representing the treatment t of
347+
which we're interested in estimating the causal impact
348+
:param coords: A dictionary with the coordinate names for our
349+
instruments and covariates
350+
:param priors: An optional dictionary of priors for the mus and
351+
sigmas of both regressions
352+
:code:`priors = {"mus": [0, 0], "sigmas": [1, 1],
353+
"eta": 2, "lkj_sd": 2}`
354+
"""
355+
356+
# --- Priors ---
357+
with self:
358+
self.add_coords(coords)
359+
beta_t = pm.Normal(
360+
name="beta_t",
361+
mu=priors["mus"][0],
362+
sigma=priors["sigmas"][0],
363+
dims="instruments",
364+
)
365+
beta_z = pm.Normal(
366+
name="beta_z",
367+
mu=priors["mus"][1],
368+
sigma=priors["sigmas"][1],
369+
dims="covariates",
370+
)
371+
sd_dist = pm.Exponential.dist(priors["lkj_sd"], shape=2)
372+
chol, corr, sigmas = pm.LKJCholeskyCov(
373+
name="chol_cov",
374+
eta=priors["eta"],
375+
n=2,
376+
sd_dist=sd_dist,
377+
)
378+
# compute and store the covariance matrix
379+
pm.Deterministic(name="cov", var=pt.dot(l=chol, r=chol.T))
380+
381+
# --- Parameterization ---
382+
mu_y = pm.Deterministic(name="mu_y", var=pm.math.dot(X, beta_z))
383+
# focal regression
384+
mu_t = pm.Deterministic(name="mu_t", var=pm.math.dot(Z, beta_t))
385+
# instrumental regression
386+
mu = pm.Deterministic(name="mu", var=pt.stack(tensors=(mu_y, mu_t), axis=1))
387+
388+
# --- Likelihood ---
389+
pm.MvNormal(
390+
name="likelihood",
391+
mu=mu,
392+
chol=chol,
393+
observed=np.stack(arrays=(y.flatten(), t.flatten()), axis=1),
394+
shape=(X.shape[0], 2),
395+
)
396+
397+
def sample_predictive_distribution(self, ppc_sampler="jax"):
398+
"""Function to sample the Multivariate Normal posterior predictive
399+
Likelihood term in the IV class. This can be slow without
400+
using the JAX sampler compilation method. If using the
401+
JAX sampler it will sample only the posterior predictive distribution.
402+
If using the PYMC sampler if will sample both the prior
403+
and posterior predictive distributions."""
404+
random_seed = self.sample_kwargs.get("random_seed", None)
405+
406+
if ppc_sampler == "jax":
407+
with self:
408+
self.idata.extend(
409+
pm.sample_posterior_predictive(
410+
self.idata,
411+
random_seed=random_seed,
412+
compile_kwargs={"mode": "JAX"},
413+
)
414+
)
415+
elif ppc_sampler == "pymc":
416+
with self:
417+
self.idata.extend(pm.sample_prior_predictive(random_seed=random_seed))
418+
self.idata.extend(
419+
pm.sample_posterior_predictive(
420+
self.idata,
421+
random_seed=random_seed,
422+
)
423+
)
424+
425+
def fit(self, X, Z, y, t, coords, priors, ppc_sampler=None):
426+
"""Draw samples from posterior distribution and potentially
427+
from the prior and posterior predictive distributions. The
428+
fit call can take values for the
429+
ppc_sampler = ['jax', 'pymc', None]
430+
We default to None, so the user can determine if they wish
431+
to spend time sampling the posterior predictive distribution
432+
independently.
433+
"""
434+
435+
# Ensure random_seed is used in sample_prior_predictive() and
436+
# sample_posterior_predictive() if provided in sample_kwargs.
437+
# Use JAX for ppc sampling of multivariate likelihood
438+
439+
self.build_model(X, Z, y, t, coords, priors)
440+
with self:
441+
self.idata = pm.sample(**self.sample_kwargs)
442+
self.sample_predictive_distribution(ppc_sampler=ppc_sampler)
443+
return self.idata
444+
445+
446+
class PropensityScore(PyMCModel):
447+
r"""
448+
Custom PyMC model for inverse propensity score models
449+
450+
.. note:
451+
Generally, the `.fit()` method should be used rather than
452+
calling `.build_model()` directly.
453+
454+
Defines the PyMC model
455+
456+
.. math::
457+
\beta &\sim \mathrm{Normal}(0, 1) \\
458+
\sigma &\sim \mathrm{HalfNormal}(1) \\
459+
\mu &= X \cdot \beta \\
460+
p &= \text{logit}^{-1}(\mu) \\
461+
t &\sim \mathrm{Bernoulli}(p)
462+
463+
Example
464+
--------
465+
>>> import causalpy as cp
466+
>>> import numpy as np
467+
>>> from causalpy.pymc_models import PropensityScore
468+
>>> df = cp.load_data('nhefs')
469+
>>> X = df[["age", "race"]]
470+
>>> t = np.asarray(df["trt"])
471+
>>> ps = PropensityScore(sample_kwargs={"progressbar": False})
472+
>>> ps.fit(X, t, coords={
473+
... 'coeffs': ['age', 'race'],
474+
... 'obs_ind': np.arange(df.shape[0])
475+
... },
476+
... )
477+
Inference...
478+
""" # noqa: W605
479+
480+
def build_model(self, X, t, coords):
481+
"Defines the PyMC propensity model"
482+
with self:
483+
self.add_coords(coords)
484+
X_data = pm.Data("X", X, dims=["obs_ind", "coeffs"])
485+
t_data = pm.Data("t", t.flatten(), dims="obs_ind")
486+
b = pm.Normal("b", mu=0, sigma=1, dims="coeffs")
487+
mu = pm.math.dot(X_data, b)
488+
p = pm.Deterministic("p", pm.math.invlogit(mu))
489+
pm.Bernoulli("t_pred", p=p, observed=t_data, dims="obs_ind")
490+
491+
def fit(self, X, t, coords):
492+
"""Draw samples from posterior, prior predictive, and posterior predictive
493+
distributions. We overwrite the base method because the base method assumes
494+
a variable y and we use t to indicate the treatment variable here.
495+
"""
496+
# Ensure random_seed is used in sample_prior_predictive() and
497+
# sample_posterior_predictive() if provided in sample_kwargs.
498+
random_seed = self.sample_kwargs.get("random_seed", None)
499+
500+
self.build_model(X, t, coords)
501+
with self:
502+
self.idata = pm.sample(**self.sample_kwargs)
503+
self.idata.extend(pm.sample_prior_predictive(random_seed=random_seed))
504+
self.idata.extend(
505+
pm.sample_posterior_predictive(
506+
self.idata, progressbar=False, random_seed=random_seed
507+
)
508+
)
509+
return self.idata
510+
511+
293512
class InterventionTimeEstimator(PyMCModel):
294513
r"""
295514
Custom PyMC model to estimate the time an intervention took place.

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)