@@ -206,22 +206,28 @@ def score(self, X: xr.DataArray, y: xr.DataArray) -> pd.Series:
206206 # Always use unified labeling system: unit_0_r2, unit_1_r2, etc.
207207 scores = {}
208208
209+ # Determine units to process - always use a loop for consistency
209210 if "treated_units" in mu_data .dims :
210- # Multiple treated units - score each unit separately
211- treated_units = mu_data .coords ["treated_units" ].values
212- for i , unit in enumerate (treated_units ):
213- unit_mu = mu_data .sel (treated_units = unit ).T # (sample, obs_ind)
214- unit_y = y .sel (treated_units = unit ).data
215- unit_score = r2_score (unit_y , unit_mu .data )
216- scores [f"unit_{ i } _r2" ] = unit_score ["r2" ]
217- scores [f"unit_{ i } _r2_std" ] = unit_score ["r2_std" ]
211+ # Multiple treated units
212+ units = list (enumerate (mu_data .coords ["treated_units" ].values ))
218213 else :
219- # Single treated unit - use unit_0 for consistency
220- mu_data = mu_data .T
221- y_data = y .data .squeeze () if y .data .ndim > 1 else y .data
222- unit_score = r2_score (y_data , mu_data .data )
223- scores ["unit_0_r2" ] = unit_score ["r2" ]
224- scores ["unit_0_r2_std" ] = unit_score ["r2_std" ]
214+ # Single unit - treat as single-item list
215+ units = [(0 , None )]
216+
217+ # Process all units using the same loop logic
218+ for i , unit_selector in units :
219+ if unit_selector is not None :
220+ # Multi-unit case: select specific unit
221+ unit_mu = mu_data .sel (treated_units = unit_selector ).T
222+ unit_y = y .sel (treated_units = unit_selector ).data
223+ else :
224+ # Single unit case: use all data
225+ unit_mu = mu_data .T
226+ unit_y = y .data .squeeze () if y .data .ndim > 1 else y .data
227+
228+ unit_score = r2_score (unit_y , unit_mu .data )
229+ scores [f"unit_{ i } _r2" ] = unit_score ["r2" ]
230+ scores [f"unit_{ i } _r2_std" ] = unit_score ["r2_std" ]
225231
226232 return pd .Series (scores )
227233
0 commit comments