|
17 | 17 | import pymc as pm |
18 | 18 | import pytest |
19 | 19 | import xarray as xr |
| 20 | +from pymc_extras.prior import Prior |
20 | 21 |
|
21 | 22 | import causalpy as cp |
22 | | -from causalpy.pymc_models import PyMCModel, WeightedSumFitter |
| 23 | +from causalpy.pymc_models import LinearRegression, PyMCModel, WeightedSumFitter |
23 | 24 |
|
24 | 25 | sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2} |
25 | 26 |
|
@@ -592,3 +593,249 @@ def test_r2_scores_differ_across_units(self, rng): |
592 | 593 | f"R² standard deviation is too low ({r2_std}), suggesting insufficient variation " |
593 | 594 | "between treated units. This might indicate a scoring implementation issue." |
594 | 595 | ) |
| 596 | + |
| 597 | + |
| 598 | +@pytest.fixture(scope="module") |
| 599 | +def prior_test_data(): |
| 600 | + """Generate test data for Prior integration tests (shared across all tests).""" |
| 601 | + rng = np.random.default_rng(42) |
| 602 | + X = xr.DataArray( |
| 603 | + rng.normal(loc=0, scale=1, size=(20, 2)), |
| 604 | + dims=["obs_ind", "coeffs"], |
| 605 | + coords={"obs_ind": np.arange(20), "coeffs": ["x1", "x2"]}, |
| 606 | + ) |
| 607 | + y = xr.DataArray( |
| 608 | + rng.normal(loc=0, scale=1, size=(20, 1)), |
| 609 | + dims=["obs_ind", "treated_units"], |
| 610 | + coords={"obs_ind": np.arange(20), "treated_units": ["unit_0"]}, |
| 611 | + ) |
| 612 | + coords = { |
| 613 | + "obs_ind": np.arange(20), |
| 614 | + "coeffs": ["x1", "x2"], |
| 615 | + "treated_units": ["unit_0"], |
| 616 | + } |
| 617 | + return X, y, coords |
| 618 | + |
| 619 | + |
| 620 | +class TestPriorIntegration: |
| 621 | + """ |
| 622 | + Test suite for Prior class integration with PyMC models. |
| 623 | + Tests the precedence system, data-driven priors, and Prior class usage. |
| 624 | + """ |
| 625 | + |
| 626 | + def test_default_priors_property(self): |
| 627 | + """Test that default_priors property returns correct Prior objects.""" |
| 628 | + model = LinearRegression() |
| 629 | + defaults = model.default_priors |
| 630 | + |
| 631 | + # Check that defaults is a dictionary with expected keys |
| 632 | + assert isinstance(defaults, dict) |
| 633 | + assert "beta" in defaults |
| 634 | + assert "y_hat" in defaults |
| 635 | + |
| 636 | + # Check that values are Prior objects |
| 637 | + assert isinstance(defaults["beta"], Prior) |
| 638 | + assert isinstance(defaults["y_hat"], Prior) |
| 639 | + |
| 640 | + # Check Prior configuration using correct API |
| 641 | + beta_prior = defaults["beta"] |
| 642 | + assert beta_prior.distribution == "Normal" |
| 643 | + assert beta_prior.parameters["mu"] == 0 |
| 644 | + assert beta_prior.parameters["sigma"] == 50 |
| 645 | + |
| 646 | + def test_priors_from_data_base_implementation(self, prior_test_data): |
| 647 | + """Test that base PyMCModel.priors_from_data returns empty dict.""" |
| 648 | + X, y, coords = prior_test_data |
| 649 | + model = PyMCModel() |
| 650 | + data_priors = model.priors_from_data(X, y) |
| 651 | + assert isinstance(data_priors, dict) |
| 652 | + assert len(data_priors) == 0 |
| 653 | + |
| 654 | + def test_weighted_sum_fitter_priors_from_data(self, prior_test_data): |
| 655 | + """Test WeightedSumFitter data-driven Dirichlet prior generation.""" |
| 656 | + X, y, coords = prior_test_data |
| 657 | + model = WeightedSumFitter() |
| 658 | + data_priors = model.priors_from_data(X, y) |
| 659 | + |
| 660 | + # Should return beta prior based on X shape |
| 661 | + assert "beta" in data_priors |
| 662 | + beta_prior = data_priors["beta"] |
| 663 | + |
| 664 | + # Check it's a Dirichlet prior using correct API |
| 665 | + assert isinstance(beta_prior, Prior) |
| 666 | + assert beta_prior.distribution == "Dirichlet" |
| 667 | + |
| 668 | + # Check shape matches number of predictors |
| 669 | + assert len(beta_prior.parameters["a"]) == X.shape[1] # 2 predictors |
| 670 | + assert np.allclose(beta_prior.parameters["a"], np.ones(2)) |
| 671 | + |
| 672 | + def test_prior_precedence_system(self, prior_test_data): |
| 673 | + """Test that user priors override data-driven priors override defaults.""" |
| 674 | + X, y, coords = prior_test_data |
| 675 | + # Create custom user prior |
| 676 | + user_beta_prior = Prior( |
| 677 | + "Normal", mu=100, sigma=10, dims=("treated_units", "coeffs") |
| 678 | + ) |
| 679 | + |
| 680 | + model = LinearRegression(priors={"beta": user_beta_prior}) |
| 681 | + |
| 682 | + # Before fit, should have user prior + defaults |
| 683 | + assert model.priors["beta"] == user_beta_prior |
| 684 | + assert "y_hat" in model.priors # From defaults |
| 685 | + |
| 686 | + # After calling priors_from_data, user prior should remain |
| 687 | + data_priors = model.priors_from_data(X, y) |
| 688 | + merged_priors = {**data_priors, **model.priors} |
| 689 | + |
| 690 | + # User prior should override any data-driven prior |
| 691 | + assert merged_priors["beta"] == user_beta_prior |
| 692 | + |
| 693 | + def test_prior_precedence_integration_in_fit(self, prior_test_data): |
| 694 | + """Test the complete prior precedence system during fit().""" |
| 695 | + X, y, coords = prior_test_data |
| 696 | + # Create model with custom user prior |
| 697 | + custom_prior = Prior("Normal", mu=5, sigma=2, dims=("treated_units", "coeffs")) |
| 698 | + model = LinearRegression( |
| 699 | + priors={"beta": custom_prior}, |
| 700 | + sample_kwargs={"tune": 5, "draws": 5, "chains": 1, "progressbar": False}, |
| 701 | + ) |
| 702 | + |
| 703 | + # Fit the model |
| 704 | + model.fit(X, y, coords=coords) |
| 705 | + |
| 706 | + # Check that the model was built with the custom prior |
| 707 | + # We can verify this by checking the model context |
| 708 | + assert model.idata is not None |
| 709 | + assert "beta" in model.idata.posterior |
| 710 | + |
| 711 | + def test_prior_dimensions_consistency(self): |
| 712 | + """Test that Prior dimensions are consistent with model expectations.""" |
| 713 | + model = LinearRegression() |
| 714 | + |
| 715 | + # Check default priors have correct dimensions (tuples, not lists) |
| 716 | + beta_prior = model.default_priors["beta"] |
| 717 | + assert beta_prior.dims == ("treated_units", "coeffs") |
| 718 | + |
| 719 | + y_hat_prior = model.default_priors["y_hat"] |
| 720 | + assert y_hat_prior.dims == ("obs_ind", "treated_units") |
| 721 | + |
| 722 | + # Check that sigma component has correct dims |
| 723 | + sigma_prior = y_hat_prior.parameters["sigma"] |
| 724 | + assert isinstance(sigma_prior, Prior) |
| 725 | + assert sigma_prior.dims == ("treated_units",) |
| 726 | + |
| 727 | + def test_custom_prior_with_build_model(self, prior_test_data): |
| 728 | + """Test that custom priors work correctly in build_model.""" |
| 729 | + # Create a custom Prior with different parameters |
| 730 | + custom_beta = Prior( |
| 731 | + "Normal", |
| 732 | + mu=0, |
| 733 | + sigma=10, # Different from default (50) |
| 734 | + dims=("treated_units", "coeffs"), |
| 735 | + ) |
| 736 | + custom_sigma = Prior( |
| 737 | + "HalfNormal", |
| 738 | + sigma=2, # Different from default (1) |
| 739 | + dims=("treated_units",), |
| 740 | + ) |
| 741 | + custom_y_hat = Prior( |
| 742 | + "Normal", sigma=custom_sigma, dims=("obs_ind", "treated_units") |
| 743 | + ) |
| 744 | + |
| 745 | + model = LinearRegression(priors={"beta": custom_beta, "y_hat": custom_y_hat}) |
| 746 | + |
| 747 | + # Build the model to ensure priors work |
| 748 | + X, y, coords = prior_test_data |
| 749 | + model.build_model(X, y, coords) |
| 750 | + |
| 751 | + # Check that variables were created in the model context |
| 752 | + with model: |
| 753 | + assert "beta" in model.named_vars |
| 754 | + assert "y_hat" in model.named_vars |
| 755 | + |
| 756 | + def test_prior_create_variable_integration(self, prior_test_data): |
| 757 | + """Test that Prior.create_variable works in model context.""" |
| 758 | + X, y, coords = prior_test_data |
| 759 | + model = LinearRegression() |
| 760 | + model.build_model(X, y, coords) |
| 761 | + |
| 762 | + # Verify that Prior.create_variable was called successfully |
| 763 | + # by checking the created variables exist and have expected names |
| 764 | + with model: |
| 765 | + beta_var = model.named_vars["beta"] |
| 766 | + # Check that the variable exists and is a PyMC variable |
| 767 | + assert beta_var is not None |
| 768 | + assert hasattr(beta_var, "name") |
| 769 | + assert beta_var.name == "beta" |
| 770 | + |
| 771 | + def test_weighted_sum_fitter_dirichlet_prior_shape(self, prior_test_data): |
| 772 | + """Test that WeightedSumFitter creates correct Dirichlet shape.""" |
| 773 | + _, y, _ = prior_test_data |
| 774 | + rng = np.random.default_rng(42) |
| 775 | + # Test with different numbers of control units |
| 776 | + for n_controls in [3, 5, 10]: |
| 777 | + X = xr.DataArray( |
| 778 | + rng.normal(size=(20, n_controls)), |
| 779 | + dims=["obs_ind", "coeffs"], |
| 780 | + coords={ |
| 781 | + "obs_ind": np.arange(20), |
| 782 | + "coeffs": [f"control_{i}" for i in range(n_controls)], |
| 783 | + }, |
| 784 | + ) |
| 785 | + |
| 786 | + model = WeightedSumFitter() |
| 787 | + data_priors = model.priors_from_data(X, y) |
| 788 | + |
| 789 | + beta_prior = data_priors["beta"] |
| 790 | + assert len(beta_prior.parameters["a"]) == n_controls |
| 791 | + assert np.allclose(beta_prior.parameters["a"], np.ones(n_controls)) |
| 792 | + |
| 793 | + def test_prior_none_handling(self): |
| 794 | + """Test that models handle None priors parameter correctly.""" |
| 795 | + model = LinearRegression(priors=None) |
| 796 | + |
| 797 | + # Should still have default priors |
| 798 | + assert len(model.priors) > 0 |
| 799 | + assert "beta" in model.priors |
| 800 | + assert "y_hat" in model.priors |
| 801 | + |
| 802 | + def test_empty_priors_dict(self): |
| 803 | + """Test that models handle empty priors dict correctly.""" |
| 804 | + model = LinearRegression(priors={}) |
| 805 | + |
| 806 | + # Should still have default priors |
| 807 | + assert len(model.priors) > 0 |
| 808 | + assert "beta" in model.priors |
| 809 | + assert "y_hat" in model.priors |
| 810 | + |
| 811 | + def test_priors_from_data_called_during_fit(self, prior_test_data): |
| 812 | + """Test that priors_from_data is called and integrated during fit.""" |
| 813 | + |
| 814 | + # Create a mock model that tracks priors_from_data calls |
| 815 | + class TrackingWeightedSumFitter(WeightedSumFitter): |
| 816 | + def __init__(self, *args, **kwargs): |
| 817 | + super().__init__(*args, **kwargs) |
| 818 | + self.priors_from_data_called = False |
| 819 | + self.priors_from_data_args = None |
| 820 | + |
| 821 | + def priors_from_data(self, X, y): |
| 822 | + self.priors_from_data_called = True |
| 823 | + self.priors_from_data_args = (X, y) |
| 824 | + return super().priors_from_data(X, y) |
| 825 | + |
| 826 | + model = TrackingWeightedSumFitter( |
| 827 | + sample_kwargs={"tune": 2, "draws": 2, "chains": 1, "progressbar": False} |
| 828 | + ) |
| 829 | + |
| 830 | + # Fit the model |
| 831 | + X, y, coords = prior_test_data |
| 832 | + model.fit(X, y, coords=coords) |
| 833 | + |
| 834 | + # Verify priors_from_data was called with correct arguments |
| 835 | + assert model.priors_from_data_called |
| 836 | + assert model.priors_from_data_args is not None |
| 837 | + |
| 838 | + # Verify the model has the Dirichlet prior after fitting |
| 839 | + assert "beta" in model.priors |
| 840 | + beta_prior = model.priors["beta"] |
| 841 | + assert beta_prior.distribution == "Dirichlet" |
0 commit comments