|
20 | 20 | import arviz as az |
21 | 21 | import numpy as np |
22 | 22 | import pandas as pd |
| 23 | +import xarray as xr |
23 | 24 | from matplotlib import pyplot as plt |
24 | 25 | from patsy import build_design_matrices, dmatrices |
25 | 26 | from sklearn.base import RegressorMixin |
@@ -231,7 +232,13 @@ def __init__( |
231 | 232 | **kwargs, |
232 | 233 | ) -> None: |
233 | 234 | super().__init__(model=model) |
| 235 | + |
| 236 | + # rename the index to "obs_ind" |
| 237 | + data.index.name = "obs_ind" |
234 | 238 | self.input_validation(data, treatment_time, model) |
| 239 | + self.treatment_time = treatment_time |
| 240 | + # set experiment type - usually done in subclasses |
| 241 | + self.expt_type = "Pre-Post Fit" |
235 | 242 |
|
236 | 243 | self.treatment_time = treatment_time |
237 | 244 | self.formula = formula |
@@ -285,13 +292,38 @@ def __init__( |
285 | 292 | ) |
286 | 293 | self.post_X = np.asarray(new_x) |
287 | 294 | self.post_y = np.asarray(new_y) |
| 295 | + # turn into xarray.DataArray's |
| 296 | + self.pre_X = xr.DataArray( |
| 297 | + self.pre_X, |
| 298 | + dims=["obs_ind", "coeffs"], |
| 299 | + coords={ |
| 300 | + "obs_ind": self.datapre.index, |
| 301 | + "coeffs": self.labels, |
| 302 | + }, |
| 303 | + ) |
| 304 | + self.pre_y = xr.DataArray( |
| 305 | + self.pre_y[:, 0], |
| 306 | + dims=["obs_ind"], |
| 307 | + coords={"obs_ind": self.datapre.index}, |
| 308 | + ) |
| 309 | + self.post_X = xr.DataArray( |
| 310 | + self.post_X, |
| 311 | + dims=["obs_ind", "coeffs"], |
| 312 | + coords={ |
| 313 | + "obs_ind": self.datapost.index, |
| 314 | + "coeffs": self.labels, |
| 315 | + }, |
| 316 | + ) |
| 317 | + self.post_y = xr.DataArray( |
| 318 | + self.post_y[:, 0], |
| 319 | + dims=["obs_ind"], |
| 320 | + coords={"obs_ind": self.datapost.index}, |
| 321 | + ) |
288 | 322 |
|
289 | 323 | # calculate the counterfactual |
290 | 324 | self.post_pred = self.model.predict(X=self.post_X) |
291 | | - self.pre_impact = self.model.calculate_impact(self.pre_y[:, 0], self.pre_pred) |
292 | | - self.post_impact = self.model.calculate_impact( |
293 | | - self.post_y[:, 0], self.post_pred |
294 | | - ) |
| 325 | + self.pre_impact = self.model.calculate_impact(self.pre_y, self.pre_pred) |
| 326 | + self.post_impact = self.model.calculate_impact(self.post_y, self.post_pred) |
295 | 327 | self.post_impact_cumulative = self.model.calculate_cumulative_impact( |
296 | 328 | self.post_impact |
297 | 329 | ) |
|
0 commit comments