Skip to content

Commit 15454b2

Browse files
committed
bug fixing
1 parent a148ec3 commit 15454b2

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

causalpy/experiments/interrupted_time_series.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,8 @@ def __init__(
125125

126126
# calculate the counterfactual
127127
self.post_pred = self.model.predict(X=self.post_X)
128-
self.pre_impact = self.model.calculate_impact(self.pre_y[:, 0], self.pre_pred)
129-
self.post_impact = self.model.calculate_impact(
130-
self.post_y[:, 0], self.post_pred
131-
)
128+
self.pre_impact = self.model.calculate_impact(self.pre_y, self.pre_pred)
129+
self.post_impact = self.model.calculate_impact(self.post_y, self.post_pred)
132130
self.post_impact_cumulative = self.model.calculate_cumulative_impact(
133131
self.post_impact
134132
)

causalpy/skl_models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from functools import partial
1717

1818
import numpy as np
19+
import xarray as xr
1920
from scipy.optimize import fmin_slsqp
2021
from sklearn.base import RegressorMixin
2122
from sklearn.linear_model._base import LinearModel
@@ -28,6 +29,9 @@ class ScikitLearnAdaptor:
2829

2930
def calculate_impact(self, y_true, y_pred):
3031
"""Calculate the causal impact of the intervention."""
32+
if isinstance(y_true, np.ndarray):
33+
y_true = xr.DataArray(y_true, dims=["obs_ind"])
34+
3135
return y_true - y_pred
3236

3337
def calculate_cumulative_impact(self, impact):

0 commit comments

Comments
 (0)