diff --git a/Cargo.lock b/Cargo.lock index 6dd8489..e702c9c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -48,21 +48,21 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstyle" -version = "1.0.7" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "anyhow" -version = "1.0.86" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" [[package]] name = "arrow" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6127ea5e585a12ec9f742232442828ebaf264dfa5eefdd71282376c599562b77" +checksum = "05048a8932648b63f21c37d88b552ccc8a65afb6dfe9fc9f30ce79174c2e7a85" dependencies = [ "arrow-arith", "arrow-array", @@ -78,9 +78,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7add7f39210b7d726e2a8efc0083e7bf06e8f2d15bdb4896b564dce4410fbf5d" +checksum = "1d8a57966e43bfe9a3277984a14c24ec617ad874e4c0e1d2a1b083a39cfbf22c" dependencies = [ "arrow-array", "arrow-buffer", @@ -93,9 +93,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81c16ec702d3898c2f5cfdc148443c6cd7dbe5bac28399859eb0a3d38f072827" +checksum = "16f4a9468c882dc66862cef4e1fd8423d47e67972377d85d80e022786427768c" dependencies = [ "ahash", "arrow-buffer", @@ -109,9 +109,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cae6970bab043c4fbc10aee1660ceb5b306d0c42c8cc5f6ae564efcd9759b663" +checksum = "c975484888fc95ec4a632cdc98be39c085b1bb518531b0c80c5d462063e5daa1" dependencies = [ "bytes", "half", @@ -120,9 +120,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c7ef44f26ef4f8edc392a048324ed5d757ad09135eff6d5509e6450d39e0398" +checksum = "da26719e76b81d8bc3faad1d4dbdc1bcc10d14704e63dc17fc9f3e7e1e567c8e" dependencies = [ "arrow-array", "arrow-buffer", @@ -140,9 +140,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a769666ffac256dd301006faca1ca553d0ae7cffcf4cd07095f73f95eb226514" +checksum = "dd9d6f18c65ef7a2573ab498c374d8ae364b4a4edf67105357491c031f716ca5" dependencies = [ "arrow-buffer", "arrow-schema", @@ -152,9 +152,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8008370e624e8e3c68174faaf793540287106cfda8ad1da862fdc53d8e096b4" +checksum = "42745f86b1ab99ef96d1c0bcf49180848a64fe2c7a7a0d945bc64fa2b21ba9bc" dependencies = [ "arrow-array", "arrow-buffer", @@ -167,9 +167,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca5e3a6b7fda8d9fe03f3b18a2d946354ea7f3c8e4076dbdb502ad50d9d44824" +checksum = "4cd09a518c602a55bd406bcc291a967b284cfa7a63edfbf8b897ea4748aad23c" dependencies = [ "ahash", "arrow-array", @@ -177,23 +177,22 @@ dependencies = [ "arrow-data", "arrow-schema", "half", - "hashbrown", ] [[package]] name = "arrow-schema" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dab1c12b40e29d9f3b699e0203c2a73ba558444c05e388a4377208f8f9c97eee" +checksum = "9e972cd1ff4a4ccd22f86d3e53e835c2ed92e0eea6a3e8eadb72b4f1ac802cf8" dependencies = [ "bitflags 2.6.0", ] [[package]] name = "arrow-select" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e80159088ffe8c48965cb9b1a7c968b2729f29f37363df7eca177fc3281fe7c3" +checksum = "600bae05d43483d216fb3494f8c32fdbefd8aa4e1de237e790dbb3d9f44690a3" dependencies = [ "ahash", "arrow-array", @@ -205,9 +204,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fd04a6ea7de183648edbcb7a6dd925bbd04c210895f6384c780e27a9b54afcd" +checksum = "f0dc1985b67cb45f6606a248ac2b4a288849f196bab8c657ea5589f47cdd55e6" dependencies = [ "arrow-array", "arrow-buffer", @@ -231,9 +230,9 @@ dependencies = [ [[package]] name = "autocfg" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "base64" @@ -243,9 +242,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "bindgen" -version = "0.69.4" +version = "0.69.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a00dc851838a2120612785d195287475a3ac45514741da670b735818822129a0" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" dependencies = [ "bitflags 2.6.0", "cexpr", @@ -260,7 +259,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.72", + "syn 2.0.87", "which", ] @@ -306,22 +305,22 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.16.1" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e" +checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d" dependencies = [ "bytemuck_derive", ] [[package]] name = "bytemuck_derive" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ee891b04274a59bd38b412188e24b849617b2e45a0fd8d057deb63e7403761b" +checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.87", ] [[package]] @@ -332,9 +331,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.6.1" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a12916984aab3fa6e39d655a33e09c0071eb36d6ab3aea5c2d78551f1df6d952" +checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" [[package]] name = "cast" @@ -344,9 +343,12 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.1.6" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2aba8f4e9906c7ce3c73463f62a7f0c65183ada1a2d47e397cc8810827f9694f" +checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" +dependencies = [ + "shlex", +] [[package]] name = "cexpr" @@ -415,18 +417,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.10" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f6b81fb3c84f5563d509c59b5a48d935f689e993afa90fe39047f05adef9142" +checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.10" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ca6706fd5224857d9ac5eb9355f6683563cc0541c7cd9d014043b57cbec78ac" +checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" dependencies = [ "anstyle", "clap_lex", @@ -434,9 +436,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.1" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70" +checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" [[package]] name = "coe-rs" @@ -453,8 +455,8 @@ dependencies = [ "encode_unicode", "lazy_static", "libc", - "unicode-width", - "windows-sys", + "unicode-width 0.1.14", + "windows-sys 0.52.0", ] [[package]] @@ -479,15 +481,15 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpufeatures" -version = "0.2.12" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +checksum = "0ca741a962e1b0bff6d724a1a0958b686406e853bb14061f218562e1896f95e6" dependencies = [ "libc", ] @@ -609,14 +611,14 @@ checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" [[package]] name = "enum-as-inner" -version = "0.6.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ffccbb6966c05b32ef8fbac435df276c4ae4d3dc55a8cd0eb9745e6c12f546a" +checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.87", ] [[package]] @@ -625,7 +627,16 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c35da53b5a021d2484a7cc49b2ac7f2d840f8236a286f84202369bd338d761ea" dependencies = [ - "equator-macro", + "equator-macro 0.2.1", +] + +[[package]] +name = "equator" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5099e7b6f0b7431c7a1c49f75929e2777693da192784f167066977a2965767af" +dependencies = [ + "equator-macro 0.4.1", ] [[package]] @@ -636,7 +647,18 @@ checksum = "3bf679796c0322556351f287a51b49e48f7c4986e727b5dd78c972d30e2e16cc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.87", +] + +[[package]] +name = "equator-macro" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5322a90066ddae2b705096eb9e10c465c0498ae93bf9bdd6437415327c88e3bb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", ] [[package]] @@ -646,20 +668,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] name = "faer" -version = "0.19.1" +version = "0.19.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41543c4de4bfb32efdffdd75cbcca5ef41b800e8a811ea4a41fb9393c6ef3bc0" +checksum = "64bc4855cb2792ae3520e8af22051a47a6d6dc8300ebc0ddf51ad73f65bd0dc9" dependencies = [ "bytemuck", "coe-rs", "dbgf", "dyn-stack", - "equator", + "equator 0.4.1", "faer-entity", "gemm", "libm", @@ -679,9 +701,9 @@ dependencies = [ [[package]] name = "faer-entity" -version = "0.19.0" +version = "0.19.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab968a02be27be95de0f1ad0af901b865fa0866b6a9b553a6cc9cf7f19c2ce71" +checksum = "c9c752ab2bff6f0b9597c6a1adc0112f7fd41fb343bc5a009a6274ae9d32fd03" dependencies = [ "bytemuck", "coe-rs", @@ -861,11 +883,17 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + [[package]] name = "hermit-abi" -version = "0.3.9" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" [[package]] name = "home" @@ -873,14 +901,14 @@ version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" dependencies = [ - "windows-sys", + "windows-sys 0.52.0", ] [[package]] name = "iana-time-zone" -version = "0.1.60" +version = "0.1.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -901,15 +929,15 @@ dependencies = [ [[package]] name = "indicatif" -version = "0.17.8" +version = "0.17.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3" +checksum = "cbf675b85ed934d3c67b5c5469701eec7db22689d0a2139d856e0925fa28b281" dependencies = [ "console", - "instant", "number_prefix", "portable-atomic", - "unicode-width", + "unicode-width 0.2.0", + "web-time", ] [[package]] @@ -918,24 +946,15 @@ version = "2.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" -[[package]] -name = "instant" -version = "0.1.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" -dependencies = [ - "cfg-if", -] - [[package]] name = "is-terminal" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b" +checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" dependencies = [ "hermit-abi", "libc", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -973,9 +992,9 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "js-sys" -version = "0.3.69" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" dependencies = [ "wasm-bindgen", ] @@ -1058,9 +1077,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.155" +version = "0.2.162" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398" [[package]] name = "libloading" @@ -1074,9 +1093,9 @@ dependencies = [ [[package]] name = "libm" -version = "0.2.8" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" [[package]] name = "linux-raw-sys" @@ -1118,9 +1137,9 @@ checksum = "b0bdabb30db18805d5290b3da7ceaccbddba795620b86c02145d688e04900a73" [[package]] name = "matrixmultiply" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" dependencies = [ "autocfg", "rawpointer", @@ -1175,7 +1194,7 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f563548d38f390ef9893e4883ec38c1fb312f569e98d76bededdd91a3b41a043" dependencies = [ - "equator", + "equator 0.2.2", "nano-gemm-c32", "nano-gemm-c64", "nano-gemm-codegen", @@ -1372,7 +1391,7 @@ dependencies = [ [[package]] name = "nutpie" -version = "0.13.1" +version = "0.13.2" dependencies = [ "anyhow", "arrow", @@ -1414,9 +1433,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.19.0" +version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "oorandom" @@ -1473,9 +1492,9 @@ dependencies = [ [[package]] name = "pest" -version = "2.7.11" +version = "2.7.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd53dff83f26735fdc1ca837098ccf133605d794cdae66acfc2bfac3ec809d95" +checksum = "879952a81a83930934cbf1786752d6dedc3b1f29e8f8fb2ad1d0a36f377cf442" dependencies = [ "memchr", "thiserror", @@ -1484,9 +1503,9 @@ dependencies = [ [[package]] name = "pest_derive" -version = "2.7.11" +version = "2.7.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a548d2beca6773b1c244554d36fcf8548a8a58e74156968211567250e48e49a" +checksum = "d214365f632b123a47fd913301e14c946c61d1c183ee245fa76eb752e59a02dd" dependencies = [ "pest", "pest_generator", @@ -1494,22 +1513,22 @@ dependencies = [ [[package]] name = "pest_generator" -version = "2.7.11" +version = "2.7.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c93a82e8d145725dcbaf44e5ea887c8a869efdcc28706df2d08c69e17077183" +checksum = "eb55586734301717aea2ac313f50b2eb8f60d2fc3dc01d190eefa2e625f60c4e" dependencies = [ "pest", "pest_meta", "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.87", ] [[package]] name = "pest_meta" -version = "2.7.11" +version = "2.7.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a941429fea7e08bedec25e4f6785b6ffaacc6b755da98df5ef3e7dcf4a124c4f" +checksum = "b75da2a70cf4d9cb76833c990ac9cd3923c9a8905a8929789ce347c84564d03d" dependencies = [ "once_cell", "pest", @@ -1518,9 +1537,9 @@ dependencies = [ [[package]] name = "plotters" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a15b6eccb8484002195a3e44fe65a4ce8e93a625797a063735536fd59cb01cf3" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" dependencies = [ "num-traits", "plotters-backend", @@ -1531,55 +1550,58 @@ dependencies = [ [[package]] name = "plotters-backend" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "414cec62c6634ae900ea1c56128dfe87cf63e7caece0852ec76aba307cebadb7" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" [[package]] name = "plotters-svg" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81b30686a7d9c3e010b84284bdd26a29f2138574f52f5eb6f794fc0ad924e705" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" dependencies = [ "plotters-backend", ] [[package]] name = "portable-atomic" -version = "1.7.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" +checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" [[package]] name = "ppv-lite86" -version = "0.2.17" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] [[package]] name = "prettyplease" -version = "0.2.20" +version = "0.2.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" +checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033" dependencies = [ "proc-macro2", - "syn 2.0.72", + "syn 2.0.87", ] [[package]] name = "proc-macro2" -version = "1.0.86" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" dependencies = [ "unicode-ident", ] [[package]] name = "pulp" -version = "0.18.21" +version = "0.18.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ec8d02258294f59e4e223b41ad7e81c874aa6b15bc4ced9ba3965826da0eed5" +checksum = "a0a01a0dc67cf4558d279f0c25b0962bd08fc6dec0137699eae304103e882fe6" dependencies = [ "bytemuck", "libm", @@ -1648,7 +1670,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.72", + "syn 2.0.87", ] [[package]] @@ -1657,18 +1679,18 @@ version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" dependencies = [ - "heck", + "heck 0.4.1", "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.72", + "syn 2.0.87", ] [[package]] name = "quote" -version = "1.0.36" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] @@ -1756,18 +1778,18 @@ checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" [[package]] name = "redox_syscall" -version = "0.5.3" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" +checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" dependencies = [ "bitflags 2.6.0", ] [[package]] name = "regex" -version = "1.10.5" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", @@ -1777,9 +1799,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.7" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", @@ -1788,9 +1810,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "rustc-hash" @@ -1800,15 +1822,15 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustix" -version = "0.38.34" +version = "0.38.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" +checksum = "99e4ea3e1cdc4b559b8e5650f9c8e5998e3e5c1343b4eaf034565f32318d63c0" dependencies = [ "bitflags 2.6.0", "errno", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -1840,31 +1862,32 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.204" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" +checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.204" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" +checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.87", ] [[package]] name = "serde_json" -version = "1.0.120" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "itoa", + "memchr", "ryu", "serde", ] @@ -1911,9 +1934,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.72" +version = "2.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc4b9b9bf2add8093d3f2c0204471e951b2285580335de42f9d2534f3ae7a8af" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" dependencies = [ "proc-macro2", "quote", @@ -1942,28 +1965,28 @@ checksum = "c1bbb9f3c5c463a01705937a24fdabc5047929ac764b2d5b9cf681c1f5041ed5" [[package]] name = "target-lexicon" -version = "0.12.15" +version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4873307b7c257eddcb50c9bedf158eb669578359fb28428bef438fec8e6ba7c2" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "thiserror" -version = "1.0.63" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.63" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.87", ] [[package]] @@ -1999,21 +2022,27 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "ucd-trie" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed646292ffc8188ef8ea4d1e0e0150fb15a5c2e12ad9b8fc191ae7a8a7f3c4b9" +checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" [[package]] name = "unicode-ident" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" [[package]] name = "unicode-width" -version = "0.1.13" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" + +[[package]] +name = "unicode-width" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" [[package]] name = "unindent" @@ -2029,9 +2058,9 @@ checksum = "9fe29601d1624f104fa9a35ea71a5f523dd8bd1cfc8c31f8124ad2b829f013c0" [[package]] name = "version_check" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "walkdir" @@ -2051,34 +2080,35 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.92" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" dependencies = [ "cfg-if", + "once_cell", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.92" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.87", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.92" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2086,28 +2116,38 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.92" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.87", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.92" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" [[package]] name = "web-sys" -version = "0.3.69" +version = "0.3.72" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "web-time" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" dependencies = [ "js-sys", "wasm-bindgen", @@ -2127,11 +2167,11 @@ dependencies = [ [[package]] name = "winapi-util" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys", + "windows-sys 0.59.0", ] [[package]] @@ -2152,6 +2192,15 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-targets" version = "0.52.6" @@ -2222,6 +2271,7 @@ version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ + "byteorder", "zerocopy-derive", ] @@ -2233,5 +2283,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.87", ] diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 0679ed2..9a2faa3 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -2,16 +2,18 @@ import itertools import warnings from dataclasses import dataclass +from functools import wraps from importlib.util import find_spec from math import prod -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union import numpy as np import pandas as pd from numpy.typing import NDArray +from pymc.initial_point import make_initial_point_fn from nutpie import _lib -from nutpie.compiled_pyfunc import from_pyfunc +from nutpie.compiled_pyfunc import SeedType, from_pyfunc from nutpie.sample import CompiledModel try: @@ -25,6 +27,60 @@ def intrinsic(f): if TYPE_CHECKING: import numba.core.ccallback import pymc as pm + from pytensor.tensor import TensorVariable, Variable + + +def rv_dict_to_flat_array_wrapper( + fn: Callable[[SeedType], dict[str, np.ndarray]], + names: list[str], + shapes: list[tuple[int]], +) -> Callable[[SeedType], np.ndarray]: + """ + Wraps a function that returns a dictionary of string:array key:value pairs + and returns a single flat float64 array. Also checks that the shapes of + the arrays match the expected shapes. + + Parameters + ---------- + fn: Callable + Function that takes a seed and return a dictionary of variable names + to initial values. This function should be the output of + pymc.initial_point.make_initial_point_fn + names: list of str + List of random variable names in the model + shapes: list of tuple of int + Shape of random variables in the model + + Returns + ------- + seeded_array_fn: Callable + Function that takes a seed and returns a flat, contiguous float64 + array of initial values. The ordering of the random variables inside + the array is controlled by the ``names`` parameter. + """ + + @wraps(fn) + def seeded_array_fn(seed: SeedType = None): + initial_value_dict = fn(seed) + total_size = sum(np.prod(shape).astype(int) for shape in shapes) + flat_array = np.empty(total_size, dtype="float64", order="C") + cursor = 0 + + for name, shape in zip(names, shapes, strict=True): + initial_value = initial_value_dict[name] + n = int(np.prod(initial_value.shape)) + if initial_value.shape != shape: + raise ValueError( + f"Size of initial value for {name} is {initial_value.shape}, " + f"expected {shape}" + ) + + flat_array[cursor : cursor + n] = initial_value.ravel().astype("float64") + cursor += n + + return flat_array + + return seeded_array_fn @intrinsic @@ -44,6 +100,7 @@ def codegen(cgctx, builder, sig, args): class CompiledPyMCModel(CompiledModel): compiled_logp_func: "numba.core.ccallback.CFunc" compiled_expand_func: "numba.core.ccallback.CFunc" + initial_point_func: Callable[[SeedType], np.ndarray] shared_data: dict[str, NDArray] user_data: NDArray n_expanded: int @@ -113,14 +170,15 @@ def _make_model(self, init_mean): ) var_sizes = [prod(shape) for shape in self.shape_info[2]] + var_names = self.shape_info[0] return _lib.PyMcModel( self.n_dim, logp_fn, expand_fn, + self.initial_point_func, var_sizes, - self.shape_info[0], - init_mean, + var_names, ) @@ -157,7 +215,11 @@ def make_user_data(shared_vars, shared_data): return user_data -def _compile_pymc_model_numba(model: "pm.Model", **kwargs) -> CompiledPyMCModel: +def _compile_pymc_model_numba( + model: "pm.Model", + pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]], + **kwargs, +) -> CompiledPyMCModel: if find_spec("numba") is None: raise ImportError( "Numba is not installed in the current environment. " @@ -172,8 +234,15 @@ def _compile_pymc_model_numba(model: "pm.Model", **kwargs) -> CompiledPyMCModel: n_expanded, logp_fn_pt, expand_fn_pt, + initial_point_fn, shape_info, - ) = _make_functions(model, mode="NUMBA", compute_grad=True, join_expanded=True) + ) = _make_functions( + model, + mode="NUMBA", + compute_grad=True, + join_expanded=True, + pymc_initial_point_fn=pymc_initial_point_fn, + ) expand_fn = expand_fn_pt.vm.jit_fn logp_fn = logp_fn_pt.vm.jit_fn @@ -228,6 +297,7 @@ def _compile_pymc_model_numba(model: "pm.Model", **kwargs) -> CompiledPyMCModel: _shapes={name: tuple(shape) for name, _, shape in zip(*shape_info)}, compiled_logp_func=logp_numba, compiled_expand_func=expand_numba, + initial_point_func=initial_point_fn, shared_data=shared_data, user_data=user_data, n_expanded=n_expanded, @@ -262,7 +332,13 @@ def _prepare_dims_and_coords(model, shape_info): return dims, coords -def _compile_pymc_model_jax(model, *, gradient_backend=None, **kwargs): +def _compile_pymc_model_jax( + model, + *, + gradient_backend=None, + pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]], + **kwargs, +): if find_spec("jax") is None: raise ImportError( "Jax is not installed in the current environment. " @@ -282,12 +358,14 @@ def _compile_pymc_model_jax(model, *, gradient_backend=None, **kwargs): _, logp_fn_pt, expand_fn_pt, + initial_point_fn, shape_info, ) = _make_functions( model, mode="JAX", compute_grad=gradient_backend == "pytensor", join_expanded=False, + pymc_initial_point_fn=pymc_initial_point_fn, ) logp_fn = logp_fn_pt.vm.jit_fn @@ -339,12 +417,13 @@ def expand(x, **shared): dims, coords = _prepare_dims_and_coords(model, shape_info) return from_pyfunc( - n_dim, - make_logp_func, - make_expand_func, - dtypes, - shapes, - names, + ndim=n_dim, + make_logp_fn=make_logp_func, + make_expand_fn=make_expand_func, + make_initial_point_fn=initial_point_fn, + expanded_dtypes=dtypes, + expanded_shapes=shapes, + expanded_names=names, shared_data=shared_data, dims=dims, coords=coords, @@ -356,6 +435,12 @@ def compile_pymc_model( *, backend: Literal["numba", "jax"] = "numba", gradient_backend: Literal["pytensor", "jax"] = "pytensor", + initial_points: dict[Union["Variable", str], np.ndarray | float | int] + | None = None, + jitter_rvs: set["TensorVariable"] | None = None, + default_initialization_strategy: Literal[ + "support_point", "prior" + ] = "support_point", **kwargs, ) -> CompiledModel: """Compile necessary functions for sampling a pymc model. @@ -367,9 +452,18 @@ def compile_pymc_model( backend : ["jax", "numba"] The pytensor backend that is used to compile the logp function. gradient_backend: ["pytensor", "jax"] - Which library is used to compute the gradients. This can only be - changed to "jax" if the jax backend is used. - + Which library is used to compute the gradients. This can only be changed + to "jax" if the jax backend is used. + jitter_rvs : set + The set (or list or tuple) of random variables for which a U(-1, +1) + jitter should be added to the initial value. Only available for + variables that have a transform or real-valued support. + default_initialization_strategy : str + Which of { "support_point", "prior" } to prefer if the initval setting + for an RV is None. + initial_points : dict + Initial value (strategies) to use instead of what's specified in + `Model.initial_values`. Returns ------- compiled_model : CompiledPyMCModel @@ -384,13 +478,29 @@ def compile_pymc_model( "and restart your kernel in case you are in an interactive session." ) + if default_initialization_strategy == "support_point" and jitter_rvs is None: + jitter_rvs = set(model.free_RVs) + + initial_point_fn = make_initial_point_fn( + model=model, + overrides=initial_points, + default_strategy=default_initialization_strategy, + jitter_rvs=jitter_rvs, + return_transformed=True, + ) + if backend.lower() == "numba": if gradient_backend == "jax": raise ValueError("Gradient backend cannot be jax when using numba backend") - return _compile_pymc_model_numba(model, **kwargs) + return _compile_pymc_model_numba( + model=model, pymc_initial_point_fn=initial_point_fn, **kwargs + ) elif backend.lower() == "jax": return _compile_pymc_model_jax( - model, gradient_backend=gradient_backend, **kwargs + model=model, + gradient_backend=gradient_backend, + pymc_initial_point_fn=initial_point_fn, + **kwargs, ) else: raise ValueError(f"Backend must be one of numba and jax. Got {backend}") @@ -425,9 +535,64 @@ def _compute_shapes(model): return dict(zip(trace_vars.keys(), shape_func())) -def _make_functions(model, *, mode, compute_grad, join_expanded): +def _make_functions( + model: "pm.Model", + *, + mode: Literal["JAX", "NUMBA"], + compute_grad: bool, + join_expanded: bool, + pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]], +) -> tuple[ + int, + int, + Callable, + Callable, + Callable, + tuple[list[str], list[slice], list[tuple[int, ...]]], +]: + """ + Compile functions required by nuts-rs from a given PyMC model. + + Parameters + ---------- + model: pymc.Model + The model to compile + mode: str + Pytensor compile mode. One of "NUMBA" or "JAX" + compute_grad: bool + Whether to compute gradients using pytensor. Must be True if mode is + "NUMBA", otherwise False implies Jax will be used to compute gradients + join_expanded: bool + Whether to join the expanded variables into a single array. If False, + the expanded variables will be returned as a list of arrays. + pymc_initial_point_fn: Callable + Initial point function created by + pymc.initial_point.make_initial_point_fn + + Returns + ------- + num_free_vars: int + Number of free (root) random variables in the model + num_expanded: int + Total number of all random variables (root and dependent) in the model + logp_fn_pt: Callable + Compiled pytensor log probability function. If compute_grad is True, the + function will return both the logp and the gradient, otherwise only the + logp is returned. + expand_fn_pt: Callable + Compiled pytensor function that computes the remaining variables for the + trace + initial_point_fn: Callable + Python function that takes a random seed and returns a flat array of + initial values + param_data: tuple of lists + Tuple containing data necessary to unravel a flat array of model + variables back into a ragged list of arrays. The first list contains the + names of the variables, the second list contains the slices that + correspond to the variables in the flat array, and the third list + contains the shapes of the variables. + """ import pytensor - import pytensor.link.numba.dispatch import pytensor.tensor as pt from pymc.pytensorf import compile_pymc @@ -471,6 +636,10 @@ def _make_functions(model, *, mode, compute_grad, join_expanded): num_free_vars = count + initial_point_fn = rv_dict_to_flat_array_wrapper( + pymc_initial_point_fn, names=joined_names, shapes=joined_shapes + ) + joined = pt.TensorType("float64", shape=(num_free_vars,))( name="_unconstrained_point" ) @@ -537,6 +706,7 @@ def _make_functions(model, *, mode, compute_grad, join_expanded): num_expanded, logp_fn_pt, expand_fn_pt, + initial_point_fn, (all_names, all_slices, all_shapes), ) diff --git a/python/nutpie/compiled_pyfunc.py b/python/nutpie/compiled_pyfunc.py index 4db549c..9ede109 100644 --- a/python/nutpie/compiled_pyfunc.py +++ b/python/nutpie/compiled_pyfunc.py @@ -8,11 +8,14 @@ from nutpie import _lib from nutpie.sample import CompiledModel +SeedType = int + @dataclass(frozen=True) class PyFuncModel(CompiledModel): _make_logp_func: Callable _make_expand_func: Callable + _make_initial_points: Callable[[SeedType], np.ndarray] | None _shared_data: dict[str, Any] _n_dim: int _variables: list[_lib.PyVariable] @@ -62,6 +65,7 @@ def make_expand_func(seed1, seed2, chain): make_expand_func, self._variables, self.n_dim, + self._make_initial_points, ) @@ -73,10 +77,10 @@ def from_pyfunc( expanded_shapes: list[tuple[int, ...]], expanded_names: list[str], *, - initial_mean: np.ndarray | None = None, coords: dict[str, Any] | None = None, dims: dict[str, tuple[str, ...]] | None = None, shared_data: dict[str, Any] | None = None, + make_initial_point_fn: Callable[[SeedType], np.ndarray] | None, ): variables = [] for name, shape, dtype in zip( @@ -98,14 +102,13 @@ def from_pyfunc( if shared_data is None: shared_data = {} - if shared_data is None: - shared_data = dict() return PyFuncModel( _n_dim=ndim, dims=dims, _coords=coords, _make_logp_func=make_logp_fn, _make_expand_func=make_expand_fn, + _make_initial_points=make_initial_point_fn, _variables=variables, _shared_data=shared_data, ) diff --git a/src/pyfunc.rs b/src/pyfunc.rs index 6a85bfa..07b32d5 100644 --- a/src/pyfunc.rs +++ b/src/pyfunc.rs @@ -16,7 +16,7 @@ use pyo3::{ Bound, Py, PyAny, PyErr, Python, }; use rand::Rng; -use rand_distr::{Distribution, StandardNormal}; +use rand_distr::{Distribution, StandardNormal, Uniform}; use smallvec::SmallVec; use thiserror::Error; @@ -73,6 +73,7 @@ impl PyVariable { pub struct PyModel { make_logp_func: Py, make_expand_func: Py, + init_point_func: Option>, variables: Vec, ndim: usize, } @@ -85,10 +86,12 @@ impl PyModel { make_expand_func: Py, variables: Vec, ndim: usize, + init_point_func: Option>, ) -> Self { Self { make_logp_func, make_expand_func, + init_point_func, variables, ndim, } @@ -437,11 +440,13 @@ impl DrawStorage for PyTrace { } impl Model for PyModel { - type Math<'model> = CpuMath + type Math<'model> + = CpuMath where Self: 'model; - type DrawStorage<'model, S: nuts_rs::Settings> = PyTrace + type DrawStorage<'model, S: nuts_rs::Settings> + = PyTrace where Self: 'model; @@ -474,10 +479,34 @@ impl Model for PyModel { rng: &mut R, position: &mut [f64], ) -> Result<()> { - let dist = StandardNormal; - dist.sample_iter(rng) - .zip(position.iter_mut()) - .for_each(|(val, pos)| *pos = val); + let Some(init_func) = self.init_point_func.as_ref() else { + let dist = Uniform::new(-2f64, 2f64); + position.iter_mut().for_each(|x| *x = dist.sample(rng)); + return Ok(()); + }; + + let seed = rng.next_u64(); + + Python::with_gil(|py| { + let init_point = init_func + .call1(py, (seed,)) + .context("Failed to initialize point")?; + + let init_point: PyReadonlyArray1 = init_point + .extract(py) + .context("Initializition array returned incorrect argument")?; + + let init_point = init_point + .as_slice() + .context("Initial point must be contiguous")?; + + if init_point.len() != position.len() { + bail!("Initial point has incorrect length"); + } + + position.copy_from_slice(init_point); + Ok(()) + })?; Ok(()) } } diff --git a/src/pymc.rs b/src/pymc.rs index 8865f84..b7862cf 100644 --- a/src/pymc.rs +++ b/src/pymc.rs @@ -1,6 +1,6 @@ use std::{ffi::c_void, fmt::Display, sync::Arc}; -use anyhow::{Context, Result}; +use anyhow::{bail, Context, Result}; use arrow::{ array::{Array, FixedSizeListArray, Float64Array, StructArray}, datatypes::{DataType, Field, Fields}, @@ -11,7 +11,7 @@ use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Model, Settings}; use pyo3::{ pyclass, pymethods, types::{PyAnyMethods, PyList}, - Bound, PyObject, PyResult, + Bound, Py, PyAny, PyObject, PyResult, Python, }; use rand::{distributions::Uniform, prelude::Distribution}; @@ -231,7 +231,7 @@ pub(crate) struct PyMcModel { dim: usize, density: LogpFunc, expand: ExpandFunc, - mu: Box<[f64]>, + init_func: Py, var_sizes: Vec, var_names: Vec, } @@ -243,15 +243,15 @@ impl PyMcModel { dim: usize, density: LogpFunc, expand: ExpandFunc, + init_func: Py, var_sizes: &Bound<'py, PyList>, var_names: &Bound<'py, PyList>, - start_point: PyReadonlyArray1<'py, f64>, ) -> PyResult { Ok(Self { dim, density, expand, - mu: start_point.as_slice()?.into(), + init_func, var_names: var_names.extract()?, var_sizes: var_sizes.extract()?, }) @@ -292,11 +292,29 @@ impl Model for PyMcModel { rng: &mut R, position: &mut [f64], ) -> Result<()> { - let dist = Uniform::new(-2f64, 2f64); - position - .iter_mut() - .zip_eq(self.mu.iter()) - .for_each(|(x, mu)| *x = dist.sample(rng) + mu); + let seed = rng.next_u64(); + + Python::with_gil(|py| { + let init_point = self + .init_func + .call1(py, (seed,)) + .context("Failed to initialize point")?; + + let init_point: PyReadonlyArray1 = init_point + .extract(py) + .context("Initializition array returned incorrect argument")?; + + let init_point = init_point + .as_slice() + .context("Initial point must be contiguous")?; + + if init_point.len() != position.len() { + bail!("Initial point has incorrect length"); + } + + position.copy_from_slice(init_point); + Ok(()) + })?; Ok(()) } diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 8f70699..d59fb18 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -38,6 +38,19 @@ def test_pymc_model_float32(backend, gradient_backend): trace.posterior.a # noqa: B018 +@parameterize_backends +def test_pymc_model_no_prior(backend, gradient_backend): + with pm.Model() as model: + a = pm.Flat("a") + pm.Normal("b", mu=a, observed=0.0) + + compiled = nutpie.compile_pymc_model( + model, backend=backend, gradient_backend=gradient_backend + ) + trace = nutpie.sample(compiled, chains=1) + trace.posterior.a # noqa: B018 + + @parameterize_backends def test_blocking(backend, gradient_backend): with pm.Model() as model: