1919import numpy as np
2020import pandas as pd
2121import seaborn as sns
22+ from formulae import design_matrices
2223from matplotlib import pyplot as plt
23- from patsy import build_design_matrices , dmatrices
2424from sklearn .base import RegressorMixin
2525
2626from causalpy .custom_exceptions import (
@@ -91,16 +91,18 @@ def __init__(
9191 self .data = data
9292 self .expt_type = "Difference in Differences"
9393 self .formula = formula
94+ self .rhs_formula = formula .split ("~" , 1 )[1 ].strip ()
9495 self .time_variable_name = time_variable_name
9596 self .group_variable_name = group_variable_name
9697 self .input_validation ()
9798
98- y , X = dmatrices (formula , self .data )
99- self ._y_design_info = y .design_info
100- self ._x_design_info = X .design_info
101- self .labels = X .design_info .column_names
102- self .y , self .X = np .asarray (y ), np .asarray (X )
103- self .outcome_variable_name = y .design_info .column_names [0 ]
99+ dm = design_matrices (self .formula , self .data )
100+ self .labels = list (dm .common .terms .keys ())
101+ self .y , self .X = (
102+ np .asarray (dm .response .design_matrix ).reshape (- 1 , 1 ),
103+ np .asarray (dm .common .design_matrix ),
104+ )
105+ self .outcome_variable_name = dm .response .name
104106
105107 # fit model
106108 if isinstance (self .model , PyMCModel ):
@@ -125,8 +127,8 @@ def __init__(
125127 )
126128 if self .x_pred_control .empty :
127129 raise ValueError ("x_pred_control is empty" )
128- ( new_x ,) = build_design_matrices ([ self ._x_design_info ] , self .x_pred_control )
129- self .y_pred_control = self .model .predict (np . asarray ( new_x ) )
130+ new_x = np . array ( design_matrices ( self .rhs_formula , self .x_pred_control ). common )
131+ self .y_pred_control = self .model .predict (new_x )
130132
131133 # predicted outcome for treatment group
132134 self .x_pred_treatment = (
@@ -142,8 +144,10 @@ def __init__(
142144 )
143145 if self .x_pred_treatment .empty :
144146 raise ValueError ("x_pred_treatment is empty" )
145- (new_x ,) = build_design_matrices ([self ._x_design_info ], self .x_pred_treatment )
146- self .y_pred_treatment = self .model .predict (np .asarray (new_x ))
147+ new_x = np .array (
148+ design_matrices (self .rhs_formula , self .x_pred_treatment ).common
149+ )
150+ self .y_pred_treatment = self .model .predict (new_x )
147151
148152 # predicted outcome for counterfactual. This is given by removing the influence
149153 # of the interaction term between the group and the post_treatment variable
@@ -162,15 +166,15 @@ def __init__(
162166 )
163167 if self .x_pred_counterfactual .empty :
164168 raise ValueError ("x_pred_counterfactual is empty" )
165- ( new_x ,) = build_design_matrices (
166- [ self ._x_design_info ] , self .x_pred_counterfactual , return_type = "dataframe"
169+ new_x = np . array (
170+ design_matrices ( self .rhs_formula , self .x_pred_counterfactual ). common
167171 )
168172 # INTERVENTION: set the interaction term between the group and the
169173 # post_treatment variable to zero. This is the counterfactual.
170174 for i , label in enumerate (self .labels ):
171175 if "post_treatment" in label and self .group_variable_name in label :
172- new_x . iloc [:, i ] = 0
173- self .y_pred_counterfactual = self .model .predict (np . asarray ( new_x ) )
176+ new_x [:, i ] = 0
177+ self .y_pred_counterfactual = self .model .predict (new_x )
174178
175179 # calculate causal impact
176180 if isinstance (self .model , PyMCModel ):
0 commit comments