Skip to content

Commit 3bbabee

Browse files
committed
Merge branch 'main' into sc-api-change
2 parents 82d041d + 273daa2 commit 3bbabee

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

causalpy/pymc_models.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,20 @@ def _data_setter(self, X) -> None:
8787
8888
This method is used internally to register new data for the model for
8989
prediction.
90+
91+
NOTE: We are actively changing the `X`. Often, this matrix will have a different
92+
number of rows than the original data. So to make the shapes work, we need to
93+
update all data nodes in the model to have the correct shape. The values are not
94+
used, so we set them to 0. In our case, we just have data nodes X and y, but if
95+
in the future we get more complex models with more data nodes, then we'll need
96+
to update all of them - ideally programmatically.
9097
"""
98+
new_no_of_observations = X.shape[0]
9199
with self:
92-
# TODO: update coords
93-
pm.set_data({"X": X})
100+
pm.set_data(
101+
{"X": X, "y": np.zeros(new_no_of_observations)},
102+
coords={"obs_ind": np.arange(new_no_of_observations)},
103+
)
94104

95105
def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
96106
"""Draw samples from posterior, prior predictive, and posterior predictive
@@ -118,7 +128,6 @@ def predict(self, X):
118128
119129
.. caution::
120130
Results in KeyError if model hasn't been fit.
121-
122131
"""
123132

124133
# Ensure random_seed is used in sample_prior_predictive() and

docs/source/_static/classes.png

136 KB
Loading

docs/source/_static/packages.png

25.9 KB
Loading

0 commit comments

Comments
 (0)