Skip to content

Commit 9bbc4cb

Browse files
committed
resolving conflicts
1 parent 47cf44e commit 9bbc4cb

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

causalpy/pymc_models.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,9 +1050,19 @@ def score(self, X, y) -> pd.Series:
10501050
Score the Bayesian :math:`R^2` given inputs ``X`` and outputs ``y``.
10511051
"""
10521052
mu_ts = self.predict(X)
1053-
mu_ts = az.extract(mu_ts, group="posterior_predictive", var_names="mu_ts").T
1054-
# Note: First argument must be a 1D array
1055-
return r2_score(y.data, mu_ts.data)
1053+
mu_data = az.extract(mu_ts, group="posterior_predictive", var_names="mu_ts")
1054+
1055+
scores = {}
1056+
1057+
# Always iterate over treated_units dimension - no branching needed!
1058+
for i, unit in enumerate(mu_data.coords["treated_units"].values):
1059+
unit_mu = mu_data.sel(treated_units=unit).T # (sample, obs_ind)
1060+
unit_y = y.sel(treated_units=unit).data
1061+
unit_score = r2_score(unit_y, unit_mu.data)
1062+
scores[f"unit_{i}_r2"] = unit_score["r2"]
1063+
scores[f"unit_{i}_r2_std"] = unit_score["r2_std"]
1064+
1065+
return pd.Series(scores)
10561066

10571067
def set_time_range(self, time_range, data):
10581068
"""

0 commit comments

Comments
 (0)