diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bbf2c71..3cd2d14 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,7 +34,7 @@ jobs: 3.10 3.11 3.12 - # Leave out 3.13 on aarch due to an issue in pyo3/rust-numpy 0.23.4 + 3.13 - name: Build wheels uses: PyO3/maturin-action@v1 if: ${{ matrix.platform.target == 'aarch64' }} @@ -50,8 +50,7 @@ jobs: if: ${{ matrix.platform.target == 'x86_64' }} with: target: ${{ matrix.platform.target }} - # No py3.13 yet... - args: --release --out dist --interpreter 3.10 3.11 3.12 --zig + args: --release --out dist --interpreter 3.10 3.11 3.12 3.13 --zig sccache: ${{ !startsWith(github.ref, 'refs/tags/') }} manylinux: auto before-script-linux: | @@ -73,7 +72,7 @@ jobs: pytest - name: pytest if: ${{ !startsWith(matrix.platform.target, 'x86') && matrix.platform.target != 'ppc64' }} - uses: uraimo/run-on-arch-action@v3 + uses: uraimo/run-on-arch-action@v2 with: arch: ${{ matrix.platform.target }} distro: ubuntu22.04 @@ -88,7 +87,7 @@ jobs: source $HOME/.local/bin/env uv pip install --system -U pip pytest uv pip install --system 'nutpie[all]' --find-links dist --force-reinstall - pytest + pytest -m "not slow" # pyarrow doesn't currently seem to work on musllinux #musllinux: @@ -141,7 +140,7 @@ jobs: # pytest # - name: pytest # if: ${{ !startsWith(matrix.platform.target, 'x86') }} - # uses: uraimo/run-on-arch-action@v3 + # uses: uraimo/run-on-arch-action@v2 # with: # arch: ${{ matrix.platform.target }} # distro: alpine_latest @@ -175,7 +174,7 @@ jobs: 3.10 3.11 3.12 - # 3.13 leave out 3.13 due to a segfault + 3.13 architecture: ${{ matrix.platform.target }} - name: Install uv uses: astral-sh/setup-uv@v5 @@ -230,6 +229,7 @@ jobs: 3.10 3.11 3.12 + 3.13 - name: Install uv uses: astral-sh/setup-uv@v5 - uses: maxim-lobanov/setup-xcode@v1 diff --git a/.gitignore b/.gitignore index f0721f7..bffa4f4 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,20 @@ tvm_libs/* notebooks/*.stan notebooks/*.csv notebooks/*.hpp +notebooks/radon* perf.data* wheels .vscode/ *~ +.zed +.cargo +*traces* +.pyrightconfig.json +*.zarr +book +docs/_site +.quarto +example-iree +posteriordb +.quarto +docs/.quarto diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a3220e8..f095381 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,13 +7,14 @@ repos: hooks: - id: debug-statements - id: check-merge-conflict - - id: check-toml - - id: check-yaml - - id: debug-statements - - id: end-of-file-fixer - - id: no-commit-to-branch - args: [--branch, main] - - id: trailing-whitespace + - id: check-toml + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + exclude: "docs/_freeze" + - id: no-commit-to-branch + args: [--branch, main] + - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.9.9 diff --git a/CHANGELOG.md b/CHANGELOG.md index bdf5a0e..1f49519 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,39 @@ All notable changes to this project will be documented in this file. +## [0.14.0] - 2025-03-05 + +### Bug Fixes + +- Set 'make_initial_point_fn' in 'from_pyfunc' to None by default (#175) (Tomás Capretto) + + +### Documentation + +- Add nutpie website source (Adrian Seyboldt) + +- Include frozen cell output in docs (Adrian Seyboldt) + + +### Features + +- Add normalizing flow adaptation (Adrian Seyboldt) + + +### Miscellaneous Tasks + +- Bump actions/attest-build-provenance from 1 to 2 (dependabot[bot]) + +- Bump softprops/action-gh-release from 1 to 2 (dependabot[bot]) + +- Bump uraimo/run-on-arch-action from 2 to 3 (dependabot[bot]) + + +### Ci + +- Run python 3.13 in ci (Adrian Seyboldt) + + ## [0.13.4] - 2025-02-18 ### Bug Fixes @@ -13,6 +46,8 @@ All notable changes to this project will be documented in this file. - Make sure all python versions are available in the builds (Adrian Seyboldt) +- Skip python 3.13 for now (Adrian Seyboldt) + ## [0.13.3] - 2025-02-12 @@ -48,8 +83,6 @@ All notable changes to this project will be documented in this file. - Update pre-commit versions (Adrian Seyboldt) -- Update version and changelog (Adrian Seyboldt) - ### Styling @@ -202,6 +235,8 @@ All notable changes to this project will be documented in this file. ### Ci +- Fix uploads of releases (Adrian Seyboldt) + - Fix architectures in CI (Adrian Seyboldt) @@ -249,11 +284,6 @@ All notable changes to this project will be documented in this file. - Set the number of parallel chains dynamically (Adrian Seyboldt) -### Ci - -- Fix uploads of releases (Adrian Seyboldt) - - ## [0.9.2] - 2024-02-19 ### Bug Fixes diff --git a/Cargo.lock b/Cargo.lock index 4fe15a5..d0ee235 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,23 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + [[package]] name = "ahash" version = "0.8.11" @@ -10,10 +27,10 @@ checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", "const-random", - "getrandom", + "getrandom 0.2.15", "once_cell", "version_check", - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -54,15 +71,15 @@ checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "anyhow" -version = "1.0.95" +version = "1.0.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" +checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f" [[package]] name = "arrow" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05048a8932648b63f21c37d88b552ccc8a65afb6dfe9fc9f30ce79174c2e7a85" +checksum = "dc208515aa0151028e464cc94a692156e945ce5126abd3537bb7fd6ba2143ed1" dependencies = [ "arrow-arith", "arrow-array", @@ -78,24 +95,23 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d8a57966e43bfe9a3277984a14c24ec617ad874e4c0e1d2a1b083a39cfbf22c" +checksum = "e07e726e2b3f7816a85c6a45b6ec118eeeabf0b2a8c208122ad949437181f49a" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "chrono", - "half", "num", ] [[package]] name = "arrow-array" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16f4a9468c882dc66862cef4e1fd8423d47e67972377d85d80e022786427768c" +checksum = "a2262eba4f16c78496adfd559a29fe4b24df6088efc9985a873d58e92be022d5" dependencies = [ "ahash", "arrow-buffer", @@ -109,9 +125,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c975484888fc95ec4a632cdc98be39c085b1bb518531b0c80c5d462063e5daa1" +checksum = "4e899dade2c3b7f5642eb8366cfd898958bcca099cde6dfea543c7e8d3ad88d4" dependencies = [ "bytes", "half", @@ -120,9 +136,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da26719e76b81d8bc3faad1d4dbdc1bcc10d14704e63dc17fc9f3e7e1e567c8e" +checksum = "4103d88c5b441525ed4ac23153be7458494c2b0c9a11115848fdb9b81f6f886a" dependencies = [ "arrow-array", "arrow-buffer", @@ -140,9 +156,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd9d6f18c65ef7a2573ab498c374d8ae364b4a4edf67105357491c031f716ca5" +checksum = "0a329fb064477c9ec5f0870d2f5130966f91055c7c5bce2b3a084f116bc28c3b" dependencies = [ "arrow-buffer", "arrow-schema", @@ -152,26 +168,23 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42745f86b1ab99ef96d1c0bcf49180848a64fe2c7a7a0d945bc64fa2b21ba9bc" +checksum = "f841bfcc1997ef6ac48ee0305c4dfceb1f7c786fe31e67c1186edf775e1f1160" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "arrow-select", - "half", - "num", ] [[package]] name = "arrow-row" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd09a518c602a55bd406bcc291a967b284cfa7a63edfbf8b897ea4748aad23c" +checksum = "1eeb55b0a0a83851aa01f2ca5ee5648f607e8506ba6802577afdda9d75cdedcd" dependencies = [ - "ahash", "arrow-array", "arrow-buffer", "arrow-data", @@ -181,18 +194,18 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e972cd1ff4a4ccd22f86d3e53e835c2ed92e0eea6a3e8eadb72b4f1ac802cf8" +checksum = "85934a9d0261e0fa5d4e2a5295107d743b543a6e0484a835d4b8db2da15306f9" dependencies = [ "bitflags", ] [[package]] name = "arrow-select" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "600bae05d43483d216fb3494f8c32fdbefd8aa4e1de237e790dbb3d9f44690a3" +checksum = "7e2932aece2d0c869dd2125feb9bd1709ef5c445daa3838ac4112dcfa0fda52c" dependencies = [ "ahash", "arrow-array", @@ -204,9 +217,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0dc1985b67cb45f6606a248ac2b4a288849f196bab8c657ea5589f47cdd55e6" +checksum = "912e38bd6a7a7714c1d9b61df80315685553b7455e8a6045c27531d8ecd5b458" dependencies = [ "arrow-array", "arrow-buffer", @@ -240,6 +253,12 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + [[package]] name = "bindgen" version = "0.71.1" @@ -255,16 +274,16 @@ dependencies = [ "proc-macro2", "quote", "regex", - "rustc-hash 2.1.1", + "rustc-hash", "shlex", - "syn 2.0.98", + "syn", ] [[package]] name = "bitflags" -version = "2.8.0" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" +checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" [[package]] name = "block-buffer" @@ -285,7 +304,7 @@ dependencies = [ "libloading", "log", "path-absolutize", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -296,9 +315,9 @@ checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" [[package]] name = "bytemuck" -version = "1.21.0" +version = "1.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef657dfab802224e671f5818e9a4935f9b1957ed18e58292690cc39e7a4092a3" +checksum = "b6b1fc10dbac614ebc03540c9dbd60e83887fda27794998c6528f1782047d540" dependencies = [ "bytemuck_derive", ] @@ -311,7 +330,7 @@ checksum = "3fa76293b4f7bb636ab88fd78228235b5248b4d05cc589aed610f954af5d7c7a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn", ] [[package]] @@ -322,9 +341,29 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.10.0" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" + +[[package]] +name = "bzip2" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f61dac84819c6588b558454b194026eb1f09c293b9036ae9b159e74e73ab6cf9" +checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.13+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225bff33b2141874fe80d71e07d6eec4f85c5c216453dd96388240f96e1acc14" +dependencies = [ + "cc", + "pkg-config", +] [[package]] name = "cast" @@ -334,10 +373,12 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.14" +version = "1.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c3d1b2e905a3a7b00a6141adb0e4c0bb941d11caf55349d863942a1cc44e3c9" +checksum = "be714c154be609ec7f5dad223a33bf1482fff90472de28f7362806e6d4832b8c" dependencies = [ + "jobserver", + "libc", "shlex", ] @@ -395,6 +436,16 @@ dependencies = [ "half", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -408,18 +459,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.30" +version = "4.5.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92b7b18d71fad5313a1e320fa9897994228ce274b60faa4d694fe0ea89cd9e6d" +checksum = "027bb0d98429ae334a8698531da7077bdf906419543a35a55c2cb1b66437d767" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.30" +version = "4.5.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a35db2071778a7344791a4fb4f95308b5673d219dee3ae348b86642574ecc90c" +checksum = "5589e0cba072e0f3d23791efac0fd8627b49c829c196a492e88168e6a669d863" dependencies = [ "anstyle", "clap_lex", @@ -431,17 +482,11 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" -[[package]] -name = "coe-rs" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e8f1e641542c07631228b1e0dc04b69ae3c1d58ef65d5691a439711d805c698" - [[package]] name = "console" -version = "0.15.10" +version = "0.15.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea3c6ecd8059b57859df5c69830340ed3c41d30e3da0c1cbed90a96ac853041b" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" dependencies = [ "encode_unicode", "libc", @@ -465,11 +510,17 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ - "getrandom", + "getrandom 0.2.15", "once_cell", "tiny-keccak", ] +[[package]] +name = "constant_time_eq" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -485,6 +536,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + [[package]] name = "criterion" version = "0.5.1" @@ -563,10 +623,13 @@ dependencies = [ ] [[package]] -name = "dbgf" -version = "0.1.2" +name = "deranged" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6ca96b45ca70b8045e0462f191bd209fcb3c3bfe8dbfb1257ada54c4dd59169" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", +] [[package]] name = "digest" @@ -576,16 +639,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", -] - -[[package]] -name = "dyn-stack" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56e53799688f5632f364f8fb387488dd05db9fe45db7011be066fc20e7027f8b" -dependencies = [ - "bytemuck", - "reborrow", + "subtle", ] [[package]] @@ -599,9 +653,9 @@ dependencies = [ [[package]] name = "either" -version = "1.13.0" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +checksum = "b7914353092ddf589ad78f25c5c1c21b7f80b0ff8621e7c814c3485b5306da9d" [[package]] name = "encode_unicode" @@ -615,10 +669,10 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", - "syn 2.0.98", + "syn", ] [[package]] @@ -647,7 +701,7 @@ checksum = "3bf679796c0322556351f287a51b49e48f7c4986e727b5dd78c972d30e2e16cc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn", ] [[package]] @@ -658,59 +712,76 @@ checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn", ] [[package]] name = "faer" -version = "0.19.4" +version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64bc4855cb2792ae3520e8af22051a47a6d6dc8300ebc0ddf51ad73f65bd0dc9" +checksum = "d671941ab57443f46ebe3f153a9fc3ed6cce777926c14e5fdf5da178a35ea476" dependencies = [ "bytemuck", - "coe-rs", - "dbgf", - "dyn-stack 0.10.0", + "dyn-stack", "equator 0.4.2", - "faer-entity", + "faer-macros", + "faer-traits", "gemm", + "generativity", "libm", - "matrixcompare", - "matrixcompare-core", "nano-gemm", "npyz", "num-complex", "num-traits", - "paste", - "rand", - "rand_distr", - "rayon", + "pulp", "reborrow", - "serde", ] [[package]] -name = "faer-entity" -version = "0.19.2" +name = "faer-macros" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d0a255d1442b5825c61812a7eafda9034ec53d969c98555251085e148428e6a" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "faer-traits" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9c752ab2bff6f0b9597c6a1adc0112f7fd41fb343bc5a009a6274ae9d32fd03" +checksum = "a2d0172aefb5f869561e558d5390657f1aa98ca3c51a09be69a4687064ebfb9a" dependencies = [ "bytemuck", - "coe-rs", + "dyn-stack", + "faer-macros", + "generativity", "libm", "num-complex", "num-traits", - "pulp 0.18.22", + "pulp", "reborrow", ] +[[package]] +name = "flate2" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "gemm" version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab96b703d31950f1aeddded248bc95543c9efc7ac9c4a21fda8703a83ee35451" dependencies = [ - "dyn-stack 0.13.0", + "dyn-stack", "gemm-c32", "gemm-c64", "gemm-common", @@ -730,7 +801,7 @@ version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6db9fd9f40421d00eea9dd0770045a5603b8d684654816637732463f4073847" dependencies = [ - "dyn-stack 0.13.0", + "dyn-stack", "gemm-common", "num-complex", "num-traits", @@ -745,7 +816,7 @@ version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfcad8a3d35a43758330b635d02edad980c1e143dc2f21e6fd25f9e4eada8edf" dependencies = [ - "dyn-stack 0.13.0", + "dyn-stack", "gemm-common", "num-complex", "num-traits", @@ -761,16 +832,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a352d4a69cbe938b9e2a9cb7a3a63b7e72f9349174a2752a558a8a563510d0f3" dependencies = [ "bytemuck", - "dyn-stack 0.13.0", + "dyn-stack", "half", "libm", "num-complex", "num-traits", "once_cell", "paste", - "pulp 0.21.4", + "pulp", "raw-cpuid", - "rayon", "seq-macro", "sysctl", ] @@ -781,7 +851,7 @@ version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cff95ae3259432f3c3410eaa919033cd03791d81cebd18018393dc147952e109" dependencies = [ - "dyn-stack 0.13.0", + "dyn-stack", "gemm-common", "gemm-f32", "half", @@ -789,7 +859,6 @@ dependencies = [ "num-traits", "paste", "raw-cpuid", - "rayon", "seq-macro", ] @@ -799,7 +868,7 @@ version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc8d3d4385393304f407392f754cd2dc4b315d05063f62cf09f47b58de276864" dependencies = [ - "dyn-stack 0.13.0", + "dyn-stack", "gemm-common", "num-complex", "num-traits", @@ -814,7 +883,7 @@ version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "35b2a4f76ce4b8b16eadc11ccf2e083252d8237c1b589558a49b0183545015bd" dependencies = [ - "dyn-stack 0.13.0", + "dyn-stack", "gemm-common", "num-complex", "num-traits", @@ -823,6 +892,12 @@ dependencies = [ "seq-macro", ] +[[package]] +name = "generativity" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5881e4c3c2433fe4905bb19cfd2b5d49d4248274862b68c27c33d9ba4e13f9ec" + [[package]] name = "generic-array" version = "0.14.7" @@ -841,7 +916,19 @@ checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", +] + +[[package]] +name = "getrandom" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.13.3+wasi-0.2.2", + "windows-targets", ] [[package]] @@ -864,15 +951,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.5" +version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" - -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" [[package]] name = "heck" @@ -886,6 +967,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "iana-time-zone" version = "0.1.61" @@ -924,9 +1014,18 @@ dependencies = [ [[package]] name = "indoc" -version = "2.0.5" +version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" +checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" + +[[package]] +name = "inout" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" +dependencies = [ + "generic-array", +] [[package]] name = "is-terminal" @@ -968,9 +1067,18 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + +[[package]] +name = "jobserver" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +dependencies = [ + "libc", +] [[package]] name = "js-sys" @@ -982,11 +1090,17 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + [[package]] name = "lexical-core" -version = "0.8.5" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" +checksum = "b765c31809609075565a70b4b71402281283aeda7ecaf4818ac14a7b2ade8958" dependencies = [ "lexical-parse-float", "lexical-parse-integer", @@ -997,9 +1111,9 @@ dependencies = [ [[package]] name = "lexical-parse-float" -version = "0.8.5" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f" +checksum = "de6f9cb01fb0b08060209a057c048fcbab8717b4c1ecd2eac66ebfe39a65b0f2" dependencies = [ "lexical-parse-integer", "lexical-util", @@ -1008,9 +1122,9 @@ dependencies = [ [[package]] name = "lexical-parse-integer" -version = "0.8.6" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9" +checksum = "72207aae22fc0a121ba7b6d479e42cbfea549af1479c3f3a4f12c70dd66df12e" dependencies = [ "lexical-util", "static_assertions", @@ -1018,18 +1132,18 @@ dependencies = [ [[package]] name = "lexical-util" -version = "0.8.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc" +checksum = "5a82e24bf537fd24c177ffbbdc6ebcc8d54732c35b50a3f28cc3f4e4c949a0b3" dependencies = [ "static_assertions", ] [[package]] name = "lexical-write-float" -version = "0.8.5" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" +checksum = "c5afc668a27f460fb45a81a757b6bf2f43c2d7e30cb5a2dcd3abf294c78d62bd" dependencies = [ "lexical-util", "lexical-write-integer", @@ -1038,9 +1152,9 @@ dependencies = [ [[package]] name = "lexical-write-integer" -version = "0.8.5" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" +checksum = "629ddff1a914a836fb245616a7888b62903aae58fa771e1d83943035efa0f978" dependencies = [ "lexical-util", "static_assertions", @@ -1048,9 +1162,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.169" +version = "0.2.170" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +checksum = "875b3680cb2f8f71bdcf9a30f38d48282f5d3c95cbf9b3fa57269bb5d5c06828" [[package]] name = "libloading" @@ -1068,37 +1182,11 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" -[[package]] -name = "lock_api" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" -dependencies = [ - "autocfg", - "scopeguard", -] - [[package]] name = "log" -version = "0.4.25" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" - -[[package]] -name = "matrixcompare" -version = "0.3.0" +version = "0.4.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37832ba820e47c93d66b4360198dccb004b43c74abc3ac1ce1fed54e65a80445" -dependencies = [ - "matrixcompare-core", - "num-traits", -] - -[[package]] -name = "matrixcompare-core" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0bdabb30db18805d5290b3da7ceaccbddba795620b86c02145d688e04900a73" +checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" [[package]] name = "matrixmultiply" @@ -1131,11 +1219,20 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" +[[package]] +name = "miniz_oxide" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5" +dependencies = [ + "adler2", +] + [[package]] name = "multiversion" -version = "0.7.4" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4851161a11d3ad0bf9402d90ffc3967bf231768bfd7aeb61755ad06dbf1a142" +checksum = "7edb7f0ff51249dfda9ab96b5823695e15a052dc15074c9dbf3d118afaf2c201" dependencies = [ "multiversion-macros", "target-features", @@ -1143,13 +1240,13 @@ dependencies = [ [[package]] name = "multiversion-macros" -version = "0.7.4" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79a74ddee9e0c27d2578323c13905793e91622148f138ba29738f9dddb835e90" +checksum = "b093064383341eb3271f42e381cb8f10a01459478446953953c75d24bd339fc0" dependencies = [ "proc-macro2", "quote", - "syn 1.0.109", + "syn", "target-features", ] @@ -1225,14 +1322,16 @@ dependencies = [ [[package]] name = "ndarray" -version = "0.15.6" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" dependencies = [ "matrixmultiply", "num-complex", "num-integer", "num-traits", + "portable-atomic", + "portable-atomic-util", "rawpointer", ] @@ -1289,9 +1388,14 @@ checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" dependencies = [ "bytemuck", "num-traits", - "rand", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-integer" version = "0.1.46" @@ -1341,9 +1445,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "numpy" -version = "0.21.0" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec170733ca37175f5d75a5bea5911d6ff45d2cd52849ce98b685394e4f2f37f4" +checksum = "b94caae805f998a07d33af06e6a3891e38556051b8045c615470a71590e13e78" dependencies = [ "libc", "ndarray", @@ -1351,12 +1455,12 @@ dependencies = [ "num-integer", "num-traits", "pyo3", - "rustc-hash 1.1.0", + "rustc-hash", ] [[package]] name = "nutpie" -version = "0.13.4" +version = "0.14.0" dependencies = [ "anyhow", "arrow", @@ -1367,33 +1471,34 @@ dependencies = [ "numpy", "nuts-rs", "pyo3", - "rand", - "rand_chacha", + "rand 0.9.0", + "rand_chacha 0.9.0", "rand_distr", "rayon", "smallvec", - "thiserror 1.0.69", + "tch", + "thiserror 2.0.12", "time-humanize", "upon", ] [[package]] name = "nuts-rs" -version = "0.12.1" +version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8573e3b5c83e8ec0570ebbd75dd6fdc7dfcfa5da9b5f9d9d63fedefebbd9cf8" +checksum = "10e87924d332fce1202087bc67db7ed8f7ef9229da5ec74a5130568f5b7f6ac7" dependencies = [ "anyhow", "arrow", "faer", - "itertools 0.13.0", + "itertools 0.14.0", "multiversion", - "pulp 0.18.22", - "rand", - "rand_chacha", + "pulp", + "rand 0.9.0", + "rand_chacha 0.9.0", "rand_distr", "rayon", - "thiserror 1.0.69", + "thiserror 2.0.12", ] [[package]] @@ -1409,26 +1514,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" [[package]] -name = "parking_lot" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.10" +name = "password-hash" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700" dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-targets", + "base64ct", + "rand_core 0.6.4", + "subtle", ] [[package]] @@ -1455,6 +1548,18 @@ dependencies = [ "once_cell", ] +[[package]] +name = "pbkdf2" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917" +dependencies = [ + "digest", + "hmac", + "password-hash", + "sha2", +] + [[package]] name = "pest" version = "2.7.15" @@ -1462,7 +1567,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b7cafe60d6cf8e62e1b9b2ea516a089c008945bb5a275416789e7db0bc199dc" dependencies = [ "memchr", - "thiserror 2.0.11", + "thiserror 2.0.12", "ucd-trie", ] @@ -1486,7 +1591,7 @@ dependencies = [ "pest_meta", "proc-macro2", "quote", - "syn 2.0.98", + "syn", ] [[package]] @@ -1500,6 +1605,12 @@ dependencies = [ "sha2", ] +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + [[package]] name = "plotters" version = "0.3.7" @@ -1530,9 +1641,24 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.10.0" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "powerfmt" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" @@ -1540,40 +1666,28 @@ version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy", + "zerocopy 0.7.35", ] [[package]] name = "prettyplease" -version = "0.2.29" +version = "0.2.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" +checksum = "f1ccf34da56fc294e7d4ccf69a85992b7dfb826b7cf57bac6a70bba3494cc08a" dependencies = [ "proc-macro2", - "syn 2.0.98", + "syn", ] [[package]] name = "proc-macro2" -version = "1.0.93" +version = "1.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" +checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84" dependencies = [ "unicode-ident", ] -[[package]] -name = "pulp" -version = "0.18.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0a01a0dc67cf4558d279f0c25b0962bd08fc6dec0137699eae304103e882fe6" -dependencies = [ - "bytemuck", - "libm", - "num-complex", - "reborrow", -] - [[package]] name = "pulp" version = "0.21.4" @@ -1603,16 +1717,16 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.21.2" +version = "0.23.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" +checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872" dependencies = [ "anyhow", "cfg-if", "indoc", "libc", "memoffset", - "parking_lot", + "once_cell", "portable-atomic", "pyo3-build-config", "pyo3-ffi", @@ -1622,9 +1736,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.21.2" +version = "0.23.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" +checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb" dependencies = [ "once_cell", "target-lexicon", @@ -1632,9 +1746,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.21.2" +version = "0.23.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" +checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d" dependencies = [ "libc", "pyo3-build-config", @@ -1642,34 +1756,34 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.21.2" +version = "0.23.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" +checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da" dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.98", + "syn", ] [[package]] name = "pyo3-macros-backend" -version = "0.21.2" +version = "0.23.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" +checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028" dependencies = [ - "heck 0.4.1", + "heck", "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.98", + "syn", ] [[package]] name = "quote" -version = "1.0.38" +version = "1.0.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" +checksum = "c1f1914ce909e1658d9907913b4b91947430c7d9be598b15a1912935b8c04801" dependencies = [ "proc-macro2", ] @@ -1681,8 +1795,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.3", + "zerocopy 0.8.21", ] [[package]] @@ -1692,7 +1817,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.3", ] [[package]] @@ -1701,24 +1836,33 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.15", +] + +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom 0.3.1", ] [[package]] name = "rand_distr" -version = "0.4.3" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" dependencies = [ "num-traits", - "rand", + "rand 0.9.0", ] [[package]] name = "raw-cpuid" -version = "11.4.0" +version = "11.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "529468c1335c1c03919960dfefdb1b3648858c20d7ec2d0663e728e4a717efbc" +checksum = "c6df7ab838ed27997ba19a4664507e6f82b41fe6e20be42929332156e5e85146" dependencies = [ "bitflags", ] @@ -1755,15 +1899,6 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" -[[package]] -name = "redox_syscall" -version = "0.5.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" -dependencies = [ - "bitflags", -] - [[package]] name = "regex" version = "1.11.1" @@ -1793,12 +1928,6 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" -[[package]] -name = "rustc-hash" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" - [[package]] name = "rustc-hash" version = "2.1.1" @@ -1807,15 +1936,25 @@ checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rustversion" -version = "1.0.19" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" +checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" [[package]] name = "ryu" -version = "1.0.19" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + +[[package]] +name = "safetensors" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df" +dependencies = [ + "serde", + "serde_json", +] [[package]] name = "same-file" @@ -1826,43 +1965,37 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - [[package]] name = "seq-macro" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" +checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" [[package]] name = "serde" -version = "1.0.217" +version = "1.0.218" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" +checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.217" +version = "1.0.218" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" +checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn", ] [[package]] name = "serde_json" -version = "1.0.138" +version = "1.0.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" +checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" dependencies = [ "itoa", "memchr", @@ -1870,6 +2003,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.8" @@ -1900,21 +2044,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" [[package]] -name = "syn" -version = "1.0.109" +name = "subtle" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.98" +version = "2.0.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1" +checksum = "e02e925281e18ffd9d640e234264753c43edc62d64b2d4cf898f1bc5e75f3fc2" dependencies = [ "proc-macro2", "quote", @@ -1947,6 +2086,23 @@ version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" +[[package]] +name = "tch" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa1ed622c8f13b0c42f8b1afa0e5e9ccccd82ecb6c0e904120722ab52fdc5234" +dependencies = [ + "half", + "lazy_static", + "libc", + "ndarray", + "rand 0.8.5", + "safetensors", + "thiserror 1.0.69", + "torch-sys", + "zip", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -1958,11 +2114,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.11" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" dependencies = [ - "thiserror-impl 2.0.11", + "thiserror-impl 2.0.12", ] [[package]] @@ -1973,20 +2129,39 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn", ] [[package]] name = "thiserror-impl" -version = "2.0.11" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn", +] + +[[package]] +name = "time" +version = "0.3.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb041120f25f8fbe8fd2dbe4671c7c2ed74d83be2e7a77529bf7e0790ae3f472" +dependencies = [ + "deranged", + "num-conv", + "powerfmt", + "serde", + "time-core", ] +[[package]] +name = "time-core" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "765c97a5b985b7c11d7bc27fa927dc4fe6af3a6dfb021d28deb60d3bf51e76ef" + [[package]] name = "time-humanize" version = "0.1.3" @@ -2012,6 +2187,18 @@ dependencies = [ "serde_json", ] +[[package]] +name = "torch-sys" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef14f5d239e3d60f4919f536a5dfe1d4f71b27b7abf6fe6875fd3a4b22c2dcd5" +dependencies = [ + "anyhow", + "cc", + "libc", + "zip", +] + [[package]] name = "typenum" version = "1.18.0" @@ -2026,9 +2213,9 @@ checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" [[package]] name = "unicode-ident" -version = "1.0.16" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" [[package]] name = "unicode-width" @@ -2038,15 +2225,15 @@ checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" [[package]] name = "unindent" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" [[package]] name = "upon" -version = "0.8.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fe29601d1624f104fa9a35ea71a5f523dd8bd1cfc8c31f8124ad2b829f013c0" +checksum = "cc1243af2969e332d5b9b99087eddd44d04a41da8630ed53e06df497b7f5c747" [[package]] name = "version_check" @@ -2070,6 +2257,15 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.13.3+wasi-0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "wasm-bindgen" version = "0.2.100" @@ -2092,7 +2288,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn 2.0.98", + "syn", "wasm-bindgen-shared", ] @@ -2114,7 +2310,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2239,6 +2435,15 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags", +] + [[package]] name = "zerocopy" version = "0.7.35" @@ -2246,7 +2451,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "byteorder", - "zerocopy-derive", + "zerocopy-derive 0.7.35", +] + +[[package]] +name = "zerocopy" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcf01143b2dd5d134f11f545cf9f1431b13b749695cb33bcce051e7568f99478" +dependencies = [ + "zerocopy-derive 0.8.21", ] [[package]] @@ -2257,5 +2471,65 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712c8386f4f4299382c9abee219bee7084f78fb939d88b6840fcc1320d5f6da2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zip" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" +dependencies = [ + "aes", + "byteorder", + "bzip2", + "constant_time_eq", + "crc32fast", + "crossbeam-utils", + "flate2", + "hmac", + "pbkdf2", + "sha1", + "time", + "zstd", +] + +[[package]] +name = "zstd" +version = "0.11.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "5.0.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db" +dependencies = [ + "libc", + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.14+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fb060d4926e4ac3a3ad15d864e99ceb5f343c6b34f5bd6d81ae6ed417311be5" +dependencies = [ + "cc", + "pkg-config", ] diff --git a/Cargo.toml b/Cargo.toml index cc452c4..587aa2f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "nutpie" -version = "0.13.4" +version = "0.14.0" authors = [ "Adrian Seyboldt ", "PyMC Developers ", @@ -22,25 +22,26 @@ name = "_lib" crate-type = ["cdylib"] [dependencies] -nuts-rs = "0.12.1" -numpy = "0.21.0" -rand = "0.8.5" -thiserror = "1.0.44" -rand_chacha = "0.3.1" -rayon = "1.9.0" +nuts-rs = "0.15.0" +numpy = "0.23.0" +rand = "0.9.0" +thiserror = "2.0.3" +rand_chacha = "0.9.0" +rayon = "1.10.0" # Keep arrow in sync with nuts-rs requirements -arrow = { version = "52.0.0", default-features = false, features = ["ffi"] } +arrow = { version = "54.1.0", default-features = false, features = ["ffi"] } anyhow = "1.0.72" itertools = "0.14.0" bridgestan = "2.6.1" -rand_distr = "0.4.3" -smallvec = "1.11.0" -upon = { version = "0.8.1", default-features = false, features = [] } +rand_distr = "0.5.0" +smallvec = "1.13.0" +upon = { version = "0.9.0", default-features = false, features = [] } time-humanize = { version = "0.1.3", default-features = false } indicatif = "0.17.8" +tch = { version = "0.19.0", optional = true } [dependencies.pyo3] -version = "0.21.0" +version = "0.23.4" features = ["extension-module", "anyhow"] [dev-dependencies] diff --git a/cliff.toml b/cliff.toml index a4e8993..01b9142 100644 --- a/cliff.toml +++ b/cliff.toml @@ -45,7 +45,6 @@ split_commits = false commit_preprocessors = [ # { pattern = '\((\w+\s)?#([0-9]+)\)', replace = "([#${2}](https://github.com/orhun/git-cliff/issues/${2}))"}, # replace issue numbers ] -# regex for parsing and grouping commits commit_parsers = [ { message = "^feat", group = "Features" }, { message = "^fix", group = "Bug Fixes" }, @@ -55,9 +54,10 @@ commit_parsers = [ { message = "^style", group = "Styling" }, { message = "^test", group = "Testing" }, { message = "^chore: Prepare", skip = true }, + { message = "^chore\\(release\\)", skip = true }, { message = "^chore", group = "Miscellaneous Tasks" }, { body = ".*security", group = "Security" }, -] +] # regex for parsing and grouping commits # protect breaking changes from being skipped due to matching a skipping commit_parser protect_breaking_commits = false # filter out the commits that are not matched by commit parsers diff --git a/docs/_freeze/index/execute-results/html.json b/docs/_freeze/index/execute-results/html.json new file mode 100644 index 0000000..37d29d5 --- /dev/null +++ b/docs/_freeze/index/execute-results/html.json @@ -0,0 +1,16 @@ +{ + "hash": "94e4388705073729b94725a15410d650", + "result": { + "engine": "jupyter", + "markdown": "---\ntitle: Nutpie Documentation\n---\n\n\n\n`nutpie` is a high-performance library designed for Bayesian inference, that\nprovides efficient sampling algorithms for probabilistic models. It can sample\nmodels that are defined in PyMC or Stan (numpyro and custom hand-coded\nlikelihoods with gradient are coming soon).\n\n- Faster sampling than either the PyMC or Stan default samplers. (An average\n ~2x speedup on `posteriordb` compared to Stan)\n- All the diagnostic information of PyMC and Stan and some more.\n- GPU support for PyMC models through jax.\n- A more informative progress bar.\n- Access to the incomplete trace during sampling.\n- *Experimental* normalizing flow adaptation for more efficient sampling of\n difficult posteriors.\n\n## Quickstart: PyMC\n\nInstall `nutpie` with pip, uv, pixi, or conda:\n\nFor usage with pymc:\n\n```bash\n# One of\npip install \"nutpie[pymc]\"\nuv add \"nutpie[pymc]\"\npixi add nutpie pymc numba\nconda install -c conda-forge nutpie pymc numba\n```\n\nAnd then sample with\n\n\n::: {#1c2d97ba .cell execution_count=1}\n``` {.python .cell-code}\nimport nutpie\nimport pymc as pm\n\nwith pm.Model() as model:\n mu = pm.Normal(\"mu\", mu=0, sigma=1)\n obs = pm.Normal(\"obs\", mu=mu, sigma=1, observed=[1, 2, 3])\n\ncompiled = nutpie.compile_pymc_model(model)\ntrace = nutpie.sample(compiled)\n```\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n\n```\n:::\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n
\n

