diff --git a/CHANGELOG.md b/CHANGELOG.md index c76ed4f..813bfe7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,26 @@ All notable changes to this project will be documented in this file. -## [0.16.0] - 2025-10-08 +## [0.16.1] - 2025-10-10 + +### Bug Fixes + +- Sampler.abort was broken in last release (Adrian Seyboldt) + +- Don't store unconstrained_draw if not requested (Adrian Seyboldt) + + +### Features + +- Expose step_size_jitter option (Adrian Seyboldt) + + +### Miscellaneous Tasks + +- Cargo update (Adrian Seyboldt) + + +## [0.16.0] - 2025-10-09 ### Bug Fixes diff --git a/Cargo.lock b/Cargo.lock index 11948d3..c71b5c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -518,9 +518,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.40" +version = "1.2.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1d05d92f4b1fd76aad469d46cdd858ca761576082cd37df81416691e50199fb" +checksum = "ac9fe6cdbb24b6ade63616c0a0688e45bb56732262c158df3c0c4bea4ca47cb7" dependencies = [ "find-msvc-tools", "jobserver", @@ -1051,9 +1051,9 @@ dependencies = [ [[package]] name = "find-msvc-tools" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0399f9d26e5191ce32c498bebd31e7a3ceabc2745f0ac54af3f335126c3f24b3" +checksum = "52051878f80a721bb68ebfbc930e07b65ba72f2da88968ea5c06fd6ca3d3a127" [[package]] name = "flatbuffers" @@ -1358,14 +1358,15 @@ dependencies = [ [[package]] name = "half" -version = "2.6.0" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" +checksum = "e54c115d4f30f52c67202f079c5f9d8b49db4691f460fdb0b4c2e838261b2ba5" dependencies = [ "bytemuck", "cfg-if", "crunchy", "num-traits", + "zerocopy", ] [[package]] @@ -1823,9 +1824,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.176" +version = "0.2.177" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58f929b4d672ea937a23a1ab494143d968337a5f47e56d0815df1e0890ddf174" +checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" [[package]] name = "libloading" @@ -2210,7 +2211,7 @@ dependencies = [ [[package]] name = "nutpie" -version = "0.16.0" +version = "0.16.1" dependencies = [ "anyhow", "arrow", @@ -3321,9 +3322,9 @@ dependencies = [ [[package]] name = "stable_deref_trait" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" [[package]] name = "strum" diff --git a/Cargo.toml b/Cargo.toml index b8410b4..a50a79d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "nutpie" -version = "0.16.0" +version = "0.16.1" authors = [ "Adrian Seyboldt ", "PyMC Developers ", diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index edc80f0..64d35f9 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -538,7 +538,7 @@ def _extract(self, results): skip_vars = [] skips = { "store_gradient": ["gradient"], - "store_unconstrained": ["unconstrained"], + "store_unconstrained": ["unconstrained_draw"], "store_mass_matrix": [ "mass_matrix_inv", "mass_matrix_eigvals", @@ -590,7 +590,7 @@ def is_finished(self): def abort(self): """Abort sampling and return the trace produced so far.""" self._sampler.abort() - results = self._sampler.extract_results() + results = self._sampler.take_results() return self._extract(results) def cancel(self): diff --git a/src/wrapper.rs b/src/wrapper.rs index 37322d6..ecab863 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -784,6 +784,35 @@ impl PyNutsSettings { }; Ok(()) } + + #[getter(step_size_jitter)] + fn step_size_jitter(&self) -> Option { + match &self.inner { + Settings::LowRank(inner) => inner.adapt_options.step_size_settings.jitter, + Settings::Diag(inner) => inner.adapt_options.step_size_settings.jitter, + Settings::Transforming(inner) => inner.adapt_options.step_size_settings.jitter, + } + } + + #[setter(step_size_jitter)] + fn set_step_size_jitter(&mut self, mut val: Option) -> PyResult<()> { + if let Some(val) = val { + if val < 0.0 { + return Err(PyValueError::new_err("step_size_jitter must be positive")); + } + } + if let Some(jitter) = val { + if jitter == 0.0 { + val = None; + } + } + match &mut self.inner { + Settings::LowRank(inner) => inner.adapt_options.step_size_settings.jitter = val, + Settings::Diag(inner) => inner.adapt_options.step_size_settings.jitter = val, + Settings::Transforming(inner) => inner.adapt_options.step_size_settings.jitter = val, + } + Ok(()) + } } pub(crate) enum SamplerState { diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 8d697b0..f5bd567 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -192,7 +192,7 @@ def test_wait_timeout(backend, gradient_backend): @pytest.mark.timeout(20) def test_pause(backend, gradient_backend): with pm.Model() as model: - pm.Normal("a", shape=100_000) + pm.Normal("a", shape=10_000) compiled = nutpie.compile_pymc_model( model, backend=backend, gradient_backend=gradient_backend ) @@ -204,6 +204,23 @@ def test_pause(backend, gradient_backend): assert start - time.time() < 5 +@pytest.mark.pymc +@parameterize_backends +@pytest.mark.timeout(20) +def test_abort(backend, gradient_backend): + with pm.Model() as model: + pm.Normal("a", shape=10_000) + compiled = nutpie.compile_pymc_model( + model, backend=backend, gradient_backend=gradient_backend + ) + start = time.time() + sampler = nutpie.sample(compiled, chains=1, blocking=False) + sampler.pause() + sampler.resume() + sampler.abort() + assert start - time.time() < 5 + + @pytest.mark.pymc @parameterize_backends def test_pymc_model_with_coordinate(backend, gradient_backend): @@ -421,7 +438,6 @@ def test_missing(backend, gradient_backend): model, backend=backend, gradient_backend=gradient_backend ) tr = nutpie.sample(compiled, chains=1, seed=1) - print(tr.posterior) assert hasattr(tr.posterior, "y_unobserved")