Skip to content

Commit f1849b1

Browse files
committed
clean up PyMCModel.predict + PyMCModel._data_setter
1 parent 4a78a50 commit f1849b1

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

causalpy/pymc_models.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def build_model(self, X, y, coords) -> None:
8989
"""Build the model, must be implemented by subclass."""
9090
raise NotImplementedError("This method must be implemented by a subclass")
9191

92-
def _data_setter(self, X) -> None:
92+
def _data_setter(self, X: xr.DataArray) -> None:
9393
"""
9494
Set data for the model.
9595
@@ -105,6 +105,9 @@ def _data_setter(self, X) -> None:
105105
"""
106106
new_no_of_observations = X.shape[0]
107107

108+
# Use integer indices for obs_ind to avoid datetime compatibility issues with PyMC
109+
obs_coords = np.arange(new_no_of_observations)
110+
108111
# Check if this model has multiple treated units
109112
if hasattr(self, "idata") and self.idata is not None:
110113
posterior = self.idata.posterior
@@ -125,13 +128,13 @@ def _data_setter(self, X) -> None:
125128
# Multi-unit case or single unit with treated_units dimension
126129
pm.set_data(
127130
{"X": X, "y": np.zeros((new_no_of_observations, n_treated_units))},
128-
coords={"obs_ind": np.arange(new_no_of_observations)},
131+
coords={"obs_ind": obs_coords},
129132
)
130133
else:
131134
# Other model types (e.g., LinearRegression) without treated_units dimension
132135
pm.set_data(
133136
{"X": X, "y": np.zeros(new_no_of_observations)},
134-
coords={"obs_ind": np.arange(new_no_of_observations)},
137+
coords={"obs_ind": obs_coords},
135138
)
136139

137140
def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
@@ -154,7 +157,7 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
154157
)
155158
return self.idata
156159

157-
def predict(self, X):
160+
def predict(self, X: xr.DataArray):
158161
"""
159162
Predict data given input data `X`
160163
@@ -166,16 +169,19 @@ def predict(self, X):
166169
# sample_posterior_predictive() if provided in sample_kwargs.
167170
random_seed = self.sample_kwargs.get("random_seed", None)
168171
self._data_setter(X)
169-
with self: # sample with new input data
172+
with self:
170173
pp = pm.sample_posterior_predictive(
171174
self.idata,
172175
var_names=["y_hat", "mu"],
173176
progressbar=False,
174177
random_seed=random_seed,
175178
)
176179

177-
# TODO: This is a bit of a hack. Maybe it could be done properly in _data_setter?
178-
if isinstance(X, xr.DataArray):
180+
# Assign coordinates from input X to ensure xarray operations work correctly
181+
# This is necessary because PyMC uses integer indices internally, but we need
182+
# to preserve the original coordinates (e.g., datetime indices) for proper
183+
# alignment with other xarray operations like calculate_impact()
184+
if isinstance(X, xr.DataArray) and "obs_ind" in X.coords:
179185
pp["posterior_predictive"] = pp["posterior_predictive"].assign_coords(
180186
obs_ind=X.obs_ind
181187
)

0 commit comments

Comments
 (0)