diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 3fef87b649..e39a5612cc 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -846,7 +846,7 @@ def test_step_vars_in_model(self): class TestType: samplers = (Metropolis, Slice, HamiltonianMC, NUTS) - @pytensor.config.change_flags({"floatX": "float64", "warn_float64": "ignore"}) + @pytensor.config.change_flags(floatX="float64", warn_float64="ignore") def test_float64(self): with pm.Model() as model: x = pm.Normal("x", initval=np.array(1.0, dtype="float64")) @@ -861,7 +861,7 @@ def test_float64(self): warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) pm.sample(draws=10, tune=10, chains=1, step=sampler()) - @pytensor.config.change_flags({"floatX": "float32", "warn_float64": "warn"}) + @pytensor.config.change_flags(floatX="float32", warn_float64="warn") def test_float32(self): with pm.Model() as model: x = pm.Normal("x", initval=np.array(1.0, dtype="float32"))