File tree Expand file tree Collapse file tree 2 files changed +6
-4
lines changed Expand file tree Collapse file tree 2 files changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -125,10 +125,8 @@ def __init__(
125
125
126
126
# calculate the counterfactual
127
127
self .post_pred = self .model .predict (X = self .post_X )
128
- self .pre_impact = self .model .calculate_impact (self .pre_y [:, 0 ], self .pre_pred )
129
- self .post_impact = self .model .calculate_impact (
130
- self .post_y [:, 0 ], self .post_pred
131
- )
128
+ self .pre_impact = self .model .calculate_impact (self .pre_y , self .pre_pred )
129
+ self .post_impact = self .model .calculate_impact (self .post_y , self .post_pred )
132
130
self .post_impact_cumulative = self .model .calculate_cumulative_impact (
133
131
self .post_impact
134
132
)
Original file line number Diff line number Diff line change 16
16
from functools import partial
17
17
18
18
import numpy as np
19
+ import xarray as xr
19
20
from scipy .optimize import fmin_slsqp
20
21
from sklearn .base import RegressorMixin
21
22
from sklearn .linear_model ._base import LinearModel
@@ -28,6 +29,9 @@ class ScikitLearnAdaptor:
28
29
29
30
def calculate_impact (self , y_true , y_pred ):
30
31
"""Calculate the causal impact of the intervention."""
32
+ if isinstance (y_true , np .ndarray ):
33
+ y_true = xr .DataArray (y_true , dims = ["obs_ind" ])
34
+
31
35
return y_true - y_pred
32
36
33
37
def calculate_cumulative_impact (self , impact ):
You can’t perform that action at this time.
0 commit comments