@@ -68,8 +68,8 @@ class PyMCModel(pm.Model):
6868 >>> model.fit(X, y)
6969 Inference data...
7070 >>> model.score(X, y) # doctest: +ELLIPSIS
71- r2 ...
72- r2_std ...
71+ unit_r2 ...
72+ unit_r2_std ...
7373 dtype: float64
7474 >>> X_new = rng.normal(loc=0, scale=1, size=(20, 2))
7575 >>> model.predict(X_new)
@@ -203,27 +203,32 @@ def score(self, X: xr.DataArray, y: xr.DataArray) -> pd.Series:
203203 mu = self .predict (X )
204204 mu_data = az .extract (mu , group = "posterior_predictive" , var_names = "mu" )
205205
206- # Handle both single and multiple treated units
206+ # Always use the multiple treated unit convention for consistency
207+ scores = {}
208+
207209 if "treated_units" in mu_data .dims :
208- # Multiple treated units - we need to score each unit separately
210+ # Multiple treated units - score each unit separately
209211 treated_units = mu_data .coords ["treated_units" ].values
210- scores = {}
211-
212212 for unit in treated_units :
213213 unit_mu = mu_data .sel (treated_units = unit ).T # (sample, obs_ind)
214214 unit_y = y .sel (treated_units = unit ).data
215215 unit_score = r2_score (unit_y , unit_mu .data )
216-
217- # Flatten the r2_score results into the expected format
218216 scores [f"{ unit } _r2" ] = unit_score ["r2" ]
219217 scores [f"{ unit } _r2_std" ] = unit_score ["r2_std" ]
220-
221- return pd .Series (scores )
222218 else :
223- # Single treated unit - transpose to match expected format
219+ # Single treated unit - determine unit name and use same format
220+ if hasattr (y , "coords" ) and "treated_units" in y .coords :
221+ unit_name = y .coords ["treated_units" ].values [0 ]
222+ else :
223+ unit_name = "unit" # Fallback for backwards compatibility
224+
224225 mu_data = mu_data .T
225226 y_data = y .data .squeeze () if y .data .ndim > 1 else y .data
226- return r2_score (y_data , mu_data .data )
227+ unit_score = r2_score (y_data , mu_data .data )
228+ scores [f"{ unit_name } _r2" ] = unit_score ["r2" ]
229+ scores [f"{ unit_name } _r2_std" ] = unit_score ["r2_std" ]
230+
231+ return pd .Series (scores )
227232
228233 def calculate_impact (
229234 self , y_true : xr .DataArray , y_pred : az .InferenceData
0 commit comments