Skip to content

Commit 809277e

Browse files
committed
Examples for pymc experiments and most models
1 parent 6bee429 commit 809277e

File tree

5 files changed

+328
-29
lines changed

5 files changed

+328
-29
lines changed

causalpy/data/simulate_data.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,13 @@ def generate_synthetic_control_data(
2929
lowess_kwargs=default_lowess_kwargs,
3030
):
3131
"""
32-
Example:
33-
>> import pathlib
34-
>> df, weightings_true = generate_synthetic_control_data(
35-
treatment_time=treatment_time
36-
)
37-
>> df.to_csv(pathlib.Path.cwd() / 'synthetic_control.csv', index=False)
32+
Example
33+
--------
34+
>>> import pathlib
35+
>>> df, weightings_true = generate_synthetic_control_data(
36+
... treatment_time=treatment_time
37+
... )
38+
>>> df.to_csv(pathlib.Path.cwd() / 'synthetic_control.csv', index=False)
3839
"""
3940

4041
# 1. Generate non-treated variables
@@ -73,6 +74,7 @@ def generate_synthetic_control_data(
7374
def generate_time_series_data(
7475
N=100, treatment_time=70, beta_temp=-1, beta_linear=0.5, beta_intercept=3
7576
):
77+
""" """
7678
x = np.arange(0, 100, 1)
7779
df = pd.DataFrame(
7880
{
@@ -102,6 +104,7 @@ def generate_time_series_data(
102104

103105

104106
def generate_time_series_data_seasonal(treatment_time):
107+
""" """
105108
dates = pd.date_range(
106109
start=pd.to_datetime("2010-01-01"), end=pd.to_datetime("2020-01-01"), freq="M"
107110
)
@@ -149,6 +152,13 @@ def generate_time_series_data_simple(treatment_time, slope=0.0):
149152

150153

151154
def generate_did():
155+
"""
156+
Generate Difference in Differences data
157+
158+
Example
159+
--------
160+
>>> df = generate_did()
161+
"""
152162
# true parameters
153163
control_intercept = 1
154164
treat_intercept_delta = 0.25
@@ -194,10 +204,13 @@ def generate_regression_discontinuity_data(
194204
N=100, true_causal_impact=0.5, true_treatment_threshold=0.0
195205
):
196206
"""
197-
Example use:
198-
>> import pathlib
199-
>> df = generate_regression_discontinuity_data(true_treatment_threshold=0.5)
200-
>> df.to_csv(pathlib.Path.cwd() / 'regression_discontinuity.csv', index=False)
207+
Generate regression discontinuity example data
208+
209+
Example
210+
--------
211+
>>> import pathlib
212+
>>> df = generate_regression_discontinuity_data(true_treatment_threshold=0.5)
213+
>>> df.to_csv(pathlib.Path.cwd() / 'regression_discontinuity.csv', index=False)
201214
"""
202215

203216
def is_treated(x):
@@ -217,6 +230,20 @@ def impact(x):
217230
def generate_ancova_data(
218231
N=200, pre_treatment_means=np.array([10, 12]), treatment_effect=2, sigma=1
219232
):
233+
"""
234+
Generate ANCOVA eample data
235+
236+
Example
237+
--------
238+
>>> import pathlib
239+
>>> df = generate_ancova_data(
240+
... N=200,
241+
... pre_treatment_threshold=np.array([10, 12]),
242+
... treatment_effect=2,
243+
... sigma=1
244+
... )
245+
>>> df.to_csv(pathlib.Path.cwd() / 'ancova_data.csv', index=False)
246+
"""
220247
group = np.random.choice(2, size=N)
221248
pre = np.random.normal(loc=pre_treatment_means[group])
222249
post = pre + treatment_effect * group + np.random.normal(size=N) * sigma

causalpy/plot_utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,22 @@ def plot_xY(
2121
hdi_prob: float = 0.94,
2222
label: Union[str, None] = None,
2323
) -> Tuple[Line2D, PolyCollection]:
24-
"""Utility function to plot HDI intervals."""
24+
"""
25+
Utility function to plot HDI intervals.
26+
27+
:param x:
28+
Pandas datetime index or numpy array of x-axis values
29+
:param y:
30+
Xarray data array of y-axis data
31+
:param ax:
32+
Matplotlib ax object
33+
:param plot_hdi_kwargs:
34+
Dictionary of keyword arguments passed to ax.plot()
35+
:param hdi_prob:
36+
The size of the HDI, default is 0.94
37+
:param label:
38+
The plot label
39+
"""
2540

2641
if plot_hdi_kwargs is None:
2742
plot_hdi_kwargs = {}

causalpy/pymc_experiments.py

Lines changed: 142 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131

3232
class ExperimentalDesign:
33-
"""Base class"""
33+
"""Base class for other experiment types"""
3434

3535
model = None
3636
expt_type = None
@@ -43,7 +43,7 @@ def __init__(self, model=None, **kwargs):
4343

4444
@property
4545
def idata(self):
46-
"""Access to the InferenceData object"""
46+
"""Access to the models InferenceData object"""
4747
return self.model.idata
4848

4949
def print_coefficients(self) -> None:
@@ -66,8 +66,32 @@ def print_coefficients(self) -> None:
6666

6767

6868
class PrePostFit(ExperimentalDesign):
69-
"""A class to analyse quasi-experiments where parameter estimation is based on just
70-
the pre-intervention data."""
69+
"""
70+
A class to analyse quasi-experiments where parameter estimation is based on just
71+
the pre-intervention data.
72+
73+
:param data:
74+
A pandas data frame
75+
:param treatment_time:
76+
The time when treatment occured, should be in reference to the data index
77+
:param formula:
78+
A statistical model formula
79+
:param model:
80+
A PyMC model
81+
82+
Example
83+
--------
84+
>>> sc = cp.load_data("sc")
85+
>>> seed = 42
86+
>>> result = cp.pymc_experiments.PrePostFit(
87+
... sc,
88+
... treatment_time,
89+
... formula="actual ~ 0 + a + b + c + d + e + f + g",
90+
... model=cp.pymc_models.WeightedSumFitter(
91+
... sample_kwargs={"target_accept": 0.95, "random_seed": seed}
92+
... ),
93+
... )
94+
"""
7195

7296
def __init__(
7397
self,
@@ -256,13 +280,64 @@ def summary(self) -> None:
256280

257281

258282
class InterruptedTimeSeries(PrePostFit):
259-
"""Interrupted time series analysis"""
283+
"""
284+
A wrapper around PrePostFit class
285+
286+
:param data:
287+
A pandas data frame
288+
:param treatment_time:
289+
The time when treatment occured, should be in reference to the data index
290+
:param formula:
291+
A statistical model formula
292+
:param model:
293+
A PyMC model
294+
295+
Example
296+
--------
297+
>>> df = (
298+
... cp.load_data("its")
299+
... .assign(date=lambda x: pd.to_datetime(x["date"]))
300+
... .set_index("date")
301+
... )
302+
>>> treatment_time = pd.to_datetime("2017-01-01")
303+
>>> seed = 42
304+
>>> result = cp.pymc_experiments.InterruptedTimeSeries(
305+
... df,
306+
... treatment_time,
307+
... formula="y ~ 1 + t + C(month)",
308+
... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
309+
... )
310+
"""
260311

261312
expt_type = "Interrupted Time Series"
262313

263314

264315
class SyntheticControl(PrePostFit):
265-
"""A wrapper around the PrePostFit class"""
316+
"""A wrapper around the PrePostFit class
317+
318+
:param data:
319+
A pandas data frame
320+
:param treatment_time:
321+
The time when treatment occured, should be in reference to the data index
322+
:param formula:
323+
A statistical model formula
324+
:param model:
325+
A PyMC model
326+
327+
Example
328+
--------
329+
>>> df = cp.load_data("sc")
330+
>>> treatment_time = 70
331+
>>> seed = 42
332+
>>> result = cp.pymc_experiments.SyntheticControl(
333+
... df,
334+
... treatment_time,
335+
... formula="actual ~ 0 + a + b + c + d + e + f + g",
336+
... model=cp.pymc_models.WeightedSumFitter(
337+
... sample_kwargs={"target_accept": 0.95, "random_seed": seed}
338+
... ),
339+
... )
340+
"""
266341

267342
expt_type = "Synthetic Control"
268343

@@ -285,6 +360,28 @@ class DifferenceInDifferences(ExperimentalDesign):
285360
286361
There is no pre/post intervention data distinction for DiD, we fit all the
287362
data available.
363+
:param data:
364+
A pandas data frame
365+
:param formula:
366+
A statistical model formula
367+
:param time_variable_name:
368+
Name of the data column for the time variable
369+
:param group_variable_name:
370+
Name of the data column for the group variable
371+
:param model:
372+
A PyMC model for difference in differences
373+
374+
Example
375+
--------
376+
>>> df = cp.load_data("did")
377+
>>> seed = 42
378+
>>> result = cp.pymc_experiments.DifferenceInDifferences(
379+
... df,
380+
... formula="y ~ 1 + group*post_treatment",
381+
... time_variable_name="t",
382+
... group_variable_name="group",
383+
... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
384+
... )
288385
289386
"""
290387

@@ -572,6 +669,18 @@ class RegressionDiscontinuity(ExperimentalDesign):
572669
:param bandwidth:
573670
Data outside of the bandwidth (relative to the discontinuity) is not used to fit
574671
the model.
672+
673+
Example
674+
--------
675+
>>> df = cp.load_data("rd")
676+
>>> seed = 42
677+
>>> result = cp.pymc_experiments.RegressionDiscontinuity(
678+
... df,
679+
... formula="y ~ 1 + x + treated + x:treated",
680+
... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
681+
... treatment_threshold=0.5,
682+
... )
683+
575684
"""
576685

577686
def __init__(
@@ -742,7 +851,33 @@ def summary(self) -> None:
742851

743852

744853
class PrePostNEGD(ExperimentalDesign):
745-
"""A class to analyse data from pretest/posttest designs"""
854+
"""
855+
A class to analyse data from pretest/posttest designs
856+
857+
:param data:
858+
A pandas data frame
859+
:param formula:
860+
A statistical model formula
861+
:param group_variable_name:
862+
Name of the column in data for the group variable
863+
:param pretreatment_variable_name:
864+
Name of the column in data for the pretreatment variable
865+
:param model:
866+
A PyMC model
867+
868+
Example
869+
--------
870+
>>> df = cp.load_data("anova1")
871+
>>> seed = 42
872+
>>> result = cp.pymc_experiments.PrePostNEGD(
873+
... df,
874+
... formula="post ~ 1 + C(group) + pre",
875+
... group_variable_name="group",
876+
... pretreatment_variable_name="pre",
877+
... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}),
878+
... )
879+
880+
"""
746881

747882
def __init__(
748883
self,

0 commit comments

Comments
 (0)