diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..8f61a8e --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# SCM syntax highlighting +pixi.lock linguist-language=YAML linguist-generated=true diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2b9b3c3..b3edeb2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,7 +34,7 @@ jobs: 3.10 3.11 3.12 - # Leave out 3.13 on aarch due to an issue in pyo3/rust-numpy 0.23.4 + 3.13 - name: Build wheels uses: PyO3/maturin-action@v1 if: ${{ matrix.platform.target == 'aarch64' }} @@ -50,8 +50,7 @@ jobs: if: ${{ matrix.platform.target == 'x86_64' }} with: target: ${{ matrix.platform.target }} - # No py3.13 yet... - args: --release --out dist --interpreter 3.10 3.11 3.12 --zig + args: --release --out dist --interpreter 3.10 3.11 3.12 3.13 --zig sccache: ${{ !startsWith(github.ref, 'refs/tags/') }} manylinux: auto before-script-linux: | @@ -175,7 +174,7 @@ jobs: 3.10 3.11 3.12 - # 3.13 leave out 3.13 due to a segfault + 3.13 architecture: ${{ matrix.platform.target }} - name: Install uv uses: astral-sh/setup-uv@v5 @@ -230,6 +229,7 @@ jobs: 3.10 3.11 3.12 + 3.13 - name: Install uv uses: astral-sh/setup-uv@v5 - uses: maxim-lobanov/setup-xcode@v1 diff --git a/.gitignore b/.gitignore index f0721f7..bffa4f4 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,20 @@ tvm_libs/* notebooks/*.stan notebooks/*.csv notebooks/*.hpp +notebooks/radon* perf.data* wheels .vscode/ *~ +.zed +.cargo +*traces* +.pyrightconfig.json +*.zarr +book +docs/_site +.quarto +example-iree +posteriordb +.quarto +docs/.quarto diff --git a/Cargo.lock b/Cargo.lock index 4fe15a5..e7d9279 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,23 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + [[package]] name = "ahash" version = "0.8.11" @@ -10,10 +27,10 @@ checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", "const-random", - "getrandom", + "getrandom 0.2.15", "once_cell", "version_check", - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -54,15 +71,15 @@ checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "anyhow" -version = "1.0.95" +version = "1.0.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" +checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f" [[package]] name = "arrow" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05048a8932648b63f21c37d88b552ccc8a65afb6dfe9fc9f30ce79174c2e7a85" +checksum = "dc208515aa0151028e464cc94a692156e945ce5126abd3537bb7fd6ba2143ed1" dependencies = [ "arrow-arith", "arrow-array", @@ -78,24 +95,23 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d8a57966e43bfe9a3277984a14c24ec617ad874e4c0e1d2a1b083a39cfbf22c" +checksum = "e07e726e2b3f7816a85c6a45b6ec118eeeabf0b2a8c208122ad949437181f49a" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "chrono", - "half", "num", ] [[package]] name = "arrow-array" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16f4a9468c882dc66862cef4e1fd8423d47e67972377d85d80e022786427768c" +checksum = "a2262eba4f16c78496adfd559a29fe4b24df6088efc9985a873d58e92be022d5" dependencies = [ "ahash", "arrow-buffer", @@ -109,9 +125,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c975484888fc95ec4a632cdc98be39c085b1bb518531b0c80c5d462063e5daa1" +checksum = "4e899dade2c3b7f5642eb8366cfd898958bcca099cde6dfea543c7e8d3ad88d4" dependencies = [ "bytes", "half", @@ -120,9 +136,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da26719e76b81d8bc3faad1d4dbdc1bcc10d14704e63dc17fc9f3e7e1e567c8e" +checksum = "4103d88c5b441525ed4ac23153be7458494c2b0c9a11115848fdb9b81f6f886a" dependencies = [ "arrow-array", "arrow-buffer", @@ -140,9 +156,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd9d6f18c65ef7a2573ab498c374d8ae364b4a4edf67105357491c031f716ca5" +checksum = "0a329fb064477c9ec5f0870d2f5130966f91055c7c5bce2b3a084f116bc28c3b" dependencies = [ "arrow-buffer", "arrow-schema", @@ -152,26 +168,23 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42745f86b1ab99ef96d1c0bcf49180848a64fe2c7a7a0d945bc64fa2b21ba9bc" +checksum = "f841bfcc1997ef6ac48ee0305c4dfceb1f7c786fe31e67c1186edf775e1f1160" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "arrow-select", - "half", - "num", ] [[package]] name = "arrow-row" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd09a518c602a55bd406bcc291a967b284cfa7a63edfbf8b897ea4748aad23c" +checksum = "1eeb55b0a0a83851aa01f2ca5ee5648f607e8506ba6802577afdda9d75cdedcd" dependencies = [ - "ahash", "arrow-array", "arrow-buffer", "arrow-data", @@ -181,18 +194,18 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e972cd1ff4a4ccd22f86d3e53e835c2ed92e0eea6a3e8eadb72b4f1ac802cf8" +checksum = "85934a9d0261e0fa5d4e2a5295107d743b543a6e0484a835d4b8db2da15306f9" dependencies = [ "bitflags", ] [[package]] name = "arrow-select" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "600bae05d43483d216fb3494f8c32fdbefd8aa4e1de237e790dbb3d9f44690a3" +checksum = "7e2932aece2d0c869dd2125feb9bd1709ef5c445daa3838ac4112dcfa0fda52c" dependencies = [ "ahash", "arrow-array", @@ -204,9 +217,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "52.2.0" +version = "54.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0dc1985b67cb45f6606a248ac2b4a288849f196bab8c657ea5589f47cdd55e6" +checksum = "912e38bd6a7a7714c1d9b61df80315685553b7455e8a6045c27531d8ecd5b458" dependencies = [ "arrow-array", "arrow-buffer", @@ -240,6 +253,12 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + [[package]] name = "bindgen" version = "0.71.1" @@ -255,16 +274,16 @@ dependencies = [ "proc-macro2", "quote", "regex", - "rustc-hash 2.1.1", + "rustc-hash", "shlex", - "syn 2.0.98", + "syn", ] [[package]] name = "bitflags" -version = "2.8.0" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" +checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" [[package]] name = "block-buffer" @@ -285,7 +304,7 @@ dependencies = [ "libloading", "log", "path-absolutize", - "thiserror 2.0.11", + "thiserror 2.0.12", ] [[package]] @@ -296,9 +315,9 @@ checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" [[package]] name = "bytemuck" -version = "1.21.0" +version = "1.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef657dfab802224e671f5818e9a4935f9b1957ed18e58292690cc39e7a4092a3" +checksum = "b6b1fc10dbac614ebc03540c9dbd60e83887fda27794998c6528f1782047d540" dependencies = [ "bytemuck_derive", ] @@ -311,7 +330,7 @@ checksum = "3fa76293b4f7bb636ab88fd78228235b5248b4d05cc589aed610f954af5d7c7a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn", ] [[package]] @@ -326,6 +345,26 @@ version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f61dac84819c6588b558454b194026eb1f09c293b9036ae9b159e74e73ab6cf9" +[[package]] +name = "bzip2" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.13+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225bff33b2141874fe80d71e07d6eec4f85c5c216453dd96388240f96e1acc14" +dependencies = [ + "cc", + "pkg-config", +] + [[package]] name = "cast" version = "0.3.0" @@ -334,10 +373,12 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.14" +version = "1.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c3d1b2e905a3a7b00a6141adb0e4c0bb941d11caf55349d863942a1cc44e3c9" +checksum = "be714c154be609ec7f5dad223a33bf1482fff90472de28f7362806e6d4832b8c" dependencies = [ + "jobserver", + "libc", "shlex", ] @@ -395,6 +436,16 @@ dependencies = [ "half", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -408,18 +459,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.30" +version = "4.5.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92b7b18d71fad5313a1e320fa9897994228ce274b60faa4d694fe0ea89cd9e6d" +checksum = "027bb0d98429ae334a8698531da7077bdf906419543a35a55c2cb1b66437d767" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.30" +version = "4.5.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a35db2071778a7344791a4fb4f95308b5673d219dee3ae348b86642574ecc90c" +checksum = "5589e0cba072e0f3d23791efac0fd8627b49c829c196a492e88168e6a669d863" dependencies = [ "anstyle", "clap_lex", @@ -431,17 +482,11 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" -[[package]] -name = "coe-rs" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e8f1e641542c07631228b1e0dc04b69ae3c1d58ef65d5691a439711d805c698" - [[package]] name = "console" -version = "0.15.10" +version = "0.15.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea3c6ecd8059b57859df5c69830340ed3c41d30e3da0c1cbed90a96ac853041b" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" dependencies = [ "encode_unicode", "libc", @@ -465,11 +510,17 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ - "getrandom", + "getrandom 0.2.15", "once_cell", "tiny-keccak", ] +[[package]] +name = "constant_time_eq" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -485,6 +536,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + [[package]] name = "criterion" version = "0.5.1" @@ -563,10 +623,13 @@ dependencies = [ ] [[package]] -name = "dbgf" -version = "0.1.2" +name = "deranged" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6ca96b45ca70b8045e0462f191bd209fcb3c3bfe8dbfb1257ada54c4dd59169" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", +] [[package]] name = "digest" @@ -576,16 +639,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", -] - -[[package]] -name = "dyn-stack" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56e53799688f5632f364f8fb387488dd05db9fe45db7011be066fc20e7027f8b" -dependencies = [ - "bytemuck", - "reborrow", + "subtle", ] [[package]] @@ -599,9 +653,9 @@ dependencies = [ [[package]] name = "either" -version = "1.13.0" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +checksum = "b7914353092ddf589ad78f25c5c1c21b7f80b0ff8621e7c814c3485b5306da9d" [[package]] name = "encode_unicode" @@ -615,10 +669,10 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", - "syn 2.0.98", + "syn", ] [[package]] @@ -647,7 +701,7 @@ checksum = "3bf679796c0322556351f287a51b49e48f7c4986e727b5dd78c972d30e2e16cc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn", ] [[package]] @@ -658,59 +712,76 @@ checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn", ] [[package]] name = "faer" -version = "0.19.4" +version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64bc4855cb2792ae3520e8af22051a47a6d6dc8300ebc0ddf51ad73f65bd0dc9" +checksum = "d671941ab57443f46ebe3f153a9fc3ed6cce777926c14e5fdf5da178a35ea476" dependencies = [ "bytemuck", - "coe-rs", - "dbgf", - "dyn-stack 0.10.0", + "dyn-stack", "equator 0.4.2", - "faer-entity", + "faer-macros", + "faer-traits", "gemm", + "generativity", "libm", - "matrixcompare", - "matrixcompare-core", "nano-gemm", "npyz", "num-complex", "num-traits", - "paste", - "rand", - "rand_distr", - "rayon", + "pulp", "reborrow", - "serde", ] [[package]] -name = "faer-entity" -version = "0.19.2" +name = "faer-macros" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9c752ab2bff6f0b9597c6a1adc0112f7fd41fb343bc5a009a6274ae9d32fd03" +checksum = "9d0a255d1442b5825c61812a7eafda9034ec53d969c98555251085e148428e6a" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "faer-traits" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2d0172aefb5f869561e558d5390657f1aa98ca3c51a09be69a4687064ebfb9a" dependencies = [ "bytemuck", - "coe-rs", + "dyn-stack", + "faer-macros", + "generativity", "libm", "num-complex", "num-traits", - "pulp 0.18.22", + "pulp", "reborrow", ] +[[package]] +name = "flate2" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "gemm" version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab96b703d31950f1aeddded248bc95543c9efc7ac9c4a21fda8703a83ee35451" dependencies = [ - "dyn-stack 0.13.0", + "dyn-stack", "gemm-c32", "gemm-c64", "gemm-common", @@ -730,7 +801,7 @@ version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6db9fd9f40421d00eea9dd0770045a5603b8d684654816637732463f4073847" dependencies = [ - "dyn-stack 0.13.0", + "dyn-stack", "gemm-common", "num-complex", "num-traits", @@ -745,7 +816,7 @@ version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfcad8a3d35a43758330b635d02edad980c1e143dc2f21e6fd25f9e4eada8edf" dependencies = [ - "dyn-stack 0.13.0", + "dyn-stack", "gemm-common", "num-complex", "num-traits", @@ -761,16 +832,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a352d4a69cbe938b9e2a9cb7a3a63b7e72f9349174a2752a558a8a563510d0f3" dependencies = [ "bytemuck", - "dyn-stack 0.13.0", + "dyn-stack", "half", "libm", "num-complex", "num-traits", "once_cell", "paste", - "pulp 0.21.4", + "pulp", "raw-cpuid", - "rayon", "seq-macro", "sysctl", ] @@ -781,7 +851,7 @@ version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cff95ae3259432f3c3410eaa919033cd03791d81cebd18018393dc147952e109" dependencies = [ - "dyn-stack 0.13.0", + "dyn-stack", "gemm-common", "gemm-f32", "half", @@ -789,7 +859,6 @@ dependencies = [ "num-traits", "paste", "raw-cpuid", - "rayon", "seq-macro", ] @@ -799,7 +868,7 @@ version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc8d3d4385393304f407392f754cd2dc4b315d05063f62cf09f47b58de276864" dependencies = [ - "dyn-stack 0.13.0", + "dyn-stack", "gemm-common", "num-complex", "num-traits", @@ -814,7 +883,7 @@ version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "35b2a4f76ce4b8b16eadc11ccf2e083252d8237c1b589558a49b0183545015bd" dependencies = [ - "dyn-stack 0.13.0", + "dyn-stack", "gemm-common", "num-complex", "num-traits", @@ -823,6 +892,12 @@ dependencies = [ "seq-macro", ] +[[package]] +name = "generativity" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5881e4c3c2433fe4905bb19cfd2b5d49d4248274862b68c27c33d9ba4e13f9ec" + [[package]] name = "generic-array" version = "0.14.7" @@ -841,7 +916,19 @@ checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", +] + +[[package]] +name = "getrandom" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.13.3+wasi-0.2.2", + "windows-targets", ] [[package]] @@ -864,15 +951,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" - -[[package]] -name = "heck" -version = "0.4.1" +version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" [[package]] name = "heck" @@ -886,6 +967,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "iana-time-zone" version = "0.1.61" @@ -924,9 +1014,18 @@ dependencies = [ [[package]] name = "indoc" -version = "2.0.5" +version = "2.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" + +[[package]] +name = "inout" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" +dependencies = [ + "generic-array", +] [[package]] name = "is-terminal" @@ -968,9 +1067,18 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + +[[package]] +name = "jobserver" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +dependencies = [ + "libc", +] [[package]] name = "js-sys" @@ -982,11 +1090,17 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + [[package]] name = "lexical-core" -version = "0.8.5" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" +checksum = "b765c31809609075565a70b4b71402281283aeda7ecaf4818ac14a7b2ade8958" dependencies = [ "lexical-parse-float", "lexical-parse-integer", @@ -997,9 +1111,9 @@ dependencies = [ [[package]] name = "lexical-parse-float" -version = "0.8.5" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f" +checksum = "de6f9cb01fb0b08060209a057c048fcbab8717b4c1ecd2eac66ebfe39a65b0f2" dependencies = [ "lexical-parse-integer", "lexical-util", @@ -1008,9 +1122,9 @@ dependencies = [ [[package]] name = "lexical-parse-integer" -version = "0.8.6" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9" +checksum = "72207aae22fc0a121ba7b6d479e42cbfea549af1479c3f3a4f12c70dd66df12e" dependencies = [ "lexical-util", "static_assertions", @@ -1018,18 +1132,18 @@ dependencies = [ [[package]] name = "lexical-util" -version = "0.8.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc" +checksum = "5a82e24bf537fd24c177ffbbdc6ebcc8d54732c35b50a3f28cc3f4e4c949a0b3" dependencies = [ "static_assertions", ] [[package]] name = "lexical-write-float" -version = "0.8.5" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" +checksum = "c5afc668a27f460fb45a81a757b6bf2f43c2d7e30cb5a2dcd3abf294c78d62bd" dependencies = [ "lexical-util", "lexical-write-integer", @@ -1038,9 +1152,9 @@ dependencies = [ [[package]] name = "lexical-write-integer" -version = "0.8.5" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" +checksum = "629ddff1a914a836fb245616a7888b62903aae58fa771e1d83943035efa0f978" dependencies = [ "lexical-util", "static_assertions", @@ -1048,9 +1162,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.169" +version = "0.2.170" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +checksum = "875b3680cb2f8f71bdcf9a30f38d48282f5d3c95cbf9b3fa57269bb5d5c06828" [[package]] name = "libloading" @@ -1068,37 +1182,11 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" -[[package]] -name = "lock_api" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" -dependencies = [ - "autocfg", - "scopeguard", -] - [[package]] name = "log" -version = "0.4.25" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" - -[[package]] -name = "matrixcompare" -version = "0.3.0" +version = "0.4.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37832ba820e47c93d66b4360198dccb004b43c74abc3ac1ce1fed54e65a80445" -dependencies = [ - "matrixcompare-core", - "num-traits", -] - -[[package]] -name = "matrixcompare-core" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0bdabb30db18805d5290b3da7ceaccbddba795620b86c02145d688e04900a73" +checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" [[package]] name = "matrixmultiply" @@ -1131,11 +1219,20 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" +[[package]] +name = "miniz_oxide" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5" +dependencies = [ + "adler2", +] + [[package]] name = "multiversion" -version = "0.7.4" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4851161a11d3ad0bf9402d90ffc3967bf231768bfd7aeb61755ad06dbf1a142" +checksum = "7edb7f0ff51249dfda9ab96b5823695e15a052dc15074c9dbf3d118afaf2c201" dependencies = [ "multiversion-macros", "target-features", @@ -1143,13 +1240,13 @@ dependencies = [ [[package]] name = "multiversion-macros" -version = "0.7.4" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79a74ddee9e0c27d2578323c13905793e91622148f138ba29738f9dddb835e90" +checksum = "b093064383341eb3271f42e381cb8f10a01459478446953953c75d24bd339fc0" dependencies = [ "proc-macro2", "quote", - "syn 1.0.109", + "syn", "target-features", ] @@ -1225,14 +1322,16 @@ dependencies = [ [[package]] name = "ndarray" -version = "0.15.6" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" dependencies = [ "matrixmultiply", "num-complex", "num-integer", "num-traits", + "portable-atomic", + "portable-atomic-util", "rawpointer", ] @@ -1289,9 +1388,14 @@ checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" dependencies = [ "bytemuck", "num-traits", - "rand", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-integer" version = "0.1.46" @@ -1341,9 +1445,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "numpy" -version = "0.21.0" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec170733ca37175f5d75a5bea5911d6ff45d2cd52849ce98b685394e4f2f37f4" +checksum = "b94caae805f998a07d33af06e6a3891e38556051b8045c615470a71590e13e78" dependencies = [ "libc", "ndarray", @@ -1351,12 +1455,12 @@ dependencies = [ "num-integer", "num-traits", "pyo3", - "rustc-hash 1.1.0", + "rustc-hash", ] [[package]] name = "nutpie" -version = "0.13.4" +version = "0.14.0" dependencies = [ "anyhow", "arrow", @@ -1367,33 +1471,34 @@ dependencies = [ "numpy", "nuts-rs", "pyo3", - "rand", - "rand_chacha", + "rand 0.9.0", + "rand_chacha 0.9.0", "rand_distr", "rayon", "smallvec", - "thiserror 1.0.69", + "tch", + "thiserror 2.0.12", "time-humanize", "upon", ] [[package]] name = "nuts-rs" -version = "0.12.1" +version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8573e3b5c83e8ec0570ebbd75dd6fdc7dfcfa5da9b5f9d9d63fedefebbd9cf8" +checksum = "10e87924d332fce1202087bc67db7ed8f7ef9229da5ec74a5130568f5b7f6ac7" dependencies = [ "anyhow", "arrow", "faer", - "itertools 0.13.0", + "itertools 0.14.0", "multiversion", - "pulp 0.18.22", - "rand", - "rand_chacha", + "pulp", + "rand 0.9.0", + "rand_chacha 0.9.0", "rand_distr", "rayon", - "thiserror 1.0.69", + "thiserror 2.0.12", ] [[package]] @@ -1409,26 +1514,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" [[package]] -name = "parking_lot" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.10" +name = "password-hash" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700" dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-targets", + "base64ct", + "rand_core 0.6.4", + "subtle", ] [[package]] @@ -1455,6 +1548,18 @@ dependencies = [ "once_cell", ] +[[package]] +name = "pbkdf2" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917" +dependencies = [ + "digest", + "hmac", + "password-hash", + "sha2", +] + [[package]] name = "pest" version = "2.7.15" @@ -1462,7 +1567,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b7cafe60d6cf8e62e1b9b2ea516a089c008945bb5a275416789e7db0bc199dc" dependencies = [ "memchr", - "thiserror 2.0.11", + "thiserror 2.0.12", "ucd-trie", ] @@ -1486,7 +1591,7 @@ dependencies = [ "pest_meta", "proc-macro2", "quote", - "syn 2.0.98", + "syn", ] [[package]] @@ -1500,6 +1605,12 @@ dependencies = [ "sha2", ] +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + [[package]] name = "plotters" version = "0.3.7" @@ -1530,9 +1641,24 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.10.0" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "powerfmt" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" @@ -1540,40 +1666,28 @@ version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy", + "zerocopy 0.7.35", ] [[package]] name = "prettyplease" -version = "0.2.29" +version = "0.2.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" +checksum = "f1ccf34da56fc294e7d4ccf69a85992b7dfb826b7cf57bac6a70bba3494cc08a" dependencies = [ "proc-macro2", - "syn 2.0.98", + "syn", ] [[package]] name = "proc-macro2" -version = "1.0.93" +version = "1.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" +checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84" dependencies = [ "unicode-ident", ] -[[package]] -name = "pulp" -version = "0.18.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0a01a0dc67cf4558d279f0c25b0962bd08fc6dec0137699eae304103e882fe6" -dependencies = [ - "bytemuck", - "libm", - "num-complex", - "reborrow", -] - [[package]] name = "pulp" version = "0.21.4" @@ -1603,16 +1717,16 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.21.2" +version = "0.23.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" +checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872" dependencies = [ "anyhow", "cfg-if", "indoc", "libc", "memoffset", - "parking_lot", + "once_cell", "portable-atomic", "pyo3-build-config", "pyo3-ffi", @@ -1622,9 +1736,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.21.2" +version = "0.23.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" +checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb" dependencies = [ "once_cell", "target-lexicon", @@ -1632,9 +1746,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.21.2" +version = "0.23.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" +checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d" dependencies = [ "libc", "pyo3-build-config", @@ -1642,34 +1756,34 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.21.2" +version = "0.23.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" +checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da" dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.98", + "syn", ] [[package]] name = "pyo3-macros-backend" -version = "0.21.2" +version = "0.23.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" +checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028" dependencies = [ - "heck 0.4.1", + "heck", "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.98", + "syn", ] [[package]] name = "quote" -version = "1.0.38" +version = "1.0.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" +checksum = "c1f1914ce909e1658d9907913b4b91947430c7d9be598b15a1912935b8c04801" dependencies = [ "proc-macro2", ] @@ -1681,8 +1795,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.3", + "zerocopy 0.8.21", ] [[package]] @@ -1692,7 +1817,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.3", ] [[package]] @@ -1701,17 +1836,26 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.15", +] + +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom 0.3.1", ] [[package]] name = "rand_distr" -version = "0.4.3" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" dependencies = [ "num-traits", - "rand", + "rand 0.9.0", ] [[package]] @@ -1755,15 +1899,6 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" -[[package]] -name = "redox_syscall" -version = "0.5.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" -dependencies = [ - "bitflags", -] - [[package]] name = "regex" version = "1.11.1" @@ -1793,12 +1928,6 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" -[[package]] -name = "rustc-hash" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" - [[package]] name = "rustc-hash" version = "2.1.1" @@ -1807,15 +1936,25 @@ checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rustversion" -version = "1.0.19" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" +checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" [[package]] name = "ryu" -version = "1.0.19" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + +[[package]] +name = "safetensors" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df" +dependencies = [ + "serde", + "serde_json", +] [[package]] name = "same-file" @@ -1826,43 +1965,37 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - [[package]] name = "seq-macro" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" +checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" [[package]] name = "serde" -version = "1.0.217" +version = "1.0.218" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" +checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.217" +version = "1.0.218" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" +checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn", ] [[package]] name = "serde_json" -version = "1.0.138" +version = "1.0.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" +checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" dependencies = [ "itoa", "memchr", @@ -1870,6 +2003,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.8" @@ -1900,21 +2044,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" [[package]] -name = "syn" -version = "1.0.109" +name = "subtle" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.98" +version = "2.0.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1" +checksum = "e02e925281e18ffd9d640e234264753c43edc62d64b2d4cf898f1bc5e75f3fc2" dependencies = [ "proc-macro2", "quote", @@ -1947,6 +2086,23 @@ version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" +[[package]] +name = "tch" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa1ed622c8f13b0c42f8b1afa0e5e9ccccd82ecb6c0e904120722ab52fdc5234" +dependencies = [ + "half", + "lazy_static", + "libc", + "ndarray", + "rand 0.8.5", + "safetensors", + "thiserror 1.0.69", + "torch-sys", + "zip", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -1958,11 +2114,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.11" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" dependencies = [ - "thiserror-impl 2.0.11", + "thiserror-impl 2.0.12", ] [[package]] @@ -1973,20 +2129,39 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn", ] [[package]] name = "thiserror-impl" -version = "2.0.11" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn", +] + +[[package]] +name = "time" +version = "0.3.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" +dependencies = [ + "deranged", + "num-conv", + "powerfmt", + "serde", + "time-core", ] +[[package]] +name = "time-core" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" + [[package]] name = "time-humanize" version = "0.1.3" @@ -2012,6 +2187,18 @@ dependencies = [ "serde_json", ] +[[package]] +name = "torch-sys" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef14f5d239e3d60f4919f536a5dfe1d4f71b27b7abf6fe6875fd3a4b22c2dcd5" +dependencies = [ + "anyhow", + "cc", + "libc", + "zip", +] + [[package]] name = "typenum" version = "1.18.0" @@ -2026,9 +2213,9 @@ checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" [[package]] name = "unicode-ident" -version = "1.0.16" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" [[package]] name = "unicode-width" @@ -2038,15 +2225,15 @@ checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" [[package]] name = "unindent" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" [[package]] name = "upon" -version = "0.8.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fe29601d1624f104fa9a35ea71a5f523dd8bd1cfc8c31f8124ad2b829f013c0" +checksum = "cc1243af2969e332d5b9b99087eddd44d04a41da8630ed53e06df497b7f5c747" [[package]] name = "version_check" @@ -2070,6 +2257,15 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.13.3+wasi-0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "wasm-bindgen" version = "0.2.100" @@ -2092,7 +2288,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn 2.0.98", + "syn", "wasm-bindgen-shared", ] @@ -2114,7 +2310,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2239,6 +2435,15 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags", +] + [[package]] name = "zerocopy" version = "0.7.35" @@ -2246,7 +2451,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "byteorder", - "zerocopy-derive", + "zerocopy-derive 0.7.35", +] + +[[package]] +name = "zerocopy" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcf01143b2dd5d134f11f545cf9f1431b13b749695cb33bcce051e7568f99478" +dependencies = [ + "zerocopy-derive 0.8.21", ] [[package]] @@ -2257,5 +2471,65 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712c8386f4f4299382c9abee219bee7084f78fb939d88b6840fcc1320d5f6da2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zip" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" +dependencies = [ + "aes", + "byteorder", + "bzip2", + "constant_time_eq", + "crc32fast", + "crossbeam-utils", + "flate2", + "hmac", + "pbkdf2", + "sha1", + "time", + "zstd", +] + +[[package]] +name = "zstd" +version = "0.11.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "5.0.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db" +dependencies = [ + "libc", + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.14+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fb060d4926e4ac3a3ad15d864e99ceb5f343c6b34f5bd6d81ae6ed417311be5" +dependencies = [ + "cc", + "pkg-config", ] diff --git a/Cargo.toml b/Cargo.toml index cc452c4..587aa2f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "nutpie" -version = "0.13.4" +version = "0.14.0" authors = [ "Adrian Seyboldt ", "PyMC Developers ", @@ -22,25 +22,26 @@ name = "_lib" crate-type = ["cdylib"] [dependencies] -nuts-rs = "0.12.1" -numpy = "0.21.0" -rand = "0.8.5" -thiserror = "1.0.44" -rand_chacha = "0.3.1" -rayon = "1.9.0" +nuts-rs = "0.15.0" +numpy = "0.23.0" +rand = "0.9.0" +thiserror = "2.0.3" +rand_chacha = "0.9.0" +rayon = "1.10.0" # Keep arrow in sync with nuts-rs requirements -arrow = { version = "52.0.0", default-features = false, features = ["ffi"] } +arrow = { version = "54.1.0", default-features = false, features = ["ffi"] } anyhow = "1.0.72" itertools = "0.14.0" bridgestan = "2.6.1" -rand_distr = "0.4.3" -smallvec = "1.11.0" -upon = { version = "0.8.1", default-features = false, features = [] } +rand_distr = "0.5.0" +smallvec = "1.13.0" +upon = { version = "0.9.0", default-features = false, features = [] } time-humanize = { version = "0.1.3", default-features = false } indicatif = "0.17.8" +tch = { version = "0.19.0", optional = true } [dependencies.pyo3] -version = "0.21.0" +version = "0.23.4" features = ["extension-module", "anyhow"] [dev-dependencies] diff --git a/docs/_quarto.yml b/docs/_quarto.yml new file mode 100644 index 0000000..8781265 --- /dev/null +++ b/docs/_quarto.yml @@ -0,0 +1,31 @@ +project: + type: website + +website: + title: "Nutpie" + navbar: + left: + - href: index.qmd + text: Home + - href: pymc-usage.qmd + text: Usage with PyMC + - href: stan-usage.qmd + text: Usage with Stan + - href: sampling-options.qmd + text: Sampling Options + - href: nf-adapt.qmd + text: Normalizing flow adaptation + - href: sample-stats.qmd + text: Diagnostic information + - about.qmd + tools: + - icon: github + href: https://github.com/pymc-devs/nutpie + +format: + html: + theme: + - cosmo + - brand + css: styles.css + toc: true diff --git a/docs/about.qmd b/docs/about.qmd new file mode 100644 index 0000000..16fc02f --- /dev/null +++ b/docs/about.qmd @@ -0,0 +1,17 @@ +--- +title: "About" +--- + +Nutpie is part of the PyMC organization. The PyMC organization develops and +maintains tools for Bayesian statistical modeling and probabilistic machine +learning. + +Nutpie provides a high-performance implementation of the No-U-Turn Sampler +(NUTS) that can be used with models defined in PyMC, Stan and other frameworks. +It was created to enable faster and more efficient Bayesian inference while +maintaining compatibility with existing probabilistic programming tools. + +For more information about the PyMC organization, visit the following links: + +- [PyMC Website](https://www.pymc.io) +- [PyMC GitHub Organization](https://github.com/pymc-devs) diff --git a/docs/benchmarks.md b/docs/benchmarks.md new file mode 100644 index 0000000..680d565 --- /dev/null +++ b/docs/benchmarks.md @@ -0,0 +1 @@ +# Benchmarks diff --git a/docs/index.qmd b/docs/index.qmd new file mode 100644 index 0000000..6de4796 --- /dev/null +++ b/docs/index.qmd @@ -0,0 +1,90 @@ +# Nutpie Documentation + +`nutpie` is a high-performance library designed for Bayesian inference, that +provides efficient sampling algorithms for probabilistic models. It can sample +models that are defined in PyMC or Stan (numpyro and custom hand-coded +likelihoods with gradient are coming soon). + +- Faster sampling than either the PyMC or Stan default samplers. (An average + ~2x speedup on `posteriordb` compared to Stan) +- All the diagnostic information of PyMC and Stan and some more. +- GPU support for PyMC models through jax. +- A more informative progress bar. +- Access to the incomplete trace during sampling. +- *Experimental* normalizing flow adaptation for more efficient sampling of + difficult posteriors. + +## Quickstart: PyMC + +Install `nutpie` with pip, uv, pixi, or conda: + +For usage with pymc: + +```bash +# One of +pip install "nutpie[pymc]" +uv add "nutpie[pymc]" +pixi add nutpie pymc numba +conda install -c conda-forge nutpie pymc numba +``` + +And then sample with + +```{python} +import nutpie +import pymc as pm + +with pm.Model() as model: + mu = pm.Normal("mu", mu=0, sigma=1) + obs = pm.Normal("obs", mu=mu, sigma=1, observed=[1, 2, 3]) + +compiled = nutpie.compile_pymc_model(model) +trace = nutpie.sample(compiled) +``` + +For more information, see the detailed [PyMC usage guide](pymc-usage.qmd). + +## Quickstart: Stan + +Stan needs access to a compiler toolchain, you can find instructions for those +[here](https://mc-stan.org/docs/cmdstan-guide/installation.html#cpp-toolchain). +You can then install nutpie through pip or uv: + +```bash +# One of +pip install "nutpie[stan]" +uv add "nutpie[stan]" +``` + +```{python} +#| echo: false +import os +os.environ["TBB_CXX_TYPE"] = "clang" +``` + +```{python} +import nutpie + +model = """ +data { + int N; + vector[N] y; +} +parameters { + real mu; +} +model { + mu ~ normal(0, 1); + y ~ normal(mu, 1); +} +""" + +compiled = ( + nutpie + .compile_stan_model(code=model) + .with_data(N=3, y=[1, 2, 3]) +) +trace = nutpie.sample(compiled) +``` + +For more information, see the detailed [Stan usage guide](stan-usage.qmd). diff --git a/docs/nf-adapt.qmd b/docs/nf-adapt.qmd new file mode 100644 index 0000000..61b9b44 --- /dev/null +++ b/docs/nf-adapt.qmd @@ -0,0 +1,119 @@ +# Adaptation with Normalizing Flows + +**Experimental and subject to change** + +Normalizing flow adaptation through Fisher HMC is a new sampling algorithm that +automatically reparameterizes a model. It adds some computational cost outside +model log-density evaluations, but allows sampling from much more difficult +posterior distributions. For models with expensive log-density evaluations, the +normalizing flow adaptation can also be much faster, if it can reduce the number +of log-density evaluations needed to reach a given effective sample size. + +The normalizing flow adaptation works by learning a transformation of the parameter +space that makes the posterior distribution more amenable to sampling. This is done +by fitting a sequence of invertible transformations (the "flow") that maps the +original parameter space to a space where the posterior is closer to a standard +normal distribution. The flow is trained during warmup. + +For more information about the algorithm, see the (still work in progress) paper +[If only my posterior were normal: Introducing Fisher +HMC](https://github.com/aseyboldt/covadapt-paper/releases/download/latest/main.pdf). + +Currently, a lot of time is spent on compiling various parts of the normalizing +flow, and for small models this can take a large amount of the total time. +Hopefully, we will be able to reduce this overhead in the future. + +## Requirements + +Install the optional dependencies for normalizing flow adaptation: + +``` +pip install 'nutpie[nnflow]' +``` + +If you use with PyMC, this will only work if the model is compiled using the jax +backend, and if the `gradient_backend` is also set to `jax`. + +Training of the normalizing flow can often be accelerated by using a GPU (even +if the model itself is written in Stan, without any GPU support). To enable GPU +you need to make sure your `jax` installation comes with GPU support, for +instance by installing it with `pip install 'jax[cuda12]'`, or selecting the +`jaxlib` version with GPU support, if you are using conda-forge. You can check if +your installation has GPU support by checking the output of: + +```python +import jax +jax.devices() +``` + +### Usage + +To use normalizing flow adaptation in `nutpie`, you need to enable the +`transform_adapt` option during sampling. Here is an example of how we can use +it to sample from a difficult posterior: + +```{python} +import pymc as pm +import nutpie +import numpy as np +import arviz + +# Define a 100-dimensional funnel model +with pm.Model() as model: + log_sigma = pm.Normal("log_sigma") + pm.Normal("x", mu=0, sigma=pm.math.exp(log_sigma / 2), shape=100) + +# Compile the model with the jax backend +compiled = nutpie.compile_pymc_model( + model, backend="jax", gradient_backend="jax" +) +``` + +If we sample this model without normalizing flow adaptation, we will encounter +convergence issues, often divergences and always low effective sample sizes: + +```{python} +# Sample without normalizing flow adaptation +trace_no_nf = nutpie.sample(compiled, seed=1) +assert (arviz.ess(trace_no_nf) < 100).any().to_array().any() +``` + +```{python} +# We can add further arguments for the normalizing flow: +compiled = compiled.with_transform_adapt( + num_layers=5, # Number of layers in the normalizing flow + nn_width=32, # Neural networks with 32 hidden units + num_diag_windows=6, # Number of windows with a diagonal mass matrix intead of a flow + verbose=False, # Whether to print details about the adaptation process + show_progress=False, # Whether to show a progress bar for each optimization step +) + +# Sample with normalizing flow adaptation +trace_nf = nutpie.sample( + compiled, + transform_adapt=True, # Enable the normalizing flow adaptation + seed=1, + chains=2, + cores=1, # Running chains in parallel can be slow + window_switch_freq=150, # Optimize the normalizing flow every 150 iterations +) +assert trace_nf.sample_stats.diverging.sum() == 0 +assert (arviz.ess(trace_nf) > 1000).all().to_array().all() +``` + +The flow adaptation occurs during warmup, so the number of warmup draws should +be large enough to allow the flow to converge. For more complex posteriors, you +may need to increase the number of layers (using the `num_layers` argument), or +you might want to increase the number of warmup draws. + +To monitor the progress of the flow adaptation, you can set `verbose=True`, or +`show_progress=True`, but the second should only be used if you sample just one +chain. + +All losses are on a log-scale. Negative values smaller -2 are a good sign that +the adaptation was successful. If the loss stays positive, the flow is either +not expressive enough, or the training period is too short. The sampler might +still converge, but will probably need more gradient evaluations per effective +draw. Large losses bigger than 6 tend to indicate that the posterior is too +difficult to sample with the current flow, and the sampler will probably not +converge. diff --git a/docs/pymc-usage.qmd b/docs/pymc-usage.qmd new file mode 100644 index 0000000..56d81d7 --- /dev/null +++ b/docs/pymc-usage.qmd @@ -0,0 +1,195 @@ +# Usage with PyMC models + +This document shows how to use `nutpie` with PyMC models. We will use the +`pymc` package to define a simple model and sample from it using `nutpie`. + +## Installation + +The recommended way to install `pymc` is through the `conda` ecosystem. A good +package manager for conda packages is `pixi`. See for the [pixi +documentation](https://pixi.sh) for instructions on how to install it. + +We create a new project for this example: + +```bash +pixi new pymc-example +``` + +This will create a new directory `pymc-example` with a `pixi.toml` file, that +you can edit to add meta information. + +We then add the `pymc` and `nutpie` packages to the project: + +```bash +cd pymc-example +pixi add pymc nutpie arviz +``` + +You can use Visual Studio Code (VSCode) or JupyterLab to write and run our code. +Both are excellent tools for working with Python and data science projects. + +### Using VSCode + +1. Open VSCode. +2. Open the `pymc-example` directory created earlier. +3. Create a new file named `model.ipynb`. +4. Select the pixi kernel to run the code. + +### Using JupyterLab + +1. Add jupyter labs to the project by running `pixi add jupyterlab`. +1. Open JupyterLab by running `pixi run jupyter lab` in your terminal. +3. Create a new Python notebook. + +## Defining and Sampling a Simple Model + +We will define a simple Bayesian model using `pymc` and sample from it using +`nutpie`. + +### Model Definition + +In your `model.ipypy` file or Jupyter notebook, add the following code: + +```{python} +import pymc as pm +import nutpie +import pandas as pd + +coords = {"observation": range(3)} + +with pm.Model(coords=coords) as model: + # Prior distributions for the intercept and slope + intercept = pm.Normal("intercept", mu=0, sigma=1) + slope = pm.Normal("slope", mu=0, sigma=1) + + # Likelihood (sampling distribution) of observations + x = [1, 2, 3] + + mu = intercept + slope * x + y = pm.Normal("y", mu=mu, sigma=0.1, observed=[1, 2, 3], dims="observation") +``` + +### Sampling + +We can now compile the model using the numba backend: + +```{python} +compiled = nutpie.compile_pymc_model(model) +trace = nutpie.sample(compiled) +``` + +Alternatively, we can also sample through the `pymc` API: + +```python +with model: + trace = pm.sample(model, nuts_sampler="nutpie") +``` + +While sampling, nutpie shows a progress bar for each chain. It also includes +information about how each chain is doing: + +- It shows the current number of draws +- The step size of the integrator (very small stepsizes are typically a bad + sign) +- The number of divergences (if there are divergences, that means that nutpie is + probably not sampling the posterior correctly) +- The number of gradient evaluation nutpie uses for each draw. Large numbers + (100 to 1000) are a sign that the parameterization of the model is not ideal, + and the sampler is very inefficient. + +After sampling, this returns an `arviz` InferenceData object that you can use to +analyze the trace. + +For example, we should check the effective sample size: + +```{python} +import arviz as az +az.ess(trace) +``` + +and take a look at a trace plot: + +```{python} +az.plot_trace(trace); +``` + +### Choosing the backend + +Right now, we have been using the numba backend. This is the default backend for +`nutpie`, when sampling from pymc models. It tends to have relatively long +compilation times, but samples small models very efficiently. For larger models +the `jax` backend sometimes outperforms `numba`. + +First, we need to install the `jax` package: + +```bash +pixi add jax +``` + +We can select the backend by passing the `backend` argument to the `compile_pymc_model`: + +```python +compiled_jax = nutpie.compiled_pymc_model(model, backend="jax") +trace = nutpie.sample(compiled_jax) +``` + +Or through the pymc API: + +```python +with model: + trace = pm.sample( + model, + nuts_sampler="nutpie", + nuts_sampler_kwargs={"backend": "jax"}, + ) +``` + +If you have an nvidia GPU, you can also use the `jax` backend with the `gpu`. We +will have to install the `jaxlib` package with the `cuda` option + +```bash +pixi add jaxlib --build 'cuda12' +``` + +Restart the kernel and check that the GPU is available: + +```python +import jax + +# Should list the cuda device +jax.devices() +``` + +Sampling again, should now use the GPU, which you can observe by checking the +GPU usage with `nvidia-smi` or `nvtop`. + +### Changing the dataset without recompilation + +If you want to use the same model with different datasets, you can modify +datasets after compilation. Since jax does not like changes in shapes, this is +only recommended with the numba backend. + +First, we define the model, but put our dataset in a `pm.Data` structure: + +```{python} +with pm.Model() as model: + x = pm.Data("x", [1, 2, 3]) + intercept = pm.Normal("intercept", mu=0, sigma=1) + slope = pm.Normal("slope", mu=0, sigma=1) + mu = intercept + slope * x + y = pm.Normal("y", mu=mu, sigma=0.1, observed=[1, 2, 3]) +``` + +We can now compile the model: + +```{python} +compiled = nutpie.compile_pymc_model(model) +trace = nutpie.sample(compiled) +``` + +After compilation, we can change the dataset: + +```{python} +compiled2 = compiled.with_data(x=[4, 5, 6]) +trace2 = nutpie.sample(compiled2) +``` diff --git a/docs/sample-stats.qmd b/docs/sample-stats.qmd new file mode 100644 index 0000000..7cf92c9 --- /dev/null +++ b/docs/sample-stats.qmd @@ -0,0 +1,221 @@ +# Understanding Sampler Statistics in Nutpie + +This guide explains the various statistics that nutpie collects during sampling. We'll use Neal's funnel distribution as an example, as it's a challenging model that demonstrates many important sampling concepts. + +## Example Model: Neal's Funnel + +Let's start by implementing Neal's funnel in PyMC: + +```{python} +import pymc as pm +import nutpie +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +import pandas as pd +import arviz as az + +# Create the funnel model +with pm.Model() as model: + log_sigma = pm.Normal('log_sigma') + pm.Normal('x', sigma=pm.math.exp(log_sigma), shape=5) + +# Sample with detailed statistics +compiled = nutpie.compile_pymc_model(model) +trace = nutpie.sample( + compiled, + tune=1000, + store_mass_matrix=True, + store_gradient=True, + store_unconstrained=True, + store_divergences=True, + seed=42, +) +``` + +## Sampler Statistics Overview + +The sampler statistics can be grouped into several categories: + +### Basic HMC Statistics + +These statistics are always collected and are essential for basic diagnostics: + +```{python} +# Access through trace.sample_stats +basic_stats = [ + 'depth', # Tree depth for current draw + 'maxdepth_reached', # Whether max tree depth was hit + 'logp', # Log probability of current position + 'energy', # Hamiltonian energy + 'diverging', # Whether the transition diverged + 'step_size', # Current step size + 'step_size_bar', # Current estimate of an ideal step size + 'n_steps' # Number of leapfrog steps + +] + +# Plot step size evolution during warmup +trace.warmup_sample_stats.step_size_bar.plot.line(x="draw", yscale="log") +``` + +### Mass Matrix Adaptation + +These statistics track how the mass matrix evolves: + +```{python} +( + trace + .warmup_sample_stats + .mass_matrix_inv + .plot + .line( + x="draw", + yscale="log", + col="chain", + col_wrap=2, + ) +) +``` + +Variables that are a source of convergence issues, will often show high variance +in the final mass matrix estimate across chains. + +The mass matrix will always be fixed for 10% of draws at the end, because we +only run final step size adaptation during that time, but high variance in the +mass matrix before this final window and indicate that more tuning steps might +be needed. + +### Detailed Diagnostics + +These are only available when explicitly requested: + +```python +detailed_stats = [ + 'gradient', # Gradient at current position + 'unconstrained_draw', # Parameters in unconstrained space + 'divergence_start', # Position where divergence started + 'divergence_end', # Position where divergence ended + 'divergence_momentum', # Momentum at divergence + 'divergence_message' # Description of divergence +] +``` + +#### Idintify Divergences + +We can for instance use this to identify the sources of divergences: + +```{python} +import xarray as xr + +draws = ( + trace + .sample_stats + .unconstrained_draw + .assign_coords(kind="draw") +) +divergence_locations = ( + trace + .sample_stats + .divergence_start + .assign_coords(kind="divergence") +) + +points = xr.concat([draws, divergence_locations], dim="kind") +points.to_dataset("unconstrained_parameter").plot.scatter(x="log_sigma", y="x_0", hue="kind") +``` + +#### Covariance of gradients and draws + +TODO this section should really use the transformed gradients and draws, not the +unconstrained ones, as that avoids the manual mass matrix correction. This +is only available for the normalizing flow adaptation at the moment though. + +In models with problematic posterior correlations, the singular value +decomposition of gradients and draws can often point us to the source of the +issue. + +Let's build a little model with correlations between parameters: + +```{python} +with pm.Model() as model: + x = pm.Normal('x') + y = pm.Normal("y", mu=x, sigma=0.01) + z = pm.Normal("z", mu=y, shape=100) + +compiled = nutpie.compile_pymc_model(model) +trace = nutpie.sample( + compiled, + tune=1000, + store_gradient=True, + store_unconstrained=True, + store_mass_matrix=True, + seed=42, +) +``` + +Now we can compute eigenvalues of the covariance matrix of the gradient and +draws (using the singular value decomposition to avoid quadratic cost): + +```{python} +def covariance_eigenvalues(x, mass_matrix): + assert x.dims == ("chain", "draw", "unconstrained_parameter") + x = x.stack(sample=["draw", "chain"]) + x = (x - x.mean("sample")) / np.sqrt(mass_matrix) + u, s, v = np.linalg.svd(x.T / np.sqrt(x.shape[1]), full_matrices=False) + print(u.shape, s.shape, v.shape) + s = xr.DataArray( + s, + dims=["eigenvalue"], + coords={"eigenvalue": range(s.size)}, + ) + v = xr.DataArray( + v, + dims=["eigenvalue", "unconstrained_parameter"], + coords={ + "eigenvalue": s.eigenvalue, + "unconstrained_parameter": x.unconstrained_parameter, + }, + ) + return s ** 2, v + +mass_matrix = trace.sample_stats.mass_matrix_inv.isel(draw=-1, chain=0) +draws_eigs, draws_eigv = covariance_eigenvalues(trace.sample_stats.unconstrained_draw, mass_matrix) +grads_eigs, grads_eigv = covariance_eigenvalues(trace.sample_stats.gradient, 1 / mass_matrix) + +draws_eigs.plot.line(x="eigenvalue", yscale="log") +grads_eigs.plot.line(x="eigenvalue", yscale="log") +``` + +We can see one very large and one very small eigenvalue in both covariances. +Large eigenvalues for the draws, and small eigenvalues for the gradients prevent +the sampler from taking larger steps. Small eigenvalues in the draws, and large +eigenvalues in the grads mean, that the sampler has to move far in parameter +space to get independent draws. So both lead to problems during sampling. For +models with many parameters, typically only the large eigenvalues of each are +meaningful, because of estimation issues with the small eigenvalues. + +We can also look at the eigenvectors to see which parameters are responsible for +the correlations: + +```{python} +( + draws_eigv + .sel(eigenvalue=0) + .to_pandas() + .sort_values(key=abs) + .tail(10) + .plot.bar(x="unconstrained_parameter") +) +``` + +```{python} +( + grads_eigv + .sel(eigenvalue=0) + .to_pandas() + .sort_values(key=abs) + .tail(10) + .plot.bar(x="unconstrained_parameter") +) +``` diff --git a/docs/sampling-options.qmd b/docs/sampling-options.qmd new file mode 100644 index 0000000..270fe69 --- /dev/null +++ b/docs/sampling-options.qmd @@ -0,0 +1,156 @@ +# Sampling Configuration Guide + +This guide covers the configuration options for `nutpie.sample` and provides +practical advice for tuning your sampler. We'll start with basic usage and move +to advanced topics like mass matrix adaptation. + +## Quick Start + +For most models, don't think too much about the options of the sampler, and just +use the defaults. Most sampling problems can't easily be solved by changing the +sampler, most of the time they require model changes. So in most cases, simply use + +```python +trace = nutpie.sample(compiled_model) +``` + +## Core Sampling Parameters + +### Drawing Samples + +```python +trace = nutpie.sample( + model, + draws=1000, # Number of post-warmup draws per chain + tune=500, # Number of warmup draws for adaptation + chains=6, # Number of independent chains + cores=None, # Number chains that are allowed to run simultainiously + seed=12345 # Random seed for reproducibility +) +``` + +The number of draws affects both accuracy and computational cost: +- Too few draws (< 500) may not capture the posterior well +- Too many draws (> 10000) may waste computation time + +If a model is sampling without divergences, but with effective sample sizes that +are not as large as necessary to accieve the markov-error for your estimates, +you can increase the number of chains and/or draws. + +If the effective sample size is much smaller than the number of draws, you might +want to consider reparameterizing the model instead, to for instance remove +posterior correlations. + +## Sampler Diagnostics + +You can enable more detailed diagnostics when troubleshooting: + +```python +trace = nutpie.sample( + model, + save_warmup=True, # Keep warmup draws, default is True + store_divergences=True, # Track divergent transitions + store_unconstrained=True, # Store transformed parameters + store_gradient=True, # Store gradient information + store_mass_matrix=True # Track mass matrix adaptation +) +``` + +For each of the `store_*` arguments, additional arrays will be availbale in the +`trace.sample_stats`. + +## Non-blocking sampling + + + +### Settings for HMC and NUTS + +```python +trace = nutpie.sample( + model, + target_accept=0.8, # Target acceptance rate + maxdepth=10 # Maximum tree depth + max_energy_error=1000 # Error at witch to count the trajectory as a divergent transition +) +``` + +The `target_accept` parameter implicitly controls the step size of the leapfrog +steps in the HMC sampler. During tuning, the sampler will try to choose a step +size, such that the acceptance statistic is `target_accept`. It has to be +between 0 and 1. + +The default is 0.8. Larger values will increase the computational cost, but +might avoid divergences during sampling. In many diverging models increasing +`target_accept` will only make divergences less frequent however, and not solve +the underlying problem. + +Lowering the maximum energy error to for instance 10 will often increase the +number of divergences, and make it easier to diagnose their cause. With lower +value the divergences often are reported closer to the critical points in the +parameter space, where the model is most likely to diverge. + +## Mass Matrix Adaptation + +Nutpie offers several strategies for adapting the mass matrix, which determines +how the sampler navigates the parameter space. + +### Standard Adaptation + +By setting `use_grad_based_mass_matrix=False`, the sampling algorithm will more +closely resemble the algorithm in Stan and PyMC. Usually, this will result in +less efficient sampling, but the total number of effective samples is sometimes +higher. If this is set to `True` (the default), nutpie will use diagonal mass +matrix estimates that are based on the posterior draws and the scores at those +positions. + +```python +trace = nutpie.sample( + model, + use_grad_based_mass_matrix=False +) +``` + +### Low-Rank Updates + +For models with strong parameter correlations you can enable a low rank modified +mass matrix. The `mass_matrix_gamma` parameter is a regularization parameter. +More regularization will lead to a smaller effect of the low-rank components, +but might work better for hiegher dimensional problems. + +`mass_matrix_eigval_cutoff` should be greater than one, and controls how large +an eigenvalue of the full mass matrix has to be, to be included into the +low-rank mass matirx. + +```python +trace = nutpie.sample( + model, + low_rank_modified_mass_matrix=True, + mass_matrix_eigval_cutoff=3, + mass_matrix_gamma=1e-5 +) +``` + +### Experimental Features + +`trasform_adapt` is an experimental feature that allows sampling from many +posteriors, where current methods diverge. It is described in more detail +[here](nf-adapt.qmd). + +```python +trace = nutpie.sample( + model, + transform_adapt=True # Experimental reparameterization +) +``` + +## Progress Monitoring + +Customize the sampling progress display: + +```python +trace = nutpie.sample( + model, + progress_bar=True, + progress_rate=500, # Update every 500ms +) +``` diff --git a/docs/stan-usage.qmd b/docs/stan-usage.qmd new file mode 100644 index 0000000..7296231 --- /dev/null +++ b/docs/stan-usage.qmd @@ -0,0 +1,230 @@ +# Usage with Stan models + +This document shows how to use `nutpie` with Stan models. We will use the +`nutpie` package to define a simple model and sample from it using Stan. + +## Installation + +For Stan, it is more common to use `pip` or `uv` to install the necessary +packages. However, `conda` is also an option if you prefer. + +To install using `pip`: + +```bash +pip install "nutpie[stan]" +``` + +To install using `uv`: + +```bash +uv add "nutpie[stan]" +``` + +To install using `conda`: + +```bash +conda install -c conda-forge nutpie +``` + +## Compiler Toolchain + +Stan requires a compiler toolchain to be installed on your system. This is +necessary for compiling the Stan models. You can find detailed instructions for +setting up the compiler toolchain in the [CmdStan +Guide](https://mc-stan.org/docs/cmdstan-guide/installation.html#cpp-toolchain). + +Additionally, since Stan uses Intel's Threading Building Blocks (TBB) for +parallelism, you might need to set the `TBB_CXX_TYPE` environment variable to +specify the compiler type. Depending on your system, you can set it to either +`clang` or `gcc`. For example: + +```{python} +import os +os.environ["TBB_CXX_TYPE"] = "clang" # or 'gcc' +``` + +Make sure to set this environment variable before compiling your Stan models to ensure proper configuration. + +## Defining and Sampling a Simple Model + +We will define a simple Bayesian model using Stan and sample from it using +`nutpie`. + +### Model Definition + +In your Python script or Jupyter notebook, add the following code: + +```{python} +import nutpie + +model_code = """ +data { + int N; + vector[N] y; +} +parameters { + real mu; +} +model { + mu ~ normal(0, 1); + y ~ normal(mu, 1); +} +""" + +compiled_model = nutpie.compile_stan_model(code=model_code) +``` + +### Sampling + +We can now compile the model and sample from it: + +```{python} +compiled_model_with_data = compiled_model.with_data(N=3, y=[1, 2, 3]) +trace = nutpie.sample(compiled_model_with_data) +``` + +### Using Dimensions + +We'll use the radon model from +[this](https://mc-stan.org/learn-stan/case-studies/radon_cmdstanpy_plotnine.html) +case-study from the stan documentation, to show how we can use coordinates and +dimension names to simplify working with trace objects. + +We follow the same data preparation as in the case-study: + +```{python} +import pandas as pd +import numpy as np +import arviz as az +import seaborn as sns + +home_data = pd.read_csv( + "https://github.com/pymc-devs/pymc-examples/raw/refs/heads/main/examples/data/srrs2.dat", + index_col="idnum", +) +county_data = pd.read_csv( + "https://github.com/pymc-devs/pymc-examples/raw/refs/heads/main/examples/data/cty.dat", +) + +radon_data = ( + home_data + .rename(columns=dict(cntyfips="ctfips")) + .merge( + ( + county_data + .drop_duplicates(['stfips', 'ctfips', 'st', 'cty', 'Uppm']) + .set_index(["ctfips", "stfips"]) + ), + right_index=True, + left_on=["ctfips", "stfips"], + ) + .assign(log_radon=lambda x: np.log(np.clip(x.activity, 0.1, np.inf))) + .assign(log_uranium=lambda x: np.log(np.clip(x["Uppm"], 0.1, np.inf))) + .query("state == 'MN'") +) +``` + +And also use the partially pooled model from the case-study: + +```{python} +model_code = """ +data { + int N; // observations + int J; // counties + array[N] int county; + vector[N] x; + vector[N] y; +} +parameters { + real mu_alpha; + real sigma_alpha; + vector[J] alpha; // non-centered parameterization + real beta; + real sigma; +} +model { + y ~ normal(alpha[county] + beta * x, sigma); + alpha ~ normal(mu_alpha, sigma_alpha); // partial-pooling + beta ~ normal(0, 10); + sigma ~ normal(0, 10); + mu_alpha ~ normal(0, 10); + sigma_alpha ~ normal(0, 10); +} +generated quantities { + array[N] real y_rep = normal_rng(alpha[county] + beta * x, sigma); +} +""" +``` + +We collect the dataset in the format that the stan model requires, +and specify the dimensions of each of the non-scalar variables in the model: + +```{python} +county_idx, counties = pd.factorize(radon_data["county"], use_na_sentinel=False) +observations = radon_data.index + +coords = { + "county": counties, + "observation": observations, +} + +dims = { + "alpha": ["county"], + "y_rep": ["observation"], +} + +data = { + "N": len(observations), + "J": len(counties), + # Stan uses 1-based indexing! + "county": county_idx + 1, + "x": radon_data.log_uranium.values, + "y": radon_data.log_radon.values, +} +``` + +Then, we compile the model and provide the dimensions, coordinates and the +dataset we just defined: + +```{python} +compiled_model = ( + nutpie.compile_stan_model(code=model_code) + .with_data(**data) + .with_dims(**dims) + .with_coords(**coords) +) +``` + +```{python} +%%time +trace = nutpie.sample(compiled_model, seed=0) +``` + +As some basic convergance checking we verify that all Rhat values are smaller +than 1.02, all parameters have at least 500 effective draws and that we have no +divergences: + +```{python} +assert trace.sample_stats.diverging.sum() == 0 +assert az.ess(trace).min().min() > 500 +assert az.rhat(trace).max().max() > 1.02 +``` + +Thanks to the coordinates and dimensions we specified, the resulting trace will +now contain labeled data, so that plots based on it have properly set-up labels: + +```{python} +import arviz as az +import seaborn as sns +import xarray as xr + +sns.catplot( + data=trace.posterior.alpha.to_dataframe().reset_index(), + y="county", + x="alpha", + kind="boxen", + height=13, + aspect=1/2.5, + showfliers=False, +) +``` diff --git a/docs/styles.css b/docs/styles.css new file mode 100644 index 0000000..2ddf50c --- /dev/null +++ b/docs/styles.css @@ -0,0 +1 @@ +/* css styles */ diff --git a/pixi.toml b/pixi.toml new file mode 100644 index 0000000..e0d6eb5 --- /dev/null +++ b/pixi.toml @@ -0,0 +1,52 @@ +[project] +authors = ["Adrian Seyboldt "] +channels = ["conda-forge"] +description = "Add a short description here" +name = "nuts-py" +platforms = ["linux-64"] +version = "0.1.0" + +[tasks] +test = "pytest" +develop = "maturin develop --release" +get-posteriordb = "git clone 'https://github.com/stan-dev/posteriordb'" + +[tasks.bench] +depends-on = ["develop", "get-posteriordb"] +cmd = "python -m samplerlab -m posteriordb-fast --posteriordb posteriordb/posterior_database --save-traces --seed 12345" + +[dependencies] +python = ">=3.12.7,<3.13" +pymc = ">=5.19.0,<6" +numba = ">=0.61.0,<0.62" +pytest = ">=8.3.4,<9" +maturin = ">=1.7.7,<2" +pip = ">=24.3.1,<25" +ipykernel = ">=6.29.5,<7" +seaborn = ">=0.13.2,<0.14" +threadpoolctl = ">=3.5.0,<4" +zarr = ">=2.18.3,<3" +polars = ">=1.16.0,<2" +viztracer = ">=1.0.0,<2" +ipywidgets = ">=8.1.5,<9" +quarto = ">=1.6.40,<2" +yaml = ">=0.2.5,<0.3" +pyyaml = ">=6.0.2,<7" +nbformat = ">=5.10.4,<6" +nbclient = ">=0.10.2,<0.11" +cmdstanpy = ">=1.2.5,<2" +# The jaxlib cuda build seems to be broken around version 0.4.34 +#jax = ">=0.4.35,<0.5" +#jaxlib = { version = "*", build = "*cuda12*" } + +[pypi-dependencies] +bridgestan = ">=2.6.0, <3" +#flowjax = { git = "https://github.com/aseyboldt/flowjax.git", rev = "07e7e32217bcfcaa7d68a304f332c925d26ab76f" } +#samplerlab = { git = "https://github.com/aseyboldt/samplerlab/" } +posteriordb = ">=0.2.0, <0.3" +jax = { version = ">=0.5, <0.5.1", extras = ["cuda12"] } +watermark = ">=2.5.0, <3" +flowjax = { git = "https://github.com/danielward27/flowjax.git" } + +[system-requirements] +#cuda = "12" diff --git a/pyproject.toml b/pyproject.toml index 6ef4dff..12c5445 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,17 +2,12 @@ requires = ["maturin>=1.1,<2.0"] build-backend = "maturin" -[tool.maturin] -module-name = "nutpie._lib" -python-source = "python" -features = ["pyo3/extension-module"] - [project] name = "nutpie" description = "Sample Stan or PyMC models" authors = [{ name = "PyMC Developers", email = "pymc.devs@gmail.com" }] readme = "README.md" -requires-python = ">=3.10,<3.13" +requires-python = ">=3.10,<3.14" license = { text = "MIT" } classifiers = [ "Programming Language :: Rust", @@ -32,56 +27,41 @@ dynamic = ["version"] stan = ["bridgestan >= 2.6.1"] pymc = ["pymc >= 5.20.1", "numba >= 0.60.0"] pymc-jax = ["pymc >= 5.20.1", "jax >= 0.4.27"] +nnflow = ["flowjax >= 17.1.0", "equinox >= 0.11.12"] +dev = [ + "bridgestan >= 2.6.1", + "pymc >= 5.20.1", + "numba >= 0.60.0", + "jax >= 0.4.27", + "flowjax >= 17.0.2", + "pytest", +] all = [ "bridgestan >= 2.6.1", "pymc >= 5.20.1", "numba >= 0.60.0", "jax >= 0.4.27", + "flowjax >= 17.1.0", + "equinox >= 0.11.12", ] [tool.ruff] line-length = 88 -target-version = "py39" +target-version = "py310" show-fixes = true output-format = "full" -[tool.ruff.lint] -select = [ - "E", # pycodestyle errors - "W", # pycodestyle warnings - "F", # Pyflakes - "I", # isort - "C4", # flake8-comprehensions - "B", # flake8-bugbear - "UP", # pyupgrade - "RUF", # Ruff-specific rules - "TID", # flake8-tidy-imports - "BLE", # flake8-blind-except - "PTH", # flake8-pathlib - "A", # flake8-builtins -] -ignore = [ - "C408", # unnecessary-collection-call (allow dict(a=1, b=2); clarity over speed!) - # The following list is recommended to disable these when using ruff's formatter. - # (Not all of the following are actually enabled.) - "W191", # tab-indentation - "E111", # indentation-with-invalid-multiple - "E114", # indentation-with-invalid-multiple-comment - "E117", # over-indented - "D206", # indent-with-spaces - "D300", # triple-single-quotes - "Q000", # bad-quotes-inline-string - "Q001", # bad-quotes-multiline-string - "Q002", # bad-quotes-docstring - "Q003", # avoidable-escaped-quote - "COM812", # missing-trailing-comma - "COM819", # prohibited-trailing-comma - "ISC001", # single-line-implicit-string-concatenation - "ISC002", # multi-line-implicit-string-concatenation -] - [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" [tool.ruff.lint.isort] known-first-party = ["nutpie"] + +[tool.pyright] +venvPath = ".pixi/envs/" +venv = "default" + +[tool.maturin] +module-name = "nutpie._lib" +python-source = "python" +features = ["pyo3/extension-module"] diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 53f3176..9daec90 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -7,7 +7,7 @@ from functools import wraps from importlib.util import find_spec from math import prod -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, cast import numpy as np import pandas as pd @@ -274,7 +274,7 @@ def _compile_pymc_model_numba( warnings.filterwarnings( "ignore", message="Cannot cache compiled function .* as it uses dynamic globals", - category=numba.NumbaWarning, + category=numba.NumbaWarning, # type: ignore ) logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw) @@ -287,7 +287,7 @@ def _compile_pymc_model_numba( warnings.filterwarnings( "ignore", message="Cannot cache compiled function .* as it uses dynamic globals", - category=numba.NumbaWarning, + category=numba.NumbaWarning, # type: ignore ) expand_numba = numba.cfunc(c_sig_expand, **kwargs)(expand_numba_raw) @@ -377,14 +377,24 @@ def _compile_pymc_model_jax( logp_fn = logp_fn_pt.vm.jit_fn expand_fn = expand_fn_pt.vm.jit_fn + logp_shared_names = [var.name for var in logp_fn_pt.get_shared()] + expand_shared_names = [var.name for var in expand_fn_pt.get_shared()] + if gradient_backend == "jax": orig_logp_fn = logp_fn._fun - @jax.jit def logp_fn_jax_grad(x, *shared): return jax.value_and_grad(lambda x: orig_logp_fn(x, *shared)[0])(x) + # static_argnums = list(range(1, len(logp_shared_names) + 1)) + logp_fn_jax_grad = jax.jit( + logp_fn_jax_grad, + # static_argnums=static_argnums, + ) + logp_fn = logp_fn_jax_grad + else: + orig_logp_fn = None shared_data = {} shared_vars = {} @@ -396,9 +406,6 @@ def logp_fn_jax_grad(x, *shared): shared_vars[val.name] = val seen.add(val) - logp_shared_names = [var.name for var in logp_fn_pt.get_shared()] - expand_shared_names = [var.name for var in expand_fn_pt.get_shared()] - def make_logp_func(): def logp(x, **shared): logp, grad = logp_fn(x, *[shared[name] for name in logp_shared_names]) @@ -407,7 +414,8 @@ def logp(x, **shared): return logp names, slices, shapes = shape_info - dtypes = [np.float64] * len(names) + # TODO do not cast to float64 + dtypes = [np.dtype("float64")] * len(names) def make_expand_func(seed1, seed2, chain): # TODO handle seeds @@ -433,6 +441,7 @@ def expand(x, **shared): shared_data=shared_data, dims=dims, coords=coords, + raw_logp_fn=orig_logp_fn, ) @@ -635,7 +644,7 @@ def _make_functions( """ import pytensor import pytensor.tensor as pt - from pymc.pytensorf import compile_pymc + from pymc.pytensorf import compile as compile_pymc shapes = _compute_shapes(model) @@ -726,7 +735,7 @@ def _make_functions( for var in remaining_rvs: all_names.append(var.name) - shape = shapes[var.name] + shape = cast(tuple[int, ...], shapes[var.name]) all_shapes.append(shape) length = prod(shape) all_slices.append(slice(count, count + length)) diff --git a/python/nutpie/compile_stan.py b/python/nutpie/compile_stan.py index 7a28052..138652d 100644 --- a/python/nutpie/compile_stan.py +++ b/python/nutpie/compile_stan.py @@ -1,6 +1,7 @@ import json import tempfile from dataclasses import dataclass, replace +from functools import partial from importlib.util import find_spec from pathlib import Path from typing import Any, Optional @@ -11,6 +12,7 @@ from nutpie import _lib from nutpie.sample import CompiledModel +from nutpie.transform_adapter import make_transform_adapter class _NumpyArrayEncoder(json.JSONEncoder): @@ -28,6 +30,7 @@ class CompiledStanModel(CompiledModel): library: Any model: Any model_name: Optional[str] = None + _transform_adapt_args: dict | None = None def with_data(self, *, seed=None, **updates): if self.data is None: @@ -42,7 +45,15 @@ def with_data(self, *, seed=None, **updates): else: data_json = None - model = _lib.StanModel(self.library, seed, data_json) + kwargs = self._transform_adapt_args + if kwargs is None: + kwargs = {} + make_adapter = partial( + make_transform_adapter(**kwargs), + logp_fn=None, + ) + + model = _lib.StanModel(self.library, seed, data_json, make_adapter) coords = self._coords if coords is None: coords = {} @@ -75,6 +86,9 @@ def with_dims(self, **dims): dims_new.update(dims) return replace(self, dims=dims_new) + def with_transform_adapt(self, **kwargs): + return replace(self, _transform_adapt_args=kwargs).with_data() + def _make_model(self, init_mean): if self.model is None: return self.with_data().model diff --git a/python/nutpie/compiled_pyfunc.py b/python/nutpie/compiled_pyfunc.py index fb6553d..304d50f 100644 --- a/python/nutpie/compiled_pyfunc.py +++ b/python/nutpie/compiled_pyfunc.py @@ -5,8 +5,9 @@ import numpy as np -from nutpie import _lib +from nutpie import _lib # type: ignore from nutpie.sample import CompiledModel +from nutpie.transform_adapter import make_transform_adapter SeedType = int @@ -20,6 +21,8 @@ class PyFuncModel(CompiledModel): _n_dim: int _variables: list[_lib.PyVariable] _coords: dict[str, Any] + _raw_logp_fn: Callable | None + _transform_adapt_args: dict | None = None @property def shapes(self) -> dict[str, tuple[int, ...]]: @@ -42,6 +45,9 @@ def with_data(self, **updates): updated.update(**updates) return dataclasses.replace(self, _shared_data=updated) + def with_transform_adapt(self, **kwargs): + return dataclasses.replace(self, _transform_adapt_args=kwargs) + def _make_sampler(self, settings, init_mean, cores, progress_type): model = self._make_model(init_mean) return _lib.PySampler.from_pyfunc( @@ -60,12 +66,24 @@ def make_expand_func(seed1, seed2, chain): expand_fn = self._make_expand_func(seed1, seed2, chain) return partial(expand_fn, **self._shared_data) + if self._raw_logp_fn is not None: + kwargs = self._transform_adapt_args + if kwargs is None: + kwargs = {} + make_adapter = partial( + make_transform_adapter(**kwargs), + logp_fn=self._raw_logp_fn, + ) + else: + make_adapter = None + return _lib.PyModel( make_logp_func, make_expand_func, self._variables, self.n_dim, - self._make_initial_points, + init_point_func=self._make_initial_points, + transform_adapter=make_adapter, ) @@ -81,6 +99,8 @@ def from_pyfunc( dims: dict[str, tuple[str, ...]] | None = None, shared_data: dict[str, Any] | None = None, make_initial_point_fn: Callable[[SeedType], np.ndarray] | None = None, + make_transform_adapter=None, + raw_logp_fn=None, ): variables = [] for name, shape, dtype in zip( @@ -111,4 +131,5 @@ def from_pyfunc( _make_initial_points=make_initial_point_fn, _variables=variables, _shared_data=shared_data, + _raw_logp_fn=raw_logp_fn, ) diff --git a/python/nutpie/normalizing_flow.py b/python/nutpie/normalizing_flow.py new file mode 100644 index 0000000..e92f7b0 --- /dev/null +++ b/python/nutpie/normalizing_flow.py @@ -0,0 +1,1064 @@ +from typing import ClassVar, Union, Literal, Callable +import math +import itertools + +from flowjax.bijections.coupling import get_ravelled_pytree_constructor +from flowjax.utils import arraylike_to_array +import jax +import jax.numpy as jnp +import equinox as eqx +from flowjax import bijections +import flowjax.distributions +import flowjax.flows +from jaxtyping import Array, ArrayLike +import numpy as np +from paramax import NonTrainable, Parameterize +from equinox.nn import Linear +from paramax.wrappers import AbstractUnwrappable + + +_NN_ACTIVATION = jax.nn.gelu + + +def _generate_sequences(k, r_vals): + """ + Generate all binary sequences of length k with exactly r 1's. + The sequences are stored in a preallocated boolean NumPy array of shape (N, k), + where N = comb(k, r). A True value represents a '1' and False represents a '0'. + + Parameters: + k (int): The length of each sequence. + r (int): The exact number of ones in each sequence. + + Returns: + A NumPy boolean array of shape (comb(k, r), k) containing all sequences. + """ + all_sequences = [] + for r in r_vals: + N = math.comb(k, r) # number of sequences + sequences = np.zeros((N, k), dtype=bool) + # Use enumerate on all combinations where ones appear. + for i, ones_positions in enumerate(itertools.combinations(range(k), r)): + sequences[i, list(ones_positions)] = True + all_sequences.append(sequences) + return np.concatenate(all_sequences, axis=0) + + +def _max_run_length(seq): + """ + Given a 1D boolean NumPy array 'seq', compute the maximum run length of consecutive + identical values (either True or False). + + Parameters: + seq (np.array): A 1D boolean array. + + Returns: + The length (int) of the longest run. + """ + # If the sequence is empty, return 0. + if seq.size == 0: + return 0 + + # Convert boolean to int (0 or 1) so we can use np.diff. + arr = seq.astype(int) + # Compute differences between consecutive elements. + diffs = np.diff(arr) + # Positions where the value changes: + change_indices = np.nonzero(diffs)[0] + + if change_indices.size == 0: + # No changes at all, so the entire sequence is one run. + return seq.size + + # To compute the run lengths, add the "start" index (-1) and the last index. + # For example, if change_indices = [i1, i2, ..., in], + # then the runs are: (i1 - (-1)), (i2 - i1), ..., (seq.size-1 - in). + boundaries = np.concatenate(([-1], change_indices, [seq.size - 1])) + run_lengths = np.diff(boundaries) + return int(run_lengths.max()) + + +def _filter_sequences(sequences, m): + """ + Filter a 2D NumPy boolean array 'sequences' (each row a binary sequence) so that + only sequences with maximum run length (of 0's or 1's) at most m are kept. + + Parameters: + sequences (np.array): A 2D boolean array of shape (N, k). + m (int): Maximum allowed run length. + + Returns: + A NumPy array containing only the rows (sequences) that pass the filter. + """ + filtered = [] + for seq in sequences: + if _max_run_length(seq) <= m: + filtered.append(seq) + return np.array(filtered) + + +def _generate_permutations(rng, n_dim, n_layers, max_run=3): + if n_layers == 1: + r = [0, 1] + elif n_layers == 2: + r = [1] + else: + if n_layers % 2 == 0: + half = n_layers // 2 + r = [half - 1, half, half + 1] + else: + half = n_layers // 2 + r = [half, half + 1] + + all_sequences = _generate_sequences(n_layers, r) + valid_sequences = _filter_sequences(all_sequences, max_run) + + valid_sequences = np.repeat( + valid_sequences, n_dim // len(valid_sequences) + 1, axis=0 + ) + rng.shuffle(valid_sequences, axis=0) + is_in_first = valid_sequences[:n_dim] + rng = np.random.default_rng(42) + permutations = (~is_in_first).argsort(axis=0, kind="stable") + return permutations.T, is_in_first.sum(0) + + +class FactoredMLP(eqx.Module, strict=True): + """Standard Multi-Layer Perceptron; also known as a feed-forward network. + + !!! faq + + If you get a TypeError saying an object is not a valid JAX type, see the + [FAQ](https://docs.kidger.site/equinox/faq/).""" + + layers: tuple[tuple[Linear, Linear], ...] + activation: tuple[Callable, ...] + final_activation: Callable + use_bias: bool = eqx.field(static=True) + use_final_bias: bool = eqx.field(static=True) + in_size: Union[int, Literal["scalar"]] = eqx.field(static=True) + out_size: Union[int, Literal["scalar"]] = eqx.field(static=True) + width_size: tuple[int, ...] = eqx.field(static=True) + depth: int = eqx.field(static=True) + + def __init__( + self, + in_size: Union[int, Literal["scalar"]], + out_size: Union[int, Literal["scalar"]], + width_size: int | tuple[int | tuple[int, int], ...], + depth: int, + activation: Callable = jax.nn.relu, + final_activation: Callable = lambda x: x, + use_bias: bool = True, + use_final_bias: bool = True, + dtype=None, + *, + key, + ): + """**Arguments**: + + - `in_size`: The input size. The input to the module should be a vector of + shape `(in_features,)` + - `out_size`: The output size. The output from the module will be a vector + of shape `(out_features,)`. + - `width_size`: The size of each hidden layer. + - `depth`: The number of hidden layers, including the output layer. + For example, `depth=2` results in an network with layers: + [`Linear(in_size, width_size)`, `Linear(width_size, width_size)`, + `Linear(width_size, out_size)`]. + - `activation`: The activation function after each hidden layer. Defaults to + ReLU. + - `final_activation`: The activation function after the output layer. Defaults + to the identity. + - `use_bias`: Whether to add on a bias to internal layers. Defaults + to `True`. + - `use_final_bias`: Whether to add on a bias to the final layer. Defaults + to `True`. + - `dtype`: The dtype to use for all the weights and biases in this MLP. + Defaults to either `jax.numpy.float32` or `jax.numpy.float64` depending + on whether JAX is in 64-bit mode. + - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter + initialisation. (Keyword only argument.) + + Note that `in_size` also supports the string `"scalar"` as a special value. + In this case the input to the module should be of shape `()`. + + Likewise `out_size` can also be a string `"scalar"`, in which case the + output from the module will have shape `()`. + """ + keys = jax.random.split(key, depth + 1) + layers = [] + if isinstance(width_size, int): + width_size = (width_size,) * depth + + assert len(width_size) == depth + activations: list[Callable] = [] + + if depth == 0: + layers.append( + Linear(in_size, out_size, use_final_bias, dtype=dtype, key=keys[0]) + ) + else: + if isinstance(width_size[0], tuple): + n, k = width_size[0] + key1, key2 = jax.random.split(keys[0]) + U = Linear(in_size, n, use_bias=False, dtype=dtype, key=key1) + K = Linear(n, k, use_bias=True, dtype=dtype, key=key2) + layers.append((U, K)) + else: + k = width_size[0] + layers.append(Linear(in_size, k, use_bias, dtype=dtype, key=keys[0])) + activations.append(eqx.filter_vmap(lambda: activation, axis_size=k)()) + + for i in range(depth - 1): + if isinstance(width_size[i + 1], tuple): + n, k_new = width_size[i + 1] + key1, key2 = jax.random.split(keys[i + 1]) + U = Linear(k, n, use_bias=False, dtype=dtype, key=key1) + K = Linear(n, k_new, use_bias=True, dtype=dtype, key=key2) + layers.append((U, K)) + k = k_new + else: + layers.append( + Linear( + k, width_size[i + 1], use_bias, dtype=dtype, key=keys[i + 1] + ) + ) + k = width_size[i + 1] + activations.append(eqx.filter_vmap(lambda: activation, axis_size=k)()) + + if isinstance(out_size, tuple): + n, k_new = out_size + key1, key2 = jax.random.split(keys[-1]) + U = Linear(k, n, use_bias=False, dtype=dtype, key=key1) + K = Linear(n, k_new, use_bias=True, dtype=dtype, key=key2) + k = k_new + layers.append((U, K)) + else: + layers.append( + Linear(k, out_size, use_final_bias, dtype=dtype, key=keys[-1]) + ) + self.layers = tuple(layers) + self.in_size = in_size + self.out_size = out_size + self.width_size = width_size + self.depth = depth + # In case `activation` or `final_activation` are learnt, then make a separate + # copy of their weights for every neuron. + self.activation = tuple(activations) + if out_size == "scalar": + self.final_activation = final_activation + else: + self.final_activation = eqx.filter_vmap( + lambda: final_activation, axis_size=out_size + )() + self.use_bias = use_bias + self.use_final_bias = use_final_bias + + @jax.named_scope("eqx.nn.MLP") + def __call__(self, x: jax.Array, *, key=None) -> jax.Array: + """**Arguments:** + + - `x`: A JAX array with shape `(in_size,)`. (Or shape `()` if + `in_size="scalar"`.) + - `key`: Ignored; provided for compatibility with the rest of the Equinox API. + (Keyword only argument.) + + **Returns:** + + A JAX array with shape `(out_size,)`. (Or shape `()` if `out_size="scalar"`.) + """ + for i, (layer, act) in enumerate(zip(self.layers[:-1], self.activation)): + if isinstance(layer, tuple): + U, K = layer + x = U(x) + x = K(x) + else: + x = layer(x) + layer_activation = jax.tree.map( + lambda x: x[i] if eqx.is_array(x) else x, act + ) + x = eqx.filter_vmap(lambda a, b: a(b))(layer_activation, x) + + if isinstance(self.layers[-1], tuple): + U, K = self.layers[-1] + x = U(x) + x = K(x) + else: + x = self.layers[-1](x) + + if self.out_size == "scalar": + x = self.final_activation(x) + else: + x = eqx.filter_vmap(lambda a, b: a(b))(self.final_activation, x) + return x + + +class AsymmetricAffine(bijections.AbstractBijection): + """An asymmetric bijection that applies different scaling factors for + positive and negative inputs. + + This bijection implements a continuous, differentiable transformation that + scales positive and negative inputs differently while maintaining smoothness + at zero. It's particularly useful for modeling data with different variances + in positive and negative regions. + + The forward transformation is defined as: + y = σ θ x for x ≥ 0 + y = σ x/θ for x < 0 + where: + - σ (scale) controls the overall scaling + - θ (theta) controls the asymmetry between positive and negative regions + - μ (loc) controls the location shift + + The transformation uses a smooth transition between the two regions to + maintain differentiability. + + For θ = 0, this is exactly an affine function with the specified location + and scale. + + Attributes: + shape: The shape of the transformation parameters + cond_shape: Shape of conditional inputs (None as this bijection is + unconditional) + loc: Location parameter μ for shifting the distribution + scale: Scale parameter σ (positive) + theta: Asymmetry parameter θ (positive) + """ + + shape: tuple[int, ...] = () + cond_shape: ClassVar[None] = None + loc: Array + scale: Array | AbstractUnwrappable[Array] + theta: Array | AbstractUnwrappable[Array] + + def __init__( + self, + loc: ArrayLike = 0, + scale: ArrayLike = 1, + theta: ArrayLike = 1, + ): + self.loc, scale, theta = jnp.broadcast_arrays( + *(arraylike_to_array(a, dtype=float) for a in (loc, scale, theta)), + ) + self.shape = scale.shape + assert self.shape == () + self.scale = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(())) + self.theta = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(())) + + def _log_derivative_f(self, x, mu, sigma, theta): + abs_x = jnp.abs(x) + theta = jnp.log(theta) + + sinh_theta = jnp.sinh(theta) + # sinh_theta = (theta - 1 / theta) / 2 + cosh_theta = jnp.cosh(theta) + # cosh_theta = (theta + 1 / theta) / 2 + numerator = sinh_theta * x * (abs_x + 2.0) + denominator = (abs_x + 1.0) ** 2 + term = numerator / denominator + dy_dx = sigma * (cosh_theta + term) + return jnp.log(dy_dx) + + def transform_and_log_det( + self, x: ArrayLike, condition: ArrayLike | None = None + ) -> tuple[Array, Array]: + def transform(x, mu, sigma, theta): + weight = (jax.nn.soft_sign(x) + 1) / 2 + z = x * sigma + y_pos = z * theta + y_neg = z / theta + y = weight * y_pos + (1.0 - weight) * y_neg + mu + return y + + mu, sigma, theta = self.loc, self.scale, self.theta + + y = transform(x, mu, sigma, theta) + logjac = self._log_derivative_f(x, mu, sigma, theta) + return y, logjac.sum() + # y, jac = jax.value_and_grad(transform, argnums=0)(x, mu, sigma, theta) + # return y, jnp.log(jac) + + def inverse_and_log_det( + self, y: ArrayLike, condition: ArrayLike | None = None + ) -> tuple[Array, Array]: + def inverse(y, mu, sigma, theta): + delta = y - mu + inv_theta = 1 / theta + + # Case 1: y >= mu (delta >= 0) + a = sigma * (theta + inv_theta) + discriminant_pos = ( + jnp.square(a - 2.0 * delta) + 16.0 * sigma * theta * delta + ) + discriminant_pos = jnp.where(discriminant_pos < 0, 1.0, discriminant_pos) + sqrt_pos = jnp.sqrt(discriminant_pos) + numerator_pos = 2.0 * delta - a + sqrt_pos + denominator_pos = 4.0 * sigma * theta + x_pos = numerator_pos / denominator_pos + + # Case 2: y < mu (delta < 0) + sigma_part = sigma * (1.0 + theta * theta) + term2 = 2.0 * delta * theta + inside_sqrt_neg = ( + jnp.square(sigma_part + term2) - 16.0 * sigma * delta * theta + ) + inside_sqrt_neg = jnp.where(inside_sqrt_neg < 0, 1.0, inside_sqrt_neg) + sqrt_neg = jnp.sqrt(inside_sqrt_neg) + numerator_neg = sigma_part + term2 - sqrt_neg + denominator_neg = 4.0 * sigma + x_neg = numerator_neg / denominator_neg + + # Combine cases based on delta + x = jnp.where(delta >= 0.0, x_pos, x_neg) + return x + + mu, sigma, theta = self.loc, self.scale, self.theta + + x = inverse(y, mu, sigma, theta) + logjac = self._log_derivative_f(x, mu, sigma, theta) + return x, -logjac.sum() + # x, jac = jax.value_and_grad(inverse, argnums=0)(y, mu, sigma, theta) + # return x, jnp.log(jac) + + +class MvScale(bijections.AbstractBijection): + shape: tuple[int, ...] + params: Array + cond_shape = None + base_index: int + + def __init__(self, params: Array, base_index: int = 0): + self.shape = (params.shape[-1],) + self.params = params + self.base_index = base_index + + def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None): + scale = jnp.linalg.norm(self.params) + v = self.params / scale + y = x + ((v @ x) * (scale - 1)) * v + return y, jnp.log(scale) + + def inverse_and_log_det(self, y: Array, condition: Array | None = None): + scale = jnp.linalg.norm(self.params) + v = self.params / scale + x = y + ((v @ y) * (1 / scale - 1)) * v + return x, -jnp.log(scale) + + +class Coupling(bijections.AbstractBijection): + """Coupling layer implementation (https://arxiv.org/abs/1605.08803). + + Args: + key: Jax key + transformer: Unconditional bijection with shape () to be parameterised by the + conditioner neural netork. Parameters wrapped with ``NonTrainable`` + are excluded from being parameterized. + untransformed_dim: Number of untransformed conditioning variables (e.g. dim//2). + dim: Total dimension. + cond_dim: Dimension of additional conditioning variables. Defaults to None. + nn_width: Neural network hidden layer width. + nn_depth: Neural network hidden layer size. + nn_activation: Neural network activation function. Defaults to jnn.relu. + """ + + shape: tuple[int, ...] + cond_shape: tuple[int, ...] | None + untransformed_dim: int + dim: int + transformer_constructor: Callable + requires_vmap: bool + conditioner: eqx.nn.MLP | eqx.Module + + def __init__( + self, + key, + *, + transformer: bijections.AbstractBijection, + untransformed_dim: int, + dim: int, + cond_dim: int | None = None, + nn_width: int, + nn_depth: int, + nn_activation: Callable = jax.nn.relu, + conditioner: eqx.Module | None = None, + ): + if transformer.cond_shape is not None: + raise ValueError( + "Only unconditional transformers are supported.", + ) + n_transformed = dim - untransformed_dim + if n_transformed < 0: + raise ValueError( + "The number of untransformed variables must be less than the total " + "dimension.", + ) + if transformer.shape != () and transformer.shape != (n_transformed,): + raise ValueError( + "The transformer must have shape () or (n_transformed,), " + f"got {transformer.shape}.", + ) + + constructor, num_params = get_ravelled_pytree_constructor( + transformer, + filter_spec=eqx.is_inexact_array, + is_leaf=lambda leaf: isinstance(leaf, NonTrainable), + ) + + if transformer.shape == (): + self.requires_vmap = True + conditioner_output_size = num_params * n_transformed + else: + self.requires_vmap = False + conditioner_output_size = num_params + + self.transformer_constructor = constructor + self.untransformed_dim = untransformed_dim + self.dim = dim + self.shape = (dim,) + self.cond_shape = (cond_dim,) if cond_dim is not None else None + + if conditioner is None: + conditioner = eqx.nn.MLP( + in_size=( + untransformed_dim + if cond_dim is None + else untransformed_dim + cond_dim + ), + out_size=conditioner_output_size, + width_size=nn_width, + depth=nn_depth, + activation=nn_activation, + key=key, + ) + self.conditioner = conditioner(conditioner_output_size) + + def transform_and_log_det(self, x, condition=None): + x_cond, x_trans = x[: self.untransformed_dim], x[self.untransformed_dim :] + nn_input = x_cond if condition is None else jnp.hstack((x_cond, condition)) + transformer_params = self.conditioner(nn_input) + transformer = self._flat_params_to_transformer(transformer_params) + y_trans, log_det = transformer.transform_and_log_det(x_trans) + y = jnp.hstack((x_cond, y_trans)) + return y, log_det + + def inverse_and_log_det(self, y, condition=None): + x_cond, y_trans = y[: self.untransformed_dim], y[self.untransformed_dim :] + nn_input = x_cond if condition is None else jnp.concatenate((x_cond, condition)) + transformer_params = self.conditioner(nn_input) + transformer = self._flat_params_to_transformer(transformer_params) + x_trans, log_det = transformer.inverse_and_log_det(y_trans) + x = jnp.hstack((x_cond, x_trans)) + return x, log_det + + def _flat_params_to_transformer(self, params: Array): + """Reshape to dim X params_per_dim, then vmap.""" + if self.requires_vmap: + dim = self.dim - self.untransformed_dim + transformer_params = jnp.reshape(params, (dim, -1)) + transformer = eqx.filter_vmap(self.transformer_constructor)( + transformer_params + ) + return bijections.Vmap(transformer, in_axes=eqx.if_array(0)) + else: + transformer = self.transformer_constructor(params) + return transformer + + +def make_mvscale(key, n_dim, size, randomize_base=False): + def make_single_hh(key, idx): + key1, key2 = jax.random.split(key) + params = jax.random.normal(key1, (n_dim,)) + params = params / jnp.linalg.norm(params) + mvscale = MvScale(params) + return mvscale + + keys = jax.random.split(key, size) + + if randomize_base: + key, key_base = jax.random.split(key) + indices = jax.random.randint(key_base, (size,), 0, n_dim) + else: + indices = [val % n_dim for val in range(size)] + + return bijections.Chain( + [make_single_hh(key, idx) for key, idx in zip(keys, indices)] + ) + + +def make_hh(key, n_dim, size, randomize_base=False): + def make_single_hh(key, idx): + key1, key2 = jax.random.split(key) + params = jax.random.normal(key1, (n_dim,)) * 1e-2 + return bijections.Householder(params, base_index=idx) + + keys = jax.random.split(key, size) + + if randomize_base: + key, key_base = jax.random.split(key) + indices = jax.random.randint(key_base, (size,), 0, n_dim) + else: + indices = [val % n_dim for val in range(size)] + + return bijections.Chain( + [make_single_hh(key, idx) for key, idx in zip(keys, indices)] + ) + + +def make_elemwise_trafo(key, n_dim, *, count=1): + def make_elemwise(key, loc): + key1, key2 = jax.random.split(key) + scale = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(())) + theta = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(())) + + affine = AsymmetricAffine( + loc, + jnp.ones(()), + jnp.ones(()), + ) + + affine = eqx.tree_at( + where=lambda aff: aff.scale, + pytree=affine, + replace=scale, + ) + affine = eqx.tree_at( + where=lambda aff: aff.theta, + pytree=affine, + replace=theta, + ) + + return bijections.Invert(affine) + + def make(key): + keys = jax.random.split(key, count + 1) + key, keys = keys[0], keys[1:] + loc = jax.random.normal(key=key, shape=(count,)) * 2 + loc = loc - loc.mean() + return bijections.Chain([make_elemwise(key, mu) for key, mu in zip(keys, loc)]) + + keys = jax.random.split(key, n_dim) + make_affine = eqx.filter_vmap(make, axis_size=n_dim)(keys) + return bijections.Vmap(make_affine, in_axes=eqx.if_array(0)) + + +def make_coupling(key, dim, n_untransformed, *, inner_mvscale=False, **kwargs): + n_transformed = dim - n_untransformed + + nn_width = kwargs.get("nn_width", None) + nn_depth = kwargs.get("nn_depth", None) + + if nn_width is None: + if dim > 128: + nn_width = (64, 2 * dim) + else: + nn_width = 2 * dim + + if nn_depth is None: + if isinstance(nn_width, int): + nn_depth = 1 + else: + nn_depth = len(nn_width) + + transformer = make_elemwise_trafo(key, n_transformed, count=3) + + if inner_mvscale: + mvscale = make_mvscale(key, n_transformed, 1, randomize_base=True) + transformer = bijections.Chain([transformer, mvscale]) + + def make_mlp(out_size): + if isinstance(nn_width, tuple): + out = (nn_width[0], out_size) + else: + out = out_size + + return FactoredMLP( + n_untransformed, + out, + nn_width, + depth=nn_depth, + key=key, + dtype=jnp.float32, + activation=_NN_ACTIVATION, + ) + + return Coupling( + key, + transformer=transformer, + untransformed_dim=n_untransformed, + dim=dim, + conditioner=make_mlp, + **kwargs, + ) + + +def make_flow( + seed, + positions, + gradients, + *, + zero_init=False, + householder_layer=False, + dct_layer=False, + untransformed_dim: int | list[int | None] | None = None, + n_layers, + nn_width=None, + nn_depth=None, +): + from flowjax import bijections + + positions = np.array(positions) + gradients = np.array(gradients) + + if len(positions) == 0: + return + + n_draws, n_dim = positions.shape + + if n_dim < 2: + n_layers = 0 + + assert positions.shape == gradients.shape + + if n_draws == 0: + raise ValueError("No draws") + elif n_draws == 1: + assert np.all(gradients != 0) + diag = np.clip(1 / jnp.sqrt(jnp.abs(gradients[0])), 1e-5, 1e5) + assert np.isfinite(diag).all() + mean = jnp.zeros_like(diag) + else: + pos_std = np.clip(positions.std(0), 1e-8, 1e8) + grad_std = np.clip(gradients.std(0), 1e-8, 1e8) + diag = jnp.sqrt(pos_std / grad_std) + mean = positions.mean(0) + gradients.mean(0) * diag * diag + + key = jax.random.PRNGKey(seed % (2**63)) + + diag_param = Parameterize( + lambda x: x + jnp.sqrt(1 + x**2), + (diag**2 - 1) / (2 * diag), + ) + diag_affine = bijections.Affine(mean, diag) + diag_affine = eqx.tree_at( + where=lambda aff: aff.scale, + pytree=diag_affine, + replace=diag_param, + ) + + flows = [ + diag_affine, + ] + + if n_layers == 0: + return bijections.Chain(flows) + + def make_layer(key, untransformed_dim: int | None, permutation=None): + key, key_couple, key_permute, key_hh = jax.random.split(key, 4) + + if untransformed_dim is None: + untransformed_dim = n_dim // 2 + + if untransformed_dim < 0: + untransformed_dim = n_dim + untransformed_dim + + coupling = make_coupling( + key_couple, + n_dim, + untransformed_dim, + nn_activation=_NN_ACTIVATION, + nn_width=nn_width, + nn_depth=nn_depth, + ) + + if zero_init: + coupling = jax.tree_util.tree_map( + lambda x: x * 1e-3 if eqx.is_inexact_array(x) else x, + coupling, + ) + + flow = coupling + + if householder_layer: + hh = make_hh(key_hh, n_dim, 1, randomize_base=False) + flow = bijections.Sandwich(flow, hh) + + def add_default_permute(bijection, dim, key): + if dim == 1: + return bijection + if dim == 2: + outer = bijections.Flip((dim,)) + else: + outer = bijections.Permute(jax.random.permutation(key, jnp.arange(dim))) + + return bijections.Sandwich(bijection, outer) + + if permutation is None: + flow = add_default_permute(flow, n_dim, key_permute) + else: + flow = bijections.Sandwich(flow, bijections.Permute(permutation)) + + mvscale = make_mvscale(key, n_dim, 1, randomize_base=True) + + flow = bijections.Chain( + [ + mvscale, + flow, + ] + ) + + return flow + + key, key_permute = jax.random.split(key) + keys = jax.random.split(key, n_layers) + + if untransformed_dim is None: + # TODO better rng? + rng = np.random.default_rng(int(jax.random.randint(key, (), 0, 2**30))) + permutation, lengths = _generate_permutations(rng, n_dim, n_layers) + layers = [] + for i, (key, p, length) in enumerate(zip(keys, permutation, lengths)): + layers.append(make_layer(key, int(length), p)) + bijection = bijections.Chain(layers) + elif isinstance(untransformed_dim, int): + make_layers = eqx.filter_vmap(make_layer) + layers = make_layers(keys, untransformed_dim) + bijection = bijections.Scan(layers) + else: + layers = [] + for i, (key, num_untrafo) in enumerate(zip(keys, untransformed_dim)): + if i % 2 == 0 or not dct_layer: + layers.append(make_layer(key, num_untrafo)) + else: + inner = make_layer(key, num_untrafo) + outer = bijections.DCT(inner.shape) + + layers.append(bijections.Sandwich(inner, outer)) + + bijection = bijections.Chain(layers) + + return bijections.Chain([bijection, *flows]) + + +def extend_flow( + key, + base, + loss_fn, + positions, + gradients, + logps, + layer: int, + *, + extension_var_count=4, + zero_init=False, + householder_layer=False, + untransformed_dim: int | list[int | None] | None = None, + dct: bool = False, + extension_var_trafo_count=2, + verbose: bool = False, + nn_width=None, + nn_depth=None, +): + n_draws, n_dim = positions.shape + + if n_dim < 2: + return base + + if n_dim <= extension_var_count: + extension_var_count = n_dim - 1 + extension_var_trafo_count = 1 + + if dct: + flow = flowjax.flows.Transformed( + flowjax.distributions.StandardNormal(base.shape), + bijections.Chain([bijections.DCT(shape=(n_dim,)), base]), + ) + else: + flow = flowjax.flows.Transformed( + flowjax.distributions.StandardNormal(base.shape), base + ) + + params, static = eqx.partition(flow, eqx.is_inexact_array) + costs = loss_fn( + params, + static, + positions, + gradients, + logps, + return_elemwise_costs=True, + ) + + if verbose: + print(max(costs), costs) + print("dct:", dct) + idxs = np.argsort(costs) + + permute = bijections.Permute(idxs) + + if True: + scale = Parameterize( + lambda x: x + jnp.sqrt(1 + x**2), + jnp.array(0.0), + ) + theta = Parameterize( + lambda x: x + jnp.sqrt(1 + x**2), + jnp.array(0.0), + ) + + affine = bijections.AsymmetricAffine(jnp.zeros(()), jnp.ones(()), jnp.ones(())) + + affine = eqx.tree_at( + where=lambda aff: aff.scale, + pytree=affine, + replace=scale, + ) + affine = eqx.tree_at( + where=lambda aff: aff.theta, + pytree=affine, + replace=theta, + ) + + do_flip = layer % 2 == 0 + + if nn_width is None: + width = 16 + else: + width = nn_width + + if do_flip: + coupling = bijections.coupling.Coupling( + key, + transformer=affine, + untransformed_dim=n_dim - extension_var_trafo_count, + dim=n_dim, + nn_activation=_NN_ACTIVATION, + nn_width=width, + nn_depth=nn_depth, + ) + + inner_permute = bijections.Permute( + jnp.concatenate( + [ + jnp.arange(n_dim - extension_var_count), + jax.random.permutation( + key, jnp.arange(n_dim - extension_var_count, n_dim) + ), + ] + ) + ) + else: + coupling = bijections.coupling.Coupling( + key, + transformer=affine, + untransformed_dim=extension_var_trafo_count, + dim=n_dim, + nn_activation=_NN_ACTIVATION, + nn_width=width, + nn_depth=nn_depth, + ) + + inner_permute = bijections.Permute( + jnp.concatenate( + [ + jax.random.permutation( + key, jnp.arange(n_dim - extension_var_count, n_dim) + ), + jnp.arange(n_dim - extension_var_count), + ] + ) + ) + + if zero_init: + coupling = jax.tree_util.tree_map( + lambda x: x * 1e-3 if eqx.is_inexact_array(x) else x, + coupling, + ) + + inner = bijections.Sandwich(coupling, inner_permute) + + if False: + scale = Parameterize( + lambda x: x + jnp.sqrt(1 + x**2), + jnp.array(0.0), + ) + affine = eqx.tree_at( + where=lambda aff: aff.scale, + pytree=flowjax.bijections.Affine(), + replace=scale, + ) + + if nn_width is None: + width = 16 + else: + width = nn_width + + coupling = flowjax.bijections.coupling.Coupling( + key, + transformer=affine, + untransformed_dim=extension_var_trafo_count, + dim=n_dim, + nn_activation=_NN_ACTIVATION, + nn_width=width, + nn_depth=nn_depth, + ) + + if zero_init: + coupling = jax.tree_util.tree_map( + lambda x: x * 1e-3 if eqx.is_inexact_array(x) else x, + coupling, + ) + + if verbose: + print(costs[permute.permutation][inner.outer.permutation]) + + inner = bijections.Sandwich( + bijections.Chain( + [ + bijections.Sandwich(coupling, bijections.Flip(shape=(n_dim,))), + inner.inner, + ] + ), + inner.outer, + ) + + if dct: + new_layer = bijections.Sandwich( + bijections.Sandwich(inner, permute), + bijections.DCT(shape=(n_dim,)), + ) + else: + new_layer = bijections.Sandwich(inner, permute) + + scale = Parameterize( + lambda x: x + jnp.sqrt(1 + x**2), + jnp.zeros(n_dim), + ) + affine = eqx.tree_at( + where=lambda aff: aff.scale, + pytree=bijections.Affine(jnp.zeros(n_dim), jnp.ones(n_dim)), + replace=scale, + ) + + pre = [] + if layer % 2 == 0: + pre.append(bijections.Neg(shape=(n_dim,))) + + nonlin_layer = bijections.Sandwich( + affine, + bijections.Chain( + [ + *pre, + bijections.Vmap(bijections.SoftPlusX(), axis_size=n_dim), + ] + ), + ) + scale = Parameterize( + lambda x: x + jnp.sqrt(1 + x**2), + jnp.zeros(n_dim), + ) + affine = eqx.tree_at( + where=lambda aff: aff.scale, + pytree=bijections.Affine(jnp.zeros(n_dim), jnp.ones(n_dim)), + replace=scale, + ) + return bijections.Chain([new_layer, nonlin_layer, affine, base]) diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index 6c5bfc6..356b8d0 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -1,13 +1,13 @@ import os from dataclasses import dataclass -from typing import Any, Literal, Optional, overload +from typing import Any, Literal, Optional, cast, overload import arviz import numpy as np import pandas as pd import pyarrow -from nutpie import _lib +from nutpie import _lib # type: ignore @dataclass(frozen=True) @@ -281,7 +281,7 @@ def in_colab(): if in_colab(): return True try: - shell = get_ipython().__class__.__name__ + shell = get_ipython().__class__.__name__ # type: ignore if shell == "ZMQInteractiveShell": # Jupyter notebook, Spyder or qtconsole try: from IPython.display import ( @@ -398,6 +398,8 @@ def _extract(self, results): dims["divergence_start_gradient"] = ["unconstrained_parameter"] dims["divergence_end"] = ["unconstrained_parameter"] dims["divergence_momentum"] = ["unconstrained_parameter"] + dims["transformed_gradient"] = ["unconstrained_parameter"] + dims["transformed_position"] = ["unconstrained_parameter"] if self._return_raw_trace: return results @@ -453,14 +455,15 @@ def _repr_html_(self): def sample( compiled_model: CompiledModel, *, - draws: int, - tune: int, + draws: int | None, + tune: int | None, chains: int, cores: Optional[int], seed: Optional[int], save_warmup: bool, progress_bar: bool, low_rank_modified_mass_matrix: bool = False, + transform_adapt: bool = False, init_mean: Optional[np.ndarray], return_raw_trace: bool, blocking: Literal[True], @@ -472,14 +475,15 @@ def sample( def sample( compiled_model: CompiledModel, *, - draws: int, - tune: int, + draws: int | None, + tune: int | None, chains: int, cores: Optional[int], seed: Optional[int], save_warmup: bool, progress_bar: bool, low_rank_modified_mass_matrix: bool = False, + transform_adapt: bool = False, init_mean: Optional[np.ndarray], return_raw_trace: bool, blocking: Literal[False], @@ -490,14 +494,15 @@ def sample( def sample( compiled_model: CompiledModel, *, - draws: int = 1000, - tune: int = 300, + draws: int | None = None, + tune: int | None = None, chains: int = 6, cores: Optional[int] = None, seed: Optional[int] = None, save_warmup: bool = True, progress_bar: bool = True, low_rank_modified_mass_matrix: bool = False, + transform_adapt: bool = False, init_mean: Optional[np.ndarray] = None, return_raw_trace: bool = False, blocking: bool = True, @@ -510,9 +515,9 @@ def sample( Parameters ---------- - draws: int + draws: int | None The number of draws after tuning in each chain. - tune: int + tune: int | None The number of tuning (warmup) draws in each chain. chains: int The number of chains to sample. @@ -585,6 +590,9 @@ def sample( mass_matrix_gamma: float > 0, default=1e-5 Regularisation parameter for the eigenvalues. Only applicable with low_rank_modified_mass_matrix=True. + transform_adapt: bool, default=False + Use the experimental transform adaptation algorithm + during tuning. **kwargs Pass additional arguments to nutpie._lib.PySamplerArgs @@ -594,12 +602,22 @@ def sample( An ArviZ ``InferenceData`` object that contains the samples. """ + if low_rank_modified_mass_matrix and transform_adapt: + raise ValueError( + "Specify only one of `low_rank_modified_mass_matrix` and `transform_adapt`" + ) + if low_rank_modified_mass_matrix: settings = _lib.PyNutsSettings.LowRank(seed) + elif transform_adapt: + settings = _lib.PyNutsSettings.Transform(seed) else: settings = _lib.PyNutsSettings.Diag(seed) - settings.num_tune = tune - settings.num_draws = draws + + if tune is not None: + settings.num_tune = tune + if draws is not None: + settings.num_draws = draws settings.num_chains = chains for name, val in kwargs.items(): @@ -608,10 +626,10 @@ def sample( if cores is None: try: # Only available in python>=3.13 - available = os.process_cpu_count() + available = os.process_cpu_count() # type: ignore except AttributeError: available = os.cpu_count() - cores = min(chains, available) + cores = min(chains, cast(int, available)) if init_mean is None: init_mean = np.zeros(compiled_model.n_dim) diff --git a/python/nutpie/transform_adapter.py b/python/nutpie/transform_adapter.py new file mode 100644 index 0000000..30ca9a6 --- /dev/null +++ b/python/nutpie/transform_adapter.py @@ -0,0 +1,924 @@ +from functools import partial +from typing import Callable +import time + +from flowjax import bijections +from jaxtyping import ArrayLike, PyTree +import numpy as np +import equinox as eqx +import jax +import jax.numpy as jnp +import jax.random as jr +import traceback +import flowjax +import flowjax.flows +import flowjax.train +from flowjax.train.losses import MaximumLikelihoodLoss, PRNGKeyArray +from flowjax.train.train_utils import ( + count_fruitless, + get_batches, + step, + train_val_split, +) +import optax +from paramax import unwrap, NonTrainable + +from nutpie.normalizing_flow import Coupling, extend_flow, make_flow +import tqdm + +_BIJECTION_TRACE = [] + + +def fit_to_data( + key: PRNGKeyArray, + dist: PyTree, # Custom losses may support broader types than AbstractDistribution + x, + *, + condition: ArrayLike | None = None, + loss_fn: Callable | None = None, + max_epochs: int = 100, + max_patience: int = 5, + batch_size: int = 100, + val_prop: float = 0.1, + learning_rate: float = 5e-4, + optimizer: optax.GradientTransformation | None = None, + return_best: bool = True, + show_progress: bool = True, + opt_state=None, + verbose: bool = False, +): + r"""Train a distribution (e.g. a flow) to samples from the target distribution. + + The distribution can be unconditional :math:`p(x)` or conditional + :math:`p(x|\text{condition})`. Note that the last batch in each epoch is dropped + if truncated (to avoid recompilation). This function can also be used to fit + non-distribution pytrees as long as a compatible loss function is provided. + + Args: + key: Jax random seed. + dist: The distribution to train. + x: Samples from target distribution. + condition: Conditioning variables. Defaults to None. + loss_fn: Loss function. Defaults to MaximumLikelihoodLoss. + max_epochs: Maximum number of epochs. Defaults to 100. + max_patience: Number of consecutive epochs with no validation loss improvement + after which training is terminated. Defaults to 5. + batch_size: Batch size. Defaults to 100. + val_prop: Proportion of data to use in validation set. Defaults to 0.1. + learning_rate: Adam learning rate. Defaults to 5e-4. + optimizer: Optax optimizer. If provided, this overrides the default Adam + optimizer, and the learning_rate is ignored. Defaults to None. + return_best: Whether the result should use the parameters where the minimum loss + was reached (when True), or the parameters after the last update (when + False). Defaults to True. + show_progress: Whether to show progress bar. Defaults to True. + + Returns: + A tuple containing the trained distribution and the losses. + """ + if not isinstance(x, tuple): + x = (x,) + data = x if condition is None else (*x, condition) + data = tuple(jnp.asarray(a) for a in data) + + if optimizer is None: + optimizer = optax.adam(learning_rate) + + if loss_fn is None: + loss_fn = MaximumLikelihoodLoss() + + params, static = eqx.partition( + dist, + eqx.is_inexact_array, + is_leaf=lambda leaf: isinstance(leaf, NonTrainable), + ) + best_params = params + + if opt_state is None: + opt_state = optimizer.init(params) + + # train val split + key, subkey = jr.split(key) + train_data, val_data = train_val_split(subkey, data, val_prop=val_prop) + losses = {"train": [], "val": []} + + loop = tqdm.tqdm(range(max_epochs), disable=not show_progress) + + for i in loop: + # Shuffle data + start = time.time() + key, *subkeys = jr.split(key, 3) + train_data = [jr.permutation(subkeys[0], a) for a in train_data] + val_data = [jr.permutation(subkeys[1], a) for a in val_data] + if verbose and i == 0: + print("shuffle timing:", time.time() - start) + + start = time.time() + + key, subkey = jr.split(key) + batches = get_batches(train_data, batch_size) + batch_losses = [] + + if verbose and i == 0: + print("batch timing:", time.time() - start) + + start = time.time() + + if True: + for batch in zip(*batches, strict=True): + key, subkey = jr.split(key) + params, opt_state, batch_loss = step( + params, + static, + *batch, + optimizer=optimizer, + opt_state=opt_state, + loss_fn=loss_fn, + key=subkey, + ) + batch_losses.append(batch_loss) + else: + params, opt_state, batch_losses = _step_batch_loop( + params, + static, + opt_state, + optimizer, + loss_fn, + subkey, + *batches, + ) + + losses["train"].append((sum(batch_losses) / len(batch_losses)).item()) + + if verbose and i == 0: + print("step timing:", time.time() - start) + + start = time.time() + # Val epoch + batch_losses = [] + for batch in zip(*get_batches(val_data, batch_size), strict=True): + key, subkey = jr.split(key) + loss_i = loss_fn(params, static, *batch, key=subkey) + batch_losses.append(loss_i) + losses["val"].append(sum(batch_losses) / len(batch_losses)) + + if verbose and i == 0: + print("val timing:", time.time() - start) + + loop.set_postfix({k: v[-1] for k, v in losses.items()}) + if losses["val"][-1] == min(losses["val"]): + best_params = params + + elif count_fruitless(losses["val"]) > max_patience: + loop.set_postfix_str(f"{loop.postfix} (Max patience reached)") + break + + params = best_params if return_best else params + dist = eqx.combine(params, static) + return dist, losses, opt_state + + +@eqx.filter_jit +def _step_batch_loop(params, static, opt_state, optimizer, loss_fn, key, *batches): + def scan_fn(carry, batch): + params, opt_state, key = carry + key, subkey = jr.split(key) + params, opt_state, loss_i = step( + params, + static, + *batch, + optimizer=optimizer, + opt_state=opt_state, + loss_fn=loss_fn, + key=subkey, + ) + return (params, opt_state, key), loss_i + + (params, opt_state, _), batch_losses = jax.lax.scan( + scan_fn, (params, opt_state, key), batches + ) + + return params, opt_state, batch_losses + + +@eqx.filter_jit +def inverse_gradient_and_val(bijection, draw, grad, logp): + if False: + x = bijection.inverse(draw) + (_, fwd_log_det), pull_grad_fn = jax.vjp( + lambda x: bijection.transform_and_log_det(x), x + ) + (x_grad,) = pull_grad_fn((grad, jnp.ones(()))) + return (x, x_grad, logp + fwd_log_det) + if isinstance(bijection, bijections.Chain): + for b in bijection.bijections[::-1]: + draw, grad, logp = inverse_gradient_and_val(b, draw, grad, logp) + return draw, grad, logp + elif isinstance(bijection, bijections.Permute): + return ( + draw[bijection.inverse_permutation], + grad[bijection.inverse_permutation], + logp, + ) + elif isinstance(bijection, bijections.Affine): + draw, logdet = bijection.inverse_and_log_det(draw) + grad = grad * bijection.scale + return (draw, grad, logp - logdet) + elif isinstance(bijection, bijections.Vmap): + + def inner(bijection, y, y_grad, y_logp): + return inverse_gradient_and_val(bijection, y, y_grad, y_logp) + + y, y_grad, log_det = eqx.filter_vmap( + inner, + in_axes=(bijection.in_axes[0], 0, 0, None), + axis_size=bijection.axis_size, + )(bijection.bijection, draw, grad, jnp.zeros(())) + return y, y_grad, jnp.sum(log_det) + logp + elif isinstance(bijection, bijections.Sandwich): + draw, grad, logp = inverse_gradient_and_val( + bijections.Invert(bijection.outer), draw, grad, logp + ) + draw, grad, logp = inverse_gradient_and_val(bijection.inner, draw, grad, logp) + draw, grad, logp = inverse_gradient_and_val(bijection.outer, draw, grad, logp) + return draw, grad, logp + # Disabeling the Coupling case for now, it slows down compile time for some reason? + elif False and isinstance(bijection, Coupling): + y, y_grad, y_logp = draw, grad, logp + y_cond, y_trans = ( + y[: bijection.untransformed_dim], + y[bijection.untransformed_dim :], + ) + x_cond = y_cond + + y_grad_cond, y_grad_trans = ( + y_grad[: bijection.untransformed_dim], + y_grad[bijection.untransformed_dim :], + ) + + def conditioner(x_cond): + return bijection.conditioner(x_cond) + + transformer_params, nn_pull = jax.vjp(conditioner, x_cond) + + def pull_transformer_grad(transformer_params): + transformer = bijection._flat_params_to_transformer(transformer_params) + + x_trans, x_grad_trans, x_logp = inverse_gradient_and_val( + transformer, y_trans, y_grad_trans, y_logp + ) + + return (x_logp, x_trans), x_grad_trans + + ((x_logp, x_trans), pull_pull_transformer_grad, x_grad_trans) = jax.vjp( + pull_transformer_grad, transformer_params, has_aux=True + ) + + (co_transformer_params,) = pull_pull_transformer_grad((1.0, -x_grad_trans)) + (co_x_cond,) = nn_pull(co_transformer_params) + + x = jnp.hstack((x_cond, x_trans)) + x_grad = jnp.hstack((y_grad_cond + co_x_cond, x_grad_trans)) + return x, x_grad, x_logp + + elif isinstance(bijection, bijections.Invert): + inner = bijection.bijection + x, _ = inner.transform_and_log_det(draw) + (_, fwd_log_det), pull_grad_fn = jax.vjp( + lambda x: inner.inverse_and_log_det(x), x + ) + (x_grad,) = pull_grad_fn((grad, jnp.ones(()))) + return (x, x_grad, logp + fwd_log_det) + else: + x, _ = bijection.inverse_and_log_det(draw) + (_, fwd_log_det), pull_grad_fn = jax.vjp( + lambda x: bijection.transform_and_log_det(x), x + ) + (x_grad,) = pull_grad_fn((grad, jnp.ones(()))) + return (x, x_grad, logp + fwd_log_det) + + +class FisherLoss: + def __init__(self, gamma=None, log_inside_batch=False): + self._gamma = gamma + self._log_inside_batch = log_inside_batch + + @eqx.filter_jit + def __call__( + self, + params, + static, + draws, + grads, + logps, + condition=None, + key=None, + return_all_costs=False, + return_elemwise_costs=False, + ): + flow = unwrap(eqx.combine(params, static, is_leaf=eqx.is_inexact_array)) + + if return_elemwise_costs: + + def compute_loss(bijection, draw, grad, logp): + draw, grad, logp = inverse_gradient_and_val(bijection, draw, grad, logp) + cost = (draw + grad) ** 2 + return cost + + costs = jax.vmap(compute_loss, [None, 0, 0, 0])( + flow.bijection, + draws, + grads, + logps, + ) + return costs.mean(0) + + if self._gamma is None: + + def compute_loss(bijection, draw, grad, logp): + draw, grad, logp = inverse_gradient_and_val(bijection, draw, grad, logp) + cost = ((draw + grad) ** 2).sum() + return cost + + costs = jax.vmap(compute_loss, [None, 0, 0, 0])( + flow.bijection, + draws, + grads, + logps, + ) + + if return_all_costs: + return costs + + if self._log_inside_batch: + return jnp.log(costs).mean() + else: + return jnp.log(costs.mean()) + + else: + + def transform(draw, grad, logp): + return inverse_gradient_and_val(flow.bijection, draw, grad, logp) + + draws, grads, logps = jax.vmap(transform, [0, 0, 0], (0, 0, 0))( + draws, grads, logps + ) + fisher_loss = ((draws + grads) ** 2).sum(1).mean(0) + normal_logps = -(draws * draws).sum(1) / 2 + var_loss = (logps - normal_logps).var() + return jnp.log(fisher_loss + self._gamma * var_loss) + + +def fit_flow(key, bijection, loss_fn, draws, grads, logps, **kwargs): + flow = flowjax.flows.Transformed( + flowjax.distributions.StandardNormal(bijection.shape), bijection + ) + + key, train_key = jax.random.split(key) + + fit, losses, opt_state = fit_to_data( + key=train_key, + dist=flow, + x=(draws, grads, logps), + loss_fn=loss_fn, + max_epochs=500, + return_best=True, + **kwargs, + ) + return fit.bijection, losses, opt_state + + +@eqx.filter_jit +def _init_from_transformed_position(logp_fn, bijection, transformed_position): + bijection = unwrap(bijection) + (untransformed_position, logdet), pull_grad = jax.vjp( + bijection.transform_and_log_det, transformed_position + ) + logp, untransformed_gradient = jax.value_and_grad(lambda x: logp_fn(x)[0])( + untransformed_position + ) + (transformed_gradient,) = pull_grad((untransformed_gradient, 1.0)) + return ( + logp, + logdet, + untransformed_position, + untransformed_gradient, + transformed_gradient, + ) + + +@eqx.filter_jit +def _init_from_transformed_position_part1(logp_fn, bijection, transformed_position): + bijection = unwrap(bijection) + (untransformed_position, logdet) = bijection.transform_and_log_det( + transformed_position + ) + + return (logdet, untransformed_position) + + +@eqx.filter_jit +def _init_from_transformed_position_part2( + bijection, + part1, + untransformed_gradient, +): + logdet, untransformed_position, transformed_position = part1 + bijection = unwrap(bijection) + _, pull_grad = jax.vjp(bijection.transform_and_log_det, transformed_position) + (transformed_gradient,) = pull_grad((untransformed_gradient, 1.0)) + return ( + logdet, + transformed_gradient, + ) + + +@eqx.filter_jit +def _init_from_untransformed_position(logp_fn, bijection, untransformed_position): + logp, untransformed_gradient = jax.value_and_grad(lambda x: logp_fn(x)[0])( + untransformed_position + ) + logdet, transformed_position, transformed_gradient = _inv_transform( + bijection, untransformed_position, untransformed_gradient + ) + return ( + logp, + logdet, + untransformed_gradient, + transformed_position, + transformed_gradient, + ) + + +@eqx.filter_jit +def _inv_transform(bijection, untransformed_position, untransformed_gradient): + bijection = unwrap(bijection) + transformed_position, transformed_gradient, logdet = inverse_gradient_and_val( + bijection, untransformed_position, untransformed_gradient, 0.0 + ) + return logdet, transformed_position, transformed_gradient + + +class TransformAdapter: + def __init__( + self, + seed, + position, + gradient, + chain, + *, + logp_fn, + make_flow_fn, + verbose=False, + window_size=2000, + show_progress=False, + num_diag_windows=10, + learning_rate=1e-3, + zero_init=True, + untransformed_dim=None, + batch_size=128, + reuse_opt_state=True, + max_patience=5, + gamma=None, + log_inside_batch=False, + initial_skip=500, + extension_windows=None, + extend_dct=False, + extension_var_count=6, + extension_var_trafo_count=4, + debug_save_bijection=False, + make_optimizer=None, + num_layers=9, + ): + self._logp_fn = logp_fn + self._make_flow_fn = make_flow_fn + self._chain = chain + self._verbose = verbose + self._window_size = window_size + self._initial_skip = initial_skip + self._num_layers = num_layers + if make_optimizer is None: + self._make_optimizer = lambda: optax.apply_if_finite( + optax.adamw(learning_rate), 50 + ) + else: + self._make_optimizer = make_optimizer + self._optimizer = self._make_optimizer() + self._loss_fn = FisherLoss(gamma, log_inside_batch) + self._show_progress = show_progress + self._num_diag_windows = num_diag_windows + self._zero_init = zero_init + self._untransformed_dim = untransformed_dim + self._batch_size = batch_size + self._reuse_opt_state = reuse_opt_state + self._opt_state = None + self._max_patience = max_patience + self._count_trace = [] + self._last_extend_dct = True + self._extend_dct = extend_dct + self._extension_var_count = extension_var_count + self._extension_var_trafo_count = extension_var_trafo_count + self._debug_save_bijection = debug_save_bijection + self._layers = 0 + + if extension_windows is None: + self._extension_windows = [] + else: + self._extension_windows = extension_windows + + try: + self._bijection = make_flow_fn(seed, [position], [gradient], n_layers=0) + except Exception as e: + print("make_flow", e) + print(traceback.format_exc()) + raise + self.index = 0 + + @property + def transformation_id(self): + return self.index + + def update(self, seed, positions, gradients, logps): + self.index += 1 + if self._verbose: + print( + f"Chain {self._chain}: Total available points: {len(positions)}, seed {seed}" + ) + n_draws = len(positions) + assert n_draws == len(positions) + assert n_draws == len(gradients) + assert n_draws == len(logps) + self._count_trace.append(n_draws) + if n_draws == 0: + return + try: + if self.index <= self._num_diag_windows: + size = len(positions) + lower_idx = -size // 5 + 3 + positions_slice = positions[lower_idx:] + gradients_slice = gradients[lower_idx:] + logp_slice = logps[lower_idx:] + + if len(positions_slice) > 0: + positions = positions_slice + gradients = gradients_slice + logps = logp_slice + + positions = np.array(positions) + gradients = np.array(gradients) + logps = np.array(logps) + + fit = self._make_flow_fn(seed, positions, gradients, n_layers=0) + + flow = flowjax.flows.Transformed( + flowjax.distributions.StandardNormal(fit.shape), fit + ) + params, static = eqx.partition(flow, eqx.is_inexact_array) + new_loss = self._loss_fn(params, static, positions, gradients, logps) + + if self._verbose: + print("loss from diag:", new_loss) + + if np.isfinite(new_loss): + self._bijection = fit + self._opt_state = None + + return + + positions = np.array(positions[self._initial_skip :][-self._window_size :]) + gradients = np.array(gradients[self._initial_skip :][-self._window_size :]) + logps = np.array(logps[self._initial_skip :][-self._window_size :]) + + if len(positions) < 10: + return + + if self._verbose and not np.isfinite(gradients).all(): + print(gradients) + print(gradients.shape) + print((~np.isfinite(gradients)).nonzero()) + + assert np.isfinite(positions).all() + assert np.isfinite(gradients).all() + assert np.isfinite(logps).all() + + # TODO don't reuse seed + key = jax.random.PRNGKey(seed % (2**63)) + + if len(self._bijection.bijections) == 1: + base = self._make_flow_fn( + seed, + positions, + gradients, + n_layers=self._num_layers, + untransformed_dim=self._untransformed_dim, + zero_init=self._zero_init, + ) + flow = flowjax.flows.Transformed( + flowjax.distributions.StandardNormal(base.shape), base + ) + params, static = eqx.partition(flow, eqx.is_inexact_array) + if self._verbose: + print( + "loss before optimization: ", + self._loss_fn( + params, + static, + positions[-128:], + gradients[-128:], + logps[-128:], + ), + ) + else: + base = self._bijection + + if self.index in self._extension_windows: + if self._verbose: + print("Extending flow...") + self._last_extend_dct = not self._last_extend_dct + dct = self._last_extend_dct and self._extend_dct + base = extend_flow( + key, + base, + self._loss_fn, + positions, + gradients, + logps, + self._layers, + dct=dct, + extension_var_count=self._extension_var_count, + extension_var_trafo_count=self._extension_var_trafo_count, + verbose=self._verbose, + ) + self._optimizer = self._make_optimizer() + self._opt_state = None + self._layers += 1 + + # make_flow might still onreturn a single trafo for 1d problems + if len(base.bijections) == 1: + self._bijection = base + self._opt_state = None + + if self._debug_save_bijection: + _BIJECTION_TRACE.append( + (self.index, fit, (positions, gradients, logps)) + ) + return + + flow = flowjax.flows.Transformed( + flowjax.distributions.StandardNormal(self._bijection.shape), + self._bijection, + ) + params, static = eqx.partition(flow, eqx.is_inexact_array) + + start = time.time() + old_loss = self._loss_fn( + params, static, positions[-128:], gradients[-128:], logps[-128:] + ) + if self._verbose: + print("loss function time: ", time.time() - start) + + if np.isfinite(old_loss) and old_loss < -5 and self.index > 10: + if self._verbose: + print(f"Loss is low ({old_loss}), skipping training") + return + + fit, _, opt_state = fit_flow( + key, + base, + self._loss_fn, + positions, + gradients, + logps, + show_progress=self._show_progress, + verbose=self._verbose, + optimizer=self._optimizer, + batch_size=self._batch_size, + opt_state=self._opt_state if self._reuse_opt_state else None, + max_patience=self._max_patience, + ) + + flow = flowjax.flows.Transformed( + flowjax.distributions.StandardNormal(fit.shape), fit + ) + params, static = eqx.partition(flow, eqx.is_inexact_array) + + start = time.time() + new_loss = self._loss_fn( + params, static, positions[-128:], gradients[-128:], logps[-128:] + ) + if self._verbose: + print("new loss function time: ", time.time() - start) + + if self._verbose: + print(f"Chain {self._chain}: New loss {new_loss}, old loss {old_loss}") + + if not np.isfinite(old_loss): + flow = flowjax.flows.Transformed( + flowjax.distributions.StandardNormal(self._bijection.shape), + self._bijection, + ) + params, static = eqx.partition(flow, eqx.is_inexact_array) + print( + self._loss_fn( + params, + static, + positions[-128:], + gradients[-128:], + logps[-128:], + return_all_costs=True, + ) + ) + + if not np.isfinite(new_loss): + flow = flowjax.flows.Transformed( + flowjax.distributions.StandardNormal(fit.shape), fit + ) + params, static = eqx.partition(flow, eqx.is_inexact_array) + print( + self._loss_fn( + params, + static, + positions[-128:], + gradients[-128:], + logps[-128:], + return_all_costs=True, + ) + ) + + if self._debug_save_bijection: + _BIJECTION_TRACE.append( + (self.index, fit, (positions, gradients, logps)) + ) + + def valid_new_logp(): + logdet, pos, grad = _inv_transform( + fit, + jnp.array(positions[-1]), + jnp.array(gradients[-1]), + ) + return ( + np.isfinite(logdet) + and np.isfinite(pos[0]).all() + and np.isfinite(grad[0]).all() + ) + + if (not np.isfinite(old_loss)) and (not np.isfinite(new_loss)): + self._bijection = self._make_flow_fn( + seed, positions, gradients, n_layers=0 + ) + self._opt_state = None + return + + if not valid_new_logp(): + if self._verbose: + print("Invalid new logp. Skipping update.") + return + + if not np.isfinite(new_loss): + if self._verbose: + print("Invalid new loss. Skipping update.") + return + + if new_loss > old_loss: + return + + self._bijection = fit + self._opt_state = opt_state + + except Exception as e: + print("update error:", e) + print(traceback.format_exc()) + raise + + def init_from_transformed_position(self, transformed_position): + try: + logp, logdet, *arrays = _init_from_transformed_position( + self._logp_fn, + self._bijection, + jnp.array(transformed_position), + ) + return ( + float(logp), + float(logdet), + *[np.array(val, dtype="float64") for val in arrays], + ) + except Exception as e: + print(e) + print(traceback.format_exc()) + raise + + def init_from_transformed_position_part1(self, transformed_position): + try: + transformed_position = jnp.array(transformed_position) + logdet, untransformed_position = _init_from_transformed_position_part1( + self._logp_fn, + self._bijection, + transformed_position, + ) + part1 = (logdet, untransformed_position, transformed_position) + return np.array(untransformed_position, dtype="float64"), part1 + except Exception as e: + print(e) + print(traceback.format_exc()) + raise + + def init_from_transformed_position_part2( + self, + part1, + untransformed_gradient, + ): + try: + # TODO We could extract the arrays from the pull_grad function + # to reuse computation from part1 + logdet, *arrays = _init_from_transformed_position_part2( + self._bijection, + part1, + untransformed_gradient, + ) + return float(logdet), *[np.array(val, dtype="float64") for val in arrays] + except Exception as e: + print(e) + print(traceback.format_exc()) + raise + + def init_from_untransformed_position(self, untransformed_position): + try: + logp, logdet, *arrays = _init_from_untransformed_position( + self._logp_fn, + self._bijection, + jnp.array(untransformed_position), + ) + arrays = [np.array(val, dtype="float64") for val in arrays] + return float(logp), float(logdet), *arrays + except Exception as e: + print(e) + print(traceback.format_exc()) + raise + + def inv_transform(self, position, gradient): + try: + logdet, *arrays = _inv_transform( + self._bijection, jnp.array(position), jnp.array(gradient) + ) + return logdet, *[np.array(val, dtype="float64") for val in arrays] + except Exception as e: + print(e) + print(traceback.format_exc()) + raise + + +def make_transform_adapter( + *, + verbose=False, + window_size=600, + show_progress=False, + nn_depth=1, + nn_width=None, + num_layers=9, + num_diag_windows=9, + learning_rate=5e-4, + untransformed_dim=None, + zero_init=True, + batch_size=128, + reuse_opt_state=False, + max_patience=20, + householder_layer=False, + dct_layer=False, + gamma=None, + log_inside_batch=False, + initial_skip=120, + extension_windows=[], + extend_dct=False, + extension_var_count=4, + extension_var_trafo_count=2, + debug_save_bijection=False, + make_optimizer=None, +): + return partial( + TransformAdapter, + verbose=verbose, + window_size=window_size, + make_flow_fn=partial( + make_flow, + householder_layer=householder_layer, + dct_layer=dct_layer, + nn_width=nn_width, + ), + show_progress=show_progress, + num_diag_windows=num_diag_windows, + learning_rate=learning_rate, + zero_init=zero_init, + untransformed_dim=untransformed_dim, + batch_size=batch_size, + reuse_opt_state=reuse_opt_state, + max_patience=max_patience, + gamma=gamma, + log_inside_batch=log_inside_batch, + initial_skip=initial_skip, + extension_windows=extension_windows, + extend_dct=extend_dct, + extension_var_count=extension_var_count, + extension_var_trafo_count=extension_var_trafo_count, + debug_save_bijection=debug_save_bijection, + make_optimizer=make_optimizer, + ) diff --git a/src/progress.rs b/src/progress.rs index 906c50d..2403e15 100644 --- a/src/progress.rs +++ b/src/progress.rs @@ -1,4 +1,4 @@ -use std::{collections::BTreeMap, time::Duration}; +use std::{collections::BTreeMap, sync::Arc, time::Duration}; use anyhow::{Context, Result}; use indicatif::ProgressBar; @@ -10,13 +10,13 @@ use upon::{Engine, Value}; pub struct ProgressHandler { engine: Engine<'static>, template: String, - callback: Py, + callback: Arc>, rate: Duration, n_cores: usize, } impl ProgressHandler { - pub fn new(callback: Py, rate: Duration, template: String, n_cores: usize) -> Self { + pub fn new(callback: Arc>, rate: Duration, template: String, n_cores: usize) -> Self { let engine = Engine::new(); Self { engine, diff --git a/src/pyfunc.rs b/src/pyfunc.rs index f914974..f23145e 100644 --- a/src/pyfunc.rs +++ b/src/pyfunc.rs @@ -8,9 +8,10 @@ use arrow::{ }, datatypes::{DataType, Field, Float32Type, Float64Type, Int64Type}, }; -use numpy::{PyArray1, PyReadonlyArray1}; +use numpy::{NotContiguousError, PyArray1, PyReadonlyArray1}; use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Model}; use pyo3::{ + exceptions::PyRuntimeError, pyclass, pymethods, types::{PyAnyMethods, PyDict, PyDictMethods}, Bound, Py, PyAny, PyErr, Python, @@ -20,6 +21,8 @@ use rand_distr::{Distribution, Uniform}; use smallvec::SmallVec; use thiserror::Error; +use crate::wrapper::PyTransformAdapt; + #[pyclass] #[derive(Debug, Clone)] #[non_exhaustive] @@ -71,29 +74,33 @@ impl PyVariable { #[pyclass] #[derive(Debug, Clone)] pub struct PyModel { - make_logp_func: Py, - make_expand_func: Py, - init_point_func: Option>, - variables: Vec, + make_logp_func: Arc>, + make_expand_func: Arc>, + init_point_func: Option>>, + variables: Arc>, + transform_adapter: Option, ndim: usize, } #[pymethods] impl PyModel { #[new] + #[pyo3(signature = (make_logp_func, make_expand_func, variables, ndim, *, init_point_func=None, transform_adapter=None))] fn new<'py>( make_logp_func: Py, make_expand_func: Py, variables: Vec, ndim: usize, init_point_func: Option>, + transform_adapter: Option>, ) -> Self { Self { - make_logp_func, - make_expand_func, - init_point_func, - variables, + make_logp_func: Arc::new(make_logp_func), + make_expand_func: Arc::new(make_expand_func), + init_point_func: init_point_func.map(|x| x.into()), + variables: Arc::new(variables), ndim, + transform_adapter: transform_adapter.map(PyTransformAdapt::new), } } } @@ -103,9 +110,13 @@ pub enum PyLogpError { #[error("Bad logp value: {0}")] BadLogp(f64), #[error("Python error: {0}")] - PyError(PyErr), + PyError(#[from] PyErr), #[error("logp function must return float.")] ReturnTypeError(), + #[error("Python retured a non-contigous array")] + NotContiguousError(#[from] NotContiguousError), + #[error("Unknown error: {0}")] + Anyhow(#[from] anyhow::Error), } impl LogpError for PyLogpError { @@ -113,7 +124,7 @@ impl LogpError for PyLogpError { match self { Self::BadLogp(_) => true, Self::PyError(err) => Python::with_gil(|py| { - let Ok(attr) = err.value_bound(py).getattr("is_recoverable") else { + let Ok(attr) = err.value(py).getattr("is_recoverable") else { return false; }; return attr @@ -121,20 +132,29 @@ impl LogpError for PyLogpError { .expect("Could not access is_recoverable in error check"); }), Self::ReturnTypeError() => false, + Self::NotContiguousError(_) => false, + Self::Anyhow(_) => false, } } } pub struct PyDensity { logp: Py, + transform_adapter: Option, dim: usize, } impl PyDensity { - fn new(logp_clone_func: &Py, dim: usize) -> Result { + fn new( + logp_clone_func: &Py, + dim: usize, + transform_adapter: Option<&PyTransformAdapt>, + ) -> Result { let logp_func = Python::with_gil(|py| logp_clone_func.call0(py))?; + let transform_adapter = transform_adapter.map(|val| val.clone()); Ok(Self { logp: logp_func, + transform_adapter, dim, }) } @@ -142,10 +162,11 @@ impl PyDensity { impl CpuLogpFunc for PyDensity { type LogpError = PyLogpError; + type TransformParams = Py; fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result { Python::with_gil(|py| { - let pos_array = PyArray1::from_slice_bound(py, position); + let pos_array = PyArray1::from_slice(py, position); let result = self.logp.call1(py, (pos_array,)); match result { Ok(val) => { @@ -172,11 +193,122 @@ impl CpuLogpFunc for PyDensity { fn dim(&self) -> usize { self.dim } + + fn inv_transform_normalize( + &mut self, + params: &Py, + untransformed_position: &[f64], + untransformed_gradient: &[f64], + transformed_position: &mut [f64], + transformed_gradient: &mut [f64], + ) -> std::result::Result { + let logdet = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .inv_transform_normalize( + params, + untransformed_position, + untransformed_gradient, + transformed_position, + transformed_gradient, + )?; + Ok(logdet) + } + + fn init_from_transformed_position( + &mut self, + params: &Py, + untransformed_position: &mut [f64], + untransformed_gradient: &mut [f64], + transformed_position: &[f64], + transformed_gradient: &mut [f64], + ) -> std::result::Result<(f64, f64), Self::LogpError> { + let (logp, logdet) = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .init_from_transformed_position( + params, + untransformed_position, + untransformed_gradient, + transformed_position, + transformed_gradient, + )?; + Ok((logp, logdet)) + } + + fn init_from_untransformed_position( + &mut self, + params: &Py, + untransformed_position: &[f64], + untransformed_gradient: &mut [f64], + transformed_position: &mut [f64], + transformed_gradient: &mut [f64], + ) -> std::result::Result<(f64, f64), Self::LogpError> { + let (logp, logdet) = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .init_from_untransformed_position( + params, + untransformed_position, + untransformed_gradient, + transformed_position, + transformed_gradient, + )?; + Ok((logp, logdet)) + } + + fn update_transformation<'a, R: rand::Rng + ?Sized>( + &'a mut self, + rng: &mut R, + untransformed_positions: impl ExactSizeIterator, + untransformed_gradients: impl ExactSizeIterator, + untransformed_logp: impl ExactSizeIterator, + params: &'a mut Py, + ) -> std::result::Result<(), Self::LogpError> { + self.transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .update_transformation( + rng, + untransformed_positions, + untransformed_gradients, + untransformed_logp, + params, + )?; + Ok(()) + } + + fn new_transformation( + &mut self, + rng: &mut R, + untransformed_position: &[f64], + untransformed_gradient: &[f64], + chain: u64, + ) -> std::result::Result, Self::LogpError> { + let trafo = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .new_transformation(rng, untransformed_position, untransformed_gradient, chain)?; + Ok(trafo) + } + + fn transformation_id(&self, params: &Py) -> std::result::Result { + let id = self + .transform_adapter + .as_ref() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .transformation_id(params)?; + Ok(id) + } } pub struct PyTrace { expand: Py, - variables: Vec, + variables: Arc>, builder: StructBuilder, } @@ -184,7 +316,7 @@ impl PyTrace { pub fn new( rng: &mut R, chain: u64, - variables: Vec, + variables: Arc>, make_expand_func: &Py, capacity: usize, ) -> Result { @@ -234,6 +366,7 @@ impl TensorShape { #[pymethods] impl TensorShape { #[new] + #[pyo3(signature = (shape, dims=None))] fn py_new(shape: Vec, dims: Option>>) -> Result { let dims = dims.unwrap_or(shape.iter().map(|_| None).collect()); if dims.len() != shape.len() { @@ -318,7 +451,7 @@ impl ExpandDtype { impl DrawStorage for PyTrace { fn append_value(&mut self, point: &[f64]) -> Result<()> { Python::with_gil(|py| { - let point = PyArray1::from_slice_bound(py, point); + let point = PyArray1::from_slice(py, point); let full_point = self .expand .call1(py, (point,)) @@ -344,36 +477,51 @@ impl DrawStorage for PyTrace { self.builder.field_builder(i).context( "Builder has incorrect type", )?; - builder.append_value(value.extract().expect("Return value from expand function could not be converted to boolean")) + let value = value + .extract() + .expect("Return value from expand function could not be converted to boolean"); + builder.append_value(value) }, ExpandDtype::Float64 {} => { let builder: &mut Float64Builder = self.builder.field_builder(i).context( "Builder has incorrect type", )?; - builder.append_value(value.extract().expect("Return value from expand function could not be converted to float64")) + builder.append_value( + value + .extract() + .expect("Return value from expand function could not be converted to float64") + ) }, ExpandDtype::Float32 {} => { let builder: &mut Float32Builder = self.builder.field_builder(i).context( "Builder has incorrect type", )?; - builder.append_value(value.extract().expect("Return value from expand function could not be converted to float32")) - + builder.append_value( + value + .extract() + .expect("Return value from expand function could not be converted to float32") + ) }, ExpandDtype::Int64 {} => { let builder: &mut Int64Builder = self.builder.field_builder(i).context( "Builder has incorrect type", )?; - builder.append_value(value.extract().expect("Return value from expand function could not be converted to int64")) + let value = value.extract().expect("Return value from expand function could not be converted to int64"); + builder.append_value(value) }, ExpandDtype::BooleanArray { tensor_type } => { let builder: &mut LargeListBuilder> = self.builder.field_builder(i).context( "Builder has incorrect type. Expected LargeListBuilder of Bool", )?; - let value_builder = builder.values().as_any_mut().downcast_mut::().context("Could not downcast builder to boolean type")?; + let value_builder = builder + .values() + .as_any_mut() + .downcast_mut::() + .context("Could not downcast builder to boolean type")?; let values: PyReadonlyArray1 = value.extract().context("Could not convert object to array")?; if values.len()? != tensor_type.size() { bail!("Extracted array has incorrect shape"); @@ -387,7 +535,11 @@ impl DrawStorage for PyTrace { self.builder.field_builder(i).context( "Builder has incorrect type. Expected LargeListBuilder of Float64", )?; - let value_builder = builder.values().as_any_mut().downcast_mut::>().context("Could not downcast builder to float64 type")?; + let value_builder = builder + .values() + .as_any_mut() + .downcast_mut::>() + .context("Could not downcast builder to float64 type")?; let values: PyReadonlyArray1 = value.extract().context("Could not convert object to array")?; if values.len()? != tensor_type.size() { bail!("Extracted array has incorrect shape"); @@ -400,7 +552,11 @@ impl DrawStorage for PyTrace { self.builder.field_builder(i).context( "Builder has incorrect type. Expected LargeListBuilder of Float32", )?; - let value_builder = builder.values().as_any_mut().downcast_mut::>().context("Could not downcast builder to float32 type")?; + let value_builder = builder + .values() + .as_any_mut() + .downcast_mut::>() + .context("Could not downcast builder to float32 type")?; let values: PyReadonlyArray1 = value.extract().context("Could not convert object to array")?; if values.len()? != tensor_type.size() { bail!("Extracted array has incorrect shape"); @@ -413,7 +569,11 @@ impl DrawStorage for PyTrace { self.builder.field_builder(i).context( "Builder has incorrect type. Expected LargeListBuilder of Int64", )?; - let value_builder = builder.values().as_any_mut().downcast_mut::>().context("Could not downcast builder to i64 type")?; + let value_builder = builder + .values() + .as_any_mut() + .downcast_mut::>() + .context("Could not downcast builder to i64 type")?; let values: PyReadonlyArray1 = value.extract().context("Could not convert object to array")?; if values.len()? != tensor_type.size() { bail!("Extracted array has incorrect shape"); @@ -471,6 +631,7 @@ impl Model for PyModel { Ok(CpuMath::new(PyDensity::new( &self.make_logp_func, self.ndim, + self.transform_adapter.as_ref(), )?)) } @@ -480,7 +641,7 @@ impl Model for PyModel { position: &mut [f64], ) -> Result<()> { let Some(init_func) = self.init_point_func.as_ref() else { - let dist = Uniform::new(-2f64, 2f64); + let dist = Uniform::new(-2f64, 2f64).expect("Could not create uniform distribution"); position.iter_mut().for_each(|x| *x = dist.sample(rng)); return Ok(()); }; diff --git a/src/pymc.rs b/src/pymc.rs index 98426c7..b33b821 100644 --- a/src/pymc.rs +++ b/src/pymc.rs @@ -39,7 +39,7 @@ type RawExpandFunc = unsafe extern "C" fn( #[derive(Clone)] pub(crate) struct LogpFunc { func: RawLogpFunc, - _keep_alive: PyObject, + _keep_alive: Arc, user_data_ptr: UserData, dim: usize, } @@ -55,7 +55,7 @@ impl LogpFunc { unsafe { std::mem::transmute::<*const c_void, RawLogpFunc>(ptr as *const c_void) }; Self { func, - _keep_alive: keep_alive, + _keep_alive: Arc::new(keep_alive), user_data_ptr: user_data_ptr as UserData, dim, } @@ -66,7 +66,7 @@ impl LogpFunc { #[derive(Clone)] pub(crate) struct ExpandFunc { func: RawExpandFunc, - _keep_alive: PyObject, + _keep_alive: Arc, user_data_ptr: UserData, dim: usize, expanded_dim: usize, @@ -87,7 +87,7 @@ impl ExpandFunc { Self { dim, expanded_dim, - _keep_alive: keep_alive, + _keep_alive: Arc::new(keep_alive), user_data_ptr: user_data_ptr as UserData, func, } @@ -114,6 +114,7 @@ impl LogpError for ErrorCode { impl<'a> CpuLogpFunc for &'a LogpFunc { type LogpError = ErrorCode; + type TransformParams = (); fn dim(&self) -> usize { self.dim @@ -233,7 +234,7 @@ pub(crate) struct PyMcModel { dim: usize, density: LogpFunc, expand: ExpandFunc, - init_func: Py, + init_func: Arc>, var_sizes: Vec, var_names: Vec, } @@ -253,7 +254,7 @@ impl PyMcModel { dim, density, expand, - init_func, + init_func: init_func.into(), var_names: var_names.extract()?, var_sizes: var_sizes.extract()?, }) diff --git a/src/stan.rs b/src/stan.rs index 0564cd2..e258096 100644 --- a/src/stan.rs +++ b/src/stan.rs @@ -7,16 +7,19 @@ use arrow::datatypes::{DataType, Field}; use bridgestan::open_library; use itertools::{izip, Itertools}; use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Model, Settings}; +use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use pyo3::types::{PyDict, PyTuple}; use pyo3::{exceptions::PyValueError, pyclass, pymethods, PyResult}; use rand::prelude::Distribution; -use rand::{thread_rng, RngCore}; +use rand::{rng, RngCore}; use rand_distr::StandardNormal; use smallvec::{SmallVec, ToSmallVec}; use thiserror::Error; +use crate::wrapper::PyTransformAdapt; + type InnerModel = bridgestan::Model>; #[pyclass] @@ -53,8 +56,8 @@ impl StanVariable { } #[getter] - fn shape<'py>(&self, py: Python<'py>) -> Bound<'py, PyTuple> { - PyTuple::new_bound(py, self.0.shape.iter()) + fn shape<'py>(&self, py: Python<'py>) -> PyResult> { + PyTuple::new(py, self.0.shape.iter()) } #[getter] @@ -68,6 +71,7 @@ impl StanVariable { pub struct StanModel { model: Arc, variables: Vec, + transform_adapter: Option, } /// Return meta information about the constrained parameters of the model @@ -136,25 +140,41 @@ fn params( #[pymethods] impl StanModel { #[new] - pub fn new(lib: StanLibrary, seed: Option, data: Option) -> anyhow::Result { + #[pyo3(signature = (lib, seed=None, data=None, transform_adapter=None))] + pub fn new( + lib: StanLibrary, + seed: Option, + data: Option, + transform_adapter: Option>, + ) -> anyhow::Result { let seed = match seed { Some(seed) => seed, - None => thread_rng().next_u32(), + None => rng().next_u32(), }; let data: Option = data.map(CString::new).transpose()?; let model = Arc::new( bridgestan::Model::new(lib.0, data.as_ref(), seed).map_err(anyhow::Error::new)?, ); let variables = params(&model, true, true)?; - Ok(StanModel { model, variables }) + let transform_adapter = transform_adapter.map(PyTransformAdapt::new); + Ok(StanModel { + model, + variables, + transform_adapter, + }) } pub fn variables<'py>(&self, py: Python<'py>) -> PyResult> { - let out = PyDict::new_bound(py); + let out = PyDict::new(py); let results: Result, _> = self .variables .iter() - .map(|var| out.set_item(var.name.clone(), StanVariable(var.clone()).into_py(py))) + .map(|var| { + out.set_item( + var.name.clone(), + StanVariable(var.clone()).into_pyobject(py)?, + ) + }) .collect(); results?; Ok(out) @@ -195,7 +215,10 @@ impl StanModel { */ } -pub struct StanDensity<'model>(&'model InnerModel); +pub struct StanDensity<'model> { + inner: &'model InnerModel, + transform_adapter: Option, +} #[derive(Debug, Error)] pub enum StanLogpError { @@ -203,6 +226,10 @@ pub enum StanLogpError { BridgeStan(#[from] bridgestan::BridgeStanError), #[error("Bad logp value: {0}")] BadLogp(f64), + #[error("Python exception: {0}")] + PyErr(#[from] PyErr), + #[error("Unspecified Error: {0}")] + Anyhow(#[from] anyhow::Error), } impl LogpError for StanLogpError { @@ -213,9 +240,12 @@ impl LogpError for StanLogpError { impl<'model> CpuLogpFunc for StanDensity<'model> { type LogpError = StanLogpError; + type TransformParams = Py; fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result { - let logp = self.0.log_density_gradient(position, true, true, grad)?; + let logp = self + .inner + .log_density_gradient(position, true, true, grad)?; if !logp.is_finite() { return Err(StanLogpError::BadLogp(logp)); } @@ -223,7 +253,143 @@ impl<'model> CpuLogpFunc for StanDensity<'model> { } fn dim(&self) -> usize { - self.0.param_unc_num() + self.inner.param_unc_num() + } + + fn inv_transform_normalize( + &mut self, + params: &Py, + untransformed_position: &[f64], + untransformed_gradient: &[f64], + transformed_position: &mut [f64], + transformed_gradient: &mut [f64], + ) -> std::result::Result { + let logdet = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .inv_transform_normalize( + params, + untransformed_position, + untransformed_gradient, + transformed_position, + transformed_gradient, + ) + .context("failed inv_transform_normalize")?; + Ok(logdet) + } + + fn init_from_transformed_position( + &mut self, + params: &Py, + untransformed_position: &mut [f64], + untransformed_gradient: &mut [f64], + transformed_position: &[f64], + transformed_gradient: &mut [f64], + ) -> std::result::Result<(f64, f64), Self::LogpError> { + let adapter = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))?; + + let part1 = adapter + .init_from_transformed_position_part1( + params, + untransformed_position, + transformed_position, + ) + .context("Failed init_from_transformed_position_part1")?; + + let logp = self.logp(untransformed_position, untransformed_gradient)?; + + let adapter = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))?; + + let logdet = adapter + .init_from_transformed_position_part2( + params, + part1, + untransformed_gradient, + transformed_gradient, + ) + .context("Failed init_from_transformed_position_part2")?; + Ok((logp, logdet)) + } + + fn init_from_untransformed_position( + &mut self, + params: &Py, + untransformed_position: &[f64], + untransformed_gradient: &mut [f64], + transformed_position: &mut [f64], + transformed_gradient: &mut [f64], + ) -> std::result::Result<(f64, f64), Self::LogpError> { + let logp = self + .logp(untransformed_position, untransformed_gradient) + .context("Failed to call stan logp function")?; + + let logdet = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .inv_transform_normalize( + params, + untransformed_position, + untransformed_gradient, + transformed_position, + transformed_gradient, + ) + .context("Failed inv_transform_normalize in stan init_from_untransformed_position")?; + Ok((logp, logdet)) + } + + fn update_transformation<'a, R: rand::Rng + ?Sized>( + &'a mut self, + rng: &mut R, + untransformed_positions: impl ExactSizeIterator, + untransformed_gradients: impl ExactSizeIterator, + untransformed_logp: impl ExactSizeIterator, + params: &'a mut Py, + ) -> std::result::Result<(), Self::LogpError> { + self.transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .update_transformation( + rng, + untransformed_positions, + untransformed_gradients, + untransformed_logp, + params, + ) + .context("Failed to update the transformation")?; + Ok(()) + } + + fn new_transformation( + &mut self, + rng: &mut R, + untransformed_position: &[f64], + untransformed_gradient: &[f64], + chain: u64, + ) -> std::result::Result, Self::LogpError> { + let trafo = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .new_transformation(rng, untransformed_position, untransformed_gradient, chain) + .context("Could not create transformation adapter")?; + Ok(trafo) + } + + fn transformation_id(&self, params: &Py) -> std::result::Result { + let id = self + .transform_adapter + .as_ref() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .transformation_id(params)?; + Ok(id) } } @@ -387,7 +553,10 @@ impl Model for StanModel { } fn math(&self) -> anyhow::Result> { - Ok(CpuMath::new(StanDensity(&self.model))) + Ok(CpuMath::new(StanDensity { + inner: &self.model, + transform_adapter: self.transform_adapter.as_ref().map(|v| v.clone()), + })) } fn init_position( diff --git a/src/wrapper.rs b/src/wrapper.rs index 6f9ad49..bb1a523 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -1,6 +1,7 @@ use std::{ fmt::Debug, - sync::Arc, + ops::{Deref, DerefMut}, + sync::{Arc, Mutex}, time::{Duration, Instant}, }; @@ -13,17 +14,19 @@ use crate::{ use anyhow::{bail, Context, Result}; use arrow::array::Array; +use numpy::{PyArray1, PyReadonlyArray1}; use nuts_rs::{ - AdaptOptions, ChainProgress, DiagAdaptExpSettings, DiagGradNutsSettings, LowRankNutsSettings, - LowRankSettings, NutsSettings, ProgressCallback, Sampler, SamplerWaitResult, Trace, + ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, ProgressCallback, Sampler, + SamplerWaitResult, Trace, TransformedNutsSettings, }; use pyo3::{ exceptions::PyTimeoutError, ffi::Py_uintptr_t, + intern, prelude::*, types::{PyList, PyTuple}, }; -use rand::{thread_rng, RngCore}; +use rand::{rng, RngCore}; #[pyclass] struct PyChainProgress(ChainProgress); @@ -66,100 +69,23 @@ impl PyChainProgress { } } -#[derive(Clone)] -enum InnerSettings { - LowRank(LowRankSettings), - Diag(DiagAdaptExpSettings), -} - #[pyclass] #[derive(Clone)] pub struct PyNutsSettings { - settings: NutsSettings<()>, - adapt: InnerSettings, + inner: Settings, } +#[derive(Clone, Debug)] enum Settings { Diag(DiagGradNutsSettings), LowRank(LowRankNutsSettings), -} - -// Would be much nicer with -// https://doc.rust-lang.org/nightly/unstable-book/language-features/type-changing-struct-update.html -fn combine_settings( - inner: T, - settings: NutsSettings<()>, -) -> NutsSettings { - let adapt = AdaptOptions { - dual_average_options: settings.adapt_options.dual_average_options, - mass_matrix_options: inner, - early_window: settings.adapt_options.early_window, - step_size_window: settings.adapt_options.step_size_window, - mass_matrix_switch_freq: settings.adapt_options.mass_matrix_switch_freq, - early_mass_matrix_switch_freq: settings.adapt_options.early_mass_matrix_switch_freq, - mass_matrix_update_freq: settings.adapt_options.mass_matrix_update_freq, - }; - NutsSettings { - num_tune: settings.num_tune, - num_draws: settings.num_draws, - maxdepth: settings.maxdepth, - store_gradient: settings.store_gradient, - store_unconstrained: settings.store_unconstrained, - max_energy_error: settings.max_energy_error, - store_divergences: settings.store_divergences, - adapt_options: adapt, - check_turning: settings.check_turning, - num_chains: settings.num_chains, - seed: settings.seed, - } -} - -fn split_settings(settings: NutsSettings) -> (NutsSettings<()>, T) { - let adapt_settings = settings.adapt_options; - let mass_matrix_settings = adapt_settings.mass_matrix_options; - - let remaining: AdaptOptions<()> = AdaptOptions { - dual_average_options: adapt_settings.dual_average_options, - mass_matrix_options: (), - early_window: adapt_settings.early_window, - step_size_window: adapt_settings.step_size_window, - mass_matrix_switch_freq: adapt_settings.mass_matrix_switch_freq, - early_mass_matrix_switch_freq: adapt_settings.early_mass_matrix_switch_freq, - mass_matrix_update_freq: adapt_settings.mass_matrix_update_freq, - }; - - let settings = NutsSettings { - adapt_options: remaining, - num_tune: settings.num_tune, - num_draws: settings.num_draws, - maxdepth: settings.maxdepth, - store_gradient: settings.store_gradient, - store_unconstrained: settings.store_unconstrained, - max_energy_error: settings.max_energy_error, - store_divergences: settings.store_divergences, - check_turning: settings.check_turning, - num_chains: settings.num_chains, - seed: settings.seed, - }; - - (settings, mass_matrix_settings) + Transforming(TransformedNutsSettings), } impl PyNutsSettings { - fn into_settings(self) -> Settings { - match self.adapt { - InnerSettings::LowRank(mass_matrix) => { - Settings::LowRank(combine_settings(mass_matrix, self.settings)) - } - InnerSettings::Diag(mass_matrix) => { - Settings::Diag(combine_settings(mass_matrix, self.settings)) - } - } - } - fn new_diag(seed: Option) -> Self { let seed = seed.unwrap_or_else(|| { - let mut rng = thread_rng(); + let mut rng = rng(); rng.next_u64() }); let settings = DiagGradNutsSettings { @@ -167,17 +93,14 @@ impl PyNutsSettings { ..Default::default() }; - let (settings, inner) = split_settings(settings); - Self { - settings, - adapt: InnerSettings::Diag(inner), + inner: Settings::Diag(settings), } } fn new_low_rank(seed: Option) -> Self { let seed = seed.unwrap_or_else(|| { - let mut rng = thread_rng(); + let mut rng = rng(); rng.next_u64() }); let settings = LowRankNutsSettings { @@ -185,229 +108,467 @@ impl PyNutsSettings { ..Default::default() }; - let (settings, inner) = split_settings(settings); + Self { + inner: Settings::LowRank(settings), + } + } + + fn new_tranform_adapt(seed: Option) -> Self { + let seed = seed.unwrap_or_else(|| { + let mut rng = rng(); + rng.next_u64() + }); + let settings = TransformedNutsSettings { + seed, + ..Default::default() + }; Self { - settings, - adapt: InnerSettings::LowRank(inner), + inner: Settings::Transforming(settings), } } } +// TODO switch to serde to expose all the options... #[pymethods] impl PyNutsSettings { #[staticmethod] #[allow(non_snake_case)] + #[pyo3(signature = (seed=None))] fn Diag(seed: Option) -> Self { PyNutsSettings::new_diag(seed) } #[staticmethod] #[allow(non_snake_case)] + #[pyo3(signature = (seed=None))] fn LowRank(seed: Option) -> Self { PyNutsSettings::new_low_rank(seed) } + #[staticmethod] + #[allow(non_snake_case)] + #[pyo3(signature = (seed=None))] + fn Transform(seed: Option) -> Self { + PyNutsSettings::new_tranform_adapt(seed) + } + #[getter] fn num_tune(&self) -> u64 { - self.settings.num_tune + match &self.inner { + Settings::Diag(nuts_settings) => nuts_settings.num_tune, + Settings::LowRank(nuts_settings) => nuts_settings.num_tune, + Settings::Transforming(nuts_settings) => nuts_settings.num_tune, + } } #[setter(num_tune)] fn set_num_tune(&mut self, val: u64) { - self.settings.num_tune = val + match &mut self.inner { + Settings::Diag(nuts_settings) => nuts_settings.num_tune = val, + Settings::LowRank(nuts_settings) => nuts_settings.num_tune = val, + Settings::Transforming(nuts_settings) => nuts_settings.num_tune = val, + } } #[getter] fn num_chains(&self) -> usize { - self.settings.num_chains + match &self.inner { + Settings::Diag(nuts_settings) => nuts_settings.num_chains, + Settings::LowRank(nuts_settings) => nuts_settings.num_chains, + Settings::Transforming(nuts_settings) => nuts_settings.num_chains, + } } #[setter(num_chains)] fn set_num_chains(&mut self, val: usize) { - self.settings.num_chains = val; + match &mut self.inner { + Settings::Diag(nuts_settings) => nuts_settings.num_chains = val, + Settings::LowRank(nuts_settings) => nuts_settings.num_chains = val, + Settings::Transforming(nuts_settings) => nuts_settings.num_chains = val, + } } #[getter] fn num_draws(&self) -> u64 { - self.settings.num_draws + match &self.inner { + Settings::Diag(nuts_settings) => nuts_settings.num_draws, + Settings::LowRank(nuts_settings) => nuts_settings.num_draws, + Settings::Transforming(nuts_settings) => nuts_settings.num_draws, + } } #[setter(num_draws)] fn set_num_draws(&mut self, val: u64) { - self.settings.num_draws = val; + match &mut self.inner { + Settings::Diag(nuts_settings) => nuts_settings.num_draws = val, + Settings::LowRank(nuts_settings) => nuts_settings.num_draws = val, + Settings::Transforming(nuts_settings) => nuts_settings.num_draws = val, + } } #[getter] - fn window_switch_freq(&self) -> u64 { - self.settings.adapt_options.mass_matrix_switch_freq + fn window_switch_freq(&self) -> Result { + match &self.inner { + Settings::Diag(nuts_settings) => { + Ok(nuts_settings.adapt_options.mass_matrix_switch_freq) + } + Settings::LowRank(nuts_settings) => { + Ok(nuts_settings.adapt_options.mass_matrix_switch_freq) + } + Settings::Transforming(nuts_settings) => { + Ok(nuts_settings.adapt_options.transform_update_freq) + } + } } #[setter(window_switch_freq)] - fn set_window_switch_freq(&mut self, val: u64) { - self.settings.adapt_options.mass_matrix_switch_freq = val; + fn set_window_switch_freq(&mut self, val: u64) -> Result<()> { + match &mut self.inner { + Settings::Diag(nuts_settings) => { + nuts_settings.adapt_options.mass_matrix_switch_freq = val; + Ok(()) + } + Settings::LowRank(nuts_settings) => { + nuts_settings.adapt_options.mass_matrix_switch_freq = val; + Ok(()) + } + Settings::Transforming(nuts_settings) => { + nuts_settings.adapt_options.transform_update_freq = val; + Ok(()) + } + } } #[getter] - fn early_window_switch_freq(&self) -> u64 { - self.settings.adapt_options.early_mass_matrix_switch_freq + fn early_window_switch_freq(&self) -> Result { + match &self.inner { + Settings::Diag(nuts_settings) => { + Ok(nuts_settings.adapt_options.early_mass_matrix_switch_freq) + } + Settings::LowRank(nuts_settings) => { + Ok(nuts_settings.adapt_options.early_mass_matrix_switch_freq) + } + Settings::Transforming(_) => { + bail!("Option early_window_switch_freq not availbale for transformation adaptation") + } + } } #[setter(early_window_switch_freq)] - fn set_early_window_switch_freq(&mut self, val: u64) { - self.settings.adapt_options.early_mass_matrix_switch_freq = val; + fn set_early_window_switch_freq(&mut self, val: u64) -> Result<()> { + match &mut self.inner { + Settings::Diag(nuts_settings) => { + nuts_settings.adapt_options.early_mass_matrix_switch_freq = val; + Ok(()) + } + Settings::LowRank(nuts_settings) => { + nuts_settings.adapt_options.early_mass_matrix_switch_freq = val; + Ok(()) + } + Settings::Transforming(_) => { + bail!("Option early_window_switch_freq not availbale for transformation adaptation") + } + } } + #[getter] fn initial_step(&self) -> f64 { - self.settings - .adapt_options - .dual_average_options - .initial_step + match &self.inner { + Settings::Diag(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .initial_step + } + Settings::LowRank(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .initial_step + } + Settings::Transforming(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .initial_step + } + } } #[setter(initial_step)] fn set_initial_step(&mut self, val: f64) { - self.settings - .adapt_options - .dual_average_options - .initial_step = val + match &mut self.inner { + Settings::Diag(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .initial_step = val; + } + Settings::LowRank(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .initial_step = val; + } + Settings::Transforming(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .initial_step = val; + } + } } #[getter] fn maxdepth(&self) -> u64 { - self.settings.maxdepth + match &self.inner { + Settings::Diag(nuts_settings) => nuts_settings.maxdepth, + Settings::LowRank(nuts_settings) => nuts_settings.maxdepth, + Settings::Transforming(nuts_settings) => nuts_settings.maxdepth, + } } #[setter(maxdepth)] fn set_maxdepth(&mut self, val: u64) { - self.settings.maxdepth = val + match &mut self.inner { + Settings::Diag(nuts_settings) => nuts_settings.maxdepth = val, + Settings::LowRank(nuts_settings) => nuts_settings.maxdepth = val, + Settings::Transforming(nuts_settings) => nuts_settings.maxdepth = val, + } } #[getter] fn store_gradient(&self) -> bool { - self.settings.store_gradient + match &self.inner { + Settings::Diag(nuts_settings) => nuts_settings.store_gradient, + Settings::LowRank(nuts_settings) => nuts_settings.store_gradient, + Settings::Transforming(nuts_settings) => nuts_settings.store_gradient, + } } #[setter(store_gradient)] fn set_store_gradient(&mut self, val: bool) { - self.settings.store_gradient = val; + match &mut self.inner { + Settings::Diag(nuts_settings) => nuts_settings.store_gradient = val, + Settings::LowRank(nuts_settings) => nuts_settings.store_gradient = val, + Settings::Transforming(nuts_settings) => nuts_settings.store_gradient = val, + } } #[getter] fn store_unconstrained(&self) -> bool { - self.settings.store_unconstrained + match &self.inner { + Settings::Diag(nuts_settings) => nuts_settings.store_unconstrained, + Settings::LowRank(nuts_settings) => nuts_settings.store_unconstrained, + Settings::Transforming(nuts_settings) => nuts_settings.store_unconstrained, + } } #[setter(store_unconstrained)] fn set_store_unconstrained(&mut self, val: bool) { - self.settings.store_unconstrained = val; + match &mut self.inner { + Settings::Diag(nuts_settings) => nuts_settings.store_unconstrained = val, + Settings::LowRank(nuts_settings) => nuts_settings.store_unconstrained = val, + Settings::Transforming(nuts_settings) => nuts_settings.store_unconstrained = val, + } } #[getter] fn store_divergences(&self) -> bool { - self.settings.store_divergences + match &self.inner { + Settings::Diag(nuts_settings) => nuts_settings.store_divergences, + Settings::LowRank(nuts_settings) => nuts_settings.store_divergences, + Settings::Transforming(nuts_settings) => nuts_settings.store_divergences, + } } #[setter(store_divergences)] fn set_store_divergences(&mut self, val: bool) { - self.settings.store_divergences = val; + match &mut self.inner { + Settings::Diag(nuts_settings) => nuts_settings.store_divergences = val, + Settings::LowRank(nuts_settings) => nuts_settings.store_divergences = val, + Settings::Transforming(nuts_settings) => nuts_settings.store_divergences = val, + } } #[getter] fn max_energy_error(&self) -> f64 { - self.settings.max_energy_error + match &self.inner { + Settings::Diag(nuts_settings) => nuts_settings.max_energy_error, + Settings::LowRank(nuts_settings) => nuts_settings.max_energy_error, + Settings::Transforming(nuts_settings) => nuts_settings.max_energy_error, + } } #[setter(max_energy_error)] fn set_max_energy_error(&mut self, val: f64) { - self.settings.max_energy_error = val + match &mut self.inner { + Settings::Diag(nuts_settings) => nuts_settings.max_energy_error = val, + Settings::LowRank(nuts_settings) => nuts_settings.max_energy_error = val, + Settings::Transforming(nuts_settings) => nuts_settings.max_energy_error = val, + } } - #[setter(target_accept)] - fn set_target_accept(&mut self, val: f64) { - self.settings - .adapt_options - .dual_average_options - .target_accept = val; + #[getter] + fn set_target_accept(&self) -> f64 { + match &self.inner { + Settings::Diag(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .target_accept + } + Settings::LowRank(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .target_accept + } + Settings::Transforming(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .target_accept + } + } } - #[getter] - fn target_accept(&self) -> f64 { - self.settings - .adapt_options - .dual_average_options - .target_accept + #[setter(target_accept)] + fn target_accept(&mut self, val: f64) { + match &mut self.inner { + Settings::Diag(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .target_accept = val + } + Settings::LowRank(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .target_accept = val + } + Settings::Transforming(nuts_settings) => { + nuts_settings + .adapt_options + .dual_average_options + .target_accept = val + } + } } #[getter] - fn store_mass_matrix(&self) -> bool { - match &self.adapt { - InnerSettings::LowRank(low_rank) => low_rank.store_mass_matrix, - InnerSettings::Diag(diag) => diag.store_mass_matrix, + fn store_mass_matrix(&self) -> Result { + match &self.inner { + Settings::LowRank(settings) => { + Ok(settings.adapt_options.mass_matrix_options.store_mass_matrix) + } + Settings::Diag(settings) => { + Ok(settings.adapt_options.mass_matrix_options.store_mass_matrix) + } + Settings::Transforming(_) => { + bail!("Option store_mass_matrix not availbale for transformation adaptation") + } } } #[setter(store_mass_matrix)] - fn set_store_mass_matrix(&mut self, val: bool) { - match &mut self.adapt { - InnerSettings::LowRank(low_rank) => { - low_rank.store_mass_matrix = val; + fn set_store_mass_matrix(&mut self, val: bool) -> Result<()> { + match &mut self.inner { + Settings::LowRank(settings) => { + settings.adapt_options.mass_matrix_options.store_mass_matrix = val; + Ok(()) + } + Settings::Diag(settings) => { + settings.adapt_options.mass_matrix_options.store_mass_matrix = val; + Ok(()) } - InnerSettings::Diag(diag) => { - diag.store_mass_matrix = val; + Settings::Transforming(_) => { + bail!("Option store_mass_matrix not availbale for transformation adaptation") } } } #[getter] fn use_grad_based_mass_matrix(&self) -> Result { - match &self.adapt { - InnerSettings::LowRank(_) => { - bail!("grad based mass matrix not available for low-rank adaptation") + match &self.inner { + Settings::LowRank(_) => { + bail!("non-grad based mass matrix not available for low-rank adaptation") + } + Settings::Transforming(_) => { + bail!("non-grad based mass matrix not available for transforming adaptation") } - InnerSettings::Diag(diag) => Ok(diag.use_grad_based_estimate), + Settings::Diag(diag) => Ok(diag + .adapt_options + .mass_matrix_options + .use_grad_based_estimate), } } #[setter(use_grad_based_mass_matrix)] fn set_use_grad_based_mass_matrix(&mut self, val: bool) -> Result<()> { - match &mut self.adapt { - InnerSettings::LowRank(_) => { - bail!("grad based mass matrix not available for low-rank adaptation") + match &mut self.inner { + Settings::LowRank(_) => { + bail!("non-grad based mass matrix not available for low-rank adaptation"); + } + Settings::Transforming(_) => { + bail!("non-grad based mass matrix not available for transforming adaptation"); } - InnerSettings::Diag(diag) => { - diag.use_grad_based_estimate = val; + Settings::Diag(diag) => { + diag.adapt_options + .mass_matrix_options + .use_grad_based_estimate = val; } } Ok(()) } #[getter] - fn mass_matrix_switch_freq(&self) -> u64 { - self.settings.adapt_options.mass_matrix_switch_freq + fn mass_matrix_switch_freq(&self) -> Result { + match &self.inner { + Settings::Diag(settings) => Ok(settings.adapt_options.mass_matrix_switch_freq), + Settings::LowRank(settings) => Ok(settings.adapt_options.mass_matrix_switch_freq), + Settings::Transforming(_) => { + bail!("mass_matrix_switch_freq not available for transforming adaptation"); + } + } } #[setter(mass_matrix_switch_freq)] - fn set_mass_matrix_switch_freq(&mut self, val: u64) { - self.settings.adapt_options.mass_matrix_switch_freq = val; + fn set_mass_matrix_switch_freq(&mut self, val: u64) -> Result<()> { + match &mut self.inner { + Settings::Diag(settings) => settings.adapt_options.mass_matrix_switch_freq = val, + Settings::LowRank(settings) => settings.adapt_options.mass_matrix_switch_freq = val, + Settings::Transforming(_) => { + bail!("mass_matrix_switch_freq not available for transforming adaptation"); + } + } + Ok(()) } #[getter] fn mass_matrix_eigval_cutoff(&self) -> Result { - match &self.adapt { - InnerSettings::LowRank(inner) => Ok(inner.eigval_cutoff), - InnerSettings::Diag(_) => { - bail!("eigenvalue cutoff not available for diag mass matrix adaptation") + match &self.inner { + Settings::LowRank(inner) => Ok(inner.adapt_options.mass_matrix_options.eigval_cutoff), + Settings::Diag(_) => { + bail!("eigenvalue cutoff not available for diag mass matrix adaptation"); + } + Settings::Transforming(_) => { + bail!("eigenvalue cutoff not available for transfor adaptation"); } } } #[setter(mass_matrix_eigval_cutoff)] fn set_mass_matrix_eigval_cutoff(&mut self, val: f64) -> Result<()> { - match &mut self.adapt { - InnerSettings::LowRank(inner) => inner.eigval_cutoff = val, - InnerSettings::Diag(_) => { - bail!("eigenvalue cutoff not available for diag mass matrix adaptation") + match &mut self.inner { + Settings::LowRank(inner) => inner.adapt_options.mass_matrix_options.eigval_cutoff = val, + Settings::Diag(_) => { + bail!("eigenvalue cutoff not available for diag mass matrix adaptation"); + } + Settings::Transforming(_) => { + bail!("eigenvalue cutoff not available for transfor adaptation"); } } Ok(()) @@ -415,21 +576,56 @@ impl PyNutsSettings { #[getter] fn mass_matrix_gamma(&self) -> Result { - match &self.adapt { - InnerSettings::LowRank(inner) => Ok(inner.gamma), - InnerSettings::Diag(_) => { - bail!("gamma not available for diag mass matrix adaptation") + match &self.inner { + Settings::LowRank(inner) => Ok(inner.adapt_options.mass_matrix_options.gamma), + Settings::Diag(_) => { + bail!("gamma not available for diag mass matrix adaptation"); + } + Settings::Transforming(_) => { + bail!("gamma not available for transform adaptation"); } } } #[setter(mass_matrix_gamma)] fn set_mass_matrix_gamma(&mut self, val: f64) -> Result<()> { - match &mut self.adapt { - InnerSettings::LowRank(inner) => inner.gamma = val, - InnerSettings::Diag(_) => { - bail!("gamma not available for diag mass matrix adaptation") + match &mut self.inner { + Settings::LowRank(inner) => { + inner.adapt_options.mass_matrix_options.gamma = val; + } + Settings::Diag(_) => { + bail!("gamma not available for diag mass matrix adaptation"); + } + Settings::Transforming(_) => { + bail!("gamma not available for transform adaptation"); + } + } + Ok(()) + } + + #[getter] + fn train_on_orbit(&self) -> Result { + match &self.inner { + Settings::LowRank(_) => { + bail!("gamma not available for low rank mass matrix adaptation"); + } + Settings::Diag(_) => { + bail!("gamma not available for diag mass matrix adaptation"); } + Settings::Transforming(inner) => Ok(inner.adapt_options.use_orbit_for_training), + } + } + + #[setter(train_on_orbit)] + fn set_train_on_orbit(&mut self, val: bool) -> Result<()> { + match &mut self.inner { + Settings::LowRank(_) => { + bail!("gamma not available for low rank mass matrix adaptation"); + } + Settings::Diag(_) => { + bail!("gamma not available for diag mass matrix adaptation"); + } + Settings::Transforming(inner) => inner.adapt_options.use_orbit_for_training = val, } Ok(()) } @@ -442,13 +638,12 @@ pub(crate) enum SamplerState { } #[derive(Clone)] -#[pyclass] -pub enum ProgressType { +enum InnerProgressType { Callback { rate: Duration, n_cores: usize, template: String, - callback: Py, + callback: Arc>, }, Indicatif { rate: Duration, @@ -456,10 +651,14 @@ pub enum ProgressType { None {}, } +#[pyclass] +#[derive(Clone)] +pub struct ProgressType(InnerProgressType); + impl ProgressType { fn into_callback(self) -> Result> { - match self { - ProgressType::Callback { + match self.0 { + InnerProgressType::Callback { callback, rate, n_cores, @@ -470,11 +669,11 @@ impl ProgressType { Ok(Some(callback)) } - ProgressType::Indicatif { rate } => { + InnerProgressType::Indicatif { rate } => { let handler = IndicatifHandler::new(rate); Ok(Some(handler.into_callback()?)) } - ProgressType::None {} => Ok(None), + InnerProgressType::None {} => Ok(None), } } } @@ -484,28 +683,28 @@ impl ProgressType { #[staticmethod] fn indicatif(rate: u64) -> Self { let rate = Duration::from_millis(rate); - ProgressType::Indicatif { rate } + ProgressType(InnerProgressType::Indicatif { rate }) } #[staticmethod] fn none() -> Self { - ProgressType::None {} + ProgressType(InnerProgressType::None {}) } #[staticmethod] fn template_callback(rate: u64, template: String, n_cores: usize, callback: Py) -> Self { let rate = Duration::from_millis(rate); - ProgressType::Callback { - callback, + ProgressType(InnerProgressType::Callback { + callback: Arc::new(callback), template, n_cores, rate, - } + }) } } #[pyclass] -struct PySampler(SamplerState); +struct PySampler(Mutex); #[pymethods] impl PySampler { @@ -517,14 +716,18 @@ impl PySampler { progress_type: ProgressType, ) -> PyResult { let callback = progress_type.into_callback()?; - match settings.into_settings() { + match settings.inner { Settings::LowRank(settings) => { let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler))) + Ok(PySampler(SamplerState::Running(sampler).into())) } Settings::Diag(settings) => { let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler))) + Ok(PySampler(SamplerState::Running(sampler).into())) + } + Settings::Transforming(settings) => { + let sampler = Sampler::new(model, settings, cores, callback)?; + Ok(PySampler(SamplerState::Running(sampler).into())) } } } @@ -537,14 +740,18 @@ impl PySampler { progress_type: ProgressType, ) -> PyResult { let callback = progress_type.into_callback()?; - match settings.into_settings() { + match settings.inner { Settings::LowRank(settings) => { let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler))) + Ok(PySampler(SamplerState::Running(sampler).into())) } Settings::Diag(settings) => { let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler))) + Ok(PySampler(SamplerState::Running(sampler).into())) + } + Settings::Transforming(settings) => { + let sampler = Sampler::new(model, settings, cores, callback)?; + Ok(PySampler(SamplerState::Running(sampler).into())) } } } @@ -557,38 +764,45 @@ impl PySampler { progress_type: ProgressType, ) -> PyResult { let callback = progress_type.into_callback()?; - match settings.into_settings() { + match settings.inner { Settings::LowRank(settings) => { let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler))) + Ok(PySampler(SamplerState::Running(sampler).into())) } Settings::Diag(settings) => { let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler))) + Ok(PySampler(SamplerState::Running(sampler).into())) + } + Settings::Transforming(settings) => { + let sampler = Sampler::new(model, settings, cores, callback)?; + Ok(PySampler(SamplerState::Running(sampler).into())) } } } fn is_finished(&mut self, py: Python<'_>) -> PyResult { py.allow_threads(|| { - let state = std::mem::replace(&mut self.0, SamplerState::Empty); + let guard = &mut self.0.lock().expect("Poisond sampler state mutex"); + let slot = guard.deref_mut(); + + let state = std::mem::replace(slot, SamplerState::Empty); let SamplerState::Running(sampler) = state else { - let _ = std::mem::replace(&mut self.0, state); + let _ = std::mem::replace(slot, state); return Ok(true); }; match sampler.wait_timeout(Duration::from_millis(1)) { SamplerWaitResult::Trace(trace) => { - let _ = std::mem::replace(&mut self.0, SamplerState::Finished(Some(trace))); + let _ = std::mem::replace(slot, SamplerState::Finished(Some(trace))); Ok(true) } SamplerWaitResult::Timeout(sampler) => { - let _ = std::mem::replace(&mut self.0, SamplerState::Running(sampler)); + let _ = std::mem::replace(slot, SamplerState::Running(sampler)); Ok(false) } SamplerWaitResult::Err(err, trace) => { - let _ = std::mem::replace(&mut self.0, SamplerState::Finished(trace)); + let _ = std::mem::replace(slot, SamplerState::Finished(trace)); Err(err.into()) } } @@ -597,7 +811,12 @@ impl PySampler { fn pause(&mut self, py: Python<'_>) -> PyResult<()> { py.allow_threads(|| { - if let SamplerState::Running(ref mut control) = self.0 { + if let SamplerState::Running(ref mut control) = self + .0 + .lock() + .expect("Poised sampler state mutex") + .deref_mut() + { control.pause()? } Ok(()) @@ -606,24 +825,33 @@ impl PySampler { fn resume(&mut self, py: Python<'_>) -> PyResult<()> { py.allow_threads(|| { - if let SamplerState::Running(ref mut control) = self.0 { + if let SamplerState::Running(ref mut control) = self + .0 + .lock() + .expect("Poisond sampler state mutex") + .deref_mut() + { control.resume()? } Ok(()) }) } + #[pyo3(signature = (timeout_seconds=None))] fn wait(&mut self, py: Python<'_>, timeout_seconds: Option) -> PyResult<()> { py.allow_threads(|| { + let guard = &mut self.0.lock().expect("Poisond sampler state mutex"); + let slot = guard.deref_mut(); + let timeout = match timeout_seconds { Some(val) => Some(Duration::try_from_secs_f64(val).context("Invalid timeout")?), None => None, }; - let state = std::mem::replace(&mut self.0, SamplerState::Empty); + let state = std::mem::replace(slot, SamplerState::Empty); let SamplerState::Running(mut control) = state else { - let _ = std::mem::replace(&mut self.0, state); + let _ = std::mem::replace(slot, state); return Ok(()); }; @@ -664,32 +892,38 @@ impl PySampler { } }; - let _ = std::mem::replace(&mut self.0, final_state); + let _ = std::mem::replace(slot, final_state); retval }) } fn abort(&mut self, py: Python<'_>) -> PyResult<()> { py.allow_threads(|| { - let state = std::mem::replace(&mut self.0, SamplerState::Empty); + let guard = &mut self.0.lock().expect("Poisond sampler state mutex"); + let slot = guard.deref_mut(); + + let state = std::mem::replace(slot, SamplerState::Empty); let SamplerState::Running(control) = state else { - let _ = std::mem::replace(&mut self.0, state); + let _ = std::mem::replace(slot, state); return Ok(()); }; let (result, trace) = control.abort(); - let _ = std::mem::replace(&mut self.0, SamplerState::Finished(trace)); + let _ = std::mem::replace(slot, SamplerState::Finished(trace)); result?; Ok(()) }) } fn extract_results<'py>(&mut self, py: Python<'py>) -> PyResult> { - let state = std::mem::replace(&mut self.0, SamplerState::Empty); + let guard = &mut self.0.lock().expect("Poisond sampler state mutex"); + let slot = guard.deref_mut(); + + let state = std::mem::replace(slot, SamplerState::Empty); let SamplerState::Finished(trace) = state else { - let _ = std::mem::replace(&mut self.0, state); + let _ = std::mem::replace(slot, state); return Err(anyhow::anyhow!("Sampler is not finished"))?; }; @@ -703,7 +937,7 @@ impl PySampler { } fn is_empty(&self) -> bool { - match self.0 { + match self.0.lock().expect("Poisoned sampler state lock").deref() { SamplerState::Running(_) => false, SamplerState::Finished(_) => false, SamplerState::Empty => true, @@ -712,7 +946,8 @@ impl PySampler { fn inspect<'py>(&mut self, py: Python<'py>) -> PyResult> { let trace = py.allow_threads(|| { - let SamplerState::Running(ref mut sampler) = self.0 else { + let mut guard = self.0.lock().unwrap(); + let SamplerState::Running(ref mut sampler) = guard.deref_mut() else { return Err(anyhow::anyhow!("Sampler is not running"))?; }; @@ -723,28 +958,28 @@ impl PySampler { } fn trace_to_list(trace: Trace, py: Python<'_>) -> PyResult> { - let list = PyList::new_bound( + let list = PyList::new( py, trace .chains .into_iter() .map(|chain| { - Ok(PyTuple::new_bound( + Ok(PyTuple::new( py, [ export_array(py, chain.draws)?, export_array(py, chain.stats)?, ] .into_iter(), - )) + )?) }) .collect::>>()?, - ); + )?; Ok(list) } fn export_array(py: Python<'_>, data: Arc) -> PyResult { - let pa = py.import_bound("pyarrow")?; + let pa = py.import("pyarrow")?; let array = pa.getattr("Array")?; let data = data.into_data(); @@ -755,12 +990,252 @@ fn export_array(py: Python<'_>, data: Arc) -> PyResult { .call_method1( "_import_from_c", ( - (&data as *const _ as Py_uintptr_t).into_py(py), - (&schema as *const _ as Py_uintptr_t).into_py(py), + (&data as *const _ as Py_uintptr_t).into_pyobject(py)?, + (&schema as *const _ as Py_uintptr_t).into_pyobject(py)?, ), ) .context("Could not import arrow trace in python")?; - Ok(data.into_py(py)) + Ok(data.unbind()) +} + +#[pyclass] +#[derive(Debug, Clone)] +pub struct PyTransformAdapt(Arc>); + +#[pymethods] +impl PyTransformAdapt { + #[new] + pub fn new(adapter: Py) -> Self { + Self(Arc::new(adapter)) + } +} + +impl PyTransformAdapt { + pub fn inv_transform_normalize( + &mut self, + params: &Py, + untransformed_position: &[f64], + untransformed_gradient: &[f64], + transformed_position: &mut [f64], + transformed_gradient: &mut [f64], + ) -> Result { + Python::with_gil(|py| { + let untransformed_position = PyArray1::from_slice(py, untransformed_position); + let untransformed_gradient = PyArray1::from_slice(py, untransformed_gradient); + + let output = params + .getattr(py, intern!(py, "inv_transform")) + .context("Could not access attribute inv_transform")? + .call1(py, (untransformed_position, untransformed_gradient)) + .context("Failed to call adapter.inv_transform")?; + let (logdet, transformed_position_out, transformed_gradient_out): ( + f64, + PyReadonlyArray1, + PyReadonlyArray1, + ) = output + .extract(py) + .context("Execpected results from adapter.inv_transform")?; + + if !transformed_position_out + .as_slice()? + .iter() + .all(|&x| x.is_finite()) + { + bail!("Transformed position is not finite"); + } + if !transformed_gradient_out + .as_slice()? + .iter() + .all(|&x| x.is_finite()) + { + bail!("Transformed position is not finite"); + } + + transformed_position.copy_from_slice( + transformed_position_out + .as_slice() + .context("Could not copy transformed_position")?, + ); + + transformed_gradient.copy_from_slice( + transformed_gradient_out + .as_slice() + .context("Could not copy transformed_gradient")?, + ); + Ok(logdet) + }) + } + + pub fn init_from_transformed_position( + &mut self, + params: &Py, + untransformed_position: &mut [f64], + untransformed_gradient: &mut [f64], + transformed_position: &[f64], + transformed_gradient: &mut [f64], + ) -> Result<(f64, f64)> { + Python::with_gil(|py| { + let transformed_position = PyArray1::from_slice(py, transformed_position); + + let output = params + .getattr(py, intern!(py, "init_from_transformed_position"))? + .call1(py, (transformed_position,))?; + let ( + logp, + logdet, + untransformed_position_out, + untransformed_gradient_out, + transformed_gradient_out, + ): ( + f64, + f64, + PyReadonlyArray1, + PyReadonlyArray1, + PyReadonlyArray1, + ) = output.extract(py)?; + + untransformed_position.copy_from_slice(untransformed_position_out.as_slice()?); + untransformed_gradient.copy_from_slice(untransformed_gradient_out.as_slice()?); + transformed_gradient.copy_from_slice(transformed_gradient_out.as_slice()?); + Ok((logp, logdet)) + }) + } + + pub fn init_from_transformed_position_part1( + &mut self, + params: &Py, + untransformed_position: &mut [f64], + transformed_position: &[f64], + ) -> Result> { + Python::with_gil(|py| { + let transformed_position = PyArray1::from_slice(py, transformed_position); + + let output = params + .getattr(py, intern!(py, "init_from_transformed_position_part1"))? + .call1(py, (transformed_position,))?; + let (untransformed_position_out, part1): (PyReadonlyArray1, Py) = + output.extract(py)?; + + untransformed_position.copy_from_slice(untransformed_position_out.as_slice()?); + Ok(part1) + }) + } + + pub fn init_from_transformed_position_part2( + &mut self, + params: &Py, + part1: Py, + untransformed_gradient: &[f64], + transformed_gradient: &mut [f64], + ) -> Result { + Python::with_gil(|py| { + let untransformed_gradient = PyArray1::from_slice(py, untransformed_gradient); + + let output = params + .getattr(py, intern!(py, "init_from_transformed_position_part2"))? + .call1(py, (part1, untransformed_gradient))?; + let (logdet, transformed_gradient_out): (f64, PyReadonlyArray1) = + output.extract(py)?; + + transformed_gradient.copy_from_slice(transformed_gradient_out.as_slice()?); + Ok(logdet) + }) + } + + pub fn init_from_untransformed_position( + &mut self, + params: &Py, + untransformed_position: &[f64], + untransformed_gradient: &mut [f64], + transformed_position: &mut [f64], + transformed_gradient: &mut [f64], + ) -> Result<(f64, f64)> { + Python::with_gil(|py| { + let untransformed_position = PyArray1::from_slice(py, untransformed_position); + + let output = params + .getattr(py, intern!(py, "init_from_untransformed_position")) + .context("No attribute init_from_untransformed_position")? + .call1(py, (untransformed_position,)) + .context("Failed adapter.init_from_untransformed_position")?; + let ( + logp, + logdet, + untransformed_gradient_out, + transformed_position_out, + transformed_gradient_out, + ): ( + f64, + f64, + PyReadonlyArray1, + PyReadonlyArray1, + PyReadonlyArray1, + ) = output + .extract(py) + .context("Unexpected return value of init_from_untransformed_position")?; + + untransformed_gradient.copy_from_slice(untransformed_gradient_out.as_slice()?); + transformed_position.copy_from_slice(transformed_position_out.as_slice()?); + transformed_gradient.copy_from_slice(transformed_gradient_out.as_slice()?); + Ok((logp, logdet)) + }) + } + + pub fn update_transformation<'a, R: rand::Rng + ?Sized>( + &'a mut self, + rng: &mut R, + untransformed_positions: impl ExactSizeIterator, + untransformed_gradients: impl ExactSizeIterator, + untransformed_logp: impl ExactSizeIterator, + params: &'a mut Py, + ) -> Result<()> { + Python::with_gil(|py| { + let positions = PyList::new( + py, + untransformed_positions.map(|pos| PyArray1::from_slice(py, pos)), + )?; + let gradients = PyList::new( + py, + untransformed_gradients.map(|grad| PyArray1::from_slice(py, grad)), + )?; + + let logps = PyArray1::from_iter(py, untransformed_logp.copied()); + let seed = rng.next_u64(); + + params + .getattr(py, intern!(py, "update"))? + .call1(py, (seed, positions, gradients, logps))?; + Ok(()) + }) + } + + pub fn new_transformation( + &mut self, + rng: &mut R, + untransformed_position: &[f64], + untransformed_gradient: &[f64], + chain: u64, + ) -> Result> { + Python::with_gil(|py| { + let position = PyArray1::from_slice(py, untransformed_position); + let gradient = PyArray1::from_slice(py, untransformed_gradient); + + let seed = rng.next_u64(); + + let transformer = self.0.call1(py, (seed, position, gradient, chain))?; + + Ok(transformer) + }) + } + + pub fn transformation_id(&self, params: &Py) -> Result { + Python::with_gil(|py| { + let id: i64 = params + .getattr(py, intern!(py, "transformation_id"))? + .extract(py)?; + Ok(id) + }) + } } /// A Python module implemented in Rust. diff --git a/tests/test_pymc.py b/tests/test_pymc.py index a586e24..9b8939f 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -243,6 +243,21 @@ def test_pymc_var_names(backend, gradient_backend): assert not hasattr(trace.posterior, "c") +def test_normalizing_flow(): + with pm.Model() as model: + a = pm.Uniform("a/b", shape=2) + with pm.Model("foo"): + c = pm.Data("c", np.array([2.0, 3.0])) + pm.Deterministic("b", c * a) + + compiled = nutpie.compile_pymc_model( + model, backend="jax", gradient_backend="jax" + ).with_transform_adapt() + trace = nutpie.sample(compiled, chains=1, transform_adapt=True) + assert trace.posterior["a/b"].shape[-1] == 2 + assert trace.posterior["foo::b"].shape[-1] == 2 + + @pytest.mark.parametrize( ("backend", "gradient_backend"), [