Skip to content

Commit d2035f8

Browse files
committed
adding some tests and a conclusion
Signed-off-by: Nathaniel <[email protected]>
1 parent 8ab06ac commit d2035f8

File tree

3 files changed

+91
-4
lines changed

3 files changed

+91
-4
lines changed

causalpy/pymc_models.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,69 @@ def fit_outcome_model(
527527
normal_outcome=True,
528528
spline_component=False,
529529
):
530+
"""
531+
Fit a Bayesian outcome model using covariates and previously estimated propensity scores.
532+
533+
This function implements the second stage of a modular two-step causal inference procedure.
534+
It uses propensity scores extracted from a prior treatment model (via `self.fit()`) to adjust
535+
for confounding when estimating treatment effects on an outcome variable `y`.
536+
537+
Parameters
538+
----------
539+
X_outcome : array-like, shape (n_samples, n_covariates)
540+
Covariate matrix for the outcome model.
541+
542+
y : array-like, shape (n_samples,)
543+
Observed outcome variable.
544+
545+
coords : dict
546+
Coordinate dictionary for named dimensions in the PyMC model. Should include
547+
a key "outcome_coeffs" for `X_outcome`.
548+
549+
priors : dict, optional
550+
Dictionary specifying priors for outcome model parameters:
551+
- "b_outcome": list [mean, std] for regression coefficients.
552+
- "a_outcome": list [mean, std] for the intercept.
553+
- "sigma": standard deviation of the outcome noise (default 1).
554+
555+
noncentred : bool, default True
556+
If True, use a non-centred parameterization for the outcome coefficients.
557+
558+
normal_outcome : bool, default True
559+
If True, assume a Normal likelihood for the outcome.
560+
If False, use a Student-t likelihood with unknown degrees of freedom.
561+
562+
spline_component : bool, default False
563+
If True, include a spline basis expansion on the propensity score to allow
564+
flexible (nonlinear) adjustment. Uses B-splines with 30 internal knots.
565+
566+
Returns
567+
-------
568+
idata_outcome : arviz.InferenceData
569+
The posterior and prior predictive samples from the outcome model.
570+
571+
model_outcome : pm.Model
572+
The PyMC model object.
573+
574+
Raises
575+
------
576+
AttributeError
577+
If the `self.idata` attribute is not available, which indicates that
578+
`fit()` (i.e., the treatment model) has not been called yet.
579+
580+
Notes
581+
-----
582+
- This model uses a sampled version of the propensity score (`p`) from the
583+
posterior of the treatment model, randomly selecting one posterior draw
584+
per call.
585+
- The term `beta_ps[0] * p + beta_ps[1] * (p * treatment)` captures both
586+
main and interaction effects of the propensity score.
587+
- Including spline adjustment enables modeling nonlinear relationships
588+
between the propensity score and the outcome.
589+
- Compatible with IPW-style estimation when combined with weighted loss or
590+
diagnostics outside this function.
591+
592+
"""
530593
if not hasattr(self, "idata"):
531594
raise AttributeError("""Object is missing required attribute 'idata'
532595
so cannot proceed. Call fit() first""")
@@ -551,7 +614,6 @@ def fit_outcome_model(
551614
dims="outcome_coeffs",
552615
)
553616

554-
beta_ps_spline = pm.Normal("beta_ps_spline", 0, 1, size=34)
555617
beta_ps = pm.Normal("beta_ps", 0, 1, size=2)
556618

557619
chosen = np.random.choice(range(propensity_scores.shape[1]))

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import arviz as az
1515
import numpy as np
1616
import pandas as pd
17+
import pymc as pm
1718
import pytest
1819
from matplotlib import pyplot as plt
1920

@@ -723,5 +724,29 @@ def test_inverse_prop(mock_pymc_sample):
723724
assert isinstance(fig, plt.Figure)
724725
assert isinstance(axs, list)
725726
assert all(isinstance(ax, plt.Axes) for ax in axs)
727+
plt.close()
726728
with pytest.raises(NotImplementedError):
727729
result.get_plot_data()
730+
731+
### testing outcome model
732+
idata_normal, model_normal = result.model.fit_outcome_model(
733+
X_outcome=result.X_outcome,
734+
y=result.y,
735+
coords=result.coords,
736+
normal_outcome=True,
737+
spline_component=False,
738+
)
739+
assert isinstance(idata_normal, az.InferenceData)
740+
assert isinstance(model_normal, pm.Model)
741+
assert "beta_" in idata_normal.posterior
742+
assert "beta_ps" in idata_normal.posterior
743+
744+
# Test spline model
745+
idata_spline, _ = result.model.fit_outcome_model(
746+
X_outcome=result.X_outcome,
747+
y=result.y,
748+
coords=result.coords,
749+
normal_outcome=True,
750+
spline_component=True,
751+
)
752+
assert "spline_features" in idata_spline.posterior

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)