Skip to content

Commit 1a5f9bd

Browse files
committed
code simplification related to PyMCModel._data_setter
1 parent eac1ef3 commit 1a5f9bd

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

causalpy/pymc_models.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -109,20 +109,18 @@ def _data_setter(self, X) -> None:
109109
has_treated_units = False
110110

111111
with self:
112-
if has_treated_units:
113-
# Get the number of treated units from the model coordinates
114-
treated_units_coord = getattr(self, "coords", {}).get(
115-
"treated_units", []
116-
)
117-
n_treated_units = (
118-
len(treated_units_coord) if treated_units_coord is not None else 1
119-
)
112+
# Get the number of treated units from the model coordinates
113+
treated_units_coord = getattr(self, "coords", {}).get("treated_units", [])
114+
n_treated_units = len(treated_units_coord) if treated_units_coord else 1
115+
116+
if n_treated_units > 1 or has_treated_units:
117+
# Multi-unit case or single unit with treated_units dimension
120118
pm.set_data(
121119
{"X": X, "y": np.zeros((new_no_of_observations, n_treated_units))},
122120
coords={"obs_ind": np.arange(new_no_of_observations)},
123121
)
124122
else:
125-
# Legacy case - this shouldn't happen with new WeightedSumFitter
123+
# Other model types (e.g., LinearRegression) without treated_units dimension
126124
pm.set_data(
127125
{"X": X, "y": np.zeros(new_no_of_observations)},
128126
coords={"obs_ind": np.arange(new_no_of_observations)},

0 commit comments

Comments
 (0)