@@ -313,7 +313,8 @@ def test_id():
313
313
314
314
315
315
@pytest .mark .parametrize ("predictions" , [True , False ])
316
- def test_predict_respects_predictions_flag (fitted_model_instance , predictions ):
316
+ @pytest .mark .parametrize ("predict_method" , ["predict" , "predict_posterior" ])
317
+ def test_predict_method_respects_predictions_flag (fitted_model_instance , predictions , predict_method ):
317
318
x_pred = np .random .uniform (0 , 1 , 100 )
318
319
prediction_data = pd .DataFrame ({"input" : x_pred })
319
320
output_var = fitted_model_instance .output_var
@@ -325,43 +326,18 @@ def test_predict_respects_predictions_flag(fitted_model_instance, predictions):
325
326
assert "predictions" not in fitted_model_instance .idata .groups ()
326
327
327
328
# Run prediction with predictions=True or False
328
- fitted_model_instance .predict (
329
- X_pred = prediction_data [["input" ]],
330
- extend_idata = True ,
331
- predictions = predictions ,
332
- )
333
-
334
- pp_after = fitted_model_instance .idata .posterior_predictive [output_var ].values
335
-
336
- # Check predictions group presence
337
- if predictions :
338
- assert "predictions" in fitted_model_instance .idata .groups ()
339
- # Posterior predictive should remain unchanged
340
- np .testing .assert_array_equal (pp_before , pp_after )
341
- else :
342
- assert "predictions" not in fitted_model_instance .idata .groups ()
343
- # Posterior predictive should be updated
344
- assert not np .array_equal (pp_before , pp_after )
345
-
346
- @pytest .mark .parametrize ("predictions" , [True , False ])
347
- def test_predict_posterior_respects_predictions_flag (fitted_model_instance , predictions ):
348
- x_pred = np .random .uniform (0 , 1 , 100 )
349
- prediction_data = pd .DataFrame ({"input" : x_pred })
350
- output_var = fitted_model_instance .output_var
351
-
352
- # Snapshot the original posterior_predictive values
353
- pp_before = fitted_model_instance .idata .posterior_predictive [output_var ].values .copy ()
354
-
355
- # Ensure 'predictions' group is not present initially
356
- assert "predictions" not in fitted_model_instance .idata .groups ()
357
-
358
- # Run prediction with predictions=True or False
359
- fitted_model_instance .predict_posterior (
360
- X_pred = prediction_data [["input" ]],
361
- extend_idata = True ,
362
- combined = True ,
363
- predictions = predictions ,
364
- )
329
+ if predict_method == "predict" :
330
+ fitted_model_instance .predict (
331
+ X_pred = prediction_data [["input" ]],
332
+ extend_idata = True ,
333
+ predictions = predictions ,
334
+ )
335
+ else :# predict_method == "predict_posterior":
336
+ fitted_model_instance .predict_posterior (
337
+ X_pred = prediction_data [["input" ]],
338
+ extend_idata = True ,
339
+ predictions = predictions ,
340
+ )
365
341
366
342
pp_after = fitted_model_instance .idata .posterior_predictive [output_var ].values
367
343
@@ -374,3 +350,4 @@ def test_predict_posterior_respects_predictions_flag(fitted_model_instance, pred
374
350
assert "predictions" not in fitted_model_instance .idata .groups ()
375
351
# Posterior predictive should be updated
376
352
assert not np .array_equal (pp_before , pp_after )
353
+
0 commit comments