Skip to content

Commit 1fc193b

Browse files
authored
Merge pull request #473 from pymc-labs/cetagostini/adding_bsts_to_causalpy
Implement Bayesian Structural Time Series (BSTS)
2 parents 0b4764d + abdbf6b commit 1fc193b

14 files changed

+4382
-84
lines changed

causalpy/experiments/__init__.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,24 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""Quasi-experimental designs for causal inference."""
14+
"""CausalPy experiment module"""
15+
16+
from .diff_in_diff import DifferenceInDifferences
17+
from .instrumental_variable import InstrumentalVariable
18+
from .interrupted_time_series import InterruptedTimeSeries
19+
from .inverse_propensity_weighting import InversePropensityWeighting
20+
from .prepostnegd import PrePostNEGD
21+
from .regression_discontinuity import RegressionDiscontinuity
22+
from .regression_kink import RegressionKink
23+
from .synthetic_control import SyntheticControl
24+
25+
__all__ = [
26+
"DifferenceInDifferences",
27+
"InstrumentalVariable",
28+
"InversePropensityWeighting",
29+
"PrePostNEGD",
30+
"RegressionDiscontinuity",
31+
"RegressionKink",
32+
"SyntheticControl",
33+
"InterruptedTimeSeries",
34+
]

causalpy/experiments/interrupted_time_series.py

Lines changed: 249 additions & 64 deletions
Large diffs are not rendered by default.

causalpy/plot_utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,23 +62,25 @@ def plot_xY(
6262
if plot_hdi_kwargs is None:
6363
plot_hdi_kwargs = {}
6464

65+
# Separate fill_kwargs for az.plot_hdi, as ax.plot doesn't accept them
66+
line_kwargs = plot_hdi_kwargs.copy()
67+
if "fill_kwargs" in line_kwargs:
68+
del line_kwargs["fill_kwargs"]
69+
6570
(h_line,) = ax.plot(
6671
x,
6772
Y.mean(dim=["chain", "draw"]),
6873
ls="-",
69-
**plot_hdi_kwargs,
70-
label=f"{label}",
74+
**line_kwargs, # Use kwargs without fill_kwargs
75+
label=label, # Use the provided label for the mean line
7176
)
7277
ax_hdi = az.plot_hdi(
7378
x,
7479
Y,
7580
hdi_prob=hdi_prob,
76-
fill_kwargs={
77-
"alpha": 0.25,
78-
"label": " ",
79-
},
80-
smooth=False,
8181
ax=ax,
82+
smooth=False, # To prevent warning about resolution with few data points
83+
# Pass original plot_hdi_kwargs which might include fill_kwargs for fill_between
8284
**plot_hdi_kwargs,
8385
)
8486
# Return handle to patch. We get a list of the children of the axis. Filter for just

0 commit comments

Comments
 (0)