diff --git a/Cargo.lock b/Cargo.lock index e702c9c..bd30888 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -54,9 +54,9 @@ checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "anyhow" -version = "1.0.93" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" +checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" [[package]] name = "arrow" @@ -185,7 +185,7 @@ version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e972cd1ff4a4ccd22f86d3e53e835c2ed92e0eea6a3e8eadb72b4f1ac802cf8" dependencies = [ - "bitflags 2.6.0", + "bitflags", ] [[package]] @@ -242,38 +242,29 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "bindgen" -version = "0.69.5" +version = "0.71.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3" dependencies = [ - "bitflags 2.6.0", + "bitflags", "cexpr", "clang-sys", - "itertools 0.12.1", - "lazy_static", - "lazycell", + "itertools 0.13.0", "log", "prettyplease", "proc-macro2", "quote", "regex", - "rustc-hash", + "rustc-hash 2.1.0", "shlex", - "syn 2.0.87", - "which", + "syn 2.0.96", ] [[package]] name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - -[[package]] -name = "bitflags" -version = "2.6.0" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +checksum = "1be3f42a67d6d345ecd59f675f3f012d6974981560836e938c22b424b85ce1be" [[package]] name = "block-buffer" @@ -286,15 +277,15 @@ dependencies = [ [[package]] name = "bridgestan" -version = "2.5.0" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc8db213e11ba8b22c444912e269f8164d4af17292e6e78d9d6a4162225e929b" +checksum = "3bcc63930a0b64b9c4ca744d32aa2c664bcc4fdd34bd68778e531572318cd0bc" dependencies = [ "bindgen", "libloading", "log", "path-absolutize", - "thiserror", + "thiserror 2.0.11", ] [[package]] @@ -305,22 +296,22 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.19.0" +version = "1.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d" +checksum = "ef657dfab802224e671f5818e9a4935f9b1957ed18e58292690cc39e7a4092a3" dependencies = [ "bytemuck_derive", ] [[package]] name = "bytemuck_derive" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec" +checksum = "3fa76293b4f7bb636ab88fd78228235b5248b4d05cc589aed610f954af5d7c7a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.96", ] [[package]] @@ -331,9 +322,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.8.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" +checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" [[package]] name = "cast" @@ -343,9 +334,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.1" +version = "1.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" +checksum = "c8293772165d9345bdaaa39b45b2109591e63fe5e6fbc23c6ff930a048aa310b" dependencies = [ "shlex", ] @@ -367,9 +358,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.38" +version = "0.4.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" dependencies = [ "android-tzdata", "iana-time-zone", @@ -417,18 +408,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.21" +version = "4.5.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" +checksum = "a8eb5e908ef3a6efbe1ed62520fb7287959888c88485abe072543190ecc66783" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.21" +version = "4.5.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" +checksum = "96b01801b5fc6a0a232407abc821660c9c6d25a1cafc0d4f85f29fb8d9afc121" dependencies = [ "anstyle", "clap_lex", @@ -436,9 +427,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" [[package]] name = "coe-rs" @@ -448,15 +439,15 @@ checksum = "7e8f1e641542c07631228b1e0dc04b69ae3c1d58ef65d5691a439711d805c698" [[package]] name = "console" -version = "0.15.8" +version = "0.15.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +checksum = "ea3c6ecd8059b57859df5c69830340ed3c41d30e3da0c1cbed90a96ac853041b" dependencies = [ "encode_unicode", - "lazy_static", "libc", - "unicode-width 0.1.14", - "windows-sys 0.52.0", + "once_cell", + "unicode-width", + "windows-sys 0.59.0", ] [[package]] @@ -487,9 +478,9 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpufeatures" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ca741a962e1b0bff6d724a1a0958b686406e853bb14061f218562e1896f95e6" +checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" dependencies = [ "libc", ] @@ -532,9 +523,9 @@ dependencies = [ [[package]] name = "crossbeam-deque" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" dependencies = [ "crossbeam-epoch", "crossbeam-utils", @@ -551,9 +542,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.20" +version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crunchy" @@ -597,6 +588,15 @@ dependencies = [ "reborrow", ] +[[package]] +name = "dyn-stack" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490bd48eb68fffcfed519b4edbfd82c69cbe741d175b84f0e0cbe8c57cbe0bdd" +dependencies = [ + "bytemuck", +] + [[package]] name = "either" version = "1.13.0" @@ -605,9 +605,9 @@ checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "encode_unicode" -version = "0.3.6" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" [[package]] name = "enum-as-inner" @@ -618,7 +618,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.96", ] [[package]] @@ -632,11 +632,11 @@ dependencies = [ [[package]] name = "equator" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5099e7b6f0b7431c7a1c49f75929e2777693da192784f167066977a2965767af" +checksum = "4711b213838dfee0117e3be6ac926007d7f433d7bbe33595975d4190cb07e6fc" dependencies = [ - "equator-macro 0.4.1", + "equator-macro 0.4.2", ] [[package]] @@ -647,28 +647,18 @@ checksum = "3bf679796c0322556351f287a51b49e48f7c4986e727b5dd78c972d30e2e16cc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.96", ] [[package]] name = "equator-macro" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5322a90066ddae2b705096eb9e10c465c0498ae93bf9bdd6437415327c88e3bb" +checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", -] - -[[package]] -name = "errno" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" -dependencies = [ - "libc", - "windows-sys 0.52.0", + "syn 2.0.96", ] [[package]] @@ -680,8 +670,8 @@ dependencies = [ "bytemuck", "coe-rs", "dbgf", - "dyn-stack", - "equator 0.4.1", + "dyn-stack 0.10.0", + "equator 0.4.2", "faer-entity", "gemm", "libm", @@ -710,17 +700,17 @@ dependencies = [ "libm", "num-complex", "num-traits", - "pulp", + "pulp 0.18.22", "reborrow", ] [[package]] name = "gemm" -version = "0.18.0" +version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e400f2ffd14e7548356236c35dc39cad6666d833a852cb8a8f3f28029359bb03" +checksum = "ab96b703d31950f1aeddded248bc95543c9efc7ac9c4a21fda8703a83ee35451" dependencies = [ - "dyn-stack", + "dyn-stack 0.13.0", "gemm-c32", "gemm-c64", "gemm-common", @@ -736,11 +726,11 @@ dependencies = [ [[package]] name = "gemm-c32" -version = "0.18.0" +version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10dc4a6176c8452d60eac1a155b454c91c668f794151a303bf3c75ea2874812d" +checksum = "f6db9fd9f40421d00eea9dd0770045a5603b8d684654816637732463f4073847" dependencies = [ - "dyn-stack", + "dyn-stack 0.13.0", "gemm-common", "num-complex", "num-traits", @@ -751,11 +741,11 @@ dependencies = [ [[package]] name = "gemm-c64" -version = "0.18.0" +version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc2032ce2c0bb150da0256338759a6fb01ca056f6dfe28c4d14af32d7f878f6f" +checksum = "dfcad8a3d35a43758330b635d02edad980c1e143dc2f21e6fd25f9e4eada8edf" dependencies = [ - "dyn-stack", + "dyn-stack 0.13.0", "gemm-common", "num-complex", "num-traits", @@ -766,18 +756,19 @@ dependencies = [ [[package]] name = "gemm-common" -version = "0.18.0" +version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90fd234fc525939654f47b39325fd5f55e552ceceea9135f3aa8bdba61eabef6" +checksum = "a352d4a69cbe938b9e2a9cb7a3a63b7e72f9349174a2752a558a8a563510d0f3" dependencies = [ "bytemuck", - "dyn-stack", + "dyn-stack 0.13.0", "half", + "libm", "num-complex", "num-traits", "once_cell", "paste", - "pulp", + "pulp 0.21.4", "raw-cpuid", "rayon", "seq-macro", @@ -786,11 +777,11 @@ dependencies = [ [[package]] name = "gemm-f16" -version = "0.18.0" +version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fc3652651f96a711d46b8833e1fac27a864be4bdfa81a374055f33ddd25c0c6" +checksum = "cff95ae3259432f3c3410eaa919033cd03791d81cebd18018393dc147952e109" dependencies = [ - "dyn-stack", + "dyn-stack 0.13.0", "gemm-common", "gemm-f32", "half", @@ -804,11 +795,11 @@ dependencies = [ [[package]] name = "gemm-f32" -version = "0.18.0" +version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acbc51c44ae3defd207e6d9416afccb3c4af1e7cef5e4960e4c720ac4d6f998e" +checksum = "bc8d3d4385393304f407392f754cd2dc4b315d05063f62cf09f47b58de276864" dependencies = [ - "dyn-stack", + "dyn-stack 0.13.0", "gemm-common", "num-complex", "num-traits", @@ -819,11 +810,11 @@ dependencies = [ [[package]] name = "gemm-f64" -version = "0.18.0" +version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f37fc86e325c2415a4d0cab8324a0c5371ec06fc7d2f9cb1636fcfc9536a8d8" +checksum = "35b2a4f76ce4b8b16eadc11ccf2e083252d8237c1b589558a49b0183545015bd" dependencies = [ - "dyn-stack", + "dyn-stack 0.13.0", "gemm-common", "num-complex", "num-traits", @@ -855,9 +846,9 @@ dependencies = [ [[package]] name = "glob" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "half" @@ -895,15 +886,6 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" -[[package]] -name = "home" -version = "0.5.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" -dependencies = [ - "windows-sys 0.52.0", -] - [[package]] name = "iana-time-zone" version = "0.1.61" @@ -936,7 +918,7 @@ dependencies = [ "console", "number_prefix", "portable-atomic", - "unicode-width 0.2.0", + "unicode-width", "web-time", ] @@ -968,49 +950,38 @@ dependencies = [ [[package]] name = "itertools" -version = "0.12.1" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" dependencies = [ "either", ] [[package]] name = "itertools" -version = "0.13.0" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" dependencies = [ "either", ] [[package]] name = "itoa" -version = "1.0.11" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" [[package]] name = "js-sys" -version = "0.3.72" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" +checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" dependencies = [ + "once_cell", "wasm-bindgen", ] -[[package]] -name = "lazy_static" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" - -[[package]] -name = "lazycell" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" - [[package]] name = "lexical-core" version = "0.8.5" @@ -1077,15 +1048,15 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.162" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libloading" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" +checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", "windows-targets", @@ -1097,12 +1068,6 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" -[[package]] -name = "linux-raw-sys" -version = "0.4.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" - [[package]] name = "lock_api" version = "0.4.12" @@ -1115,9 +1080,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.22" +version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" [[package]] name = "matrixcompare" @@ -1386,7 +1351,7 @@ dependencies = [ "num-integer", "num-traits", "pyo3", - "rustc-hash", + "rustc-hash 1.1.0", ] [[package]] @@ -1398,7 +1363,7 @@ dependencies = [ "bridgestan", "criterion", "indicatif", - "itertools 0.13.0", + "itertools 0.14.0", "numpy", "nuts-rs", "pyo3", @@ -1407,7 +1372,7 @@ dependencies = [ "rand_distr", "rayon", "smallvec", - "thiserror", + "thiserror 1.0.69", "time-humanize", "upon", ] @@ -1423,12 +1388,12 @@ dependencies = [ "faer", "itertools 0.13.0", "multiversion", - "pulp", + "pulp 0.18.22", "rand", "rand_chacha", "rand_distr", "rayon", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -1492,20 +1457,20 @@ dependencies = [ [[package]] name = "pest" -version = "2.7.14" +version = "2.7.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "879952a81a83930934cbf1786752d6dedc3b1f29e8f8fb2ad1d0a36f377cf442" +checksum = "8b7cafe60d6cf8e62e1b9b2ea516a089c008945bb5a275416789e7db0bc199dc" dependencies = [ "memchr", - "thiserror", + "thiserror 2.0.11", "ucd-trie", ] [[package]] name = "pest_derive" -version = "2.7.14" +version = "2.7.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d214365f632b123a47fd913301e14c946c61d1c183ee245fa76eb752e59a02dd" +checksum = "816518421cfc6887a0d62bf441b6ffb4536fcc926395a69e1a85852d4363f57e" dependencies = [ "pest", "pest_generator", @@ -1513,22 +1478,22 @@ dependencies = [ [[package]] name = "pest_generator" -version = "2.7.14" +version = "2.7.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb55586734301717aea2ac313f50b2eb8f60d2fc3dc01d190eefa2e625f60c4e" +checksum = "7d1396fd3a870fc7838768d171b4616d5c91f6cc25e377b673d714567d99377b" dependencies = [ "pest", "pest_meta", "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.96", ] [[package]] name = "pest_meta" -version = "2.7.14" +version = "2.7.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b75da2a70cf4d9cb76833c990ac9cd3923c9a8905a8929789ce347c84564d03d" +checksum = "e1e58089ea25d717bfd31fb534e4f3afcc2cc569c70de3e239778991ea3b7dea" dependencies = [ "once_cell", "pest", @@ -1565,9 +1530,9 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" +checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" [[package]] name = "ppv-lite86" @@ -1580,19 +1545,19 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.25" +version = "0.2.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033" +checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" dependencies = [ "proc-macro2", - "syn 2.0.87", + "syn 2.0.96", ] [[package]] name = "proc-macro2" -version = "1.0.89" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" +checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" dependencies = [ "unicode-ident", ] @@ -1609,6 +1574,20 @@ dependencies = [ "reborrow", ] +[[package]] +name = "pulp" +version = "0.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95fb7a99b37aaef4c7dd2fd15a819eb8010bfc7a2c2155230d51f497316cad6d" +dependencies = [ + "bytemuck", + "cfg-if", + "libm", + "num-complex", + "reborrow", + "version_check", +] + [[package]] name = "py_literal" version = "0.4.0" @@ -1670,7 +1649,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.87", + "syn 2.0.96", ] [[package]] @@ -1683,14 +1662,14 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.87", + "syn 2.0.96", ] [[package]] name = "quote" -version = "1.0.37" +version = "1.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" dependencies = [ "proc-macro2", ] @@ -1737,11 +1716,11 @@ dependencies = [ [[package]] name = "raw-cpuid" -version = "10.7.0" +version = "11.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +checksum = "c6928fa44c097620b706542d428957635951bade7143269085389d42c8a4927e" dependencies = [ - "bitflags 1.3.2", + "bitflags", ] [[package]] @@ -1778,11 +1757,11 @@ checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" [[package]] name = "redox_syscall" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" +checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" dependencies = [ - "bitflags 2.6.0", + "bitflags", ] [[package]] @@ -1821,17 +1800,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] -name = "rustix" -version = "0.38.40" +name = "rustc-hash" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99e4ea3e1cdc4b559b8e5650f9c8e5998e3e5c1343b4eaf034565f32318d63c0" -dependencies = [ - "bitflags 2.6.0", - "errno", - "libc", - "linux-raw-sys", - "windows-sys 0.52.0", -] +checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497" + +[[package]] +name = "rustversion" +version = "1.0.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" [[package]] name = "ryu" @@ -1862,29 +1840,29 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.215" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" +checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.215" +version = "1.0.217" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" +checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.96", ] [[package]] name = "serde_json" -version = "1.0.132" +version = "1.0.135" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" +checksum = "2b0d7ba2887406110130a978386c4e1befb98c674b4fba677954e4db976630d9" dependencies = [ "itoa", "memchr", @@ -1934,9 +1912,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.87" +version = "2.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" +checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" dependencies = [ "proc-macro2", "quote", @@ -1945,15 +1923,15 @@ dependencies = [ [[package]] name = "sysctl" -version = "0.5.5" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec7dddc5f0fee506baf8b9fdb989e242f17e4b11c61dfbb0635b705217199eea" +checksum = "01198a2debb237c62b6826ec7081082d951f46dbb64b0e8c7649a452230d1dfc" dependencies = [ - "bitflags 2.6.0", + "bitflags", "byteorder", "enum-as-inner", "libc", - "thiserror", + "thiserror 1.0.69", "walkdir", ] @@ -1975,7 +1953,16 @@ version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ - "thiserror-impl", + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" +dependencies = [ + "thiserror-impl 2.0.11", ] [[package]] @@ -1986,7 +1973,18 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.96", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", ] [[package]] @@ -2028,15 +2026,9 @@ checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" [[package]] name = "unicode-ident" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" - -[[package]] -name = "unicode-width" -version = "0.1.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" +checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" [[package]] name = "unicode-width" @@ -2080,35 +2072,35 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.95" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" +checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" dependencies = [ "cfg-if", "once_cell", + "rustversion", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.95" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" +checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" dependencies = [ "bumpalo", "log", - "once_cell", "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.96", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.95" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" +checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2116,28 +2108,31 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.95" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" +checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.96", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.95" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" +checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +dependencies = [ + "unicode-ident", +] [[package]] name = "web-sys" -version = "0.3.72" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" +checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" dependencies = [ "js-sys", "wasm-bindgen", @@ -2153,18 +2148,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "which" -version = "4.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" -dependencies = [ - "either", - "home", - "once_cell", - "rustix", -] - [[package]] name = "winapi-util" version = "0.1.9" @@ -2283,5 +2266,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.87", + "syn 2.0.96", ] diff --git a/Cargo.toml b/Cargo.toml index 48634d8..d0df604 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,8 +31,8 @@ rayon = "1.9.0" # Keep arrow in sync with nuts-rs requirements arrow = { version = "52.0.0", default-features = false, features = ["ffi"] } anyhow = "1.0.72" -itertools = "0.13.0" -bridgestan = "2.5.0" +itertools = "0.14.0" +bridgestan = "2.6.1" rand_distr = "0.4.3" smallvec = "1.11.0" upon = { version = "0.8.1", default-features = false, features = [] } diff --git a/pyproject.toml b/pyproject.toml index af9c60f..7e674af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ name = "nutpie" description = "Sample Stan or PyMC models" authors = [{ name = "PyMC Developers", email = "pymc.devs@gmail.com" }] readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.10,<3.13" license = { text = "MIT" } classifiers = [ "Programming Language :: Rust", @@ -26,6 +26,7 @@ dependencies = [ "xarray >= 2023.06.0", "arviz >= 0.15.0", ] +dynamic = ["version"] [project.optional-dependencies] stan = ["bridgestan >= 2.4.1"] diff --git a/python/nutpie/__init__.py b/python/nutpie/__init__.py index 980b7e5..443f099 100644 --- a/python/nutpie/__init__.py +++ b/python/nutpie/__init__.py @@ -4,4 +4,4 @@ from nutpie.sample import sample __version__: str = _lib.__version__ -__all__ = ["__version__", "sample", "compile_pymc_model", "compile_stan_model"] +__all__ = ["__version__", "compile_pymc_model", "compile_stan_model", "sample"] diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 9a2faa3..3149e4e 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -1,6 +1,7 @@ import dataclasses import itertools import warnings +from collections.abc import Iterable from dataclasses import dataclass from functools import wraps from importlib.util import find_spec @@ -69,7 +70,7 @@ def seeded_array_fn(seed: SeedType = None): for name, shape in zip(names, shapes, strict=True): initial_value = initial_value_dict[name] n = int(np.prod(initial_value.shape)) - if initial_value.shape != shape: + if tuple(initial_value.shape) != tuple(shape): raise ValueError( f"Size of initial value for {name} is {initial_value.shape}, " f"expected {shape}" @@ -218,6 +219,7 @@ def make_user_data(shared_vars, shared_data): def _compile_pymc_model_numba( model: "pm.Model", pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]], + var_names: Iterable[str] | None = None, **kwargs, ) -> CompiledPyMCModel: if find_spec("numba") is None: @@ -242,6 +244,7 @@ def _compile_pymc_model_numba( compute_grad=True, join_expanded=True, pymc_initial_point_fn=pymc_initial_point_fn, + var_names=var_names, ) expand_fn = expand_fn_pt.vm.jit_fn @@ -337,6 +340,7 @@ def _compile_pymc_model_jax( *, gradient_backend=None, pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]], + var_names: Iterable[str] | None = None, **kwargs, ): if find_spec("jax") is None: @@ -366,6 +370,7 @@ def _compile_pymc_model_jax( compute_grad=gradient_backend == "pytensor", join_expanded=False, pymc_initial_point_fn=pymc_initial_point_fn, + var_names=var_names, ) logp_fn = logp_fn_pt.vm.jit_fn @@ -441,6 +446,7 @@ def compile_pymc_model( default_initialization_strategy: Literal[ "support_point", "prior" ] = "support_point", + var_names: Iterable[str] | None = None, **kwargs, ) -> CompiledModel: """Compile necessary functions for sampling a pymc model. @@ -464,6 +470,8 @@ def compile_pymc_model( initial_points : dict Initial value (strategies) to use instead of what's specified in `Model.initial_values`. + var_names : list[str] | None + A list of variables to store in the trace. If None, store all variables. Returns ------- compiled_model : CompiledPyMCModel @@ -493,13 +501,17 @@ def compile_pymc_model( if gradient_backend == "jax": raise ValueError("Gradient backend cannot be jax when using numba backend") return _compile_pymc_model_numba( - model=model, pymc_initial_point_fn=initial_point_fn, **kwargs + model=model, + pymc_initial_point_fn=initial_point_fn, + var_names=var_names, + **kwargs, ) elif backend.lower() == "jax": return _compile_pymc_model_jax( model=model, gradient_backend=gradient_backend, pymc_initial_point_fn=initial_point_fn, + var_names=var_names, **kwargs, ) else: @@ -542,6 +554,7 @@ def _make_functions( compute_grad: bool, join_expanded: bool, pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]], + var_names: Iterable[str] | None = None, ) -> tuple[ int, int, @@ -568,6 +581,8 @@ def _make_functions( pymc_initial_point_fn: Callable Initial point function created by pymc.initial_point.make_initial_point_fn + var_names: + Names of variables to store in the trace. Defaults to all variables. Returns ------- @@ -673,6 +688,10 @@ def _make_functions( var for var in model.unobserved_value_vars if var.name not in joined_names ] + if var_names is not None: + names = set(var_names) + remaining_rvs = [var for var in remaining_rvs if var.name in names] + all_names = joined_names + remaining_rvs all_names = joined_names.copy() diff --git a/src/pyfunc.rs b/src/pyfunc.rs index 07b32d5..f914974 100644 --- a/src/pyfunc.rs +++ b/src/pyfunc.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use anyhow::{anyhow, bail, Context, Result}; use arrow::{ array::{ - Array, ArrayBuilder, BooleanBuilder, FixedSizeListBuilder, Float32Builder, Float64Builder, - Int64Builder, ListBuilder, PrimitiveBuilder, StructBuilder, + Array, ArrayBuilder, BooleanBuilder, Float32Builder, Float64Builder, Int64Builder, + LargeListBuilder, PrimitiveBuilder, StructBuilder, }, datatypes::{DataType, Field, Float32Type, Float64Type, Int64Type}, }; @@ -16,7 +16,7 @@ use pyo3::{ Bound, Py, PyAny, PyErr, Python, }; use rand::Rng; -use rand_distr::{Distribution, StandardNormal, Uniform}; +use rand_distr::{Distribution, Uniform}; use smallvec::SmallVec; use thiserror::Error; @@ -37,21 +37,21 @@ impl PyVariable { ExpandDtype::Float64 {} => DataType::Float64, ExpandDtype::Float32 {} => DataType::Float32, ExpandDtype::Int64 {} => DataType::Int64, - ExpandDtype::BooleanArray { tensor_type } => { + ExpandDtype::BooleanArray { tensor_type: _ } => { let field = Arc::new(Field::new("item", DataType::Boolean, false)); - DataType::FixedSizeList(field, tensor_type.size() as i32) + DataType::LargeList(field) } ExpandDtype::ArrayFloat64 { tensor_type: _ } => { let field = Arc::new(Field::new("item", DataType::Float64, true)); - DataType::List(field) + DataType::LargeList(field) } - ExpandDtype::ArrayFloat32 { tensor_type } => { + ExpandDtype::ArrayFloat32 { tensor_type: _ } => { let field = Arc::new(Field::new("item", DataType::Float32, false)); - DataType::FixedSizeList(field, tensor_type.size() as i32) + DataType::LargeList(field) } - ExpandDtype::ArrayInt64 { tensor_type } => { + ExpandDtype::ArrayInt64 { tensor_type: _ } => { let field = Arc::new(Field::new("item", DataType::Int64, false)); - DataType::FixedSizeList(field, tensor_type.size() as i32) + DataType::LargeList(field) } } } @@ -368,10 +368,10 @@ impl DrawStorage for PyTrace { )?; builder.append_value(value.extract().expect("Return value from expand function could not be converted to int64")) }, - ExpandDtype::BooleanArray { tensor_type} => { - let builder: &mut FixedSizeListBuilder> = + ExpandDtype::BooleanArray { tensor_type } => { + let builder: &mut LargeListBuilder> = self.builder.field_builder(i).context( - "Builder has incorrect type", + "Builder has incorrect type. Expected LargeListBuilder of Bool", )?; let value_builder = builder.values().as_any_mut().downcast_mut::().context("Could not downcast builder to boolean type")?; let values: PyReadonlyArray1 = value.extract().context("Could not convert object to array")?; @@ -383,9 +383,9 @@ impl DrawStorage for PyTrace { }, ExpandDtype::ArrayFloat64 { tensor_type } => { //let builder: &mut FixedSizeListBuilder> = - let builder: &mut ListBuilder> = + let builder: &mut LargeListBuilder> = self.builder.field_builder(i).context( - "Builder has incorrect type", + "Builder has incorrect type. Expected LargeListBuilder of Float64", )?; let value_builder = builder.values().as_any_mut().downcast_mut::>().context("Could not downcast builder to float64 type")?; let values: PyReadonlyArray1 = value.extract().context("Could not convert object to array")?; @@ -396,9 +396,9 @@ impl DrawStorage for PyTrace { builder.append(true); }, ExpandDtype::ArrayFloat32 { tensor_type } => { - let builder: &mut FixedSizeListBuilder> = + let builder: &mut LargeListBuilder> = self.builder.field_builder(i).context( - "Builder has incorrect type", + "Builder has incorrect type. Expected LargeListBuilder of Float32", )?; let value_builder = builder.values().as_any_mut().downcast_mut::>().context("Could not downcast builder to float32 type")?; let values: PyReadonlyArray1 = value.extract().context("Could not convert object to array")?; @@ -409,9 +409,9 @@ impl DrawStorage for PyTrace { builder.append(true); }, ExpandDtype::ArrayInt64 {tensor_type} => { - let builder: &mut FixedSizeListBuilder> = + let builder: &mut LargeListBuilder> = self.builder.field_builder(i).context( - "Builder has incorrect type", + "Builder has incorrect type. Expected LargeListBuilder of Int64", )?; let value_builder = builder.values().as_any_mut().downcast_mut::>().context("Could not downcast builder to i64 type")?; let values: PyReadonlyArray1 = value.extract().context("Could not convert object to array")?; diff --git a/src/pymc.rs b/src/pymc.rs index b7862cf..98426c7 100644 --- a/src/pymc.rs +++ b/src/pymc.rs @@ -2,7 +2,8 @@ use std::{ffi::c_void, fmt::Display, sync::Arc}; use anyhow::{bail, Context, Result}; use arrow::{ - array::{Array, FixedSizeListArray, Float64Array, StructArray}, + array::{Array, Float64Array, LargeListArray, StructArray}, + buffer::OffsetBuffer, datatypes::{DataType, Field, Fields}, }; use itertools::{izip, Itertools}; @@ -13,7 +14,6 @@ use pyo3::{ types::{PyAnyMethods, PyList}, Bound, Py, PyAny, PyObject, PyResult, Python, }; -use rand::{distributions::Uniform, prelude::Distribution}; use thiserror::Error; @@ -170,11 +170,13 @@ impl<'model> DrawStorage for PyMcTrace<'model> { fn finalize(self) -> Result> { let (fields, arrays): (Vec<_>, _) = izip!(self.data, self.var_names, self.var_sizes) .map(|(data, name, size)| { + assert!(data.len() % size == 0); + let num_arrays = data.len() / size; let data = Float64Array::from(data); let item_field = Arc::new(Field::new("item", DataType::Float64, false)); - let array = - FixedSizeListArray::new(item_field.clone(), size as _, Arc::new(data), None); - let field = Field::new(name, DataType::FixedSizeList(item_field, size as _), false); + let offsets = OffsetBuffer::from_lengths((0..num_arrays).into_iter().map(|_| size)); + let array = LargeListArray::new(item_field.clone(), offsets, Arc::new(data), None); + let field = Field::new(name, DataType::LargeList(item_field), false); (Arc::new(field), Arc::new(array) as Arc) }) .unzip(); diff --git a/tests/test_pymc.py b/tests/test_pymc.py index d59fb18..12fc75c 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -193,6 +193,53 @@ def test_pymc_model_shared(backend, gradient_backend): nutpie.sample(compiled3, chains=1) +@parameterize_backends +def test_pymc_var_names(backend, gradient_backend): + with pm.Model() as model: + mu = pm.Data("mu", -0.1) + sigma = pm.Data("sigma", np.ones(3)) + a = pm.Normal("a", mu=mu, sigma=sigma, shape=3) + + b = pm.Deterministic("b", mu * a) + pm.Deterministic("c", mu * b) + + compiled = nutpie.compile_pymc_model( + model, + backend=backend, + gradient_backend=gradient_backend, + var_names=None, + ) + trace = nutpie.sample(compiled, chains=1, seed=1) + + # Check that variables are stored + assert hasattr(trace.posterior, "b") + assert hasattr(trace.posterior, "c") + + compiled = nutpie.compile_pymc_model( + model, + backend=backend, + gradient_backend=gradient_backend, + var_names=[], + ) + trace = nutpie.sample(compiled, chains=1, seed=1) + + # Check that variables are stored + assert not hasattr(trace.posterior, "b") + assert not hasattr(trace.posterior, "c") + + compiled = nutpie.compile_pymc_model( + model, + backend=backend, + gradient_backend=gradient_backend, + var_names=["b"], + ) + trace = nutpie.sample(compiled, chains=1, seed=1) + + # Check that variables are stored + assert hasattr(trace.posterior, "b") + assert not hasattr(trace.posterior, "c") + + @pytest.mark.parametrize( ("backend", "gradient_backend"), [