@@ -125,16 +125,14 @@ def output_var(self):
125
125
return "output"
126
126
127
127
def _data_setter (self , X : pd .Series | np .ndarray , y : pd .Series | np .ndarray = None ):
128
-
129
128
with self .model :
130
-
131
129
X = X .values if isinstance (X , pd .Series ) else X .ravel ()
132
-
130
+
133
131
pm .set_data ({"x" : X })
134
-
132
+
135
133
if y is not None :
136
134
y = y .values if isinstance (y , pd .Series ) else y .ravel ()
137
-
135
+
138
136
pm .set_data ({"y_data" : y })
139
137
140
138
@property
@@ -263,12 +261,16 @@ def test_sample_xxx_extend_idata_param(fitted_model_instance, group, extend_idat
263
261
264
262
prediction_data = pd .DataFrame ({"input" : x_pred })
265
263
if group == "prior_predictive" :
266
- pred = fitted_model_instance .sample_prior_predictive (prediction_data ["input" ], combined = False , extend_idata = extend_idata )
264
+ pred = fitted_model_instance .sample_prior_predictive (
265
+ prediction_data ["input" ], combined = False , extend_idata = extend_idata
266
+ )
267
267
else : # group == "posterior_predictive":
268
- pred = fitted_model_instance .sample_posterior_predictive (prediction_data ["input" ], combined = False , predictions = False , extend_idata = extend_idata )
268
+ pred = fitted_model_instance .sample_posterior_predictive (
269
+ prediction_data ["input" ], combined = False , predictions = False , extend_idata = extend_idata
270
+ )
269
271
270
272
pred_unstacked = pred [output_var ].values
271
-
273
+
272
274
idata_now = fitted_model_instance .idata [group ][output_var ].values
273
275
274
276
if extend_idata :
@@ -314,7 +316,9 @@ def test_id():
314
316
315
317
@pytest .mark .parametrize ("predictions" , [True , False ])
316
318
@pytest .mark .parametrize ("predict_method" , ["predict" , "predict_posterior" ])
317
- def test_predict_method_respects_predictions_flag (fitted_model_instance , predictions , predict_method ):
319
+ def test_predict_method_respects_predictions_flag (
320
+ fitted_model_instance , predictions , predict_method
321
+ ):
318
322
x_pred = np .random .uniform (0 , 1 , 100 )
319
323
prediction_data = pd .DataFrame ({"input" : x_pred })
320
324
output_var = fitted_model_instance .output_var
@@ -332,7 +336,7 @@ def test_predict_method_respects_predictions_flag(fitted_model_instance, predict
332
336
extend_idata = True ,
333
337
predictions = predictions ,
334
338
)
335
- else :# predict_method == "predict_posterior":
339
+ else : # predict_method == "predict_posterior":
336
340
fitted_model_instance .predict_posterior (
337
341
X_pred = prediction_data [["input" ]],
338
342
extend_idata = True ,
@@ -350,4 +354,3 @@ def test_predict_method_respects_predictions_flag(fitted_model_instance, predict
350
354
assert "predictions" not in fitted_model_instance .idata .groups ()
351
355
# Posterior predictive should be updated
352
356
assert not np .array_equal (pp_before , pp_after )
353
-
0 commit comments