2020import arviz as az
2121import numpy as np
2222import pandas as pd
23+ import xarray as xr
2324from matplotlib import pyplot as plt
2425from patsy import build_design_matrices , dmatrices
2526from sklearn .base import RegressorMixin
@@ -84,6 +85,8 @@ def __init__(
8485 ** kwargs ,
8586 ) -> None :
8687 super ().__init__ (model = model )
88+ # rename the index to "obs_ind"
89+ data .index .name = "obs_ind"
8790 self .input_validation (data , treatment_time )
8891 self .treatment_time = treatment_time
8992 # set experiment type - usually done in subclasses
@@ -107,6 +110,33 @@ def __init__(
107110 )
108111 self .post_X = np .asarray (new_x )
109112 self .post_y = np .asarray (new_y )
113+ # turn into xarray.DataArray's
114+ self .pre_X = xr .DataArray (
115+ self .pre_X ,
116+ dims = ["obs_ind" , "coeffs" ],
117+ coords = {
118+ "obs_ind" : self .datapre .index ,
119+ "coeffs" : self .labels ,
120+ },
121+ )
122+ self .pre_y = xr .DataArray (
123+ self .pre_y [:, 0 ],
124+ dims = ["obs_ind" ],
125+ coords = {"obs_ind" : self .datapre .index },
126+ )
127+ self .post_X = xr .DataArray (
128+ self .post_X ,
129+ dims = ["obs_ind" , "coeffs" ],
130+ coords = {
131+ "obs_ind" : self .datapost .index ,
132+ "coeffs" : self .labels ,
133+ },
134+ )
135+ self .post_y = xr .DataArray (
136+ self .post_y [:, 0 ],
137+ dims = ["obs_ind" ],
138+ coords = {"obs_ind" : self .datapost .index },
139+ )
110140
111141 # fit the model to the observed (pre-intervention) data
112142 if isinstance (self .model , PyMCModel ):
@@ -125,10 +155,8 @@ def __init__(
125155
126156 # calculate the counterfactual
127157 self .post_pred = self .model .predict (X = self .post_X )
128- self .pre_impact = self .model .calculate_impact (self .pre_y [:, 0 ], self .pre_pred )
129- self .post_impact = self .model .calculate_impact (
130- self .post_y [:, 0 ], self .post_pred
131- )
158+ self .pre_impact = self .model .calculate_impact (self .pre_y , self .pre_pred )
159+ self .post_impact = self .model .calculate_impact (self .post_y , self .post_pred )
132160 self .post_impact_cumulative = self .model .calculate_cumulative_impact (
133161 self .post_impact
134162 )
0 commit comments