Skip to content

Commit adf04f9

Browse files
committed
refactor PyMCModel.score
1 parent 3ee430e commit adf04f9

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

causalpy/pymc_models.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -206,22 +206,28 @@ def score(self, X: xr.DataArray, y: xr.DataArray) -> pd.Series:
206206
# Always use unified labeling system: unit_0_r2, unit_1_r2, etc.
207207
scores = {}
208208

209+
# Determine units to process - always use a loop for consistency
209210
if "treated_units" in mu_data.dims:
210-
# Multiple treated units - score each unit separately
211-
treated_units = mu_data.coords["treated_units"].values
212-
for i, unit in enumerate(treated_units):
213-
unit_mu = mu_data.sel(treated_units=unit).T # (sample, obs_ind)
214-
unit_y = y.sel(treated_units=unit).data
215-
unit_score = r2_score(unit_y, unit_mu.data)
216-
scores[f"unit_{i}_r2"] = unit_score["r2"]
217-
scores[f"unit_{i}_r2_std"] = unit_score["r2_std"]
211+
# Multiple treated units
212+
units = list(enumerate(mu_data.coords["treated_units"].values))
218213
else:
219-
# Single treated unit - use unit_0 for consistency
220-
mu_data = mu_data.T
221-
y_data = y.data.squeeze() if y.data.ndim > 1 else y.data
222-
unit_score = r2_score(y_data, mu_data.data)
223-
scores["unit_0_r2"] = unit_score["r2"]
224-
scores["unit_0_r2_std"] = unit_score["r2_std"]
214+
# Single unit - treat as single-item list
215+
units = [(0, None)]
216+
217+
# Process all units using the same loop logic
218+
for i, unit_selector in units:
219+
if unit_selector is not None:
220+
# Multi-unit case: select specific unit
221+
unit_mu = mu_data.sel(treated_units=unit_selector).T
222+
unit_y = y.sel(treated_units=unit_selector).data
223+
else:
224+
# Single unit case: use all data
225+
unit_mu = mu_data.T
226+
unit_y = y.data.squeeze() if y.data.ndim > 1 else y.data
227+
228+
unit_score = r2_score(unit_y, unit_mu.data)
229+
scores[f"unit_{i}_r2"] = unit_score["r2"]
230+
scores[f"unit_{i}_r2_std"] = unit_score["r2_std"]
225231

226232
return pd.Series(scores)
227233

0 commit comments

Comments
 (0)