File tree Expand file tree Collapse file tree 2 files changed +6
-4
lines changed Expand file tree Collapse file tree 2 files changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -125,10 +125,8 @@ def __init__(
125125
126126 # calculate the counterfactual
127127 self .post_pred = self .model .predict (X = self .post_X )
128- self .pre_impact = self .model .calculate_impact (self .pre_y [:, 0 ], self .pre_pred )
129- self .post_impact = self .model .calculate_impact (
130- self .post_y [:, 0 ], self .post_pred
131- )
128+ self .pre_impact = self .model .calculate_impact (self .pre_y , self .pre_pred )
129+ self .post_impact = self .model .calculate_impact (self .post_y , self .post_pred )
132130 self .post_impact_cumulative = self .model .calculate_cumulative_impact (
133131 self .post_impact
134132 )
Original file line number Diff line number Diff line change 1616from functools import partial
1717
1818import numpy as np
19+ import xarray as xr
1920from scipy .optimize import fmin_slsqp
2021from sklearn .base import RegressorMixin
2122from sklearn .linear_model ._base import LinearModel
@@ -28,6 +29,9 @@ class ScikitLearnAdaptor:
2829
2930 def calculate_impact (self , y_true , y_pred ):
3031 """Calculate the causal impact of the intervention."""
32+ if isinstance (y_true , np .ndarray ):
33+ y_true = xr .DataArray (y_true , dims = ["obs_ind" ])
34+
3135 return y_true - y_pred
3236
3337 def calculate_cumulative_impact (self , impact ):
You can’t perform that action at this time.
0 commit comments