Skip to content

Commit 45f1b1a

Browse files
committed
start embracing xarray to handle broadcasting
1 parent 2137091 commit 45f1b1a

File tree

4 files changed

+61
-10
lines changed

4 files changed

+61
-10
lines changed

causalpy/experiments/diff_in_diff.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020
import pandas as pd
2121
import seaborn as sns
22+
import xarray as xr
2223
from matplotlib import pyplot as plt
2324
from patsy import build_design_matrices, dmatrices
2425
from sklearn.base import RegressorMixin
@@ -87,7 +88,8 @@ def __init__(
8788
**kwargs,
8889
) -> None:
8990
super().__init__(model=model)
90-
91+
# rename the index to "obs_ind"
92+
data.index.name = "obs_ind"
9193
self.data = data
9294
self.expt_type = "Difference in Differences"
9395
self.formula = formula
@@ -102,6 +104,23 @@ def __init__(
102104
self.y, self.X = np.asarray(y), np.asarray(X)
103105
self.outcome_variable_name = y.design_info.column_names[0]
104106

107+
# turn into xarray.DataArray's
108+
self.X = xr.DataArray(
109+
self.X,
110+
dims=["obs_ind", "coeffs"],
111+
coords={
112+
"obs_ind": np.arange(self.X.shape[0]),
113+
"coeffs": self.labels,
114+
},
115+
)
116+
self.y = xr.DataArray(
117+
self.y[:, 0],
118+
dims=["obs_ind"],
119+
coords={
120+
"obs_ind": self.data.index,
121+
},
122+
)
123+
105124
# fit model
106125
if isinstance(self.model, PyMCModel):
107126
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.X.shape[0])}
@@ -190,6 +209,8 @@ def __init__(
190209
else:
191210
raise ValueError("Model type not recognized")
192211

212+
return
213+
193214
def input_validation(self):
194215
"""Validate the input data and model formula for correctness"""
195216
if "post_treatment" not in self.formula:

causalpy/experiments/interrupted_time_series.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import arviz as az
2121
import numpy as np
2222
import pandas as pd
23+
import xarray as xr
2324
from matplotlib import pyplot as plt
2425
from patsy import build_design_matrices, dmatrices
2526
from sklearn.base import RegressorMixin
@@ -84,6 +85,8 @@ def __init__(
8485
**kwargs,
8586
) -> None:
8687
super().__init__(model=model)
88+
# rename the index to "obs_ind"
89+
data.index.name = "obs_ind"
8790
self.input_validation(data, treatment_time)
8891
self.treatment_time = treatment_time
8992
# set experiment type - usually done in subclasses
@@ -107,6 +110,37 @@ def __init__(
107110
)
108111
self.post_X = np.asarray(new_x)
109112
self.post_y = np.asarray(new_y)
113+
# turn into xarray.DataArray's
114+
self.pre_X = xr.DataArray(
115+
self.pre_X,
116+
dims=["obs_ind", "coeffs"],
117+
coords={
118+
"obs_ind": self.datapre.index,
119+
"coeffs": self.labels,
120+
},
121+
)
122+
self.pre_y = xr.DataArray(
123+
self.pre_y[:, 0],
124+
dims=["obs_ind"],
125+
coords={
126+
"obs_ind": self.datapre.index,
127+
},
128+
)
129+
self.post_X = xr.DataArray(
130+
self.post_X,
131+
dims=["obs_ind", "coeffs"],
132+
coords={
133+
"obs_ind": self.datapost.index,
134+
"coeffs": self.labels,
135+
},
136+
)
137+
self.post_y = xr.DataArray(
138+
self.post_y[:, 0],
139+
dims=["obs_ind"],
140+
coords={
141+
"obs_ind": self.datapost.index,
142+
},
143+
)
110144

111145
# fit the model to the observed (pre-intervention) data
112146
if isinstance(self.model, PyMCModel):

causalpy/experiments/synthetic_control.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __init__(
9595
self.datapost = data[data.index >= self.treatment_time]
9696

9797
# split data into the 4 quadrants (pre/post, control/treated) and store as
98-
# xarray DataArray objects.
98+
# xarray.DataArray objects.
9999
# NOTE: if we have renamed/ensured the index is named "obs_ind", then it will
100100
# make constructing the xarray DataArray objects easier.
101101
self.datapre_control = xr.DataArray(

causalpy/pymc_models.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,16 +163,12 @@ def score(self, X, y) -> pd.Series:
163163
164164
"""
165165
mu = self.predict(X)
166-
mu = az.extract(mu, group="posterior_predictive", var_names="mu").T.values
167-
# Note: First argument must be a 1D array
168-
return r2_score(y.flatten(), mu)
166+
mu = az.extract(mu, group="posterior_predictive", var_names="mu").T
167+
return r2_score(y.data, mu.data)
169168

170169
def calculate_impact(
171-
self, y_true: xr.DataArray | np.ndarray, y_pred: az.InferenceData
170+
self, y_true: xr.DataArray, y_pred: az.InferenceData
172171
) -> xr.DataArray:
173-
if isinstance(y_true, np.ndarray):
174-
y_true = xr.DataArray(y_true, dims=["obs_ind"])
175-
176172
impact = y_true - y_pred["posterior_predictive"]["y_hat"]
177173
return impact.transpose(..., "obs_ind")
178174

@@ -240,7 +236,7 @@ def build_model(self, X, y, coords):
240236
with self:
241237
self.add_coords(coords)
242238
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
243-
y = pm.Data("y", y[:, 0], dims="obs_ind")
239+
y = pm.Data("y", y, dims="obs_ind")
244240
beta = pm.Normal("beta", 0, 50, dims="coeffs")
245241
sigma = pm.HalfNormal("sigma", 1)
246242
mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind")

0 commit comments

Comments
 (0)