2626
2727from causalpy .custom_exceptions import (
2828 DataException ,
29- FormulaException ,
3029)
3130from causalpy .plot_utils import plot_xY
3231from causalpy .pymc_models import PyMCModel
@@ -84,6 +83,7 @@ def __init__(
8483 formula : str ,
8584 time_variable_name : str ,
8685 group_variable_name : str ,
86+ post_treatment_variable_name : str = "post_treatment" ,
8787 model = None ,
8888 ** kwargs ,
8989 ) -> None :
@@ -95,6 +95,7 @@ def __init__(
9595 self .formula = formula
9696 self .time_variable_name = time_variable_name
9797 self .group_variable_name = group_variable_name
98+ self .post_treatment_variable_name = post_treatment_variable_name
9899 self .input_validation ()
99100
100101 y , X = dmatrices (formula , self .data )
@@ -128,6 +129,12 @@ def __init__(
128129 }
129130 self .model .fit (X = self .X , y = self .y , coords = COORDS )
130131 elif isinstance (self .model , RegressorMixin ):
132+ # For scikit-learn models, automatically set fit_intercept=False
133+ # This ensures the intercept is included in the coefficients array rather than being a separate intercept_ attribute
134+ # without this, the intercept is not included in the coefficients array hence would be displayed as 0 in the model summary
135+ # TODO: later, this should be handled in ScikitLearnAdaptor itself
136+ if hasattr (self .model , "fit_intercept" ):
137+ self .model .fit_intercept = False
131138 self .model .fit (X = self .X , y = self .y )
132139 else :
133140 raise ValueError ("Model type not recognized" )
@@ -173,7 +180,7 @@ def __init__(
173180 # just the treated group
174181 .query (f"{ self .group_variable_name } == 1" )
175182 # just the treatment period(s)
176- .query ("post_treatment == True" )
183+ .query (f" { self . post_treatment_variable_name } == True" )
177184 # drop the outcome variable
178185 .drop (self .outcome_variable_name , axis = 1 )
179186 # We may have multiple units per time point, we only want one time point
@@ -189,7 +196,10 @@ def __init__(
189196 # INTERVENTION: set the interaction term between the group and the
190197 # post_treatment variable to zero. This is the counterfactual.
191198 for i , label in enumerate (self .labels ):
192- if "post_treatment" in label and self .group_variable_name in label :
199+ if (
200+ self .post_treatment_variable_name in label
201+ and self .group_variable_name in label
202+ ):
193203 new_x .iloc [:, i ] = 0
194204 self .y_pred_counterfactual = self .model .predict (np .asarray (new_x ))
195205
@@ -198,32 +208,53 @@ def __init__(
198208 # This is the coefficient on the interaction term
199209 coeff_names = self .model .idata .posterior .coords ["coeffs" ].data
200210 for i , label in enumerate (coeff_names ):
201- if "post_treatment" in label and self .group_variable_name in label :
211+ if (
212+ self .post_treatment_variable_name in label
213+ and self .group_variable_name in label
214+ ):
202215 self .causal_impact = self .model .idata .posterior ["beta" ].isel (
203216 {"coeffs" : i }
204217 )
205218 elif isinstance (self .model , RegressorMixin ):
206219 # This is the coefficient on the interaction term
207- # TODO: CHECK FOR CORRECTNESS
208- self .causal_impact = (
209- self .y_pred_treatment [1 ] - self .y_pred_counterfactual [0 ]
210- ).item ()
220+ # Store the coefficient into dictionary {intercept:value}
221+ coef_map = dict (zip (self .labels , self .model .get_coeffs ()))
222+ # Create and find the interaction term based on the values user provided
223+ interaction_term = (
224+ f"{ self .group_variable_name } :{ self .post_treatment_variable_name } "
225+ )
226+ matched_key = next ((k for k in coef_map if interaction_term in k ), None )
227+ att = coef_map .get (matched_key )
228+ self .causal_impact = att
211229 else :
212230 raise ValueError ("Model type not recognized" )
213231
214232 return
215233
216234 def input_validation (self ):
217235 """Validate the input data and model formula for correctness"""
218- if "post_treatment" not in self .formula :
219- raise FormulaException (
220- "A predictor called `post_treatment` should be in the formula"
221- )
222-
223- if "post_treatment" not in self .data .columns :
224- raise DataException (
225- "Require a boolean column labelling observations which are `treated`"
226- )
236+ if (
237+ self .post_treatment_variable_name not in self .formula
238+ or self .post_treatment_variable_name not in self .data .columns
239+ ):
240+ if self .post_treatment_variable_name == "post_treatment" :
241+ # Default case - user didn't specify custom name, so guide them to use "post_treatment"
242+ raise DataException (
243+ "Missing 'post_treatment' in formula or dataset.\n "
244+ "Note: post_treatment_variable_name might have been set to 'post_treatment' by default.\n "
245+ "1) Add 'post_treatment' to formula (e.g., 'y ~ 1 + group*post_treatment')\n "
246+ "2) and ensure dataset has boolean column 'post_treatment'.\n "
247+ "To use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'."
248+ )
249+ else :
250+ # Custom case - user specified custom name, so remind them what they specified
251+ raise DataException (
252+ f"Missing required variable '{ self .post_treatment_variable_name } ' in formula or dataset.\n \n "
253+ f"Since you specified post_treatment_variable_name='{ self .post_treatment_variable_name } ', "
254+ f"please ensure:\n "
255+ f"1) formula includes '{ self .post_treatment_variable_name } '\n "
256+ f"2) dataset has boolean column named '{ self .post_treatment_variable_name } '"
257+ )
227258
228259 if "unit" not in self .data .columns :
229260 raise DataException (
0 commit comments