@@ -124,11 +124,18 @@ def _save_input_params(self, idata):
124
124
def output_var (self ):
125
125
return "output"
126
126
127
- def _data_setter (self , x : pd .Series , y : pd .Series = None ):
127
+ def _data_setter (self , X : pd .Series | np .ndarray , y : pd .Series | np .ndarray = None ):
128
+
128
129
with self .model :
129
- pm .set_data ({"x" : x .values })
130
+
131
+ X = X .values if isinstance (X , pd .Series ) else X .ravel ()
132
+
133
+ pm .set_data ({"x" : X })
134
+
130
135
if y is not None :
131
- pm .set_data ({"y_data" : y .values })
136
+ y = y .values if isinstance (y , pd .Series ) else y .ravel ()
137
+
138
+ pm .set_data ({"y_data" : y })
132
139
133
140
@property
134
141
def _serializable_model_config (self ):
@@ -177,8 +184,8 @@ def test_save_load(fitted_model_instance):
177
184
assert fitted_model_instance .id == test_builder2 .id
178
185
x_pred = np .random .uniform (low = 0 , high = 1 , size = 100 )
179
186
prediction_data = pd .DataFrame ({"input" : x_pred })
180
- pred1 = fitted_model_instance .predict (prediction_data ["input" ])
181
- pred2 = test_builder2 .predict (prediction_data ["input" ])
187
+ pred1 = fitted_model_instance .predict (prediction_data [[ "input" ] ])
188
+ pred2 = test_builder2 .predict (prediction_data [[ "input" ] ])
182
189
assert pred1 .shape == pred2 .shape
183
190
temp .close ()
184
191
@@ -205,7 +212,7 @@ def test_empty_sampler_config_fit(toy_X, toy_y):
205
212
206
213
def test_fit (fitted_model_instance ):
207
214
prediction_data = pd .DataFrame ({"input" : np .random .uniform (low = 0 , high = 1 , size = 100 )})
208
- pred = fitted_model_instance .predict (prediction_data ["input" ])
215
+ pred = fitted_model_instance .predict (prediction_data [[ "input" ] ])
209
216
post_pred = fitted_model_instance .sample_posterior_predictive (
210
217
prediction_data ["input" ], extend_idata = True , combined = True
211
218
)
@@ -223,7 +230,7 @@ def test_fit_no_y(toy_X):
223
230
def test_predict (fitted_model_instance ):
224
231
x_pred = np .random .uniform (low = 0 , high = 1 , size = 100 )
225
232
prediction_data = pd .DataFrame ({"input" : x_pred })
226
- pred = fitted_model_instance .predict (prediction_data ["input" ])
233
+ pred = fitted_model_instance .predict (prediction_data [[ "input" ] ])
227
234
# Perform elementwise comparison using numpy
228
235
assert isinstance (pred , np .ndarray )
229
236
assert len (pred ) > 0
@@ -256,13 +263,12 @@ def test_sample_xxx_extend_idata_param(fitted_model_instance, group, extend_idat
256
263
257
264
prediction_data = pd .DataFrame ({"input" : x_pred })
258
265
if group == "prior_predictive" :
259
- prediction_method = fitted_model_instance .sample_prior_predictive
266
+ pred = fitted_model_instance .sample_prior_predictive ( prediction_data [ "input" ], combined = False , extend_idata = extend_idata )
260
267
else : # group == "posterior_predictive":
261
- prediction_method = fitted_model_instance .sample_posterior_predictive
262
-
263
- pred = prediction_method (prediction_data ["input" ], combined = False , extend_idata = extend_idata )
268
+ pred = fitted_model_instance .sample_posterior_predictive (prediction_data ["input" ], combined = False , predictions = False , extend_idata = extend_idata )
264
269
265
270
pred_unstacked = pred [output_var ].values
271
+
266
272
idata_now = fitted_model_instance .idata [group ][output_var ].values
267
273
268
274
if extend_idata :
@@ -320,9 +326,40 @@ def test_predict_respects_predictions_flag(fitted_model_instance, predictions):
320
326
321
327
# Run prediction with predictions=True or False
322
328
fitted_model_instance .predict (
323
- prediction_data ["input" ],
329
+ X_pred = prediction_data [["input" ]],
330
+ extend_idata = True ,
331
+ predictions = predictions ,
332
+ )
333
+
334
+ pp_after = fitted_model_instance .idata .posterior_predictive [output_var ].values
335
+
336
+ # Check predictions group presence
337
+ if predictions :
338
+ assert "predictions" in fitted_model_instance .idata .groups ()
339
+ # Posterior predictive should remain unchanged
340
+ np .testing .assert_array_equal (pp_before , pp_after )
341
+ else :
342
+ assert "predictions" not in fitted_model_instance .idata .groups ()
343
+ # Posterior predictive should be updated
344
+ assert not np .array_equal (pp_before , pp_after )
345
+
346
+ @pytest .mark .parametrize ("predictions" , [True , False ])
347
+ def test_predict_posterior_respects_predictions_flag (fitted_model_instance , predictions ):
348
+ x_pred = np .random .uniform (0 , 1 , 100 )
349
+ prediction_data = pd .DataFrame ({"input" : x_pred })
350
+ output_var = fitted_model_instance .output_var
351
+
352
+ # Snapshot the original posterior_predictive values
353
+ pp_before = fitted_model_instance .idata .posterior_predictive [output_var ].values .copy ()
354
+
355
+ # Ensure 'predictions' group is not present initially
356
+ assert "predictions" not in fitted_model_instance .idata .groups ()
357
+
358
+ # Run prediction with predictions=True or False
359
+ fitted_model_instance .predict_posterior (
360
+ X_pred = prediction_data [["input" ]],
324
361
extend_idata = True ,
325
- combined = False ,
362
+ combined = True ,
326
363
predictions = predictions ,
327
364
)
328
365
@@ -336,4 +373,4 @@ def test_predict_respects_predictions_flag(fitted_model_instance, predictions):
336
373
else :
337
374
assert "predictions" not in fitted_model_instance .idata .groups ()
338
375
# Posterior predictive should be updated
339
- np .testing . assert_array_not_equal (pp_before , pp_after )
376
+ assert not np .array_equal (pp_before , pp_after )
0 commit comments