Skip to content

Commit 4a78a50

Browse files
committed
simplification of WeightedSumFitter.build_model
1 parent 8badc05 commit 4a78a50

File tree

1 file changed

+1
-16
lines changed

1 file changed

+1
-16
lines changed

causalpy/pymc_models.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -360,22 +360,7 @@ def build_model(self, X, y, coords):
360360
self.add_coords(coords)
361361
n_predictors = X.shape[1]
362362
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
363-
364-
# Always use treated_units dimension for consistency
365-
# Convert to numpy array if it's an xarray DataArray
366-
if hasattr(y, "values"):
367-
y_data = y.values
368-
else:
369-
y_data = np.asarray(y)
370-
371-
# Ensure y_data has treated_units dimension
372-
if y_data.ndim == 1:
373-
y_data = y_data.reshape(-1, 1) # Add treated_units dimension
374-
elif y_data.ndim > 1 and y_data.shape[1] == 1:
375-
pass # Already has correct shape
376-
# If y_data.ndim > 1 and y_data.shape[1] > 1, it's multi-unit and already correct
377-
378-
y = pm.Data("y", y_data, dims=["obs_ind", "treated_units"])
363+
y = pm.Data("y", y, dims=["obs_ind", "treated_units"])
379364
beta = pm.Dirichlet(
380365
"beta", a=np.ones(n_predictors), dims=["treated_units", "coeffs"]
381366
)

0 commit comments

Comments
 (0)