Skip to content

Commit 6bc0d20

Browse files
committed
Dataset.dims -> Dataset.sizes
1 parent 774ddff commit 6bc0d20

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

pymc_experimental/tests/test_blackjax_smc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def test_arviz_from_particles():
181181
with model:
182182
inference_data = arviz_from_particles(model, particles)
183183

184-
assert inference_data.posterior.dims == Frozen({"chain": 1, "draw": 3, "x_dim_0": 2})
184+
assert inference_data.posterior.sizes == Frozen({"chain": 1, "draw": 3, "x_dim_0": 2})
185185
assert inference_data.posterior.data_vars.dtypes == Frozen(
186186
{"x": dtype("float64"), "z": dtype("float64")}
187187
)

pymc_experimental/tests/test_linearmodel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ def test_predict_posterior(fitted_linear_model_instance, combined):
142142
n_pred = 150
143143
X_pred = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=n_pred)})
144144
pred = model.predict_posterior(X_pred, combined=combined)
145-
chains = model.idata.sample_stats.dims["chain"]
146-
draws = model.idata.sample_stats.dims["draw"]
145+
chains = model.idata.sample_stats.sizes["chain"]
146+
draws = model.idata.sample_stats.sizes["draw"]
147147
expected_shape = (n_pred, chains * draws) if combined else (chains, draws, n_pred)
148148
assert pred.shape == expected_shape
149149
assert np.issubdtype(pred.dtype, np.floating)

pymc_experimental/tests/test_model_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,8 @@ def test_sample_posterior_predictive(fitted_model_instance, combined):
238238
pred = fitted_model_instance.sample_posterior_predictive(
239239
prediction_data["input"], combined=combined, extend_idata=True
240240
)
241-
chains = fitted_model_instance.idata.sample_stats.dims["chain"]
242-
draws = fitted_model_instance.idata.sample_stats.dims["draw"]
241+
chains = fitted_model_instance.idata.sample_stats.sizes["chain"]
242+
draws = fitted_model_instance.idata.sample_stats.sizes["draw"]
243243
expected_shape = (n_pred, chains * draws) if combined else (chains, draws, n_pred)
244244
assert pred[fitted_model_instance.output_var].shape == expected_shape
245245
assert np.issubdtype(pred[fitted_model_instance.output_var].dtype, np.floating)

0 commit comments

Comments
 (0)