Sampler Progress

\n

Total Chains: 6

\n

Active Chains: 0

\n

\n Finished Chains:\n 6\n

\n

Sampling for now

\n

\n Estimated Time to Completion:\n now\n

\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
ProgressDrawsDivergencesStep SizeGradients/Draw
\n \n \n 140001.351
\n \n \n 140001.283
\n \n \n 140001.293
\n \n \n 140001.233
\n \n \n 140001.403
\n \n \n 140001.281
\n
\n```\n:::\n:::\n\n\nFor more information, see the detailed [PyMC usage guide](pymc-usage.qmd).\n\n## Quickstart: Stan\n\nStan needs access to a compiler toolchain, you can find instructions for those\n[here](https://mc-stan.org/docs/cmdstan-guide/installation.html#cpp-toolchain).\nYou can then install nutpie through pip or uv:\n\n```bash\n# One of\npip install \"nutpie[stan]\"\nuv add \"nutpie[stan]\"\n```\n\n\n\n::: {#700ed270 .cell execution_count=3}\n``` {.python .cell-code}\nimport nutpie\n\nmodel = \"\"\"\ndata {\n int N;\n vector[N] y;\n}\nparameters {\n real mu;\n}\nmodel {\n mu ~ normal(0, 1);\n y ~ normal(mu, 1);\n}\n\"\"\"\n\ncompiled = (\n nutpie\n .compile_stan_model(code=model)\n .with_data(N=3, y=[1, 2, 3])\n)\ntrace = nutpie.sample(compiled)\n```\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n\n```\n:::\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n
\n

Sampler Progress

\n

Total Chains: 6

\n

Active Chains: 0

\n

\n Finished Chains:\n 6\n

\n

Sampling for now

\n

\n Estimated Time to Completion:\n now\n

\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
ProgressDrawsDivergencesStep SizeGradients/Draw
\n \n \n 140001.291
\n \n \n 140001.273
\n \n \n 140001.343
\n \n \n 140001.331
\n \n \n 140001.413
\n \n \n 140001.293
\n
\n```\n:::\n:::\n\n\nFor more information, see the detailed [Stan usage guide](stan-usage.qmd).\n\n", + "supporting": [ + "index_files" + ], + "filters": [], + "includes": { + "include-in-header": [ + "\n\n\n" + ] + } + } +} \ No newline at end of file diff --git a/docs/_freeze/nf-adapt/execute-results/html.json b/docs/_freeze/nf-adapt/execute-results/html.json new file mode 100644 index 0000000..fc75744 --- /dev/null +++ b/docs/_freeze/nf-adapt/execute-results/html.json @@ -0,0 +1,16 @@ +{ + "hash": "99a8749bdb41e64f77fd32de347bbc2a", + "result": { + "engine": "jupyter", + "markdown": "---\ntitle: Adaptation with Normalizing Flows\n---\n\n\n\n**Experimental and subject to change**\n\nNormalizing flow adaptation through Fisher HMC is a new sampling algorithm that\nautomatically reparameterizes a model. It adds some computational cost outside\nmodel log-density evaluations, but allows sampling from much more difficult\nposterior distributions. For models with expensive log-density evaluations, the\nnormalizing flow adaptation can also be much faster, if it can reduce the number\nof log-density evaluations needed to reach a given effective sample size.\n\nThe normalizing flow adaptation works by learning a transformation of the parameter\nspace that makes the posterior distribution more amenable to sampling. This is done\nby fitting a sequence of invertible transformations (the \"flow\") that maps the\noriginal parameter space to a space where the posterior is closer to a standard\nnormal distribution. The flow is trained during warmup.\n\nFor more information about the algorithm, see the (still work in progress) paper\n[If only my posterior were normal: Introducing Fisher\nHMC](https://github.com/aseyboldt/covadapt-paper/releases/download/latest/main.pdf).\n\nCurrently, a lot of time is spent on compiling various parts of the normalizing\nflow, and for small models this can take a large amount of the total time.\nHopefully, we will be able to reduce this overhead in the future.\n\n## Requirements\n\nInstall the optional dependencies for normalizing flow adaptation:\n\n```\npip install 'nutpie[nnflow]'\n```\n\nIf you use with PyMC, this will only work if the model is compiled using the jax\nbackend, and if the `gradient_backend` is also set to `jax`.\n\nTraining of the normalizing flow can often be accelerated by using a GPU (even\nif the model itself is written in Stan, without any GPU support). To enable GPU\nyou need to make sure your `jax` installation comes with GPU support, for\ninstance by installing it with `pip install 'jax[cuda12]'`, or selecting the\n`jaxlib` version with GPU support, if you are using conda-forge. You can check if\nyour installation has GPU support by checking the output of:\n\n```python\nimport jax\njax.devices()\n```\n\n### Usage\n\nTo use normalizing flow adaptation in `nutpie`, you need to enable the\n`transform_adapt` option during sampling. Here is an example of how we can use\nit to sample from a difficult posterior:\n\n\n::: {#1e499251 .cell execution_count=1}\n``` {.python .cell-code}\nimport pymc as pm\nimport nutpie\nimport numpy as np\nimport arviz\n\n# Define a 100-dimensional funnel model\nwith pm.Model() as model:\n log_sigma = pm.Normal(\"log_sigma\")\n pm.Normal(\"x\", mu=0, sigma=pm.math.exp(log_sigma / 2), shape=100)\n\n# Compile the model with the jax backend\ncompiled = nutpie.compile_pymc_model(\n model, backend=\"jax\", gradient_backend=\"jax\"\n)\n```\n:::\n\n\nIf we sample this model without normalizing flow adaptation, we will encounter\nconvergence issues, often divergences and always low effective sample sizes:\n\n::: {#f7faabf0 .cell execution_count=2}\n``` {.python .cell-code}\n# Sample without normalizing flow adaptation\ntrace_no_nf = nutpie.sample(compiled, seed=1)\nassert (arviz.ess(trace_no_nf) < 100).any().to_array().any()\n```\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n\n```\n:::\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n
\n

Sampler Progress

\n

Total Chains: 6

\n

Active Chains: 0

\n

\n Finished Chains:\n 6\n

\n

Sampling for 16 seconds

\n

\n Estimated Time to Completion:\n now\n

\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
ProgressDrawsDivergencesStep SizeGradients/Draw
\n \n \n 140000.457
\n \n \n 140000.3115
\n \n \n 140000.317
\n \n \n 140000.287
\n \n \n 140000.3915
\n \n \n 140000.347
\n
\n```\n:::\n:::\n\n\n::: {#6cfb99bd .cell execution_count=3}\n``` {.python .cell-code}\n# We can add further arguments for the normalizing flow:\ncompiled = compiled.with_transform_adapt(\n num_layers=5, # Number of layers in the normalizing flow\n nn_width=32, # Neural networks with 32 hidden units\n num_diag_windows=6, # Number of windows with a diagonal mass matrix intead of a flow\n verbose=False, # Whether to print details about the adaptation process\n show_progress=False, # Whether to show a progress bar for each optimization step\n)\n\n# Sample with normalizing flow adaptation\ntrace_nf = nutpie.sample(\n compiled,\n transform_adapt=True, # Enable the normalizing flow adaptation\n seed=1,\n chains=2,\n cores=1, # Running chains in parallel can be slow\n window_switch_freq=150, # Optimize the normalizing flow every 150 iterations\n)\nassert trace_nf.sample_stats.diverging.sum() == 0\nassert (arviz.ess(trace_nf) > 1000).all().to_array().all()\n```\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n\n```\n:::\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n
\n

Sampler Progress

\n

Total Chains: 2

\n

Active Chains: 0

\n

\n Finished Chains:\n 2\n

\n

Sampling for 18 minutes

\n

\n Estimated Time to Completion:\n now\n

