diff --git a/causalpy/experiments/prepostnegd.py b/causalpy/experiments/prepostnegd.py index a187b7aa..32c1ceb1 100644 --- a/causalpy/experiments/prepostnegd.py +++ b/causalpy/experiments/prepostnegd.py @@ -82,7 +82,7 @@ class PrePostNEGD(BaseExperiment): Intercept -0.5, 94% HDI [-1, 0.2] C(group)[T.1] 2, 94% HDI [2, 2] pre 1, 94% HDI [1, 1] - sigma 0.5, 94% HDI [0.5, 0.6] + y_hat_sigma 0.5, 94% HDI [0.5, 0.6] """ supports_ols = False diff --git a/causalpy/pymc_models.py b/causalpy/pymc_models.py index 50bfb0cb..d3ecbb6e 100644 --- a/causalpy/pymc_models.py +++ b/causalpy/pymc_models.py @@ -22,6 +22,7 @@ import pytensor.tensor as pt import xarray as xr from arviz import r2_score +from pymc_extras.prior import Prior from causalpy.utils import round_num @@ -89,7 +90,18 @@ class PyMCModel(pm.Model): Inference data... """ - def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None): + @property + def default_priors(self): + return {} + + def priors_from_data(self, X, y) -> Dict[str, Any]: + return {} + + def __init__( + self, + sample_kwargs: Optional[Dict[str, Any]] = None, + priors: dict[str, Any] | None = None, + ): """ :param sample_kwargs: A dictionary of kwargs that get unpacked and passed to the :func:`pymc.sample` function. Defaults to an empty dictionary. @@ -98,6 +110,8 @@ def __init__(self, sample_kwargs: Optional[Dict[str, Any]] = None): self.idata = None self.sample_kwargs = sample_kwargs if sample_kwargs is not None else {} + self.priors = {**self.default_priors, **(priors or {})} + def build_model(self, X, y, coords) -> None: """Build the model, must be implemented by subclass.""" raise NotImplementedError("This method must be implemented by a subclass") @@ -143,6 +157,8 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None: # sample_posterior_predictive() if provided in sample_kwargs. random_seed = self.sample_kwargs.get("random_seed", None) + self.priors = {**self.priors_from_data(X, y), **self.priors} + self.build_model(X, y, coords) with self: self.idata = pm.sample(**self.sample_kwargs) @@ -238,26 +254,34 @@ def print_coefficients_for_unit( ) -> None: """Print coefficients for a single unit""" # Determine the width of the longest label - max_label_length = max(len(name) for name in labels + ["sigma"]) + max_label_length = max(len(name) for name in labels + ["y_hat_sigma"]) for name in labels: coeff_samples = unit_coeffs.sel(coeffs=name) print_row(max_label_length, name, coeff_samples, round_to) # Add coefficient for measurement std - print_row(max_label_length, "sigma", unit_sigma, round_to) + print_row(max_label_length, "y_hat_sigma", unit_sigma, round_to) print("Model coefficients:") coeffs = az.extract(self.idata.posterior, var_names="beta") - # Always has treated_units dimension - no branching needed! + # Check if sigma or y_hat_sigma variable exists + sigma_var_name = None + if "sigma" in self.idata.posterior: + sigma_var_name = "sigma" + elif "y_hat_sigma" in self.idata.posterior: + sigma_var_name = "y_hat_sigma" + else: + raise ValueError("Neither 'sigma' nor 'y_hat_sigma' found in posterior") + treated_units = coeffs.coords["treated_units"].values for unit in treated_units: if len(treated_units) > 1: print(f"\nTreated unit: {unit}") unit_coeffs = coeffs.sel(treated_units=unit) - unit_sigma = az.extract(self.idata.posterior, var_names="sigma").sel( + unit_sigma = az.extract(self.idata.posterior, var_names=sigma_var_name).sel( treated_units=unit ) print_coefficients_for_unit(unit_coeffs, unit_sigma, labels, round_to or 2) @@ -300,6 +324,15 @@ class LinearRegression(PyMCModel): Inference data... """ # noqa: W605 + default_priors = { + "beta": Prior("Normal", mu=0, sigma=50, dims=["treated_units", "coeffs"]), + "y_hat": Prior( + "Normal", + sigma=Prior("HalfNormal", sigma=1, dims=["treated_units"]), + dims=["obs_ind", "treated_units"], + ), + } + def build_model(self, X, y, coords): """ Defines the PyMC model @@ -313,12 +346,13 @@ def build_model(self, X, y, coords): self.add_coords(coords) X = pm.Data("X", X, dims=["obs_ind", "coeffs"]) y = pm.Data("y", y, dims=["obs_ind", "treated_units"]) - beta = pm.Normal("beta", 0, 50, dims=["treated_units", "coeffs"]) - sigma = pm.HalfNormal("sigma", 1, dims="treated_units") + # beta = pm.Normal("beta", 0, 50, dims=["treated_units", "coeffs"]) + beta = self.priors["beta"].create_variable("beta") mu = pm.Deterministic( "mu", pt.dot(X, beta.T), dims=["obs_ind", "treated_units"] ) - pm.Normal("y_hat", mu, sigma, observed=y, dims=["obs_ind", "treated_units"]) + # pm.Normal("y_hat", mu, sigma, observed=y, dims=["obs_ind", "treated_units"]) + self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y) class WeightedSumFitter(PyMCModel): @@ -361,23 +395,35 @@ class WeightedSumFitter(PyMCModel): Inference data... """ # noqa: W605 + default_priors = { + "y_hat": Prior( + "Normal", + sigma=Prior("HalfNormal", sigma=1, dims=["treated_units"]), + dims=["obs_ind", "treated_units"], + ), + } + + def priors_from_data(self, X, y) -> Dict[str, Any]: + n_predictors = X.shape[1] + return { + "beta": Prior( + "Dirichlet", a=np.ones(n_predictors), dims=["treated_units", "coeffs"] + ), + } + def build_model(self, X, y, coords): """ Defines the PyMC model """ with self: self.add_coords(coords) - n_predictors = X.sizes["coeffs"] X = pm.Data("X", X, dims=["obs_ind", "coeffs"]) y = pm.Data("y", y, dims=["obs_ind", "treated_units"]) - beta = pm.Dirichlet( - "beta", a=np.ones(n_predictors), dims=["treated_units", "coeffs"] - ) - sigma = pm.HalfNormal("sigma", 1, dims="treated_units") + beta = self.priors["beta"].create_variable("beta") mu = pm.Deterministic( "mu", pt.dot(X, beta.T), dims=["obs_ind", "treated_units"] ) - pm.Normal("y_hat", mu, sigma, observed=y, dims=["obs_ind", "treated_units"]) + self.priors["y_hat"].create_likelihood_variable("y_hat", mu=mu, observed=y) class InstrumentalVariableRegression(PyMCModel): @@ -566,13 +612,17 @@ class PropensityScore(PyMCModel): Inference... """ # noqa: W605 + default_priors = { + "b": Prior("Normal", mu=0, sigma=1, dims="coeffs"), + } + def build_model(self, X, t, coords): "Defines the PyMC propensity model" with self: self.add_coords(coords) X_data = pm.Data("X", X, dims=["obs_ind", "coeffs"]) t_data = pm.Data("t", t.flatten(), dims="obs_ind") - b = pm.Normal("b", mu=0, sigma=1, dims="coeffs") + b = self.priors["b"].create_variable("b") mu = pt.dot(X_data, b) p = pm.Deterministic("p", pm.math.invlogit(mu)) pm.Bernoulli("t_pred", p=p, observed=t_data, dims="obs_ind") diff --git a/causalpy/tests/test_pymc_models.py b/causalpy/tests/test_pymc_models.py index 22f3a045..e5fc9582 100644 --- a/causalpy/tests/test_pymc_models.py +++ b/causalpy/tests/test_pymc_models.py @@ -45,7 +45,7 @@ def build_model(self, X, y, coords): X_ = pm.Data(name="X", value=X, dims=["obs_ind", "coeffs"]) y_ = pm.Data(name="y", value=y, dims=["obs_ind", "treated_units"]) beta = pm.Normal("beta", mu=0, sigma=1, dims=["treated_units", "coeffs"]) - sigma = pm.HalfNormal("sigma", sigma=1, dims="treated_units") + sigma = pm.HalfNormal("y_hat_sigma", sigma=1, dims="treated_units") mu = pm.Deterministic( "mu", pm.math.dot(X_, beta.T), dims=["obs_ind", "treated_units"] ) @@ -159,7 +159,7 @@ def test_fit_predict(self, coords, rng, mock_pymc_sample) -> None: 2, 2 * 2, ) # (treated_units, coeffs, sample) - assert az.extract(data=model.idata, var_names=["sigma"]).shape == ( + assert az.extract(data=model.idata, var_names=["y_hat_sigma"]).shape == ( 1, 2 * 2, ) # (treated_units, sample) @@ -402,7 +402,7 @@ def test_multi_unit_coefficients(self, synthetic_control_data): # Extract coefficients beta = az.extract(wsf.idata.posterior, var_names="beta") - sigma = az.extract(wsf.idata.posterior, var_names="sigma") + sigma = az.extract(wsf.idata.posterior, var_names="y_hat_sigma") # Check beta dimensions: should be (sample, treated_units, coeffs) assert "treated_units" in beta.dims @@ -461,7 +461,7 @@ def test_print_coefficients_multi_unit(self, synthetic_control_data, capsys): assert control in output # Check that sigma is printed for each unit - assert output.count("sigma") == len(treated_units) + assert output.count("y_hat_sigma") == len(treated_units) def test_scoring_multi_unit(self, synthetic_control_data): """Test that scoring works with multiple treated units.""" diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index d2d886ad..5b70fde2 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,10 +1,10 @@ - interrogate: 95.4% + interrogate: 94.2% - + @@ -12,8 +12,8 @@ interrogate interrogate - 95.4% - 95.4% + 94.2% + 94.2% diff --git a/environment.yml b/environment.yml index 02b7f920..2bc8ed20 100644 --- a/environment.yml +++ b/environment.yml @@ -15,3 +15,4 @@ dependencies: - seaborn>=0.11.2 - statsmodels - xarray>=v2022.11.0 + - pymc-extras>=0.2.7 diff --git a/pyproject.toml b/pyproject.toml index 909f7969..88df3f87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "seaborn>=0.11.2", "statsmodels", "xarray>=v2022.11.0", + "pymc-extras>=0.2.7", ] # List additional groups of dependencies here (e.g. development dependencies). Users