Skip to content

Commit 4807bec

Browse files
authored
Merge branch 'main' into cate_example
2 parents 762137b + 9a88b01 commit 4807bec

17 files changed

+864
-555
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ repos:
2525
exclude: &exclude_pattern 'iv_weak_instruments.ipynb'
2626
args: ["--maxkb=1500"]
2727
- repo: https://github.com/astral-sh/ruff-pre-commit
28-
rev: v0.13.3
28+
rev: v0.14.1
2929
hooks:
3030
# Run the linter
3131
- id: ruff

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ check_lint:
1313
interrogate .
1414

1515
doctest:
16-
pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py
16+
python -m pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py
1717

1818
test:
19-
pytest
19+
python -m pytest
2020

2121
uml:
2222
pyreverse -o png causalpy --output-directory docs/source/_static --ignore tests

causalpy/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import arviz as az
1514

1615
import causalpy.pymc_models as pymc_models
1716
import causalpy.skl_models as skl_models
@@ -28,8 +27,6 @@
2827
from .experiments.regression_kink import RegressionKink
2928
from .experiments.synthetic_control import SyntheticControl
3029

31-
az.style.use("arviz-darkgrid")
32-
3330
__all__ = [
3431
"__version__",
3532
"DifferenceInDifferences",

causalpy/experiments/base.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
from abc import abstractmethod
1919

20+
import arviz as az
21+
import matplotlib.pyplot as plt
2022
import pandas as pd
2123
from sklearn.base import RegressorMixin
2224

@@ -63,12 +65,14 @@ def plot(self, *args, **kwargs) -> tuple:
6365
Internally, this function dispatches to either `_bayesian_plot` or `_ols_plot`
6466
depending on the model type.
6567
"""
66-
if isinstance(self.model, PyMCModel):
67-
return self._bayesian_plot(*args, **kwargs)
68-
elif isinstance(self.model, RegressorMixin):
69-
return self._ols_plot(*args, **kwargs)
70-
else:
71-
raise ValueError("Unsupported model type")
68+
# Apply arviz-darkgrid style only during plotting, then revert
69+
with plt.style.context(az.style.library["arviz-darkgrid"]):
70+
if isinstance(self.model, PyMCModel):
71+
return self._bayesian_plot(*args, **kwargs)
72+
elif isinstance(self.model, RegressorMixin):
73+
return self._ols_plot(*args, **kwargs)
74+
else:
75+
raise ValueError("Unsupported model type")
7276

7377
@abstractmethod
7478
def _bayesian_plot(self, *args, **kwargs):

causalpy/experiments/interrupted_time_series.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ class InterruptedTimeSeries(BaseExperiment):
7070
... }
7171
... ),
7272
... )
73+
74+
Notes
75+
-----
76+
For Bayesian models, the causal impact is calculated using the posterior expectation
77+
(``mu``) rather than the posterior predictive (``y_hat``). This means the impact and
78+
its uncertainty represent the systematic causal effect, excluding observation-level
79+
noise. The uncertainty bands in the plots reflect parameter uncertainty and
80+
counterfactual prediction uncertainty, but not individual observation variability.
7381
"""
7482

7583
expt_type = "Interrupted Time Series"

causalpy/experiments/synthetic_control.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,14 @@ class SyntheticControl(BaseExperiment):
6767
... }
6868
... ),
6969
... )
70+
71+
Notes
72+
-----
73+
For Bayesian models, the causal impact is calculated using the posterior expectation
74+
(``mu``) rather than the posterior predictive (``y_hat``). This means the impact and
75+
its uncertainty represent the systematic causal effect, excluding observation-level
76+
noise. The uncertainty bands in the plots reflect parameter uncertainty and
77+
counterfactual prediction uncertainty, but not individual observation variability.
7078
"""
7179

7280
supports_ols = True

causalpy/plot_utils.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,17 @@ def get_hdi_to_df(
9393
:param hdi_prob:
9494
The size of the HDI, default is 0.94
9595
"""
96-
hdi = (
97-
az.hdi(x, hdi_prob=hdi_prob)
98-
.to_dataframe()
99-
.unstack(level="hdi")
100-
.droplevel(0, axis=1)
101-
)
102-
return hdi
96+
hdi_result = az.hdi(x, hdi_prob=hdi_prob)
97+
98+
# Get the data variable name (typically 'mu' or 'x')
99+
# We select only the data variable column to exclude coordinates like 'treated_units'
100+
data_var = list(hdi_result.data_vars)[0]
101+
102+
# Convert to DataFrame, select only the data variable column, then unstack
103+
# This prevents coordinate values (like 'treated_agg') from appearing as columns
104+
hdi_df = hdi_result[data_var].to_dataframe()[[data_var]].unstack(level="hdi")
105+
106+
# Remove the top level of column MultiIndex to get just 'lower' and 'higher'
107+
hdi_df.columns = hdi_df.columns.droplevel(0)
108+
109+
return hdi_df

