Skip to content

Commit b3f9a6c

Browse files
committed
Consolidate test with pytest paramterize
1 parent e698292 commit b3f9a6c

File tree

1 file changed

+15
-38
lines changed

1 file changed

+15
-38
lines changed

tests/test_model_builder.py

Lines changed: 15 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,8 @@ def test_id():
313313

314314

315315
@pytest.mark.parametrize("predictions", [True, False])
316-
def test_predict_respects_predictions_flag(fitted_model_instance, predictions):
316+
@pytest.mark.parametrize("predict_method", ["predict", "predict_posterior"])
317+
def test_predict_method_respects_predictions_flag(fitted_model_instance, predictions, predict_method):
317318
x_pred = np.random.uniform(0, 1, 100)
318319
prediction_data = pd.DataFrame({"input": x_pred})
319320
output_var = fitted_model_instance.output_var
@@ -325,43 +326,18 @@ def test_predict_respects_predictions_flag(fitted_model_instance, predictions):
325326
assert "predictions" not in fitted_model_instance.idata.groups()
326327

327328
# Run prediction with predictions=True or False
328-
fitted_model_instance.predict(
329-
X_pred=prediction_data[["input"]],
330-
extend_idata=True,
331-
predictions=predictions,
332-
)
333-
334-
pp_after = fitted_model_instance.idata.posterior_predictive[output_var].values
335-
336-
# Check predictions group presence
337-
if predictions:
338-
assert "predictions" in fitted_model_instance.idata.groups()
339-
# Posterior predictive should remain unchanged
340-
np.testing.assert_array_equal(pp_before, pp_after)
341-
else:
342-
assert "predictions" not in fitted_model_instance.idata.groups()
343-
# Posterior predictive should be updated
344-
assert not np.array_equal(pp_before, pp_after)
345-
346-
@pytest.mark.parametrize("predictions", [True, False])
347-
def test_predict_posterior_respects_predictions_flag(fitted_model_instance, predictions):
348-
x_pred = np.random.uniform(0, 1, 100)
349-
prediction_data = pd.DataFrame({"input": x_pred})
350-
output_var = fitted_model_instance.output_var
351-
352-
# Snapshot the original posterior_predictive values
353-
pp_before = fitted_model_instance.idata.posterior_predictive[output_var].values.copy()
354-
355-
# Ensure 'predictions' group is not present initially
356-
assert "predictions" not in fitted_model_instance.idata.groups()
357-
358-
# Run prediction with predictions=True or False
359-
fitted_model_instance.predict_posterior(
360-
X_pred=prediction_data[["input"]],
361-
extend_idata=True,
362-
combined=True,
363-
predictions=predictions,
364-
)
329+
if predict_method == "predict":
330+
fitted_model_instance.predict(
331+
X_pred=prediction_data[["input"]],
332+
extend_idata=True,
333+
predictions=predictions,
334+
)
335+
else:# predict_method == "predict_posterior":
336+
fitted_model_instance.predict_posterior(
337+
X_pred=prediction_data[["input"]],
338+
extend_idata=True,
339+
predictions=predictions,
340+
)
365341

366342
pp_after = fitted_model_instance.idata.posterior_predictive[output_var].values
367343

@@ -374,3 +350,4 @@ def test_predict_posterior_respects_predictions_flag(fitted_model_instance, pred
374350
assert "predictions" not in fitted_model_instance.idata.groups()
375351
# Posterior predictive should be updated
376352
assert not np.array_equal(pp_before, pp_after)
353+

0 commit comments

Comments
 (0)