3030
3131
3232def  test_did_validation_post_treatment_formula ():
33-     """Test that we get a FormulaException if do not include post_treatment in the 
34-     formula""" 
33+     """Test that we get a FormulaException for invalid formulas and missing post_treatment variables""" 
3534    df  =  pd .DataFrame (
3635        {
3736            "group" : [0 , 0 , 1 , 1 ],
3837            "t" : [0 , 1 , 0 , 1 ],
3938            "unit" : [0 , 0 , 1 , 1 ],
4039            "post_treatment" : [0 , 1 , 0 , 1 ],
40+             "male" : [0 , 1 , 0 , 1 ],  # Additional variable for testing 
4141            "y" : [1 , 2 , 3 , 4 ],
4242        }
4343    )
4444
45+     df_with_custom  =  pd .DataFrame (
46+         {
47+             "group" : [0 , 0 , 1 , 1 ],
48+             "t" : [0 , 1 , 0 , 1 ],
49+             "unit" : [0 , 0 , 1 , 1 ],
50+             "custom_post" : [0 , 1 , 0 , 1 ],  # Custom column name 
51+             "y" : [1 , 2 , 3 , 4 ],
52+         }
53+     )
54+ 
55+     # Test 1: Missing post_treatment variable in formula 
4556    with  pytest .raises (FormulaException ):
4657        _  =  cp .DifferenceInDifferences (
4758            df ,
@@ -51,6 +62,7 @@ def test_did_validation_post_treatment_formula():
5162            model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
5263        )
5364
65+     # Test 2: Missing post_treatment variable in formula (duplicate test) 
5466    with  pytest .raises (FormulaException ):
5567        _  =  cp .DifferenceInDifferences (
5668            df ,
@@ -60,6 +72,88 @@ def test_did_validation_post_treatment_formula():
6072            model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
6173        )
6274
75+     # Test 3: Custom post_treatment_variable_name but formula uses different name 
76+     with  pytest .raises (FormulaException ):
77+         _  =  cp .DifferenceInDifferences (
78+             df_with_custom ,
79+             formula = "y ~ 1 + group*post_treatment" ,  # Formula uses 'post_treatment' 
80+             time_variable_name = "t" ,
81+             group_variable_name = "group" ,
82+             post_treatment_variable_name = "custom_post" ,  # But user specifies 'custom_post' 
83+             model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
84+         )
85+ 
86+     # Test 4: Default post_treatment_variable_name but formula uses different name 
87+     with  pytest .raises (FormulaException ):
88+         _  =  cp .DifferenceInDifferences (
89+             df ,
90+             formula = "y ~ 1 + group*custom_post" ,  # Formula uses 'custom_post' 
91+             time_variable_name = "t" ,
92+             group_variable_name = "group" ,
93+             # post_treatment_variable_name defaults to "post_treatment" 
94+             model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
95+         )
96+ 
97+     # Test 5: Repeated interaction terms (should be invalid) 
98+     with  pytest .raises (FormulaException ):
99+         _  =  cp .DifferenceInDifferences (
100+             df ,
101+             formula = "y ~ 1 + group + group*post_treatment + group*post_treatment" ,
102+             time_variable_name = "t" ,
103+             group_variable_name = "group" ,
104+             model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
105+         )
106+ 
107+     # Test 6: Three-way interactions using * (should be invalid) 
108+     with  pytest .raises (FormulaException ):
109+         _  =  cp .DifferenceInDifferences (
110+             df ,
111+             formula = "y ~ 1 + group + group*post_treatment*male" ,
112+             time_variable_name = "t" ,
113+             group_variable_name = "group" ,
114+             model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
115+         )
116+ 
117+     # Test 7: Three-way interactions using : (should be invalid) 
118+     with  pytest .raises (FormulaException ):
119+         _  =  cp .DifferenceInDifferences (
120+             df ,
121+             formula = "y ~ 1 + group + group:post_treatment:male" ,
122+             time_variable_name = "t" ,
123+             group_variable_name = "group" ,
124+             model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
125+         )
126+ 
127+     # Test 8: Multiple different interaction terms using * (should be invalid) 
128+     with  pytest .raises (FormulaException ):
129+         _  =  cp .DifferenceInDifferences (
130+             df ,
131+             formula = "y ~ 1 + group + group*post_treatment + group*male" ,
132+             time_variable_name = "t" ,
133+             group_variable_name = "group" ,
134+             model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
135+         )
136+ 
137+     # Test 9: Multiple different interaction terms using : (should be invalid) 
138+     with  pytest .raises (FormulaException ):
139+         _  =  cp .DifferenceInDifferences (
140+             df ,
141+             formula = "y ~ 1 + group + group:post_treatment + group:male" ,
142+             time_variable_name = "t" ,
143+             group_variable_name = "group" ,
144+             model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
145+         )
146+ 
147+     # Test 10: Mixed issues - multiple terms + three-way interaction (should be invalid) 
148+     with  pytest .raises (FormulaException ):
149+         _  =  cp .DifferenceInDifferences (
150+             df ,
151+             formula = "y ~ 1 + group + group*post_treatment + group:post_treatment:male" ,
152+             time_variable_name = "t" ,
153+             group_variable_name = "group" ,
154+             model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
155+         )
156+ 
63157
64158def  test_did_validation_post_treatment_data ():
65159    """Test that we get a DataException if do not include post_treatment in the data""" 
@@ -91,6 +185,27 @@ def test_did_validation_post_treatment_data():
91185            model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
92186        )
93187
188+     # Test 2: Custom post_treatment_variable_name but column doesn't exist in data 
189+     df_with_post  =  pd .DataFrame (
190+         {
191+             "group" : [0 , 0 , 1 , 1 ],
192+             "t" : [0 , 1 , 0 , 1 ],
193+             "unit" : [0 , 0 , 1 , 1 ],
194+             "post_treatment" : [0 , 1 , 0 , 1 ],  # Data has 'post_treatment' 
195+             "y" : [1 , 2 , 3 , 4 ],
196+         }
197+     )
198+ 
199+     with  pytest .raises (DataException ):
200+         _  =  cp .DifferenceInDifferences (
201+             df_with_post ,
202+             formula = "y ~ 1 + group*custom_post" ,  # Formula uses 'custom_post' 
203+             time_variable_name = "t" ,
204+             group_variable_name = "group" ,
205+             post_treatment_variable_name = "custom_post" ,  # User specifies 'custom_post' 
206+             model = cp .pymc_models .LinearRegression (sample_kwargs = sample_kwargs ),
207+         )
208+ 
94209
95210def  test_did_validation_unit_data ():
96211    """Test that we get a DataException if do not include unit in the data""" 
0 commit comments