Skip to content

Commit b6f5ca8

Browse files
committed
PyMCModel.score always to get xr.DataArray arguments
1 parent d0fc0d3 commit b6f5ca8

File tree

2 files changed

+15
-31
lines changed

2 files changed

+15
-31
lines changed

causalpy/experiments/synthetic_control.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ def __init__(
156156

157157
# score the goodness of fit to the pre-intervention data
158158
self.score = self.model.score(
159-
X=self.datapre_control.to_numpy(),
160-
y=self.datapre_treated.to_numpy(),
159+
X=self.datapre_control,
160+
y=self.datapre_treated,
161161
)
162162

163163
# get the model predictions of the observed (pre-intervention) data

causalpy/pymc_models.py

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,16 @@ class PyMCModel(pm.Model):
4747
... mu = pm.Deterministic("mu", pm.math.dot(X_, beta))
4848
... pm.Normal("y_hat", mu=mu, sigma=sigma, observed=y_)
4949
>>> rng = np.random.default_rng(seed=42)
50-
>>> X = rng.normal(loc=0, scale=1, size=(20, 2))
51-
>>> y = rng.normal(loc=0, scale=1, size=(20,))
50+
>>> X = xr.DataArray(
51+
... rng.normal(loc=0, scale=1, size=(20, 2)),
52+
... dims=["obs_ind", "coeffs"],
53+
... coords={"obs_ind": np.arange(20), "coeffs": ["coeff_0", "coeff_1"]},
54+
... )
55+
>>> y = xr.DataArray(
56+
... rng.normal(loc=0, scale=1, size=(20,)),
57+
... dims=["obs_ind"],
58+
... coords={"obs_ind": np.arange(20)},
59+
... )
5260
>>> model = MyToyModel(
5361
... sample_kwargs={
5462
... "chains": 2,
@@ -174,7 +182,7 @@ def predict(self, X):
174182

175183
return pp
176184

177-
def score(self, X, y) -> pd.Series:
185+
def score(self, X: xr.DataArray, y: xr.DataArray) -> pd.Series:
178186
"""Score the Bayesian :math:`R^2` given inputs ``X`` and outputs ``y``.
179187
180188
Note that the score is based on a comparison of the observed data ``y`` and the
@@ -197,14 +205,7 @@ def score(self, X, y) -> pd.Series:
197205

198206
for unit in treated_units:
199207
unit_mu = mu_data.sel(treated_units=unit).T # (sample, obs_ind)
200-
201-
# Handle both xarray and numpy arrays for y
202-
if hasattr(y, "sel"): # xarray.DataArray
203-
unit_y = y.sel(treated_units=unit).data
204-
else: # numpy array
205-
unit_idx = list(treated_units).index(unit)
206-
unit_y = y[:, unit_idx] if y.ndim > 1 else y
207-
208+
unit_y = y.sel(treated_units=unit).data
208209
unit_score = r2_score(unit_y, unit_mu.data)
209210

210211
# Flatten the r2_score results into the expected format
@@ -215,24 +216,7 @@ def score(self, X, y) -> pd.Series:
215216
else:
216217
# Single treated unit - transpose to match expected format
217218
mu_data = mu_data.T
218-
219-
# Handle different y types robustly
220-
if hasattr(y, "data"): # xarray.DataArray
221-
y_raw = y.data
222-
# Convert to numpy array if it's a memoryview
223-
if isinstance(y_raw, memoryview):
224-
y_data = np.asarray(y_raw)
225-
else:
226-
y_data = y_raw
227-
# Squeeze if needed
228-
y_data = y_data if y_data.ndim == 1 else y_data.squeeze()
229-
else: # numpy array or memoryview
230-
if hasattr(y, "squeeze"): # numpy array
231-
y_data = y if y.ndim == 1 else y.squeeze()
232-
else: # memoryview or other
233-
y_data = np.asarray(y)
234-
y_data = y_data if y_data.ndim == 1 else y_data.squeeze()
235-
219+
y_data = y.data.squeeze() if y.data.ndim > 1 else y.data
236220
return r2_score(y_data, mu_data.data)
237221

238222
def calculate_impact(

0 commit comments

Comments
 (0)