@@ -24,8 +24,11 @@ def test_did():
24
24
assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
25
25
26
26
27
+ # TODO: set up fixture for the banks dataset
28
+
29
+
27
30
@pytest .mark .integration
28
- def test_did_banks ():
31
+ def test_did_banks_simple ():
29
32
treatment_time = 1930.5
30
33
df = (
31
34
cp .load_data ("banks" )
@@ -60,6 +63,42 @@ def test_did_banks():
60
63
assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
61
64
62
65
66
+ @pytest .mark .integration
67
+ def test_did_banks_multi ():
68
+ treatment_time = 1930.5
69
+ df = (
70
+ cp .load_data ("banks" )
71
+ .filter (items = ["bib6" , "bib8" , "year" ])
72
+ .rename (columns = {"bib6" : "Sixth District" , "bib8" : "Eighth District" })
73
+ .groupby ("year" )
74
+ .median ()
75
+ )
76
+ df .reset_index (level = 0 , inplace = True )
77
+ df_long = pd .melt (
78
+ df ,
79
+ id_vars = ["year" ],
80
+ value_vars = ["Sixth District" , "Eighth District" ],
81
+ var_name = "district" ,
82
+ value_name = "bib" ,
83
+ ).sort_values ("year" )
84
+ df_long ["district" ] = df_long ["district" ].astype ("category" )
85
+ df_long ["unit" ] = df_long ["district" ]
86
+ df_long ["post_treatment" ] = df_long .year >= treatment_time
87
+ result = cp .pymc_experiments .DifferenceInDifferences (
88
+ df_long ,
89
+ formula = "bib ~ 1 + district + year + district:post_treatment" ,
90
+ time_variable_name = "year" ,
91
+ group_variable_name = "district" ,
92
+ treated = "Sixth District" ,
93
+ untreated = "Eighth District" ,
94
+ prediction_model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
95
+ )
96
+ assert isinstance (df , pd .DataFrame )
97
+ assert isinstance (result , cp .pymc_experiments .DifferenceInDifferences )
98
+ assert len (result .idata .posterior .coords ["chain" ]) == sample_kwargs ["chains" ]
99
+ assert len (result .idata .posterior .coords ["draw" ]) == sample_kwargs ["draws" ]
100
+
101
+
63
102
@pytest .mark .integration
64
103
def test_rd ():
65
104
df = cp .load_data ("rd" )
0 commit comments