diff --git a/Cargo.lock b/Cargo.lock index c433a6a1..046bd304 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,28 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "ahash" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" -dependencies = [ - "getrandom", - "once_cell", - "version_check", -] - -[[package]] -name = "ahash" -version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" -dependencies = [ - "cfg-if", - "once_cell", - "version_check", -] - [[package]] name = "aho-corasick" version = "1.0.2" @@ -93,36 +71,12 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6" -[[package]] -name = "ascii-canvas" -version = "3.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8824ecca2e851cec16968d54a01dd372ef8f95b244fb84b84e70128be347c3c6" -dependencies = [ - "term", -] - [[package]] name = "autocfg" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" -[[package]] -name = "bit-set" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" -dependencies = [ - "bit-vec", -] - -[[package]] -name = "bit-vec" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" - [[package]] name = "bitflags" version = "1.3.2" @@ -239,10 +193,10 @@ dependencies = [ ] [[package]] -name = "crunchy" -version = "0.2.2" +name = "crossbeam-utils" +version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crypto-common" @@ -264,27 +218,6 @@ dependencies = [ "crypto-common", ] -[[package]] -name = "dirs-next" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1" -dependencies = [ - "cfg-if", - "dirs-sys-next", -] - -[[package]] -name = "dirs-sys-next" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" -dependencies = [ - "libc", - "redox_users", - "winapi", -] - [[package]] name = "dot-generator" version = "0.2.0" @@ -302,41 +235,46 @@ checksum = "675e35c02a51bb4d4618cb4885b3839ce6d1787c97b664474d9208d074742e20" [[package]] name = "egglog" -version = "0.3.0" -source = "git+https://github.com/egraphs-good/egglog?rev=b0db06832264c9b22694bd3de2bdacd55bbe9e32#b0db06832264c9b22694bd3de2bdacd55bbe9e32" +version = "0.4.0" +source = "git+https://github.com/saulshanabrook/egg-smol?rev=889ca7635368d7e382e16a93b2883aba82f1078f#889ca7635368d7e382e16a93b2883aba82f1078f" dependencies = [ "chrono", "clap", "egraph-serialize", "env_logger", - "generic_symbolic_expressions", - "hashbrown 0.14.1", + "hashbrown 0.15.2", + "im", "im-rc", "indexmap", "instant", - "lalrpop", - "lalrpop-util 0.20.2", "lazy_static", "log", - "num-integer", - "num-rational", - "num-traits", + "num", "ordered-float", - "regex", "rustc-hash", - "serde_json", "smallvec", "symbol_table", "thiserror", ] +[[package]] +name = "egglog-experimental" +version = "0.1.0" +source = "git+https://github.com/egraphs-good/egglog-experimental?rev=8a1b3d6ad2723a8438f51f05027161e51f37917c#8a1b3d6ad2723a8438f51f05027161e51f37917c" +dependencies = [ + "egglog", + "lazy_static", + "num", +] + [[package]] name = "egglog_python" version = "8.0.1" dependencies = [ "egglog", + "egglog-experimental", "egraph-serialize", - "lalrpop-util 0.22.0", + "lalrpop-util", "log", "ordered-float", "pyo3", @@ -359,21 +297,6 @@ dependencies = [ "serde_json", ] -[[package]] -name = "either" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" - -[[package]] -name = "ena" -version = "0.14.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c533630cf40e9caa44bd91aadc88a75d75a4c3a12b4cfde353cbed41daa1e1f1" -dependencies = [ - "log", -] - [[package]] name = "env_logger" version = "0.10.0" @@ -421,10 +344,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" [[package]] -name = "fixedbitset" -version = "0.4.2" +name = "foldhash" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" +checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" [[package]] name = "generic-array" @@ -436,12 +359,6 @@ dependencies = [ "version_check", ] -[[package]] -name = "generic_symbolic_expressions" -version = "5.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "597eb584fb7cfd1935294fc3608a453fc35a58dfa9da4299c8fd3bc75a4c0b4b" - [[package]] name = "getrandom" version = "0.2.10" @@ -471,21 +388,19 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.12.3" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" -dependencies = [ - "ahash 0.7.6", -] +checksum = "7dfda62a12f55daeae5015f81b0baea145391cb4520f86c248fc615d72640d12" [[package]] name = "hashbrown" -version = "0.14.1" +version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7dfda62a12f55daeae5015f81b0baea145391cb4520f86c248fc615d72640d12" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" dependencies = [ - "ahash 0.8.3", "allocator-api2", + "equivalent", + "foldhash", ] [[package]] @@ -512,6 +427,20 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +[[package]] +name = "im" +version = "15.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0acd33ff0285af998aaf9b57342af478078f53492322fafc47450e09397e0e9" +dependencies = [ + "bitmaps", + "rand_core", + "rand_xoshiro", + "sized-chunks", + "typenum", + "version_check", +] + [[package]] name = "im-rc" version = "15.1.0" @@ -585,52 +514,12 @@ dependencies = [ "windows-sys", ] -[[package]] -name = "itertools" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" -dependencies = [ - "either", -] - [[package]] name = "itoa" version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" -[[package]] -name = "lalrpop" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55cb077ad656299f160924eb2912aa147d7339ea7d69e1b5517326fdcec3c1ca" -dependencies = [ - "ascii-canvas", - "bit-set", - "ena", - "itertools", - "lalrpop-util 0.20.2", - "petgraph", - "pico-args", - "regex", - "regex-syntax", - "string_cache", - "term", - "tiny-keccak", - "unicode-xid", - "walkdir", -] - -[[package]] -name = "lalrpop-util" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "507460a910eb7b32ee961886ff48539633b788a36b65692b95f225b844c82553" -dependencies = [ - "regex-automata", -] - [[package]] name = "lalrpop-util" version = "0.22.0" @@ -659,16 +548,6 @@ version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3852614a3bd9ca9804678ba6be5e3b8ce76dfc902cae004e3e0c44051b6e88db" -[[package]] -name = "lock_api" -version = "0.4.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" -dependencies = [ - "autocfg", - "scopeguard", -] - [[package]] name = "log" version = "0.4.22" @@ -691,39 +570,64 @@ dependencies = [ ] [[package]] -name = "new_debug_unreachable" -version = "1.0.4" +name = "num" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4a24736216ec316047a1fc4252e27dabb04218aa4a3f37c6e7ddbf1f9782b54" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] [[package]] name = "num-bigint" -version = "0.4.3" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93ab6289c7b344a8a9f60f88d80aa20032336fe78da341afc91c8a2341fc75f" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ - "autocfg", "num-integer", "num-traits", ] +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + [[package]] name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-iter" version = "0.1.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" dependencies = [ "autocfg", + "num-integer", "num-traits", ] [[package]] name = "num-rational" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" dependencies = [ - "autocfg", "num-bigint", "num-integer", "num-traits", @@ -731,9 +635,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.15" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", ] @@ -755,29 +659,6 @@ dependencies = [ "serde", ] -[[package]] -name = "parking_lot" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall 0.3.5", - "smallvec", - "windows-targets", -] - [[package]] name = "pest" version = "2.7.11" @@ -823,31 +704,6 @@ dependencies = [ "sha2", ] -[[package]] -name = "petgraph" -version = "0.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" -dependencies = [ - "fixedbitset", - "indexmap", -] - -[[package]] -name = "phf_shared" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096" -dependencies = [ - "siphasher", -] - -[[package]] -name = "pico-args" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5be167a7af36ee22fe3115051bc51f6e6c7054c9348e28deb4f49bd6f705a315" - [[package]] name = "portable-atomic" version = "1.6.0" @@ -863,12 +719,6 @@ dependencies = [ "zerocopy", ] -[[package]] -name = "precomputed-hash" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" - [[package]] name = "proc-macro2" version = "1.0.81" @@ -1002,15 +852,6 @@ dependencies = [ "rand_core", ] -[[package]] -name = "redox_syscall" -version = "0.2.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" -dependencies = [ - "bitflags 1.3.2", -] - [[package]] name = "redox_syscall" version = "0.3.5" @@ -1020,17 +861,6 @@ dependencies = [ "bitflags 1.3.2", ] -[[package]] -name = "redox_users" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" -dependencies = [ - "getrandom", - "redox_syscall 0.2.16", - "thiserror", -] - [[package]] name = "regex" version = "1.10.4" @@ -1091,21 +921,6 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" -[[package]] -name = "same-file" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" -dependencies = [ - "winapi-util", -] - -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - [[package]] name = "serde" version = "1.0.199" @@ -1150,12 +965,6 @@ dependencies = [ "digest", ] -[[package]] -name = "siphasher" -version = "0.3.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" - [[package]] name = "sized-chunks" version = "0.6.5" @@ -1172,19 +981,6 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" -[[package]] -name = "string_cache" -version = "0.8.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f91138e76242f575eb1d3b38b4f1362f10d3a43f47d182a5b359af488a02293b" -dependencies = [ - "new_debug_unreachable", - "once_cell", - "parking_lot", - "phf_shared", - "precomputed-hash", -] - [[package]] name = "strsim" version = "0.10.0" @@ -1193,12 +989,13 @@ checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" [[package]] name = "symbol_table" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "828f672b631c220bf6ea8a1d3b82c7d0fc998e5ba8373383d8604bc1e2a6245a" +checksum = "f19bffd69fb182e684d14e3c71d04c0ef33d1641ac0b9e81c712c734e83703bc" dependencies = [ - "ahash 0.7.6", - "hashbrown 0.12.3", + "crossbeam-utils", + "foldhash", + "hashbrown 0.15.2", ] [[package]] @@ -1237,22 +1034,11 @@ checksum = "cb94d2f3cc536af71caac6b6fcebf65860b347e7ce0cc9ebe8f70d3e521054ef" dependencies = [ "cfg-if", "fastrand", - "redox_syscall 0.3.5", + "redox_syscall", "rustix", "windows-sys", ] -[[package]] -name = "term" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c59df8ac95d96ff9bede18eb7300b0fda5e5d8d90960e76f8e14ae765eedbf1f" -dependencies = [ - "dirs-next", - "rustversion", - "winapi", -] - [[package]] name = "termcolor" version = "1.3.0" @@ -1282,15 +1068,6 @@ dependencies = [ "syn 2.0.60", ] -[[package]] -name = "tiny-keccak" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" -dependencies = [ - "crunchy", -] - [[package]] name = "typenum" version = "1.17.0" @@ -1309,12 +1086,6 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c" -[[package]] -name = "unicode-xid" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" - [[package]] name = "unindent" version = "0.2.3" @@ -1342,16 +1113,6 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" -[[package]] -name = "walkdir" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" -dependencies = [ - "same-file", - "winapi-util", -] - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index 5525a391..ceb65b94 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ name = "egglog_python" version = "8.0.1" edition = "2021" + # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib] name = "egglog" @@ -11,27 +12,23 @@ crate-type = ["cdylib"] [dependencies] pyo3 = { version = "0.22.5", features = ["extension-module"] } -# https://github.com/egraphs-good/egglog/compare/ceed816e9369570ffed9feeba157b19471dda70d...main -egglog = { git = "https://github.com/egraphs-good/egglog", rev = "b0db06832264c9b22694bd3de2bdacd55bbe9e32" } -# egglog = { path = "../egg-smol" } -# egglog = { git = "https://github.com/oflatt/egg-smol", branch = "oflatt-fast-terms" } -# egglog = { git = "https://github.com/saulshanabrook/egg-smol", rev = "a555b2f5e82c684442775cc1a5da94b71930113c" } +egglog = { git = "https://github.com/saulshanabrook/egg-smol", rev = "889ca7635368d7e382e16a93b2883aba82f1078f" } +egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental", rev = "8a1b3d6ad2723a8438f51f05027161e51f37917c" } egraph-serialize = { version = "0.2.0", features = ["serde", "graphviz"] } -# egraph-serialize = { path = "../egraph-serialize", features = [ -# "serde", -# "graphviz", -# ] } serde_json = "1.0.132" pyo3-log = "0.11.0" log = "0.4.22" lalrpop-util = { version = "0.22", features = ["lexer"] } -ordered-float = "3.7" +ordered-float = "3.7.0" uuid = { version = "1.11.0", features = ["v4"] } -# Use unreleased version of egraph-serialize in egglog as well -# [patch.crates-io] -# egraph-serialize = { git = "https://github.com/egraphs-good/egraph-serialize", rev = "5838c036623e91540831745b1574539e01c8cb23" } -# egraph-serialize = { path = "../egraph-serialize" } +# Use unreleased version of egglog in experimental +[patch.'https://github.com/egraphs-good/egglog'] +# https://github.com/rust-lang/cargo/issues/5478#issuecomment-522719793 +egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", rev = "889ca7635368d7e382e16a93b2883aba82f1078f" } + +# [replace] +# 'https://github.com/egraphs-good/egglog.git#egglog@0.3.0' = { git = "https://github.com/egraphs-good/egglog.git", rev = "215714e1cbb13ae9e21bed2f2e1bf95804571512" } # enable debug symbols for easier profiling # [profile.release] diff --git a/docs/changelog.md b/docs/changelog.md index ccca4290..98b903e1 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -14,6 +14,16 @@ _This project uses semantic versioning_ - Use `add_note` for exception to add more context, instead of raising a new exception, to make it easier to debug. - Add conversions from generic types to be supported at runtime and typing level (so can go from `(1, 2, 3)` to `TupleInt`) - Open files with webbrowser instead of internal graphviz util for better support +- Add support for not visualizing when using `.saturate()` method [#254](https://github.com/egraphs-good/egglog-python/pull/254) +- Upgrade [egglog](https://github.com/egraphs-good/egglog/compare/b0db06832264c9b22694bd3de2bdacd55bbe9e32...saulshanabrook:egg-smol:889ca7635368d7e382e16a93b2883aba82f1078f) + - This includes a few big changes to the underlying bindings, which I won't go over in full detail here. See the [pyi diff](https://github.com/egraphs-good/egglog-python/pull/258/files#diff-f34a5dd5d6568cd258ed9f786e5abce03df5ee95d356ea9e1b1b39e3505e5d62) for all public changes. + - Creates seperate parent classes for `BuiltinExpr` vs `Expr` (aka eqsort aka user defined expressions). This is to + allow us statically to differentiate between the two, to be more precise about what behavior is allowed. For example, + `union` can only take `Expr` and not `BuiltinExpr`. + - Removes deprecated support for modules and building functions off of the e-egraph. + - Updates function constructor to remove `default` and `on_merge`. You also can't set a `cost` when you use a `merge` + function or return a primitive. + - `eq` now only takes two args, instead of being able to compare any number of values. ## 8.0.1 (2024-10-24) diff --git a/docs/how-to-guides.md b/docs/how-to-guides.md index f51e975b..3e2abe89 100644 --- a/docs/how-to-guides.md +++ b/docs/how-to-guides.md @@ -6,13 +6,13 @@ file_format: mystnb ## Parsing and running program strings -You can provide your program in a special DSL language. You can parse this with {meth}`egglog.bindings.parse_program` and then run the result with You can parse this with {meth}`egglog.bindings.EGraph.run_program`:: +You can provide your program in a special DSL language. You can parse this with {meth}`egglog.bindings.EGraph.parse_program` and then run the result with You can parse this with {meth}`egglog.bindings.EGraph.run_program`:: ```{code-cell} -from egglog.bindings import EGraph, parse_program +from egglog.bindings import EGraph egraph = EGraph() -commands = parse_program("(check (= (+ 1 2) 3))") +commands = egraph.parse_program("(check (= (+ 1 2) 3))") commands ``` diff --git a/docs/reference/bindings.md b/docs/reference/bindings.md index 77199ce5..a677ec83 100644 --- a/docs/reference/bindings.md +++ b/docs/reference/bindings.md @@ -36,7 +36,7 @@ eqsat_basic = """(datatype Math (check (= expr1 expr2))""" egraph = EGraph() -commands = parse_program(eqsat_basic) +commands = egraph.parse_program(eqsat_basic) egraph.run_program(*commands) ``` diff --git a/docs/reference/egglog-translation.md b/docs/reference/egglog-translation.md index 9c30feb6..570be5f1 100644 --- a/docs/reference/egglog-translation.md +++ b/docs/reference/egglog-translation.md @@ -12,10 +12,7 @@ Any EGraph can also be converted to egglog with the `egraph.as_egglog_string` pr The currently unsupported features are: -- Proof mode: Not currently tested, but could add support if needed. -- Naive mode: Not currently exposed, but could add support - `(output ...)`: No examples in the tests, so not sure how this works. -- `(calc ...)`: Could be implemented, but haven't yet. ## Builtin Types @@ -124,6 +121,16 @@ def my_foo() -> i64: The static types on the decorator preserve the type of the underlying function, so that they can all be checked statically. +### Functions vs Constructors + +Egglog has changed how it handles functions, seperating them into two seperate commands: + +- `function` which can include a `merge` expression. +- `constructor` which can include a cost and requires the result to be an "eqsort" aka a non builtin type. + +Since this was added after the Python API was first created, we added support to automatically choose between the two based on the return type of the function and whether a merge function is provided. If the return type is a builtin type, it will be a `function`, otherwise it will be a `constructor`, unless it has a merge function +provided then it will always be a `function`. + ### Datatype functions In egglog, the `(datatype ...)` command can also be used to declare functions. All of the functions declared in this block return the type of the declared datatype. Similarly, in Python, any methods of an `Expr` will be registered automatically. These diff --git a/pyproject.toml b/pyproject.toml index 16dd80dc..bd08f40d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -227,6 +227,9 @@ strict_equality = true warn_unused_configs = true allow_redefinition = true exclude = ["__snapshots__", "_build", "^conftest.py$"] +# mypy_path = "python" +# explicit_package_bases = true +# namespace_packages = true [tool.maturin] python-source = "python" diff --git a/python/egglog/bindings.pyi b/python/egglog/bindings.pyi index bebcd798..1789469c 100644 --- a/python/egglog/bindings.pyi +++ b/python/egglog/bindings.pyi @@ -1,10 +1,84 @@ from datetime import timedelta -from fractions import Fraction from pathlib import Path from typing import TypeAlias from typing_extensions import final +__all__ = [ + "ActionCommand", + "AddRuleset", + "Best", + "BiRewriteCommand", + "Bool", + "Call", + "Change", + "Check", + "Constructor", + "Datatype", + "Datatypes", + "Delete", + "EGraph", + "EggSmolError", + "EgglogSpan", + "Eq", + "Expr_", + "Extract", + "Fact", + "Fail", + "Float", + "Function", + "IdentSort", + "Include", + "Input", + "Int", + "Let", + "Lit", + "NewSort", + "Output", + "Panic", + "PanicSpan", + "Pop", + "PrintFunction", + "PrintOverallStatistics", + "PrintSize", + "Push", + "PyObjectSort", + "QueryExtract", + "Relation", + "Repeat", + "Rewrite", + "RewriteCommand", + "Rule", + "RuleCommand", + "Run", + "RunConfig", + "RunReport", + "RunSchedule", + "RustSpan", + "Saturate", + "Schema", + "Sequence", + "SerializedEGraph", + "Set", + "SetOption", + "Simplify", + "Sort", + "SrcFile", + "String", + "SubVariants", + "Subsume", + "TermApp", + "TermDag", + "TermLit", + "TermVar", + "Union", + "Unit", + "UnstableCombinedRuleset", + "Var", + "Variant", + "Variants", +] + @final class SerializedEGraph: def inline_leaves(self) -> None: ... @@ -19,7 +93,6 @@ class PyObjectSort: def __init__(self) -> None: ... def store(self, __o: object, /) -> _Expr: ... -def parse_program(__input: str, /, filename: str | None = None) -> list[_Command]: ... @final class EGraph: def __init__( @@ -30,6 +103,7 @@ class EGraph: seminaive: bool = True, record: bool = False, ) -> None: ... + def parse_program(self, __input: str, /, filename: str | None = None) -> list[_Command]: ... def commands(self) -> str | None: ... def run_program(self, *commands: _Command) -> list[str]: ... def extract_report(self) -> _ExtractReport | None: ... @@ -47,7 +121,6 @@ class EGraph: def eval_f64(self, __expr: _Expr) -> float: ... def eval_string(self, __expr: _Expr) -> str: ... def eval_bool(self, __expr: _Expr) -> bool: ... - def eval_rational(self, __expr: _Expr) -> Fraction: ... @final class EggSmolError(Exception): @@ -57,20 +130,31 @@ class EggSmolError(Exception): # Spans ## +@final +class PanicSpan: + def __init__(self) -> None: ... + @final class SrcFile: - def __init__(self, name: str, contents: str | None = None) -> None: ... - name: str - contents: str | None + name: str | None + contents: str + def __init__(self, name: str | None, contents: str) -> None: ... @final -class Span: - def __init__(self, file: SrcFile, start: int, end: int) -> None: ... +class EgglogSpan: file: SrcFile - start: int - end: int + i: int + j: int + def __init__(self, file: SrcFile, i: int, j: int) -> None: ... -DUMMY_SPAN: Span = ... +@final +class RustSpan: + file: str + line: int + column: int + def __init__(self, file: str, line: int, column: int) -> None: ... + +_Span: TypeAlias = PanicSpan | EgglogSpan | RustSpan ## # Literals @@ -82,7 +166,7 @@ class Int: value: int @final -class F64: +class Float: value: float def __init__(self, value: float) -> None: ... @@ -100,7 +184,7 @@ class Bool: def __init__(self, b: bool) -> None: ... value: bool -_Literal: TypeAlias = Int | F64 | String | Bool | Unit +_Literal: TypeAlias = Int | Float | String | Bool | Unit ## # Expressions @@ -108,20 +192,20 @@ _Literal: TypeAlias = Int | F64 | String | Bool | Unit @final class Lit: - def __init__(self, span: Span, value: _Literal) -> None: ... - span: Span + def __init__(self, span: _Span, value: _Literal) -> None: ... + span: _Span value: _Literal @final class Var: - def __init__(self, span: Span, name: str) -> None: ... - span: Span + def __init__(self, span: _Span, name: str) -> None: ... + span: _Span name: str @final class Call: - def __init__(self, span: Span, name: str, args: list[_Expr]) -> None: ... - span: Span + def __init__(self, span: _Span, name: str, args: list[_Expr]) -> None: ... + span: _Span name: str args: list[_Expr] @@ -150,20 +234,16 @@ class TermApp: _Term: TypeAlias = TermLit | TermVar | TermApp -@final -class TermDag: - nodes: list[_Term] - hashcons: dict[_Term, int] - ## # Facts ## @final class Eq: - def __init__(self, span: Span, exprs: list[_Expr]) -> None: ... - span: Span - exprs: list[_Expr] + def __init__(self, span: _Span, left: _Expr, right: _Expr) -> None: ... + span: _Span + left: _Expr + right: _Expr @final class Fact: @@ -192,50 +272,50 @@ _Change: TypeAlias = Delete | Subsume @final class Let: - def __init__(self, span: Span, lhs: str, rhs: _Expr) -> None: ... - span: Span + def __init__(self, span: _Span, lhs: str, rhs: _Expr) -> None: ... + span: _Span lhs: str rhs: _Expr @final class Set: - def __init__(self, span: Span, lhs: str, args: list[_Expr], rhs: _Expr) -> None: ... - span: Span + def __init__(self, span: _Span, lhs: str, args: list[_Expr], rhs: _Expr) -> None: ... + span: _Span lhs: str args: list[_Expr] rhs: _Expr @final class Change: - span: Span + span: _Span change: _Change sym: str args: list[_Expr] - def __init__(self, span: Span, change: _Change, sym: str, args: list[_Expr]) -> None: ... + def __init__(self, span: _Span, change: _Change, sym: str, args: list[_Expr]) -> None: ... @final class Union: - def __init__(self, span: Span, lhs: _Expr, rhs: _Expr) -> None: ... - span: Span + def __init__(self, span: _Span, lhs: _Expr, rhs: _Expr) -> None: ... + span: _Span lhs: _Expr rhs: _Expr @final class Panic: - def __init__(self, span: Span, msg: str) -> None: ... - span: Span + def __init__(self, span: _Span, msg: str) -> None: ... + span: _Span msg: str @final class Expr_: # noqa: N801 - def __init__(self, span: Span, expr: _Expr) -> None: ... - span: Span + def __init__(self, span: _Span, expr: _Expr) -> None: ... + span: _Span expr: _Expr @final class Extract: - def __init__(self, span: Span, expr: _Expr, variants: _Expr) -> None: ... - span: Span + def __init__(self, span: _Span, expr: _Expr, variants: _Expr) -> None: ... + span: _Span expr: _Expr variants: _Expr @@ -245,35 +325,10 @@ _Action: TypeAlias = Let | Set | Change | Union | Panic | Expr_ | Extract # Other Structs ## -@final -class FunctionDecl: - span: Span - name: str - schema: Schema - default: _Expr | None - merge: _Expr | None - merge_action: list[_Action] - cost: int | None - unextractable: bool - ignore_viz: bool - - def __init__( - self, - span: Span, - name: str, - schema: Schema, - default: _Expr | None = None, - merge: _Expr | None = None, - merge_action: list[_Action] = [], - cost: int | None = None, - unextractable: bool = False, - ignore_viz: bool = False, - ) -> None: ... - @final class Variant: - def __init__(self, span: Span, name: str, types: list[str], cost: int | None = None) -> None: ... - span: Span + def __init__(self, span: _Span, name: str, types: list[str], cost: int | None = None) -> None: ... + span: _Span name: str types: list[str] cost: int | None @@ -286,19 +341,19 @@ class Schema: @final class Rule: - span: Span + span: _Span head: list[_Action] body: list[_Fact] - def __init__(self, span: Span, head: list[_Action], body: list[_Fact]) -> None: ... + def __init__(self, span: _Span, head: list[_Action], body: list[_Fact]) -> None: ... @final class Rewrite: - span: Span + span: _Span lhs: _Expr rhs: _Expr conditions: list[_Fact] - def __init__(self, span: Span, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = []) -> None: ... + def __init__(self, span: _Span, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = []) -> None: ... @final class RunConfig: @@ -354,28 +409,28 @@ _ExtractReport: TypeAlias = Variants | Best @final class Saturate: - span: Span + span: _Span schedule: _Schedule - def __init__(self, span: Span, schedule: _Schedule) -> None: ... + def __init__(self, span: _Span, schedule: _Schedule) -> None: ... @final class Repeat: - span: Span + span: _Span length: int schedule: _Schedule - def __init__(self, span: Span, length: int, schedule: _Schedule) -> None: ... + def __init__(self, span: _Span, length: int, schedule: _Schedule) -> None: ... @final class Run: - span: Span + span: _Span config: RunConfig - def __init__(self, span: Span, config: RunConfig) -> None: ... + def __init__(self, span: _Span, config: RunConfig) -> None: ... @final class Sequence: - span: Span + span: _Span schedules: list[_Schedule] - def __init__(self, span: Span, schedules: list[_Schedule]) -> None: ... + def __init__(self, span: _Span, schedules: list[_Schedule]) -> None: ... _Schedule: TypeAlias = Saturate | Repeat | Run | Sequence @@ -408,28 +463,31 @@ class SetOption: @final class Datatype: - span: Span + span: _Span name: str variants: list[Variant] - def __init__(self, span: Span, name: str, variants: list[Variant]) -> None: ... + def __init__(self, span: _Span, name: str, variants: list[Variant]) -> None: ... @final class Datatypes: - span: Span - datatypes: list[tuple[Span, str, _Subdatatypes]] - def __init__(self, span: Span, datatypes: list[tuple[Span, str, _Subdatatypes]]) -> None: ... + span: _Span + datatypes: list[tuple[_Span, str, _Subdatatypes]] + def __init__(self, span: _Span, datatypes: list[tuple[_Span, str, _Subdatatypes]]) -> None: ... @final class Sort: - span: Span + span: _Span name: str presort_and_args: tuple[str, list[_Expr]] | None - def __init__(self, span: Span, name: str, presort_and_args: tuple[str, list[_Expr]] | None = None) -> None: ... + def __init__(self, span: _Span, name: str, presort_and_args: tuple[str, list[_Expr]] | None = None) -> None: ... @final class Function: - decl: FunctionDecl - def __init__(self, decl: FunctionDecl) -> None: ... + span: _Span + name: str + schema: Schema + merge: _Expr | None + def __init__(self, span: _Span, name: str, schema: Schema, merge: _Expr | None) -> None: ... @final class AddRuleset: @@ -470,50 +528,50 @@ class RunSchedule: @final class Simplify: - span: Span + span: _Span expr: _Expr schedule: _Schedule - def __init__(self, span: Span, expr: _Expr, schedule: _Schedule) -> None: ... + def __init__(self, span: _Span, expr: _Expr, schedule: _Schedule) -> None: ... @final class QueryExtract: - span: Span + span: _Span variants: int expr: _Expr - def __init__(self, span: Span, variants: int, expr: _Expr) -> None: ... + def __init__(self, span: _Span, variants: int, expr: _Expr) -> None: ... @final class Check: - span: Span + span: _Span facts: list[_Fact] - def __init__(self, span: Span, facts: list[_Fact]) -> None: ... + def __init__(self, span: _Span, facts: list[_Fact]) -> None: ... @final class PrintFunction: - span: Span + span: _Span name: str length: int - def __init__(self, span: Span, name: str, length: int) -> None: ... + def __init__(self, span: _Span, name: str, length: int) -> None: ... @final class PrintSize: - span: Span + span: _Span name: str | None - def __init__(self, span: Span, name: str | None) -> None: ... + def __init__(self, span: _Span, name: str | None) -> None: ... @final class Output: - span: Span + span: _Span file: str exprs: list[_Expr] - def __init__(self, span: Span, file: str, exprs: list[_Expr]) -> None: ... + def __init__(self, span: _Span, file: str, exprs: list[_Expr]) -> None: ... @final class Input: - span: Span + span: _Span name: str file: str - def __init__(self, span: Span, name: str, file: str) -> None: ... + def __init__(self, span: _Span, name: str, file: str) -> None: ... @final class Push: @@ -522,29 +580,38 @@ class Push: @final class Pop: - span: Span + span: _Span length: int - def __init__(self, span: Span, length: int) -> None: ... + def __init__(self, span: _Span, length: int) -> None: ... @final class Fail: - span: Span + span: _Span command: _Command - def __init__(self, span: Span, command: _Command) -> None: ... + def __init__(self, span: _Span, command: _Command) -> None: ... @final class Include: - span: Span + span: _Span path: str - def __init__(self, span: Span, path: str) -> None: ... + def __init__(self, span: _Span, path: str) -> None: ... @final class Relation: - span: Span - constructor: str + span: _Span + name: str inputs: list[str] - def __init__(self, span: Span, constructor: str, inputs: list[str]) -> None: ... + def __init__(self, span: _Span, name: str, inputs: list[str]) -> None: ... + +@final +class Constructor: + span: _Span + name: str + schema: Schema + cost: int | None + unextractable: bool + def __init__(self, span: _Span, name: str, schema: Schema, cost: int | None, unextractable: bool) -> None: ... @final class PrintOverallStatistics: @@ -582,6 +649,22 @@ _Command: TypeAlias = ( | Relation | PrintOverallStatistics | UnstableCombinedRuleset + | Constructor ) -def termdag_term_to_expr(termdag: TermDag, term: _Term) -> _Expr: ... +## +# TermDag +## + +@final +class TermDag: + def __init__(self) -> None: ... + def size(self) -> int: ... + def lookup(self, node: _Term) -> int: ... + def get(self, id: int) -> _Term: ... + def app(self, sym: str, children: list[int]) -> _Term: ... + def lit(self, lit: _Literal) -> _Term: ... + def var(self, sym: str) -> _Term: ... + def expr_to_term(self, expr: _Expr) -> _Term: ... + def term_to_expr(self, term: _Term, span: _Span) -> _Expr: ... + def to_string(self, term: _Term) -> str: ... diff --git a/python/egglog/builtins.py b/python/egglog/builtins.py index 9818ccaa..e6fd3c42 100644 --- a/python/egglog/builtins.py +++ b/python/egglog/builtins.py @@ -12,7 +12,7 @@ from typing_extensions import TypeVarTuple, Unpack from .conversion import convert, converter, get_type_args -from .egraph import Expr, Unit, function, get_current_ruleset, method +from .egraph import BaseExpr, BuiltinExpr, Unit, function, get_current_ruleset, method from .functionalize import functionalize from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction from .thunk import Thunk @@ -46,7 +46,7 @@ ] -class String(Expr, builtin=True): +class String(BuiltinExpr): def __init__(self, value: str) -> None: ... @method(egg_fn="replace") @@ -65,7 +65,7 @@ def join(*strings: StringLike) -> String: ... BoolLike = Union["Bool", bool] -class Bool(Expr, egg_sort="bool", builtin=True): +class Bool(BuiltinExpr, egg_sort="bool"): def __init__(self, value: bool) -> None: ... @method(egg_fn="not") @@ -90,7 +90,7 @@ def implies(self, other: BoolLike) -> Bool: ... i64Like: TypeAlias = Union["i64", int] # noqa: N816, PYI042 -class i64(Expr, builtin=True): # noqa: N801 +class i64(BuiltinExpr): # noqa: N801 def __init__(self, value: int) -> None: ... @method(egg_fn="+") @@ -192,7 +192,7 @@ def count_matches(s: StringLike, pattern: StringLike) -> i64: ... f64Like: TypeAlias = Union["f64", float] # noqa: N816, PYI042 -class f64(Expr, builtin=True): # noqa: N801 +class f64(BuiltinExpr): # noqa: N801 def __init__(self, value: float) -> None: ... @method(egg_fn="neg") @@ -260,11 +260,11 @@ def to_string(self) -> String: ... converter(float, f64, f64) -T = TypeVar("T", bound=Expr) -V = TypeVar("V", bound=Expr) +T = TypeVar("T", bound=BaseExpr) +V = TypeVar("V", bound=BaseExpr) -class Map(Expr, Generic[T, V], builtin=True): +class Map(BuiltinExpr, Generic[T, V]): @method(egg_fn="map-empty") @classmethod def empty(cls) -> Map[T, V]: ... @@ -304,7 +304,7 @@ def rebuild(self) -> Map[T, V]: ... MapLike: TypeAlias = Map[T, V] | dict[TO, VO] -class Set(Expr, Generic[T], builtin=True): +class Set(BuiltinExpr, Generic[T]): @method(egg_fn="set-of") def __init__(self, *args: T) -> None: ... @@ -348,7 +348,7 @@ def rebuild(self) -> Set[T]: ... SetLike: TypeAlias = Set[T] | set[TO] -class Rational(Expr, builtin=True): +class Rational(BuiltinExpr): @method(egg_fn="rational") def __init__(self, num: i64Like, den: i64Like) -> None: ... @@ -409,7 +409,7 @@ def numer(self) -> i64: ... def denom(self) -> i64: ... -class Vec(Expr, Generic[T], builtin=True): +class Vec(BuiltinExpr, Generic[T]): @method(egg_fn="vec-of") def __init__(self, *args: T) -> None: ... @@ -460,7 +460,7 @@ def set(self, index: i64Like, value: T) -> Vec[T]: ... VecLike: TypeAlias = Vec[T] | tuple[TO, ...] | list[TO] -class PyObject(Expr, builtin=True): +class PyObject(BuiltinExpr): def __init__(self, value: object) -> None: ... @method(egg_fn="py-from-string") @@ -530,7 +530,7 @@ def py_exec(code: StringLike, globals: object = PyObject.dict(), locals: object T3 = TypeVar("T3") -class UnstableFn(Expr, Generic[T, Unpack[TS]], builtin=True): +class UnstableFn(BuiltinExpr, Generic[T, Unpack[TS]]): @overload def __init__(self, f: Callable[[Unpack[TS]], T]) -> None: ... diff --git a/python/egglog/conversion.py b/python/egglog/conversion.py index 52474b2a..a0adaea0 100644 --- a/python/egglog/conversion.py +++ b/python/egglog/conversion.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Generator - from .egraph import Expr + from .egraph import BaseExpr __all__ = ["convert", "convert_to_same_type", "converter", "resolve_literal", "ConvertError"] # Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target @@ -33,7 +33,7 @@ def _retrieve_conversion_decls() -> Declarations: T = TypeVar("T") -V = TypeVar("V", bound="Expr") +V = TypeVar("V", bound="BaseExpr") class ConvertError(Exception): diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index 3d1cee8f..ac989e12 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -32,6 +32,7 @@ "CommandDecl", "ConstantDecl", "ConstantRef", + "ConstructorDecl", "Declarations", "Declarations", "DeclerationsLike", @@ -121,7 +122,7 @@ def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[D @dataclass class Declarations: _unnamed_functions: set[UnnamedFunctionRef] = field(default_factory=set) - _functions: dict[str, FunctionDecl | RelationDecl] = field(default_factory=dict) + _functions: dict[str, FunctionDecl | RelationDecl | ConstructorDecl] = field(default_factory=dict) _constants: dict[str, ConstantDecl] = field(default_factory=dict) _classes: dict[str, ClassDecl] = field(default_factory=dict) _rulesets: dict[str, RulesetDecl | CombinedRulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])}) @@ -198,11 +199,14 @@ def get_callable_decl(self, ref: CallableRef) -> CallableDecl: # noqa: PLR0911 assert init_fn return init_fn case UnnamedFunctionRef(): - return ref.to_function_decl() + return ConstructorDecl(ref.signature) + assert_never(ref) def set_function_decl( - self, ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef, decl: FunctionDecl + self, + ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef, + decl: FunctionDecl | ConstructorDecl, ) -> None: match ref: case FunctionRef(name): @@ -242,12 +246,12 @@ class ClassDecl: egg_name: str | None = None type_vars: tuple[ClassTypeVarRef, ...] = () builtin: bool = False - init: FunctionDecl | None = None - class_methods: dict[str, FunctionDecl] = field(default_factory=dict) + init: ConstructorDecl | FunctionDecl | None = None + class_methods: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict) # These have to be seperate from class_methods so that printing them can be done easily class_variables: dict[str, ConstantDecl] = field(default_factory=dict) - methods: dict[str, FunctionDecl] = field(default_factory=dict) - properties: dict[str, FunctionDecl] = field(default_factory=dict) + methods: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict) + properties: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict) preserved_methods: dict[str, Callable] = field(default_factory=dict) @@ -350,20 +354,19 @@ class UnnamedFunctionRef: args: tuple[TypedExprDecl, ...] res: TypedExprDecl - def to_function_decl(self) -> FunctionDecl: + @property + def signature(self) -> FunctionSignature: arg_types = [] arg_names = [] for a in self.args: arg_types.append(a.tp.to_var()) assert isinstance(a.expr, VarDecl) arg_names.append(a.expr.name) - return FunctionDecl( - FunctionSignature( - arg_types=tuple(arg_types), - arg_names=tuple(arg_names), - arg_defaults=(None,) * len(self.args), - return_type=self.res.tp.to_var(), - ), + return FunctionSignature( + arg_types=tuple(arg_types), + arg_names=tuple(arg_names), + arg_defaults=(None,) * len(self.args), + return_type=self.res.tp.to_var(), ) @property @@ -434,16 +437,13 @@ class RelationDecl: arg_defaults: tuple[ExprDecl | None, ...] egg_name: str | None - def to_function_decl(self) -> FunctionDecl: - return FunctionDecl( - FunctionSignature( - arg_types=tuple(a.to_var() for a in self.arg_types), - arg_names=tuple(f"__{i}" for i in range(len(self.arg_types))), - arg_defaults=self.arg_defaults, - return_type=TypeRefWithVars("Unit"), - ), - egg_name=self.egg_name, - default=LitDecl(None), + @property + def signature(self) -> FunctionSignature: + return FunctionSignature( + arg_types=tuple(a.to_var() for a in self.arg_types), + arg_names=tuple(f"__{i}" for i in range(len(self.arg_types))), + arg_defaults=self.arg_defaults, + return_type=TypeRefWithVars("Unit"), ) @@ -456,11 +456,9 @@ class ConstantDecl: type_ref: JustTypeRef egg_name: str | None = None - def to_function_decl(self) -> FunctionDecl: - return FunctionDecl( - FunctionSignature(return_type=self.type_ref.to_var()), - egg_name=self.egg_name, - ) + @property + def signature(self) -> FunctionSignature: + return FunctionSignature(return_type=self.type_ref.to_var()) # special cases for partial function creation and application, which cannot use the normal python rules @@ -492,20 +490,20 @@ def mutates(self) -> bool: @dataclass(frozen=True) class FunctionDecl: signature: FunctionSignature | SpecialFunctions = field(default_factory=FunctionSignature) - # Egg params builtin: bool = False egg_name: str | None = None - cost: int | None = None - default: ExprDecl | None = None - on_merge: tuple[ActionDecl, ...] = () merge: ExprDecl | None = None - unextractable: bool = False - def to_function_decl(self) -> FunctionDecl: - return self + +@dataclass(frozen=True) +class ConstructorDecl: + signature: FunctionSignature = field(default_factory=FunctionSignature) + egg_name: str | None = None + cost: int | None = None + unextractable: bool = False -CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl +CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl | ConstructorDecl ## # Expressions @@ -697,7 +695,8 @@ class RunDecl: @dataclass(frozen=True) class EqDecl: tp: JustTypeRef - exprs: tuple[ExprDecl, ...] + left: ExprDecl + right: ExprDecl @dataclass(frozen=True) diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index be2c9319..adeea9d2 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -4,7 +4,6 @@ import inspect import pathlib import tempfile -from abc import abstractmethod from collections.abc import Callable, Generator, Iterable from contextvars import ContextVar, Token from dataclasses import InitVar, dataclass, field @@ -17,6 +16,7 @@ ClassVar, Generic, Literal, + Never, NoReturn, TypeAlias, TypedDict, @@ -27,7 +27,7 @@ ) import graphviz -from typing_extensions import ParamSpec, Self, Unpack, assert_never, deprecated +from typing_extensions import ParamSpec, Self, Unpack, assert_never from . import bindings from .conversion import * @@ -36,7 +36,6 @@ from .ipython_magic import IN_IPYTHON from .pretty import pretty_decl from .runtime import * -from .runtime import resolve_type_annotation_mutate from .thunk import * if TYPE_CHECKING: @@ -48,11 +47,12 @@ "Command", "Command", "EGraph", + "BuiltinExpr", + "BaseExpr", "Expr", "Fact", "Fact", "GraphvizKwargs", - "Module", "RewriteOrRule", "Ruleset", "Schedule", @@ -63,7 +63,6 @@ "_RewriteBuilder", "_SetBuilder", "_UnionBuilder", - "action_command", "birewrite", "check", "check_eq", @@ -95,13 +94,14 @@ T = TypeVar("T") P = ParamSpec("P") -TYPE = TypeVar("TYPE", bound="type[Expr]") -CALLABLE = TypeVar("CALLABLE", bound=Callable) +EXPR_TYPE = TypeVar("EXPR_TYPE", bound="type[Expr]") +BASE_EXPR_TYPE = TypeVar("BASE_EXPR_TYPE", bound="type[BaseExpr]") EXPR = TypeVar("EXPR", bound="Expr") -E1 = TypeVar("E1", bound="Expr") -E2 = TypeVar("E2", bound="Expr") -E3 = TypeVar("E3", bound="Expr") -E4 = TypeVar("E4", bound="Expr") +BASE_EXPR = TypeVar("BASE_EXPR", bound="BaseExpr") +BE1 = TypeVar("BE1", bound="BaseExpr") +BE2 = TypeVar("BE2", bound="BaseExpr") +BE3 = TypeVar("BE3", bound="BaseExpr") +BE4 = TypeVar("BE4", bound="BaseExpr") # Attributes which are sometimes added to classes by the interpreter or the dataclass decorator, or by ipython. # We ignore these when inspecting the class. @@ -186,212 +186,15 @@ def check(x: FactLike, schedule: Schedule | None = None, *given: ActionLike) -> egraph.check(x) -@dataclass -class _BaseModule: - """ - Base Module which provides methods to register sorts, expressions, actions etc. - - Inherited by: - - EGraph: Holds a live EGraph instance - - Builtins: Stores a list of the builtins which have already been pre-regsietered - - Module: Stores a list of commands and additional declerations - """ - - # TODO: If we want to preserve existing semantics, then we use the module to find the default schedules - # and add them to the - - modules: InitVar[list[Module]] = [] # noqa: RUF008 - - # TODO: Move commands to Decleraration instance. Pass in is_builtins to declerations so we can skip adding commands for those. Pass in from module, set as argument of module and subclcass - - # Any modules you want to depend on - # # All dependencies flattened - _flatted_deps: list[Module] = field(init=False, default_factory=list) - # _mod_decls: ModuleDeclarations = field(init=False) - - def __post_init__(self, modules: list[Module]) -> None: - for mod in modules: - for child_mod in [*mod._flatted_deps, mod]: - if child_mod not in self._flatted_deps: - self._flatted_deps.append(child_mod) - - @deprecated("Remove this decorator and move the egg_sort to the class statement, i.e. E(Expr, egg_sort='MySort').") - @overload - def class_(self, *, egg_sort: str) -> Callable[[TYPE], TYPE]: ... - - @deprecated("Remove this decorator. Simply subclassing Expr is enough now.") - @overload - def class_(self, cls: TYPE, /) -> TYPE: ... - - def class_(self, *args, **kwargs) -> Any: - """ - Registers a class. - """ - if kwargs: - msg = "Switch to subclassing from Expr and passing egg_sort as a keyword arg to the class constructor" - raise NotImplementedError(msg) - - assert len(args) == 1 - return args[0] - - @overload - def method( - self, - *, - preserve: Literal[True], - ) -> Callable[[CALLABLE], CALLABLE]: ... - - @overload - def method( - self, - *, - egg_fn: str | None = None, - cost: int | None = None, - merge: Callable[[Any, Any], Any] | None = None, - on_merge: Callable[[Any, Any], Iterable[ActionLike]] | None = None, - mutates_self: bool = False, - unextractable: bool = False, - ) -> Callable[[CALLABLE], CALLABLE]: ... - - @overload - def method( - self, - *, - egg_fn: str | None = None, - cost: int | None = None, - default: EXPR | None = None, - merge: Callable[[EXPR, EXPR], EXPR] | None = None, - on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None, - mutates_self: bool = False, - unextractable: bool = False, - ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ... - - @deprecated("Use top level method function instead") - def method( - self, - *, - egg_fn: str | None = None, - cost: int | None = None, - default: EXPR | None = None, - merge: Callable[[EXPR, EXPR], EXPR] | None = None, - on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None, - preserve: bool = False, - mutates_self: bool = False, - unextractable: bool = False, - ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: - return lambda fn: _WrappedMethod( - egg_fn, cost, default, merge, on_merge, fn, preserve, mutates_self, unextractable, False - ) - - @overload - def function(self, fn: CALLABLE, /) -> CALLABLE: ... - - @overload - def function( - self, - *, - egg_fn: str | None = None, - cost: int | None = None, - merge: Callable[[Any, Any], Any] | None = None, - on_merge: Callable[[Any, Any], Iterable[ActionLike]] | None = None, - mutates_first_arg: bool = False, - unextractable: bool = False, - ) -> Callable[[CALLABLE], CALLABLE]: ... - - @overload - def function( - self, - *, - egg_fn: str | None = None, - cost: int | None = None, - default: EXPR | None = None, - merge: Callable[[EXPR, EXPR], EXPR] | None = None, - on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None, - mutates_first_arg: bool = False, - unextractable: bool = False, - ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ... - - @deprecated("Use top level function `function` instead") - def function(self, *args, **kwargs) -> Any: - """ - Registers a function. - """ - fn_locals = currentframe().f_back.f_back.f_locals # type: ignore[union-attr] - # If we have any positional args, then we are calling it directly on a function - if args: - assert len(args) == 1 - return _FunctionConstructor(fn_locals)(args[0]) - # otherwise, we are passing some keyword args, so save those, and then return a partial - return _FunctionConstructor(fn_locals, **kwargs) - - @deprecated("Use top level `ruleset` function instead") - def ruleset(self, name: str) -> Ruleset: - return Ruleset(name) - - # Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value - @overload - def relation( - self, name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], tp4: type[E4], / - ) -> Callable[[E1, E2, E3, E4], Unit]: ... - - @overload - def relation(self, name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], /) -> Callable[[E1, E2, E3], Unit]: ... - - @overload - def relation(self, name: str, tp1: type[E1], tp2: type[E2], /) -> Callable[[E1, E2], Unit]: ... - - @overload - def relation(self, name: str, tp1: type[T], /, *, egg_fn: str | None = None) -> Callable[[T], Unit]: ... - - @overload - def relation(self, name: str, /, *, egg_fn: str | None = None) -> Callable[[], Unit]: ... - - @deprecated("Use top level relation function instead") - def relation(self, name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..., Unit]: - """ - Defines a relation, which is the same as a function which returns unit. - """ - return relation(name, *tps, egg_fn=egg_fn) - - @deprecated("Use top level constant function instead") - def constant(self, name: str, tp: type[EXPR], egg_name: str | None = None) -> EXPR: - """ - - Defines a named constant of a certain type. - - This is the same as defining a nullary function with a high cost. - # TODO: Rename as declare to match eggglog? - """ - return constant(name, tp, egg_name=egg_name) - - def register( - self, - /, - command_or_generator: ActionLike | RewriteOrRule | RewriteOrRuleGenerator, - *command_likes: ActionLike | RewriteOrRule, - ) -> None: - """ - Registers any number of rewrites or rules. - """ - if isinstance(command_or_generator, FunctionType): - assert not command_likes - current_frame = inspect.currentframe() - assert current_frame - original_frame = current_frame.f_back - assert original_frame - command_likes = tuple(_rewrite_or_rule_generator(command_or_generator, original_frame)) - else: - command_likes = (cast(CommandLike, command_or_generator), *command_likes) - commands = [_command_like(c) for c in command_likes] - self._register_commands(commands) +# We seperate the function and method overloads to make it simpler to know if we are modifying a function or method, +# So that we can add the functions eagerly to the registry and wait on the methods till we process the class. - @abstractmethod - def _register_commands(self, cmds: list[Command]) -> None: - raise NotImplementedError +CALLABLE = TypeVar("CALLABLE", bound=Callable) +CONSTRUCTOR_CALLABLE = TypeVar("CONSTRUCTOR_CALLABLE", bound=Callable[..., "Expr | None"]) -# We seperate the function and method overloads to make it simpler to know if we are modifying a function or method, -# So that we can add the functions eagerly to the registry and wait on the methods till we process the class. +EXPR_NONE = TypeVar("EXPR_NONE", bound="Expr | None") +BASE_EXPR_NONE = TypeVar("BASE_EXPR_NONE", bound="BaseExpr | None") @overload @@ -401,58 +204,121 @@ def method( ) -> Callable[[CALLABLE], CALLABLE]: ... -# We have to seperate method/function overloads for those that use the T params and those that don't -# Otherwise, if you say just pass in `cost` then the T param is inferred as `Nothing` and -# It will break the typing. +# function wihout merge +@overload +def method( + *, + egg_fn: str | None = ..., + mutates_self: bool = ..., +) -> Callable[[CALLABLE], CALLABLE]: ... +# function @overload def method( *, - egg_fn: str | None = None, - cost: int | None = None, - merge: Callable[[Any, Any], Any] | None = None, - on_merge: Callable[[Any, Any], Iterable[ActionLike]] | None = None, - mutates_self: bool = False, - unextractable: bool = False, - subsume: bool = False, -) -> Callable[[CALLABLE], CALLABLE]: ... + egg_fn: str | None = ..., + merge: Callable[[BASE_EXPR, BASE_EXPR], BASE_EXPR] | None = ..., + mutates_self: bool = ..., +) -> Callable[[Callable[P, BASE_EXPR]], Callable[P, BASE_EXPR]]: ... +# constructor @overload def method( *, - egg_fn: str | None = None, - cost: int | None = None, - default: EXPR | None = None, - merge: Callable[[EXPR, EXPR], EXPR] | None = None, - on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None, - mutates_self: bool = False, - unextractable: bool = False, - subsume: bool = False, -) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ... + egg_fn: str | None = ..., + cost: int | None = ..., + mutates_self: bool = ..., + unextractable: bool = ..., + subsume: bool = ..., +) -> Callable[[Callable[P, EXPR_NONE]], Callable[P, EXPR_NONE]]: ... def method( *, egg_fn: str | None = None, cost: int | None = None, - default: EXPR | None = None, - merge: Callable[[EXPR, EXPR], EXPR] | None = None, - on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None, + merge: Callable[[BASE_EXPR, BASE_EXPR], BASE_EXPR] | None = None, preserve: bool = False, mutates_self: bool = False, unextractable: bool = False, subsume: bool = False, -) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: +) -> Callable[[Callable[P, BASE_EXPR_NONE]], Callable[P, BASE_EXPR_NONE]]: """ Any method can be decorated with this to customize it's behavior. This is only supported in classes which subclass :class:`Expr`. """ + merge = cast(Callable[[object, object], object], merge) return lambda fn: _WrappedMethod( - egg_fn, cost, default, merge, on_merge, fn, preserve, mutates_self, unextractable, subsume + egg_fn, + cost, + merge, + fn, + preserve, + mutates_self, + unextractable, + subsume, ) +@overload +def function(fn: CALLABLE, /) -> CALLABLE: ... + + +# function without merge +@overload +def function( + *, + egg_fn: str | None = ..., + builtin: bool = ..., + mutates_first_arg: bool = ..., +) -> Callable[[CALLABLE], CALLABLE]: ... + + +# function +@overload +def function( + *, + egg_fn: str | None = ..., + merge: Callable[[BASE_EXPR, BASE_EXPR], BASE_EXPR] | None = ..., + builtin: bool = ..., + mutates_first_arg: bool = ..., +) -> Callable[[Callable[P, BASE_EXPR]], Callable[P, BASE_EXPR]]: ... + + +# constructor +@overload +def function( + *, + egg_fn: str | None = ..., + cost: int | None = ..., + mutates_first_arg: bool = ..., + unextractable: bool = ..., + ruleset: Ruleset | None = ..., + use_body_as_name: bool = ..., + subsume: bool = ..., +) -> Callable[[CONSTRUCTOR_CALLABLE], CONSTRUCTOR_CALLABLE]: ... + + +def function(*args, **kwargs) -> Any: + """ + Decorate a function typing stub to create an egglog function for it. + + If a body is included, it will be added to the `ruleset` passed in as a default rewrite. + + This will default to creating a "constructor" in egglog, unless a merge function is passed in or the return + type is a primtive, then it will be a "function". + """ + fn_locals = currentframe().f_back.f_locals # type: ignore[union-attr] + + # If we have any positional args, then we are calling it directly on a function + if args: + assert len(args) == 1 + return _FunctionConstructor(fn_locals)(args[0]) + # otherwise, we are passing some keyword args, so save those, and then return a partial + return _FunctionConstructor(fn_locals, **kwargs) + + class _ExprMetaclass(type): """ Metaclass of Expr. @@ -466,12 +332,12 @@ def __new__( # type: ignore[misc] bases: tuple[type, ...], namespace: dict[str, Any], egg_sort: str | None = None, - builtin: bool = False, ruleset: Ruleset | None = None, ) -> RuntimeClass | type: # If this is the Expr subclass, just return the class - if not bases: + if not bases or bases == (BaseExpr,): return super().__new__(cls, name, bases, namespace) + builtin = BuiltinExpr in bases # TODO: Raise error on subclassing or multiple inheritence frame = currentframe() @@ -502,6 +368,30 @@ def __instancecheck__(cls, instance: object) -> bool: return isinstance(instance, RuntimeExpr) +class BaseExpr(metaclass=_ExprMetaclass): + """ + Either a builtin or a user defined expression type. + """ + + def __ne__(self, other: NoReturn) -> NoReturn: # type: ignore[override, empty-body] + ... + + def __eq__(self, other: NoReturn) -> NoReturn: # type: ignore[override, empty-body] + ... + + +class BuiltinExpr(BaseExpr, metaclass=_ExprMetaclass): + """ + A builtin expr type, not an eqsort. + """ + + +class Expr(BaseExpr, metaclass=_ExprMetaclass): + """ + Subclass this to define a custom expression type. + """ + + def _generate_class_decls( # noqa: C901,PLR0912 namespace: dict[str, Any], frame: FrameType, @@ -560,10 +450,10 @@ def _generate_class_decls( # noqa: C901,PLR0912 if is_init and cls_name in LIT_CLASS_NAMES: continue match method: - case _WrappedMethod(egg_fn, cost, default, merge, on_merge, fn, preserve, mutates, unextractable, subsume): + case _WrappedMethod(egg_fn, cost, merge, fn, preserve, mutates, unextractable, subsume): pass case _: - egg_fn, cost, default, merge, on_merge = None, None, None, None, None + egg_fn, cost, merge = None, None, None fn = method unextractable, preserve, subsume = False, False, False mutates = method_name in ALWAYS_MUTATES_SELF @@ -599,10 +489,8 @@ def _generate_class_decls( # noqa: C901,PLR0912 ref, fn, locals, - default, cost, merge, - on_merge, mutates, builtin, ruleset=ruleset, @@ -624,58 +512,6 @@ def _generate_class_decls( # noqa: C901,PLR0912 return decls -@overload -def function(fn: CALLABLE, /) -> CALLABLE: ... - - -@overload -def function( - *, - egg_fn: str | None = None, - cost: int | None = None, - merge: Callable[[Any, Any], Any] | None = None, - on_merge: Callable[[Any, Any], Iterable[ActionLike]] | None = None, - mutates_first_arg: bool = False, - unextractable: bool = False, - builtin: bool = False, - ruleset: Ruleset | None = None, - use_body_as_name: bool = False, - subsume: bool = False, -) -> Callable[[CALLABLE], CALLABLE]: ... - - -@overload -def function( - *, - egg_fn: str | None = None, - cost: int | None = None, - default: EXPR | None = None, - merge: Callable[[EXPR, EXPR], EXPR] | None = None, - on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None, - mutates_first_arg: bool = False, - unextractable: bool = False, - ruleset: Ruleset | None = None, - use_body_as_name: bool = False, - subsume: bool = False, -) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ... - - -def function(*args, **kwargs) -> Any: - """ - Defined by a unique name and a typing relation that will specify the return type based on the types of the argument expressions. - - - """ - fn_locals = currentframe().f_back.f_locals # type: ignore[union-attr] - - # If we have any positional args, then we are calling it directly on a function - if args: - assert len(args) == 1 - return _FunctionConstructor(fn_locals)(args[0]) - # otherwise, we are passing some keyword args, so save those, and then return a partial - return _FunctionConstructor(fn_locals, **kwargs) - - @dataclass class _FunctionConstructor: hint_locals: dict[str, Any] @@ -683,18 +519,16 @@ class _FunctionConstructor: mutates_first_arg: bool = False egg_fn: str | None = None cost: int | None = None - default: RuntimeExpr | None = None - merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None - on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None + merge: Callable[[object, object], object] | None = None unextractable: bool = False ruleset: Ruleset | None = None use_body_as_name: bool = False subsume: bool = False - def __call__(self, fn: Callable[..., RuntimeExpr]) -> RuntimeFunction: + def __call__(self, fn: Callable) -> RuntimeFunction: return RuntimeFunction(*split_thunk(Thunk.fn(self.create_decls, fn))) - def create_decls(self, fn: Callable[..., RuntimeExpr]) -> tuple[Declarations, CallableRef]: + def create_decls(self, fn: Callable) -> tuple[Declarations, CallableRef]: decls = Declarations() ref = None if self.use_body_as_name else FunctionRef(fn.__name__) ref, add_rewrite = _fn_decl( @@ -703,10 +537,8 @@ def create_decls(self, fn: Callable[..., RuntimeExpr]) -> tuple[Declarations, Ca ref, fn, self.hint_locals, - self.default, self.cost, self.merge, - self.on_merge, self.mutates_first_arg, self.builtin, ruleset=self.ruleset, @@ -726,10 +558,8 @@ def _fn_decl( # Pass in the locals, retrieved from the frame when wrapping, # so that we support classes and function defined inside of other functions (which won't show up in the globals) hint_locals: dict[str, Any], - default: RuntimeExpr | None, cost: int | None, - merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None, - on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None, + merge: Callable[[object, object], object] | None, mutates_first_arg: bool, is_builtin: bool, subsume: bool, @@ -794,28 +624,20 @@ def _fn_decl( arg_names = tuple(t.name for t in params) - decls |= default merged = ( None if merge is None - else merge( - RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old", False))), - RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new", False))), - ) - ) - decls |= merged - - merge_action = ( - [] - if on_merge is None - else _action_likes( - on_merge( + else resolve_literal( + return_type, + merge( RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old", False))), RuntimeExpr.__from_values__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new", False))), - ) + ), + lambda: decls, ) ) - decls.update(*merge_action) + decls |= merged + # defer this in generator so it doesnt resolve for builtins eagerly args = (TypedExprDecl(tp.to_just(), VarDecl(name, False)) for name, tp in zip(arg_names, arg_types, strict=True)) res_ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef | UnnamedFunctionRef @@ -830,6 +652,10 @@ def _fn_decl( res_thunk = Thunk.value(res) else: + return_type_is_eqsort = ( + not decls._classes[return_type.name].builtin if isinstance(return_type, TypeRefWithVars) else False + ) + is_constructor = not is_builtin and return_type_is_eqsort and merged is None signature_ = FunctionSignature( return_type=None if mutates_first_arg else return_type, var_arg_type=var_arg_type, @@ -837,16 +663,22 @@ def _fn_decl( arg_names=arg_names, arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults), ) - decl = FunctionDecl( - signature=signature_, - cost=cost, - egg_name=egg_name, - merge=merged.__egg_typed_expr__.expr if merged is not None else None, - unextractable=unextractable, - builtin=is_builtin, - default=None if default is None else default.__egg_typed_expr__.expr, - on_merge=tuple(a.action for a in merge_action), - ) + decl: ConstructorDecl | FunctionDecl + if is_constructor: + decl = ConstructorDecl(signature_, egg_name, cost, unextractable) + else: + if cost is not None: + msg = "Cost can only be set for constructors" + raise ValueError(msg) + if unextractable: + msg = "Unextractable can only be set for constructors" + raise ValueError(msg) + decl = FunctionDecl( + signature=signature_, + egg_name=egg_name, + merge=merged.__egg_typed_expr__.expr if merged is not None else None, + builtin=is_builtin, + ) res_ref = ref decls.set_function_decl(ref, decl) res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset) @@ -856,20 +688,20 @@ def _fn_decl( # Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value @overload def relation( - name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], tp4: type[E4], / -) -> Callable[[E1, E2, E3, E4], Unit]: ... + name: str, tp1: type[BE1], tp2: type[BE2], tp3: type[BE3], tp4: type[BE4], / +) -> Callable[[BE1, BE2, BE3, BE4], Unit]: ... @overload -def relation(name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], /) -> Callable[[E1, E2, E3], Unit]: ... +def relation(name: str, tp1: type[BE1], tp2: type[BE2], tp3: type[BE3], /) -> Callable[[BE1, BE2, BE3], Unit]: ... @overload -def relation(name: str, tp1: type[E1], tp2: type[E2], /) -> Callable[[E1, E2], Unit]: ... +def relation(name: str, tp1: type[BE1], tp2: type[BE2], /) -> Callable[[BE1, BE2], Unit]: ... @overload -def relation(name: str, tp1: type[T], /, *, egg_fn: str | None = None) -> Callable[[T], Unit]: ... +def relation(name: str, tp1: type[BE1], /, *, egg_fn: str | None = None) -> Callable[[BE1], Unit]: ... @overload @@ -894,19 +726,20 @@ def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Dec def constant( name: str, - tp: type[EXPR], - default_replacement: EXPR | None = None, + tp: type[BASE_EXPR], + default_replacement: BASE_EXPR | None = None, /, *, egg_name: str | None = None, ruleset: Ruleset | None = None, -) -> EXPR: +) -> BASE_EXPR: """ A "constant" is implemented as the instantiation of a value that takes no args. This creates a function with `name` and return type `tp` and returns a value of it being called. """ return cast( - EXPR, RuntimeExpr(*split_thunk(Thunk.fn(_constant_thunk, name, tp, egg_name, default_replacement, ruleset))) + BASE_EXPR, + RuntimeExpr(*split_thunk(Thunk.fn(_constant_thunk, name, tp, egg_name, default_replacement, ruleset))), ) @@ -998,29 +831,6 @@ def _last_param_variable(params: list[Parameter]) -> bool: return found_var_arg -@deprecated( - "Modules are deprecated, use top level functions to register classes/functions and rulesets to register rules" -) -@dataclass -class Module(_BaseModule): - cmds: list[Command] = field(default_factory=list) - - def _register_commands(self, cmds: list[Command]) -> None: - self.cmds.extend(cmds) - - def without_rules(self) -> Module: - return Module() - - # Use identity for hash and equility, so we don't have to compare commands and compare expressions - def __hash__(self) -> int: - return id(self) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Module): - return NotImplemented - return self is other - - class GraphvizKwargs(TypedDict, total=False): max_functions: int | None max_calls_per_function: int | None @@ -1031,7 +841,7 @@ class GraphvizKwargs(TypedDict, total=False): @dataclass -class EGraph(_BaseModule): +class EGraph: """ A collection of expressions where each expression is part of a distinct equivalence class. @@ -1047,13 +857,9 @@ class EGraph(_BaseModule): # For storing the global "current" egraph _token_stack: list[Token[EGraph]] = field(default_factory=list, repr=False) - def __post_init__(self, modules: list[Module], seminaive: bool, save_egglog_string: bool) -> None: + def __post_init__(self, seminaive: bool, save_egglog_string: bool) -> None: egraph = bindings.EGraph(GLOBAL_PY_OBJECT_SORT, seminaive=seminaive, record=save_egglog_string) self._state = EGraphState(egraph) - super().__post_init__(modules) - - for m in self._flatted_deps: - self._register_commands(m.cmds) def _add_decls(self, *decls: DeclerationsLike) -> None: for d in decls: @@ -1077,14 +883,14 @@ def input(self, fn: Callable[..., String], path: str) -> None: """ Loads a CSV file and sets it as *input, output of the function. """ - self._egraph.run_program(bindings.Input(bindings.DUMMY_SPAN, self._callable_to_egg(fn), path)) + self._egraph.run_program(bindings.Input(span(1), self._callable_to_egg(fn), path)) def _callable_to_egg(self, fn: object) -> str: ref, decls = resolve_callable(fn) self._add_decls(decls) return self._state.callable_ref_to_egg(ref) - def let(self, name: str, expr: EXPR) -> EXPR: + def let(self, name: str, expr: BASE_EXPR) -> BASE_EXPR: """ Define a new expression in the egraph and return a reference to it. """ @@ -1093,21 +899,21 @@ def let(self, name: str, expr: EXPR) -> EXPR: runtime_expr = to_runtime_expr(expr) self._add_decls(runtime_expr) return cast( - EXPR, + BASE_EXPR, RuntimeExpr.__from_values__( self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name, True)) ), ) @overload - def simplify(self, expr: EXPR, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> EXPR: ... + def simplify(self, expr: BASE_EXPR, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> BASE_EXPR: ... @overload - def simplify(self, expr: EXPR, schedule: Schedule, /) -> EXPR: ... + def simplify(self, expr: BASE_EXPR, schedule: Schedule, /) -> BASE_EXPR: ... def simplify( - self, expr: EXPR, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None - ) -> EXPR: + self, expr: BASE_EXPR, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None + ) -> BASE_EXPR: """ Simplifies the given expression. """ @@ -1119,13 +925,13 @@ def simplify( typed_expr = runtime_expr.__egg_typed_expr__ # Must also register type egg_expr = self._state.typed_expr_to_egg(typed_expr) - self._egraph.run_program(bindings.Simplify(bindings.DUMMY_SPAN, egg_expr, egg_schedule)) + self._egraph.run_program(bindings.Simplify(span(1), egg_expr, egg_schedule)) extract_report = self._egraph.extract_report() if not isinstance(extract_report, bindings.Best): msg = "No extract report saved" raise ValueError(msg) # noqa: TRY004 (new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp) - return cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr)) + return cast(BASE_EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr)) def include(self, path: str) -> None: """ @@ -1174,21 +980,21 @@ def check_fail(self, *facts: FactLike) -> None: """ Checks that one of the facts is not true """ - self._egraph.run_program(bindings.Fail(bindings.DUMMY_SPAN, self._facts_to_check(facts))) + self._egraph.run_program(bindings.Fail(span(1), self._facts_to_check(facts))) def _facts_to_check(self, fact_likes: Iterable[FactLike]) -> bindings.Check: facts = _fact_likes(fact_likes) self._add_decls(*facts) egg_facts = [self._state.fact_to_egg(f.fact) for f in _fact_likes(facts)] - return bindings.Check(bindings.DUMMY_SPAN, egg_facts) + return bindings.Check(span(2), egg_facts) @overload - def extract(self, expr: EXPR, /, include_cost: Literal[False] = False) -> EXPR: ... + def extract(self, expr: BASE_EXPR, /, include_cost: Literal[False] = False) -> BASE_EXPR: ... @overload - def extract(self, expr: EXPR, /, include_cost: Literal[True]) -> tuple[EXPR, int]: ... + def extract(self, expr: BASE_EXPR, /, include_cost: Literal[True]) -> tuple[BASE_EXPR, int]: ... - def extract(self, expr: EXPR, include_cost: bool = False) -> EXPR | tuple[EXPR, int]: + def extract(self, expr: BASE_EXPR, include_cost: bool = False) -> BASE_EXPR | tuple[BASE_EXPR, int]: """ Extract the lowest cost expression from the egraph. """ @@ -1202,12 +1008,12 @@ def extract(self, expr: EXPR, include_cost: bool = False) -> EXPR | tuple[EXPR, raise ValueError(msg) # noqa: TRY004 (new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp) - res = cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr)) + res = cast(BASE_EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr)) if include_cost: return res, extract_report.cost return res - def extract_multiple(self, expr: EXPR, n: int) -> list[EXPR]: + def extract_multiple(self, expr: BASE_EXPR, n: int) -> list[BASE_EXPR]: """ Extract multiple expressions from the egraph. """ @@ -1220,14 +1026,12 @@ def extract_multiple(self, expr: EXPR, n: int) -> list[EXPR]: msg = "Wrong extract report type" raise ValueError(msg) # noqa: TRY004 new_exprs = self._state.exprs_from_egg(extract_report.termdag, extract_report.terms, typed_expr.tp) - return [cast(EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, expr)) for expr in new_exprs] + return [cast(BASE_EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, expr)) for expr in new_exprs] def _run_extract(self, typed_expr: TypedExprDecl, n: int) -> bindings._ExtractReport: expr = self._state.typed_expr_to_egg(typed_expr) self._egraph.run_program( - bindings.ActionCommand( - bindings.Extract(bindings.DUMMY_SPAN, expr, bindings.Lit(bindings.DUMMY_SPAN, bindings.Int(n))) - ) + bindings.ActionCommand(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n)))) ) extract_report = self._egraph.extract_report() if not extract_report: @@ -1247,7 +1051,7 @@ def pop(self) -> None: """ Pop the current state of the egraph, reverting back to the previous state. """ - self._egraph.run_program(bindings.Pop(bindings.DUMMY_SPAN, 1)) + self._egraph.run_program(bindings.Pop(span(1), 1)) self._state = self._state_stack.pop() def __enter__(self) -> Self: @@ -1265,13 +1069,13 @@ def __exit__(self, exc_type, exc, exc_tb) -> None: self.pop() @overload - def eval(self, expr: i64) -> int: ... + def eval(self, expr: Bool) -> bool: ... @overload - def eval(self, expr: f64) -> float: ... + def eval(self, expr: i64) -> int: ... @overload - def eval(self, expr: Bool) -> bool: ... + def eval(self, expr: f64) -> float: ... @overload def eval(self, expr: String) -> str: ... @@ -1279,7 +1083,7 @@ def eval(self, expr: String) -> str: ... @overload def eval(self, expr: PyObject) -> object: ... - def eval(self, expr: Expr) -> object: + def eval(self, expr: BuiltinExpr) -> object: """ Evaluates the given expression (which must be a primitive type), returning the result. """ @@ -1435,6 +1239,27 @@ def _egraph(self) -> bindings.EGraph: def __egg_decls__(self) -> Declarations: return self._state.__egg_decls__ + def register( + self, + /, + command_or_generator: ActionLike | RewriteOrRule | RewriteOrRuleGenerator, + *command_likes: ActionLike | RewriteOrRule, + ) -> None: + """ + Registers any number of rewrites or rules. + """ + if isinstance(command_or_generator, FunctionType): + assert not command_likes + current_frame = inspect.currentframe() + assert current_frame + original_frame = current_frame.f_back + assert original_frame + command_likes = tuple(_rewrite_or_rule_generator(command_or_generator, original_frame)) + else: + command_likes = (cast(CommandLike, command_or_generator), *command_likes) + commands = [_command_like(c) for c in command_likes] + self._register_commands(commands) + def _register_commands(self, cmds: list[Command]) -> None: self._add_decls(*cmds) egg_cmds = list(map(self._command_to_egg, cmds)) @@ -1458,40 +1283,26 @@ def _command_to_egg(self, cmd: Command) -> bindings._Command: @dataclass(frozen=True) -class _WrappedMethod(Generic[P, EXPR]): +class _WrappedMethod: """ Used to wrap a method and store some extra options on it before processing it when processing the class. """ egg_fn: str | None cost: int | None - default: EXPR | None - merge: Callable[[EXPR, EXPR], EXPR] | None - on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None - fn: Callable[P, EXPR] + merge: Callable[[object, object], object] | None + fn: Callable preserve: bool mutates_self: bool unextractable: bool subsume: bool - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> EXPR: + def __call__(self, *args, **kwargs) -> Never: msg = "We should never call a wrapped method. Did you forget to wrap the class?" raise NotImplementedError(msg) -class Expr(metaclass=_ExprMetaclass): - """ - Either a function called with some number of argument expressions or a literal integer, float, or string, with a particular type. - """ - - def __ne__(self, other: NoReturn) -> NoReturn: # type: ignore[override, empty-body] - ... - - def __eq__(self, other: NoReturn) -> NoReturn: # type: ignore[override, empty-body] - ... - - -class Unit(Expr, egg_sort="Unit", builtin=True): +class Unit(BuiltinExpr, egg_sort="Unit"): """ The unit type. This is also used to reprsent if a value exists, if it is resolved or not. """ @@ -1715,40 +1526,22 @@ def __repr__(self) -> str: # if the arguments are the same type of expression -@deprecated("Use .register() instead of passing rulesets as arguments to rewrites.") -@overload -def rewrite(lhs: EXPR, ruleset: Ruleset, *, subsume: bool = False) -> _RewriteBuilder[EXPR]: ... - - -@overload -def rewrite(lhs: EXPR, ruleset: None = None, *, subsume: bool = False) -> _RewriteBuilder[EXPR]: ... - - -def rewrite(lhs: EXPR, ruleset: Ruleset | None = None, *, subsume: bool = False) -> _RewriteBuilder[EXPR]: +def rewrite(lhs: EXPR, ruleset: None = None, *, subsume: bool = False) -> _RewriteBuilder[EXPR]: """Rewrite the given expression to a new expression.""" return _RewriteBuilder(lhs, ruleset, subsume) -@deprecated("Use .register() instead of passing rulesets as arguments to birewrites.") -@overload -def birewrite(lhs: EXPR, ruleset: Ruleset) -> _BirewriteBuilder[EXPR]: ... - - -@overload -def birewrite(lhs: EXPR, ruleset: None = None) -> _BirewriteBuilder[EXPR]: ... - - -def birewrite(lhs: EXPR, ruleset: Ruleset | None = None) -> _BirewriteBuilder[EXPR]: +def birewrite(lhs: EXPR, ruleset: None = None) -> _BirewriteBuilder[EXPR]: """Rewrite the given expression to a new expression and vice versa.""" return _BirewriteBuilder(lhs, ruleset) -def eq(expr: EXPR) -> _EqBuilder[EXPR]: +def eq(expr: BASE_EXPR) -> _EqBuilder[BASE_EXPR]: """Check if the given expression is equal to the given value.""" return _EqBuilder(expr) -def ne(expr: EXPR) -> _NeBuilder[EXPR]: +def ne(expr: BASE_EXPR) -> _NeBuilder[BASE_EXPR]: """Check if the given expression is not equal to the given value.""" return _NeBuilder(expr) @@ -1758,18 +1551,18 @@ def panic(message: str) -> Action: return Action(Declarations(), PanicDecl(message)) -def let(name: str, expr: Expr) -> Action: +def let(name: str, expr: BaseExpr) -> Action: """Create a let binding.""" runtime_expr = to_runtime_expr(expr) return Action(runtime_expr.__egg_decls__, LetDecl(name, runtime_expr.__egg_typed_expr__)) -def expr_action(expr: Expr) -> Action: +def expr_action(expr: BaseExpr) -> Action: runtime_expr = to_runtime_expr(expr) return Action(runtime_expr.__egg_decls__, ExprActionDecl(runtime_expr.__egg_typed_expr__)) -def delete(expr: Expr) -> Action: +def delete(expr: BaseExpr) -> Action: """Create a delete expression.""" runtime_expr = to_runtime_expr(expr) typed_expr = runtime_expr.__egg_typed_expr__ @@ -1787,7 +1580,7 @@ def subsume(expr: Expr) -> Action: return Action(runtime_expr.__egg_decls__, ChangeDecl(typed_expr.tp, call_decl, "subsume")) -def expr_fact(expr: Expr) -> Fact: +def expr_fact(expr: BaseExpr) -> Fact: runtime_expr = to_runtime_expr(expr) return Fact(runtime_expr.__egg_decls__, ExprFactDecl(runtime_expr.__egg_typed_expr__)) @@ -1797,30 +1590,16 @@ def union(lhs: EXPR) -> _UnionBuilder[EXPR]: return _UnionBuilder(lhs=lhs) -def set_(lhs: EXPR) -> _SetBuilder[EXPR]: +def set_(lhs: BASE_EXPR) -> _SetBuilder[BASE_EXPR]: """Create a set of the given expression.""" return _SetBuilder(lhs=lhs) -@deprecated("Use .register() instead of passing rulesets as arguments to rules.") -@overload -def rule(*facts: FactLike, ruleset: Ruleset, name: str | None = None) -> _RuleBuilder: ... - - -@overload -def rule(*facts: FactLike, ruleset: None = None, name: str | None = None) -> _RuleBuilder: ... - - -def rule(*facts: FactLike, ruleset: Ruleset | None = None, name: str | None = None) -> _RuleBuilder: +def rule(*facts: FactLike, ruleset: None = None, name: str | None = None) -> _RuleBuilder: """Create a rule with the given facts.""" return _RuleBuilder(facts=_fact_likes(facts), name=name, ruleset=ruleset) -@deprecated("This function is now a no-op, you can remove it and use actions as commands") -def action_command(action: Action) -> Action: - return action - - def var(name: str, bound: type[T]) -> T: """Create a new variable with the given name and type.""" return cast(T, _var(name, bound)) @@ -1834,7 +1613,7 @@ def _var(name: str, bound: object) -> RuntimeExpr: ) -def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]: +def vars_(names: str, bound: type[BASE_EXPR]) -> Iterable[BASE_EXPR]: """Create variables with the given names and type.""" for name in names.split(" "): yield var(name, bound) @@ -1897,15 +1676,15 @@ def __str__(self) -> str: @dataclass -class _EqBuilder(Generic[EXPR]): - expr: EXPR +class _EqBuilder(Generic[BASE_EXPR]): + expr: BASE_EXPR - def to(self, *exprs: EXPR) -> Fact: + def to(self, other: BASE_EXPR) -> Fact: expr = to_runtime_expr(self.expr) - args = [expr, *(convert_to_same_type(e, expr) for e in exprs)] + other = convert_to_same_type(other, expr) return Fact( - Declarations.create(*args), - EqDecl(expr.__egg_typed_expr__.tp, tuple(a.__egg_typed_expr__.expr for a in args)), + Declarations.create(expr, other), + EqDecl(expr.__egg_typed_expr__.tp, expr.__egg_typed_expr__.expr, other.__egg_typed_expr__.expr), ) def __repr__(self) -> str: @@ -1917,10 +1696,10 @@ def __str__(self) -> str: @dataclass -class _NeBuilder(Generic[EXPR]): - lhs: EXPR +class _NeBuilder(Generic[BASE_EXPR]): + lhs: BASE_EXPR - def to(self, rhs: EXPR) -> Unit: + def to(self, rhs: BASE_EXPR) -> Unit: lhs = to_runtime_expr(self.lhs) rhs = convert_to_same_type(rhs, lhs) assert isinstance(Unit, RuntimeClass) @@ -1941,10 +1720,10 @@ def __str__(self) -> str: @dataclass -class _SetBuilder(Generic[EXPR]): - lhs: EXPR +class _SetBuilder(Generic[BASE_EXPR]): + lhs: BASE_EXPR - def to(self, rhs: EXPR) -> Action: + def to(self, rhs: BASE_EXPR) -> Action: lhs = to_runtime_expr(self.lhs) rhs = convert_to_same_type(rhs, lhs) lhs_expr = lhs.__egg_typed_expr__.expr @@ -2008,7 +1787,7 @@ def __str__(self) -> str: return f"rule({', '.join(args)})" -def expr_parts(expr: Expr) -> TypedExprDecl: +def expr_parts(expr: BaseExpr) -> TypedExprDecl: """ Returns the underlying type and decleration of the expression. Useful for testing structural equality or debugging. """ @@ -2017,7 +1796,7 @@ def expr_parts(expr: Expr) -> TypedExprDecl: return expr.__egg_typed_expr__ -def to_runtime_expr(expr: Expr) -> RuntimeExpr: +def to_runtime_expr(expr: BaseExpr) -> RuntimeExpr: if not isinstance(expr, RuntimeExpr): raise TypeError(f"Expected a RuntimeExpr not {expr}") return expr @@ -2041,7 +1820,7 @@ def seq(*schedules: Schedule) -> Schedule: return Schedule(Thunk.fn(Declarations.create, *schedules), SequenceDecl(tuple(s.schedule for s in schedules))) -ActionLike: TypeAlias = Action | Expr +ActionLike: TypeAlias = Action | BaseExpr def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]: @@ -2049,9 +1828,9 @@ def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]: def _action_like(action_like: ActionLike) -> Action: - if isinstance(action_like, Expr): - return expr_action(action_like) - return action_like + if isinstance(action_like, Action): + return action_like + return expr_action(action_like) Command: TypeAlias = Action | RewriteOrRule @@ -2082,7 +1861,7 @@ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) -> return list(gen(*args)) # type: ignore[misc] -FactLike = Fact | Expr +FactLike = Fact | BaseExpr def _fact_likes(fact_likes: Iterable[FactLike]) -> tuple[Fact, ...]: @@ -2090,9 +1869,9 @@ def _fact_likes(fact_likes: Iterable[FactLike]) -> tuple[Fact, ...]: def _fact_like(fact_like: FactLike) -> Fact: - if isinstance(fact_like, Expr): - return expr_fact(fact_like) - return fact_like + if isinstance(fact_like, Fact): + return fact_like + return expr_fact(fact_like) _CURRENT_RULESET = ContextVar[Ruleset | None]("CURRENT_RULESET", default=None) diff --git a/python/egglog/egraph_state.py b/python/egglog/egraph_state.py index 2c3adfef..c95c45b6 100644 --- a/python/egglog/egraph_state.py +++ b/python/egglog/egraph_state.py @@ -13,19 +13,34 @@ from . import bindings from .declarations import * +from .declarations import ConstructorDecl from .pretty import * from .type_constraint_solver import TypeConstraintError, TypeConstraintSolver if TYPE_CHECKING: from collections.abc import Iterable -__all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT"] +__all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT", "span"] # Create a global sort for python objects, so we can store them without an e-graph instance # Needed when serializing commands to egg commands when creating modules GLOBAL_PY_OBJECT_SORT = bindings.PyObjectSort() +def span(frame_index: int = 0) -> bindings.RustSpan: + """ + Returns a span for the current file and line. + + If `frame_index` is passed, it will return the span for that frame in the stack, where 0 is the current frame + this is called in and 1 is the parent. + """ + # Currently disable this because it's too expensive. + # import inspect + + # frame = inspect.stack()[frame_index + 1] + return bindings.RustSpan("", 0, 0) + + @dataclass class EGraphState: """ @@ -71,15 +86,15 @@ def copy(self) -> EGraphState: def schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule: match schedule: case SaturateDecl(schedule): - return bindings.Saturate(bindings.DUMMY_SPAN, self.schedule_to_egg(schedule)) + return bindings.Saturate(span(), self.schedule_to_egg(schedule)) case RepeatDecl(schedule, times): - return bindings.Repeat(bindings.DUMMY_SPAN, times, self.schedule_to_egg(schedule)) + return bindings.Repeat(span(), times, self.schedule_to_egg(schedule)) case SequenceDecl(schedules): - return bindings.Sequence(bindings.DUMMY_SPAN, [self.schedule_to_egg(s) for s in schedules]) + return bindings.Sequence(span(), [self.schedule_to_egg(s) for s in schedules]) case RunDecl(ruleset_name, until): self.ruleset_to_egg(ruleset_name) config = bindings.RunConfig(ruleset_name, None if not until else list(map(self.fact_to_egg, until))) - return bindings.Run(bindings.DUMMY_SPAN, config) + return bindings.Run(span(), config) case _: assert_never(schedule) @@ -116,7 +131,7 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command: case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions): self.type_ref_to_egg(tp) rewrite = bindings.Rewrite( - bindings.DUMMY_SPAN, + span(), self._expr_to_egg(lhs), self._expr_to_egg(rhs), [self.fact_to_egg(c) for c in conditions], @@ -128,15 +143,14 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command: ) case RuleDecl(head, body, name): rule = bindings.Rule( - bindings.DUMMY_SPAN, + span(), [self.action_to_egg(a) for a in head], [self.fact_to_egg(f) for f in body], ) return bindings.RuleCommand(name or "", ruleset, rule) # TODO: Replace with just constants value and looking at REF of function case DefaultRewriteDecl(ref, expr, subsume): - decl = self.__egg_decls__.get_callable_decl(ref).to_function_decl() - sig = decl.signature + sig = self.__egg_decls__.get_callable_decl(ref).signature assert isinstance(sig, FunctionSignature) # Replace args with rule_var_name mapping arg_mapping = tuple( @@ -156,13 +170,13 @@ def action_to_egg(self, action: ActionDecl) -> bindings._Action: var_decl = VarDecl(name, True) var_egg = self._expr_to_egg(var_decl) self.expr_to_egg_cache[var_decl] = var_egg - return bindings.Let(bindings.DUMMY_SPAN, var_egg.name, self.typed_expr_to_egg(typed_expr)) + return bindings.Let(span(), var_egg.name, self.typed_expr_to_egg(typed_expr)) case SetDecl(tp, call, rhs): self.type_ref_to_egg(tp) call_ = self._expr_to_egg(call) - return bindings.Set(bindings.DUMMY_SPAN, call_.name, call_.args, self._expr_to_egg(rhs)) + return bindings.Set(span(), call_.name, call_.args, self._expr_to_egg(rhs)) case ExprActionDecl(typed_expr): - return bindings.Expr_(bindings.DUMMY_SPAN, self.typed_expr_to_egg(typed_expr)) + return bindings.Expr_(span(), self.typed_expr_to_egg(typed_expr)) case ChangeDecl(tp, call, change): self.type_ref_to_egg(tp) call_ = self._expr_to_egg(call) @@ -174,20 +188,20 @@ def action_to_egg(self, action: ActionDecl) -> bindings._Action: egg_change = bindings.Subsume() case _: assert_never(change) - return bindings.Change(bindings.DUMMY_SPAN, egg_change, call_.name, call_.args) + return bindings.Change(span(), egg_change, call_.name, call_.args) case UnionDecl(tp, lhs, rhs): self.type_ref_to_egg(tp) - return bindings.Union(bindings.DUMMY_SPAN, self._expr_to_egg(lhs), self._expr_to_egg(rhs)) + return bindings.Union(span(), self._expr_to_egg(lhs), self._expr_to_egg(rhs)) case PanicDecl(name): - return bindings.Panic(bindings.DUMMY_SPAN, name) + return bindings.Panic(span(), name) case _: assert_never(action) def fact_to_egg(self, fact: FactDecl) -> bindings._Fact: match fact: - case EqDecl(tp, exprs): + case EqDecl(tp, left, right): self.type_ref_to_egg(tp) - return bindings.Eq(bindings.DUMMY_SPAN, [self._expr_to_egg(e) for e in exprs]) + return bindings.Eq(span(), self._expr_to_egg(left), self._expr_to_egg(right)) case ExprFactDecl(typed_expr): return bindings.Fact(self.typed_expr_to_egg(typed_expr, False)) case _: @@ -207,40 +221,55 @@ def callable_ref_to_egg(self, ref: CallableRef) -> str: match decl: case RelationDecl(arg_types, _, _): self.egraph.run_program( - bindings.Relation(bindings.DUMMY_SPAN, egg_name, [self.type_ref_to_egg(a) for a in arg_types]) + bindings.Relation(span(), egg_name, [self.type_ref_to_egg(a) for a in arg_types]) ) case ConstantDecl(tp, _): - # Use function decleration instead of constant b/c constants cannot be extracted + # Use constructor decleration instead of constant b/c constants cannot be extracted # https://github.com/egraphs-good/egglog/issues/334 self.egraph.run_program( - bindings.Function( - bindings.FunctionDecl( - bindings.DUMMY_SPAN, egg_name, bindings.Schema([], self.type_ref_to_egg(tp)) - ) - ) + bindings.Constructor(span(), egg_name, bindings.Schema([], self.type_ref_to_egg(tp)), None, False) ) - case FunctionDecl(): - if not decl.builtin: - signature = decl.signature + case FunctionDecl(signature, builtin, _, merge): + if not builtin: assert isinstance(signature, FunctionSignature), "Cannot turn special function to egg" - egg_fn_decl = bindings.FunctionDecl( - bindings.DUMMY_SPAN, + # Compile functions that return unit to relations, because these show up in methods where you + # cant use the relation helper + schema = self._signature_to_egg_schema(signature) + if signature.return_type == TypeRefWithVars("Unit"): + if merge: + msg = "Cannot specify a merge function for a function that returns unit" + raise ValueError(msg) + self.egraph.run_program(bindings.Relation(span(), egg_name, schema.input)) + else: + self.egraph.run_program( + bindings.Function( + span(), + egg_name, + self._signature_to_egg_schema(signature), + self._expr_to_egg(merge) if merge else None, + ) + ) + case ConstructorDecl(signature, _, cost, unextractable): + self.egraph.run_program( + bindings.Constructor( + span(), egg_name, - bindings.Schema( - [self.type_ref_to_egg(a.to_just()) for a in signature.arg_types], - self.type_ref_to_egg(signature.semantic_return_type.to_just()), - ), - self._expr_to_egg(decl.default) if decl.default else None, - self._expr_to_egg(decl.merge) if decl.merge else None, - [self.action_to_egg(a) for a in decl.on_merge], - decl.cost, - decl.unextractable, + self._signature_to_egg_schema(signature), + cost, + unextractable, ) - self.egraph.run_program(bindings.Function(egg_fn_decl)) + ) + case _: assert_never(decl) return egg_name + def _signature_to_egg_schema(self, signature: FunctionSignature) -> bindings.Schema: + return bindings.Schema( + [self.type_ref_to_egg(a.to_just()) for a in signature.arg_types], + self.type_ref_to_egg(signature.semantic_return_type.to_just()), + ) + def type_ref_to_egg(self, ref: JustTypeRef) -> str: """ Returns the egg sort name for a type reference, registering it if it is not already registered. @@ -257,18 +286,18 @@ def type_ref_to_egg(self, ref: JustTypeRef) -> str: # UnstableFn is a special case, where the rest of args are collected into a call type_args: list[bindings._Expr] = [ bindings.Call( - bindings.DUMMY_SPAN, + span(), self.type_ref_to_egg(ref.args[1]), - [bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(a)) for a in ref.args[2:]], + [bindings.Var(span(), self.type_ref_to_egg(a)) for a in ref.args[2:]], ), - bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(ref.args[0])), + bindings.Var(span(), self.type_ref_to_egg(ref.args[0])), ] else: - type_args = [bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(a)) for a in ref.args] + type_args = [bindings.Var(span(), self.type_ref_to_egg(a)) for a in ref.args] args = (self.type_ref_to_egg(JustTypeRef(ref.name)), type_args) else: args = None - self.egraph.run_program(bindings.Sort(bindings.DUMMY_SPAN, egg_name, args)) + self.egraph.run_program(bindings.Sort(span(), egg_name, args)) # For builtin classes, let's also make sure we have the mapping of all egg fn names for class methods, because # these can be created even without adding them to the e-graph, like `vec-empty` which can be extracted # even if you never use that function. @@ -310,9 +339,7 @@ def _transform_let(self, typed_expr: TypedExprDecl) -> None: if var_decl in self.expr_to_egg_cache: return var_egg = self._expr_to_egg(var_decl) - cmd = bindings.ActionCommand( - bindings.Let(bindings.DUMMY_SPAN, var_egg.name, self.typed_expr_to_egg(typed_expr)) - ) + cmd = bindings.ActionCommand(bindings.Let(span(), var_egg.name, self.typed_expr_to_egg(typed_expr))) try: self.egraph.run_program(cmd) # errors when creating let bindings for things like `(vec-empty)` @@ -344,7 +371,7 @@ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: # noqa: PLR0912, # prefix let bindings with % to avoid name conflicts with rewrites if is_let: name = f"%{name}" - res = bindings.Var(bindings.DUMMY_SPAN, name) + res = bindings.Var(span(), name) case LitDecl(value): l: bindings._Literal match value: @@ -355,24 +382,24 @@ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: # noqa: PLR0912, case int(i): l = bindings.Int(i) case float(f): - l = bindings.F64(f) + l = bindings.Float(f) case str(s): l = bindings.String(s) case _: assert_never(value) - res = bindings.Lit(bindings.DUMMY_SPAN, l) + res = bindings.Lit(span(), l) case CallDecl(ref, args, _): egg_fn = self.callable_ref_to_egg(ref) egg_args = [self.typed_expr_to_egg(a, False) for a in args] - res = bindings.Call(bindings.DUMMY_SPAN, egg_fn, egg_args) + res = bindings.Call(span(), egg_fn, egg_args) case PyObjectDecl(value): res = GLOBAL_PY_OBJECT_SORT.store(value) case PartialCallDecl(call_decl): egg_fn_call = self._expr_to_egg(call_decl) res = bindings.Call( - bindings.DUMMY_SPAN, + span(), "unstable-fn", - [bindings.Lit(bindings.DUMMY_SPAN, bindings.String(egg_fn_call.name)), *egg_fn_call.args], + [bindings.Lit(span(), bindings.String(egg_fn_call.name)), *egg_fn_call.args], ) case _: assert_never(expr_decl.expr) @@ -491,7 +518,7 @@ def from_expr(self, tp: JustTypeRef, term: bindings._Term) -> TypedExprDecl: expr_decl = LitDecl(None if isinstance(value, bindings.Unit) else value.value) elif isinstance(term, bindings.TermApp): if term.name == "py-object": - call = bindings.termdag_term_to_expr(self.termdag, term) + call = self.termdag.term_to_expr(term, span()) expr_decl = PyObjectDecl(self.state.egraph.eval_py_object(call)) elif term.name == "unstable-fn": # Get function name @@ -530,7 +557,7 @@ def from_call( # If this is a classmethod, we might need the type params that were bound for this type # This could be multiple types if the classmethod is ambiguous, like map create. possible_types: Iterable[JustTypeRef | None] - signature = self.decls.get_callable_decl(callable_ref).to_function_decl().signature + signature = self.decls.get_callable_decl(callable_ref).signature assert isinstance(signature, FunctionSignature) if isinstance(callable_ref, ClassMethodRef | InitRef | MethodRef): # Need OR in case we have class method whose class whas never added as a sort, which would happen @@ -566,5 +593,5 @@ def resolve_term(self, term_id: int, tp: JustTypeRef) -> TypedExprDecl: try: return self.cache[term_id] except KeyError: - res = self.cache[term_id] = self.from_expr(tp, self.termdag.nodes[term_id]) + res = self.cache[term_id] = self.from_expr(tp, self.termdag.get(term_id)) return res diff --git a/python/egglog/examples/higher_order_functions.py b/python/egglog/examples/higher_order_functions.py index 4077af6b..d82638cf 100644 --- a/python/egglog/examples/higher_order_functions.py +++ b/python/egglog/examples/higher_order_functions.py @@ -37,7 +37,7 @@ def incr_list(xs: MathList) -> MathList: return xs.map(lambda x: x + Math(1)) -egraph = EGraph() +egraph = EGraph(save_egglog_string=True) y = egraph.let("y", incr_list(MathList().append(Math(1)).append(Math(2)))) egraph.run(math_ruleset.saturate()) egraph.check(eq(y).to(MathList().append(Math(2)).append(Math(3)))) diff --git a/python/egglog/exp/program_gen.py b/python/egglog/exp/program_gen.py index cd882ef0..0a83b631 100644 --- a/python/egglog/exp/program_gen.py +++ b/python/egglog/exp/program_gen.py @@ -75,13 +75,13 @@ def next_sym(self) -> i64: Returns the next gensym to use. This is set after calling `compile(i)` on a program. """ - @method(default=Unit()) + # TODO: Replace w/ def next_sym(self) -> i64: ... ? def compile(self, next_sym: i64 = i64(0)) -> Unit: """ Triggers compilation of the program. """ - @method(merge=lambda old, _new: old, unextractable=True) # type: ignore[misc] + @method(merge=lambda old, _new: old) # type: ignore[misc] @property def parent(self) -> Program: """ diff --git a/python/egglog/pretty.py b/python/egglog/pretty.py index cffc5a8a..7358c75b 100644 --- a/python/egglog/pretty.py +++ b/python/egglog/pretty.py @@ -159,7 +159,7 @@ def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C90 self(action) for fact in body: self(fact) - case SetDecl(_, lhs, rhs) | UnionDecl(_, lhs, rhs): + case SetDecl(_, lhs, rhs) | UnionDecl(_, lhs, rhs) | EqDecl(_, lhs, rhs): self(lhs) self(rhs) case LetDecl(_, d) | ExprActionDecl(d) | ExprFactDecl(d): @@ -168,7 +168,7 @@ def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C90 self(d) case PanicDecl(_) | VarDecl(_) | LitDecl(_) | PyObjectDecl(_): pass - case EqDecl(_, decls) | SequenceDecl(decls) | RulesetDecl(decls): + case SequenceDecl(decls) | RulesetDecl(decls): for de in decls: if isinstance(de, DefaultRewriteDecl): continue @@ -281,9 +281,8 @@ def uncached(self, decl: AllDecls, *, unwrap_lit: bool, parens: bool, ruleset_na return f"{change}({self(expr)})", "action" case PanicDecl(s): return f"panic({s!r})", "action" - case EqDecl(_, exprs): - first, *rest = exprs - return f"eq({self(first)}).to({', '.join(map(self, rest))})", "fact" + case EqDecl(_, left, right): + return f"eq({self(left)}).to({self(right)})", "fact" case RulesetDecl(rules): if ruleset_name: return f"ruleset(name={ruleset_name!r})", f"ruleset_{ruleset_name}" @@ -330,8 +329,7 @@ def _call( if decl.callable == FunctionRef("!="): l, r = self(args[0]), self(args[1]) return f"ne({l}).to({r})", "Unit" - function_decl = self.decls.get_callable_decl(ref).to_function_decl() - signature = function_decl.signature + signature = self.decls.get_callable_decl(ref).signature # Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default n_defaults = 0 diff --git a/python/egglog/runtime.py b/python/egglog/runtime.py index 5195d31e..6b8c8571 100644 --- a/python/egglog/runtime.py +++ b/python/egglog/runtime.py @@ -35,6 +35,7 @@ "RuntimeFunction", "resolve_callable", "resolve_type_annotation", + "resolve_type_annotation_mutate", ] @@ -300,7 +301,7 @@ def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: if isinstance(self.__egg_bound__, RuntimeExpr): args = (self.__egg_bound__, *args) try: - signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).to_function_decl().signature + signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).signature except Exception as e: e.add_note(f"Failed to find callable {self}") raise diff --git a/python/tests/__snapshots__/test_bindings/TestEGraph.test_parse_program.py b/python/tests/__snapshots__/test_bindings/TestEGraph.test_parse_program.py index 70d74d26..6e332947 100644 --- a/python/tests/__snapshots__/test_bindings/TestEGraph.test_parse_program.py +++ b/python/tests/__snapshots__/test_bindings/TestEGraph.test_parse_program.py @@ -1,6 +1,6 @@ [ Datatype( - Span( + EgglogSpan( SrcFile( "test.egg", '(datatype Math\n (Num i64)\n (Var String)\n (Add Math Math)\n (Mul Math Math))\n\n ;; expr1 = 2 * (x + 3)\n (let expr1 (Mul (Num 2) (Add (Var "x") (Num 3))))', @@ -11,7 +11,7 @@ "Math", [ Variant( - Span( + EgglogSpan( SrcFile( "test.egg", '(datatype Math\n (Num i64)\n (Var String)\n (Add Math Math)\n (Mul Math Math))\n\n ;; expr1 = 2 * (x + 3)\n (let expr1 (Mul (Num 2) (Add (Var "x") (Num 3))))', @@ -24,7 +24,7 @@ None, ), Variant( - Span( + EgglogSpan( SrcFile( "test.egg", '(datatype Math\n (Num i64)\n (Var String)\n (Add Math Math)\n (Mul Math Math))\n\n ;; expr1 = 2 * (x + 3)\n (let expr1 (Mul (Num 2) (Add (Var "x") (Num 3))))', @@ -37,7 +37,7 @@ None, ), Variant( - Span( + EgglogSpan( SrcFile( "test.egg", '(datatype Math\n (Num i64)\n (Var String)\n (Add Math Math)\n (Mul Math Math))\n\n ;; expr1 = 2 * (x + 3)\n (let expr1 (Mul (Num 2) (Add (Var "x") (Num 3))))', @@ -50,7 +50,7 @@ None, ), Variant( - Span( + EgglogSpan( SrcFile( "test.egg", '(datatype Math\n (Num i64)\n (Var String)\n (Add Math Math)\n (Mul Math Math))\n\n ;; expr1 = 2 * (x + 3)\n (let expr1 (Mul (Num 2) (Add (Var "x") (Num 3))))', @@ -66,7 +66,7 @@ ), ActionCommand( Let( - Span( + EgglogSpan( SrcFile( "test.egg", '(datatype Math\n (Num i64)\n (Var String)\n (Add Math Math)\n (Mul Math Math))\n\n ;; expr1 = 2 * (x + 3)\n (let expr1 (Mul (Num 2) (Add (Var "x") (Num 3))))', @@ -76,7 +76,7 @@ ), "expr1", Call( - Span( + EgglogSpan( SrcFile( "test.egg", '(datatype Math\n (Num i64)\n (Var String)\n (Add Math Math)\n (Mul Math Math))\n\n ;; expr1 = 2 * (x + 3)\n (let expr1 (Mul (Num 2) (Add (Var "x") (Num 3))))', @@ -87,7 +87,7 @@ "Mul", [ Call( - Span( + EgglogSpan( SrcFile( "test.egg", '(datatype Math\n (Num i64)\n (Var String)\n (Add Math Math)\n (Mul Math Math))\n\n ;; expr1 = 2 * (x + 3)\n (let expr1 (Mul (Num 2) (Add (Var "x") (Num 3))))', @@ -98,7 +98,7 @@ "Num", [ Lit( - Span( + EgglogSpan( SrcFile( "test.egg", '(datatype Math\n (Num i64)\n (Var String)\n (Add Math Math)\n (Mul Math Math))\n\n ;; expr1 = 2 * (x + 3)\n (let expr1 (Mul (Num 2) (Add (Var "x") (Num 3))))', @@ -111,7 +111,7 @@ ], ), Call( - Span( + EgglogSpan( SrcFile( "test.egg", '(datatype Math\n (Num i64)\n (Var String)\n (Add Math Math)\n (Mul Math Math))\n\n ;; expr1 = 2 * (x + 3)\n (let expr1 (Mul (Num 2) (Add (Var "x") (Num 3))))', @@ -122,7 +122,7 @@ "Add", [ Call( - Span( + EgglogSpan( SrcFile( "test.egg", '(datatype Math\n (Num i64)\n (Var String)\n (Add Math Math)\n (Mul Math Math))\n\n ;; expr1 = 2 * (x + 3)\n (let expr1 (Mul (Num 2) (Add (Var "x") (Num 3))))', @@ -133,7 +133,7 @@ "Var", [ Lit( - Span( + EgglogSpan( SrcFile( "test.egg", '(datatype Math\n (Num i64)\n (Var String)\n (Add Math Math)\n (Mul Math Math))\n\n ;; expr1 = 2 * (x + 3)\n (let expr1 (Mul (Num 2) (Add (Var "x") (Num 3))))', @@ -146,7 +146,7 @@ ], ), Call( - Span( + EgglogSpan( SrcFile( "test.egg", '(datatype Math\n (Num i64)\n (Var String)\n (Add Math Math)\n (Mul Math Math))\n\n ;; expr1 = 2 * (x + 3)\n (let expr1 (Mul (Num 2) (Add (Var "x") (Num 3))))', @@ -157,7 +157,7 @@ "Num", [ Lit( - Span( + EgglogSpan( SrcFile( "test.egg", '(datatype Math\n (Num i64)\n (Var String)\n (Add Math Math)\n (Mul Math Math))\n\n ;; expr1 = 2 * (x + 3)\n (let expr1 (Mul (Num 2) (Add (Var "x") (Num 3))))', diff --git a/python/tests/test_bindings.py b/python/tests/test_bindings.py index 73dbbb0f..6a40c492 100644 --- a/python/tests/test_bindings.py +++ b/python/tests/test_bindings.py @@ -1,5 +1,4 @@ import _thread -import fractions import json import os import pathlib @@ -10,6 +9,8 @@ from egglog.bindings import * +DUMMY_SPAN = RustSpan(__name__, 0, 0) + def get_egglog_folder() -> pathlib.Path: """ @@ -44,7 +45,7 @@ def get_egglog_folder() -> pathlib.Path: ) def test_example(example_file: pathlib.Path): egraph = EGraph(fact_directory=EGG_SMOL_FOLDER) - commands = parse_program(example_file.read_text()) + commands = egraph.parse_program(example_file.read_text()) # TODO: Include currently relies on the CWD instead of the fact directory. We should fix this upstream # and then remove this workaround. os.chdir(EGG_SMOL_FOLDER) @@ -56,7 +57,7 @@ def test_example(example_file: pathlib.Path): class TestEGraph: def test_parse_program(self, snapshot_py): - res = parse_program( + res = EGraph().parse_program( """(datatype Math (Num i64) (Var String) @@ -74,7 +75,7 @@ def test_parse_and_run_program(self): program = "(check (= (+ 2 2) 4))" egraph = EGraph() - assert egraph.run_program(*parse_program(program)) == [] + assert egraph.run_program(*egraph.parse_program(program)) == [] def test_parse_and_run_program_exception(self): program = "(check (= 1 1.0))" @@ -84,7 +85,7 @@ def test_parse_and_run_program_exception(self): EggSmolError, match="to have type", ): - egraph.run_program(*parse_program(program)) + egraph.run_program(*egraph.parse_program(program)) def test_run_rules(self): egraph = EGraph() @@ -117,7 +118,7 @@ def test_extract(self): extract_report = egraph.extract_report() assert isinstance(extract_report, Best) assert extract_report.cost == 6 - assert termdag_term_to_expr(extract_report.termdag, extract_report.term) == Call( + assert extract_report.termdag.term_to_expr(extract_report.term, DUMMY_SPAN) == Call( DUMMY_SPAN, "Num", [Lit(DUMMY_SPAN, Int(1))] ) @@ -132,7 +133,7 @@ def test_simplify(self): extract_report = egraph.extract_report() assert isinstance(extract_report, Best) assert extract_report.cost == 6 - assert termdag_term_to_expr(extract_report.termdag, extract_report.term) == Call( + assert extract_report.termdag.term_to_expr(extract_report.term, DUMMY_SPAN) == Call( DUMMY_SPAN, "Num", [Lit(DUMMY_SPAN, Int(1))] ) @@ -172,10 +173,8 @@ def test_sort_alias(self): [ Eq( DUMMY_SPAN, - [ - Lit(DUMMY_SPAN, String("one")), - Call(DUMMY_SPAN, "map-get", [Var(DUMMY_SPAN, "my_map1"), Lit(DUMMY_SPAN, Int(1))]), - ], + Lit(DUMMY_SPAN, String("one")), + Call(DUMMY_SPAN, "map-get", [Var(DUMMY_SPAN, "my_map1"), Lit(DUMMY_SPAN, Int(1))]), ) ], ), @@ -184,7 +183,7 @@ def test_sort_alias(self): extract_report = egraph.extract_report() assert isinstance(extract_report, Best) - assert termdag_term_to_expr(extract_report.termdag, extract_report.term) == Call( + assert extract_report.termdag.term_to_expr(extract_report.term, DUMMY_SPAN) == Call( DUMMY_SPAN, "map-insert", [ @@ -224,7 +223,7 @@ def test_i64(self): assert EGraph().eval_i64(Lit(DUMMY_SPAN, Int(1))) == 1 def test_f64(self): - assert EGraph().eval_f64(Lit(DUMMY_SPAN, F64(1.0))) == 1.0 + assert EGraph().eval_f64(Lit(DUMMY_SPAN, Float(1.0))) == 1.0 def test_string(self): assert EGraph().eval_string(Lit(DUMMY_SPAN, String("hi"))) == "hi" @@ -232,17 +231,6 @@ def test_string(self): def test_bool(self): assert EGraph().eval_bool(Lit(DUMMY_SPAN, Bool(True))) is True - @pytest.mark.xfail(reason="Depends on getting actual sort from egraph") - def test_rational(self): - egraph = EGraph() - rational = Call(DUMMY_SPAN, "rational", [Lit(DUMMY_SPAN, Int(1)), Lit(DUMMY_SPAN, Int(2))]) - egraph.run_program( - ActionCommand( - Expr_(DUMMY_SPAN, Call(DUMMY_SPAN, "rational", [Lit(DUMMY_SPAN, Int(1)), Lit(DUMMY_SPAN, Int(2))])) - ) - ) - assert egraph.eval_rational(rational) == fractions.Fraction(1, 2) - class TestThreads: """ diff --git a/python/tests/test_convert.py b/python/tests/test_convert.py index 20ece58e..3a52e7a8 100644 --- a/python/tests/test_convert.py +++ b/python/tests/test_convert.py @@ -82,7 +82,7 @@ def __init__(self) -> None: ... assert expr_parts(convert(MyType(), MyTypeExpr)) == expr_parts(MyTypeExpr()) -T = TypeVar("T", bound=Expr) +T = TypeVar("T", bound=BaseExpr) def test_convert_to_generic(): @@ -91,7 +91,7 @@ def test_convert_to_generic(): particular instance of that generic even if the general instance is registered """ - class G(Expr, Generic[T], builtin=True): + class G(BuiltinExpr, Generic[T]): def __init__(self, x: T) -> None: ... converter(i64, G[i64], lambda x: G(x)) @@ -110,7 +110,7 @@ def test_convert_to_unbound_generic(): particular instance of that generic even if the general instance is registered """ - class G(Expr, Generic[T], builtin=True): + class G(BuiltinExpr, Generic[T]): def __init__(self, x: i64) -> None: ... converter(i64, G, lambda x: G[get_type_args()[0]](x)) # type: ignore[misc, operator] @@ -126,7 +126,7 @@ def test_convert_generic_transitive(): class A(Expr): def __init__(self) -> None: ... - class B(Expr, Generic[T], builtin=True): + class B(BuiltinExpr, Generic[T]): def __init__( self, ) -> None: ... diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 389b55fd..30e1206c 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -17,14 +17,6 @@ TypedExprDecl, ) -EXAMPLE_FILES = list((pathlib.Path(__file__).parent / "../egglog/examples").glob("*.py")) - - -# Test all files in the `examples` directory by importing them in this parametrized test -@pytest.mark.parametrize("name", [f.stem for f in EXAMPLE_FILES if f.stem != "__init__"]) -def test_example(name): - importlib.import_module(f"egglog.examples.{name}") - class TestExprStr: def test_unwrap_lit(self): @@ -223,33 +215,6 @@ def __init__(self, x: i64Like) -> None: ... assert expr_parts(Foo(1)) == expr_parts(Foo(x=1)) -def test_modules() -> None: - with pytest.deprecated_call(): - m = Module() - - @m.class_ - class Numeric(Expr): - ONE: ClassVar[Numeric] - - with pytest.deprecated_call(): - m2 = Module() - - with pytest.deprecated_call(): - - @m2.class_ - class OtherNumeric(Expr): - @m2.method(cost=10) - def __init__(self, v: i64Like) -> None: ... - - egraph = EGraph([m, m2]) - - @function - def from_numeric(n: Numeric) -> OtherNumeric: ... - - egraph.register(rewrite(OtherNumeric(1)).to(from_numeric(Numeric.ONE))) - assert expr_parts(egraph.simplify(OtherNumeric(i64(1)), 10)) == expr_parts(from_numeric(Numeric.ONE)) - - def test_property(): egraph = EGraph() @@ -280,7 +245,10 @@ def test_from_string(self): assert EGraph().eval(PyObject.from_string("foo")) == "foo" def test_to_string(self): - assert EGraph().eval(PyObject("foo").to_string()) == "foo" + x: String = PyObject("foo").to_string() + # reveal_type(cast(Bool, x))) + # reveal_type(EGraph().eval(x)) + assert EGraph().eval(x) == "foo" def test_dict_update(self): original_d = {"foo": "bar"} @@ -462,12 +430,14 @@ def from_int(cls, other: Int) -> Float: ... def test_rewrite_upcasts(): - rewrite(i64(1)).to(0) # type: ignore[arg-type] + class X(Expr): + def __init__(self, value: i64Like) -> None: ... + converter(i64, X, X) + rewrite(X(1)).to(0) # type: ignore[arg-type] -def test_function_default_upcasts(): - EGraph() +def test_function_default_upcasts(): @function def f(x: i64Like) -> i64: ... @@ -578,7 +548,7 @@ def f(x: T) -> T: ... with egraph: @function - def f(x: T, y: T) -> T: ... + def f(x: T, y: T) -> T: ... # type: ignore[misc] egraph.register(f(T(1), T(2))) # type: ignore[call-arg] @@ -802,3 +772,12 @@ def my_fn(xs: MapLike[i64, String, i64Like, StringLike]) -> Unit: ... assert expr_parts(my_fn({1: "hi"})) == expr_parts(my_fn(Map[i64, String].empty().insert(i64(1), String("hi")))) assert expr_parts(my_fn({})) == expr_parts(my_fn(Map[i64, String].empty())) + + +EXAMPLE_FILES = list((pathlib.Path(__file__).parent / "../egglog/examples").glob("*.py")) + + +# Test all files in the `examples` directory by importing them in this parametrized test +@pytest.mark.parametrize("name", [f.stem for f in EXAMPLE_FILES if f.stem != "__init__"]) +def test_example(name): + importlib.import_module(f"egglog.examples.{name}") diff --git a/python/tests/test_modules.py b/python/tests/test_modules.py deleted file mode 100644 index 85b5c38b..00000000 --- a/python/tests/test_modules.py +++ /dev/null @@ -1,43 +0,0 @@ -import pytest - -# from egglog.declarations import ModuleDeclarations -from egglog.egraph import * - -# from egglog.egraph import _BUILTIN_DECLS, BUILTINS - - -def test_tree_modules(): - """ - BUILTINS - / | \ - A B C - | / - D - """ - # assert _BUILTIN_DECLS - # assert BUILTINS._mod_decls == ModuleDeclarations(_BUILTIN_DECLS, []) - - with pytest.deprecated_call(): - A, B, C = Module(), Module(), Module() - # assert list(A._mod_decls._included_decls) == [_BUILTIN_DECLS] - - with pytest.deprecated_call(): - a = A.relation("a") - b = B.relation("b") - c = C.relation("c") - A.register(a()) - B.register(b()) - C.register(c()) - - with pytest.deprecated_call(): - D = Module([A, B]) - d = D.relation("d") - D.register(d()) - - assert D._flatted_deps == [A, B] - - egraph = EGraph([D, B]) - # assert egraph._flatted_deps == [A, B, D] - egraph.check(a(), b(), d()) - with pytest.raises(Exception): # noqa: B017, PT011 - egraph.check(c()) diff --git a/python/tests/test_pretty.py b/python/tests/test_pretty.py index 96fdf4ba..b372e7f1 100644 --- a/python/tests/test_pretty.py +++ b/python/tests/test_pretty.py @@ -139,7 +139,7 @@ def my_very_long_function_name() -> A: ... pytest.param(panic("oh no"), 'panic("oh no")', id="panic"), # Fact pytest.param(expr_fact(A()), "A()", id="expr fact"), - pytest.param(eq(g()).to(h(), A()), "eq(g()).to(h(), A())", id="eq"), + pytest.param(eq(g()).to(h()), "eq(g()).to(h())", id="eq"), # Ruleset pytest.param(ruleset(rewrite(g()).to(h())), "ruleset(rewrite(g()).to(h()))", id="ruleset"), # Schedules diff --git a/python/tests/test_py_object_sort.py b/python/tests/test_py_object_sort.py index c56d2ba6..d69e6f4f 100644 --- a/python/tests/test_py_object_sort.py +++ b/python/tests/test_py_object_sort.py @@ -6,6 +6,8 @@ from egglog.bindings import * +DUMMY_SPAN = RustSpan(__name__, 0, 0) + @dataclasses.dataclass(frozen=True) class MyObject: diff --git a/src/conversions.rs b/src/conversions.rs index a1cab1d9..aa097844 100644 --- a/src/conversions.rs +++ b/src/conversions.rs @@ -8,14 +8,16 @@ use pyo3::types::PyDeltaAccess; use std::collections::HashMap; use std::sync::Arc; +use crate::termdag::TermDag; + convert_enums!( egglog::ast::Literal: "{:}" Hash => Literal { Int[trait=Hash](value: i64) i -> egglog::ast::Literal::Int(i.value), egglog::ast::Literal::Int(i) => Int { value: *i }; - F64[trait=Hash](value: WrappedOrderedF64) - f -> egglog::ast::Literal::F64(f.value.0), - egglog::ast::Literal::F64(f) => F64 { value: WrappedOrderedF64(*f) }; + Float[trait=Hash](value: WrappedOrderedF64) + f -> egglog::ast::Literal::Float(f.value.0), + egglog::ast::Literal::Float(f) => Float { value: WrappedOrderedF64(*f) }; String_[name="String"][trait=Hash](value: String) s -> egglog::ast::Literal::String((&s.value).into()), egglog::ast::Literal::String(s) => String_ { value: s.to_string() }; @@ -42,9 +44,9 @@ convert_enums!( } }; egglog::ast::Fact: "{}" => Fact_ { - Eq(span: Span, exprs: Vec) - eq -> egglog::ast::Fact::Eq(eq.span.clone().into(), eq.exprs.iter().map(|e| e.into()).collect()), - egglog::ast::Fact::Eq(span, e) => Eq { span: span.into(), exprs: e.iter().map(|e| e.into()).collect() }; + Eq(span: Span, left: Expr, right: Expr) + eq -> egglog::ast::Fact::Eq(eq.span.clone().into(), eq.left.clone().into(), eq.right.clone().into()), + egglog::ast::Fact::Eq(span, left, right) => Eq { span: span.into(), left: left.into(), right: right.into() }; Fact(expr: Expr) f -> egglog::ast::Fact::Fact((&f.expr).into()), egglog::ast::Fact::Fact(e) => Fact { expr: e.into() } @@ -154,9 +156,19 @@ convert_enums!( presort_and_args: presort_and_args.as_ref().map(|(p, a)| (p.to_string(), a.iter().map(|e| e.into()).collect())), span: span.into() }; - Function(decl: FunctionDecl) - f -> egglog::ast::Command::Function((&f.decl).into()), - egglog::ast::Command::Function(f) => Function { decl: f.into() }; + Function(span: Span, name: String, schema: Schema, merge: Option) + f -> egglog::ast::Command::Function{ + span: f.span.clone().into(), + name: (&f.name).into(), + schema: (&f.schema).into(), + merge: f.merge.as_ref().map(|e| e.into()) + }, + egglog::ast::Command::Function {span, name, schema, merge} => Function { + span: span.into(), + name: name.to_string(), + schema: schema.into(), + merge: merge.as_ref().map(|e| e.into()) + }; AddRuleset(name: String) a -> egglog::ast::Command::AddRuleset((&a.name).into()), egglog::ast::Command::AddRuleset(n) => AddRuleset { name: n.to_string() }; @@ -259,15 +271,30 @@ convert_enums!( Include(span: Span, path: String) i -> egglog::ast::Command::Include(i.span.clone().into(), (&i.path).into()), egglog::ast::Command::Include(span, p) => Include { span: span.into(), path: p.to_string() }; - Relation(span: Span, constructor: String, inputs: Vec) + Constructor(span: Span, name: String, schema: Schema, cost: Option, unextractable: bool) + c -> egglog::ast::Command::Constructor { + span: c.span.clone().into(), + name: (&c.name).into(), + schema: (&c.schema).into(), + cost: c.cost, + unextractable: c.unextractable + }, + egglog::ast::Command::Constructor {span, name, schema, cost, unextractable} => Constructor { + span: span.into(), + name: name.to_string(), + schema: schema.into(), + cost: *cost, + unextractable: *unextractable + }; + Relation(span: Span, name: String, inputs: Vec) r -> egglog::ast::Command::Relation { span: r.span.clone().into(), - constructor: (&r.constructor).into(), + name: (&r.name).into(), inputs: r.inputs.iter().map(|i| i.into()).collect() }, - egglog::ast::Command::Relation {span, constructor, inputs} => Relation { + egglog::ast::Command::Relation {span, name, inputs} => Relation { span: span.into(), - constructor: constructor.to_string(), + name: name.to_string(), inputs: inputs.iter().map(|i| i.to_string()).collect() }; PrintOverallStatistics() @@ -322,62 +349,41 @@ convert_enums!( termdag: termdag.into(), terms: terms.iter().map(|v| v.into()).collect() } + }; + egglog::ast::Span: "{:?}" => Span { + PanicSpan() + _p -> egglog::ast::Span::Panic, + egglog::ast::Span::Panic => PanicSpan {}; + EgglogSpan(file: SrcFile, i: usize, j: usize) + e -> egglog::ast::Span::Egglog(Arc::new({ + egglog::ast::EgglogSpan { + file: Arc::new(e.file.clone().into()), + i: e.i, + j: e.j + } + })), + egglog::ast::Span::Egglog(e) => EgglogSpan { + file: (*e.file.clone()).clone().into(), + i: e.i, + j: e.j + }; + RustSpan(file: String, line: u32, column: u32) + r -> egglog::ast::Span::Rust(Arc::new(egglog::ast::RustSpan { + file: Box::leak(r.file.clone().into_boxed_str()), + line: r.line, + column: r.column + })), + egglog::ast::Span::Rust(r) => RustSpan {file: r.file.to_string(), line: r.line, column: r.column} } ); convert_struct!( egglog::ast::SrcFile: "{:?}" => SrcFile( - name: String, - contents: Option - ) - s -> egglog::ast::SrcFile {name: s.name.to_string(), contents: s.contents.clone()}, - s -> SrcFile {name: s.name.to_string(), contents: s.contents.clone()}; - egglog::ast::Span: "{:?}" => Span( - file: SrcFile, - start: usize, - end: usize + name: Option, + contents: String ) - s -> egglog::ast::Span(Arc::new(s.file.clone().into()), s.start, s.end), - s -> Span {file: (*s.0.clone()).clone().into(), start: s.1, end: s.2}; - egglog::TermDag: "{:?}" => TermDag( - nodes: Vec, - hashcons: HashMap - ) - t -> egglog::TermDag {nodes: t.nodes.iter().map(|v| v.into()).collect(), hashcons: t.hashcons.iter().map(|(k, v)| (k.clone().into(), *v)).collect()}, - t -> TermDag {nodes: t.nodes.iter().map(|v| v.into()).collect(), hashcons: t.hashcons.iter().map(|(k, v)| (k.clone().into(), *v)).collect()}; - egglog::ast::FunctionDecl: "{:?}" => FunctionDecl( - span: Span, - name: String, - schema: Schema, - default: Option = None, - merge: Option = None, - merge_action: Vec = Vec::new(), - cost: Option = None, - unextractable: bool = false, - ignore_viz: bool = false - ) - f -> egglog::ast::FunctionDecl { - span: f.span.clone().into(), - name: (&f.name).into(), - schema: (&f.schema).into(), - default: f.default.as_ref().map(|e| e.into()), - merge: f.merge.as_ref().map(|e| e.into()), - merge_action: egglog::ast::GenericActions(f.merge_action.iter().map(|a| a.into()).collect()), - cost: f.cost, - unextractable: f.unextractable, - ignore_viz: f.ignore_viz - }, - f -> FunctionDecl { - span: f.span.clone().into(), - name: f.name.to_string(), - schema: (&f.schema).into(), - default: f.default.as_ref().map(|e| e.into()), - merge: f.merge.as_ref().map(|e| e.into()), - merge_action: f.merge_action.0.iter().map(|a| a.into()).collect(), - cost: f.cost, - unextractable: f.unextractable, - ignore_viz: f.ignore_viz - }; + s -> egglog::ast::SrcFile {name: s.name.clone(), contents: s.contents.clone()}, + s -> SrcFile {name: s.name.clone(), contents: s.contents.clone()}; egglog::ast::Variant: "{:?}" => Variant( span: Span, name: String, @@ -392,7 +398,7 @@ convert_struct!( ) s -> egglog::ast::Schema {input: s.input.iter().map(|v| v.into()).collect(), output: (&s.output).into()}, s -> Schema {input: s.input.iter().map(|v| v.to_string()).collect(), output: s.output.to_string()}; - egglog::ast::GenericRule: "{}" => Rule( + egglog::ast::GenericRule: "{:?}" => Rule( span: Span, head: Vec, body: Vec diff --git a/src/egraph.rs b/src/egraph.rs index 24bd7fbc..b5fca9dc 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -5,9 +5,8 @@ use crate::error::{EggResult, WrappedError}; use crate::py_object_sort::{ArcPyObjectSort, MyPyObject, PyObjectSort}; use crate::serialize::SerializedEGraph; -use egglog::ast::DUMMY_SPAN; use egglog::sort::{BoolSort, F64Sort, I64Sort, StringSort}; -use egglog::SerializeConfig; +use egglog::{span, SerializeConfig}; use log::info; use pyo3::prelude::*; use std::path::PathBuf; @@ -37,11 +36,13 @@ impl EGraph { seminaive: bool, record: bool, ) -> Self { - let mut egraph = egglog::EGraph::default(); + let mut egraph = egglog_experimental::new_experimental_egraph(); egraph.fact_directory = fact_directory; egraph.seminaive = seminaive; let py_object_arcsort = if let Some(py_object_sort) = py_object_sort { - egraph.add_arcsort(py_object_sort.0.clone()).unwrap(); + egraph + .add_arcsort(py_object_sort.0.clone(), span!()) + .unwrap(); Some(py_object_sort.0) } else { None @@ -53,6 +54,16 @@ impl EGraph { } } + /// Parse a program into a list of commands. + #[pyo3(signature = (input, /, filename=None))] + fn parse_program(&mut self, input: &str, filename: Option) -> EggResult> { + let commands = self + .egraph + .parser + .get_program_from_string(filename, input)?; + Ok(commands.into_iter().map(|x| x.into()).collect()) + } + /// Run a series of commands on the EGraph. /// Returns a list of strings representing the output. /// An EggSmolError is raised if there is problem parsing or executing. @@ -113,14 +124,9 @@ impl EGraph { max_calls_per_function: Option, include_temporary_functions: bool, ) -> SerializedEGraph { - let root_eclasses: Vec = root_eclasses + let root_eclasses: Vec<_> = root_eclasses .into_iter() - .map(|x| { - self.egraph - .eval_expr(&egglog::ast::Expr::from(x)) - .unwrap() - .1 - }) + .map(|x| self.egraph.eval_expr(&egglog::ast::Expr::from(x)).unwrap()) .collect(); SerializedEGraph { egraph: self.egraph.serialize(SerializeConfig { @@ -157,40 +163,6 @@ impl EGraph { fn eval_bool(&mut self, expr: Expr) -> EggResult { self.eval_sort(expr, Arc::new(BoolSort)) } - - #[pyo3(signature = (expr, /))] - fn eval_rational(&mut self, _py: Python<'_>, expr: Expr) -> EggResult { - // Need to get actual sort for rational, this hack doesnt work. - // todo!(); - // For rational we need the actual sort on the e-graph, because it contains state - // There isn't a public way to get a sort right now, so until there is, we use a hack where we create - // a dummy expression of that sort, and use eval_expr to get the sort - let _one = egglog::ast::Expr::Lit(DUMMY_SPAN.clone(), egglog::ast::Literal::Int(1)); - // let arcsort = self - // .egraph - // .eval_expr(&egglog::ast::Expr::Call( - // (), - // "rational".into(), - // vec![one.clone(), one], - // )) - // .unwrap() - // .0; - let expr: egglog::ast::Expr = expr.into(); - let (_, _value) = self.egraph.eval_expr(&expr)?; - // Need to get actual sort for rational, this hack doesnt work. - todo!(); - // let r = num_rational::Rational64::load(&arcsort, &value); - - // // let r: num_rational::Rational64 = - // // self.eval_sort(expr, Arc::downcast::(arcsort).unwrap())?; - // let frac = py.import("fractions")?; - // let f = frac.call_method( - // "Fraction", - // (r.numer().into_py(py), r.denom().into_py(py)), - // None, - // )?; - // Ok(f.into()) - } } impl EGraph { diff --git a/src/error.rs b/src/error.rs index 59319b1c..8b468e11 100644 --- a/src/error.rs +++ b/src/error.rs @@ -23,6 +23,7 @@ impl EggSmolError { pub enum WrappedError { // Add additional context for egglog error Egglog(egglog::Error, String), + ParseError(egglog::ast::ParseError), Py(PyErr), } @@ -34,6 +35,7 @@ impl From for PyErr { PyErr::new::(error.to_string() + &str) } WrappedError::Py(error) => error, + WrappedError::ParseError(error) => PyErr::new::(error.to_string()), } } } @@ -45,6 +47,12 @@ impl From for WrappedError { } } +impl From for WrappedError { + fn from(other: egglog::ast::ParseError) -> Self { + Self::ParseError(other) + } +} + // Convert from a PyErr to a WrappedError impl From for WrappedError { fn from(other: PyErr) -> Self { diff --git a/src/lib.rs b/src/lib.rs index 8e92e6d5..365bcbeb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,39 +3,21 @@ mod egraph; mod error; mod py_object_sort; mod serialize; +mod termdag; mod utils; -use conversions::{Command, Expr, Span, Term, TermDag}; -use error::EggResult; use pyo3::prelude::*; -#[pyfunction] -fn termdag_term_to_expr(termdag: &TermDag, term: Term) -> Expr { - let termdag: egglog::TermDag = termdag.into(); - let term: egglog::Term = term.into(); - termdag.term_to_expr(&term).into() -} - -/// Parse a program into a list of commands. -#[pyfunction(signature = (input, /, filename=None))] -fn parse_program(input: &str, filename: Option) -> EggResult> { - let commands = egglog::ast::parse_program(filename, input)?; - Ok(commands.into_iter().map(|x| x.into()).collect()) -} - /// Bindings for egglog rust library #[pymodule] fn bindings(m: &Bound<'_, PyModule>) -> PyResult<()> { pyo3_log::init(); - let dummy: Span = egglog::ast::DUMMY_SPAN.clone().into(); - m.add("DUMMY_SPAN", dummy)?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_function(wrap_pyfunction!(termdag_term_to_expr, m)?)?; - m.add_function(wrap_pyfunction!(parse_program, m)?)?; + m.add_class::()?; crate::conversions::add_structs_to_module(m)?; crate::conversions::add_enums_to_module(m)?; diff --git a/src/py_object_sort.rs b/src/py_object_sort.rs index ad20dabc..47960367 100644 --- a/src/py_object_sort.rs +++ b/src/py_object_sort.rs @@ -1,14 +1,13 @@ use crate::error::EggResult; use egglog::{ - ast::{Expr, Literal, Span, Symbol, DUMMY_SPAN}, + ast::{Expr, Literal, Span, Symbol}, + call, constraint::{AllEqualTypeConstraint, SimpleTypeConstraint, TypeConstraint}, - // core::AtomTerm, + extract::{Cost, Extractor}, + lit, sort::{BoolSort, FromSort, I64Sort, IntoSort as _, Sort, StringSort}, util::IndexMap, - EGraph, - PrimitiveLike, - TypeInfo, - Value, + ArcSort, EGraph, PrimitiveLike, Term, TermDag, TypeInfo, Value, }; use pyo3::{ ffi, intern, prelude::*, types::PyDict, AsPyPointer, IntoPy, PyAny, PyErr, PyObject, PyResult, @@ -29,6 +28,7 @@ const NAME: &str = "PyObject"; fn value(i: usize) -> Value { Value { + #[cfg(debug_assertions)] tag: NAME.into(), bits: i as u64, } @@ -55,16 +55,16 @@ impl PyObjectIdent { pub fn to_expr(self) -> Expr { let children = match self { PyObjectIdent::Unhashable(id) => { - vec![Expr::Lit(DUMMY_SPAN.clone(), Literal::Int(id as i64))] + vec![lit!(Literal::Int(id as i64))] } PyObjectIdent::Hashable(type_hash, hash) => { vec![ - Expr::Lit(DUMMY_SPAN.clone(), Literal::Int(type_hash as i64)), - Expr::Lit(DUMMY_SPAN.clone(), Literal::Int(hash as i64)), + lit!(Literal::Int(type_hash as i64)), + lit!(Literal::Int(hash as i64)), ] } }; - Expr::call_no_span("py-object", children) + call!("py-object", children) } } @@ -227,20 +227,27 @@ impl Sort for PyObjectSort { int: typeinfo.get_sort_nofail(), }); } - fn make_expr(&self, _egraph: &EGraph, value: Value) -> (usize, Expr) { + fn extract_term( + &self, + _egraph: &EGraph, + value: Value, + _extractor: &Extractor, + termdag: &mut TermDag, + ) -> Option<(Cost, Term)> { + #[cfg(debug_assertions)] assert!(value.tag == self.name()); let children = match self.load_ident(&value) { PyObjectIdent::Unhashable(id) => { - vec![Expr::Lit(DUMMY_SPAN.clone(), Literal::Int(id as i64))] + vec![termdag.lit(Literal::Int(id as i64))] } PyObjectIdent::Hashable(type_hash, hash) => { vec![ - Expr::Lit(DUMMY_SPAN.clone(), Literal::Int(type_hash as i64)), - Expr::Lit(DUMMY_SPAN.clone(), Literal::Int(hash as i64)), + termdag.lit(Literal::Int(type_hash as i64)), + termdag.lit(Literal::Int(hash as i64)), ] } }; - (1, Expr::call_no_span("py-object", children)) + Some((1, termdag.app("py-object".into(), children))) } } @@ -263,7 +270,12 @@ impl PrimitiveLike for Ctor { .into_box() } - fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + fn apply( + &self, + values: &[Value], + _sorts: (&[ArcSort], &ArcSort), + _egraph: Option<&mut EGraph>, + ) -> Option { let ident = match values { [id] => PyObjectIdent::Unhashable(i64::load(self.i64.as_ref(), id) as usize), [type_hash, hash] => PyObjectIdent::Hashable( @@ -302,7 +314,12 @@ impl PrimitiveLike for Eval { .into_box(); } - fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + fn apply( + &self, + values: &[Value], + _sorts: (&[ArcSort], &ArcSort), + _egraph: Option<&mut EGraph>, + ) -> Option { let code: Symbol = Symbol::load(self.string.as_ref(), &values[0]); let res_obj: PyObject = Python::with_gil(|py| { let globals = self.py_object.load(py, values[1]); @@ -344,7 +361,12 @@ impl PrimitiveLike for Exec { .into_box() } - fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + fn apply( + &self, + values: &[Value], + _sorts: (&[ArcSort], &ArcSort), + _egraph: Option<&mut EGraph>, + ) -> Option { let code: Symbol = Symbol::load(self.string.as_ref(), &values[0]); let code: &str = code.into(); let locals: PyObject = Python::with_gil(|py| { @@ -386,7 +408,12 @@ impl PrimitiveLike for Dict { .into_box() } - fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + fn apply( + &self, + values: &[Value], + _sorts: (&[ArcSort], &ArcSort), + _egraph: Option<&mut EGraph>, + ) -> Option { let dict: PyObject = Python::with_gil(|py| { let dict = PyDict::new_bound(py); // Update the dict with the key-value pairs @@ -419,7 +446,12 @@ impl PrimitiveLike for DictUpdate { .into_box() } - fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + fn apply( + &self, + values: &[Value], + _sorts: (&[ArcSort], &ArcSort), + _egraph: Option<&mut EGraph>, + ) -> Option { let dict: PyObject = Python::with_gil(|py| { let dict = self.py_object.load(py, values[0]); // Copy the dict so we can mutate it and return it @@ -456,7 +488,12 @@ impl PrimitiveLike for ToString { ) .into_box() } - fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + fn apply( + &self, + values: &[Value], + _sorts: (&[ArcSort], &ArcSort), + _egraph: Option<&mut EGraph>, + ) -> Option { let obj: String = Python::with_gil(|py| self.py_object.load(py, values[0]).extract(py).unwrap()); let symbol: Symbol = obj.into(); @@ -485,7 +522,12 @@ impl PrimitiveLike for ToBool { .into_box() } - fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + fn apply( + &self, + values: &[Value], + _sorts: (&[ArcSort], &ArcSort), + _egraph: Option<&mut EGraph>, + ) -> Option { let obj: bool = Python::with_gil(|py| self.py_object.load(py, values[0]).extract(py).unwrap()); obj.store(self.bool_.as_ref()) @@ -513,7 +555,12 @@ impl PrimitiveLike for FromString { .into_box() } - fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + fn apply( + &self, + values: &[Value], + _sorts: (&[ArcSort], &ArcSort), + _egraph: Option<&mut EGraph>, + ) -> Option { let str = Symbol::load(self.string.as_ref(), &values[0]).to_string(); let obj: PyObject = Python::with_gil(|py| str.into_py(py)); Some(self.py_object.store(obj)) @@ -541,7 +588,12 @@ impl PrimitiveLike for FromInt { .into_box(); } - fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + fn apply( + &self, + values: &[Value], + _sorts: (&[ArcSort], &ArcSort), + _egraph: Option<&mut EGraph>, + ) -> Option { let int = i64::load(self.int.as_ref(), &values[0]); let obj: PyObject = Python::with_gil(|py| int.into_py(py)); Some(self.py_object.store(obj)) diff --git a/src/termdag.rs b/src/termdag.rs new file mode 100644 index 00000000..6e333a7b --- /dev/null +++ b/src/termdag.rs @@ -0,0 +1,107 @@ +use crate::conversions::{Expr, Literal, Span, Term}; +use egglog::TermId; +use pyo3::prelude::*; + +#[pyclass()] +#[derive(Clone, PartialEq, Eq)] +pub struct TermDag { + pub termdag: egglog::TermDag, +} + +#[pymethods] +impl TermDag { + /// Create a new, empty TermDag. + #[new] + fn new() -> Self { + Self { + termdag: egglog::TermDag::default(), + } + } + + /// Returns the number of nodes in this DAG. + pub fn size(&self) -> usize { + self.termdag.size() + } + + /// Convert the given term to its id. + /// + /// Panics if the term does not already exist in this [TermDag]. + pub fn lookup(&self, node: Term) -> TermId { + self.termdag.lookup(&node.into()).into() + } + + /// Convert the given id to the corresponding term. + /// + /// Panics if the id is not valid. + pub fn get(&self, id: TermId) -> Term { + self.termdag.get(id).into() + } + /// Make and return a App with the given head symbol and children, + /// and insert into the DAG if it is not already present. + /// + /// Panics if any of the children are not already in the DAG. + pub fn app(&mut self, sym: String, children: Vec) -> Term { + self.termdag + .app(sym.into(), children.into_iter().map(|c| c.into()).collect()) + .into() + } + + /// Make and return a [`Term::Lit`] with the given literal, and insert into + /// the DAG if it is not already present. + pub fn lit(&mut self, lit: Literal) -> Term { + self.termdag.lit(lit.into()).into() + } + + /// Make and return a [`Term::Var`] with the given symbol, and insert into + /// the DAG if it is not already present. + pub fn var(&mut self, sym: String) -> Term { + self.termdag.var(sym.into()).into() + } + + /// Recursively converts the given expression to a term. + /// + /// This involves inserting every subexpression into this DAG. Because + /// TermDags are hashconsed, the resulting term is guaranteed to maximally + /// share subterms. + pub fn expr_to_term(&mut self, expr: Expr) -> Term { + self.termdag.expr_to_term(&expr.into()).into() + } + + /// Recursively converts the given term to an expression. + /// + /// Panics if the term contains subterms that are not in the DAG. + pub fn term_to_expr(&self, term: Term, span: Span) -> Expr { + self.termdag.term_to_expr(&term.into(), span.into()).into() + } + + /// Converts the given term to a string. + /// + /// Panics if the term or any of its subterms are not in the DAG. + pub fn to_string(&self, term: Term) -> String { + self.termdag.to_string(&term.into()) + } +} + +impl From<&egglog::TermDag> for TermDag { + fn from(termdag: &egglog::TermDag) -> Self { + Self { + termdag: termdag.clone(), + } + } +} +impl From<&TermDag> for egglog::TermDag { + fn from(termdag: &TermDag) -> Self { + termdag.termdag.clone() + } +} + +impl From for TermDag { + fn from(termdag: egglog::TermDag) -> Self { + Self { termdag } + } +} +impl From for egglog::TermDag { + fn from(termdag: TermDag) -> Self { + termdag.termdag + } +} diff --git a/src/utils.rs b/src/utils.rs index 9bd50bd2..7a77455c 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -25,7 +25,7 @@ pub fn data_repr( // Macro to create a wrapper around rust enums. // We create Python classes for each variant of the enum // and create a wrapper enum around all variants to enable conversion to/from Python -// and to/from egg_smol +// and to/from egglog #[macro_export] macro_rules! convert_enums { ($( @@ -38,7 +38,7 @@ macro_rules! convert_enums { } );*) => { $($( - #[pyclass(frozen, module="egg_smol.bindings"$(, name=$py_name)?)] + #[pyclass(frozen$(, name=$py_name)?)] #[derive(Clone, PartialEq, Eq$(, $trait_inner)?)] pub struct $variant { $( @@ -178,7 +178,7 @@ macro_rules! convert_struct { $to_ident:ident -> $to:expr );*) => { $( - #[pyclass(frozen, module="egg_smol.bindings")] + #[pyclass(frozen)] #[derive(Clone, PartialEq, Eq$(, $struct_trait)?)] pub struct $to_type { $( diff --git a/stubtest_allow b/stubtest_allow index 8d561883..ba434be8 100644 --- a/stubtest_allow +++ b/stubtest_allow @@ -4,3 +4,5 @@ .*egglog.bindings.PyObjectSort.__init__.* .*egglog.bindings.Delete.__init__.* .*egglog.bindings.Subsume.__init__.* +.*egglog.bindings.TermDag.__init__.* +.*egglog.bindings.PanicSpan.__init__.* diff --git a/uv.lock b/uv.lock index d55bb028..747ab17d 100644 --- a/uv.lock +++ b/uv.lock @@ -1656,36 +1656,40 @@ wheels = [ [[package]] name = "mypy" -version = "1.13.0" +version = "1.15.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "mypy-extensions" }, { name = "tomli", marker = "python_full_version < '3.11'" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e8/21/7e9e523537991d145ab8a0a2fd98548d67646dc2aaaf6091c31ad883e7c1/mypy-1.13.0.tar.gz", hash = "sha256:0291a61b6fbf3e6673e3405cfcc0e7650bebc7939659fdca2702958038bd835e", size = 3152532 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/5e/8c/206de95a27722b5b5a8c85ba3100467bd86299d92a4f71c6b9aa448bfa2f/mypy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6607e0f1dd1fb7f0aca14d936d13fd19eba5e17e1cd2a14f808fa5f8f6d8f60a", size = 11020731 }, - { url = "https://files.pythonhosted.org/packages/ab/bb/b31695a29eea76b1569fd28b4ab141a1adc9842edde080d1e8e1776862c7/mypy-1.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a21be69bd26fa81b1f80a61ee7ab05b076c674d9b18fb56239d72e21d9f4c80", size = 10184276 }, - { url = "https://files.pythonhosted.org/packages/a5/2d/4a23849729bb27934a0e079c9c1aad912167d875c7b070382a408d459651/mypy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b2353a44d2179846a096e25691d54d59904559f4232519d420d64da6828a3a7", size = 12587706 }, - { url = "https://files.pythonhosted.org/packages/5c/c3/d318e38ada50255e22e23353a469c791379825240e71b0ad03e76ca07ae6/mypy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0730d1c6a2739d4511dc4253f8274cdd140c55c32dfb0a4cf8b7a43f40abfa6f", size = 13105586 }, - { url = "https://files.pythonhosted.org/packages/4a/25/3918bc64952370c3dbdbd8c82c363804678127815febd2925b7273d9482c/mypy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:c5fc54dbb712ff5e5a0fca797e6e0aa25726c7e72c6a5850cfd2adbc1eb0a372", size = 9632318 }, - { url = "https://files.pythonhosted.org/packages/d0/19/de0822609e5b93d02579075248c7aa6ceaddcea92f00bf4ea8e4c22e3598/mypy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:581665e6f3a8a9078f28d5502f4c334c0c8d802ef55ea0e7276a6e409bc0d82d", size = 10939027 }, - { url = "https://files.pythonhosted.org/packages/c8/71/6950fcc6ca84179137e4cbf7cf41e6b68b4a339a1f5d3e954f8c34e02d66/mypy-1.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3ddb5b9bf82e05cc9a627e84707b528e5c7caaa1c55c69e175abb15a761cec2d", size = 10108699 }, - { url = "https://files.pythonhosted.org/packages/26/50/29d3e7dd166e74dc13d46050b23f7d6d7533acf48f5217663a3719db024e/mypy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:20c7ee0bc0d5a9595c46f38beb04201f2620065a93755704e141fcac9f59db2b", size = 12506263 }, - { url = "https://files.pythonhosted.org/packages/3f/1d/676e76f07f7d5ddcd4227af3938a9c9640f293b7d8a44dd4ff41d4db25c1/mypy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3790ded76f0b34bc9c8ba4def8f919dd6a46db0f5a6610fb994fe8efdd447f73", size = 12984688 }, - { url = "https://files.pythonhosted.org/packages/9c/03/5a85a30ae5407b1d28fab51bd3e2103e52ad0918d1e68f02a7778669a307/mypy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:51f869f4b6b538229c1d1bcc1dd7d119817206e2bc54e8e374b3dfa202defcca", size = 9626811 }, - { url = "https://files.pythonhosted.org/packages/fb/31/c526a7bd2e5c710ae47717c7a5f53f616db6d9097caf48ad650581e81748/mypy-1.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5c7051a3461ae84dfb5dd15eff5094640c61c5f22257c8b766794e6dd85e72d5", size = 11077900 }, - { url = "https://files.pythonhosted.org/packages/83/67/b7419c6b503679d10bd26fc67529bc6a1f7a5f220bbb9f292dc10d33352f/mypy-1.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:39bb21c69a5d6342f4ce526e4584bc5c197fd20a60d14a8624d8743fffb9472e", size = 10074818 }, - { url = "https://files.pythonhosted.org/packages/ba/07/37d67048786ae84e6612575e173d713c9a05d0ae495dde1e68d972207d98/mypy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:164f28cb9d6367439031f4c81e84d3ccaa1e19232d9d05d37cb0bd880d3f93c2", size = 12589275 }, - { url = "https://files.pythonhosted.org/packages/1f/17/b1018c6bb3e9f1ce3956722b3bf91bff86c1cefccca71cec05eae49d6d41/mypy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a4c1bfcdbce96ff5d96fc9b08e3831acb30dc44ab02671eca5953eadad07d6d0", size = 13037783 }, - { url = "https://files.pythonhosted.org/packages/cb/32/cd540755579e54a88099aee0287086d996f5a24281a673f78a0e14dba150/mypy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0affb3a79a256b4183ba09811e3577c5163ed06685e4d4b46429a271ba174d2", size = 9726197 }, - { url = "https://files.pythonhosted.org/packages/11/bb/ab4cfdc562cad80418f077d8be9b4491ee4fb257440da951b85cbb0a639e/mypy-1.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a7b44178c9760ce1a43f544e595d35ed61ac2c3de306599fa59b38a6048e1aa7", size = 11069721 }, - { url = "https://files.pythonhosted.org/packages/59/3b/a393b1607cb749ea2c621def5ba8c58308ff05e30d9dbdc7c15028bca111/mypy-1.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5d5092efb8516d08440e36626f0153b5006d4088c1d663d88bf79625af3d1d62", size = 10063996 }, - { url = "https://files.pythonhosted.org/packages/d1/1f/6b76be289a5a521bb1caedc1f08e76ff17ab59061007f201a8a18cc514d1/mypy-1.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de2904956dac40ced10931ac967ae63c5089bd498542194b436eb097a9f77bc8", size = 12584043 }, - { url = "https://files.pythonhosted.org/packages/a6/83/5a85c9a5976c6f96e3a5a7591aa28b4a6ca3a07e9e5ba0cec090c8b596d6/mypy-1.13.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:7bfd8836970d33c2105562650656b6846149374dc8ed77d98424b40b09340ba7", size = 13036996 }, - { url = "https://files.pythonhosted.org/packages/b4/59/c39a6f752f1f893fccbcf1bdd2aca67c79c842402b5283563d006a67cf76/mypy-1.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:9f73dba9ec77acb86457a8fc04b5239822df0c14a082564737833d2963677dbc", size = 9737709 }, - { url = "https://files.pythonhosted.org/packages/3b/86/72ce7f57431d87a7ff17d442f521146a6585019eb8f4f31b7c02801f78ad/mypy-1.13.0-py3-none-any.whl", hash = "sha256:9c250883f9fd81d212e0952c92dbfcc96fc237f4b7c92f56ac81fd48460b3e5a", size = 2647043 }, +sdist = { url = "https://files.pythonhosted.org/packages/ce/43/d5e49a86afa64bd3839ea0d5b9c7103487007d728e1293f52525d6d5486a/mypy-1.15.0.tar.gz", hash = "sha256:404534629d51d3efea5c800ee7c42b72a6554d6c400e6a79eafe15d11341fd43", size = 3239717 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/f8/65a7ce8d0e09b6329ad0c8d40330d100ea343bd4dd04c4f8ae26462d0a17/mypy-1.15.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:979e4e1a006511dacf628e36fadfecbcc0160a8af6ca7dad2f5025529e082c13", size = 10738433 }, + { url = "https://files.pythonhosted.org/packages/b4/95/9c0ecb8eacfe048583706249439ff52105b3f552ea9c4024166c03224270/mypy-1.15.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c4bb0e1bd29f7d34efcccd71cf733580191e9a264a2202b0239da95984c5b559", size = 9861472 }, + { url = "https://files.pythonhosted.org/packages/84/09/9ec95e982e282e20c0d5407bc65031dfd0f0f8ecc66b69538296e06fcbee/mypy-1.15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:be68172e9fd9ad8fb876c6389f16d1c1b5f100ffa779f77b1fb2176fcc9ab95b", size = 11611424 }, + { url = "https://files.pythonhosted.org/packages/78/13/f7d14e55865036a1e6a0a69580c240f43bc1f37407fe9235c0d4ef25ffb0/mypy-1.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c7be1e46525adfa0d97681432ee9fcd61a3964c2446795714699a998d193f1a3", size = 12365450 }, + { url = "https://files.pythonhosted.org/packages/48/e1/301a73852d40c241e915ac6d7bcd7fedd47d519246db2d7b86b9d7e7a0cb/mypy-1.15.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2e2c2e6d3593f6451b18588848e66260ff62ccca522dd231cd4dd59b0160668b", size = 12551765 }, + { url = "https://files.pythonhosted.org/packages/77/ba/c37bc323ae5fe7f3f15a28e06ab012cd0b7552886118943e90b15af31195/mypy-1.15.0-cp310-cp310-win_amd64.whl", hash = "sha256:6983aae8b2f653e098edb77f893f7b6aca69f6cffb19b2cc7443f23cce5f4828", size = 9274701 }, + { url = "https://files.pythonhosted.org/packages/03/bc/f6339726c627bd7ca1ce0fa56c9ae2d0144604a319e0e339bdadafbbb599/mypy-1.15.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2922d42e16d6de288022e5ca321cd0618b238cfc5570e0263e5ba0a77dbef56f", size = 10662338 }, + { url = "https://files.pythonhosted.org/packages/e2/90/8dcf506ca1a09b0d17555cc00cd69aee402c203911410136cd716559efe7/mypy-1.15.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2ee2d57e01a7c35de00f4634ba1bbf015185b219e4dc5909e281016df43f5ee5", size = 9787540 }, + { url = "https://files.pythonhosted.org/packages/05/05/a10f9479681e5da09ef2f9426f650d7b550d4bafbef683b69aad1ba87457/mypy-1.15.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:973500e0774b85d9689715feeffcc980193086551110fd678ebe1f4342fb7c5e", size = 11538051 }, + { url = "https://files.pythonhosted.org/packages/e9/9a/1f7d18b30edd57441a6411fcbc0c6869448d1a4bacbaee60656ac0fc29c8/mypy-1.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5a95fb17c13e29d2d5195869262f8125dfdb5c134dc8d9a9d0aecf7525b10c2c", size = 12286751 }, + { url = "https://files.pythonhosted.org/packages/72/af/19ff499b6f1dafcaf56f9881f7a965ac2f474f69f6f618b5175b044299f5/mypy-1.15.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1905f494bfd7d85a23a88c5d97840888a7bd516545fc5aaedff0267e0bb54e2f", size = 12421783 }, + { url = "https://files.pythonhosted.org/packages/96/39/11b57431a1f686c1aed54bf794870efe0f6aeca11aca281a0bd87a5ad42c/mypy-1.15.0-cp311-cp311-win_amd64.whl", hash = "sha256:c9817fa23833ff189db061e6d2eff49b2f3b6ed9856b4a0a73046e41932d744f", size = 9265618 }, + { url = "https://files.pythonhosted.org/packages/98/3a/03c74331c5eb8bd025734e04c9840532226775c47a2c39b56a0c8d4f128d/mypy-1.15.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:aea39e0583d05124836ea645f412e88a5c7d0fd77a6d694b60d9b6b2d9f184fd", size = 10793981 }, + { url = "https://files.pythonhosted.org/packages/f0/1a/41759b18f2cfd568848a37c89030aeb03534411eef981df621d8fad08a1d/mypy-1.15.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2f2147ab812b75e5b5499b01ade1f4a81489a147c01585cda36019102538615f", size = 9749175 }, + { url = "https://files.pythonhosted.org/packages/12/7e/873481abf1ef112c582db832740f4c11b2bfa510e829d6da29b0ab8c3f9c/mypy-1.15.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ce436f4c6d218a070048ed6a44c0bbb10cd2cc5e272b29e7845f6a2f57ee4464", size = 11455675 }, + { url = "https://files.pythonhosted.org/packages/b3/d0/92ae4cde706923a2d3f2d6c39629134063ff64b9dedca9c1388363da072d/mypy-1.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8023ff13985661b50a5928fc7a5ca15f3d1affb41e5f0a9952cb68ef090b31ee", size = 12410020 }, + { url = "https://files.pythonhosted.org/packages/46/8b/df49974b337cce35f828ba6fda228152d6db45fed4c86ba56ffe442434fd/mypy-1.15.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1124a18bc11a6a62887e3e137f37f53fbae476dc36c185d549d4f837a2a6a14e", size = 12498582 }, + { url = "https://files.pythonhosted.org/packages/13/50/da5203fcf6c53044a0b699939f31075c45ae8a4cadf538a9069b165c1050/mypy-1.15.0-cp312-cp312-win_amd64.whl", hash = "sha256:171a9ca9a40cd1843abeca0e405bc1940cd9b305eaeea2dda769ba096932bb22", size = 9366614 }, + { url = "https://files.pythonhosted.org/packages/6a/9b/fd2e05d6ffff24d912f150b87db9e364fa8282045c875654ce7e32fffa66/mypy-1.15.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:93faf3fdb04768d44bf28693293f3904bbb555d076b781ad2530214ee53e3445", size = 10788592 }, + { url = "https://files.pythonhosted.org/packages/74/37/b246d711c28a03ead1fd906bbc7106659aed7c089d55fe40dd58db812628/mypy-1.15.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:811aeccadfb730024c5d3e326b2fbe9249bb7413553f15499a4050f7c30e801d", size = 9753611 }, + { url = "https://files.pythonhosted.org/packages/a6/ac/395808a92e10cfdac8003c3de9a2ab6dc7cde6c0d2a4df3df1b815ffd067/mypy-1.15.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:98b7b9b9aedb65fe628c62a6dc57f6d5088ef2dfca37903a7d9ee374d03acca5", size = 11438443 }, + { url = "https://files.pythonhosted.org/packages/d2/8b/801aa06445d2de3895f59e476f38f3f8d610ef5d6908245f07d002676cbf/mypy-1.15.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c43a7682e24b4f576d93072216bf56eeff70d9140241f9edec0c104d0c515036", size = 12402541 }, + { url = "https://files.pythonhosted.org/packages/c7/67/5a4268782eb77344cc613a4cf23540928e41f018a9a1ec4c6882baf20ab8/mypy-1.15.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:baefc32840a9f00babd83251560e0ae1573e2f9d1b067719479bfb0e987c6357", size = 12494348 }, + { url = "https://files.pythonhosted.org/packages/83/3e/57bb447f7bbbfaabf1712d96f9df142624a386d98fb026a761532526057e/mypy-1.15.0-cp313-cp313-win_amd64.whl", hash = "sha256:b9378e2c00146c44793c98b8d5a61039a048e31f429fb0eb546d93f4b000bedf", size = 9373648 }, + { url = "https://files.pythonhosted.org/packages/09/4e/a7d65c7322c510de2c409ff3828b03354a7c43f5a8ed458a7a131b41c7b9/mypy-1.15.0-py3-none-any.whl", hash = "sha256:5469affef548bd1895d86d3bf10ce2b44e33d86923c29e4d675b3e323437ea3e", size = 2221777 }, ] [[package]]