\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
ProgressDrawsDivergencesStep SizeGradients/Draw
\n \n \n 250000.527
\n \n \n 250000.537
\n
\n```\n:::\n:::\n\n\nThe sampler used fewer gradient evaluations with the normalizing flow adaptation,\nbut still converged, and produce a good effective sample size:\n\n::: {#78aaecea .cell execution_count=4}\n``` {.python .cell-code}\nn_steps = int(trace_nf.sample_stats.n_steps.sum() + trace_nf.warmup_sample_stats.n_steps.sum())\ness = float(arviz.ess(trace_nf).min().to_array().min())\nprint(f\"Number of gradient evaluations: {n_steps}\")\nprint(f\"Minimum effective sample size: {ess}\")\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nNumber of gradient evaluations: 42527\nMinimum effective sample size: 1835.9674640023168\n```\n:::\n:::\n\n\nWithout normalizing flow, it used more gradient evaluations, and still wasn't able\nto get a good effective sample size:\n\n::: {#820fea9f .cell execution_count=5}\n``` {.python .cell-code}\nn_steps = int(trace_no_nf.sample_stats.n_steps.sum() + trace_no_nf.warmup_sample_stats.n_steps.sum())\ness = float(arviz.ess(trace_no_nf).min().to_array().min())\nprint(f\"Number of gradient evaluations: {n_steps}\")\nprint(f\"Minimum effective sample size: {ess}\")\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nNumber of gradient evaluations: 124219\nMinimum effective sample size: 31.459420094540565\n```\n:::\n:::\n\n\nThe flow adaptation occurs during warmup, so the number of warmup draws should\nbe large enough to allow the flow to converge. For more complex posteriors, you\nmay need to increase the number of layers (using the `num_layers` argument), or\nyou might want to increase the number of warmup draws.\n\nTo monitor the progress of the flow adaptation, you can set `verbose=True`, or\n`show_progress=True`, but the second should only be used if you sample just one\nchain.\n\nAll losses are on a log-scale. Negative values smaller -2 are a good sign that\nthe adaptation was successful. If the loss stays positive, the flow is either\nnot expressive enough, or the training period is too short. The sampler might\nstill converge, but will probably need more gradient evaluations per effective\ndraw. Large losses bigger than 6 tend to indicate that the posterior is too\ndifficult to sample with the current flow, and the sampler will probably not\nconverge.\n\n", + "supporting": [ + "nf-adapt_files/figure-html" + ], + "filters": [], + "includes": { + "include-in-header": [ + "\n\n\n" + ] + } + } +} \ No newline at end of file diff --git a/docs/_freeze/pymc-usage/execute-results/html.json b/docs/_freeze/pymc-usage/execute-results/html.json new file mode 100644 index 0000000..48635bb --- /dev/null +++ b/docs/_freeze/pymc-usage/execute-results/html.json @@ -0,0 +1,16 @@ +{ + "hash": "fbae7cbc3710a3ccd22e1de17bdfdb36", + "result": { + "engine": "jupyter", + "markdown": "---\ntitle: Usage with PyMC models\n---\n\n\n\nThis document shows how to use `nutpie` with PyMC models. We will use the\n`pymc` package to define a simple model and sample from it using `nutpie`.\n\n## Installation\n\nThe recommended way to install `pymc` is through the `conda` ecosystem. A good\npackage manager for conda packages is `pixi`. See for the [pixi\ndocumentation](https://pixi.sh) for instructions on how to install it.\n\nWe create a new project for this example:\n\n```bash\npixi new pymc-example\n```\n\nThis will create a new directory `pymc-example` with a `pixi.toml` file, that\nyou can edit to add meta information.\n\nWe then add the `pymc` and `nutpie` packages to the project:\n\n```bash\ncd pymc-example\npixi add pymc nutpie arviz\n```\n\nYou can use Visual Studio Code (VSCode) or JupyterLab to write and run our code.\nBoth are excellent tools for working with Python and data science projects.\n\n### Using VSCode\n\n1. Open VSCode.\n2. Open the `pymc-example` directory created earlier.\n3. Create a new file named `model.ipynb`.\n4. Select the pixi kernel to run the code.\n\n### Using JupyterLab\n\n1. Add jupyter labs to the project by running `pixi add jupyterlab`.\n1. Open JupyterLab by running `pixi run jupyter lab` in your terminal.\n3. Create a new Python notebook.\n\n## Defining and Sampling a Simple Model\n\nWe will define a simple Bayesian model using `pymc` and sample from it using\n`nutpie`.\n\n### Model Definition\n\nIn your `model.ipypy` file or Jupyter notebook, add the following code:\n\n\n::: {#dbca6234 .cell execution_count=1}\n``` {.python .cell-code}\nimport pymc as pm\nimport nutpie\nimport pandas as pd\n\ncoords = {\"observation\": range(3)}\n\nwith pm.Model(coords=coords) as model:\n # Prior distributions for the intercept and slope\n intercept = pm.Normal(\"intercept\", mu=0, sigma=1)\n slope = pm.Normal(\"slope\", mu=0, sigma=1)\n\n # Likelihood (sampling distribution) of observations\n x = [1, 2, 3]\n\n mu = intercept + slope * x\n y = pm.Normal(\"y\", mu=mu, sigma=0.1, observed=[1, 2, 3], dims=\"observation\")\n```\n:::\n\n\n### Sampling\n\nWe can now compile the model using the numba backend:\n\n::: {#74540a7e .cell execution_count=2}\n``` {.python .cell-code}\ncompiled = nutpie.compile_pymc_model(model)\ntrace = nutpie.sample(compiled)\n```\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n\n```\n:::\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n
\n

Sampler Progress

\n

Total Chains: 6

\n

Active Chains: 0

\n

\n Finished Chains:\n 6\n

\n

Sampling for now

\n

\n Estimated Time to Completion:\n now\n

\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
ProgressDrawsDivergencesStep SizeGradients/Draw
\n \n \n 140000.593
\n \n \n 140000.659
\n \n \n 140000.551
\n \n \n 140000.5815
\n \n \n 140000.677
\n \n \n 140000.583
\n
\n```\n:::\n:::\n\n\nAlternatively, we can also sample through the `pymc` API:\n\n```python\nwith model:\n trace = pm.sample(model, nuts_sampler=\"nutpie\")\n```\n\nWhile sampling, nutpie shows a progress bar for each chain. It also includes\ninformation about how each chain is doing:\n\n- It shows the current number of draws\n- The step size of the integrator (very small stepsizes are typically a bad\n sign)\n- The number of divergences (if there are divergences, that means that nutpie is\n probably not sampling the posterior correctly)\n- The number of gradient evaluation nutpie uses for each draw. Large numbers\n (100 to 1000) are a sign that the parameterization of the model is not ideal,\n and the sampler is very inefficient.\n\nAfter sampling, this returns an `arviz` InferenceData object that you can use to\nanalyze the trace.\n\nFor example, we should check the effective sample size:\n\n::: {#7a0b20fe .cell execution_count=3}\n``` {.python .cell-code}\nimport arviz as az\naz.ess(trace)\n```\n\n::: {.cell-output .cell-output-display execution_count=3}\n```{=html}\n
\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
<xarray.Dataset> Size: 16B\nDimensions:    ()\nData variables:\n    intercept  float64 8B 1.517e+03\n    slope      float64 8B 1.517e+03
\n```\n:::\n:::\n\n\nand take a look at a trace plot:\n\n::: {#181126a5 .cell execution_count=4}\n``` {.python .cell-code}\naz.plot_trace(trace);\n```\n\n::: {.cell-output .cell-output-display}\n![](pymc-usage_files/figure-html/cell-5-output-1.png){}\n:::\n:::\n\n\n### Choosing the backend\n\nRight now, we have been using the numba backend. This is the default backend for\n`nutpie`, when sampling from pymc models. It tends to have relatively long\ncompilation times, but samples small models very efficiently. For larger models\nthe `jax` backend sometimes outperforms `numba`.\n\nFirst, we need to install the `jax` package:\n\n```bash\npixi add jax\n```\n\nWe can select the backend by passing the `backend` argument to the `compile_pymc_model`:\n\n```python\ncompiled_jax = nutpie.compiled_pymc_model(model, backend=\"jax\")\ntrace = nutpie.sample(compiled_jax)\n```\n\nOr through the pymc API:\n\n```python\nwith model:\n trace = pm.sample(\n model,\n nuts_sampler=\"nutpie\",\n nuts_sampler_kwargs={\"backend\": \"jax\"},\n )\n```\n\nIf you have an nvidia GPU, you can also use the `jax` backend with the `gpu`. We\nwill have to install the `jaxlib` package with the `cuda` option\n\n```bash\npixi add jaxlib --build 'cuda12'\n```\n\nRestart the kernel and check that the GPU is available:\n\n```python\nimport jax\n\n# Should list the cuda device\njax.devices()\n```\n\nSampling again, should now use the GPU, which you can observe by checking the\nGPU usage with `nvidia-smi` or `nvtop`.\n\n### Changing the dataset without recompilation\n\nIf you want to use the same model with different datasets, you can modify\ndatasets after compilation. Since jax does not like changes in shapes, this is\nonly recommended with the numba backend.\n\nFirst, we define the model, but put our dataset in a `pm.Data` structure:\n\n::: {#629172a7 .cell execution_count=5}\n``` {.python .cell-code}\nwith pm.Model() as model:\n x = pm.Data(\"x\", [1, 2, 3])\n intercept = pm.Normal(\"intercept\", mu=0, sigma=1)\n slope = pm.Normal(\"slope\", mu=0, sigma=1)\n mu = intercept + slope * x\n y = pm.Normal(\"y\", mu=mu, sigma=0.1, observed=[1, 2, 3])\n```\n:::\n\n\nWe can now compile the model:\n\n::: {#e865b2bd .cell execution_count=6}\n``` {.python .cell-code}\ncompiled = nutpie.compile_pymc_model(model)\ntrace = nutpie.sample(compiled)\n```\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n\n```\n:::\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n
\n

Sampler Progress

\n

Total Chains: 6

\n

Active Chains: 0

\n

\n Finished Chains:\n 6\n

\n

Sampling for now

\n

\n Estimated Time to Completion:\n now\n

\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
ProgressDrawsDivergencesStep SizeGradients/Draw
\n \n \n 140000.641
\n \n \n 140000.6111
\n \n \n 140000.683
\n \n \n 140000.559
\n \n \n 140000.579
\n \n \n 140000.653
\n
\n```\n:::\n:::\n\n\nAfter compilation, we can change the dataset:\n\n::: {#070d3016 .cell execution_count=7}\n``` {.python .cell-code}\ncompiled2 = compiled.with_data(x=[4, 5, 6])\ntrace2 = nutpie.sample(compiled2)\n```\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n\n```\n:::\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n
\n

Sampler Progress

\n

Total Chains: 6

\n

Active Chains: 0

\n

\n Finished Chains:\n 6\n

\n

Sampling for now

\n

\n Estimated Time to Completion:\n now\n

\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
ProgressDrawsDivergencesStep SizeGradients/Draw
\n \n \n 140000.427
\n \n \n 140000.3527
\n \n \n 140000.4213
\n \n \n 140000.453
\n \n \n 140000.403
\n \n \n 140000.3919
\n
\n```\n:::\n:::\n\n\n", + "supporting": [ + "pymc-usage_files" + ], + "filters": [], + "includes": { + "include-in-header": [ + "\n\n\n" + ] + } + } +} \ No newline at end of file diff --git a/docs/_freeze/pymc-usage/figure-html/cell-5-output-1.png b/docs/_freeze/pymc-usage/figure-html/cell-5-output-1.png new file mode 100644 index 0000000..57cea90 Binary files /dev/null and b/docs/_freeze/pymc-usage/figure-html/cell-5-output-1.png differ diff --git a/docs/_freeze/sample-stats/execute-results/html.json b/docs/_freeze/sample-stats/execute-results/html.json new file mode 100644 index 0000000..a656155 --- /dev/null +++ b/docs/_freeze/sample-stats/execute-results/html.json @@ -0,0 +1,16 @@ +{ + "hash": "1bb39e60c4e9c979109a04b06b7d02e7", + "result": { + "engine": "jupyter", + "markdown": "---\ntitle: Understanding Sampler Statistics in Nutpie\n---\n\n\n\nThis guide explains the various statistics that nutpie collects during sampling. We'll use Neal's funnel distribution as an example, as it's a challenging model that demonstrates many important sampling concepts.\n\n## Example Model: Neal's Funnel\n\nLet's start by implementing Neal's funnel in PyMC:\n\n\n::: {#9ef0aa6e .cell execution_count=1}\n``` {.python .cell-code}\nimport pymc as pm\nimport nutpie\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport seaborn as sns\nimport pandas as pd\nimport arviz as az\n\n# Create the funnel model\nwith pm.Model() as model:\n log_sigma = pm.Normal('log_sigma')\n pm.Normal('x', sigma=pm.math.exp(log_sigma), shape=5)\n\n# Sample with detailed statistics\ncompiled = nutpie.compile_pymc_model(model)\ntrace = nutpie.sample(\n compiled,\n tune=1000,\n store_mass_matrix=True,\n store_gradient=True,\n store_unconstrained=True,\n store_divergences=True,\n seed=42,\n)\n```\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n\n```\n:::\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n
\n

Sampler Progress

\n

Total Chains: 6

\n

Active Chains: 0

\n

\n Finished Chains:\n 6\n

\n

Sampling for now

\n

\n Estimated Time to Completion:\n now\n

\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
ProgressDrawsDivergencesStep SizeGradients/Draw
\n \n \n 2000210.417
\n \n \n 2000240.357
\n \n \n 200000.357
\n \n \n 2000130.447
\n \n \n 200000.3415
\n \n \n 200050.537
\n
\n```\n:::\n:::\n\n\n## Sampler Statistics Overview\n\nThe sampler statistics can be grouped into several categories:\n\n### Basic HMC Statistics\n\nThese statistics are always collected and are essential for basic diagnostics:\n\n::: {#ffa51d43 .cell execution_count=2}\n``` {.python .cell-code}\n# Access through trace.sample_stats\nbasic_stats = [\n 'depth', # Tree depth for current draw\n 'maxdepth_reached', # Whether max tree depth was hit\n 'logp', # Log probability of current position\n 'energy', # Hamiltonian energy\n 'diverging', # Whether the transition diverged\n 'step_size', # Current step size\n 'step_size_bar', # Current estimate of an ideal step size\n 'n_steps' # Number of leapfrog steps\n\n]\n\n# Plot step size evolution during warmup\ntrace.warmup_sample_stats.step_size_bar.plot.line(x=\"draw\", yscale=\"log\")\n```\n\n::: {.cell-output .cell-output-display}\n![](sample-stats_files/figure-html/cell-3-output-1.png){}\n:::\n:::\n\n\n### Mass Matrix Adaptation\n\nThese statistics track how the mass matrix evolves:\n\n::: {#146dc128 .cell execution_count=3}\n``` {.python .cell-code}\n(\n trace\n .warmup_sample_stats\n .mass_matrix_inv\n .plot\n .line(\n x=\"draw\",\n yscale=\"log\",\n col=\"chain\",\n col_wrap=2,\n )\n)\n```\n\n::: {.cell-output .cell-output-display}\n![](sample-stats_files/figure-html/cell-4-output-1.png){}\n:::\n:::\n\n\nVariables that are a source of convergence issues, will often show high variance\nin the final mass matrix estimate across chains.\n\nThe mass matrix will always be fixed for 10% of draws at the end, because we\nonly run final step size adaptation during that time, but high variance in the\nmass matrix before this final window and indicate that more tuning steps might\nbe needed.\n\n### Detailed Diagnostics\n\nThese are only available when explicitly requested:\n\n```python\ndetailed_stats = [\n 'gradient', # Gradient at current position\n 'unconstrained_draw', # Parameters in unconstrained space\n 'divergence_start', # Position where divergence started\n 'divergence_end', # Position where divergence ended\n 'divergence_momentum', # Momentum at divergence\n 'divergence_message' # Description of divergence\n]\n```\n\n#### Idintify Divergences\n\nWe can for instance use this to identify the sources of divergences:\n\n::: {#44f63008 .cell execution_count=4}\n``` {.python .cell-code}\nimport xarray as xr\n\ndraws = (\n trace\n .sample_stats\n .unconstrained_draw\n .assign_coords(kind=\"draw\")\n)\ndivergence_locations = (\n trace\n .sample_stats\n .divergence_start\n .assign_coords(kind=\"divergence\")\n)\n\npoints = xr.concat([draws, divergence_locations], dim=\"kind\")\npoints.to_dataset(\"unconstrained_parameter\").plot.scatter(x=\"log_sigma\", y=\"x_0\", hue=\"kind\")\n```\n\n::: {.cell-output .cell-output-display}\n![](sample-stats_files/figure-html/cell-5-output-1.png){}\n:::\n:::\n\n\n#### Covariance of gradients and draws\n\nTODO this section should really use the transformed gradients and draws, not the\nunconstrained ones, as that avoids the manual mass matrix correction. This\nis only available for the normalizing flow adaptation at the moment though.\n\nIn models with problematic posterior correlations, the singular value\ndecomposition of gradients and draws can often point us to the source of the\nissue.\n\nLet's build a little model with correlations between parameters:\n\n::: {#32313889 .cell execution_count=5}\n``` {.python .cell-code}\nwith pm.Model() as model:\n x = pm.Normal('x')\n y = pm.Normal(\"y\", mu=x, sigma=0.01)\n z = pm.Normal(\"z\", mu=y, shape=100)\n\ncompiled = nutpie.compile_pymc_model(model)\ntrace = nutpie.sample(\n compiled,\n tune=1000,\n store_gradient=True,\n store_unconstrained=True,\n store_mass_matrix=True,\n seed=42,\n)\n```\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n\n```\n:::\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n
\n

Sampler Progress

\n

Total Chains: 6

\n

Active Chains: 0

\n

\n Finished Chains:\n 6\n

\n

Sampling for now

\n

\n Estimated Time to Completion:\n now\n

\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
ProgressDrawsDivergencesStep SizeGradients/Draw
\n \n \n 200000.1331
\n \n \n 200000.1631
\n \n \n 200000.2215
\n \n \n 200000.1431
\n \n \n 200000.1531
\n \n \n 200000.1531
\n
\n```\n:::\n:::\n\n\nNow we can compute eigenvalues of the covariance matrix of the gradient and\ndraws (using the singular value decomposition to avoid quadratic cost):\n\n::: {#be02c1fa .cell execution_count=6}\n``` {.python .cell-code}\ndef covariance_eigenvalues(x, mass_matrix):\n assert x.dims == (\"chain\", \"draw\", \"unconstrained_parameter\")\n x = x.stack(sample=[\"draw\", \"chain\"])\n x = (x - x.mean(\"sample\")) / np.sqrt(mass_matrix)\n u, s, v = np.linalg.svd(x.T / np.sqrt(x.shape[1]), full_matrices=False)\n print(u.shape, s.shape, v.shape)\n s = xr.DataArray(\n s,\n dims=[\"eigenvalue\"],\n coords={\"eigenvalue\": range(s.size)},\n )\n v = xr.DataArray(\n v,\n dims=[\"eigenvalue\", \"unconstrained_parameter\"],\n coords={\n \"eigenvalue\": s.eigenvalue,\n \"unconstrained_parameter\": x.unconstrained_parameter,\n },\n )\n return s ** 2, v\n\nmass_matrix = trace.sample_stats.mass_matrix_inv.isel(draw=-1, chain=0)\ndraws_eigs, draws_eigv = covariance_eigenvalues(trace.sample_stats.unconstrained_draw, mass_matrix)\ngrads_eigs, grads_eigv = covariance_eigenvalues(trace.sample_stats.gradient, 1 / mass_matrix)\n\ndraws_eigs.plot.line(x=\"eigenvalue\", yscale=\"log\")\ngrads_eigs.plot.line(x=\"eigenvalue\", yscale=\"log\")\n```\n\n::: {.cell-output .cell-output-stdout}\n```\n(6000, 102) (102,) (102, 102)\n(6000, 102) (102,) (102, 102)\n```\n:::\n\n::: {.cell-output .cell-output-display}\n![](sample-stats_files/figure-html/cell-7-output-2.png){}\n:::\n:::\n\n\nWe can see one very large and one very small eigenvalue in both covariances.\nLarge eigenvalues for the draws, and small eigenvalues for the gradients prevent\nthe sampler from taking larger steps. Small eigenvalues in the draws, and large\neigenvalues in the grads mean, that the sampler has to move far in parameter\nspace to get independent draws. So both lead to problems during sampling. For\nmodels with many parameters, typically only the large eigenvalues of each are\nmeaningful, because of estimation issues with the small eigenvalues.\n\nWe can also look at the eigenvectors to see which parameters are responsible for\nthe correlations:\n\n::: {#202af166 .cell execution_count=7}\n``` {.python .cell-code}\n(\n draws_eigv\n .sel(eigenvalue=0)\n .to_pandas()\n .sort_values(key=abs)\n .tail(10)\n .plot.bar(x=\"unconstrained_parameter\")\n)\n```\n\n::: {.cell-output .cell-output-display}\n![](sample-stats_files/figure-html/cell-8-output-1.png){}\n:::\n:::\n\n\n::: {#52404958 .cell execution_count=8}\n``` {.python .cell-code}\n(\n grads_eigv\n .sel(eigenvalue=0)\n .to_pandas()\n .sort_values(key=abs)\n .tail(10)\n .plot.bar(x=\"unconstrained_parameter\")\n)\n```\n\n::: {.cell-output .cell-output-display}\n![](sample-stats_files/figure-html/cell-9-output-1.png){}\n:::\n:::\n\n\n", + "supporting": [ + "sample-stats_files" + ], + "filters": [], + "includes": { + "include-in-header": [ + "\n\n\n" + ] + } + } +} \ No newline at end of file diff --git a/docs/_freeze/sample-stats/figure-html/cell-3-output-1.png b/docs/_freeze/sample-stats/figure-html/cell-3-output-1.png new file mode 100644 index 0000000..f9351aa Binary files /dev/null and b/docs/_freeze/sample-stats/figure-html/cell-3-output-1.png differ diff --git a/docs/_freeze/sample-stats/figure-html/cell-4-output-1.png b/docs/_freeze/sample-stats/figure-html/cell-4-output-1.png new file mode 100644 index 0000000..2cea4b6 Binary files /dev/null and b/docs/_freeze/sample-stats/figure-html/cell-4-output-1.png differ diff --git a/docs/_freeze/sample-stats/figure-html/cell-5-output-1.png b/docs/_freeze/sample-stats/figure-html/cell-5-output-1.png new file mode 100644 index 0000000..e9016f4 Binary files /dev/null and b/docs/_freeze/sample-stats/figure-html/cell-5-output-1.png differ diff --git a/docs/_freeze/sample-stats/figure-html/cell-7-output-2.png b/docs/_freeze/sample-stats/figure-html/cell-7-output-2.png new file mode 100644 index 0000000..4f8213c Binary files /dev/null and b/docs/_freeze/sample-stats/figure-html/cell-7-output-2.png differ diff --git a/docs/_freeze/sample-stats/figure-html/cell-8-output-1.png b/docs/_freeze/sample-stats/figure-html/cell-8-output-1.png new file mode 100644 index 0000000..185762a Binary files /dev/null and b/docs/_freeze/sample-stats/figure-html/cell-8-output-1.png differ diff --git a/docs/_freeze/sample-stats/figure-html/cell-9-output-1.png b/docs/_freeze/sample-stats/figure-html/cell-9-output-1.png new file mode 100644 index 0000000..053bc20 Binary files /dev/null and b/docs/_freeze/sample-stats/figure-html/cell-9-output-1.png differ diff --git a/docs/_freeze/site_libs/clipboard/clipboard.min.js b/docs/_freeze/site_libs/clipboard/clipboard.min.js new file mode 100644 index 0000000..1103f81 --- /dev/null +++ b/docs/_freeze/site_libs/clipboard/clipboard.min.js @@ -0,0 +1,7 @@ +/*! + * clipboard.js v2.0.11 + * https://clipboardjs.com/ + * + * Licensed MIT © Zeno Rocha + */ +!function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.ClipboardJS=e():t.ClipboardJS=e()}(this,function(){return n={686:function(t,e,n){"use strict";n.d(e,{default:function(){return b}});var e=n(279),i=n.n(e),e=n(370),u=n.n(e),e=n(817),r=n.n(e);function c(t){try{return document.execCommand(t)}catch(t){return}}var a=function(t){t=r()(t);return c("cut"),t};function o(t,e){var n,o,t=(n=t,o="rtl"===document.documentElement.getAttribute("dir"),(t=document.createElement("textarea")).style.fontSize="12pt",t.style.border="0",t.style.padding="0",t.style.margin="0",t.style.position="absolute",t.style[o?"right":"left"]="-9999px",o=window.pageYOffset||document.documentElement.scrollTop,t.style.top="".concat(o,"px"),t.setAttribute("readonly",""),t.value=n,t);return e.container.appendChild(t),e=r()(t),c("copy"),t.remove(),e}var f=function(t){var e=1 N;\n vector[N] y;\n}\nparameters {\n real mu;\n}\nmodel {\n mu ~ normal(0, 1);\n y ~ normal(mu, 1);\n}\n\"\"\"\n\ncompiled_model = nutpie.compile_stan_model(code=model_code)\n```\n:::\n\n\n### Sampling\n\nWe can now compile the model and sample from it:\n\n::: {#60b965bf .cell execution_count=3}\n``` {.python .cell-code}\ncompiled_model_with_data = compiled_model.with_data(N=3, y=[1, 2, 3])\ntrace = nutpie.sample(compiled_model_with_data)\n```\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n\n```\n:::\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n
\n

Sampler Progress

\n

Total Chains: 6

\n

Active Chains: 0

\n

\n Finished Chains:\n 6\n

\n

Sampling for now

\n

\n Estimated Time to Completion:\n now\n

\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
ProgressDrawsDivergencesStep SizeGradients/Draw
\n \n \n 140001.333
\n \n \n 140001.391
\n \n \n 140001.373
\n \n \n 140001.381
\n \n \n 140001.353
\n \n \n 140001.333
\n
\n```\n:::\n:::\n\n\n### Using Dimensions\n\nWe'll use the radon model from\n[this](https://mc-stan.org/learn-stan/case-studies/radon_cmdstanpy_plotnine.html)\ncase-study from the stan documentation, to show how we can use coordinates and\ndimension names to simplify working with trace objects.\n\nWe follow the same data preparation as in the case-study:\n\n::: {#92d854e3 .cell execution_count=4}\n``` {.python .cell-code}\nimport pandas as pd\nimport numpy as np\nimport arviz as az\nimport seaborn as sns\n\nhome_data = pd.read_csv(\n \"https://github.com/pymc-devs/pymc-examples/raw/refs/heads/main/examples/data/srrs2.dat\",\n index_col=\"idnum\",\n)\ncounty_data = pd.read_csv(\n \"https://github.com/pymc-devs/pymc-examples/raw/refs/heads/main/examples/data/cty.dat\",\n)\n\nradon_data = (\n home_data\n .rename(columns=dict(cntyfips=\"ctfips\"))\n .merge(\n (\n county_data\n .drop_duplicates(['stfips', 'ctfips', 'st', 'cty', 'Uppm'])\n .set_index([\"ctfips\", \"stfips\"])\n ),\n right_index=True,\n left_on=[\"ctfips\", \"stfips\"],\n )\n .assign(log_radon=lambda x: np.log(np.clip(x.activity, 0.1, np.inf)))\n .assign(log_uranium=lambda x: np.log(np.clip(x[\"Uppm\"], 0.1, np.inf)))\n .query(\"state == 'MN'\")\n)\n```\n:::\n\n\nAnd also use the partially pooled model from the case-study:\n\n::: {#ce581edd .cell execution_count=5}\n``` {.python .cell-code}\nmodel_code = \"\"\"\ndata {\n int N; // observations\n int J; // counties\n array[N] int county;\n vector[N] x;\n vector[N] y;\n}\nparameters {\n real mu_alpha;\n real sigma_alpha;\n vector[J] alpha; // non-centered parameterization\n real beta;\n real sigma;\n}\nmodel {\n y ~ normal(alpha[county] + beta * x, sigma);\n alpha ~ normal(mu_alpha, sigma_alpha); // partial-pooling\n beta ~ normal(0, 10);\n sigma ~ normal(0, 10);\n mu_alpha ~ normal(0, 10);\n sigma_alpha ~ normal(0, 10);\n}\ngenerated quantities {\n array[N] real y_rep = normal_rng(alpha[county] + beta * x, sigma);\n}\n\"\"\"\n```\n:::\n\n\nWe collect the dataset in the format that the stan model requires,\nand specify the dimensions of each of the non-scalar variables in the model:\n\n::: {#9a29bf02 .cell execution_count=6}\n``` {.python .cell-code}\ncounty_idx, counties = pd.factorize(radon_data[\"county\"], use_na_sentinel=False)\nobservations = radon_data.index\n\ncoords = {\n \"county\": counties,\n \"observation\": observations,\n}\n\ndims = {\n \"alpha\": [\"county\"],\n \"y_rep\": [\"observation\"],\n}\n\ndata = {\n \"N\": len(observations),\n \"J\": len(counties),\n # Stan uses 1-based indexing!\n \"county\": county_idx + 1,\n \"x\": radon_data.log_uranium.values,\n \"y\": radon_data.log_radon.values,\n}\n```\n:::\n\n\nThen, we compile the model and provide the dimensions, coordinates and the\ndataset we just defined:\n\n::: {#fe0286f3 .cell execution_count=7}\n``` {.python .cell-code}\ncompiled_model = (\n nutpie.compile_stan_model(code=model_code)\n .with_data(**data)\n .with_dims(**dims)\n .with_coords(**coords)\n)\n```\n:::\n\n\n::: {#7a704cbf .cell execution_count=8}\n``` {.python .cell-code}\n%%time\ntrace = nutpie.sample(compiled_model, seed=0)\n```\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n\n```\n:::\n\n::: {.cell-output .cell-output-display}\n```{=html}\n\n
\n

Sampler Progress

\n

Total Chains: 6

\n

Active Chains: 0

\n

\n Finished Chains:\n 6\n

\n

Sampling for now

\n

\n Estimated Time to Completion:\n now\n

\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
ProgressDrawsDivergencesStep SizeGradients/Draw
\n \n \n 140000.3931
\n \n \n 140000.477
\n \n \n 140000.457
\n \n \n 140000.467
\n \n \n 140000.457
\n \n \n 140000.457
\n
\n```\n:::\n\n::: {.cell-output .cell-output-stdout}\n```\nCPU times: user 2.27 s, sys: 39.2 ms, total: 2.31 s\nWall time: 547 ms\n```\n:::\n:::\n\n\nAs some basic convergance checking we verify that all Rhat values are smaller\nthan 1.02, all parameters have at least 500 effective draws and that we have no\ndivergences:\n\n::: {#013fe62f .cell execution_count=9}\n``` {.python .cell-code}\nassert trace.sample_stats.diverging.sum() == 0\nassert az.ess(trace).min().min() > 500\nassert az.rhat(trace).max().max() > 1.02\n```\n:::\n\n\nThanks to the coordinates and dimensions we specified, the resulting trace will\nnow contain labeled data, so that plots based on it have properly set-up labels:\n\n::: {#34452909 .cell execution_count=10}\n``` {.python .cell-code}\nimport arviz as az\nimport seaborn as sns\nimport xarray as xr\n\nsns.catplot(\n data=trace.posterior.alpha.to_dataframe().reset_index(),\n y=\"county\",\n x=\"alpha\",\n kind=\"boxen\",\n height=13,\n aspect=1/2.5,\n showfliers=False,\n)\n```\n\n::: {.cell-output .cell-output-display}\n![](stan-usage_files/figure-html/cell-11-output-1.png){}\n:::\n:::\n\n\n", + "supporting": [ + "stan-usage_files" + ], + "filters": [], + "includes": { + "include-in-header": [ + "\n\n\n" + ] + } + } +} \ No newline at end of file diff --git a/docs/_freeze/stan-usage/figure-html/cell-11-output-1.png b/docs/_freeze/stan-usage/figure-html/cell-11-output-1.png new file mode 100644 index 0000000..4b219d9 Binary files /dev/null and b/docs/_freeze/stan-usage/figure-html/cell-11-output-1.png differ diff --git a/docs/_quarto.yml b/docs/_quarto.yml new file mode 100644 index 0000000..bf0f442 --- /dev/null +++ b/docs/_quarto.yml @@ -0,0 +1,34 @@ +project: + type: website + +website: + title: "Nutpie" + navbar: + left: + - href: index.qmd + text: Home + - href: pymc-usage.qmd + text: Usage with PyMC + - href: stan-usage.qmd + text: Usage with Stan + - href: sampling-options.qmd + text: Sampling Options + - href: nf-adapt.qmd + text: Normalizing flow adaptation + - href: sample-stats.qmd + text: Diagnostic information + - about.qmd + tools: + - icon: github + href: https://github.com/pymc-devs/nutpie + +format: + html: + theme: + - cosmo + - brand + css: styles.css + toc: true + +execute: + freeze: auto diff --git a/docs/about.qmd b/docs/about.qmd new file mode 100644 index 0000000..16fc02f --- /dev/null +++ b/docs/about.qmd @@ -0,0 +1,17 @@ +--- +title: "About" +--- + +Nutpie is part of the PyMC organization. The PyMC organization develops and +maintains tools for Bayesian statistical modeling and probabilistic machine +learning. + +Nutpie provides a high-performance implementation of the No-U-Turn Sampler +(NUTS) that can be used with models defined in PyMC, Stan and other frameworks. +It was created to enable faster and more efficient Bayesian inference while +maintaining compatibility with existing probabilistic programming tools. + +For more information about the PyMC organization, visit the following links: + +- [PyMC Website](https://www.pymc.io) +- [PyMC GitHub Organization](https://github.com/pymc-devs) diff --git a/docs/index.qmd b/docs/index.qmd new file mode 100644 index 0000000..6de4796 --- /dev/null +++ b/docs/index.qmd @@ -0,0 +1,90 @@ +# Nutpie Documentation + +`nutpie` is a high-performance library designed for Bayesian inference, that +provides efficient sampling algorithms for probabilistic models. It can sample +models that are defined in PyMC or Stan (numpyro and custom hand-coded +likelihoods with gradient are coming soon). + +- Faster sampling than either the PyMC or Stan default samplers. (An average + ~2x speedup on `posteriordb` compared to Stan) +- All the diagnostic information of PyMC and Stan and some more. +- GPU support for PyMC models through jax. +- A more informative progress bar. +- Access to the incomplete trace during sampling. +- *Experimental* normalizing flow adaptation for more efficient sampling of + difficult posteriors. + +## Quickstart: PyMC + +Install `nutpie` with pip, uv, pixi, or conda: + +For usage with pymc: + +```bash +# One of +pip install "nutpie[pymc]" +uv add "nutpie[pymc]" +pixi add nutpie pymc numba +conda install -c conda-forge nutpie pymc numba +``` + +And then sample with + +```{python} +import nutpie +import pymc as pm + +with pm.Model() as model: + mu = pm.Normal("mu", mu=0, sigma=1) + obs = pm.Normal("obs", mu=mu, sigma=1, observed=[1, 2, 3]) + +compiled = nutpie.compile_pymc_model(model) +trace = nutpie.sample(compiled) +``` + +For more information, see the detailed [PyMC usage guide](pymc-usage.qmd). + +## Quickstart: Stan + +Stan needs access to a compiler toolchain, you can find instructions for those +[here](https://mc-stan.org/docs/cmdstan-guide/installation.html#cpp-toolchain). +You can then install nutpie through pip or uv: + +```bash +# One of +pip install "nutpie[stan]" +uv add "nutpie[stan]" +``` + +```{python} +#| echo: false +import os +os.environ["TBB_CXX_TYPE"] = "clang" +``` + +```{python} +import nutpie + +model = """ +data { + int N; + vector[N] y; +} +parameters { + real mu; +} +model { + mu ~ normal(0, 1); + y ~ normal(mu, 1); +} +""" + +compiled = ( + nutpie + .compile_stan_model(code=model) + .with_data(N=3, y=[1, 2, 3]) +) +trace = nutpie.sample(compiled) +``` + +For more information, see the detailed [Stan usage guide](stan-usage.qmd). diff --git a/docs/nf-adapt.qmd b/docs/nf-adapt.qmd new file mode 100644 index 0000000..7900c5d --- /dev/null +++ b/docs/nf-adapt.qmd @@ -0,0 +1,139 @@ +# Adaptation with Normalizing Flows + +**Experimental and subject to change** + +Normalizing flow adaptation through Fisher HMC is a new sampling algorithm that +automatically reparameterizes a model. It adds some computational cost outside +model log-density evaluations, but allows sampling from much more difficult +posterior distributions. For models with expensive log-density evaluations, the +normalizing flow adaptation can also be much faster, if it can reduce the number +of log-density evaluations needed to reach a given effective sample size. + +The normalizing flow adaptation works by learning a transformation of the parameter +space that makes the posterior distribution more amenable to sampling. This is done +by fitting a sequence of invertible transformations (the "flow") that maps the +original parameter space to a space where the posterior is closer to a standard +normal distribution. The flow is trained during warmup. + +For more information about the algorithm, see the (still work in progress) paper +[If only my posterior were normal: Introducing Fisher +HMC](https://github.com/aseyboldt/covadapt-paper/releases/download/latest/main.pdf). + +Currently, a lot of time is spent on compiling various parts of the normalizing +flow, and for small models this can take a large amount of the total time. +Hopefully, we will be able to reduce this overhead in the future. + +## Requirements + +Install the optional dependencies for normalizing flow adaptation: + +``` +pip install 'nutpie[nnflow]' +``` + +If you use with PyMC, this will only work if the model is compiled using the jax +backend, and if the `gradient_backend` is also set to `jax`. + +Training of the normalizing flow can often be accelerated by using a GPU (even +if the model itself is written in Stan, without any GPU support). To enable GPU +you need to make sure your `jax` installation comes with GPU support, for +instance by installing it with `pip install 'jax[cuda12]'`, or selecting the +`jaxlib` version with GPU support, if you are using conda-forge. You can check if +your installation has GPU support by checking the output of: + +```python +import jax +jax.devices() +``` + +### Usage + +To use normalizing flow adaptation in `nutpie`, you need to enable the +`transform_adapt` option during sampling. Here is an example of how we can use +it to sample from a difficult posterior: + +```{python} +import pymc as pm +import nutpie +import numpy as np +import arviz + +# Define a 100-dimensional funnel model +with pm.Model() as model: + log_sigma = pm.Normal("log_sigma") + pm.Normal("x", mu=0, sigma=pm.math.exp(log_sigma / 2), shape=100) + +# Compile the model with the jax backend +compiled = nutpie.compile_pymc_model( + model, backend="jax", gradient_backend="jax" +) +``` + +If we sample this model without normalizing flow adaptation, we will encounter +convergence issues, often divergences and always low effective sample sizes: + +```{python} +# Sample without normalizing flow adaptation +trace_no_nf = nutpie.sample(compiled, seed=1) +assert (arviz.ess(trace_no_nf) < 100).any().to_array().any() +``` + +```{python} +# We can add further arguments for the normalizing flow: +compiled = compiled.with_transform_adapt( + num_layers=5, # Number of layers in the normalizing flow + nn_width=32, # Neural networks with 32 hidden units + num_diag_windows=6, # Number of windows with a diagonal mass matrix intead of a flow + verbose=False, # Whether to print details about the adaptation process + show_progress=False, # Whether to show a progress bar for each optimization step +) + +# Sample with normalizing flow adaptation +trace_nf = nutpie.sample( + compiled, + transform_adapt=True, # Enable the normalizing flow adaptation + seed=1, + chains=2, + cores=1, # Running chains in parallel can be slow + window_switch_freq=150, # Optimize the normalizing flow every 150 iterations +) +assert trace_nf.sample_stats.diverging.sum() == 0 +assert (arviz.ess(trace_nf) > 1000).all().to_array().all() +``` + +The sampler used fewer gradient evaluations with the normalizing flow adaptation, +but still converged, and produce a good effective sample size: + +```{python} +n_steps = int(trace_nf.sample_stats.n_steps.sum() + trace_nf.warmup_sample_stats.n_steps.sum()) +ess = float(arviz.ess(trace_nf).min().to_array().min()) +print(f"Number of gradient evaluations: {n_steps}") +print(f"Minimum effective sample size: {ess}") +``` + +Without normalizing flow, it used more gradient evaluations, and still wasn't able +to get a good effective sample size: + +```{python} +n_steps = int(trace_no_nf.sample_stats.n_steps.sum() + trace_no_nf.warmup_sample_stats.n_steps.sum()) +ess = float(arviz.ess(trace_no_nf).min().to_array().min()) +print(f"Number of gradient evaluations: {n_steps}") +print(f"Minimum effective sample size: {ess}") +``` + +The flow adaptation occurs during warmup, so the number of warmup draws should +be large enough to allow the flow to converge. For more complex posteriors, you +may need to increase the number of layers (using the `num_layers` argument), or +you might want to increase the number of warmup draws. + +To monitor the progress of the flow adaptation, you can set `verbose=True`, or +`show_progress=True`, but the second should only be used if you sample just one +chain. + +All losses are on a log-scale. Negative values smaller -2 are a good sign that +the adaptation was successful. If the loss stays positive, the flow is either +not expressive enough, or the training period is too short. The sampler might +still converge, but will probably need more gradient evaluations per effective +draw. Large losses bigger than 6 tend to indicate that the posterior is too +difficult to sample with the current flow, and the sampler will probably not +converge. diff --git a/docs/pymc-usage.qmd b/docs/pymc-usage.qmd new file mode 100644 index 0000000..56d81d7 --- /dev/null +++ b/docs/pymc-usage.qmd @@ -0,0 +1,195 @@ +# Usage with PyMC models + +This document shows how to use `nutpie` with PyMC models. We will use the +`pymc` package to define a simple model and sample from it using `nutpie`. + +## Installation + +The recommended way to install `pymc` is through the `conda` ecosystem. A good +package manager for conda packages is `pixi`. See for the [pixi +documentation](https://pixi.sh) for instructions on how to install it. + +We create a new project for this example: + +```bash +pixi new pymc-example +``` + +This will create a new directory `pymc-example` with a `pixi.toml` file, that +you can edit to add meta information. + +We then add the `pymc` and `nutpie` packages to the project: + +```bash +cd pymc-example +pixi add pymc nutpie arviz +``` + +You can use Visual Studio Code (VSCode) or JupyterLab to write and run our code. +Both are excellent tools for working with Python and data science projects. + +### Using VSCode + +1. Open VSCode. +2. Open the `pymc-example` directory created earlier. +3. Create a new file named `model.ipynb`. +4. Select the pixi kernel to run the code. + +### Using JupyterLab + +1. Add jupyter labs to the project by running `pixi add jupyterlab`. +1. Open JupyterLab by running `pixi run jupyter lab` in your terminal. +3. Create a new Python notebook. + +## Defining and Sampling a Simple Model + +We will define a simple Bayesian model using `pymc` and sample from it using +`nutpie`. + +### Model Definition + +In your `model.ipypy` file or Jupyter notebook, add the following code: + +```{python} +import pymc as pm +import nutpie +import pandas as pd + +coords = {"observation": range(3)} + +with pm.Model(coords=coords) as model: + # Prior distributions for the intercept and slope + intercept = pm.Normal("intercept", mu=0, sigma=1) + slope = pm.Normal("slope", mu=0, sigma=1) + + # Likelihood (sampling distribution) of observations + x = [1, 2, 3] + + mu = intercept + slope * x + y = pm.Normal("y", mu=mu, sigma=0.1, observed=[1, 2, 3], dims="observation") +``` + +### Sampling + +We can now compile the model using the numba backend: + +```{python} +compiled = nutpie.compile_pymc_model(model) +trace = nutpie.sample(compiled) +``` + +Alternatively, we can also sample through the `pymc` API: + +```python +with model: + trace = pm.sample(model, nuts_sampler="nutpie") +``` + +While sampling, nutpie shows a progress bar for each chain. It also includes +information about how each chain is doing: + +- It shows the current number of draws +- The step size of the integrator (very small stepsizes are typically a bad + sign) +- The number of divergences (if there are divergences, that means that nutpie is + probably not sampling the posterior correctly) +- The number of gradient evaluation nutpie uses for each draw. Large numbers + (100 to 1000) are a sign that the parameterization of the model is not ideal, + and the sampler is very inefficient. + +After sampling, this returns an `arviz` InferenceData object that you can use to +analyze the trace. + +For example, we should check the effective sample size: + +```{python} +import arviz as az +az.ess(trace) +``` + +and take a look at a trace plot: + +```{python} +az.plot_trace(trace); +``` + +### Choosing the backend + +Right now, we have been using the numba backend. This is the default backend for +`nutpie`, when sampling from pymc models. It tends to have relatively long +compilation times, but samples small models very efficiently. For larger models +the `jax` backend sometimes outperforms `numba`. + +First, we need to install the `jax` package: + +```bash +pixi add jax +``` + +We can select the backend by passing the `backend` argument to the `compile_pymc_model`: + +```python +compiled_jax = nutpie.compiled_pymc_model(model, backend="jax") +trace = nutpie.sample(compiled_jax) +``` + +Or through the pymc API: + +```python +with model: + trace = pm.sample( + model, + nuts_sampler="nutpie", + nuts_sampler_kwargs={"backend": "jax"}, + ) +``` + +If you have an nvidia GPU, you can also use the `jax` backend with the `gpu`. We +will have to install the `jaxlib` package with the `cuda` option + +```bash +pixi add jaxlib --build 'cuda12' +``` + +Restart the kernel and check that the GPU is available: + +```python +import jax + +# Should list the cuda device +jax.devices() +``` + +Sampling again, should now use the GPU, which you can observe by checking the +GPU usage with `nvidia-smi` or `nvtop`. + +### Changing the dataset without recompilation + +If you want to use the same model with different datasets, you can modify +datasets after compilation. Since jax does not like changes in shapes, this is +only recommended with the numba backend. + +First, we define the model, but put our dataset in a `pm.Data` structure: + +```{python} +with pm.Model() as model: + x = pm.Data("x", [1, 2, 3]) + intercept = pm.Normal("intercept", mu=0, sigma=1) + slope = pm.Normal("slope", mu=0, sigma=1) + mu = intercept + slope * x + y = pm.Normal("y", mu=mu, sigma=0.1, observed=[1, 2, 3]) +``` + +We can now compile the model: + +```{python} +compiled = nutpie.compile_pymc_model(model) +trace = nutpie.sample(compiled) +``` + +After compilation, we can change the dataset: + +```{python} +compiled2 = compiled.with_data(x=[4, 5, 6]) +trace2 = nutpie.sample(compiled2) +``` diff --git a/docs/sample-stats.qmd b/docs/sample-stats.qmd new file mode 100644 index 0000000..7cf92c9 --- /dev/null +++ b/docs/sample-stats.qmd @@ -0,0 +1,221 @@ +# Understanding Sampler Statistics in Nutpie + +This guide explains the various statistics that nutpie collects during sampling. We'll use Neal's funnel distribution as an example, as it's a challenging model that demonstrates many important sampling concepts. + +## Example Model: Neal's Funnel + +Let's start by implementing Neal's funnel in PyMC: + +```{python} +import pymc as pm +import nutpie +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +import pandas as pd +import arviz as az + +# Create the funnel model +with pm.Model() as model: + log_sigma = pm.Normal('log_sigma') + pm.Normal('x', sigma=pm.math.exp(log_sigma), shape=5) + +# Sample with detailed statistics +compiled = nutpie.compile_pymc_model(model) +trace = nutpie.sample( + compiled, + tune=1000, + store_mass_matrix=True, + store_gradient=True, + store_unconstrained=True, + store_divergences=True, + seed=42, +) +``` + +## Sampler Statistics Overview + +The sampler statistics can be grouped into several categories: + +### Basic HMC Statistics + +These statistics are always collected and are essential for basic diagnostics: + +```{python} +# Access through trace.sample_stats +basic_stats = [ + 'depth', # Tree depth for current draw + 'maxdepth_reached', # Whether max tree depth was hit + 'logp', # Log probability of current position + 'energy', # Hamiltonian energy + 'diverging', # Whether the transition diverged + 'step_size', # Current step size + 'step_size_bar', # Current estimate of an ideal step size + 'n_steps' # Number of leapfrog steps + +] + +# Plot step size evolution during warmup +trace.warmup_sample_stats.step_size_bar.plot.line(x="draw", yscale="log") +``` + +### Mass Matrix Adaptation + +These statistics track how the mass matrix evolves: + +```{python} +( + trace + .warmup_sample_stats + .mass_matrix_inv + .plot + .line( + x="draw", + yscale="log", + col="chain", + col_wrap=2, + ) +) +``` + +Variables that are a source of convergence issues, will often show high variance +in the final mass matrix estimate across chains. + +The mass matrix will always be fixed for 10% of draws at the end, because we +only run final step size adaptation during that time, but high variance in the +mass matrix before this final window and indicate that more tuning steps might +be needed. + +### Detailed Diagnostics + +These are only available when explicitly requested: + +```python +detailed_stats = [ + 'gradient', # Gradient at current position + 'unconstrained_draw', # Parameters in unconstrained space + 'divergence_start', # Position where divergence started + 'divergence_end', # Position where divergence ended + 'divergence_momentum', # Momentum at divergence + 'divergence_message' # Description of divergence +] +``` + +#### Idintify Divergences + +We can for instance use this to identify the sources of divergences: + +```{python} +import xarray as xr + +draws = ( + trace + .sample_stats + .unconstrained_draw + .assign_coords(kind="draw") +) +divergence_locations = ( + trace + .sample_stats + .divergence_start + .assign_coords(kind="divergence") +) + +points = xr.concat([draws, divergence_locations], dim="kind") +points.to_dataset("unconstrained_parameter").plot.scatter(x="log_sigma", y="x_0", hue="kind") +``` + +#### Covariance of gradients and draws + +TODO this section should really use the transformed gradients and draws, not the +unconstrained ones, as that avoids the manual mass matrix correction. This +is only available for the normalizing flow adaptation at the moment though. + +In models with problematic posterior correlations, the singular value +decomposition of gradients and draws can often point us to the source of the +issue. + +Let's build a little model with correlations between parameters: + +```{python} +with pm.Model() as model: + x = pm.Normal('x') + y = pm.Normal("y", mu=x, sigma=0.01) + z = pm.Normal("z", mu=y, shape=100) + +compiled = nutpie.compile_pymc_model(model) +trace = nutpie.sample( + compiled, + tune=1000, + store_gradient=True, + store_unconstrained=True, + store_mass_matrix=True, + seed=42, +) +``` + +Now we can compute eigenvalues of the covariance matrix of the gradient and +draws (using the singular value decomposition to avoid quadratic cost): + +```{python} +def covariance_eigenvalues(x, mass_matrix): + assert x.dims == ("chain", "draw", "unconstrained_parameter") + x = x.stack(sample=["draw", "chain"]) + x = (x - x.mean("sample")) / np.sqrt(mass_matrix) + u, s, v = np.linalg.svd(x.T / np.sqrt(x.shape[1]), full_matrices=False) + print(u.shape, s.shape, v.shape) + s = xr.DataArray( + s, + dims=["eigenvalue"], + coords={"eigenvalue": range(s.size)}, + ) + v = xr.DataArray( + v, + dims=["eigenvalue", "unconstrained_parameter"], + coords={ + "eigenvalue": s.eigenvalue, + "unconstrained_parameter": x.unconstrained_parameter, + }, + ) + return s ** 2, v + +mass_matrix = trace.sample_stats.mass_matrix_inv.isel(draw=-1, chain=0) +draws_eigs, draws_eigv = covariance_eigenvalues(trace.sample_stats.unconstrained_draw, mass_matrix) +grads_eigs, grads_eigv = covariance_eigenvalues(trace.sample_stats.gradient, 1 / mass_matrix) + +draws_eigs.plot.line(x="eigenvalue", yscale="log") +grads_eigs.plot.line(x="eigenvalue", yscale="log") +``` + +We can see one very large and one very small eigenvalue in both covariances. +Large eigenvalues for the draws, and small eigenvalues for the gradients prevent +the sampler from taking larger steps. Small eigenvalues in the draws, and large +eigenvalues in the grads mean, that the sampler has to move far in parameter +space to get independent draws. So both lead to problems during sampling. For +models with many parameters, typically only the large eigenvalues of each are +meaningful, because of estimation issues with the small eigenvalues. + +We can also look at the eigenvectors to see which parameters are responsible for +the correlations: + +```{python} +( + draws_eigv + .sel(eigenvalue=0) + .to_pandas() + .sort_values(key=abs) + .tail(10) + .plot.bar(x="unconstrained_parameter") +) +``` + +```{python} +( + grads_eigv + .sel(eigenvalue=0) + .to_pandas() + .sort_values(key=abs) + .tail(10) + .plot.bar(x="unconstrained_parameter") +) +``` diff --git a/docs/sampling-options.qmd b/docs/sampling-options.qmd new file mode 100644 index 0000000..270fe69 --- /dev/null +++ b/docs/sampling-options.qmd @@ -0,0 +1,156 @@ +# Sampling Configuration Guide + +This guide covers the configuration options for `nutpie.sample` and provides +practical advice for tuning your sampler. We'll start with basic usage and move +to advanced topics like mass matrix adaptation. + +## Quick Start + +For most models, don't think too much about the options of the sampler, and just +use the defaults. Most sampling problems can't easily be solved by changing the +sampler, most of the time they require model changes. So in most cases, simply use + +```python +trace = nutpie.sample(compiled_model) +``` + +## Core Sampling Parameters + +### Drawing Samples + +```python +trace = nutpie.sample( + model, + draws=1000, # Number of post-warmup draws per chain + tune=500, # Number of warmup draws for adaptation + chains=6, # Number of independent chains + cores=None, # Number chains that are allowed to run simultainiously + seed=12345 # Random seed for reproducibility +) +``` + +The number of draws affects both accuracy and computational cost: +- Too few draws (< 500) may not capture the posterior well +- Too many draws (> 10000) may waste computation time + +If a model is sampling without divergences, but with effective sample sizes that +are not as large as necessary to accieve the markov-error for your estimates, +you can increase the number of chains and/or draws. + +If the effective sample size is much smaller than the number of draws, you might +want to consider reparameterizing the model instead, to for instance remove +posterior correlations. + +## Sampler Diagnostics + +You can enable more detailed diagnostics when troubleshooting: + +```python +trace = nutpie.sample( + model, + save_warmup=True, # Keep warmup draws, default is True + store_divergences=True, # Track divergent transitions + store_unconstrained=True, # Store transformed parameters + store_gradient=True, # Store gradient information + store_mass_matrix=True # Track mass matrix adaptation +) +``` + +For each of the `store_*` arguments, additional arrays will be availbale in the +`trace.sample_stats`. + +## Non-blocking sampling + + + +### Settings for HMC and NUTS + +```python +trace = nutpie.sample( + model, + target_accept=0.8, # Target acceptance rate + maxdepth=10 # Maximum tree depth + max_energy_error=1000 # Error at witch to count the trajectory as a divergent transition +) +``` + +The `target_accept` parameter implicitly controls the step size of the leapfrog +steps in the HMC sampler. During tuning, the sampler will try to choose a step +size, such that the acceptance statistic is `target_accept`. It has to be +between 0 and 1. + +The default is 0.8. Larger values will increase the computational cost, but +might avoid divergences during sampling. In many diverging models increasing +`target_accept` will only make divergences less frequent however, and not solve +the underlying problem. + +Lowering the maximum energy error to for instance 10 will often increase the +number of divergences, and make it easier to diagnose their cause. With lower +value the divergences often are reported closer to the critical points in the +parameter space, where the model is most likely to diverge. + +## Mass Matrix Adaptation + +Nutpie offers several strategies for adapting the mass matrix, which determines +how the sampler navigates the parameter space. + +### Standard Adaptation + +By setting `use_grad_based_mass_matrix=False`, the sampling algorithm will more +closely resemble the algorithm in Stan and PyMC. Usually, this will result in +less efficient sampling, but the total number of effective samples is sometimes +higher. If this is set to `True` (the default), nutpie will use diagonal mass +matrix estimates that are based on the posterior draws and the scores at those +positions. + +```python +trace = nutpie.sample( + model, + use_grad_based_mass_matrix=False +) +``` + +### Low-Rank Updates + +For models with strong parameter correlations you can enable a low rank modified +mass matrix. The `mass_matrix_gamma` parameter is a regularization parameter. +More regularization will lead to a smaller effect of the low-rank components, +but might work better for hiegher dimensional problems. + +`mass_matrix_eigval_cutoff` should be greater than one, and controls how large +an eigenvalue of the full mass matrix has to be, to be included into the +low-rank mass matirx. + +```python +trace = nutpie.sample( + model, + low_rank_modified_mass_matrix=True, + mass_matrix_eigval_cutoff=3, + mass_matrix_gamma=1e-5 +) +``` + +### Experimental Features + +`trasform_adapt` is an experimental feature that allows sampling from many +posteriors, where current methods diverge. It is described in more detail +[here](nf-adapt.qmd). + +```python +trace = nutpie.sample( + model, + transform_adapt=True # Experimental reparameterization +) +``` + +## Progress Monitoring + +Customize the sampling progress display: + +```python +trace = nutpie.sample( + model, + progress_bar=True, + progress_rate=500, # Update every 500ms +) +``` diff --git a/docs/stan-usage.qmd b/docs/stan-usage.qmd new file mode 100644 index 0000000..7296231 --- /dev/null +++ b/docs/stan-usage.qmd @@ -0,0 +1,230 @@ +# Usage with Stan models + +This document shows how to use `nutpie` with Stan models. We will use the +`nutpie` package to define a simple model and sample from it using Stan. + +## Installation + +For Stan, it is more common to use `pip` or `uv` to install the necessary +packages. However, `conda` is also an option if you prefer. + +To install using `pip`: + +```bash +pip install "nutpie[stan]" +``` + +To install using `uv`: + +```bash +uv add "nutpie[stan]" +``` + +To install using `conda`: + +```bash +conda install -c conda-forge nutpie +``` + +## Compiler Toolchain + +Stan requires a compiler toolchain to be installed on your system. This is +necessary for compiling the Stan models. You can find detailed instructions for +setting up the compiler toolchain in the [CmdStan +Guide](https://mc-stan.org/docs/cmdstan-guide/installation.html#cpp-toolchain). + +Additionally, since Stan uses Intel's Threading Building Blocks (TBB) for +parallelism, you might need to set the `TBB_CXX_TYPE` environment variable to +specify the compiler type. Depending on your system, you can set it to either +`clang` or `gcc`. For example: + +```{python} +import os +os.environ["TBB_CXX_TYPE"] = "clang" # or 'gcc' +``` + +Make sure to set this environment variable before compiling your Stan models to ensure proper configuration. + +## Defining and Sampling a Simple Model + +We will define a simple Bayesian model using Stan and sample from it using +`nutpie`. + +### Model Definition + +In your Python script or Jupyter notebook, add the following code: + +```{python} +import nutpie + +model_code = """ +data { + int N; + vector[N] y; +} +parameters { + real mu; +} +model { + mu ~ normal(0, 1); + y ~ normal(mu, 1); +} +""" + +compiled_model = nutpie.compile_stan_model(code=model_code) +``` + +### Sampling + +We can now compile the model and sample from it: + +```{python} +compiled_model_with_data = compiled_model.with_data(N=3, y=[1, 2, 3]) +trace = nutpie.sample(compiled_model_with_data) +``` + +### Using Dimensions + +We'll use the radon model from +[this](https://mc-stan.org/learn-stan/case-studies/radon_cmdstanpy_plotnine.html) +case-study from the stan documentation, to show how we can use coordinates and +dimension names to simplify working with trace objects. + +We follow the same data preparation as in the case-study: + +```{python} +import pandas as pd +import numpy as np +import arviz as az +import seaborn as sns + +home_data = pd.read_csv( + "https://github.com/pymc-devs/pymc-examples/raw/refs/heads/main/examples/data/srrs2.dat", + index_col="idnum", +) +county_data = pd.read_csv( + "https://github.com/pymc-devs/pymc-examples/raw/refs/heads/main/examples/data/cty.dat", +) + +radon_data = ( + home_data + .rename(columns=dict(cntyfips="ctfips")) + .merge( + ( + county_data + .drop_duplicates(['stfips', 'ctfips', 'st', 'cty', 'Uppm']) + .set_index(["ctfips", "stfips"]) + ), + right_index=True, + left_on=["ctfips", "stfips"], + ) + .assign(log_radon=lambda x: np.log(np.clip(x.activity, 0.1, np.inf))) + .assign(log_uranium=lambda x: np.log(np.clip(x["Uppm"], 0.1, np.inf))) + .query("state == 'MN'") +) +``` + +And also use the partially pooled model from the case-study: + +```{python} +model_code = """ +data { + int N; // observations + int J; // counties + array[N] int county; + vector[N] x; + vector[N] y; +} +parameters { + real mu_alpha; + real sigma_alpha; + vector[J] alpha; // non-centered parameterization + real beta; + real sigma; +} +model { + y ~ normal(alpha[county] + beta * x, sigma); + alpha ~ normal(mu_alpha, sigma_alpha); // partial-pooling + beta ~ normal(0, 10); + sigma ~ normal(0, 10); + mu_alpha ~ normal(0, 10); + sigma_alpha ~ normal(0, 10); +} +generated quantities { + array[N] real y_rep = normal_rng(alpha[county] + beta * x, sigma); +} +""" +``` + +We collect the dataset in the format that the stan model requires, +and specify the dimensions of each of the non-scalar variables in the model: + +```{python} +county_idx, counties = pd.factorize(radon_data["county"], use_na_sentinel=False) +observations = radon_data.index + +coords = { + "county": counties, + "observation": observations, +} + +dims = { + "alpha": ["county"], + "y_rep": ["observation"], +} + +data = { + "N": len(observations), + "J": len(counties), + # Stan uses 1-based indexing! + "county": county_idx + 1, + "x": radon_data.log_uranium.values, + "y": radon_data.log_radon.values, +} +``` + +Then, we compile the model and provide the dimensions, coordinates and the +dataset we just defined: + +```{python} +compiled_model = ( + nutpie.compile_stan_model(code=model_code) + .with_data(**data) + .with_dims(**dims) + .with_coords(**coords) +) +``` + +```{python} +%%time +trace = nutpie.sample(compiled_model, seed=0) +``` + +As some basic convergance checking we verify that all Rhat values are smaller +than 1.02, all parameters have at least 500 effective draws and that we have no +divergences: + +```{python} +assert trace.sample_stats.diverging.sum() == 0 +assert az.ess(trace).min().min() > 500 +assert az.rhat(trace).max().max() > 1.02 +``` + +Thanks to the coordinates and dimensions we specified, the resulting trace will +now contain labeled data, so that plots based on it have properly set-up labels: + +```{python} +import arviz as az +import seaborn as sns +import xarray as xr + +sns.catplot( + data=trace.posterior.alpha.to_dataframe().reset_index(), + y="county", + x="alpha", + kind="boxen", + height=13, + aspect=1/2.5, + showfliers=False, +) +``` diff --git a/docs/styles.css b/docs/styles.css new file mode 100644 index 0000000..2ddf50c --- /dev/null +++ b/docs/styles.css @@ -0,0 +1 @@ +/* css styles */ diff --git a/pyproject.toml b/pyproject.toml index 6ef4dff..12c5445 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,17 +2,12 @@ requires = ["maturin>=1.1,<2.0"] build-backend = "maturin" -[tool.maturin] -module-name = "nutpie._lib" -python-source = "python" -features = ["pyo3/extension-module"] - [project] name = "nutpie" description = "Sample Stan or PyMC models" authors = [{ name = "PyMC Developers", email = "pymc.devs@gmail.com" }] readme = "README.md" -requires-python = ">=3.10,<3.13" +requires-python = ">=3.10,<3.14" license = { text = "MIT" } classifiers = [ "Programming Language :: Rust", @@ -32,56 +27,41 @@ dynamic = ["version"] stan = ["bridgestan >= 2.6.1"] pymc = ["pymc >= 5.20.1", "numba >= 0.60.0"] pymc-jax = ["pymc >= 5.20.1", "jax >= 0.4.27"] +nnflow = ["flowjax >= 17.1.0", "equinox >= 0.11.12"] +dev = [ + "bridgestan >= 2.6.1", + "pymc >= 5.20.1", + "numba >= 0.60.0", + "jax >= 0.4.27", + "flowjax >= 17.0.2", + "pytest", +] all = [ "bridgestan >= 2.6.1", "pymc >= 5.20.1", "numba >= 0.60.0", "jax >= 0.4.27", + "flowjax >= 17.1.0", + "equinox >= 0.11.12", ] [tool.ruff] line-length = 88 -target-version = "py39" +target-version = "py310" show-fixes = true output-format = "full" -[tool.ruff.lint] -select = [ - "E", # pycodestyle errors - "W", # pycodestyle warnings - "F", # Pyflakes - "I", # isort - "C4", # flake8-comprehensions - "B", # flake8-bugbear - "UP", # pyupgrade - "RUF", # Ruff-specific rules - "TID", # flake8-tidy-imports - "BLE", # flake8-blind-except - "PTH", # flake8-pathlib - "A", # flake8-builtins -] -ignore = [ - "C408", # unnecessary-collection-call (allow dict(a=1, b=2); clarity over speed!) - # The following list is recommended to disable these when using ruff's formatter. - # (Not all of the following are actually enabled.) - "W191", # tab-indentation - "E111", # indentation-with-invalid-multiple - "E114", # indentation-with-invalid-multiple-comment - "E117", # over-indented - "D206", # indent-with-spaces - "D300", # triple-single-quotes - "Q000", # bad-quotes-inline-string - "Q001", # bad-quotes-multiline-string - "Q002", # bad-quotes-docstring - "Q003", # avoidable-escaped-quote - "COM812", # missing-trailing-comma - "COM819", # prohibited-trailing-comma - "ISC001", # single-line-implicit-string-concatenation - "ISC002", # multi-line-implicit-string-concatenation -] - [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" [tool.ruff.lint.isort] known-first-party = ["nutpie"] + +[tool.pyright] +venvPath = ".pixi/envs/" +venv = "default" + +[tool.maturin] +module-name = "nutpie._lib" +python-source = "python" +features = ["pyo3/extension-module"] diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 53f3176..9daec90 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -7,7 +7,7 @@ from functools import wraps from importlib.util import find_spec from math import prod -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, cast import numpy as np import pandas as pd @@ -274,7 +274,7 @@ def _compile_pymc_model_numba( warnings.filterwarnings( "ignore", message="Cannot cache compiled function .* as it uses dynamic globals", - category=numba.NumbaWarning, + category=numba.NumbaWarning, # type: ignore ) logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw) @@ -287,7 +287,7 @@ def _compile_pymc_model_numba( warnings.filterwarnings( "ignore", message="Cannot cache compiled function .* as it uses dynamic globals", - category=numba.NumbaWarning, + category=numba.NumbaWarning, # type: ignore ) expand_numba = numba.cfunc(c_sig_expand, **kwargs)(expand_numba_raw) @@ -377,14 +377,24 @@ def _compile_pymc_model_jax( logp_fn = logp_fn_pt.vm.jit_fn expand_fn = expand_fn_pt.vm.jit_fn + logp_shared_names = [var.name for var in logp_fn_pt.get_shared()] + expand_shared_names = [var.name for var in expand_fn_pt.get_shared()] + if gradient_backend == "jax": orig_logp_fn = logp_fn._fun - @jax.jit def logp_fn_jax_grad(x, *shared): return jax.value_and_grad(lambda x: orig_logp_fn(x, *shared)[0])(x) + # static_argnums = list(range(1, len(logp_shared_names) + 1)) + logp_fn_jax_grad = jax.jit( + logp_fn_jax_grad, + # static_argnums=static_argnums, + ) + logp_fn = logp_fn_jax_grad + else: + orig_logp_fn = None shared_data = {} shared_vars = {} @@ -396,9 +406,6 @@ def logp_fn_jax_grad(x, *shared): shared_vars[val.name] = val seen.add(val) - logp_shared_names = [var.name for var in logp_fn_pt.get_shared()] - expand_shared_names = [var.name for var in expand_fn_pt.get_shared()] - def make_logp_func(): def logp(x, **shared): logp, grad = logp_fn(x, *[shared[name] for name in logp_shared_names]) @@ -407,7 +414,8 @@ def logp(x, **shared): return logp names, slices, shapes = shape_info - dtypes = [np.float64] * len(names) + # TODO do not cast to float64 + dtypes = [np.dtype("float64")] * len(names) def make_expand_func(seed1, seed2, chain): # TODO handle seeds @@ -433,6 +441,7 @@ def expand(x, **shared): shared_data=shared_data, dims=dims, coords=coords, + raw_logp_fn=orig_logp_fn, ) @@ -635,7 +644,7 @@ def _make_functions( """ import pytensor import pytensor.tensor as pt - from pymc.pytensorf import compile_pymc + from pymc.pytensorf import compile as compile_pymc shapes = _compute_shapes(model) @@ -726,7 +735,7 @@ def _make_functions( for var in remaining_rvs: all_names.append(var.name) - shape = shapes[var.name] + shape = cast(tuple[int, ...], shapes[var.name]) all_shapes.append(shape) length = prod(shape) all_slices.append(slice(count, count + length)) diff --git a/python/nutpie/compile_stan.py b/python/nutpie/compile_stan.py index 7a28052..138652d 100644 --- a/python/nutpie/compile_stan.py +++ b/python/nutpie/compile_stan.py @@ -1,6 +1,7 @@ import json import tempfile from dataclasses import dataclass, replace +from functools import partial from importlib.util import find_spec from pathlib import Path from typing import Any, Optional @@ -11,6 +12,7 @@ from nutpie import _lib from nutpie.sample import CompiledModel +from nutpie.transform_adapter import make_transform_adapter class _NumpyArrayEncoder(json.JSONEncoder): @@ -28,6 +30,7 @@ class CompiledStanModel(CompiledModel): library: Any model: Any model_name: Optional[str] = None + _transform_adapt_args: dict | None = None def with_data(self, *, seed=None, **updates): if self.data is None: @@ -42,7 +45,15 @@ def with_data(self, *, seed=None, **updates): else: data_json = None - model = _lib.StanModel(self.library, seed, data_json) + kwargs = self._transform_adapt_args + if kwargs is None: + kwargs = {} + make_adapter = partial( + make_transform_adapter(**kwargs), + logp_fn=None, + ) + + model = _lib.StanModel(self.library, seed, data_json, make_adapter) coords = self._coords if coords is None: coords = {} @@ -75,6 +86,9 @@ def with_dims(self, **dims): dims_new.update(dims) return replace(self, dims=dims_new) + def with_transform_adapt(self, **kwargs): + return replace(self, _transform_adapt_args=kwargs).with_data() + def _make_model(self, init_mean): if self.model is None: return self.with_data().model diff --git a/python/nutpie/compiled_pyfunc.py b/python/nutpie/compiled_pyfunc.py index fb6553d..304d50f 100644 --- a/python/nutpie/compiled_pyfunc.py +++ b/python/nutpie/compiled_pyfunc.py @@ -5,8 +5,9 @@ import numpy as np -from nutpie import _lib +from nutpie import _lib # type: ignore from nutpie.sample import CompiledModel +from nutpie.transform_adapter import make_transform_adapter SeedType = int @@ -20,6 +21,8 @@ class PyFuncModel(CompiledModel): _n_dim: int _variables: list[_lib.PyVariable] _coords: dict[str, Any] + _raw_logp_fn: Callable | None + _transform_adapt_args: dict | None = None @property def shapes(self) -> dict[str, tuple[int, ...]]: @@ -42,6 +45,9 @@ def with_data(self, **updates): updated.update(**updates) return dataclasses.replace(self, _shared_data=updated) + def with_transform_adapt(self, **kwargs): + return dataclasses.replace(self, _transform_adapt_args=kwargs) + def _make_sampler(self, settings, init_mean, cores, progress_type): model = self._make_model(init_mean) return _lib.PySampler.from_pyfunc( @@ -60,12 +66,24 @@ def make_expand_func(seed1, seed2, chain): expand_fn = self._make_expand_func(seed1, seed2, chain) return partial(expand_fn, **self._shared_data) + if self._raw_logp_fn is not None: + kwargs = self._transform_adapt_args + if kwargs is None: + kwargs = {} + make_adapter = partial( + make_transform_adapter(**kwargs), + logp_fn=self._raw_logp_fn, + ) + else: + make_adapter = None + return _lib.PyModel( make_logp_func, make_expand_func, self._variables, self.n_dim, - self._make_initial_points, + init_point_func=self._make_initial_points, + transform_adapter=make_adapter, ) @@ -81,6 +99,8 @@ def from_pyfunc( dims: dict[str, tuple[str, ...]] | None = None, shared_data: dict[str, Any] | None = None, make_initial_point_fn: Callable[[SeedType], np.ndarray] | None = None, + make_transform_adapter=None, + raw_logp_fn=None, ): variables = [] for name, shape, dtype in zip( @@ -111,4 +131,5 @@ def from_pyfunc( _make_initial_points=make_initial_point_fn, _variables=variables, _shared_data=shared_data, + _raw_logp_fn=raw_logp_fn, ) diff --git a/python/nutpie/normalizing_flow.py b/python/nutpie/normalizing_flow.py new file mode 100644 index 0000000..e92f7b0 --- /dev/null +++ b/python/nutpie/normalizing_flow.py @@ -0,0 +1,1064 @@ +from typing import ClassVar, Union, Literal, Callable +import math +import itertools + +from flowjax.bijections.coupling import get_ravelled_pytree_constructor +from flowjax.utils import arraylike_to_array +import jax +import jax.numpy as jnp +import equinox as eqx +from flowjax import bijections +import flowjax.distributions +import flowjax.flows +from jaxtyping import Array, ArrayLike +import numpy as np +from paramax import NonTrainable, Parameterize +from equinox.nn import Linear +from paramax.wrappers import AbstractUnwrappable + + +_NN_ACTIVATION = jax.nn.gelu + + +def _generate_sequences(k, r_vals): + """ + Generate all binary sequences of length k with exactly r 1's. + The sequences are stored in a preallocated boolean NumPy array of shape (N, k), + where N = comb(k, r). A True value represents a '1' and False represents a '0'. + + Parameters: + k (int): The length of each sequence. + r (int): The exact number of ones in each sequence. + + Returns: + A NumPy boolean array of shape (comb(k, r), k) containing all sequences. + """ + all_sequences = [] + for r in r_vals: + N = math.comb(k, r) # number of sequences + sequences = np.zeros((N, k), dtype=bool) + # Use enumerate on all combinations where ones appear. + for i, ones_positions in enumerate(itertools.combinations(range(k), r)): + sequences[i, list(ones_positions)] = True + all_sequences.append(sequences) + return np.concatenate(all_sequences, axis=0) + + +def _max_run_length(seq): + """ + Given a 1D boolean NumPy array 'seq', compute the maximum run length of consecutive + identical values (either True or False). + + Parameters: + seq (np.array): A 1D boolean array. + + Returns: + The length (int) of the longest run. + """ + # If the sequence is empty, return 0. + if seq.size == 0: + return 0 + + # Convert boolean to int (0 or 1) so we can use np.diff. + arr = seq.astype(int) + # Compute differences between consecutive elements. + diffs = np.diff(arr) + # Positions where the value changes: + change_indices = np.nonzero(diffs)[0] + + if change_indices.size == 0: + # No changes at all, so the entire sequence is one run. + return seq.size + + # To compute the run lengths, add the "start" index (-1) and the last index. + # For example, if change_indices = [i1, i2, ..., in], + # then the runs are: (i1 - (-1)), (i2 - i1), ..., (seq.size-1 - in). + boundaries = np.concatenate(([-1], change_indices, [seq.size - 1])) + run_lengths = np.diff(boundaries) + return int(run_lengths.max()) + + +def _filter_sequences(sequences, m): + """ + Filter a 2D NumPy boolean array 'sequences' (each row a binary sequence) so that + only sequences with maximum run length (of 0's or 1's) at most m are kept. + + Parameters: + sequences (np.array): A 2D boolean array of shape (N, k). + m (int): Maximum allowed run length. + + Returns: + A NumPy array containing only the rows (sequences) that pass the filter. + """ + filtered = [] + for seq in sequences: + if _max_run_length(seq) <= m: + filtered.append(seq) + return np.array(filtered) + + +def _generate_permutations(rng, n_dim, n_layers, max_run=3): + if n_layers == 1: + r = [0, 1] + elif n_layers == 2: + r = [1] + else: + if n_layers % 2 == 0: + half = n_layers // 2 + r = [half - 1, half, half + 1] + else: + half = n_layers // 2 + r = [half, half + 1] + + all_sequences = _generate_sequences(n_layers, r) + valid_sequences = _filter_sequences(all_sequences, max_run) + + valid_sequences = np.repeat( + valid_sequences, n_dim // len(valid_sequences) + 1, axis=0 + ) + rng.shuffle(valid_sequences, axis=0) + is_in_first = valid_sequences[:n_dim] + rng = np.random.default_rng(42) + permutations = (~is_in_first).argsort(axis=0, kind="stable") + return permutations.T, is_in_first.sum(0) + + +class FactoredMLP(eqx.Module, strict=True): + """Standard Multi-Layer Perceptron; also known as a feed-forward network. + + !!! faq + + If you get a TypeError saying an object is not a valid JAX type, see the + [FAQ](https://docs.kidger.site/equinox/faq/).""" + + layers: tuple[tuple[Linear, Linear], ...] + activation: tuple[Callable, ...] + final_activation: Callable + use_bias: bool = eqx.field(static=True) + use_final_bias: bool = eqx.field(static=True) + in_size: Union[int, Literal["scalar"]] = eqx.field(static=True) + out_size: Union[int, Literal["scalar"]] = eqx.field(static=True) + width_size: tuple[int, ...] = eqx.field(static=True) + depth: int = eqx.field(static=True) + + def __init__( + self, + in_size: Union[int, Literal["scalar"]], + out_size: Union[int, Literal["scalar"]], + width_size: int | tuple[int | tuple[int, int], ...], + depth: int, + activation: Callable = jax.nn.relu, + final_activation: Callable = lambda x: x, + use_bias: bool = True, + use_final_bias: bool = True, + dtype=None, + *, + key, + ): + """**Arguments**: + + - `in_size`: The input size. The input to the module should be a vector of + shape `(in_features,)` + - `out_size`: The output size. The output from the module will be a vector + of shape `(out_features,)`. + - `width_size`: The size of each hidden layer. + - `depth`: The number of hidden layers, including the output layer. + For example, `depth=2` results in an network with layers: + [`Linear(in_size, width_size)`, `Linear(width_size, width_size)`, + `Linear(width_size, out_size)`]. + - `activation`: The activation function after each hidden layer. Defaults to + ReLU. + - `final_activation`: The activation function after the output layer. Defaults + to the identity. + - `use_bias`: Whether to add on a bias to internal layers. Defaults + to `True`. + - `use_final_bias`: Whether to add on a bias to the final layer. Defaults + to `True`. + - `dtype`: The dtype to use for all the weights and biases in this MLP. + Defaults to either `jax.numpy.float32` or `jax.numpy.float64` depending + on whether JAX is in 64-bit mode. + - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter + initialisation. (Keyword only argument.) + + Note that `in_size` also supports the string `"scalar"` as a special value. + In this case the input to the module should be of shape `()`. + + Likewise `out_size` can also be a string `"scalar"`, in which case the + output from the module will have shape `()`. + """ + keys = jax.random.split(key, depth + 1) + layers = [] + if isinstance(width_size, int): + width_size = (width_size,) * depth + + assert len(width_size) == depth + activations: list[Callable] = [] + + if depth == 0: + layers.append( + Linear(in_size, out_size, use_final_bias, dtype=dtype, key=keys[0]) + ) + else: + if isinstance(width_size[0], tuple): + n, k = width_size[0] + key1, key2 = jax.random.split(keys[0]) + U = Linear(in_size, n, use_bias=False, dtype=dtype, key=key1) + K = Linear(n, k, use_bias=True, dtype=dtype, key=key2) + layers.append((U, K)) + else: + k = width_size[0] + layers.append(Linear(in_size, k, use_bias, dtype=dtype, key=keys[0])) + activations.append(eqx.filter_vmap(lambda: activation, axis_size=k)()) + + for i in range(depth - 1): + if isinstance(width_size[i + 1], tuple): + n, k_new = width_size[i + 1] + key1, key2 = jax.random.split(keys[i + 1]) + U = Linear(k, n, use_bias=False, dtype=dtype, key=key1) + K = Linear(n, k_new, use_bias=True, dtype=dtype, key=key2) + layers.append((U, K)) + k = k_new + else: + layers.append( + Linear( + k, width_size[i + 1], use_bias, dtype=dtype, key=keys[i + 1] + ) + ) + k = width_size[i + 1] + activations.append(eqx.filter_vmap(lambda: activation, axis_size=k)()) + + if isinstance(out_size, tuple): + n, k_new = out_size + key1, key2 = jax.random.split(keys[-1]) + U = Linear(k, n, use_bias=False, dtype=dtype, key=key1) + K = Linear(n, k_new, use_bias=True, dtype=dtype, key=key2) + k = k_new + layers.append((U, K)) + else: + layers.append( + Linear(k, out_size, use_final_bias, dtype=dtype, key=keys[-1]) + ) + self.layers = tuple(layers) + self.in_size = in_size + self.out_size = out_size + self.width_size = width_size + self.depth = depth + # In case `activation` or `final_activation` are learnt, then make a separate + # copy of their weights for every neuron. + self.activation = tuple(activations) + if out_size == "scalar": + self.final_activation = final_activation + else: + self.final_activation = eqx.filter_vmap( + lambda: final_activation, axis_size=out_size + )() + self.use_bias = use_bias + self.use_final_bias = use_final_bias + + @jax.named_scope("eqx.nn.MLP") + def __call__(self, x: jax.Array, *, key=None) -> jax.Array: + """**Arguments:** + + - `x`: A JAX array with shape `(in_size,)`. (Or shape `()` if + `in_size="scalar"`.) + - `key`: Ignored; provided for compatibility with the rest of the Equinox API. + (Keyword only argument.) + + **Returns:** + + A JAX array with shape `(out_size,)`. (Or shape `()` if `out_size="scalar"`.) + """ + for i, (layer, act) in enumerate(zip(self.layers[:-1], self.activation)): + if isinstance(layer, tuple): + U, K = layer + x = U(x) + x = K(x) + else: + x = layer(x) + layer_activation = jax.tree.map( + lambda x: x[i] if eqx.is_array(x) else x, act + ) + x = eqx.filter_vmap(lambda a, b: a(b))(layer_activation, x) + + if isinstance(self.layers[-1], tuple): + U, K = self.layers[-1] + x = U(x) + x = K(x) + else: + x = self.layers[-1](x) + + if self.out_size == "scalar": + x = self.final_activation(x) + else: + x = eqx.filter_vmap(lambda a, b: a(b))(self.final_activation, x) + return x + + +class AsymmetricAffine(bijections.AbstractBijection): + """An asymmetric bijection that applies different scaling factors for + positive and negative inputs. + + This bijection implements a continuous, differentiable transformation that + scales positive and negative inputs differently while maintaining smoothness + at zero. It's particularly useful for modeling data with different variances + in positive and negative regions. + + The forward transformation is defined as: + y = σ θ x for x ≥ 0 + y = σ x/θ for x < 0 + where: + - σ (scale) controls the overall scaling + - θ (theta) controls the asymmetry between positive and negative regions + - μ (loc) controls the location shift + + The transformation uses a smooth transition between the two regions to + maintain differentiability. + + For θ = 0, this is exactly an affine function with the specified location + and scale. + + Attributes: + shape: The shape of the transformation parameters + cond_shape: Shape of conditional inputs (None as this bijection is + unconditional) + loc: Location parameter μ for shifting the distribution + scale: Scale parameter σ (positive) + theta: Asymmetry parameter θ (positive) + """ + + shape: tuple[int, ...] = () + cond_shape: ClassVar[None] = None + loc: Array + scale: Array | AbstractUnwrappable[Array] + theta: Array | AbstractUnwrappable[Array] + + def __init__( + self, + loc: ArrayLike = 0, + scale: ArrayLike = 1, + theta: ArrayLike = 1, + ): + self.loc, scale, theta = jnp.broadcast_arrays( + *(arraylike_to_array(a, dtype=float) for a in (loc, scale, theta)), + ) + self.shape = scale.shape + assert self.shape == () + self.scale = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(())) + self.theta = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(())) + + def _log_derivative_f(self, x, mu, sigma, theta): + abs_x = jnp.abs(x) + theta = jnp.log(theta) + + sinh_theta = jnp.sinh(theta) + # sinh_theta = (theta - 1 / theta) / 2 + cosh_theta = jnp.cosh(theta) + # cosh_theta = (theta + 1 / theta) / 2 + numerator = sinh_theta * x * (abs_x + 2.0) + denominator = (abs_x + 1.0) ** 2 + term = numerator / denominator + dy_dx = sigma * (cosh_theta + term) + return jnp.log(dy_dx) + + def transform_and_log_det( + self, x: ArrayLike, condition: ArrayLike | None = None + ) -> tuple[Array, Array]: + def transform(x, mu, sigma, theta): + weight = (jax.nn.soft_sign(x) + 1) / 2 + z = x * sigma + y_pos = z * theta + y_neg = z / theta + y = weight * y_pos + (1.0 - weight) * y_neg + mu + return y + + mu, sigma, theta = self.loc, self.scale, self.theta + + y = transform(x, mu, sigma, theta) + logjac = self._log_derivative_f(x, mu, sigma, theta) + return y, logjac.sum() + # y, jac = jax.value_and_grad(transform, argnums=0)(x, mu, sigma, theta) + # return y, jnp.log(jac) + + def inverse_and_log_det( + self, y: ArrayLike, condition: ArrayLike | None = None + ) -> tuple[Array, Array]: + def inverse(y, mu, sigma, theta): + delta = y - mu + inv_theta = 1 / theta + + # Case 1: y >= mu (delta >= 0) + a = sigma * (theta + inv_theta) + discriminant_pos = ( + jnp.square(a - 2.0 * delta) + 16.0 * sigma * theta * delta + ) + discriminant_pos = jnp.where(discriminant_pos < 0, 1.0, discriminant_pos) + sqrt_pos = jnp.sqrt(discriminant_pos) + numerator_pos = 2.0 * delta - a + sqrt_pos + denominator_pos = 4.0 * sigma * theta + x_pos = numerator_pos / denominator_pos + + # Case 2: y < mu (delta < 0) + sigma_part = sigma * (1.0 + theta * theta) + term2 = 2.0 * delta * theta + inside_sqrt_neg = ( + jnp.square(sigma_part + term2) - 16.0 * sigma * delta * theta + ) + inside_sqrt_neg = jnp.where(inside_sqrt_neg < 0, 1.0, inside_sqrt_neg) + sqrt_neg = jnp.sqrt(inside_sqrt_neg) + numerator_neg = sigma_part + term2 - sqrt_neg + denominator_neg = 4.0 * sigma + x_neg = numerator_neg / denominator_neg + + # Combine cases based on delta + x = jnp.where(delta >= 0.0, x_pos, x_neg) + return x + + mu, sigma, theta = self.loc, self.scale, self.theta + + x = inverse(y, mu, sigma, theta) + logjac = self._log_derivative_f(x, mu, sigma, theta) + return x, -logjac.sum() + # x, jac = jax.value_and_grad(inverse, argnums=0)(y, mu, sigma, theta) + # return x, jnp.log(jac) + + +class MvScale(bijections.AbstractBijection): + shape: tuple[int, ...] + params: Array + cond_shape = None + base_index: int + + def __init__(self, params: Array, base_index: int = 0): + self.shape = (params.shape[-1],) + self.params = params + self.base_index = base_index + + def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None): + scale = jnp.linalg.norm(self.params) + v = self.params / scale + y = x + ((v @ x) * (scale - 1)) * v + return y, jnp.log(scale) + + def inverse_and_log_det(self, y: Array, condition: Array | None = None): + scale = jnp.linalg.norm(self.params) + v = self.params / scale + x = y + ((v @ y) * (1 / scale - 1)) * v + return x, -jnp.log(scale) + + +class Coupling(bijections.AbstractBijection): + """Coupling layer implementation (https://arxiv.org/abs/1605.08803). + + Args: + key: Jax key + transformer: Unconditional bijection with shape () to be parameterised by the + conditioner neural netork. Parameters wrapped with ``NonTrainable`` + are excluded from being parameterized. + untransformed_dim: Number of untransformed conditioning variables (e.g. dim//2). + dim: Total dimension. + cond_dim: Dimension of additional conditioning variables. Defaults to None. + nn_width: Neural network hidden layer width. + nn_depth: Neural network hidden layer size. + nn_activation: Neural network activation function. Defaults to jnn.relu. + """ + + shape: tuple[int, ...] + cond_shape: tuple[int, ...] | None + untransformed_dim: int + dim: int + transformer_constructor: Callable + requires_vmap: bool + conditioner: eqx.nn.MLP | eqx.Module + + def __init__( + self, + key, + *, + transformer: bijections.AbstractBijection, + untransformed_dim: int, + dim: int, + cond_dim: int | None = None, + nn_width: int, + nn_depth: int, + nn_activation: Callable = jax.nn.relu, + conditioner: eqx.Module | None = None, + ): + if transformer.cond_shape is not None: + raise ValueError( + "Only unconditional transformers are supported.", + ) + n_transformed = dim - untransformed_dim + if n_transformed < 0: + raise ValueError( + "The number of untransformed variables must be less than the total " + "dimension.", + ) + if transformer.shape != () and transformer.shape != (n_transformed,): + raise ValueError( + "The transformer must have shape () or (n_transformed,), " + f"got {transformer.shape}.", + ) + + constructor, num_params = get_ravelled_pytree_constructor( + transformer, + filter_spec=eqx.is_inexact_array, + is_leaf=lambda leaf: isinstance(leaf, NonTrainable), + ) + + if transformer.shape == (): + self.requires_vmap = True + conditioner_output_size = num_params * n_transformed + else: + self.requires_vmap = False + conditioner_output_size = num_params + + self.transformer_constructor = constructor + self.untransformed_dim = untransformed_dim + self.dim = dim + self.shape = (dim,) + self.cond_shape = (cond_dim,) if cond_dim is not None else None + + if conditioner is None: + conditioner = eqx.nn.MLP( + in_size=( + untransformed_dim + if cond_dim is None + else untransformed_dim + cond_dim + ), + out_size=conditioner_output_size, + width_size=nn_width, + depth=nn_depth, + activation=nn_activation, + key=key, + ) + self.conditioner = conditioner(conditioner_output_size) + + def transform_and_log_det(self, x, condition=None): + x_cond, x_trans = x[: self.untransformed_dim], x[self.untransformed_dim :] + nn_input = x_cond if condition is None else jnp.hstack((x_cond, condition)) + transformer_params = self.conditioner(nn_input) + transformer = self._flat_params_to_transformer(transformer_params) + y_trans, log_det = transformer.transform_and_log_det(x_trans) + y = jnp.hstack((x_cond, y_trans)) + return y, log_det + + def inverse_and_log_det(self, y, condition=None): + x_cond, y_trans = y[: self.untransformed_dim], y[self.untransformed_dim :] + nn_input = x_cond if condition is None else jnp.concatenate((x_cond, condition)) + transformer_params = self.conditioner(nn_input) + transformer = self._flat_params_to_transformer(transformer_params) + x_trans, log_det = transformer.inverse_and_log_det(y_trans) + x = jnp.hstack((x_cond, x_trans)) + return x, log_det + + def _flat_params_to_transformer(self, params: Array): + """Reshape to dim X params_per_dim, then vmap.""" + if self.requires_vmap: + dim = self.dim - self.untransformed_dim + transformer_params = jnp.reshape(params, (dim, -1)) + transformer = eqx.filter_vmap(self.transformer_constructor)( + transformer_params + ) + return bijections.Vmap(transformer, in_axes=eqx.if_array(0)) + else: + transformer = self.transformer_constructor(params) + return transformer + + +def make_mvscale(key, n_dim, size, randomize_base=False): + def make_single_hh(key, idx): + key1, key2 = jax.random.split(key) + params = jax.random.normal(key1, (n_dim,)) + params = params / jnp.linalg.norm(params) + mvscale = MvScale(params) + return mvscale + + keys = jax.random.split(key, size) + + if randomize_base: + key, key_base = jax.random.split(key) + indices = jax.random.randint(key_base, (size,), 0, n_dim) + else: + indices = [val % n_dim for val in range(size)] + + return bijections.Chain( + [make_single_hh(key, idx) for key, idx in zip(keys, indices)] + ) + + +def make_hh(key, n_dim, size, randomize_base=False): + def make_single_hh(key, idx): + key1, key2 = jax.random.split(key) + params = jax.random.normal(key1, (n_dim,)) * 1e-2 + return bijections.Householder(params, base_index=idx) + + keys = jax.random.split(key, size) + + if randomize_base: + key, key_base = jax.random.split(key) + indices = jax.random.randint(key_base, (size,), 0, n_dim) + else: + indices = [val % n_dim for val in range(size)] + + return bijections.Chain( + [make_single_hh(key, idx) for key, idx in zip(keys, indices)] + ) + + +def make_elemwise_trafo(key, n_dim, *, count=1): + def make_elemwise(key, loc): + key1, key2 = jax.random.split(key) + scale = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(())) + theta = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(())) + + affine = AsymmetricAffine( + loc, + jnp.ones(()), + jnp.ones(()), + ) + + affine = eqx.tree_at( + where=lambda aff: aff.scale, + pytree=affine, + replace=scale, + ) + affine = eqx.tree_at( + where=lambda aff: aff.theta, + pytree=affine, + replace=theta, + ) + + return bijections.Invert(affine) + + def make(key): + keys = jax.random.split(key, count + 1) + key, keys = keys[0], keys[1:] + loc = jax.random.normal(key=key, shape=(count,)) * 2 + loc = loc - loc.mean() + return bijections.Chain([make_elemwise(key, mu) for key, mu in zip(keys, loc)]) + + keys = jax.random.split(key, n_dim) + make_affine = eqx.filter_vmap(make, axis_size=n_dim)(keys) + return bijections.Vmap(make_affine, in_axes=eqx.if_array(0)) + + +def make_coupling(key, dim, n_untransformed, *, inner_mvscale=False, **kwargs): + n_transformed = dim - n_untransformed + + nn_width = kwargs.get("nn_width", None) + nn_depth = kwargs.get("nn_depth", None) + + if nn_width is None: + if dim > 128: + nn_width = (64, 2 * dim) + else: + nn_width = 2 * dim + + if nn_depth is None: + if isinstance(nn_width, int): + nn_depth = 1 + else: + nn_depth = len(nn_width) + + transformer = make_elemwise_trafo(key, n_transformed, count=3) + + if inner_mvscale: + mvscale = make_mvscale(key, n_transformed, 1, randomize_base=True) + transformer = bijections.Chain([transformer, mvscale]) + + def make_mlp(out_size): + if isinstance(nn_width, tuple): + out = (nn_width[0], out_size) + else: + out = out_size + + return FactoredMLP( + n_untransformed, + out, + nn_width, + depth=nn_depth, + key=key, + dtype=jnp.float32, + activation=_NN_ACTIVATION, + ) + + return Coupling( + key, + transformer=transformer, + untransformed_dim=n_untransformed, + dim=dim, + conditioner=make_mlp, + **kwargs, + ) + + +def make_flow( + seed, + positions, + gradients, + *, + zero_init=False, + householder_layer=False, + dct_layer=False, + untransformed_dim: int | list[int | None] | None = None, + n_layers, + nn_width=None, + nn_depth=None, +): + from flowjax import bijections + + positions = np.array(positions) + gradients = np.array(gradients) + + if len(positions) == 0: + return + + n_draws, n_dim = positions.shape + + if n_dim < 2: + n_layers = 0 + + assert positions.shape == gradients.shape + + if n_draws == 0: + raise ValueError("No draws") + elif n_draws == 1: + assert np.all(gradients != 0) + diag = np.clip(1 / jnp.sqrt(jnp.abs(gradients[0])), 1e-5, 1e5) + assert np.isfinite(diag).all() + mean = jnp.zeros_like(diag) + else: + pos_std = np.clip(positions.std(0), 1e-8, 1e8) + grad_std = np.clip(gradients.std(0), 1e-8, 1e8) + diag = jnp.sqrt(pos_std / grad_std) + mean = positions.mean(0) + gradients.mean(0) * diag * diag + + key = jax.random.PRNGKey(seed % (2**63)) + + diag_param = Parameterize( + lambda x: x + jnp.sqrt(1 + x**2), + (diag**2 - 1) / (2 * diag), + ) + diag_affine = bijections.Affine(mean, diag) + diag_affine = eqx.tree_at( + where=lambda aff: aff.scale, + pytree=diag_affine, + replace=diag_param, + ) + + flows = [ + diag_affine, + ] + + if n_layers == 0: + return bijections.Chain(flows) + + def make_layer(key, untransformed_dim: int | None, permutation=None): + key, key_couple, key_permute, key_hh = jax.random.split(key, 4) + + if untransformed_dim is None: + untransformed_dim = n_dim // 2 + + if untransformed_dim < 0: + untransformed_dim = n_dim + untransformed_dim + + coupling = make_coupling( + key_couple, + n_dim, + untransformed_dim, + nn_activation=_NN_ACTIVATION, + nn_width=nn_width, + nn_depth=nn_depth, + ) + + if zero_init: + coupling = jax.tree_util.tree_map( + lambda x: x * 1e-3 if eqx.is_inexact_array(x) else x, + coupling, + ) + + flow = coupling + + if householder_layer: + hh = make_hh(key_hh, n_dim, 1, randomize_base=False) + flow = bijections.Sandwich(flow, hh) + + def add_default_permute(bijection, dim, key): + if dim == 1: + return bijection + if dim == 2: + outer = bijections.Flip((dim,)) + else: + outer = bijections.Permute(jax.random.permutation(key, jnp.arange(dim))) + + return bijections.Sandwich(bijection, outer) + + if permutation is None: + flow = add_default_permute(flow, n_dim, key_permute) + else: + flow = bijections.Sandwich(flow, bijections.Permute(permutation)) + + mvscale = make_mvscale(key, n_dim, 1, randomize_base=True) + + flow = bijections.Chain( + [ + mvscale, + flow, + ] + ) + + return flow + + key, key_permute = jax.random.split(key) + keys = jax.random.split(key, n_layers) + + if untransformed_dim is None: + # TODO better rng? + rng = np.random.default_rng(int(jax.random.randint(key, (), 0, 2**30))) + permutation, lengths = _generate_permutations(rng, n_dim, n_layers) + layers = [] + for i, (key, p, length) in enumerate(zip(keys, permutation, lengths)): + layers.append(make_layer(key, int(length), p)) + bijection = bijections.Chain(layers) + elif isinstance(untransformed_dim, int): + make_layers = eqx.filter_vmap(make_layer) + layers = make_layers(keys, untransformed_dim) + bijection = bijections.Scan(layers) + else: + layers = [] + for i, (key, num_untrafo) in enumerate(zip(keys, untransformed_dim)): + if i % 2 == 0 or not dct_layer: + layers.append(make_layer(key, num_untrafo)) + else: + inner = make_layer(key, num_untrafo) + outer = bijections.DCT(inner.shape) + + layers.append(bijections.Sandwich(inner, outer)) + + bijection = bijections.Chain(layers) + + return bijections.Chain([bijection, *flows]) + + +def extend_flow( + key, + base, + loss_fn, + positions, + gradients, + logps, + layer: int, + *, + extension_var_count=4, + zero_init=False, + householder_layer=False, + untransformed_dim: int | list[int | None] | None = None, + dct: bool = False, + extension_var_trafo_count=2, + verbose: bool = False, + nn_width=None, + nn_depth=None, +): + n_draws, n_dim = positions.shape + + if n_dim < 2: + return base + + if n_dim <= extension_var_count: + extension_var_count = n_dim - 1 + extension_var_trafo_count = 1 + + if dct: + flow = flowjax.flows.Transformed( + flowjax.distributions.StandardNormal(base.shape), + bijections.Chain([bijections.DCT(shape=(n_dim,)), base]), + ) + else: + flow = flowjax.flows.Transformed( + flowjax.distributions.StandardNormal(base.shape), base + ) + + params, static = eqx.partition(flow, eqx.is_inexact_array) + costs = loss_fn( + params, + static, + positions, + gradients, + logps, + return_elemwise_costs=True, + ) + + if verbose: + print(max(costs), costs) + print("dct:", dct) + idxs = np.argsort(costs) + + permute = bijections.Permute(idxs) + + if True: + scale = Parameterize( + lambda x: x + jnp.sqrt(1 + x**2), + jnp.array(0.0), + ) + theta = Parameterize( + lambda x: x + jnp.sqrt(1 + x**2), + jnp.array(0.0), + ) + + affine = bijections.AsymmetricAffine(jnp.zeros(()), jnp.ones(()), jnp.ones(())) + + affine = eqx.tree_at( + where=lambda aff: aff.scale, + pytree=affine, + replace=scale, + ) + affine = eqx.tree_at( + where=lambda aff: aff.theta, + pytree=affine, + replace=theta, + ) + + do_flip = layer % 2 == 0 + + if nn_width is None: + width = 16 + else: + width = nn_width + + if do_flip: + coupling = bijections.coupling.Coupling( + key, + transformer=affine, + untransformed_dim=n_dim - extension_var_trafo_count, + dim=n_dim, + nn_activation=_NN_ACTIVATION, + nn_width=width, + nn_depth=nn_depth, + ) + + inner_permute = bijections.Permute( + jnp.concatenate( + [ + jnp.arange(n_dim - extension_var_count), + jax.random.permutation( + key, jnp.arange(n_dim - extension_var_count, n_dim) + ), + ] + ) + ) + else: + coupling = bijections.coupling.Coupling( + key, + transformer=affine, + untransformed_dim=extension_var_trafo_count, + dim=n_dim, + nn_activation=_NN_ACTIVATION, + nn_width=width, + nn_depth=nn_depth, + ) + + inner_permute = bijections.Permute( + jnp.concatenate( + [ + jax.random.permutation( + key, jnp.arange(n_dim - extension_var_count, n_dim) + ), + jnp.arange(n_dim - extension_var_count), + ] + ) + ) + + if zero_init: + coupling = jax.tree_util.tree_map( + lambda x: x * 1e-3 if eqx.is_inexact_array(x) else x, + coupling, + ) + + inner = bijections.Sandwich(coupling, inner_permute) + + if False: + scale = Parameterize( + lambda x: x + jnp.sqrt(1 + x**2), + jnp.array(0.0), + ) + affine = eqx.tree_at( + where=lambda aff: aff.scale, + pytree=flowjax.bijections.Affine(), + replace=scale, + ) + + if nn_width is None: + width = 16 + else: + width = nn_width + + coupling = flowjax.bijections.coupling.Coupling( + key, + transformer=affine, + untransformed_dim=extension_var_trafo_count, + dim=n_dim, + nn_activation=_NN_ACTIVATION, + nn_width=width, + nn_depth=nn_depth, + ) + + if zero_init: + coupling = jax.tree_util.tree_map( + lambda x: x * 1e-3 if eqx.is_inexact_array(x) else x, + coupling, + ) + + if verbose: + print(costs[permute.permutation][inner.outer.permutation]) + + inner = bijections.Sandwich( + bijections.Chain( + [ + bijections.Sandwich(coupling, bijections.Flip(shape=(n_dim,))), + inner.inner, + ] + ), + inner.outer, + ) + + if dct: + new_layer = bijections.Sandwich( + bijections.Sandwich(inner, permute), + bijections.DCT(shape=(n_dim,)), + ) + else: + new_layer = bijections.Sandwich(inner, permute) + + scale = Parameterize( + lambda x: x + jnp.sqrt(1 + x**2), + jnp.zeros(n_dim), + ) + affine = eqx.tree_at( + where=lambda aff: aff.scale, + pytree=bijections.Affine(jnp.zeros(n_dim), jnp.ones(n_dim)), + replace=scale, + ) + + pre = [] + if layer % 2 == 0: + pre.append(bijections.Neg(shape=(n_dim,))) + + nonlin_layer = bijections.Sandwich( + affine, + bijections.Chain( + [ + *pre, + bijections.Vmap(bijections.SoftPlusX(), axis_size=n_dim), + ] + ), + ) + scale = Parameterize( + lambda x: x + jnp.sqrt(1 + x**2), + jnp.zeros(n_dim), + ) + affine = eqx.tree_at( + where=lambda aff: aff.scale, + pytree=bijections.Affine(jnp.zeros(n_dim), jnp.ones(n_dim)), + replace=scale, + ) + return bijections.Chain([new_layer, nonlin_layer, affine, base]) diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index 6c5bfc6..356b8d0 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -1,13 +1,13 @@ import os from dataclasses import dataclass -from typing import Any, Literal, Optional, overload +from typing import Any, Literal, Optional, cast, overload import arviz import numpy as np import pandas as pd import pyarrow -from nutpie import _lib +from nutpie import _lib # type: ignore @dataclass(frozen=True) @@ -281,7 +281,7 @@ def in_colab(): if in_colab(): return True try: - shell = get_ipython().__class__.__name__ + shell = get_ipython().__class__.__name__ # type: ignore if shell == "ZMQInteractiveShell": # Jupyter notebook, Spyder or qtconsole try: from IPython.display import ( @@ -398,6 +398,8 @@ def _extract(self, results): dims["divergence_start_gradient"] = ["unconstrained_parameter"] dims["divergence_end"] = ["unconstrained_parameter"] dims["divergence_momentum"] = ["unconstrained_parameter"] + dims["transformed_gradient"] = ["unconstrained_parameter"] + dims["transformed_position"] = ["unconstrained_parameter"] if self._return_raw_trace: return results @@ -453,14 +455,15 @@ def _repr_html_(self): def sample( compiled_model: CompiledModel, *, - draws: int, - tune: int, + draws: int | None, + tune: int | None, chains: int, cores: Optional[int], seed: Optional[int], save_warmup: bool, progress_bar: bool, low_rank_modified_mass_matrix: bool = False, + transform_adapt: bool = False, init_mean: Optional[np.ndarray], return_raw_trace: bool, blocking: Literal[True], @@ -472,14 +475,15 @@ def sample( def sample( compiled_model: CompiledModel, *, - draws: int, - tune: int, + draws: int | None, + tune: int | None, chains: int, cores: Optional[int], seed: Optional[int], save_warmup: bool, progress_bar: bool, low_rank_modified_mass_matrix: bool = False, + transform_adapt: bool = False, init_mean: Optional[np.ndarray], return_raw_trace: bool, blocking: Literal[False], @@ -490,14 +494,15 @@ def sample( def sample( compiled_model: CompiledModel, *, - draws: int = 1000, - tune: int = 300, + draws: int | None = None, + tune: int | None = None, chains: int = 6, cores: Optional[int] = None, seed: Optional[int] = None, save_warmup: bool = True, progress_bar: bool = True, low_rank_modified_mass_matrix: bool = False, + transform_adapt: bool = False, init_mean: Optional[np.ndarray] = None, return_raw_trace: bool = False, blocking: bool = True, @@ -510,9 +515,9 @@ def sample( Parameters ---------- - draws: int + draws: int | None The number of draws after tuning in each chain. - tune: int + tune: int | None The number of tuning (warmup) draws in each chain. chains: int The number of chains to sample. @@ -585,6 +590,9 @@ def sample( mass_matrix_gamma: float > 0, default=1e-5 Regularisation parameter for the eigenvalues. Only applicable with low_rank_modified_mass_matrix=True. + transform_adapt: bool, default=False + Use the experimental transform adaptation algorithm + during tuning. **kwargs Pass additional arguments to nutpie._lib.PySamplerArgs @@ -594,12 +602,22 @@ def sample( An ArviZ ``InferenceData`` object that contains the samples. """ + if low_rank_modified_mass_matrix and transform_adapt: + raise ValueError( + "Specify only one of `low_rank_modified_mass_matrix` and `transform_adapt`" + ) + if low_rank_modified_mass_matrix: settings = _lib.PyNutsSettings.LowRank(seed) + elif transform_adapt: + settings = _lib.PyNutsSettings.Transform(seed) else: settings = _lib.PyNutsSettings.Diag(seed) - settings.num_tune = tune - settings.num_draws = draws + + if tune is not None: + settings.num_tune = tune + if draws is not None: + settings.num_draws = draws settings.num_chains = chains for name, val in kwargs.items(): @@ -608,10 +626,10 @@ def sample( if cores is None: try: # Only available in python>=3.13 - available = os.process_cpu_count() + available = os.process_cpu_count() # type: ignore except AttributeError: available = os.cpu_count() - cores = min(chains, available) + cores = min(chains, cast(int, available)) if init_mean is None: init_mean = np.zeros(compiled_model.n_dim) diff --git a/python/nutpie/transform_adapter.py b/python/nutpie/transform_adapter.py new file mode 100644 index 0000000..30ca9a6 --- /dev/null +++ b/python/nutpie/transform_adapter.py @@ -0,0 +1,924 @@ +from functools import partial +from typing import Callable +import time + +from flowjax import bijections +from jaxtyping import ArrayLike, PyTree +import numpy as np +import equinox as eqx +import jax +import jax.numpy as jnp +import jax.random as jr +import traceback +import flowjax +import flowjax.flows +import flowjax.train +from flowjax.train.losses import MaximumLikelihoodLoss, PRNGKeyArray +from flowjax.train.train_utils import ( + count_fruitless, + get_batches, + step, + train_val_split, +) +import optax +from paramax import unwrap, NonTrainable + +from nutpie.normalizing_flow import Coupling, extend_flow, make_flow +import tqdm + +_BIJECTION_TRACE = [] + + +def fit_to_data( + key: PRNGKeyArray, + dist: PyTree, # Custom losses may support broader types than AbstractDistribution + x, + *, + condition: ArrayLike | None = None, + loss_fn: Callable | None = None, + max_epochs: int = 100, + max_patience: int = 5, + batch_size: int = 100, + val_prop: float = 0.1, + learning_rate: float = 5e-4, + optimizer: optax.GradientTransformation | None = None, + return_best: bool = True, + show_progress: bool = True, + opt_state=None, + verbose: bool = False, +): + r"""Train a distribution (e.g. a flow) to samples from the target distribution. + + The distribution can be unconditional :math:`p(x)` or conditional + :math:`p(x|\text{condition})`. Note that the last batch in each epoch is dropped + if truncated (to avoid recompilation). This function can also be used to fit + non-distribution pytrees as long as a compatible loss function is provided. + + Args: + key: Jax random seed. + dist: The distribution to train. + x: Samples from target distribution. + condition: Conditioning variables. Defaults to None. + loss_fn: Loss function. Defaults to MaximumLikelihoodLoss. + max_epochs: Maximum number of epochs. Defaults to 100. + max_patience: Number of consecutive epochs with no validation loss improvement + after which training is terminated. Defaults to 5. + batch_size: Batch size. Defaults to 100. + val_prop: Proportion of data to use in validation set. Defaults to 0.1. + learning_rate: Adam learning rate. Defaults to 5e-4. + optimizer: Optax optimizer. If provided, this overrides the default Adam + optimizer, and the learning_rate is ignored. Defaults to None. + return_best: Whether the result should use the parameters where the minimum loss + was reached (when True), or the parameters after the last update (when + False). Defaults to True. + show_progress: Whether to show progress bar. Defaults to True. + + Returns: + A tuple containing the trained distribution and the losses. + """ + if not isinstance(x, tuple): + x = (x,) + data = x if condition is None else (*x, condition) + data = tuple(jnp.asarray(a) for a in data) + + if optimizer is None: + optimizer = optax.adam(learning_rate) + + if loss_fn is None: + loss_fn = MaximumLikelihoodLoss() + + params, static = eqx.partition( + dist, + eqx.is_inexact_array, + is_leaf=lambda leaf: isinstance(leaf, NonTrainable), + ) + best_params = params + + if opt_state is None: + opt_state = optimizer.init(params) + + # train val split + key, subkey = jr.split(key) + train_data, val_data = train_val_split(subkey, data, val_prop=val_prop) + losses = {"train": [], "val": []} + + loop = tqdm.tqdm(range(max_epochs), disable=not show_progress) + + for i in loop: + # Shuffle data + start = time.time() + key, *subkeys = jr.split(key, 3) + train_data = [jr.permutation(subkeys[0], a) for a in train_data] + val_data = [jr.permutation(subkeys[1], a) for a in val_data] + if verbose and i == 0: + print("shuffle timing:", time.time() - start) + + start = time.time() + + key, subkey = jr.split(key) + batches = get_batches(train_data, batch_size) + batch_losses = [] + + if verbose and i == 0: + print("batch timing:", time.time() - start) + + start = time.time() + + if True: + for batch in zip(*batches, strict=True): + key, subkey = jr.split(key) + params, opt_state, batch_loss = step( + params, + static, + *batch, + optimizer=optimizer, + opt_state=opt_state, + loss_fn=loss_fn, + key=subkey, + ) + batch_losses.append(batch_loss) + else: + params, opt_state, batch_losses = _step_batch_loop( + params, + static, + opt_state, + optimizer, + loss_fn, + subkey, + *batches, + ) + + losses["train"].append((sum(batch_losses) / len(batch_losses)).item()) + + if verbose and i == 0: + print("step timing:", time.time() - start) + + start = time.time() + # Val epoch + batch_losses = [] + for batch in zip(*get_batches(val_data, batch_size), strict=True): + key, subkey = jr.split(key) + loss_i = loss_fn(params, static, *batch, key=subkey) + batch_losses.append(loss_i) + losses["val"].append(sum(batch_losses) / len(batch_losses)) + + if verbose and i == 0: + print("val timing:", time.time() - start) + + loop.set_postfix({k: v[-1] for k, v in losses.items()}) + if losses["val"][-1] == min(losses["val"]): + best_params = params + + elif count_fruitless(losses["val"]) > max_patience: + loop.set_postfix_str(f"{loop.postfix} (Max patience reached)") + break + + params = best_params if return_best else params + dist = eqx.combine(params, static) + return dist, losses, opt_state + + +@eqx.filter_jit +def _step_batch_loop(params, static, opt_state, optimizer, loss_fn, key, *batches): + def scan_fn(carry, batch): + params, opt_state, key = carry + key, subkey = jr.split(key) + params, opt_state, loss_i = step( + params, + static, + *batch, + optimizer=optimizer, + opt_state=opt_state, + loss_fn=loss_fn, + key=subkey, + ) + return (params, opt_state, key), loss_i + + (params, opt_state, _), batch_losses = jax.lax.scan( + scan_fn, (params, opt_state, key), batches + ) + + return params, opt_state, batch_losses + + +@eqx.filter_jit +def inverse_gradient_and_val(bijection, draw, grad, logp): + if False: + x = bijection.inverse(draw) + (_, fwd_log_det), pull_grad_fn = jax.vjp( + lambda x: bijection.transform_and_log_det(x), x + ) + (x_grad,) = pull_grad_fn((grad, jnp.ones(()))) + return (x, x_grad, logp + fwd_log_det) + if isinstance(bijection, bijections.Chain): + for b in bijection.bijections[::-1]: + draw, grad, logp = inverse_gradient_and_val(b, draw, grad, logp) + return draw, grad, logp + elif isinstance(bijection, bijections.Permute): + return ( + draw[bijection.inverse_permutation], + grad[bijection.inverse_permutation], + logp, + ) + elif isinstance(bijection, bijections.Affine): + draw, logdet = bijection.inverse_and_log_det(draw) + grad = grad * bijection.scale + return (draw, grad, logp - logdet) + elif isinstance(bijection, bijections.Vmap): + + def inner(bijection, y, y_grad, y_logp): + return inverse_gradient_and_val(bijection, y, y_grad, y_logp) + + y, y_grad, log_det = eqx.filter_vmap( + inner, + in_axes=(bijection.in_axes[0], 0, 0, None), + axis_size=bijection.axis_size, + )(bijection.bijection, draw, grad, jnp.zeros(())) + return y, y_grad, jnp.sum(log_det) + logp + elif isinstance(bijection, bijections.Sandwich): + draw, grad, logp = inverse_gradient_and_val( + bijections.Invert(bijection.outer), draw, grad, logp + ) + draw, grad, logp = inverse_gradient_and_val(bijection.inner, draw, grad, logp) + draw, grad, logp = inverse_gradient_and_val(bijection.outer, draw, grad, logp) + return draw, grad, logp + # Disabeling the Coupling case for now, it slows down compile time for some reason? + elif False and isinstance(bijection, Coupling): + y, y_grad, y_logp = draw, grad, logp + y_cond, y_trans = ( + y[: bijection.untransformed_dim], + y[bijection.untransformed_dim :], + ) + x_cond = y_cond + + y_grad_cond, y_grad_trans = ( + y_grad[: bijection.untransformed_dim], + y_grad[bijection.untransformed_dim :], + ) + + def conditioner(x_cond): + return bijection.conditioner(x_cond) + + transformer_params, nn_pull = jax.vjp(conditioner, x_cond) + + def pull_transformer_grad(transformer_params): + transformer = bijection._flat_params_to_transformer(transformer_params) + + x_trans, x_grad_trans, x_logp = inverse_gradient_and_val( + transformer, y_trans, y_grad_trans, y_logp + ) + + return (x_logp, x_trans), x_grad_trans + + ((x_logp, x_trans), pull_pull_transformer_grad, x_grad_trans) = jax.vjp( + pull_transformer_grad, transformer_params, has_aux=True + ) + + (co_transformer_params,) = pull_pull_transformer_grad((1.0, -x_grad_trans)) + (co_x_cond,) = nn_pull(co_transformer_params) + + x = jnp.hstack((x_cond, x_trans)) + x_grad = jnp.hstack((y_grad_cond + co_x_cond, x_grad_trans)) + return x, x_grad, x_logp + + elif isinstance(bijection, bijections.Invert): + inner = bijection.bijection + x, _ = inner.transform_and_log_det(draw) + (_, fwd_log_det), pull_grad_fn = jax.vjp( + lambda x: inner.inverse_and_log_det(x), x + ) + (x_grad,) = pull_grad_fn((grad, jnp.ones(()))) + return (x, x_grad, logp + fwd_log_det) + else: + x, _ = bijection.inverse_and_log_det(draw) + (_, fwd_log_det), pull_grad_fn = jax.vjp( + lambda x: bijection.transform_and_log_det(x), x + ) + (x_grad,) = pull_grad_fn((grad, jnp.ones(()))) + return (x, x_grad, logp + fwd_log_det) + + +class FisherLoss: + def __init__(self, gamma=None, log_inside_batch=False): + self._gamma = gamma + self._log_inside_batch = log_inside_batch + + @eqx.filter_jit + def __call__( + self, + params, + static, + draws, + grads, + logps, + condition=None, + key=None, + return_all_costs=False, + return_elemwise_costs=False, + ): + flow = unwrap(eqx.combine(params, static, is_leaf=eqx.is_inexact_array)) + + if return_elemwise_costs: + + def compute_loss(bijection, draw, grad, logp): + draw, grad, logp = inverse_gradient_and_val(bijection, draw, grad, logp) + cost = (draw + grad) ** 2 + return cost + + costs = jax.vmap(compute_loss, [None, 0, 0, 0])( + flow.bijection, + draws, + grads, + logps, + ) + return costs.mean(0) + + if self._gamma is None: + + def compute_loss(bijection, draw, grad, logp): + draw, grad, logp = inverse_gradient_and_val(bijection, draw, grad, logp) + cost = ((draw + grad) ** 2).sum() + return cost + + costs = jax.vmap(compute_loss, [None, 0, 0, 0])( + flow.bijection, + draws, + grads, + logps, + ) + + if return_all_costs: + return costs + + if self._log_inside_batch: + return jnp.log(costs).mean() + else: + return jnp.log(costs.mean()) + + else: + + def transform(draw, grad, logp): + return inverse_gradient_and_val(flow.bijection, draw, grad, logp) + + draws, grads, logps = jax.vmap(transform, [0, 0, 0], (0, 0, 0))( + draws, grads, logps + ) + fisher_loss = ((draws + grads) ** 2).sum(1).mean(0) + normal_logps = -(draws * draws).sum(1) / 2 + var_loss = (logps - normal_logps).var() + return jnp.log(fisher_loss + self._gamma * var_loss) + + +def fit_flow(key, bijection, loss_fn, draws, grads, logps, **kwargs): + flow = flowjax.flows.Transformed( + flowjax.distributions.StandardNormal(bijection.shape), bijection + ) + + key, train_key = jax.random.split(key) + + fit, losses, opt_state = fit_to_data( + key=train_key, + dist=flow, + x=(draws, grads, logps), + loss_fn=loss_fn, + max_epochs=500, + return_best=True, + **kwargs, + ) + return fit.bijection, losses, opt_state + + +@eqx.filter_jit +def _init_from_transformed_position(logp_fn, bijection, transformed_position): + bijection = unwrap(bijection) + (untransformed_position, logdet), pull_grad = jax.vjp( + bijection.transform_and_log_det, transformed_position + ) + logp, untransformed_gradient = jax.value_and_grad(lambda x: logp_fn(x)[0])( + untransformed_position + ) + (transformed_gradient,) = pull_grad((untransformed_gradient, 1.0)) + return ( + logp, + logdet, + untransformed_position, + untransformed_gradient, + transformed_gradient, + ) + + +@eqx.filter_jit +def _init_from_transformed_position_part1(logp_fn, bijection, transformed_position): + bijection = unwrap(bijection) + (untransformed_position, logdet) = bijection.transform_and_log_det( + transformed_position + ) + + return (logdet, untransformed_position) + + +@eqx.filter_jit +def _init_from_transformed_position_part2( + bijection, + part1, + untransformed_gradient, +): + logdet, untransformed_position, transformed_position = part1 + bijection = unwrap(bijection) + _, pull_grad = jax.vjp(bijection.transform_and_log_det, transformed_position) + (transformed_gradient,) = pull_grad((untransformed_gradient, 1.0)) + return ( + logdet, + transformed_gradient, + ) + + +@eqx.filter_jit +def _init_from_untransformed_position(logp_fn, bijection, untransformed_position): + logp, untransformed_gradient = jax.value_and_grad(lambda x: logp_fn(x)[0])( + untransformed_position + ) + logdet, transformed_position, transformed_gradient = _inv_transform( + bijection, untransformed_position, untransformed_gradient + ) + return ( + logp, + logdet, + untransformed_gradient, + transformed_position, + transformed_gradient, + ) + + +@eqx.filter_jit +def _inv_transform(bijection, untransformed_position, untransformed_gradient): + bijection = unwrap(bijection) + transformed_position, transformed_gradient, logdet = inverse_gradient_and_val( + bijection, untransformed_position, untransformed_gradient, 0.0 + ) + return logdet, transformed_position, transformed_gradient + + +class TransformAdapter: + def __init__( + self, + seed, + position, + gradient, + chain, + *, + logp_fn, + make_flow_fn, + verbose=False, + window_size=2000, + show_progress=False, + num_diag_windows=10, + learning_rate=1e-3, + zero_init=True, + untransformed_dim=None, + batch_size=128, + reuse_opt_state=True, + max_patience=5, + gamma=None, + log_inside_batch=False, + initial_skip=500, + extension_windows=None, + extend_dct=False, + extension_var_count=6, + extension_var_trafo_count=4, + debug_save_bijection=False, + make_optimizer=None, + num_layers=9, + ): + self._logp_fn = logp_fn + self._make_flow_fn = make_flow_fn + self._chain = chain + self._verbose = verbose + self._window_size = window_size + self._initial_skip = initial_skip + self._num_layers = num_layers + if make_optimizer is None: + self._make_optimizer = lambda: optax.apply_if_finite( + optax.adamw(learning_rate), 50 + ) + else: + self._make_optimizer = make_optimizer + self._optimizer = self._make_optimizer() + self._loss_fn = FisherLoss(gamma, log_inside_batch) + self._show_progress = show_progress + self._num_diag_windows = num_diag_windows + self._zero_init = zero_init + self._untransformed_dim = untransformed_dim + self._batch_size = batch_size + self._reuse_opt_state = reuse_opt_state + self._opt_state = None + self._max_patience = max_patience + self._count_trace = [] + self._last_extend_dct = True + self._extend_dct = extend_dct + self._extension_var_count = extension_var_count + self._extension_var_trafo_count = extension_var_trafo_count + self._debug_save_bijection = debug_save_bijection + self._layers = 0 + + if extension_windows is None: + self._extension_windows = [] + else: + self._extension_windows = extension_windows + + try: + self._bijection = make_flow_fn(seed, [position], [gradient], n_layers=0) + except Exception as e: + print("make_flow", e) + print(traceback.format_exc()) + raise + self.index = 0 + + @property + def transformation_id(self): + return self.index + + def update(self, seed, positions, gradients, logps): + self.index += 1 + if self._verbose: + print( + f"Chain {self._chain}: Total available points: {len(positions)}, seed {seed}" + ) + n_draws = len(positions) + assert n_draws == len(positions) + assert n_draws == len(gradients) + assert n_draws == len(logps) + self._count_trace.append(n_draws) + if n_draws == 0: + return + try: + if self.index <= self._num_diag_windows: + size = len(positions) + lower_idx = -size // 5 + 3 + positions_slice = positions[lower_idx:] + gradients_slice = gradients[lower_idx:] + logp_slice = logps[lower_idx:] + + if len(positions_slice) > 0: + positions = positions_slice + gradients = gradients_slice + logps = logp_slice + + positions = np.array(positions) + gradients = np.array(gradients) + logps = np.array(logps) + + fit = self._make_flow_fn(seed, positions, gradients, n_layers=0) + + flow = flowjax.flows.Transformed( + flowjax.distributions.StandardNormal(fit.shape), fit + ) + params, static = eqx.partition(flow, eqx.is_inexact_array) + new_loss = self._loss_fn(params, static, positions, gradients, logps) + + if self._verbose: + print("loss from diag:", new_loss) + + if np.isfinite(new_loss): + self._bijection = fit + self._opt_state = None + + return + + positions = np.array(positions[self._initial_skip :][-self._window_size :]) + gradients = np.array(gradients[self._initial_skip :][-self._window_size :]) + logps = np.array(logps[self._initial_skip :][-self._window_size :]) + + if len(positions) < 10: + return + + if self._verbose and not np.isfinite(gradients).all(): + print(gradients) + print(gradients.shape) + print((~np.isfinite(gradients)).nonzero()) + + assert np.isfinite(positions).all() + assert np.isfinite(gradients).all() + assert np.isfinite(logps).all() + + # TODO don't reuse seed + key = jax.random.PRNGKey(seed % (2**63)) + + if len(self._bijection.bijections) == 1: + base = self._make_flow_fn( + seed, + positions, + gradients, + n_layers=self._num_layers, + untransformed_dim=self._untransformed_dim, + zero_init=self._zero_init, + ) + flow = flowjax.flows.Transformed( + flowjax.distributions.StandardNormal(base.shape), base + ) + params, static = eqx.partition(flow, eqx.is_inexact_array) + if self._verbose: + print( + "loss before optimization: ", + self._loss_fn( + params, + static, + positions[-128:], + gradients[-128:], + logps[-128:], + ), + ) + else: + base = self._bijection + + if self.index in self._extension_windows: + if self._verbose: + print("Extending flow...") + self._last_extend_dct = not self._last_extend_dct + dct = self._last_extend_dct and self._extend_dct + base = extend_flow( + key, + base, + self._loss_fn, + positions, + gradients, + logps, + self._layers, + dct=dct, + extension_var_count=self._extension_var_count, + extension_var_trafo_count=self._extension_var_trafo_count, + verbose=self._verbose, + ) + self._optimizer = self._make_optimizer() + self._opt_state = None + self._layers += 1 + + # make_flow might still onreturn a single trafo for 1d problems + if len(base.bijections) == 1: + self._bijection = base + self._opt_state = None + + if self._debug_save_bijection: + _BIJECTION_TRACE.append( + (self.index, fit, (positions, gradients, logps)) + ) + return + + flow = flowjax.flows.Transformed( + flowjax.distributions.StandardNormal(self._bijection.shape), + self._bijection, + ) + params, static = eqx.partition(flow, eqx.is_inexact_array) + + start = time.time() + old_loss = self._loss_fn( + params, static, positions[-128:], gradients[-128:], logps[-128:] + ) + if self._verbose: + print("loss function time: ", time.time() - start) + + if np.isfinite(old_loss) and old_loss < -5 and self.index > 10: + if self._verbose: + print(f"Loss is low ({old_loss}), skipping training") + return + + fit, _, opt_state = fit_flow( + key, + base, + self._loss_fn, + positions, + gradients, + logps, + show_progress=self._show_progress, + verbose=self._verbose, + optimizer=self._optimizer, + batch_size=self._batch_size, + opt_state=self._opt_state if self._reuse_opt_state else None, + max_patience=self._max_patience, + ) + + flow = flowjax.flows.Transformed( + flowjax.distributions.StandardNormal(fit.shape), fit + ) + params, static = eqx.partition(flow, eqx.is_inexact_array) + + start = time.time() + new_loss = self._loss_fn( + params, static, positions[-128:], gradients[-128:], logps[-128:] + ) + if self._verbose: + print("new loss function time: ", time.time() - start) + + if self._verbose: + print(f"Chain {self._chain}: New loss {new_loss}, old loss {old_loss}") + + if not np.isfinite(old_loss): + flow = flowjax.flows.Transformed( + flowjax.distributions.StandardNormal(self._bijection.shape), + self._bijection, + ) + params, static = eqx.partition(flow, eqx.is_inexact_array) + print( + self._loss_fn( + params, + static, + positions[-128:], + gradients[-128:], + logps[-128:], + return_all_costs=True, + ) + ) + + if not np.isfinite(new_loss): + flow = flowjax.flows.Transformed( + flowjax.distributions.StandardNormal(fit.shape), fit + ) + params, static = eqx.partition(flow, eqx.is_inexact_array) + print( + self._loss_fn( + params, + static, + positions[-128:], + gradients[-128:], + logps[-128:], + return_all_costs=True, + ) + ) + + if self._debug_save_bijection: + _BIJECTION_TRACE.append( + (self.index, fit, (positions, gradients, logps)) + ) + + def valid_new_logp(): + logdet, pos, grad = _inv_transform( + fit, + jnp.array(positions[-1]), + jnp.array(gradients[-1]), + ) + return ( + np.isfinite(logdet) + and np.isfinite(pos[0]).all() + and np.isfinite(grad[0]).all() + ) + + if (not np.isfinite(old_loss)) and (not np.isfinite(new_loss)): + self._bijection = self._make_flow_fn( + seed, positions, gradients, n_layers=0 + ) + self._opt_state = None + return + + if not valid_new_logp(): + if self._verbose: + print("Invalid new logp. Skipping update.") + return + + if not np.isfinite(new_loss): + if self._verbose: + print("Invalid new loss. Skipping update.") + return + + if new_loss > old_loss: + return + + self._bijection = fit + self._opt_state = opt_state + + except Exception as e: + print("update error:", e) + print(traceback.format_exc()) + raise + + def init_from_transformed_position(self, transformed_position): + try: + logp, logdet, *arrays = _init_from_transformed_position( + self._logp_fn, + self._bijection, + jnp.array(transformed_position), + ) + return ( + float(logp), + float(logdet), + *[np.array(val, dtype="float64") for val in arrays], + ) + except Exception as e: + print(e) + print(traceback.format_exc()) + raise + + def init_from_transformed_position_part1(self, transformed_position): + try: + transformed_position = jnp.array(transformed_position) + logdet, untransformed_position = _init_from_transformed_position_part1( + self._logp_fn, + self._bijection, + transformed_position, + ) + part1 = (logdet, untransformed_position, transformed_position) + return np.array(untransformed_position, dtype="float64"), part1 + except Exception as e: + print(e) + print(traceback.format_exc()) + raise + + def init_from_transformed_position_part2( + self, + part1, + untransformed_gradient, + ): + try: + # TODO We could extract the arrays from the pull_grad function + # to reuse computation from part1 + logdet, *arrays = _init_from_transformed_position_part2( + self._bijection, + part1, + untransformed_gradient, + ) + return float(logdet), *[np.array(val, dtype="float64") for val in arrays] + except Exception as e: + print(e) + print(traceback.format_exc()) + raise + + def init_from_untransformed_position(self, untransformed_position): + try: + logp, logdet, *arrays = _init_from_untransformed_position( + self._logp_fn, + self._bijection, + jnp.array(untransformed_position), + ) + arrays = [np.array(val, dtype="float64") for val in arrays] + return float(logp), float(logdet), *arrays + except Exception as e: + print(e) + print(traceback.format_exc()) + raise + + def inv_transform(self, position, gradient): + try: + logdet, *arrays = _inv_transform( + self._bijection, jnp.array(position), jnp.array(gradient) + ) + return logdet, *[np.array(val, dtype="float64") for val in arrays] + except Exception as e: + print(e) + print(traceback.format_exc()) + raise + + +def make_transform_adapter( + *, + verbose=False, + window_size=600, + show_progress=False, + nn_depth=1, + nn_width=None, + num_layers=9, + num_diag_windows=9, + learning_rate=5e-4, + untransformed_dim=None, + zero_init=True, + batch_size=128, + reuse_opt_state=False, + max_patience=20, + householder_layer=False, + dct_layer=False, + gamma=None, + log_inside_batch=False, + initial_skip=120, + extension_windows=[], + extend_dct=False, + extension_var_count=4, + extension_var_trafo_count=2, + debug_save_bijection=False, + make_optimizer=None, +): + return partial( + TransformAdapter, + verbose=verbose, + window_size=window_size, + make_flow_fn=partial( + make_flow, + householder_layer=householder_layer, + dct_layer=dct_layer, + nn_width=nn_width, + ), + show_progress=show_progress, + num_diag_windows=num_diag_windows, + learning_rate=learning_rate, + zero_init=zero_init, + untransformed_dim=untransformed_dim, + batch_size=batch_size, + reuse_opt_state=reuse_opt_state, + max_patience=max_patience, + gamma=gamma, + log_inside_batch=log_inside_batch, + initial_skip=initial_skip, + extension_windows=extension_windows, + extend_dct=extend_dct, + extension_var_count=extension_var_count, + extension_var_trafo_count=extension_var_trafo_count, + debug_save_bijection=debug_save_bijection, + make_optimizer=make_optimizer, + ) diff --git a/src/progress.rs b/src/progress.rs index 906c50d..2403e15 100644 --- a/src/progress.rs +++ b/src/progress.rs @@ -1,4 +1,4 @@ -use std::{collections::BTreeMap, time::Duration}; +use std::{collections::BTreeMap, sync::Arc, time::Duration}; use anyhow::{Context, Result}; use indicatif::ProgressBar; @@ -10,13 +10,13 @@ use upon::{Engine, Value}; pub struct ProgressHandler { engine: Engine<'static>, template: String, - callback: Py, + callback: Arc>, rate: Duration, n_cores: usize, } impl ProgressHandler { - pub fn new(callback: Py, rate: Duration, template: String, n_cores: usize) -> Self { + pub fn new(callback: Arc>, rate: Duration, template: String, n_cores: usize) -> Self { let engine = Engine::new(); Self { engine, diff --git a/src/pyfunc.rs b/src/pyfunc.rs index f914974..f23145e 100644 --- a/src/pyfunc.rs +++ b/src/pyfunc.rs @@ -8,9 +8,10 @@ use arrow::{ }, datatypes::{DataType, Field, Float32Type, Float64Type, Int64Type}, }; -use numpy::{PyArray1, PyReadonlyArray1}; +use numpy::{NotContiguousError, PyArray1, PyReadonlyArray1}; use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Model}; use pyo3::{ + exceptions::PyRuntimeError, pyclass, pymethods, types::{PyAnyMethods, PyDict, PyDictMethods}, Bound, Py, PyAny, PyErr, Python, @@ -20,6 +21,8 @@ use rand_distr::{Distribution, Uniform}; use smallvec::SmallVec; use thiserror::Error; +use crate::wrapper::PyTransformAdapt; + #[pyclass] #[derive(Debug, Clone)] #[non_exhaustive] @@ -71,29 +74,33 @@ impl PyVariable { #[pyclass] #[derive(Debug, Clone)] pub struct PyModel { - make_logp_func: Py, - make_expand_func: Py, - init_point_func: Option>, - variables: Vec, + make_logp_func: Arc>, + make_expand_func: Arc>, + init_point_func: Option>>, + variables: Arc>, + transform_adapter: Option, ndim: usize, } #[pymethods] impl PyModel { #[new] + #[pyo3(signature = (make_logp_func, make_expand_func, variables, ndim, *, init_point_func=None, transform_adapter=None))] fn new<'py>( make_logp_func: Py, make_expand_func: Py, variables: Vec, ndim: usize, init_point_func: Option>, + transform_adapter: Option>, ) -> Self { Self { - make_logp_func, - make_expand_func, - init_point_func, - variables, + make_logp_func: Arc::new(make_logp_func), + make_expand_func: Arc::new(make_expand_func), + init_point_func: init_point_func.map(|x| x.into()), + variables: Arc::new(variables), ndim, + transform_adapter: transform_adapter.map(PyTransformAdapt::new), } } } @@ -103,9 +110,13 @@ pub enum PyLogpError { #[error("Bad logp value: {0}")] BadLogp(f64), #[error("Python error: {0}")] - PyError(PyErr), + PyError(#[from] PyErr), #[error("logp function must return float.")] ReturnTypeError(), + #[error("Python retured a non-contigous array")] + NotContiguousError(#[from] NotContiguousError), + #[error("Unknown error: {0}")] + Anyhow(#[from] anyhow::Error), } impl LogpError for PyLogpError { @@ -113,7 +124,7 @@ impl LogpError for PyLogpError { match self { Self::BadLogp(_) => true, Self::PyError(err) => Python::with_gil(|py| { - let Ok(attr) = err.value_bound(py).getattr("is_recoverable") else { + let Ok(attr) = err.value(py).getattr("is_recoverable") else { return false; }; return attr @@ -121,20 +132,29 @@ impl LogpError for PyLogpError { .expect("Could not access is_recoverable in error check"); }), Self::ReturnTypeError() => false, + Self::NotContiguousError(_) => false, + Self::Anyhow(_) => false, } } } pub struct PyDensity { logp: Py, + transform_adapter: Option, dim: usize, } impl PyDensity { - fn new(logp_clone_func: &Py, dim: usize) -> Result { + fn new( + logp_clone_func: &Py, + dim: usize, + transform_adapter: Option<&PyTransformAdapt>, + ) -> Result { let logp_func = Python::with_gil(|py| logp_clone_func.call0(py))?; + let transform_adapter = transform_adapter.map(|val| val.clone()); Ok(Self { logp: logp_func, + transform_adapter, dim, }) } @@ -142,10 +162,11 @@ impl PyDensity { impl CpuLogpFunc for PyDensity { type LogpError = PyLogpError; + type TransformParams = Py; fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result { Python::with_gil(|py| { - let pos_array = PyArray1::from_slice_bound(py, position); + let pos_array = PyArray1::from_slice(py, position); let result = self.logp.call1(py, (pos_array,)); match result { Ok(val) => { @@ -172,11 +193,122 @@ impl CpuLogpFunc for PyDensity { fn dim(&self) -> usize { self.dim } + + fn inv_transform_normalize( + &mut self, + params: &Py, + untransformed_position: &[f64], + untransformed_gradient: &[f64], + transformed_position: &mut [f64], + transformed_gradient: &mut [f64], + ) -> std::result::Result { + let logdet = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .inv_transform_normalize( + params, + untransformed_position, + untransformed_gradient, + transformed_position, + transformed_gradient, + )?; + Ok(logdet) + } + + fn init_from_transformed_position( + &mut self, + params: &Py, + untransformed_position: &mut [f64], + untransformed_gradient: &mut [f64], + transformed_position: &[f64], + transformed_gradient: &mut [f64], + ) -> std::result::Result<(f64, f64), Self::LogpError> { + let (logp, logdet) = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .init_from_transformed_position( + params, + untransformed_position, + untransformed_gradient, + transformed_position, + transformed_gradient, + )?; + Ok((logp, logdet)) + } + + fn init_from_untransformed_position( + &mut self, + params: &Py, + untransformed_position: &[f64], + untransformed_gradient: &mut [f64], + transformed_position: &mut [f64], + transformed_gradient: &mut [f64], + ) -> std::result::Result<(f64, f64), Self::LogpError> { + let (logp, logdet) = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .init_from_untransformed_position( + params, + untransformed_position, + untransformed_gradient, + transformed_position, + transformed_gradient, + )?; + Ok((logp, logdet)) + } + + fn update_transformation<'a, R: rand::Rng + ?Sized>( + &'a mut self, + rng: &mut R, + untransformed_positions: impl ExactSizeIterator, + untransformed_gradients: impl ExactSizeIterator, + untransformed_logp: impl ExactSizeIterator, + params: &'a mut Py, + ) -> std::result::Result<(), Self::LogpError> { + self.transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .update_transformation( + rng, + untransformed_positions, + untransformed_gradients, + untransformed_logp, + params, + )?; + Ok(()) + } + + fn new_transformation( + &mut self, + rng: &mut R, + untransformed_position: &[f64], + untransformed_gradient: &[f64], + chain: u64, + ) -> std::result::Result, Self::LogpError> { + let trafo = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .new_transformation(rng, untransformed_position, untransformed_gradient, chain)?; + Ok(trafo) + } + + fn transformation_id(&self, params: &Py) -> std::result::Result { + let id = self + .transform_adapter + .as_ref() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .transformation_id(params)?; + Ok(id) + } } pub struct PyTrace { expand: Py, - variables: Vec, + variables: Arc>, builder: StructBuilder, } @@ -184,7 +316,7 @@ impl PyTrace { pub fn new( rng: &mut R, chain: u64, - variables: Vec, + variables: Arc>, make_expand_func: &Py, capacity: usize, ) -> Result { @@ -234,6 +366,7 @@ impl TensorShape { #[pymethods] impl TensorShape { #[new] + #[pyo3(signature = (shape, dims=None))] fn py_new(shape: Vec, dims: Option>>) -> Result { let dims = dims.unwrap_or(shape.iter().map(|_| None).collect()); if dims.len() != shape.len() { @@ -318,7 +451,7 @@ impl ExpandDtype { impl DrawStorage for PyTrace { fn append_value(&mut self, point: &[f64]) -> Result<()> { Python::with_gil(|py| { - let point = PyArray1::from_slice_bound(py, point); + let point = PyArray1::from_slice(py, point); let full_point = self .expand .call1(py, (point,)) @@ -344,36 +477,51 @@ impl DrawStorage for PyTrace { self.builder.field_builder(i).context( "Builder has incorrect type", )?; - builder.append_value(value.extract().expect("Return value from expand function could not be converted to boolean")) + let value = value + .extract() + .expect("Return value from expand function could not be converted to boolean"); + builder.append_value(value) }, ExpandDtype::Float64 {} => { let builder: &mut Float64Builder = self.builder.field_builder(i).context( "Builder has incorrect type", )?; - builder.append_value(value.extract().expect("Return value from expand function could not be converted to float64")) + builder.append_value( + value + .extract() + .expect("Return value from expand function could not be converted to float64") + ) }, ExpandDtype::Float32 {} => { let builder: &mut Float32Builder = self.builder.field_builder(i).context( "Builder has incorrect type", )?; - builder.append_value(value.extract().expect("Return value from expand function could not be converted to float32")) - + builder.append_value( + value + .extract() + .expect("Return value from expand function could not be converted to float32") + ) }, ExpandDtype::Int64 {} => { let builder: &mut Int64Builder = self.builder.field_builder(i).context( "Builder has incorrect type", )?; - builder.append_value(value.extract().expect("Return value from expand function could not be converted to int64")) + let value = value.extract().expect("Return value from expand function could not be converted to int64"); + builder.append_value(value) }, ExpandDtype::BooleanArray { tensor_type } => { let builder: &mut LargeListBuilder> = self.builder.field_builder(i).context( "Builder has incorrect type. Expected LargeListBuilder of Bool", )?; - let value_builder = builder.values().as_any_mut().downcast_mut::().context("Could not downcast builder to boolean type")?; + let value_builder = builder + .values() + .as_any_mut() + .downcast_mut::() + .context("Could not downcast builder to boolean type")?; let values: PyReadonlyArray1 = value.extract().context("Could not convert object to array")?; if values.len()? != tensor_type.size() { bail!("Extracted array has incorrect shape"); @@ -387,7 +535,11 @@ impl DrawStorage for PyTrace { self.builder.field_builder(i).context( "Builder has incorrect type. Expected LargeListBuilder of Float64", )?; - let value_builder = builder.values().as_any_mut().downcast_mut::>().context("Could not downcast builder to float64 type")?; + let value_builder = builder + .values() + .as_any_mut() + .downcast_mut::>() + .context("Could not downcast builder to float64 type")?; let values: PyReadonlyArray1 = value.extract().context("Could not convert object to array")?; if values.len()? != tensor_type.size() { bail!("Extracted array has incorrect shape"); @@ -400,7 +552,11 @@ impl DrawStorage for PyTrace { self.builder.field_builder(i).context( "Builder has incorrect type. Expected LargeListBuilder of Float32", )?; - let value_builder = builder.values().as_any_mut().downcast_mut::>().context("Could not downcast builder to float32 type")?; + let value_builder = builder + .values() + .as_any_mut() + .downcast_mut::>() + .context("Could not downcast builder to float32 type")?; let values: PyReadonlyArray1 = value.extract().context("Could not convert object to array")?; if values.len()? != tensor_type.size() { bail!("Extracted array has incorrect shape"); @@ -413,7 +569,11 @@ impl DrawStorage for PyTrace { self.builder.field_builder(i).context( "Builder has incorrect type. Expected LargeListBuilder of Int64", )?; - let value_builder = builder.values().as_any_mut().downcast_mut::>().context("Could not downcast builder to i64 type")?; + let value_builder = builder + .values() + .as_any_mut() + .downcast_mut::>() + .context("Could not downcast builder to i64 type")?; let values: PyReadonlyArray1 = value.extract().context("Could not convert object to array")?; if values.len()? != tensor_type.size() { bail!("Extracted array has incorrect shape"); @@ -471,6 +631,7 @@ impl Model for PyModel { Ok(CpuMath::new(PyDensity::new( &self.make_logp_func, self.ndim, + self.transform_adapter.as_ref(), )?)) } @@ -480,7 +641,7 @@ impl Model for PyModel { position: &mut [f64], ) -> Result<()> { let Some(init_func) = self.init_point_func.as_ref() else { - let dist = Uniform::new(-2f64, 2f64); + let dist = Uniform::new(-2f64, 2f64).expect("Could not create uniform distribution"); position.iter_mut().for_each(|x| *x = dist.sample(rng)); return Ok(()); }; diff --git a/src/pymc.rs b/src/pymc.rs index 98426c7..b33b821 100644 --- a/src/pymc.rs +++ b/src/pymc.rs @@ -39,7 +39,7 @@ type RawExpandFunc = unsafe extern "C" fn( #[derive(Clone)] pub(crate) struct LogpFunc { func: RawLogpFunc, - _keep_alive: PyObject, + _keep_alive: Arc, user_data_ptr: UserData, dim: usize, } @@ -55,7 +55,7 @@ impl LogpFunc { unsafe { std::mem::transmute::<*const c_void, RawLogpFunc>(ptr as *const c_void) }; Self { func, - _keep_alive: keep_alive, + _keep_alive: Arc::new(keep_alive), user_data_ptr: user_data_ptr as UserData, dim, } @@ -66,7 +66,7 @@ impl LogpFunc { #[derive(Clone)] pub(crate) struct ExpandFunc { func: RawExpandFunc, - _keep_alive: PyObject, + _keep_alive: Arc, user_data_ptr: UserData, dim: usize, expanded_dim: usize, @@ -87,7 +87,7 @@ impl ExpandFunc { Self { dim, expanded_dim, - _keep_alive: keep_alive, + _keep_alive: Arc::new(keep_alive), user_data_ptr: user_data_ptr as UserData, func, } @@ -114,6 +114,7 @@ impl LogpError for ErrorCode { impl<'a> CpuLogpFunc for &'a LogpFunc { type LogpError = ErrorCode; + type TransformParams = (); fn dim(&self) -> usize { self.dim @@ -233,7 +234,7 @@ pub(crate) struct PyMcModel { dim: usize, density: LogpFunc, expand: ExpandFunc, - init_func: Py, + init_func: Arc>, var_sizes: Vec, var_names: Vec, } @@ -253,7 +254,7 @@ impl PyMcModel { dim, density, expand, - init_func, + init_func: init_func.into(), var_names: var_names.extract()?, var_sizes: var_sizes.extract()?, }) diff --git a/src/stan.rs b/src/stan.rs index 0564cd2..e258096 100644 --- a/src/stan.rs +++ b/src/stan.rs @@ -7,16 +7,19 @@ use arrow::datatypes::{DataType, Field}; use bridgestan::open_library; use itertools::{izip, Itertools}; use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Model, Settings}; +use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use pyo3::types::{PyDict, PyTuple}; use pyo3::{exceptions::PyValueError, pyclass, pymethods, PyResult}; use rand::prelude::Distribution; -use rand::{thread_rng, RngCore}; +use rand::{rng, RngCore}; use rand_distr::StandardNormal; use smallvec::{SmallVec, ToSmallVec}; use thiserror::Error; +use crate::wrapper::PyTransformAdapt; + type InnerModel = bridgestan::Model>; #[pyclass] @@ -53,8 +56,8 @@ impl StanVariable { } #[getter] - fn shape<'py>(&self, py: Python<'py>) -> Bound<'py, PyTuple> { - PyTuple::new_bound(py, self.0.shape.iter()) + fn shape<'py>(&self, py: Python<'py>) -> PyResult> { + PyTuple::new(py, self.0.shape.iter()) } #[getter] @@ -68,6 +71,7 @@ impl StanVariable { pub struct StanModel { model: Arc, variables: Vec, + transform_adapter: Option, } /// Return meta information about the constrained parameters of the model @@ -136,25 +140,41 @@ fn params( #[pymethods] impl StanModel { #[new] - pub fn new(lib: StanLibrary, seed: Option, data: Option) -> anyhow::Result { + #[pyo3(signature = (lib, seed=None, data=None, transform_adapter=None))] + pub fn new( + lib: StanLibrary, + seed: Option, + data: Option, + transform_adapter: Option>, + ) -> anyhow::Result { let seed = match seed { Some(seed) => seed, - None => thread_rng().next_u32(), + None => rng().next_u32(), }; let data: Option = data.map(CString::new).transpose()?; let model = Arc::new( bridgestan::Model::new(lib.0, data.as_ref(), seed).map_err(anyhow::Error::new)?, ); let variables = params(&model, true, true)?; - Ok(StanModel { model, variables }) + let transform_adapter = transform_adapter.map(PyTransformAdapt::new); + Ok(StanModel { + model, + variables, + transform_adapter, + }) } pub fn variables<'py>(&self, py: Python<'py>) -> PyResult> { - let out = PyDict::new_bound(py); + let out = PyDict::new(py); let results: Result, _> = self .variables .iter() - .map(|var| out.set_item(var.name.clone(), StanVariable(var.clone()).into_py(py))) + .map(|var| { + out.set_item( + var.name.clone(), + StanVariable(var.clone()).into_pyobject(py)?, + ) + }) .collect(); results?; Ok(out) @@ -195,7 +215,10 @@ impl StanModel { */ } -pub struct StanDensity<'model>(&'model InnerModel); +pub struct StanDensity<'model> { + inner: &'model InnerModel, + transform_adapter: Option, +} #[derive(Debug, Error)] pub enum StanLogpError { @@ -203,6 +226,10 @@ pub enum StanLogpError { BridgeStan(#[from] bridgestan::BridgeStanError), #[error("Bad logp value: {0}")] BadLogp(f64), + #[error("Python exception: {0}")] + PyErr(#[from] PyErr), + #[error("Unspecified Error: {0}")] + Anyhow(#[from] anyhow::Error), } impl LogpError for StanLogpError { @@ -213,9 +240,12 @@ impl LogpError for StanLogpError { impl<'model> CpuLogpFunc for StanDensity<'model> { type LogpError = StanLogpError; + type TransformParams = Py; fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result { - let logp = self.0.log_density_gradient(position, true, true, grad)?; + let logp = self + .inner + .log_density_gradient(position, true, true, grad)?; if !logp.is_finite() { return Err(StanLogpError::BadLogp(logp)); } @@ -223,7 +253,143 @@ impl<'model> CpuLogpFunc for StanDensity<'model> { } fn dim(&self) -> usize { - self.0.param_unc_num() + self.inner.param_unc_num() + } + + fn inv_transform_normalize( + &mut self, + params: &Py, + untransformed_position: &[f64], + untransformed_gradient: &[f64], + transformed_position: &mut [f64], + transformed_gradient: &mut [f64], + ) -> std::result::Result { + let logdet = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .inv_transform_normalize( + params, + untransformed_position, + untransformed_gradient, + transformed_position, + transformed_gradient, + ) + .context("failed inv_transform_normalize")?; + Ok(logdet) + } + + fn init_from_transformed_position( + &mut self, + params: &Py, + untransformed_position: &mut [f64], + untransformed_gradient: &mut [f64], + transformed_position: &[f64], + transformed_gradient: &mut [f64], + ) -> std::result::Result<(f64, f64), Self::LogpError> { + let adapter = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))?; + + let part1 = adapter + .init_from_transformed_position_part1( + params, + untransformed_position, + transformed_position, + ) + .context("Failed init_from_transformed_position_part1")?; + + let logp = self.logp(untransformed_position, untransformed_gradient)?; + + let adapter = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))?; + + let logdet = adapter + .init_from_transformed_position_part2( + params, + part1, + untransformed_gradient, + transformed_gradient, + ) + .context("Failed init_from_transformed_position_part2")?; + Ok((logp, logdet)) + } + + fn init_from_untransformed_position( + &mut self, + params: &Py, + untransformed_position: &[f64], + untransformed_gradient: &mut [f64], + transformed_position: &mut [f64], + transformed_gradient: &mut [f64], + ) -> std::result::Result<(f64, f64), Self::LogpError> { + let logp = self + .logp(untransformed_position, untransformed_gradient) + .context("Failed to call stan logp function")?; + + let logdet = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .inv_transform_normalize( + params, + untransformed_position, + untransformed_gradient, + transformed_position, + transformed_gradient, + ) + .context("Failed inv_transform_normalize in stan init_from_untransformed_position")?; + Ok((logp, logdet)) + } + + fn update_transformation<'a, R: rand::Rng + ?Sized>( + &'a mut self, + rng: &mut R, + untransformed_positions: impl ExactSizeIterator, + untransformed_gradients: impl ExactSizeIterator, + untransformed_logp: impl ExactSizeIterator, + params: &'a mut Py, + ) -> std::result::Result<(), Self::LogpError> { + self.transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .update_transformation( + rng, + untransformed_positions, + untransformed_gradients, + untransformed_logp, + params, + ) + .context("Failed to update the transformation")?; + Ok(()) + } + + fn new_transformation( + &mut self, + rng: &mut R, + untransformed_position: &[f64], + untransformed_gradient: &[f64], + chain: u64, + ) -> std::result::Result, Self::LogpError> { + let trafo = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .new_transformation(rng, untransformed_position, untransformed_gradient, chain) + .context("Could not create transformation adapter")?; + Ok(trafo) + } + + fn transformation_id(&self, params: &Py) -> std::result::Result { + let id = self + .transform_adapter + .as_ref() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .transformation_id(params)?; + Ok(id) } } @@ -387,7 +553,10 @@ impl Model for StanModel { } fn math(&self) -> anyhow::Result> { - Ok(CpuMath::new(StanDensity(&self.model))) + Ok(CpuMath::new(StanDensity { + inner: &self.model, + transform_adapter: self.transform_adapter.as_ref().map(|v| v.clone()), + })) } fn init_position( diff --git a/src/wrapper.rs b/src/wrapper.rs index 6f9ad49..bb1a523 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -1,6 +1,7 @@ use std::{ fmt::Debug, - sync::Arc, + ops::{Deref, DerefMut}, + sync::{Arc, Mutex}, time::{Duration, Instant}, }; @@ -13,17 +14,19 @@ use crate::{ use anyhow::{bail, Context, Result}; use arrow::array::Array; +use numpy::{PyArray1, PyReadonlyArray1}; use nuts_rs::{ - AdaptOptions, ChainProgress, DiagAdaptExpSettings, DiagGradNutsSettings, LowRankNutsSettings, - LowRankSettings, NutsSettings, ProgressCallback, Sampler, SamplerWaitResult, Trace, + ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, ProgressCallback, Sampler, + SamplerWaitResult, Trace, TransformedNutsSettings, }; use pyo3::{ exceptions::PyTimeoutError, ffi::Py_uintptr_t, + intern, prelude::*, types::{PyList, PyTuple}, }; -use rand::{thread_rng, RngCore}; +use rand::{rng, RngCore}; #[pyclass] struct PyChainProgress(ChainProgress); @@ -66,100 +69,23 @@ impl PyChainProgress { } } -#[derive(Clone)] -enum InnerSettings { - LowRank(LowRankSettings), - Diag(DiagAdaptExpSettings), -} - #[pyclass] #[derive(Clone)] pub struct PyNutsSettings { - settings: NutsSettings<()>, - adapt: InnerSettings, + inner: Settings, } +#[derive(Clone, Debug)] enum Settings { Diag(DiagGradNutsSettings), LowRank(LowRankNutsSettings), -} - -// Would be much nicer with -// https://doc.rust-lang.org/nightly/unstable-book/language-features/type-changing-struct-update.html -fn combine_settings( - inner: T, - settings: NutsSettings<()>, -) -> NutsSettings { - let adapt = AdaptOptions { - dual_average_options: settings.adapt_options.dual_average_options, - mass_matrix_options: inner, - early_window: settings.adapt_options.early_window, - step_size_window: settings.adapt_options.step_size_window, - mass_matrix_switch_freq: settings.adapt_options.mass_matrix_switch_freq, - early_mass_matrix_switch_freq: settings.adapt_options.early_mass_matrix_switch_freq, - mass_matrix_update_freq: settings.adapt_options.mass_matrix_update_freq, - }; - NutsSettings { - num_tune: settings.num_tune, - num_draws: settings.num_draws, - maxdepth: settings.maxdepth, - store_gradient: settings.store_gradient, - store_unconstrained: settings.store_unconstrained, - max_energy_error: settings.max_energy_error, - store_divergences: settings.store_divergences, - adapt_options: adapt, - check_turning: settings.check_turning, - num_chains: settings.num_chains, - seed: settings.seed, - } -} - -fn split_settings(settings: NutsSettings) -> (NutsSettings<()>, T) { - let adapt_settings = settings.adapt_options; - let mass_matrix_settings = adapt_settings.mass_matrix_options; - - let remaining: AdaptOptions<()> = AdaptOptions { - dual_average_options: adapt_settings.dual_average_options, - mass_matrix_options: (), - early_window: adapt_settings.early_window, - step_size_window: adapt_settings.step_size_window, - mass_matrix_switch_freq: adapt_settings.mass_matrix_switch_freq, - early_mass_matrix_switch_freq: adapt_settings.early_mass_matrix_switch_freq, - mass_matrix_update_freq: adapt_settings.mass_matrix_update_freq, - }; - - let settings = NutsSettings { - adapt_options: remaining, - num_tune: settings.num_tune, - num_draws: settings.num_draws, - maxdepth: settings.maxdepth, - store_gradient: settings.store_gradient, - store_unconstrained: settings.store_unconstrained, - max_energy_error: settings.max_energy_error, - store_divergences: settings.store_divergences, - check_turning: settings.check_turning, - num_chains: settings.num_chains, - seed: settings.seed, - }; - - (settings, mass_matrix_settings) + Transforming(TransformedNutsSettings), } impl PyNutsSettings { - fn into_settings(self) -> Settings { - match self.adapt { - InnerSettings::LowRank(mass_matrix) => { - Settings::LowRank(combine_settings(mass_matrix, self.settings)) - } - InnerSettings::Diag(mass_matrix) => { - Settings::Diag(combine_settings(mass_matrix, self.settings)) - } - } - } - fn new_diag(seed: Option) -> Self { let seed = seed.unwrap_or_else(|| { - let mut rng = thread_rng(); + let mut rng = rng(); rng.next_u64() }); let settings = DiagGradNutsSettings { @@ -167,17 +93,14 @@ impl PyNutsSettings { ..Default::default() }; - let (settings, inner) = split_settings(settings); - Self { - settings, - adapt: InnerSettings::Diag(inner), + inner: Settings::Diag(settings), } } fn new_low_rank(seed: Option) -> Self { let seed = seed.unwrap_or_else(|| { - let mut rng = thread_rng(); + let mut rng = rng(); rng.next_u64() }); let settings = LowRankNutsSettings { @@ -185,229 +108,467 @@ impl PyNutsSettings { ..Default::default() }; - let (settings, inner) = split_settings(settings); + Self { + inner: Settings::LowRank(settings), + } + } + + fn new_tranform_adapt(seed: Option) -> Self { + let seed = seed.unwrap_or_else(|| { + let mut rng = rng(); + rng.next_u64() + }); + let settings = TransformedNutsSettings { + seed, + ..Default::default() + }; Self { - settings, - adapt: InnerSettings::LowRank(inner), + inner: Settings::Transforming(settings), } } } +// TODO switch to serde to expose all the options... #[pymethods] impl PyNutsSettings { #[staticmethod] #[allow(non_snake_case)] + #[pyo3(signature = (seed=None))] fn Diag(seed: Option) -> Self { PyNutsSettings::new_diag(seed) } #[staticmethod] #[allow(non_snake_case)] + #[pyo3(signature = (seed=None))] fn LowRank(seed: Option) -> Self { PyNutsSettings::new_low_rank(seed) } + #[staticmethod] + #[allow(non_snake_case)] + #[pyo3(signature = (seed=None))] + fn Transform(seed: Option) -> Self { + PyNutsSettings::new_tranform_adapt(seed) + } + #[getter] fn num_tune(&self) -> u64 { - self.settings.num_tune + match &self.inner { + Settings::Diag(nuts_settings) => nuts_settings.num_tune, + Settings::LowRank(nuts_settings) => nuts_settings.num_tune, + Settings::Transforming(nuts_settings) => nuts_settings.num_tune, + } } #[setter(num_tune)] fn set_num_tune(&mut self, val: u64) { - self.settings.num_tune = val + match &mut self.inner { + Settings::Diag(nuts_settings) => nuts_settings.num_tune = val, + Settings::LowRank(nuts_settings) => nuts_settings.num_tune = val, + Settings::Transforming(nuts_settings) => nuts_settings.num_tune = val, + } } #[getter] fn num_chains(&self) -> usize { - self.settings.num_chains + match &self.inner { + Settings::Diag(nuts_settings) => nuts_settings.num_chains, + Settings::LowRank(nuts_settings) => nuts_settings.num_chains, + Settings::Transforming(nuts_settings) => nuts_settings.num_chains, + } } #[setter(num_chains)] fn set_num_chains(&mut self, val: usize) { - self.settings.num_chains = val; + match &mut self.inner { + Settings::Diag(nuts_settings) => nuts_settings.num_chains = val, + Settings::LowRank(nuts_settings) => nuts_settings.num_chains = val, + Settings::Transforming(nuts_settings) => nuts_settings.num_chains = val, + } } #[getter] fn num_draws(&self) -> u64 { - self.settings.num_draws + match &self.inner { + Settings::Diag(nuts_settings) => nuts_settings.num_draws, + Settings::LowRank(nuts_settings) => nuts_settings.num_draws, + Settings::Transforming(nuts_settings) => nuts_settings.num_draws, + } } #[setter(num_draws)] fn set_num_draws(&mut self, val: u64) { - self.settings.num_draws = val; + match &mut self.inner { + Settings::Diag(nuts_settings) => nuts_settings.num_draws = val, + Settings::LowRank(nuts_settings) => nuts_settings.num_draws = val, + Settings::Transforming(nuts_settings) => nuts_settings.num_draws = val, + } } #[getter] - fn window_switch_freq(&self) -> u64 { - self.settings.adapt_options.mass_matrix_switch_freq + fn window_switch_freq(&self) -> Result { + match &self.inner { + Settings::Diag(nuts_settings) => { + Ok(nuts_settings.adapt_options.mass_matrix_switch_freq) + } + Settings::LowRank(nuts_settings) => { + Ok(nuts_settings.adapt_options.mass_matrix_switch_freq) + } + Settings::Transforming(nuts_settings) => { + Ok(nuts_settings.adapt_options.transform_update_freq) + } + } } #[setter(window_switch_freq)] - fn set_window_switch_freq(&mut self, val: u64) { - self.settings.adapt_options.mass_matrix_switch_freq = val; + fn set_window_switch_freq(&mut self, val: u64) -> Result<()> { + match &mut self.inner { + Settings::Diag(nuts_settings) => { + nuts_settings.adapt_options.mass_matrix_switch_freq = val; + Ok(()) + } + Settings::LowRank(nuts_settings) => { + nuts_settings.adapt_options.mass_matrix_switch_freq = val; + Ok(()) + } + Settings::Transforming(nuts_settings) => { + nuts_settings.adapt_options.transform_update_freq = val; + Ok(()) + } + } } #[getter] - fn early_window_switch_freq(&self) -> u64 { - self.settings.adapt_options.early_mass_matrix_switch_freq + fn early_window_switch_freq(&self) -> Result { + match &self.inner { + Settings::Diag(nuts_settings) => { + Ok(nuts_settings.adapt_options.early_mass_matrix_switch_freq) + } + Settings::LowRank(nuts_settings) => { + Ok(nuts_settings.adapt_options.early_mass_matrix_switch_freq) + } + Settings::Transforming(_) => { + bail!("Option early_window_switch_freq not availbale for transformation adaptation") + } + } } #[setter(early_window_switch_freq)] - fn set_early_window_switch_freq(&mut self, val: u64) { - self.settings.adapt_options.early_mass_matrix_switch_freq = val; + fn set_early_window_switch_freq(&mut self, val: u64) -> Result<()> { + match &mut self.inner { + Settings::Diag(nuts_settings) => { + nuts_settings.adapt_options.early_mass_matrix_switch_freq = val; + Ok(()) + } + Settings::LowRank(nuts_settings) => { + nuts_settings.adapt_options.early_mass_matrix_switch_freq = val; + Ok(()) + } + Settings::Transforming(_) => { + bail!("Option early_window_switch_freq not availbale for transformation adaptation") + } + } } + #[getter] fn initial_step(&self) -> f64 { - self.settings - .adapt_options - .dual_average_options - .initial_step + match &self.inner { + Settings::Diag(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .initial_step + } + Settings::LowRank(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .initial_step + } + Settings::Transforming(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .initial_step + } + } } #[setter(initial_step)] fn set_initial_step(&mut self, val: f64) { - self.settings - .adapt_options - .dual_average_options - .initial_step = val + match &mut self.inner { + Settings::Diag(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .initial_step = val; + } + Settings::LowRank(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .initial_step = val; + } + Settings::Transforming(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .initial_step = val; + } + } } #[getter] fn maxdepth(&self) -> u64 { - self.settings.maxdepth + match &self.inner { + Settings::Diag(nuts_settings) => nuts_settings.maxdepth, + Settings::LowRank(nuts_settings) => nuts_settings.maxdepth, + Settings::Transforming(nuts_settings) => nuts_settings.maxdepth, + } } #[setter(maxdepth)] fn set_maxdepth(&mut self, val: u64) { - self.settings.maxdepth = val + match &mut self.inner { + Settings::Diag(nuts_settings) => nuts_settings.maxdepth = val, + Settings::LowRank(nuts_settings) => nuts_settings.maxdepth = val, + Settings::Transforming(nuts_settings) => nuts_settings.maxdepth = val, + } } #[getter] fn store_gradient(&self) -> bool { - self.settings.store_gradient + match &self.inner { + Settings::Diag(nuts_settings) => nuts_settings.store_gradient, + Settings::LowRank(nuts_settings) => nuts_settings.store_gradient, + Settings::Transforming(nuts_settings) => nuts_settings.store_gradient, + } } #[setter(store_gradient)] fn set_store_gradient(&mut self, val: bool) { - self.settings.store_gradient = val; + match &mut self.inner { + Settings::Diag(nuts_settings) => nuts_settings.store_gradient = val, + Settings::LowRank(nuts_settings) => nuts_settings.store_gradient = val, + Settings::Transforming(nuts_settings) => nuts_settings.store_gradient = val, + } } #[getter] fn store_unconstrained(&self) -> bool { - self.settings.store_unconstrained + match &self.inner { + Settings::Diag(nuts_settings) => nuts_settings.store_unconstrained, + Settings::LowRank(nuts_settings) => nuts_settings.store_unconstrained, + Settings::Transforming(nuts_settings) => nuts_settings.store_unconstrained, + } } #[setter(store_unconstrained)] fn set_store_unconstrained(&mut self, val: bool) { - self.settings.store_unconstrained = val; + match &mut self.inner { + Settings::Diag(nuts_settings) => nuts_settings.store_unconstrained = val, + Settings::LowRank(nuts_settings) => nuts_settings.store_unconstrained = val, + Settings::Transforming(nuts_settings) => nuts_settings.store_unconstrained = val, + } } #[getter] fn store_divergences(&self) -> bool { - self.settings.store_divergences + match &self.inner { + Settings::Diag(nuts_settings) => nuts_settings.store_divergences, + Settings::LowRank(nuts_settings) => nuts_settings.store_divergences, + Settings::Transforming(nuts_settings) => nuts_settings.store_divergences, + } } #[setter(store_divergences)] fn set_store_divergences(&mut self, val: bool) { - self.settings.store_divergences = val; + match &mut self.inner { + Settings::Diag(nuts_settings) => nuts_settings.store_divergences = val, + Settings::LowRank(nuts_settings) => nuts_settings.store_divergences = val, + Settings::Transforming(nuts_settings) => nuts_settings.store_divergences = val, + } } #[getter] fn max_energy_error(&self) -> f64 { - self.settings.max_energy_error + match &self.inner { + Settings::Diag(nuts_settings) => nuts_settings.max_energy_error, + Settings::LowRank(nuts_settings) => nuts_settings.max_energy_error, + Settings::Transforming(nuts_settings) => nuts_settings.max_energy_error, + } } #[setter(max_energy_error)] fn set_max_energy_error(&mut self, val: f64) { - self.settings.max_energy_error = val + match &mut self.inner { + Settings::Diag(nuts_settings) => nuts_settings.max_energy_error = val, + Settings::LowRank(nuts_settings) => nuts_settings.max_energy_error = val, + Settings::Transforming(nuts_settings) => nuts_settings.max_energy_error = val, + } } - #[setter(target_accept)] - fn set_target_accept(&mut self, val: f64) { - self.settings - .adapt_options - .dual_average_options - .target_accept = val; + #[getter] + fn set_target_accept(&self) -> f64 { + match &self.inner { + Settings::Diag(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .target_accept + } + Settings::LowRank(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .target_accept + } + Settings::Transforming(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .target_accept + } + } } - #[getter] - fn target_accept(&self) -> f64 { - self.settings - .adapt_options - .dual_average_options - .target_accept + #[setter(target_accept)] + fn target_accept(&mut self, val: f64) { + match &mut self.inner { + Settings::Diag(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .target_accept = val + } + Settings::LowRank(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .target_accept = val + } + Settings::Transforming(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .target_accept = val + } + } } #[getter] - fn store_mass_matrix(&self) -> bool { - match &self.adapt { - InnerSettings::LowRank(low_rank) => low_rank.store_mass_matrix, - InnerSettings::Diag(diag) => diag.store_mass_matrix, + fn store_mass_matrix(&self) -> Result { + match &self.inner { + Settings::LowRank(settings) => { + Ok(settings.adapt_options.mass_matrix_options.store_mass_matrix) + } + Settings::Diag(settings) => { + Ok(settings.adapt_options.mass_matrix_options.store_mass_matrix) + } + Settings::Transforming(_) => { + bail!("Option store_mass_matrix not availbale for transformation adaptation") + } } } #[setter(store_mass_matrix)] - fn set_store_mass_matrix(&mut self, val: bool) { - match &mut self.adapt { - InnerSettings::LowRank(low_rank) => { - low_rank.store_mass_matrix = val; + fn set_store_mass_matrix(&mut self, val: bool) -> Result<()> { + match &mut self.inner { + Settings::LowRank(settings) => { + settings.adapt_options.mass_matrix_options.store_mass_matrix = val; + Ok(()) + } + Settings::Diag(settings) => { + settings.adapt_options.mass_matrix_options.store_mass_matrix = val; + Ok(()) } - InnerSettings::Diag(diag) => { - diag.store_mass_matrix = val; + Settings::Transforming(_) => { + bail!("Option store_mass_matrix not availbale for transformation adaptation") } } } #[getter] fn use_grad_based_mass_matrix(&self) -> Result { - match &self.adapt { - InnerSettings::LowRank(_) => { - bail!("grad based mass matrix not available for low-rank adaptation") + match &self.inner { + Settings::LowRank(_) => { + bail!("non-grad based mass matrix not available for low-rank adaptation") + } + Settings::Transforming(_) => { + bail!("non-grad based mass matrix not available for transforming adaptation") } - InnerSettings::Diag(diag) => Ok(diag.use_grad_based_estimate), + Settings::Diag(diag) => Ok(diag + .adapt_options + .mass_matrix_options + .use_grad_based_estimate), } } #[setter(use_grad_based_mass_matrix)] fn set_use_grad_based_mass_matrix(&mut self, val: bool) -> Result<()> { - match &mut self.adapt { - InnerSettings::LowRank(_) => { - bail!("grad based mass matrix not available for low-rank adaptation") + match &mut self.inner { + Settings::LowRank(_) => { + bail!("non-grad based mass matrix not available for low-rank adaptation"); + } + Settings::Transforming(_) => { + bail!("non-grad based mass matrix not available for transforming adaptation"); } - InnerSettings::Diag(diag) => { - diag.use_grad_based_estimate = val; + Settings::Diag(diag) => { + diag.adapt_options + .mass_matrix_options + .use_grad_based_estimate = val; } } Ok(()) } #[getter] - fn mass_matrix_switch_freq(&self) -> u64 { - self.settings.adapt_options.mass_matrix_switch_freq + fn mass_matrix_switch_freq(&self) -> Result { + match &self.inner { + Settings::Diag(settings) => Ok(settings.adapt_options.mass_matrix_switch_freq), + Settings::LowRank(settings) => Ok(settings.adapt_options.mass_matrix_switch_freq), + Settings::Transforming(_) => { + bail!("mass_matrix_switch_freq not available for transforming adaptation"); + } + } } #[setter(mass_matrix_switch_freq)] - fn set_mass_matrix_switch_freq(&mut self, val: u64) { - self.settings.adapt_options.mass_matrix_switch_freq = val; + fn set_mass_matrix_switch_freq(&mut self, val: u64) -> Result<()> { + match &mut self.inner { + Settings::Diag(settings) => settings.adapt_options.mass_matrix_switch_freq = val, + Settings::LowRank(settings) => settings.adapt_options.mass_matrix_switch_freq = val, + Settings::Transforming(_) => { + bail!("mass_matrix_switch_freq not available for transforming adaptation"); + } + } + Ok(()) } #[getter] fn mass_matrix_eigval_cutoff(&self) -> Result { - match &self.adapt { - InnerSettings::LowRank(inner) => Ok(inner.eigval_cutoff), - InnerSettings::Diag(_) => { - bail!("eigenvalue cutoff not available for diag mass matrix adaptation") + match &self.inner { + Settings::LowRank(inner) => Ok(inner.adapt_options.mass_matrix_options.eigval_cutoff), + Settings::Diag(_) => { + bail!("eigenvalue cutoff not available for diag mass matrix adaptation"); + } + Settings::Transforming(_) => { + bail!("eigenvalue cutoff not available for transfor adaptation"); } } } #[setter(mass_matrix_eigval_cutoff)] fn set_mass_matrix_eigval_cutoff(&mut self, val: f64) -> Result<()> { - match &mut self.adapt { - InnerSettings::LowRank(inner) => inner.eigval_cutoff = val, - InnerSettings::Diag(_) => { - bail!("eigenvalue cutoff not available for diag mass matrix adaptation") + match &mut self.inner { + Settings::LowRank(inner) => inner.adapt_options.mass_matrix_options.eigval_cutoff = val, + Settings::Diag(_) => { + bail!("eigenvalue cutoff not available for diag mass matrix adaptation"); + } + Settings::Transforming(_) => { + bail!("eigenvalue cutoff not available for transfor adaptation"); } } Ok(()) @@ -415,21 +576,56 @@ impl PyNutsSettings { #[getter] fn mass_matrix_gamma(&self) -> Result { - match &self.adapt { - InnerSettings::LowRank(inner) => Ok(inner.gamma), - InnerSettings::Diag(_) => { - bail!("gamma not available for diag mass matrix adaptation") + match &self.inner { + Settings::LowRank(inner) => Ok(inner.adapt_options.mass_matrix_options.gamma), + Settings::Diag(_) => { + bail!("gamma not available for diag mass matrix adaptation"); + } + Settings::Transforming(_) => { + bail!("gamma not available for transform adaptation"); } } } #[setter(mass_matrix_gamma)] fn set_mass_matrix_gamma(&mut self, val: f64) -> Result<()> { - match &mut self.adapt { - InnerSettings::LowRank(inner) => inner.gamma = val, - InnerSettings::Diag(_) => { - bail!("gamma not available for diag mass matrix adaptation") + match &mut self.inner { + Settings::LowRank(inner) => { + inner.adapt_options.mass_matrix_options.gamma = val; + } + Settings::Diag(_) => { + bail!("gamma not available for diag mass matrix adaptation"); + } + Settings::Transforming(_) => { + bail!("gamma not available for transform adaptation"); + } + } + Ok(()) + } + + #[getter] + fn train_on_orbit(&self) -> Result { + match &self.inner { + Settings::LowRank(_) => { + bail!("gamma not available for low rank mass matrix adaptation"); + } + Settings::Diag(_) => { + bail!("gamma not available for diag mass matrix adaptation"); } + Settings::Transforming(inner) => Ok(inner.adapt_options.use_orbit_for_training), + } + } + + #[setter(train_on_orbit)] + fn set_train_on_orbit(&mut self, val: bool) -> Result<()> { + match &mut self.inner { + Settings::LowRank(_) => { + bail!("gamma not available for low rank mass matrix adaptation"); + } + Settings::Diag(_) => { + bail!("gamma not available for diag mass matrix adaptation"); + } + Settings::Transforming(inner) => inner.adapt_options.use_orbit_for_training = val, } Ok(()) } @@ -442,13 +638,12 @@ pub(crate) enum SamplerState { } #[derive(Clone)] -#[pyclass] -pub enum ProgressType { +enum InnerProgressType { Callback { rate: Duration, n_cores: usize, template: String, - callback: Py, + callback: Arc>, }, Indicatif { rate: Duration, @@ -456,10 +651,14 @@ pub enum ProgressType { None {}, } +#[pyclass] +#[derive(Clone)] +pub struct ProgressType(InnerProgressType); + impl ProgressType { fn into_callback(self) -> Result> { - match self { - ProgressType::Callback { + match self.0 { + InnerProgressType::Callback { callback, rate, n_cores, @@ -470,11 +669,11 @@ impl ProgressType { Ok(Some(callback)) } - ProgressType::Indicatif { rate } => { + InnerProgressType::Indicatif { rate } => { let handler = IndicatifHandler::new(rate); Ok(Some(handler.into_callback()?)) } - ProgressType::None {} => Ok(None), + InnerProgressType::None {} => Ok(None), } } } @@ -484,28 +683,28 @@ impl ProgressType { #[staticmethod] fn indicatif(rate: u64) -> Self { let rate = Duration::from_millis(rate); - ProgressType::Indicatif { rate } + ProgressType(InnerProgressType::Indicatif { rate }) } #[staticmethod] fn none() -> Self { - ProgressType::None {} + ProgressType(InnerProgressType::None {}) } #[staticmethod] fn template_callback(rate: u64, template: String, n_cores: usize, callback: Py) -> Self { let rate = Duration::from_millis(rate); - ProgressType::Callback { - callback, + ProgressType(InnerProgressType::Callback { + callback: Arc::new(callback), template, n_cores, rate, - } + }) } } #[pyclass] -struct PySampler(SamplerState); +struct PySampler(Mutex); #[pymethods] impl PySampler { @@ -517,14 +716,18 @@ impl PySampler { progress_type: ProgressType, ) -> PyResult { let callback = progress_type.into_callback()?; - match settings.into_settings() { + match settings.inner { Settings::LowRank(settings) => { let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler))) + Ok(PySampler(SamplerState::Running(sampler).into())) } Settings::Diag(settings) => { let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler))) + Ok(PySampler(SamplerState::Running(sampler).into())) + } + Settings::Transforming(settings) => { + let sampler = Sampler::new(model, settings, cores, callback)?; + Ok(PySampler(SamplerState::Running(sampler).into())) } } } @@ -537,14 +740,18 @@ impl PySampler { progress_type: ProgressType, ) -> PyResult { let callback = progress_type.into_callback()?; - match settings.into_settings() { + match settings.inner { Settings::LowRank(settings) => { let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler))) + Ok(PySampler(SamplerState::Running(sampler).into())) } Settings::Diag(settings) => { let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler))) + Ok(PySampler(SamplerState::Running(sampler).into())) + } + Settings::Transforming(settings) => { + let sampler = Sampler::new(model, settings, cores, callback)?; + Ok(PySampler(SamplerState::Running(sampler).into())) } } } @@ -557,38 +764,45 @@ impl PySampler { progress_type: ProgressType, ) -> PyResult { let callback = progress_type.into_callback()?; - match settings.into_settings() { + match settings.inner { Settings::LowRank(settings) => { let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler))) + Ok(PySampler(SamplerState::Running(sampler).into())) } Settings::Diag(settings) => { let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler))) + Ok(PySampler(SamplerState::Running(sampler).into())) + } + Settings::Transforming(settings) => { + let sampler = Sampler::new(model, settings, cores, callback)?; + Ok(PySampler(SamplerState::Running(sampler).into())) } } } fn is_finished(&mut self, py: Python<'_>) -> PyResult { py.allow_threads(|| { - let state = std::mem::replace(&mut self.0, SamplerState::Empty); + let guard = &mut self.0.lock().expect("Poisond sampler state mutex"); + let slot = guard.deref_mut(); + + let state = std::mem::replace(slot, SamplerState::Empty); let SamplerState::Running(sampler) = state else { - let _ = std::mem::replace(&mut self.0, state); + let _ = std::mem::replace(slot, state); return Ok(true); }; match sampler.wait_timeout(Duration::from_millis(1)) { SamplerWaitResult::Trace(trace) => { - let _ = std::mem::replace(&mut self.0, SamplerState::Finished(Some(trace))); + let _ = std::mem::replace(slot, SamplerState::Finished(Some(trace))); Ok(true) } SamplerWaitResult::Timeout(sampler) => { - let _ = std::mem::replace(&mut self.0, SamplerState::Running(sampler)); + let _ = std::mem::replace(slot, SamplerState::Running(sampler)); Ok(false) } SamplerWaitResult::Err(err, trace) => { - let _ = std::mem::replace(&mut self.0, SamplerState::Finished(trace)); + let _ = std::mem::replace(slot, SamplerState::Finished(trace)); Err(err.into()) } } @@ -597,7 +811,12 @@ impl PySampler { fn pause(&mut self, py: Python<'_>) -> PyResult<()> { py.allow_threads(|| { - if let SamplerState::Running(ref mut control) = self.0 { + if let SamplerState::Running(ref mut control) = self + .0 + .lock() + .expect("Poised sampler state mutex") + .deref_mut() + { control.pause()? } Ok(()) @@ -606,24 +825,33 @@ impl PySampler { fn resume(&mut self, py: Python<'_>) -> PyResult<()> { py.allow_threads(|| { - if let SamplerState::Running(ref mut control) = self.0 { + if let SamplerState::Running(ref mut control) = self + .0 + .lock() + .expect("Poisond sampler state mutex") + .deref_mut() + { control.resume()? } Ok(()) }) } + #[pyo3(signature = (timeout_seconds=None))] fn wait(&mut self, py: Python<'_>, timeout_seconds: Option) -> PyResult<()> { py.allow_threads(|| { + let guard = &mut self.0.lock().expect("Poisond sampler state mutex"); + let slot = guard.deref_mut(); + let timeout = match timeout_seconds { Some(val) => Some(Duration::try_from_secs_f64(val).context("Invalid timeout")?), None => None, }; - let state = std::mem::replace(&mut self.0, SamplerState::Empty); + let state = std::mem::replace(slot, SamplerState::Empty); let SamplerState::Running(mut control) = state else { - let _ = std::mem::replace(&mut self.0, state); + let _ = std::mem::replace(slot, state); return Ok(()); }; @@ -664,32 +892,38 @@ impl PySampler { } }; - let _ = std::mem::replace(&mut self.0, final_state); + let _ = std::mem::replace(slot, final_state); retval }) } fn abort(&mut self, py: Python<'_>) -> PyResult<()> { py.allow_threads(|| { - let state = std::mem::replace(&mut self.0, SamplerState::Empty); + let guard = &mut self.0.lock().expect("Poisond sampler state mutex"); + let slot = guard.deref_mut(); + + let state = std::mem::replace(slot, SamplerState::Empty); let SamplerState::Running(control) = state else { - let _ = std::mem::replace(&mut self.0, state); + let _ = std::mem::replace(slot, state); return Ok(()); }; let (result, trace) = control.abort(); - let _ = std::mem::replace(&mut self.0, SamplerState::Finished(trace)); + let _ = std::mem::replace(slot, SamplerState::Finished(trace)); result?; Ok(()) }) } fn extract_results<'py>(&mut self, py: Python<'py>) -> PyResult> { - let state = std::mem::replace(&mut self.0, SamplerState::Empty); + let guard = &mut self.0.lock().expect("Poisond sampler state mutex"); + let slot = guard.deref_mut(); + + let state = std::mem::replace(slot, SamplerState::Empty); let SamplerState::Finished(trace) = state else { - let _ = std::mem::replace(&mut self.0, state); + let _ = std::mem::replace(slot, state); return Err(anyhow::anyhow!("Sampler is not finished"))?; }; @@ -703,7 +937,7 @@ impl PySampler { } fn is_empty(&self) -> bool { - match self.0 { + match self.0.lock().expect("Poisoned sampler state lock").deref() { SamplerState::Running(_) => false, SamplerState::Finished(_) => false, SamplerState::Empty => true, @@ -712,7 +946,8 @@ impl PySampler { fn inspect<'py>(&mut self, py: Python<'py>) -> PyResult> { let trace = py.allow_threads(|| { - let SamplerState::Running(ref mut sampler) = self.0 else { + let mut guard = self.0.lock().unwrap(); + let SamplerState::Running(ref mut sampler) = guard.deref_mut() else { return Err(anyhow::anyhow!("Sampler is not running"))?; }; @@ -723,28 +958,28 @@ impl PySampler { } fn trace_to_list(trace: Trace, py: Python<'_>) -> PyResult> { - let list = PyList::new_bound( + let list = PyList::new( py, trace .chains .into_iter() .map(|chain| { - Ok(PyTuple::new_bound( + Ok(PyTuple::new( py, [ export_array(py, chain.draws)?, export_array(py, chain.stats)?, ] .into_iter(), - )) + )?) }) .collect::>>()?, - ); + )?; Ok(list) } fn export_array(py: Python<'_>, data: Arc) -> PyResult { - let pa = py.import_bound("pyarrow")?; + let pa = py.import("pyarrow")?; let array = pa.getattr("Array")?; let data = data.into_data(); @@ -755,12 +990,252 @@ fn export_array(py: Python<'_>, data: Arc) -> PyResult { .call_method1( "_import_from_c", ( - (&data as *const _ as Py_uintptr_t).into_py(py), - (&schema as *const _ as Py_uintptr_t).into_py(py), + (&data as *const _ as Py_uintptr_t).into_pyobject(py)?, + (&schema as *const _ as Py_uintptr_t).into_pyobject(py)?, ), ) .context("Could not import arrow trace in python")?; - Ok(data.into_py(py)) + Ok(data.unbind()) +} + +#[pyclass] +#[derive(Debug, Clone)] +pub struct PyTransformAdapt(Arc>); + +#[pymethods] +impl PyTransformAdapt { + #[new] + pub fn new(adapter: Py) -> Self { + Self(Arc::new(adapter)) + } +} + +impl PyTransformAdapt { + pub fn inv_transform_normalize( + &mut self, + params: &Py, + untransformed_position: &[f64], + untransformed_gradient: &[f64], + transformed_position: &mut [f64], + transformed_gradient: &mut [f64], + ) -> Result { + Python::with_gil(|py| { + let untransformed_position = PyArray1::from_slice(py, untransformed_position); + let untransformed_gradient = PyArray1::from_slice(py, untransformed_gradient); + + let output = params + .getattr(py, intern!(py, "inv_transform")) + .context("Could not access attribute inv_transform")? + .call1(py, (untransformed_position, untransformed_gradient)) + .context("Failed to call adapter.inv_transform")?; + let (logdet, transformed_position_out, transformed_gradient_out): ( + f64, + PyReadonlyArray1, + PyReadonlyArray1, + ) = output + .extract(py) + .context("Execpected results from adapter.inv_transform")?; + + if !transformed_position_out + .as_slice()? + .iter() + .all(|&x| x.is_finite()) + { + bail!("Transformed position is not finite"); + } + if !transformed_gradient_out + .as_slice()? + .iter() + .all(|&x| x.is_finite()) + { + bail!("Transformed position is not finite"); + } + + transformed_position.copy_from_slice( + transformed_position_out + .as_slice() + .context("Could not copy transformed_position")?, + ); + + transformed_gradient.copy_from_slice( + transformed_gradient_out + .as_slice() + .context("Could not copy transformed_gradient")?, + ); + Ok(logdet) + }) + } + + pub fn init_from_transformed_position( + &mut self, + params: &Py, + untransformed_position: &mut [f64], + untransformed_gradient: &mut [f64], + transformed_position: &[f64], + transformed_gradient: &mut [f64], + ) -> Result<(f64, f64)> { + Python::with_gil(|py| { + let transformed_position = PyArray1::from_slice(py, transformed_position); + + let output = params + .getattr(py, intern!(py, "init_from_transformed_position"))? + .call1(py, (transformed_position,))?; + let ( + logp, + logdet, + untransformed_position_out, + untransformed_gradient_out, + transformed_gradient_out, + ): ( + f64, + f64, + PyReadonlyArray1, + PyReadonlyArray1, + PyReadonlyArray1, + ) = output.extract(py)?; + + untransformed_position.copy_from_slice(untransformed_position_out.as_slice()?); + untransformed_gradient.copy_from_slice(untransformed_gradient_out.as_slice()?); + transformed_gradient.copy_from_slice(transformed_gradient_out.as_slice()?); + Ok((logp, logdet)) + }) + } + + pub fn init_from_transformed_position_part1( + &mut self, + params: &Py, + untransformed_position: &mut [f64], + transformed_position: &[f64], + ) -> Result> { + Python::with_gil(|py| { + let transformed_position = PyArray1::from_slice(py, transformed_position); + + let output = params + .getattr(py, intern!(py, "init_from_transformed_position_part1"))? + .call1(py, (transformed_position,))?; + let (untransformed_position_out, part1): (PyReadonlyArray1, Py) = + output.extract(py)?; + + untransformed_position.copy_from_slice(untransformed_position_out.as_slice()?); + Ok(part1) + }) + } + + pub fn init_from_transformed_position_part2( + &mut self, + params: &Py, + part1: Py, + untransformed_gradient: &[f64], + transformed_gradient: &mut [f64], + ) -> Result { + Python::with_gil(|py| { + let untransformed_gradient = PyArray1::from_slice(py, untransformed_gradient); + + let output = params + .getattr(py, intern!(py, "init_from_transformed_position_part2"))? + .call1(py, (part1, untransformed_gradient))?; + let (logdet, transformed_gradient_out): (f64, PyReadonlyArray1) = + output.extract(py)?; + + transformed_gradient.copy_from_slice(transformed_gradient_out.as_slice()?); + Ok(logdet) + }) + } + + pub fn init_from_untransformed_position( + &mut self, + params: &Py, + untransformed_position: &[f64], + untransformed_gradient: &mut [f64], + transformed_position: &mut [f64], + transformed_gradient: &mut [f64], + ) -> Result<(f64, f64)> { + Python::with_gil(|py| { + let untransformed_position = PyArray1::from_slice(py, untransformed_position); + + let output = params + .getattr(py, intern!(py, "init_from_untransformed_position")) + .context("No attribute init_from_untransformed_position")? + .call1(py, (untransformed_position,)) + .context("Failed adapter.init_from_untransformed_position")?; + let ( + logp, + logdet, + untransformed_gradient_out, + transformed_position_out, + transformed_gradient_out, + ): ( + f64, + f64, + PyReadonlyArray1, + PyReadonlyArray1, + PyReadonlyArray1, + ) = output + .extract(py) + .context("Unexpected return value of init_from_untransformed_position")?; + + untransformed_gradient.copy_from_slice(untransformed_gradient_out.as_slice()?); + transformed_position.copy_from_slice(transformed_position_out.as_slice()?); + transformed_gradient.copy_from_slice(transformed_gradient_out.as_slice()?); + Ok((logp, logdet)) + }) + } + + pub fn update_transformation<'a, R: rand::Rng + ?Sized>( + &'a mut self, + rng: &mut R, + untransformed_positions: impl ExactSizeIterator, + untransformed_gradients: impl ExactSizeIterator, + untransformed_logp: impl ExactSizeIterator, + params: &'a mut Py, + ) -> Result<()> { + Python::with_gil(|py| { + let positions = PyList::new( + py, + untransformed_positions.map(|pos| PyArray1::from_slice(py, pos)), + )?; + let gradients = PyList::new( + py, + untransformed_gradients.map(|grad| PyArray1::from_slice(py, grad)), + )?; + + let logps = PyArray1::from_iter(py, untransformed_logp.copied()); + let seed = rng.next_u64(); + + params + .getattr(py, intern!(py, "update"))? + .call1(py, (seed, positions, gradients, logps))?; + Ok(()) + }) + } + + pub fn new_transformation( + &mut self, + rng: &mut R, + untransformed_position: &[f64], + untransformed_gradient: &[f64], + chain: u64, + ) -> Result> { + Python::with_gil(|py| { + let position = PyArray1::from_slice(py, untransformed_position); + let gradient = PyArray1::from_slice(py, untransformed_gradient); + + let seed = rng.next_u64(); + + let transformer = self.0.call1(py, (seed, position, gradient, chain))?; + + Ok(transformer) + }) + } + + pub fn transformation_id(&self, params: &Py) -> Result { + Python::with_gil(|py| { + let id: i64 = params + .getattr(py, intern!(py, "transformation_id"))? + .extract(py)?; + Ok(id) + }) + } } /// A Python module implemented in Rust. diff --git a/tests/test_pymc.py b/tests/test_pymc.py index a586e24..8816967 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -1,6 +1,7 @@ import numpy as np import pymc as pm import pytest +from scipy import stats import nutpie import nutpie.compile_pymc @@ -243,6 +244,34 @@ def test_pymc_var_names(backend, gradient_backend): assert not hasattr(trace.posterior, "c") +@pytest.mark.slow +def test_normalizing_flow(): + with pm.Model() as model: + pm.HalfNormal("x", shape=2) + + compiled = nutpie.compile_pymc_model( + model, backend="jax", gradient_backend="jax" + ).with_transform_adapt( + num_diag_windows=6, + verbose=True, + ) + trace = nutpie.sample( + compiled, + chains=1, + transform_adapt=True, + window_switch_freq=150, + tune=600, + seed=1, + ) + draws = trace.posterior.x.isel(x_dim_0=0, chain=0) + kstest = stats.ks_1samp(draws, stats.halfnorm.cdf) + assert kstest.pvalue > 0.01 + + draws = trace.posterior.x.isel(x_dim_0=1, chain=0) + kstest = stats.ks_1samp(draws, stats.halfnorm.cdf) + assert kstest.pvalue > 0.01 + + @pytest.mark.parametrize( ("backend", "gradient_backend"), [