@@ -47,8 +47,16 @@ class PyMCModel(pm.Model):
4747 ... mu = pm.Deterministic("mu", pm.math.dot(X_, beta))
4848 ... pm.Normal("y_hat", mu=mu, sigma=sigma, observed=y_)
4949 >>> rng = np.random.default_rng(seed=42)
50- >>> X = rng.normal(loc=0, scale=1, size=(20, 2))
51- >>> y = rng.normal(loc=0, scale=1, size=(20,))
50+ >>> X = xr.DataArray(
51+ ... rng.normal(loc=0, scale=1, size=(20, 2)),
52+ ... dims=["obs_ind", "coeffs"],
53+ ... coords={"obs_ind": np.arange(20), "coeffs": ["coeff_0", "coeff_1"]},
54+ ... )
55+ >>> y = xr.DataArray(
56+ ... rng.normal(loc=0, scale=1, size=(20,)),
57+ ... dims=["obs_ind"],
58+ ... coords={"obs_ind": np.arange(20)},
59+ ... )
5260 >>> model = MyToyModel(
5361 ... sample_kwargs={
5462 ... "chains": 2,
@@ -174,7 +182,7 @@ def predict(self, X):
174182
175183 return pp
176184
177- def score (self , X , y ) -> pd .Series :
185+ def score (self , X : xr . DataArray , y : xr . DataArray ) -> pd .Series :
178186 """Score the Bayesian :math:`R^2` given inputs ``X`` and outputs ``y``.
179187
180188 Note that the score is based on a comparison of the observed data ``y`` and the
@@ -197,14 +205,7 @@ def score(self, X, y) -> pd.Series:
197205
198206 for unit in treated_units :
199207 unit_mu = mu_data .sel (treated_units = unit ).T # (sample, obs_ind)
200-
201- # Handle both xarray and numpy arrays for y
202- if hasattr (y , "sel" ): # xarray.DataArray
203- unit_y = y .sel (treated_units = unit ).data
204- else : # numpy array
205- unit_idx = list (treated_units ).index (unit )
206- unit_y = y [:, unit_idx ] if y .ndim > 1 else y
207-
208+ unit_y = y .sel (treated_units = unit ).data
208209 unit_score = r2_score (unit_y , unit_mu .data )
209210
210211 # Flatten the r2_score results into the expected format
@@ -215,24 +216,7 @@ def score(self, X, y) -> pd.Series:
215216 else :
216217 # Single treated unit - transpose to match expected format
217218 mu_data = mu_data .T
218-
219- # Handle different y types robustly
220- if hasattr (y , "data" ): # xarray.DataArray
221- y_raw = y .data
222- # Convert to numpy array if it's a memoryview
223- if isinstance (y_raw , memoryview ):
224- y_data = np .asarray (y_raw )
225- else :
226- y_data = y_raw
227- # Squeeze if needed
228- y_data = y_data if y_data .ndim == 1 else y_data .squeeze ()
229- else : # numpy array or memoryview
230- if hasattr (y , "squeeze" ): # numpy array
231- y_data = y if y .ndim == 1 else y .squeeze ()
232- else : # memoryview or other
233- y_data = np .asarray (y )
234- y_data = y_data if y_data .ndim == 1 else y_data .squeeze ()
235-
219+ y_data = y .data .squeeze () if y .data .ndim > 1 else y .data
236220 return r2_score (y_data , mu_data .data )
237221
238222 def calculate_impact (
0 commit comments