Skip to content

Commit cc62438

Browse files
committed
better model/experiment compatability
1 parent 121fe46 commit cc62438

11 files changed

+157
-5
lines changed

causalpy/experiments/base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,21 @@
2424
class BaseExperiment:
2525
"""Base class for quasi experimental designs."""
2626

27+
supports_bayes: bool
28+
supports_ols: bool
29+
2730
def __init__(self, model=None):
2831
if model is not None:
2932
self.model = model
33+
34+
if isinstance(self.model, PyMCModel) and not self.supports_bayes:
35+
raise ValueError("Bayesian models not supported.")
36+
37+
if isinstance(self.model, ScikitLearnModel) and not self.supports_ols:
38+
raise ValueError("OLS models not supported.")
39+
3040
if self.model is None:
31-
raise ValueError("fitting_model not set or passed.")
41+
raise ValueError("model not set or passed.")
3242

3343
@property
3444
def idata(self):

causalpy/experiments/diff_in_diff.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ class DifferenceInDifferences(BaseExperiment):
7474
... )
7575
"""
7676

77+
supports_ols = True
78+
supports_bayes = True
79+
7780
def __init__(
7881
self,
7982
data: pd.DataFrame,

causalpy/experiments/instrumental_variable.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ class InstrumentalVariable(BaseExperiment):
8888
... )
8989
"""
9090

91+
supports_ols = False
92+
supports_bayes = True
93+
9194
def __init__(
9295
self,
9396
instruments_data: pd.DataFrame,

causalpy/experiments/inverse_propensity_weighting.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ class InversePropensityWeighting(BaseExperiment):
6969
... )
7070
"""
7171

72+
supports_ols = False
73+
supports_bayes = True
74+
7275
def __init__(
7376
self,
7477
data: pd.DataFrame,

causalpy/experiments/prepostfit.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,8 @@ class InterruptedTimeSeries(PrePostFit):
342342
"""
343343

344344
expt_type = "Interrupted Time Series"
345+
supports_ols = True
346+
supports_bayes = True
345347

346348

347349
class SyntheticControl(PrePostFit):
@@ -377,6 +379,8 @@ class SyntheticControl(PrePostFit):
377379
"""
378380

379381
expt_type = "SyntheticControl"
382+
supports_ols = True
383+
supports_bayes = True
380384

381385
def bayesian_plot(self, *args, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]:
382386
"""

causalpy/experiments/prepostnegd.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ class PrePostNEGD(BaseExperiment):
8484
sigma 0.5, 94% HDI [0.5, 0.6]
8585
"""
8686

87+
supports_ols = False
88+
supports_bayes = True
89+
8790
def __init__(
8891
self,
8992
data: pd.DataFrame,

causalpy/experiments/regression_discontinuity.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ class RegressionDiscontinuity(BaseExperiment):
7878
... )
7979
"""
8080

81+
supports_ols = True
82+
supports_bayes = True
83+
8184
def __init__(
8285
self,
8386
data: pd.DataFrame,

causalpy/experiments/regression_kink.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141
class RegressionKink(BaseExperiment):
4242
"""Regression Kink experiment class."""
4343

44+
supports_ols = False
45+
supports_bayes = True
46+
4447
def __init__(
4548
self,
4649
data: pd.DataFrame,
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright 2024 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Test exceptions are raised when an experiment object is provided a model type (e.g.
16+
`PyMCModel` or `ScikitLearnModel`) that is not supported by the experiment object.
17+
"""
18+
19+
import numpy as np
20+
import pandas as pd
21+
import pytest
22+
from sklearn.linear_model import LinearRegression
23+
24+
import causalpy as cp
25+
26+
CustomLinearRegression = cp.create_causalpy_compatible_class(LinearRegression)
27+
28+
29+
# TODO: THE TWO FUNCTIONS BELOW ARE COPIED FROM causalpy/tests/test_regression_kink.py
30+
31+
32+
def setup_regression_kink_data(kink):
33+
"""Set up data for regression kink design tests"""
34+
# define parameters for data generation
35+
seed = 42
36+
rng = np.random.default_rng(seed)
37+
N = 50
38+
kink = 0.5
39+
beta = [0, -1, 0, 2, 0]
40+
sigma = 0.05
41+
# generate data
42+
x = rng.uniform(-1, 1, N)
43+
y = reg_kink_function(x, beta, kink) + rng.normal(0, sigma, N)
44+
return pd.DataFrame({"x": x, "y": y, "treated": x >= kink})
45+
46+
47+
def reg_kink_function(x, beta, kink):
48+
"""Utility function for regression kink design. Returns a piecewise linear function
49+
evaluated at x with a kink at kink and parameters beta"""
50+
return (
51+
beta[0]
52+
+ beta[1] * x
53+
+ beta[2] * x**2
54+
+ beta[3] * (x - kink) * (x >= kink)
55+
+ beta[4] * (x - kink) ** 2 * (x >= kink)
56+
)
57+
58+
59+
# Test that a ValueError is raised when a ScikitLearnModel is provided to a RegressionKink object
60+
def test_olsmodel_and_regressionkink():
61+
"""RegressionKink does not support OLS models, so a ValueError should be raised"""
62+
63+
with pytest.raises(ValueError):
64+
kink = 0.5
65+
df = setup_regression_kink_data(kink)
66+
_ = cp.RegressionKink(
67+
df,
68+
formula=f"y ~ 1 + x + I((x-{kink})*treated)",
69+
model=CustomLinearRegression(),
70+
kink_point=kink,
71+
)
72+
73+
74+
# Test that a ValueError is raised when a ScikitLearnModel is provided to a InstrumentalVariable object
75+
def test_olsmodel_and_iv():
76+
"""InstrumentalVariable does not support OLS models, so a ValueError should be raised"""
77+
78+
with pytest.raises(ValueError):
79+
df = cp.load_data("risk")
80+
instruments_formula = "risk ~ 1 + logmort0"
81+
formula = "loggdp ~ 1 + risk"
82+
instruments_data = df[["risk", "logmort0"]]
83+
data = df[["loggdp", "risk"]]
84+
_ = cp.InstrumentalVariable(
85+
instruments_data=instruments_data,
86+
data=data,
87+
instruments_formula=instruments_formula,
88+
formula=formula,
89+
model=CustomLinearRegression(),
90+
)
91+
92+
93+
# Test that a ValueError is raised when a ScikitLearnModel is provided to a PrePostNEGD object
94+
def test_olsmodel_and_prepostnegd():
95+
"""PrePostNEGD does not support OLS models, so a ValueError should be raised"""
96+
97+
with pytest.raises(ValueError):
98+
df = cp.load_data("anova1")
99+
_ = cp.PrePostNEGD(
100+
df,
101+
formula="post ~ 1 + C(group) + pre",
102+
group_variable_name="group",
103+
pretreatment_variable_name="pre",
104+
model=CustomLinearRegression(),
105+
)
106+
107+
108+
# Test that a ValueError is raised when a ScikitLearnModel is provided to a InversePropensityWeighting object
109+
def test_olsmodel_and_ipw():
110+
"""InversePropensityWeighting does not support OLS models, so a ValueError should be raised"""
111+
112+
with pytest.raises(ValueError):
113+
df = cp.load_data("nhefs")
114+
_ = cp.InversePropensityWeighting(
115+
df,
116+
formula="trt ~ 1 + age + race",
117+
outcome_variable="outcome",
118+
weighting_scheme="robust",
119+
model=CustomLinearRegression(),
120+
)

docs/source/_static/classes.png

36.3 KB
Loading

0 commit comments

Comments
 (0)