Skip to content

Commit e698292

Browse files
committed
update predict calls to handle validate data and predictions group
1 parent e7fe9a2 commit e698292

File tree

1 file changed

+51
-14
lines changed

1 file changed

+51
-14
lines changed

tests/test_model_builder.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,18 @@ def _save_input_params(self, idata):
124124
def output_var(self):
125125
return "output"
126126

127-
def _data_setter(self, x: pd.Series, y: pd.Series = None):
127+
def _data_setter(self, X: pd.Series | np.ndarray, y: pd.Series | np.ndarray = None):
128+
128129
with self.model:
129-
pm.set_data({"x": x.values})
130+
131+
X = X.values if isinstance(X, pd.Series) else X.ravel()
132+
133+
pm.set_data({"x": X})
134+
130135
if y is not None:
131-
pm.set_data({"y_data": y.values})
136+
y = y.values if isinstance(y, pd.Series) else y.ravel()
137+
138+
pm.set_data({"y_data": y})
132139

133140
@property
134141
def _serializable_model_config(self):
@@ -177,8 +184,8 @@ def test_save_load(fitted_model_instance):
177184
assert fitted_model_instance.id == test_builder2.id
178185
x_pred = np.random.uniform(low=0, high=1, size=100)
179186
prediction_data = pd.DataFrame({"input": x_pred})
180-
pred1 = fitted_model_instance.predict(prediction_data["input"])
181-
pred2 = test_builder2.predict(prediction_data["input"])
187+
pred1 = fitted_model_instance.predict(prediction_data[["input"]])
188+
pred2 = test_builder2.predict(prediction_data[["input"]])
182189
assert pred1.shape == pred2.shape
183190
temp.close()
184191

@@ -205,7 +212,7 @@ def test_empty_sampler_config_fit(toy_X, toy_y):
205212

206213
def test_fit(fitted_model_instance):
207214
prediction_data = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)})
208-
pred = fitted_model_instance.predict(prediction_data["input"])
215+
pred = fitted_model_instance.predict(prediction_data[["input"]])
209216
post_pred = fitted_model_instance.sample_posterior_predictive(
210217
prediction_data["input"], extend_idata=True, combined=True
211218
)
@@ -223,7 +230,7 @@ def test_fit_no_y(toy_X):
223230
def test_predict(fitted_model_instance):
224231
x_pred = np.random.uniform(low=0, high=1, size=100)
225232
prediction_data = pd.DataFrame({"input": x_pred})
226-
pred = fitted_model_instance.predict(prediction_data["input"])
233+
pred = fitted_model_instance.predict(prediction_data[["input"]])
227234
# Perform elementwise comparison using numpy
228235
assert isinstance(pred, np.ndarray)
229236
assert len(pred) > 0
@@ -256,13 +263,12 @@ def test_sample_xxx_extend_idata_param(fitted_model_instance, group, extend_idat
256263

257264
prediction_data = pd.DataFrame({"input": x_pred})
258265
if group == "prior_predictive":
259-
prediction_method = fitted_model_instance.sample_prior_predictive
266+
pred = fitted_model_instance.sample_prior_predictive(prediction_data["input"], combined=False, extend_idata=extend_idata)
260267
else: # group == "posterior_predictive":
261-
prediction_method = fitted_model_instance.sample_posterior_predictive
262-
263-
pred = prediction_method(prediction_data["input"], combined=False, extend_idata=extend_idata)
268+
pred = fitted_model_instance.sample_posterior_predictive(prediction_data["input"], combined=False, predictions=False, extend_idata=extend_idata)
264269

265270
pred_unstacked = pred[output_var].values
271+
266272
idata_now = fitted_model_instance.idata[group][output_var].values
267273

268274
if extend_idata:
@@ -320,9 +326,40 @@ def test_predict_respects_predictions_flag(fitted_model_instance, predictions):
320326

321327
# Run prediction with predictions=True or False
322328
fitted_model_instance.predict(
323-
prediction_data["input"],
329+
X_pred=prediction_data[["input"]],
330+
extend_idata=True,
331+
predictions=predictions,
332+
)
333+
334+
pp_after = fitted_model_instance.idata.posterior_predictive[output_var].values
335+
336+
# Check predictions group presence
337+
if predictions:
338+
assert "predictions" in fitted_model_instance.idata.groups()
339+
# Posterior predictive should remain unchanged
340+
np.testing.assert_array_equal(pp_before, pp_after)
341+
else:
342+
assert "predictions" not in fitted_model_instance.idata.groups()
343+
# Posterior predictive should be updated
344+
assert not np.array_equal(pp_before, pp_after)
345+
346+
@pytest.mark.parametrize("predictions", [True, False])
347+
def test_predict_posterior_respects_predictions_flag(fitted_model_instance, predictions):
348+
x_pred = np.random.uniform(0, 1, 100)
349+
prediction_data = pd.DataFrame({"input": x_pred})
350+
output_var = fitted_model_instance.output_var
351+
352+
# Snapshot the original posterior_predictive values
353+
pp_before = fitted_model_instance.idata.posterior_predictive[output_var].values.copy()
354+
355+
# Ensure 'predictions' group is not present initially
356+
assert "predictions" not in fitted_model_instance.idata.groups()
357+
358+
# Run prediction with predictions=True or False
359+
fitted_model_instance.predict_posterior(
360+
X_pred=prediction_data[["input"]],
324361
extend_idata=True,
325-
combined=False,
362+
combined=True,
326363
predictions=predictions,
327364
)
328365

@@ -336,4 +373,4 @@ def test_predict_respects_predictions_flag(fitted_model_instance, predictions):
336373
else:
337374
assert "predictions" not in fitted_model_instance.idata.groups()
338375
# Posterior predictive should be updated
339-
np.testing.assert_array_not_equal(pp_before, pp_after)
376+
assert not np.array_equal(pp_before, pp_after)

0 commit comments

Comments
 (0)