diff --git a/pints/_mcmc/_mala.py b/pints/_mcmc/_mala.py index 784a3c487..fabcfc4ff 100644 --- a/pints/_mcmc/_mala.py +++ b/pints/_mcmc/_mala.py @@ -84,7 +84,6 @@ def __init__(self, x0, sigma0=None): # Set initial state self._running = False - self._ready_for_tell = False # Current point and proposed point self._current = None @@ -225,9 +224,8 @@ def tell(self, reply): """ See :meth:`pints.SingleChainMCMC.tell()`. """ # Check if we had a proposal - if not self._ready_for_tell: + if self._proposed is None: raise RuntimeError('Tell called before proposal was set.') - self._ready_for_tell = False # Unpack reply fx, log_gradient = reply diff --git a/pints/tests/test_mcmc_mala.py b/pints/tests/test_mcmc_mala.py index 6b13a690b..a33d33dc8 100755 --- a/pints/tests/test_mcmc_mala.py +++ b/pints/tests/test_mcmc_mala.py @@ -20,7 +20,8 @@ class TestMALAMCMC(unittest.TestCase): Tests the basic methods of the MALA MCMC routine. """ - def test_method(self): + def test_short_run(self): + # Test a short run with MALA # Create log pdf log_pdf = pints.toy.GaussianLogPDF([5, 5], [[4, 1], [1, 3]]) @@ -30,9 +31,6 @@ def test_method(self): sigma = [[3, 0], [0, 3]] mcmc = pints.MALAMCMC(x0, sigma) - # This method needs sensitivities - self.assertTrue(mcmc.needs_sensitivities()) - # Perform short run chain = [] for i in range(100): @@ -47,11 +45,13 @@ def test_method(self): chain = np.array(chain) self.assertEqual(chain.shape[0], 50) self.assertEqual(chain.shape[1], len(x0)) - self.assertTrue(mcmc.acceptance_rate() >= 0.0 and - mcmc.acceptance_rate() <= 1.0) + self.assertTrue(0 <= mcmc.acceptance_rate() <= 1.0) + + def test_needs_sensitivities(self): + # This method needs sensitivities - mcmc._proposed = [1, 3] - self.assertRaises(RuntimeError, mcmc.tell, (fx, gr)) + mcmc = pints.MALAMCMC(np.array([2, 2])) + self.assertTrue(mcmc.needs_sensitivities()) def test_logging(self): # Test logging includes name and custom fields. @@ -59,29 +59,35 @@ def test_logging(self): log_pdf = pints.toy.GaussianLogPDF([5, 5], [[4, 1], [1, 3]]) x0 = [np.array([2, 2]), np.array([8, 8])] - mcmc = pints.MCMCSampling(log_pdf, 2, x0, method=pints.MALAMCMC) + mcmc = pints.MCMCController(log_pdf, 2, x0, method=pints.MALAMCMC) mcmc.set_max_iterations(5) with StreamCapture() as c: mcmc.run() text = c.text() - self.assertIn('Metropolis-Adjusted Langevin Algorithm (MALA)', - text) + self.assertIn('Metropolis-Adjusted Langevin Algorithm (MALA)', text) self.assertIn(' Accept.', text) def test_flow(self): + # Test the ask-and-tell flow log_pdf = pints.toy.GaussianLogPDF([5, 5], [[4, 1], [1, 3]]) x0 = np.array([2, 2]) # Test initial proposal is first point mcmc = pints.MALAMCMC(x0) - self.assertTrue(np.all(mcmc.ask() == mcmc._x0)) - - # Repeated asks - self.assertRaises(RuntimeError, mcmc.ask) - - # Tell without ask + self.assertTrue(np.all(mcmc.ask() == x0)) + + # Repeated asks return same point + self.assertTrue(np.all(mcmc.ask() == x0)) + self.assertTrue(np.all(mcmc.ask() == x0)) + self.assertTrue(np.all(mcmc.ask() == x0)) + for i in range(5): + mcmc.tell(log_pdf.evaluateS1(mcmc.ask())) + x1 = mcmc.ask() + self.assertTrue(np.all(mcmc.ask() == x1)) + + # Tell without ask should fail mcmc = pints.MALAMCMC(x0) self.assertRaises(RuntimeError, mcmc.tell, 0) @@ -101,8 +107,8 @@ def test_flow(self): mcmc._running = True self.assertRaises(RuntimeError, mcmc._initialise) - def test_set_hyper_parameters(self): - # Tests the parameter interface for this sampler. + def test_hyper_parameters(self): + # Tests the hyper parameter interface for this sampler. x0 = np.array([2, 2]) mcmc = pints.MALAMCMC(x0) @@ -113,17 +119,24 @@ def test_set_hyper_parameters(self): self.assertTrue(np.array_equal(mcmc.epsilon(), 0.2 * np.diag(mcmc._sigma0))) + mcmc = pints.MALAMCMC(np.array([2, 2])) self.assertEqual(mcmc.n_hyper_parameters(), 1) mcmc.set_hyper_parameters([[3, 2]]) self.assertTrue(np.array_equal(mcmc.epsilon(), [3, 2])) + mcmc.set_hyper_parameters([[5, 5]]) + self.assertTrue(np.array_equal(mcmc.epsilon(), [5, 5])) - mcmc._step_size = 5 - mcmc._scale_vector = np.array([3, 7]) - mcmc._epsilon = None + def test_epsilon(self): + # Test the epsilon methods + + mcmc = pints.MALAMCMC(np.array([2, 2]), np.array([3, 3])) mcmc.set_epsilon() - self.assertTrue(np.array_equal(mcmc.epsilon(), [15, 35])) + x = mcmc.epsilon() + self.assertAlmostEqual(x[0], 0.6) + self.assertAlmostEqual(x[1], 0.6) mcmc.set_epsilon([0.4, 0.5]) - self.assertTrue(np.array_equal(mcmc.epsilon(), [0.4, 0.5])) + self.assertTrue(np.all(mcmc.epsilon() == [0.4, 0.5])) + self.assertRaises(ValueError, mcmc.set_epsilon, 3.0) self.assertRaises(ValueError, mcmc.set_epsilon, [-2.0, 1])