1919import numpy as np
2020import pandas as pd
2121import seaborn as sns
22+ import xarray as xr
2223from matplotlib import pyplot as plt
2324from patsy import build_design_matrices , dmatrices
2425from sklearn .base import RegressorMixin
@@ -87,7 +88,8 @@ def __init__(
8788 ** kwargs ,
8889 ) -> None :
8990 super ().__init__ (model = model )
90-
91+ # rename the index to "obs_ind"
92+ data .index .name = "obs_ind"
9193 self .data = data
9294 self .expt_type = "Difference in Differences"
9395 self .formula = formula
@@ -102,6 +104,21 @@ def __init__(
102104 self .y , self .X = np .asarray (y ), np .asarray (X )
103105 self .outcome_variable_name = y .design_info .column_names [0 ]
104106
107+ # turn into xarray.DataArray's
108+ self .X = xr .DataArray (
109+ self .X ,
110+ dims = ["obs_ind" , "coeffs" ],
111+ coords = {
112+ "obs_ind" : np .arange (self .X .shape [0 ]),
113+ "coeffs" : self .labels ,
114+ },
115+ )
116+ self .y = xr .DataArray (
117+ self .y [:, 0 ],
118+ dims = ["obs_ind" ],
119+ coords = {"obs_ind" : np .arange (self .y .shape [0 ])},
120+ )
121+
105122 # fit model
106123 if isinstance (self .model , PyMCModel ):
107124 COORDS = {"coeffs" : self .labels , "obs_ind" : np .arange (self .X .shape [0 ])}
@@ -183,13 +200,15 @@ def __init__(
183200 )
184201 elif isinstance (self .model , RegressorMixin ):
185202 # This is the coefficient on the interaction term
186- # TODO: THIS IS NOT YET CORRECT ?????
203+ # TODO: CHECK FOR CORRECTNESS
187204 self .causal_impact = (
188205 self .y_pred_treatment [1 ] - self .y_pred_counterfactual [0 ]
189- )[ 0 ]
206+ )
190207 else :
191208 raise ValueError ("Model type not recognized" )
192209
210+ return
211+
193212 def input_validation (self ):
194213 """Validate the input data and model formula for correctness"""
195214 if "post_treatment" not in self .formula :
0 commit comments