Skip to content

Commit 3c659d3

Browse files
committed
add tests
1 parent 0650644 commit 3c659d3

File tree

2 files changed

+251
-4
lines changed

2 files changed

+251
-4
lines changed

causalpy/tests/test_pymc_models.py

Lines changed: 248 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
import pymc as pm
1818
import pytest
1919
import xarray as xr
20+
from pymc_extras.prior import Prior
2021

2122
import causalpy as cp
22-
from causalpy.pymc_models import PyMCModel, WeightedSumFitter
23+
from causalpy.pymc_models import LinearRegression, PyMCModel, WeightedSumFitter
2324

2425
sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2}
2526

@@ -592,3 +593,249 @@ def test_r2_scores_differ_across_units(self, rng):
592593
f"R² standard deviation is too low ({r2_std}), suggesting insufficient variation "
593594
"between treated units. This might indicate a scoring implementation issue."
594595
)
596+
597+
598+
@pytest.fixture(scope="module")
599+
def prior_test_data():
600+
"""Generate test data for Prior integration tests (shared across all tests)."""
601+
rng = np.random.default_rng(42)
602+
X = xr.DataArray(
603+
rng.normal(loc=0, scale=1, size=(20, 2)),
604+
dims=["obs_ind", "coeffs"],
605+
coords={"obs_ind": np.arange(20), "coeffs": ["x1", "x2"]},
606+
)
607+
y = xr.DataArray(
608+
rng.normal(loc=0, scale=1, size=(20, 1)),
609+
dims=["obs_ind", "treated_units"],
610+
coords={"obs_ind": np.arange(20), "treated_units": ["unit_0"]},
611+
)
612+
coords = {
613+
"obs_ind": np.arange(20),
614+
"coeffs": ["x1", "x2"],
615+
"treated_units": ["unit_0"],
616+
}
617+
return X, y, coords
618+
619+
620+
class TestPriorIntegration:
621+
"""
622+
Test suite for Prior class integration with PyMC models.
623+
Tests the precedence system, data-driven priors, and Prior class usage.
624+
"""
625+
626+
def test_default_priors_property(self):
627+
"""Test that default_priors property returns correct Prior objects."""
628+
model = LinearRegression()
629+
defaults = model.default_priors
630+
631+
# Check that defaults is a dictionary with expected keys
632+
assert isinstance(defaults, dict)
633+
assert "beta" in defaults
634+
assert "y_hat" in defaults
635+
636+
# Check that values are Prior objects
637+
assert isinstance(defaults["beta"], Prior)
638+
assert isinstance(defaults["y_hat"], Prior)
639+
640+
# Check Prior configuration using correct API
641+
beta_prior = defaults["beta"]
642+
assert beta_prior.distribution == "Normal"
643+
assert beta_prior.parameters["mu"] == 0
644+
assert beta_prior.parameters["sigma"] == 50
645+
646+
def test_priors_from_data_base_implementation(self, prior_test_data):
647+
"""Test that base PyMCModel.priors_from_data returns empty dict."""
648+
X, y, coords = prior_test_data
649+
model = PyMCModel()
650+
data_priors = model.priors_from_data(X, y)
651+
assert isinstance(data_priors, dict)
652+
assert len(data_priors) == 0
653+
654+
def test_weighted_sum_fitter_priors_from_data(self, prior_test_data):
655+
"""Test WeightedSumFitter data-driven Dirichlet prior generation."""
656+
X, y, coords = prior_test_data
657+
model = WeightedSumFitter()
658+
data_priors = model.priors_from_data(X, y)
659+
660+
# Should return beta prior based on X shape
661+
assert "beta" in data_priors
662+
beta_prior = data_priors["beta"]
663+
664+
# Check it's a Dirichlet prior using correct API
665+
assert isinstance(beta_prior, Prior)
666+
assert beta_prior.distribution == "Dirichlet"
667+
668+
# Check shape matches number of predictors
669+
assert len(beta_prior.parameters["a"]) == X.shape[1] # 2 predictors
670+
assert np.allclose(beta_prior.parameters["a"], np.ones(2))
671+
672+
def test_prior_precedence_system(self, prior_test_data):
673+
"""Test that user priors override data-driven priors override defaults."""
674+
X, y, coords = prior_test_data
675+
# Create custom user prior
676+
user_beta_prior = Prior(
677+
"Normal", mu=100, sigma=10, dims=("treated_units", "coeffs")
678+
)
679+
680+
model = LinearRegression(priors={"beta": user_beta_prior})
681+
682+
# Before fit, should have user prior + defaults
683+
assert model.priors["beta"] == user_beta_prior
684+
assert "y_hat" in model.priors # From defaults
685+
686+
# After calling priors_from_data, user prior should remain
687+
data_priors = model.priors_from_data(X, y)
688+
merged_priors = {**data_priors, **model.priors}
689+
690+
# User prior should override any data-driven prior
691+
assert merged_priors["beta"] == user_beta_prior
692+
693+
def test_prior_precedence_integration_in_fit(self, prior_test_data):
694+
"""Test the complete prior precedence system during fit()."""
695+
X, y, coords = prior_test_data
696+
# Create model with custom user prior
697+
custom_prior = Prior("Normal", mu=5, sigma=2, dims=("treated_units", "coeffs"))
698+
model = LinearRegression(
699+
priors={"beta": custom_prior},
700+
sample_kwargs={"tune": 5, "draws": 5, "chains": 1, "progressbar": False},
701+
)
702+
703+
# Fit the model
704+
model.fit(X, y, coords=coords)
705+
706+
# Check that the model was built with the custom prior
707+
# We can verify this by checking the model context
708+
assert model.idata is not None
709+
assert "beta" in model.idata.posterior
710+
711+
def test_prior_dimensions_consistency(self):
712+
"""Test that Prior dimensions are consistent with model expectations."""
713+
model = LinearRegression()
714+
715+
# Check default priors have correct dimensions (tuples, not lists)
716+
beta_prior = model.default_priors["beta"]
717+
assert beta_prior.dims == ("treated_units", "coeffs")
718+
719+
y_hat_prior = model.default_priors["y_hat"]
720+
assert y_hat_prior.dims == ("obs_ind", "treated_units")
721+
722+
# Check that sigma component has correct dims
723+
sigma_prior = y_hat_prior.parameters["sigma"]
724+
assert isinstance(sigma_prior, Prior)
725+
assert sigma_prior.dims == ("treated_units",)
726+
727+
def test_custom_prior_with_build_model(self, prior_test_data):
728+
"""Test that custom priors work correctly in build_model."""
729+
# Create a custom Prior with different parameters
730+
custom_beta = Prior(
731+
"Normal",
732+
mu=0,
733+
sigma=10, # Different from default (50)
734+
dims=("treated_units", "coeffs"),
735+
)
736+
custom_sigma = Prior(
737+
"HalfNormal",
738+
sigma=2, # Different from default (1)
739+
dims=("treated_units",),
740+
)
741+
custom_y_hat = Prior(
742+
"Normal", sigma=custom_sigma, dims=("obs_ind", "treated_units")
743+
)
744+
745+
model = LinearRegression(priors={"beta": custom_beta, "y_hat": custom_y_hat})
746+
747+
# Build the model to ensure priors work
748+
X, y, coords = prior_test_data
749+
model.build_model(X, y, coords)
750+
751+
# Check that variables were created in the model context
752+
with model:
753+
assert "beta" in model.named_vars
754+
assert "y_hat" in model.named_vars
755+
756+
def test_prior_create_variable_integration(self, prior_test_data):
757+
"""Test that Prior.create_variable works in model context."""
758+
X, y, coords = prior_test_data
759+
model = LinearRegression()
760+
model.build_model(X, y, coords)
761+
762+
# Verify that Prior.create_variable was called successfully
763+
# by checking the created variables exist and have expected names
764+
with model:
765+
beta_var = model.named_vars["beta"]
766+
# Check that the variable exists and is a PyMC variable
767+
assert beta_var is not None
768+
assert hasattr(beta_var, "name")
769+
assert beta_var.name == "beta"
770+
771+
def test_weighted_sum_fitter_dirichlet_prior_shape(self, prior_test_data):
772+
"""Test that WeightedSumFitter creates correct Dirichlet shape."""
773+
_, y, _ = prior_test_data
774+
rng = np.random.default_rng(42)
775+
# Test with different numbers of control units
776+
for n_controls in [3, 5, 10]:
777+
X = xr.DataArray(
778+
rng.normal(size=(20, n_controls)),
779+
dims=["obs_ind", "coeffs"],
780+
coords={
781+
"obs_ind": np.arange(20),
782+
"coeffs": [f"control_{i}" for i in range(n_controls)],
783+
},
784+
)
785+
786+
model = WeightedSumFitter()
787+
data_priors = model.priors_from_data(X, y)
788+
789+
beta_prior = data_priors["beta"]
790+
assert len(beta_prior.parameters["a"]) == n_controls
791+
assert np.allclose(beta_prior.parameters["a"], np.ones(n_controls))
792+
793+
def test_prior_none_handling(self):
794+
"""Test that models handle None priors parameter correctly."""
795+
model = LinearRegression(priors=None)
796+
797+
# Should still have default priors
798+
assert len(model.priors) > 0
799+
assert "beta" in model.priors
800+
assert "y_hat" in model.priors
801+
802+
def test_empty_priors_dict(self):
803+
"""Test that models handle empty priors dict correctly."""
804+
model = LinearRegression(priors={})
805+
806+
# Should still have default priors
807+
assert len(model.priors) > 0
808+
assert "beta" in model.priors
809+
assert "y_hat" in model.priors
810+
811+
def test_priors_from_data_called_during_fit(self, prior_test_data):
812+
"""Test that priors_from_data is called and integrated during fit."""
813+
814+
# Create a mock model that tracks priors_from_data calls
815+
class TrackingWeightedSumFitter(WeightedSumFitter):
816+
def __init__(self, *args, **kwargs):
817+
super().__init__(*args, **kwargs)
818+
self.priors_from_data_called = False
819+
self.priors_from_data_args = None
820+
821+
def priors_from_data(self, X, y):
822+
self.priors_from_data_called = True
823+
self.priors_from_data_args = (X, y)
824+
return super().priors_from_data(X, y)
825+
826+
model = TrackingWeightedSumFitter(
827+
sample_kwargs={"tune": 2, "draws": 2, "chains": 1, "progressbar": False}
828+
)
829+
830+
# Fit the model
831+
X, y, coords = prior_test_data
832+
model.fit(X, y, coords=coords)
833+
834+
# Verify priors_from_data was called with correct arguments
835+
assert model.priors_from_data_called
836+
assert model.priors_from_data_args is not None
837+
838+
# Verify the model has the Dirichlet prior after fitting
839+
assert "beta" in model.priors
840+
beta_prior = model.priors["beta"]
841+
assert beta_prior.distribution == "Dirichlet"

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)