Skip to content

Commit f4d04c2

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent b3f9a6c commit f4d04c2

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

pymc_extras/model_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ def predict(
561561
>>> prediction_data = pd.DataFrame({'input':x_pred})
562562
>>> pred_mean = model.predict(prediction_data)
563563
"""
564-
564+
565565
X_pred = self._validate_data(X_pred)
566566

567567
posterior_predictive_samples = self.sample_posterior_predictive(

tests/test_model_builder.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,14 @@ def output_var(self):
125125
return "output"
126126

127127
def _data_setter(self, X: pd.Series | np.ndarray, y: pd.Series | np.ndarray = None):
128-
129128
with self.model:
130-
131129
X = X.values if isinstance(X, pd.Series) else X.ravel()
132-
130+
133131
pm.set_data({"x": X})
134-
132+
135133
if y is not None:
136134
y = y.values if isinstance(y, pd.Series) else y.ravel()
137-
135+
138136
pm.set_data({"y_data": y})
139137

140138
@property
@@ -263,12 +261,16 @@ def test_sample_xxx_extend_idata_param(fitted_model_instance, group, extend_idat
263261

264262
prediction_data = pd.DataFrame({"input": x_pred})
265263
if group == "prior_predictive":
266-
pred = fitted_model_instance.sample_prior_predictive(prediction_data["input"], combined=False, extend_idata=extend_idata)
264+
pred = fitted_model_instance.sample_prior_predictive(
265+
prediction_data["input"], combined=False, extend_idata=extend_idata
266+
)
267267
else: # group == "posterior_predictive":
268-
pred = fitted_model_instance.sample_posterior_predictive(prediction_data["input"], combined=False, predictions=False, extend_idata=extend_idata)
268+
pred = fitted_model_instance.sample_posterior_predictive(
269+
prediction_data["input"], combined=False, predictions=False, extend_idata=extend_idata
270+
)
269271

270272
pred_unstacked = pred[output_var].values
271-
273+
272274
idata_now = fitted_model_instance.idata[group][output_var].values
273275

274276
if extend_idata:
@@ -314,7 +316,9 @@ def test_id():
314316

315317
@pytest.mark.parametrize("predictions", [True, False])
316318
@pytest.mark.parametrize("predict_method", ["predict", "predict_posterior"])
317-
def test_predict_method_respects_predictions_flag(fitted_model_instance, predictions, predict_method):
319+
def test_predict_method_respects_predictions_flag(
320+
fitted_model_instance, predictions, predict_method
321+
):
318322
x_pred = np.random.uniform(0, 1, 100)
319323
prediction_data = pd.DataFrame({"input": x_pred})
320324
output_var = fitted_model_instance.output_var
@@ -332,7 +336,7 @@ def test_predict_method_respects_predictions_flag(fitted_model_instance, predict
332336
extend_idata=True,
333337
predictions=predictions,
334338
)
335-
else:# predict_method == "predict_posterior":
339+
else: # predict_method == "predict_posterior":
336340
fitted_model_instance.predict_posterior(
337341
X_pred=prediction_data[["input"]],
338342
extend_idata=True,
@@ -350,4 +354,3 @@ def test_predict_method_respects_predictions_flag(fitted_model_instance, predict
350354
assert "predictions" not in fitted_model_instance.idata.groups()
351355
# Posterior predictive should be updated
352356
assert not np.array_equal(pp_before, pp_after)
353-

0 commit comments

Comments
 (0)