causalpy/pymc_models.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,44 @@ def score(self, X: xr.DataArray, y: xr.DataArray) -> pd.Series:
305305
def calculate_impact(
306306
self, y_true: xr.DataArray, y_pred: az.InferenceData
307307
) -> xr.DataArray:
308-
impact = y_true - y_pred["posterior_predictive"]["y_hat"]
308+
"""
309+
Calculate the causal impact as the difference between observed and predicted values.
310+
311+
The impact is calculated using the posterior expectation (`mu`) rather than the
312+
posterior predictive (`y_hat`). This means the causal impact represents the
313+
difference from the expected value of the model, excluding observation noise.
314+
This approach provides a cleaner measure of the causal effect by focusing on
315+
the systematic difference rather than including sampling variability from the
316+
observation noise term.
317+
318+
Parameters
319+
----------
320+
y_true : xr.DataArray
321+
The observed outcome values with dimensions ["obs_ind", "treated_units"].
322+
y_pred : az.InferenceData
323+
The posterior predictive samples containing the "mu" variable, which
324+
represents the expected value (mean) of the outcome.
325+
326+
Returns
327+
-------
328+
xr.DataArray
329+
The causal impact with dimensions ending in "obs_ind". The impact includes
330+
posterior uncertainty from the model parameters but excludes observation noise.
331+
332+
Notes
333+
-----
334+
By using `mu` (the posterior expectation) rather than `y_hat` (the posterior
335+
predictive with observation noise), the uncertainty in the impact reflects:
336+
- Parameter uncertainty in the fitted model
337+
- Uncertainty in the counterfactual prediction
338+
339+
But excludes:
340+
- Observation-level noise (sigma)
341+
342+
This makes the impact plots focus on the systematic causal effect rather than
343+
individual observation variability.
344+
"""
345+
impact = y_true - y_pred["posterior_predictive"]["mu"]
309346
return impact.transpose(..., "obs_ind")
310347

311348
def calculate_cumulative_impact(self, impact):

causalpy/tests/test_plot_utils.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright 2025 - 2025 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Tests for plot utility functions
16+
"""
17+
18+
import numpy as np
19+
import pandas as pd
20+
import pytest
21+
import xarray as xr
22+
23+
from causalpy.plot_utils import get_hdi_to_df
24+
25+
26+
@pytest.mark.integration
27+
def test_get_hdi_to_df_with_coordinate_dimensions():
28+
"""
29+
Regression test for bug where get_hdi_to_df returned string coordinate values
30+
instead of numeric HDI values when xarray had named coordinate dimensions.
31+
32+
This bug manifested in multi-cell synthetic control experiments where columns
33+
like 'pred_hdi_upper_94' contained the string "treated_agg" instead of
34+
numeric upper bound values.
35+
36+
See: https://github.com/pymc-labs/CausalPy/issues/532
37+
"""
38+
# Create a mock xarray DataArray similar to what's produced in synthetic control
39+
# with a coordinate dimension like 'treated_units'
40+
np.random.seed(42)
41+
n_chains = 2
42+
n_draws = 100
43+
n_obs = 10
44+
45+
# Simulate posterior samples with a named coordinate
46+
data = np.random.normal(loc=5.0, scale=0.5, size=(n_chains, n_draws, n_obs))
47+
48+
xr_data = xr.DataArray(
49+
data,
50+
dims=["chain", "draw", "obs_ind"],
51+
coords={
52+
"chain": np.arange(n_chains),
53+
"draw": np.arange(n_draws),
54+
"obs_ind": np.arange(n_obs),
55+
"treated_units": "treated_agg", # This coordinate caused the bug
56+
},
57+
)
58+
59+
# Call get_hdi_to_df
60+
result = get_hdi_to_df(xr_data, hdi_prob=0.94)
61+
62+
# Assertions to verify the bug is fixed
63+
assert isinstance(result, pd.DataFrame), "Result should be a DataFrame"
64+
65+
# Check that we have exactly 2 columns (lower and higher)
66+
assert result.shape[1] == 2, f"Expected 2 columns, got {result.shape[1]}"
67+
68+
# Check column names
69+
assert "lower" in result.columns, "Should have 'lower' column"
70+
assert "higher" in result.columns, "Should have 'higher' column"
71+
72+
# CRITICAL: Check that columns contain numeric data, not strings
73+
assert result["lower"].dtype in [
74+
np.float64,
75+
np.float32,
76+
], f"'lower' column should be numeric, got {result['lower'].dtype}"
77+
assert result["higher"].dtype in [
78+
np.float64,
79+
np.float32,
80+
], f"'higher' column should be numeric, got {result['higher'].dtype}"
81+
82+
# Check that no string values like 'treated_agg' appear in the data
83+
assert not (result["lower"].astype(str).str.contains("treated_agg").any()), (
84+
"'lower' column should not contain coordinate string values"
85+
)
86+
assert not (result["higher"].astype(str).str.contains("treated_agg").any()), (
87+
"'higher' column should not contain coordinate string values"
88+
)
89+
90+
# Verify HDI ordering
91+
assert (result["lower"] <= result["higher"]).all(), (
92+
"'lower' should be <= 'higher' for all rows"
93+
)
94+
95+
# Verify reasonable HDI values (should be around the mean of 5.0)
96+
assert result["lower"].min() > 3.0, "HDI lower bounds should be reasonable"
97+
assert result["higher"].max() < 7.0, "HDI upper bounds should be reasonable"

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)