diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 61257fb3..74c6f833 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -68,11 +68,12 @@ jobs: - run: | export UV_PROJECT_ENVIRONMENT="${pythonLocation}" uv sync --extra test --locked - - uses: CodSpeedHQ/action@v3.8.1 + - uses: CodSpeedHQ/action@v4.0.1 with: token: ${{ secrets.CODSPEED_TOKEN }} # allow updating snapshots due to indeterministic benchmarks run: pytest -vvv --snapshot-update --durations=10 + mode: "instrumentation" docs: runs-on: ubuntu-latest diff --git a/Cargo.lock b/Cargo.lock index c5614477..102adecb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,9 +19,9 @@ checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" [[package]] name = "anyhow" -version = "1.0.99" +version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100" +checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" [[package]] name = "arc-swap" @@ -215,14 +215,14 @@ checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" [[package]] name = "egglog" version = "1.0.0" -source = "git+https://github.com/egraphs-good//egglog.git?branch=main#1f0b6ecfc2eda306945a6869ddea838e9c309ec6" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=clone-cost#2b6de8034d8fd7af0dc8143f0a8ade5734c34f8c" dependencies = [ "csv", "dyn-clone", "egglog-add-primitive", - "egglog-bridge 1.0.0 (git+https://github.com/egraphs-good//egglog.git?branch=main)", - "egglog-core-relations 1.0.0 (git+https://github.com/egraphs-good//egglog.git?branch=main)", - "egglog-numeric-id 1.0.0 (git+https://github.com/egraphs-good//egglog.git?branch=main)", + "egglog-bridge", + "egglog-core-relations", + "egglog-numeric-id", "egraph-serialize", "hashbrown 0.15.5", "im-rc", @@ -240,7 +240,7 @@ dependencies = [ [[package]] name = "egglog-add-primitive" version = "1.0.0" -source = "git+https://github.com/egraphs-good//egglog.git?branch=main#1f0b6ecfc2eda306945a6869ddea838e9c309ec6" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=clone-cost#2b6de8034d8fd7af0dc8143f0a8ade5734c34f8c" dependencies = [ "quote", "syn 2.0.106", @@ -249,35 +249,13 @@ dependencies = [ [[package]] name = "egglog-bridge" version = "1.0.0" -source = "git+https://github.com/egraphs-good//egglog.git?branch=main#1f0b6ecfc2eda306945a6869ddea838e9c309ec6" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=clone-cost#2b6de8034d8fd7af0dc8143f0a8ade5734c34f8c" dependencies = [ "anyhow", "dyn-clone", - "egglog-core-relations 1.0.0 (git+https://github.com/egraphs-good//egglog.git?branch=main)", - "egglog-numeric-id 1.0.0 (git+https://github.com/egraphs-good//egglog.git?branch=main)", - "egglog-union-find 1.0.0 (git+https://github.com/egraphs-good//egglog.git?branch=main)", - "hashbrown 0.15.5", - "indexmap", - "log", - "num-rational", - "once_cell", - "petgraph", - "rayon", - "smallvec", - "thiserror", - "web-time", -] - -[[package]] -name = "egglog-bridge" -version = "1.0.0" -source = "git+https://github.com/egraphs-good/egglog.git?branch=main#1f0b6ecfc2eda306945a6869ddea838e9c309ec6" -dependencies = [ - "anyhow", - "dyn-clone", - "egglog-core-relations 1.0.0 (git+https://github.com/egraphs-good/egglog.git?branch=main)", - "egglog-numeric-id 1.0.0 (git+https://github.com/egraphs-good/egglog.git?branch=main)", - "egglog-union-find 1.0.0 (git+https://github.com/egraphs-good/egglog.git?branch=main)", + "egglog-core-relations", + "egglog-numeric-id", + "egglog-union-find", "hashbrown 0.15.5", "indexmap", "log", @@ -293,63 +271,25 @@ dependencies = [ [[package]] name = "egglog-concurrency" version = "1.0.0" -source = "git+https://github.com/egraphs-good//egglog.git?branch=main#1f0b6ecfc2eda306945a6869ddea838e9c309ec6" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=clone-cost#2b6de8034d8fd7af0dc8143f0a8ade5734c34f8c" dependencies = [ "arc-swap", "rayon", ] -[[package]] -name = "egglog-concurrency" -version = "1.0.0" -source = "git+https://github.com/egraphs-good/egglog.git?branch=main#1f0b6ecfc2eda306945a6869ddea838e9c309ec6" -dependencies = [ - "arc-swap", - "rayon", -] - -[[package]] -name = "egglog-core-relations" -version = "1.0.0" -source = "git+https://github.com/egraphs-good//egglog.git?branch=main#1f0b6ecfc2eda306945a6869ddea838e9c309ec6" -dependencies = [ - "anyhow", - "bumpalo", - "crossbeam-queue", - "dashmap", - "dyn-clone", - "egglog-concurrency 1.0.0 (git+https://github.com/egraphs-good//egglog.git?branch=main)", - "egglog-numeric-id 1.0.0 (git+https://github.com/egraphs-good//egglog.git?branch=main)", - "egglog-union-find 1.0.0 (git+https://github.com/egraphs-good//egglog.git?branch=main)", - "fixedbitset 0.5.7", - "hashbrown 0.15.5", - "indexmap", - "lazy_static", - "log", - "num", - "once_cell", - "petgraph", - "rand 0.9.2", - "rayon", - "rustc-hash", - "smallvec", - "thiserror", - "web-time", -] - [[package]] name = "egglog-core-relations" version = "1.0.0" -source = "git+https://github.com/egraphs-good/egglog.git?branch=main#1f0b6ecfc2eda306945a6869ddea838e9c309ec6" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=clone-cost#2b6de8034d8fd7af0dc8143f0a8ade5734c34f8c" dependencies = [ "anyhow", "bumpalo", "crossbeam-queue", "dashmap", "dyn-clone", - "egglog-concurrency 1.0.0 (git+https://github.com/egraphs-good/egglog.git?branch=main)", - "egglog-numeric-id 1.0.0 (git+https://github.com/egraphs-good/egglog.git?branch=main)", - "egglog-union-find 1.0.0 (git+https://github.com/egraphs-good/egglog.git?branch=main)", + "egglog-concurrency", + "egglog-numeric-id", + "egglog-union-find", "fixedbitset 0.5.7", "hashbrown 0.15.5", "indexmap", @@ -380,16 +320,7 @@ dependencies = [ [[package]] name = "egglog-numeric-id" version = "1.0.0" -source = "git+https://github.com/egraphs-good//egglog.git?branch=main#1f0b6ecfc2eda306945a6869ddea838e9c309ec6" -dependencies = [ - "lazy_static", - "rayon", -] - -[[package]] -name = "egglog-numeric-id" -version = "1.0.0" -source = "git+https://github.com/egraphs-good/egglog.git?branch=main#1f0b6ecfc2eda306945a6869ddea838e9c309ec6" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=clone-cost#2b6de8034d8fd7af0dc8143f0a8ade5734c34f8c" dependencies = [ "lazy_static", "rayon", @@ -398,34 +329,26 @@ dependencies = [ [[package]] name = "egglog-union-find" version = "1.0.0" -source = "git+https://github.com/egraphs-good//egglog.git?branch=main#1f0b6ecfc2eda306945a6869ddea838e9c309ec6" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=clone-cost#2b6de8034d8fd7af0dc8143f0a8ade5734c34f8c" dependencies = [ "crossbeam", - "egglog-concurrency 1.0.0 (git+https://github.com/egraphs-good//egglog.git?branch=main)", - "egglog-numeric-id 1.0.0 (git+https://github.com/egraphs-good//egglog.git?branch=main)", -] - -[[package]] -name = "egglog-union-find" -version = "1.0.0" -source = "git+https://github.com/egraphs-good/egglog.git?branch=main#1f0b6ecfc2eda306945a6869ddea838e9c309ec6" -dependencies = [ - "crossbeam", - "egglog-concurrency 1.0.0 (git+https://github.com/egraphs-good/egglog.git?branch=main)", - "egglog-numeric-id 1.0.0 (git+https://github.com/egraphs-good/egglog.git?branch=main)", + "egglog-concurrency", + "egglog-numeric-id", ] [[package]] name = "egglog_python" -version = "11.2.0" +version = "11.3.0" dependencies = [ "egglog", - "egglog-bridge 1.0.0 (git+https://github.com/egraphs-good/egglog.git?branch=main)", - "egglog-core-relations 1.0.0 (git+https://github.com/egraphs-good/egglog.git?branch=main)", + "egglog-bridge", + "egglog-core-relations", "egglog-experimental", "egraph-serialize", "lalrpop-util", "log", + "num-bigint", + "num-rational", "ordered-float", "pyo3", "pyo3-log", @@ -524,7 +447,7 @@ dependencies = [ "cfg-if", "libc", "r-efi", - "wasi 0.14.5+wasi-0.2.4", + "wasi 0.14.7+wasi-0.2.4", ] [[package]] @@ -560,6 +483,12 @@ dependencies = [ "foldhash", ] +[[package]] +name = "hashbrown" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d" + [[package]] name = "heck" version = "0.5.0" @@ -582,13 +511,14 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.11.1" +version = "2.11.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "206a8042aec68fa4a62e8d3f7aa4ceb508177d9324faf261e1959e495b7a1921" +checksum = "4b0f83760fb341a774ed326568e19f5a863af4a952def8c39f9ab92fd95b88e5" dependencies = [ "equivalent", - "hashbrown 0.15.5", + "hashbrown 0.16.0", "serde", + "serde_core", ] [[package]] @@ -627,9 +557,9 @@ checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "js-sys" -version = "0.3.78" +version = "0.3.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c0b063578492ceec17683ef2f8c5e89121fbd0b172cbc280635ab7567db2738" +checksum = "852f13bec5eba4ba9afbeb93fd7c13fe56147f055939ae21c43a29a0ecb2702e" dependencies = [ "once_cell", "wasm-bindgen", @@ -884,6 +814,8 @@ dependencies = [ "indoc", "libc", "memoffset", + "num-bigint", + "num-rational", "once_cell", "portable-atomic", "pyo3-build-config", @@ -913,8 +845,9 @@ dependencies = [ [[package]] name = "pyo3-log" -version = "0.12.4" -source = "git+https://github.com/alex/pyo3-log.git?branch=pyo3-bump#3193bba54809be49338815beb363a43252ff7843" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "833e6fdc21553e9938d9443050ed3c7787ac3c1a1aefccbd03dfae0c7a4be529" dependencies = [ "arc-swap", "log", @@ -1116,18 +1049,28 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.219" +version = "1.0.225" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd6c24dee235d0da097043389623fb913daddf92c76e9f5a1db88607a0bcbd1d" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.225" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "659356f9a0cb1e529b24c01e43ad2bdf520ec4ceaf83047b83ddcc2251f96383" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.225" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "0ea936adf78b1f766949a4977b91d2f5595825bd6ec079aa9543ad2685fc4516" dependencies = [ "proc-macro2", "quote", @@ -1136,15 +1079,16 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.143" +version = "1.0.145" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d401abef1d108fbd9cbaebc3e46611f4b1021f714a0597a71f41ee463f5f4a5a" +checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" dependencies = [ "indexmap", "itoa", "memchr", "ryu", "serde", + "serde_core", ] [[package]] @@ -1284,27 +1228,27 @@ checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasi" -version = "0.14.5+wasi-0.2.4" +version = "0.14.7+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4494f6290a82f5fe584817a676a34b9d6763e8d9d18204009fb31dceca98fd4" +checksum = "883478de20367e224c0090af9cf5f9fa85bed63a95c1abf3afc5c083ebc06e8c" dependencies = [ "wasip2", ] [[package]] name = "wasip2" -version = "1.0.0+wasi-0.2.4" +version = "1.0.1+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03fa2761397e5bd52002cd7e73110c71af2109aca4e521a9f40473fe685b0a24" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" dependencies = [ "wit-bindgen", ] [[package]] name = "wasm-bindgen" -version = "0.2.101" +version = "0.2.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e14915cadd45b529bb8d1f343c4ed0ac1de926144b746e2710f9cd05df6603b" +checksum = "ab10a69fbd0a177f5f649ad4d8d3305499c42bab9aef2f7ff592d0ec8f833819" dependencies = [ "cfg-if", "once_cell", @@ -1315,9 +1259,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.101" +version = "0.2.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e28d1ba982ca7923fd01448d5c30c6864d0a14109560296a162f80f305fb93bb" +checksum = "0bb702423545a6007bbc368fde243ba47ca275e549c8a28617f56f6ba53b1d1c" dependencies = [ "bumpalo", "log", @@ -1329,9 +1273,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.101" +version = "0.2.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c3d463ae3eff775b0c45df9da45d68837702ac35af998361e2c84e7c5ec1b0d" +checksum = "fc65f4f411d91494355917b605e1480033152658d71f722a90647f56a70c88a0" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -1339,9 +1283,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.101" +version = "0.2.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bb4ce89b08211f923caf51d527662b75bdc9c9c7aab40f86dcb9fb85ac552aa" +checksum = "ffc003a991398a8ee604a401e194b6b3a39677b3173d6e74495eb51b82e99a32" dependencies = [ "proc-macro2", "quote", @@ -1352,9 +1296,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.101" +version = "0.2.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f143854a3b13752c6950862c906306adb27c7e839f7414cec8fea35beab624c1" +checksum = "293c37f4efa430ca14db3721dfbe48d8c33308096bd44d80ebaa775ab71ba1cf" dependencies = [ "unicode-ident", ] @@ -1450,9 +1394,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "wit-bindgen" -version = "0.45.1" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c573471f125075647d03df72e026074b7203790d41351cd6edc96f46bcccd36" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" [[package]] name = "zerocopy" diff --git a/Cargo.toml b/Cargo.toml index 34cc6c8a..babd1bca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,17 +10,19 @@ name = "egglog" crate-type = ["cdylib"] [dependencies] -pyo3 = { version = "0.26", features = ["extension-module"] } - -egglog = { git = "https://github.com/egraphs-good/egglog.git", branch = "main", default-features = false } -# egglog = { path = "../egg-smol" } -egglog-bridge = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" } -egglog-core-relations = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" } +pyo3 = { version = "0.26", features = ["extension-module", "num-bigint", "num-rational"] } +num-bigint = "*" +num-rational = "*" +# egglog = { git = "https://github.com/egraphs-good/egglog.git", branch = "main", default-features = false } +# egglog-bridge = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" } +# egglog-core-relations = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" } +egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "clone-cost", default-features = false } +egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "clone-cost" } +egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "clone-cost" } egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental", branch = "cli", default-features = false } egraph-serialize = { version = "0.2", features = ["serde", "graphviz"] } serde_json = "1" -# https://github.com/vorner/pyo3-log/pull/66 -pyo3-log = { git = "https://github.com/alex/pyo3-log.git", branch = "pyo3-bump" } +pyo3-log = "0.13" log = "0.4" lalrpop-util = { version = "0.22", features = ["lexer"] } ordered-float = "3.7" @@ -29,7 +31,10 @@ rayon = "1.11" # Use patched version of egglog in experimental [patch.'https://github.com/egraphs-good/egglog'] -egglog = { git = "https://github.com/egraphs-good//egglog.git", branch = "main" } +# egglog = { git = "https://github.com/egraphs-good//egglog.git", branch = "main" } +egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "clone-cost" } +egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "clone-cost" } +egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "clone-cost" } # egglog = { path = "../egg-smol" } # egglog = { git = "https://github.com/egraphs-good//egglog.git", rev = "5542549" } diff --git a/Makefile b/Makefile index 42ac1bc3..f03d957f 100644 --- a/Makefile +++ b/Makefile @@ -31,7 +31,7 @@ from-local: # remove once https://github.com/astral-sh/uv/issues/5903 mypy: - uv run dmypy run -- python/ + uv run dmypy run -- python/ docs/ stubtest: uv run python -m mypy.stubtest egglog.bindings --allowlist stubtest_allow diff --git a/docs/changelog.md b/docs/changelog.md index 36b8e725..d51751a1 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,7 @@ _This project uses semantic versioning_ ## UNRELEASED +- Add ability to create custom model and pass them in to extract [#357](https://github.com/egraphs-good/egglog-python/pull/357) ## 11.3.0 (2025-09-12) - Add egglog tutorials, change display to not inline by default, and fix bug looking up binary methods [#352](https://github.com/egraphs-good/egglog-python/pull/352) diff --git a/docs/reference/contributing.md b/docs/reference/contributing.md index 4fe5c2d0..5ba0125c 100644 --- a/docs/reference/contributing.md +++ b/docs/reference/contributing.md @@ -83,7 +83,7 @@ Debug symbols are turned on by default. If there is a performance sensitive piece of code, you could isolate it in a file and profile it locally with: ```bash -uv run py-spy record --format speedscope -- python tmp.py +uv run py-spy record --format speedscope -- python -O tmp.py ``` ### Making changes diff --git a/docs/reference/egglog-translation.md b/docs/reference/egglog-translation.md index c169aabb..a472d916 100644 --- a/docs/reference/egglog-translation.md +++ b/docs/reference/egglog-translation.md @@ -285,6 +285,8 @@ It does this by creating a new table for each function you set the cost for that _Note: Unlike in egglog, where you have to declare which functions support custom costs, in Python all functions are automatically registered to create a custom cost table when they are constructed_ +You can also get the cost of a function with `get_cost`, which will return an `i64` if one has already been set. + ## Defining Rules To define rules in Python, we create a rule with the `rule(*facts).then(*actions) (rule ...)` command in egglog. diff --git a/docs/reference/python-integration.md b/docs/reference/python-integration.md index f6e4c9f6..48154ffa 100644 --- a/docs/reference/python-integration.md +++ b/docs/reference/python-integration.md @@ -635,3 +635,33 @@ r = ruleset( ) egraph.saturate(r) ``` + +## Custom Cost Models + +By default, when extracting from the e-graph, we use a simple cost model, that looks at the costs assigned to each +function and any custom costs set with `set_cost`, and finds the lowest cost expression looking at the total tree size. + +Custom cost models are also supported, which can be passed into `extract` as the `cost_model` keyword argument. They +are defined as functions followed the `CostModel` protocol, that take in an e-graph, an expression, and the costs of the children, and return the total cost of that expression. Costs don't have to be integers, they can be any type that supports comparison. + +There are a few builtin cost models: + +- `default_cost_model`: The default cost model, which uses integer costs and sums them up. +- `greedy_dag_cost_model(inner_cost_model=default_cost_model)`: A cost model which uses a greedy DAG algorithm to find the lowest cost expression, allowing for shared sub-expressions. It takes in another cost model to use for the base costs of each expression. + +Note that when passed into your cost model, the expression won't be a full tree. Instead, only the top level call be present, and all of it's arguments will be opaque "value" expressions, representing e-classes in the e-graph. You can't do much with them except use them to construct other expression to pass into `egraph.lookup_function_value` to get the resulting value of a call with those arguments. The only exception is all builtin types, like ints, vecs, strings, etc. will be fully evaluated recursively, so they can be matched against. + +For example, here is a cost model that has a boolean cost if the value is even or not: + +```{code-cell} python +def is_even_cost_model(egraph: EGraph, expr: Expr, children_costs: list[bool]) -> bool: + from egglog import i64 # noqa: PLC0415 + + match expr: + case i64(i): + return i % 2 == 0 + return False +assert EGraph().extract(i64(10), include_cost=True, cost_model=is_even_cost_model) == (i64(10), True) + +assert EGraph().extract(i64(5), include_cost=True, cost_model=is_even_cost_model) == (i64(5), False) +``` diff --git a/pyproject.toml b/pyproject.toml index c6c55d3c..c5e4ae08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ test = [ "egglog[array]", "pytest-codspeed", "pytest-benchmark", - "pytest-xdist", + "pytest-xdist" ] docs = [ diff --git a/python/egglog/bindings.pyi b/python/egglog/bindings.pyi index 46702646..9c93b64d 100644 --- a/python/egglog/bindings.pyi +++ b/python/egglog/bindings.pyi @@ -1,6 +1,8 @@ +from collections.abc import Callable from datetime import timedelta +from fractions import Fraction from pathlib import Path -from typing import TypeAlias +from typing import Any, Generic, Protocol, TypeAlias, TypeVar from typing_extensions import final @@ -14,6 +16,7 @@ __all__ = [ "Change", "Check", "Constructor", + "CostModel", "Datatype", "Datatypes", "DefaultPrintFunctionMode", @@ -26,6 +29,7 @@ __all__ = [ "Extract", "ExtractBest", "ExtractVariants", + "Extractor", "Fact", "Fail", "Float", @@ -83,6 +87,7 @@ __all__ = [ "UserDefined", "UserDefinedCommandOutput", "UserDefinedOutput", + "Value", "Var", "Variant", ] @@ -128,6 +133,31 @@ class EGraph: max_calls_per_function: int | None = None, include_temporary_functions: bool = False, ) -> SerializedEGraph: ... + def lookup_function(self, name: str, key: list[Value]) -> Value | None: ... + def eval_expr(self, expr: _Expr) -> tuple[str, Value]: ... + def value_to_i64(self, v: Value) -> int: ... + def value_to_f64(self, v: Value) -> float: ... + def value_to_string(self, v: Value) -> str: ... + def value_to_bool(self, v: Value) -> bool: ... + def value_to_rational(self, v: Value) -> Fraction: ... + def value_to_bigint(self, v: Value) -> int: ... + def value_to_bigrat(self, v: Value) -> Fraction: ... + def value_to_pyobject(self, py_object_sort: PyObjectSort, v: Value) -> object: ... + def value_to_map(self, v: Value) -> dict[Value, Value]: ... + def value_to_multiset(self, v: Value) -> list[Value]: ... + def value_to_vec(self, v: Value) -> list[Value]: ... + def value_to_function(self, v: Value) -> tuple[str, list[Value]]: ... + def value_to_set(self, v: Value) -> set[Value]: ... + # def dynamic_cost_model_enode_cost(self, func: str, args: list[Value]) -> int: ... + +@final +class Value: + def __hash__(self) -> int: ... + def __eq__(self, value: object) -> bool: ... + def __lt__(self, other: object) -> bool: ... + def __le__(self, other: object) -> bool: ... + def __gt__(self, other: object) -> bool: ... + def __ge__(self, other: object) -> bool: ... @final class EggSmolError(Exception): @@ -732,3 +762,34 @@ class TermDag: def expr_to_term(self, expr: _Expr) -> _Term: ... def term_to_expr(self, term: _Term, span: _Span) -> _Expr: ... def to_string(self, term: _Term) -> str: ... + +## +# Extraction +## +class _Cost(Protocol): + def __lt__(self, other: _Cost) -> bool: ... + def __le__(self, other: _Cost) -> bool: ... + def __gt__(self, other: _Cost) -> bool: ... + def __ge__(self, other: _Cost) -> bool: ... + +_COST = TypeVar("_COST", bound=_Cost) + +_ENODE_COST = TypeVar("_ENODE_COST") + +@final +class CostModel(Generic[_COST, _ENODE_COST]): + def __init__( + self, + fold: Callable[[str, _ENODE_COST, list[_COST]], _COST], + enode_cost: Callable[[str, list[Value]], _ENODE_COST], + container_cost: Callable[[str, Value, list[_COST]], _COST], + base_value_cost: Callable[[str, Value], _COST], + ) -> None: ... + +@final +class Extractor(Generic[_COST]): + def __init__(self, rootsorts: list[str] | None, egraph: EGraph, cost_model: CostModel[_COST, Any]) -> None: ... + def extract_best(self, egraph: EGraph, termdag: TermDag, value: Value, sort: str) -> tuple[_COST, _Term]: ... + def extract_variants( + self, egraph: EGraph, termdag: TermDag, value: Value, nvariants: int, sort: str + ) -> list[tuple[_COST, _Term]]: ... diff --git a/python/egglog/builtins.py b/python/egglog/builtins.py index 9735495c..269566c0 100644 --- a/python/egglog/builtins.py +++ b/python/egglog/builtins.py @@ -33,10 +33,12 @@ "BigRatLike", "Bool", "BoolLike", + "Container", "ExprValueError", "Map", "MapLike", "MultiSet", + "Primitive", "PyObject", "Rational", "Set", @@ -1135,3 +1137,7 @@ def _convert_function(fn: FunctionType) -> UnstableFn: converter(FunctionType, UnstableFn, _convert_function) + + +Container: TypeAlias = Map | Set | MultiSet | Vec | UnstableFn +Primitive: TypeAlias = String | Bool | i64 | f64 | Rational | BigInt | BigRat | PyObject | Unit diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index c657930f..f3ed264b 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -14,6 +14,8 @@ from typing_extensions import Self, assert_never +from .bindings import Value + if TYPE_CHECKING: from collections.abc import Callable, Iterable, Mapping @@ -49,6 +51,7 @@ "FunctionDecl", "FunctionRef", "FunctionSignature", + "GetCostDecl", "HasDeclerations", "InitRef", "JustTypeRef", @@ -82,6 +85,7 @@ "UnboundVarDecl", "UnionDecl", "UnnamedFunctionRef", + "ValueDecl", "collect_unbound_vars", "replace_typed_expr", "upcast_declerations", @@ -639,7 +643,7 @@ class CallDecl: args: tuple[TypedExprDecl, ...] = () # type parameters that were bound to the callable, if it is a classmethod # Used for pretty printing classmethod calls with type parameters - bound_tp_params: tuple[JustTypeRef, ...] | None = None + bound_tp_params: tuple[JustTypeRef, ...] = () # pool objects for faster __eq__ _args_to_value: ClassVar[WeakValueDictionary[tuple[object, ...], CallDecl]] = WeakValueDictionary({}) @@ -654,7 +658,7 @@ def __new__(cls, *args: object, **kwargs: object) -> Self: # normalize the args/kwargs to a tuple so that they can be compared callable = args[0] if args else kwargs["callable"] args_ = args[1] if len(args) > 1 else kwargs.get("args", ()) - bound_tp_params = args[2] if len(args) > 2 else kwargs.get("bound_tp_params") + bound_tp_params = args[2] if len(args) > 2 else kwargs.get("bound_tp_params", ()) normalized_args = (callable, args_, bound_tp_params) try: @@ -696,7 +700,20 @@ class PartialCallDecl: call: CallDecl -ExprDecl: TypeAlias = UnboundVarDecl | LetRefDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl +@dataclass(frozen=True) +class GetCostDecl: + callable: CallableRef + args: tuple[TypedExprDecl, ...] + + +@dataclass(frozen=True) +class ValueDecl: + value: Value + + +ExprDecl: TypeAlias = ( + UnboundVarDecl | LetRefDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl | ValueDecl | GetCostDecl +) @dataclass(frozen=True) diff --git a/python/egglog/deconstruct.py b/python/egglog/deconstruct.py index aab80705..1d60b85d 100644 --- a/python/egglog/deconstruct.py +++ b/python/egglog/deconstruct.py @@ -11,7 +11,7 @@ from typing_extensions import TypeVarTuple, Unpack from .declarations import * -from .egraph import BaseExpr +from .egraph import BaseExpr, Expr from .runtime import * from .thunk import * @@ -49,7 +49,11 @@ def get_literal_value(x: PyObject) -> object: ... def get_literal_value(x: UnstableFn[T, Unpack[TS]]) -> Callable[[Unpack[TS]], T] | None: ... -def get_literal_value(x: String | Bool | i64 | f64 | PyObject | UnstableFn) -> object: +@overload +def get_literal_value(x: Expr) -> None: ... + + +def get_literal_value(x: object) -> object: """ Returns the literal value of an expression if it is a literal. If it is not a literal, returns None. @@ -95,12 +99,9 @@ def get_var_name(x: BaseExpr) -> str | None: return None -def get_callable_fn(x: T) -> Callable[..., T] | None: +def get_callable_fn(x: T) -> Callable[..., T] | T | None: """ - Gets the function of an expression if it is a call expression. - If it is not a call expression (a property, a primitive value, constants, classvars, a let value), return None. - For those values, you can check them by comparing them directly with equality or for primitives calling `.eval()` - to return the Python value. + Gets the function of an expression, or if it's a constant or classvar, return that. """ if not isinstance(x, RuntimeExpr): raise TypeError(f"Expected Expression, got {type(x).__name__}") @@ -159,6 +160,7 @@ def _deconstruct_call_decl( """ args = call.args arg_exprs = tuple(RuntimeExpr(decls_thunk, Thunk.value(a)) for a in args) + # TODO: handle values? Like constants if isinstance(call.callable, InitRef): return RuntimeClass( decls_thunk, diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index f391c642..64d97fe9 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -16,6 +16,7 @@ ClassVar, Generic, Literal, + Protocol, TypeAlias, TypedDict, TypeVar, @@ -41,7 +42,7 @@ from .version_compat import * if TYPE_CHECKING: - from .builtins import String, Unit, i64Like + from .builtins import String, Unit, i64, i64Like __all__ = [ @@ -51,11 +52,14 @@ "BuiltinExpr", "Command", "Command", + "CostModel", "EGraph", "Expr", + "ExprCallable", "Fact", "Fact", "GraphvizKwargs", + "GreedyDagCost", "RewriteOrRule", "Ruleset", "Schedule", @@ -70,12 +74,15 @@ "check", "check_eq", "constant", + "default_cost_model", "delete", "eq", "expr_action", "expr_fact", "expr_parts", "function", + "get_cost", + "greedy_dag_cost_model", "let", "method", "ne", @@ -88,6 +95,7 @@ "seq", "set_", "set_cost", + "set_current_ruleset", "subsume", "union", "unstable_combine_rulesets", @@ -452,7 +460,7 @@ def _generate_class_decls( # noqa: C901,PLR0912 continue locals = frame.f_locals ref: ClassMethodRef | MethodRef | PropertyRef | InitRef - # TODO: Store deprecated message so we can print at runtime + # TODO: Store deprecated message so we can get at runtime if (getattr(fn, "__deprecated__", None)) is not None: fn = fn.__wrapped__ # type: ignore[attr-defined] match fn: @@ -953,22 +961,45 @@ def _facts_to_check(self, fact_likes: Iterable[FactLike]) -> bindings.Check: return bindings.Check(span(2), egg_facts) @overload - def extract(self, expr: BASE_EXPR, /, include_cost: Literal[False] = False) -> BASE_EXPR: ... + def extract( + self, expr: BASE_EXPR, /, include_cost: Literal[False] = False, cost_model: CostModel | None = None + ) -> BASE_EXPR: ... @overload - def extract(self, expr: BASE_EXPR, /, include_cost: Literal[True]) -> tuple[BASE_EXPR, int]: ... + def extract( + self, expr: BASE_EXPR, /, include_cost: Literal[True], cost_model: None = None + ) -> tuple[BASE_EXPR, int]: ... - def extract(self, expr: BASE_EXPR, include_cost: bool = False) -> BASE_EXPR | tuple[BASE_EXPR, int]: + @overload + def extract( + self, expr: BASE_EXPR, /, include_cost: Literal[True], cost_model: CostModel[COST] + ) -> tuple[BASE_EXPR, COST]: ... + + def extract( + self, expr: BASE_EXPR, /, include_cost: bool = False, cost_model: CostModel[COST] | None = None + ) -> BASE_EXPR | tuple[BASE_EXPR, COST]: """ Extract the lowest cost expression from the egraph. """ runtime_expr = to_runtime_expr(expr) - extract_report = self._run_extract(runtime_expr, 0) - assert isinstance(extract_report, bindings.ExtractBest) - res = self._from_termdag(extract_report.termdag, extract_report.term, runtime_expr.__egg_typed_expr__.tp) - if include_cost: - return res, extract_report.cost - return res + self._add_decls(runtime_expr) + tp = runtime_expr.__egg_typed_expr__.tp + if cost_model is None: + extract_report = self._run_extract(runtime_expr, 0) + assert isinstance(extract_report, bindings.ExtractBest) + res = self._from_termdag(extract_report.termdag, extract_report.term, tp) + cost = cast("COST", extract_report.cost) + else: + # TODO: For some reason we need this or else it wont be registered. Not sure why + self.register(expr) + egg_cost_model = _CostModel(cost_model, self).to_bindings_cost_model() + egg_sort = self._state.type_ref_to_egg(tp) + extractor = bindings.Extractor([egg_sort], self._state.egraph, egg_cost_model) + termdag = bindings.TermDag() + value = self._state.typed_expr_to_value(runtime_expr.__egg_typed_expr__) + cost, term = extractor.extract_best(self._state.egraph, termdag, value, egg_sort) + res = self._from_termdag(termdag, term, tp) + return (res, cost) if include_cost else res def _from_termdag(self, termdag: bindings.TermDag, term: bindings._Term, tp: JustTypeRef) -> Any: (new_typed_expr,) = self._state.exprs_from_egg(termdag, [term], tp) @@ -979,6 +1010,7 @@ def extract_multiple(self, expr: BASE_EXPR, n: int) -> list[BASE_EXPR]: Extract multiple expressions from the egraph. """ runtime_expr = to_runtime_expr(expr) + self._add_decls(runtime_expr) extract_report = self._run_extract(runtime_expr, n) assert isinstance(extract_report, bindings.ExtractVariants) new_exprs = self._state.exprs_from_egg( @@ -987,7 +1019,6 @@ def extract_multiple(self, expr: BASE_EXPR, n: int) -> list[BASE_EXPR]: return [cast("BASE_EXPR", RuntimeExpr.__from_values__(self.__egg_decls__, expr)) for expr in new_exprs] def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._CommandOutput: - self._add_decls(expr) expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__) # If we have defined any cost tables use the custom extraction args = (expr, bindings.Lit(span(2), bindings.Int(n))) @@ -1212,16 +1243,12 @@ def all_function_sizes(self) -> list[tuple[ExprCallable, int]]: """ (output,) = self._egraph.run_program(bindings.PrintSize(span(1), None)) assert isinstance(output, bindings.PrintAllFunctionsSize) + return [(callables[0], size) for (name, size) in output.sizes if (callables := self._egg_fn_to_callables(name))] + + def _egg_fn_to_callables(self, egg_fn: str) -> list[ExprCallable]: return [ - ( - cast( - "ExprCallable", - create_callable(self._state.__egg_decls__, next(iter(refs))), - ), - size, - ) - for (name, size) in output.sizes - if (refs := self._state.egg_fn_to_callable_refs[name]) + cast("ExprCallable", create_callable(self._state.__egg_decls__, ref)) + for ref in self._state.egg_fn_to_callable_refs[egg_fn] ] def function_values( @@ -1245,6 +1272,33 @@ def function_values( for (call, res) in output.terms } + def lookup_function_value(self, expr: BASE_EXPR) -> BASE_EXPR | None: + """ + Given an expression that is a function call, looks up the value of the function call if it exists. + """ + runtime_expr = to_runtime_expr(expr) + typed_expr = runtime_expr.__egg_typed_expr__ + assert isinstance(typed_expr.expr, CallDecl | GetCostDecl) + egg_fn, typed_args = self._state.translate_call(typed_expr.expr) + values_args = [self._state.typed_expr_to_value(a) for a in typed_args] + possible_value = self._egraph.lookup_function(egg_fn, values_args) + if possible_value is None: + return None + return cast( + "BASE_EXPR", + RuntimeExpr.__from_values__( + self.__egg_decls__, + TypedExprDecl(typed_expr.tp, self._state.value_to_expr(typed_expr.tp, possible_value)), + ), + ) + + def has_custom_cost(self, fn: ExprCallable) -> bool: + """ + Checks if the any custom costs have been set for this expression callable. + """ + resolved, _ = resolve_callable(fn) + return resolved in self._state.cost_callables + # Either a constant or a function. ExprCallable: TypeAlias = Callable[..., BaseExpr] | BaseExpr @@ -1910,3 +1964,246 @@ def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]: yield finally: _CURRENT_RULESET.reset(token) + + +def get_cost(expr: BaseExpr) -> i64: + """ + Return a lookup of the cost of an expression. If not set, won't match. + """ + assert isinstance(expr, RuntimeExpr) + expr_decl = expr.__egg_typed_expr__.expr + if not isinstance(expr_decl, CallDecl): + msg = "Can only get cost of function calls, not literals or variables" + raise TypeError(msg) + return RuntimeExpr.__from_values__( + expr.__egg_decls__, + TypedExprDecl(JustTypeRef("i64"), GetCostDecl(expr_decl.callable, expr_decl.args)), + ) + + +class Comparable(Protocol): + def __lt__(self, other: Self) -> bool: ... + def __le__(self, other: Self) -> bool: ... + def __gt__(self, other: Self) -> bool: ... + def __ge__(self, other: Self) -> bool: ... + + +COST = TypeVar("COST", bound=Comparable) + + +class CostModel(Protocol, Generic[COST]): + """ + A cost model for an e-graph. Used to determine the cost of an expression based on its structure and the costs of its sub-expressions. + + Called with an expression and the costs of its children, returns the total cost of the expression. + + Additionally, the cost model should guarantee that a term has a no-smaller cost + than its subterms to avoid cycles in the extracted terms for common case usages. + For more niche usages, a term can have a cost less than its subterms. + As long as there is no negative cost cycle, the default extractor is guaranteed to terminate in computing the costs. + However, the user needs to be careful to guarantee acyclicity in the extracted terms. + """ + + def __call__(self, egraph: EGraph, expr: BaseExpr, children_costs: list[COST]) -> COST: + """ + The total cost of a term given the cost of the root e-node and its immediate children's total costs. + """ + raise NotImplementedError + + +def default_cost_model(egraph: EGraph, expr: BaseExpr, children_costs: list[int]) -> int: + """ + A default cost model for an e-graph, which looks up costs set on function calls, or uses 1 as the default cost. + """ + from .builtins import Container # noqa: PLC0415 + from .deconstruct import get_callable_fn # noqa: PLC0415 + + # 1. First prefer if the expr has a custom cost set on it + if ( + (callable_fn := get_callable_fn(expr)) is not None + and egraph.has_custom_cost(callable_fn) + and (i := egraph.lookup_function_value(get_cost(expr))) is not None + ): + self_cost = int(i) + # 2. Else, check if this is a callable and it has a cost set on its declaration + elif callable_fn is not None and (callable_cost := get_callable_cost(callable_fn)) is not None: + self_cost = callable_cost + # 3. Else, if this is a container, it has no cost, otherwise it has a cost of 1 + else: + # By default, all nodes have a cost of 1 except for containers which have a cost of 0 + self_cost = 0 if isinstance(expr, Container) else 1 + # Sum up the costs of the children and our own cost + return sum(children_costs, start=self_cost) + + +class ComparableAddSub(Comparable, Protocol): + def __add__(self, other: Self) -> Self: ... + def __sub__(self, other: Self) -> Self: ... + + +DAG_COST = TypeVar("DAG_COST", bound=ComparableAddSub) + + +@dataclass +class GreedyDagCost(Generic[DAG_COST]): + """ + Cost of a DAG, which stores children costs. Use `.total` to get the underlying cost. + """ + + total: DAG_COST + _costs: dict[TypedExprDecl, DAG_COST] = field(repr=False) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, GreedyDagCost): + return NotImplemented + return self.total == other.total + + def __lt__(self, other: Self) -> bool: + return self.total < other.total + + def __le__(self, other: Self) -> bool: + return self.total <= other.total + + def __gt__(self, other: Self) -> bool: + return self.total > other.total + + def __ge__(self, other: Self) -> bool: + return self.total >= other.total + + def __hash__(self) -> int: + return hash(self.total) + + +@dataclass +class GreedyDagCostModel(CostModel[GreedyDagCost[DAG_COST]]): + """ + A cost model which will count duplicate nodes only once. + + Should have similar behavior as https://github.com/egraphs-good/extraction-gym/blob/main/src/extract/greedy_dag.rs + but implemented as a cost model that will be used with the default extractor. + """ + + base: CostModel[DAG_COST] + + def __call__( + self, egraph: EGraph, expr: BaseExpr, children_costs: list[GreedyDagCost[DAG_COST]] + ) -> GreedyDagCost[DAG_COST]: + cost = self.base(egraph, expr, [c.total for c in children_costs]) + for c in children_costs: + cost -= c.total + costs = {} + for c in children_costs: + costs.update(c._costs) + total = sum(costs.values(), start=cost) + costs[to_runtime_expr(expr).__egg_typed_expr__] = cost + return GreedyDagCost(total, costs) + + +@overload +def greedy_dag_cost_model() -> CostModel[GreedyDagCost[int]]: ... + + +@overload +def greedy_dag_cost_model(base: CostModel[DAG_COST]) -> CostModel[GreedyDagCost[DAG_COST]]: ... + + +def greedy_dag_cost_model(base: CostModel[Any] = default_cost_model) -> CostModel[GreedyDagCost[Any]]: + """ + Creates a greedy dag cost model from a base cost model. + """ + return GreedyDagCostModel(base or default_cost_model) + + +def get_callable_cost(fn: ExprCallable) -> int | None: + """ + Returns the cost of a callable, if it has one set. Otherwise returns None. + """ + callable_ref, decls = resolve_callable(fn) + callable_decl = decls.get_callable_decl(callable_ref) + return callable_decl.cost if isinstance(callable_decl, ConstructorDecl) else 1 + + +@dataclass +class _CostModel(Generic[COST]): + """ + Implements the methods compatible with the bindings for the cost model. + """ + + model: CostModel[COST] + egraph: EGraph + enode_cost_results: dict[tuple[str, tuple[bindings.Value, ...]], int] = field(default_factory=dict) + enode_cost_expressions: list[RuntimeExpr] = field(default_factory=list) + fold_results: dict[tuple[int, tuple[COST, ...]], COST] = field(default_factory=dict) + base_value_cost_results: dict[tuple[str, bindings.Value], COST] = field(default_factory=dict) + container_cost_results: dict[tuple[str, bindings.Value, tuple[COST, ...]], COST] = field(default_factory=dict) + + def call_model(self, expr: RuntimeExpr, children_costs: list[COST]) -> COST: + return self.model(self.egraph, cast("BaseExpr", expr), children_costs) + # if __debug__: + # for c in children_costs: + # if res <= c: + # msg = f"Cost model {self.model} produced a cost {res} less than or equal to a child cost {c} for {expr}" + # raise ValueError(msg) + + def fold(self, _fn: str, index: int, children_costs: list[COST]) -> COST: + try: + return self.fold_results[(index, tuple(children_costs))] + except KeyError: + pass + + expr = self.enode_cost_expressions[index] + return self.call_model(expr, children_costs) + + # enode cost is only ever called right before fold, for the head_cost + def enode_cost(self, name: str, args: list[bindings.Value]) -> int: + try: + return self.enode_cost_results[(name, tuple(args))] + except KeyError: + pass + (callable_ref,) = self.egraph._state.egg_fn_to_callable_refs[name] + signature = self.egraph.__egg_decls__.get_callable_decl(callable_ref).signature + assert isinstance(signature, FunctionSignature) + arg_exprs = [ + TypedExprDecl(tp.to_just(), self.egraph._state.value_to_expr(tp.to_just(), arg)) + for (arg, tp) in zip(args, signature.arg_types, strict=True) + ] + res_type = signature.semantic_return_type.to_just() + res = RuntimeExpr.__from_values__( + self.egraph.__egg_decls__, + TypedExprDecl(res_type, CallDecl(callable_ref, tuple(arg_exprs))), + ) + index = len(self.enode_cost_expressions) + self.enode_cost_expressions.append(res) + self.enode_cost_results[(name, tuple(args))] = index + return index + + def base_value_cost(self, tp: str, value: bindings.Value) -> COST: + try: + return self.base_value_cost_results[(tp, value)] + except KeyError: + pass + type_ref = self.egraph._state.egg_sort_to_type_ref[tp] + expr = RuntimeExpr.__from_values__( + self.egraph.__egg_decls__, + TypedExprDecl(type_ref, self.egraph._state.value_to_expr(type_ref, value)), + ) + res = self.call_model(expr, []) + self.base_value_cost_results[(tp, value)] = res + return res + + def container_cost(self, tp: str, value: bindings.Value, element_costs: list[COST]) -> COST: + try: + return self.container_cost_results[(tp, value, tuple(element_costs))] + except KeyError: + pass + type_ref = self.egraph._state.egg_sort_to_type_ref[tp] + expr = RuntimeExpr.__from_values__( + self.egraph.__egg_decls__, + TypedExprDecl(type_ref, self.egraph._state.value_to_expr(type_ref, value)), + ) + res = self.call_model(expr, element_costs) + self.container_cost_results[(tp, value, tuple(element_costs))] = res + return res + + def to_bindings_cost_model(self) -> bindings.CostModel[COST, int]: + return bindings.CostModel(self.fold, self.enode_cost, self.container_cost, self.base_value_cost) diff --git a/python/egglog/egraph_state.py b/python/egglog/egraph_state.py index acca0778..19aeba16 100644 --- a/python/egglog/egraph_state.py +++ b/python/egglog/egraph_state.py @@ -68,6 +68,7 @@ class EGraphState: # Bidirectional mapping between egg sort names and python type references. type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict) + egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict) # Cache of egg expressions for converting to egg expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict) @@ -86,6 +87,7 @@ def copy(self) -> EGraphState: egg_fn_to_callable_refs=defaultdict(set, {k: v.copy() for k, v in self.egg_fn_to_callable_refs.items()}), callable_ref_to_egg_fn=self.callable_ref_to_egg_fn.copy(), type_ref_to_egg_sort=self.type_ref_to_egg_sort.copy(), + egg_sort_to_type_ref=self.egg_sort_to_type_ref.copy(), expr_to_egg_cache=self.expr_to_egg_cache.copy(), cost_callables=self.cost_callables.copy(), ) @@ -352,6 +354,7 @@ def create_cost_table(self, ref: CallableRef) -> str: Creates the egg cost table if needed and gets the name of the table. """ name = self.cost_table_name(ref) + print(name, self.cost_callables) if ref not in self.cost_callables: self.cost_callables.add(ref) signature = self.__egg_decls__.get_callable_decl(ref).signature @@ -455,10 +458,14 @@ def type_ref_to_egg(self, ref: JustTypeRef) -> str: # noqa: C901, PLR0912 pass decl = self.__egg_decls__._classes[ref.name] self.type_ref_to_egg_sort[ref] = egg_name = decl.egg_name or _generate_type_egg_name(ref) + self.egg_sort_to_type_ref[egg_name] = ref if not decl.builtin or ref.args: if ref.args: if ref.name == "UnstableFn": # UnstableFn is a special case, where the rest of args are collected into a call + if len(ref.args) < 2: + msg = "Zero argument higher order functions not supported" + raise NotImplementedError(msg) type_args: list[bindings._Expr] = [ bindings.Call( span(), @@ -589,11 +596,9 @@ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: # noqa: PLR0912, case _: assert_never(value) res = bindings.Lit(span(), l) - case CallDecl(ref, args, _): - egg_fn, reverse_args = self.callable_ref_to_egg(ref) - egg_args = [self.typed_expr_to_egg(a, False) for a in args] - if reverse_args: - egg_args.reverse() + case CallDecl() | GetCostDecl(): + egg_fn, typed_args = self.translate_call(expr_decl) + egg_args = [self.typed_expr_to_egg(a, False) for a in typed_args] res = bindings.Call(span(), egg_fn, egg_args) case PyObjectDecl(value): res = GLOBAL_PY_OBJECT_SORT.store(value) @@ -604,11 +609,31 @@ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: # noqa: PLR0912, "unstable-fn", [bindings.Lit(span(), bindings.String(egg_fn_call.name)), *egg_fn_call.args], ) + case ValueDecl(): + msg = "Cannot turn a Value into an expression" + raise ValueError(msg) case _: assert_never(expr_decl.expr) self.expr_to_egg_cache[expr_decl] = res return res + def translate_call(self, expr: CallDecl | GetCostDecl) -> tuple[str, list[TypedExprDecl]]: + """ + Handle get cost and call decl, turn into egg table name and typed expr decls. + """ + match expr: + case CallDecl(ref, args, _): + egg_fn, reverse_args = self.callable_ref_to_egg(ref) + args_list = list(args) + if reverse_args: + args_list.reverse() + return egg_fn, args_list + case GetCostDecl(ref, args): + cost_table = self.create_cost_table(ref) + return cost_table, list(args) + case _: + assert_never(expr) + def exprs_from_egg( self, termdag: bindings.TermDag, terms: list[bindings._Term], tp: JustTypeRef ) -> Iterable[TypedExprDecl]: @@ -652,6 +677,129 @@ def _generate_callable_egg_name(self, ref: CallableRef) -> str: case _: assert_never(ref) + def typed_expr_to_value(self, typed_expr: TypedExprDecl) -> bindings.Value: + egg_expr = self.typed_expr_to_egg(typed_expr, False) + return self.egraph.eval_expr(egg_expr)[1] + + def value_to_expr(self, tp: JustTypeRef, value: bindings.Value) -> ExprDecl: # noqa: C901, PLR0911, PLR0912 + match tp.name: + # Should match list in egraph bindings + case "i64": + return LitDecl(self.egraph.value_to_i64(value)) + case "f64": + return LitDecl(self.egraph.value_to_f64(value)) + case "Bool": + return LitDecl(self.egraph.value_to_bool(value)) + case "String": + return LitDecl(self.egraph.value_to_string(value)) + case "Unit": + return LitDecl(None) + case "PyObject": + return PyObjectDecl(self.egraph.value_to_pyobject(GLOBAL_PY_OBJECT_SORT, value)) + case "Rational": + fraction = self.egraph.value_to_rational(value) + return CallDecl( + InitRef("Rational"), + ( + TypedExprDecl(JustTypeRef("i64"), LitDecl(fraction.numerator)), + TypedExprDecl(JustTypeRef("i64"), LitDecl(fraction.denominator)), + ), + ) + case "BigInt": + i = self.egraph.value_to_bigint(value) + return CallDecl( + ClassMethodRef("BigInt", "from_string"), + (TypedExprDecl(JustTypeRef("String"), LitDecl(str(i))),), + ) + case "BigRat": + fraction = self.egraph.value_to_bigrat(value) + return CallDecl( + InitRef("BigRat"), + ( + TypedExprDecl( + JustTypeRef("BigInt"), + CallDecl( + ClassMethodRef("BigInt", "from_string"), + (TypedExprDecl(JustTypeRef("String"), LitDecl(str(fraction.numerator))),), + ), + ), + TypedExprDecl( + JustTypeRef("BigInt"), + CallDecl( + ClassMethodRef("BigInt", "from_string"), + (TypedExprDecl(JustTypeRef("String"), LitDecl(str(fraction.denominator))),), + ), + ), + ), + ) + case "Map": + k_tp, v_tp = tp.args + expr = CallDecl(ClassMethodRef("Map", "empty"), (), (k_tp, v_tp)) + for k, v in self.egraph.value_to_map(value).items(): + expr = CallDecl( + MethodRef("Map", "insert"), + ( + TypedExprDecl(tp, expr), + TypedExprDecl(k_tp, self.value_to_expr(k_tp, k)), + TypedExprDecl(v_tp, self.value_to_expr(v_tp, v)), + ), + ) + return expr + case "Set": + xs_ = self.egraph.value_to_set(value) + (v_tp,) = tp.args + return CallDecl( + InitRef("Set"), tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs_), (v_tp,) + ) + case "Vec": + xs = self.egraph.value_to_vec(value) + (v_tp,) = tp.args + return CallDecl( + InitRef("Vec"), tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs), (v_tp,) + ) + case "MultiSet": + xs = self.egraph.value_to_multiset(value) + (v_tp,) = tp.args + return CallDecl( + InitRef("MultiSet"), tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs), (v_tp,) + ) + case "UnstableFn": + _names, _args = self.egraph.value_to_function(value) + return_tp, *arg_types = tp.args + return self._unstable_fn_value_to_expr(_names, _args, return_tp, arg_types) + return ValueDecl(value) + + def _unstable_fn_value_to_expr( + self, name: str, partial_args: list[bindings.Value], return_tp: JustTypeRef, _arg_types: list[JustTypeRef] + ) -> PartialCallDecl: + # Similar to FromEggState::from_call but accepts partial list of args and returns in values + # Find first callable ref whose return type matches and fill in arg types. + for callable_ref in self.egg_fn_to_callable_refs[name]: + signature = self.__egg_decls__.get_callable_decl(callable_ref).signature + if not isinstance(signature, FunctionSignature): + continue + if signature.semantic_return_type.name != return_tp.name: + continue + tcs = TypeConstraintSolver(self.__egg_decls__) + + arg_types, bound_tp_params = tcs.infer_arg_types( + signature.arg_types, signature.semantic_return_type, signature.var_arg_type, return_tp, None + ) + + args = tuple( + TypedExprDecl(tp, self.value_to_expr(tp, v)) for tp, v in zip(arg_types, partial_args, strict=False) + ) + + call_decl = CallDecl( + callable_ref, + args, + # Don't include bound type params if this is just a method, we only needed them for type resolution + # but dont need to store them + bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else (), + ) + return PartialCallDecl(call_decl) + raise ValueError(f"Function '{name}' not found") + # https://chatgpt.com/share/9ab899b4-4e17-4426-a3f2-79d67a5ec456 _EGGLOG_INVALID_IDENT = re.compile(r"[^\w\-+*/?!=<>&|^/%]") @@ -789,7 +937,7 @@ def from_call( args, # Don't include bound type params if this is just a method, we only needed them for type resolution # but dont need to store them - bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else None, + bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else (), ) raise ValueError( f"Could not find callable ref for call {term}. None of these refs matched the types: {self.state.egg_fn_to_callable_refs[term.name]}" diff --git a/python/egglog/exp/array_api_jit.py b/python/egglog/exp/array_api_jit.py index 0e2ee53d..529b4a41 100644 --- a/python/egglog/exp/array_api_jit.py +++ b/python/egglog/exp/array_api_jit.py @@ -4,7 +4,7 @@ import numpy as np -from egglog import EGraph +from egglog import EGraph, greedy_dag_cost_model from egglog.exp.array_api import NDArray, set_array_api_egraph, try_evaling from egglog.exp.array_api_numba import array_api_numba_schedule from egglog.exp.array_api_program_gen import EvalProgram, array_api_program_gen_schedule, ndarray_function_two_program @@ -41,7 +41,7 @@ def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph, res = fn(NDArray.var(arg1), NDArray.var(arg2)) egraph.register(res) egraph.run(array_api_numba_schedule) - res_optimized = egraph.extract(res) + res_optimized = egraph.extract(res, cost_model=greedy_dag_cost_model()) return ( egraph, diff --git a/python/egglog/pretty.py b/python/egglog/pretty.py index 62531acf..a2eeeed4 100644 --- a/python/egglog/pretty.py +++ b/python/egglog/pretty.py @@ -183,7 +183,7 @@ def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C90 if isinstance(de, DefaultRewriteDecl): continue self(de) - case CallDecl(ref, exprs, _): + case CallDecl(ref, exprs, _) | GetCostDecl(ref, exprs): match ref: case FunctionRef(UnnamedFunctionRef(_, res)): self(res.expr) @@ -205,12 +205,13 @@ def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C90 case SetCostDecl(_, e, c): self(e) self(c) - case BackOffDecl(): + case BackOffDecl() | ValueDecl(): pass case LetSchedulerDecl(scheduler, schedule): self(scheduler) self(schedule) - + case GetCostDecl(ref, args): + self(CallDecl(ref, args)) case _: assert_never(decl) @@ -353,6 +354,10 @@ def uncached(self, decl: AllDecls, *, unwrap_lit: bool, parens: bool, ruleset_na if ban_length is not None: list_args.append(f"ban_length={ban_length}") return f"back_off({', '.join(list_args)})", "scheduler" + case ValueDecl(value): + return str(value), "value" + case GetCostDecl(ref, args): + return f"get_cost({self(CallDecl(ref, args))})", "get_cost" assert_never(decl) def _call( diff --git a/python/egglog/runtime.py b/python/egglog/runtime.py index 2a5a9a76..a53aa260 100644 --- a/python/egglog/runtime.py +++ b/python/egglog/runtime.py @@ -457,7 +457,7 @@ def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: arg_exprs = tuple(arg.__egg_typed_expr__ for arg in upcasted_args) return_tp = tcs.substitute_typevars(signature.semantic_return_type, cls_name) bound_params = ( - cast("JustTypeRef", bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef | InitRef) else None + cast("JustTypeRef", bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef | InitRef) else () ) # If we were using unstable-app to call a funciton, add that function back as the first arg. if function_value: @@ -584,11 +584,17 @@ def __eq__(self, other: object) -> object: # type: ignore[override] if (method := _get_expr_method(self, "__eq__")) is not None: return method(other) - # TODO: Check if two objects can be upcasted to be the same. If not, then return NotImplemented so other - # expr gets a chance to resolve __eq__ which could be a preserved method. - from .egraph import BaseExpr, eq # noqa: PLC0415 + if not (isinstance(self, RuntimeExpr) and isinstance(other, RuntimeExpr)): + return NotImplemented + if self.__egg_typed_expr__.tp != other.__egg_typed_expr__.tp: + return NotImplemented - return eq(cast("BaseExpr", self)).to(cast("BaseExpr", other)) + from .egraph import Fact # noqa: PLC0415 + + return Fact( + Declarations.create(self, other), + EqDecl(self.__egg_typed_expr__.tp, self.__egg_typed_expr__.expr, other.__egg_typed_expr__.expr), + ) def __ne__(self, other: object) -> object: # type: ignore[override] if (method := _get_expr_method(self, "__ne__")) is not None: diff --git a/python/egglog/type_constraint_solver.py b/python/egglog/type_constraint_solver.py index fd2f8c71..24779a25 100644 --- a/python/egglog/type_constraint_solver.py +++ b/python/egglog/type_constraint_solver.py @@ -54,7 +54,7 @@ def infer_arg_types( fn_var_args: TypeOrVarRef | None, return_: JustTypeRef, cls_name: str | None, - ) -> tuple[Iterable[JustTypeRef], tuple[JustTypeRef, ...] | None]: + ) -> tuple[Iterable[JustTypeRef], tuple[JustTypeRef, ...]]: """ Given a return type, infer the argument types. If there is a variable arg, it returns an infinite iterable. @@ -75,7 +75,7 @@ def infer_arg_types( ) ) if cls_name - else None + else () ) return arg_types, bound_typevars diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index 4ef0f68e..65e9cb63 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -11,7 +11,7 @@ from sklearn import config_context, datasets from sklearn.discriminant_analysis import LinearDiscriminantAnalysis -from egglog.egraph import set_current_ruleset +from egglog import greedy_dag_cost_model, set_current_ruleset from egglog.exp.array_api import * from egglog.exp.array_api import NDArray, Value from egglog.exp.array_api_jit import function_to_program, jit @@ -323,7 +323,7 @@ def test_program_compile(program: Program, snapshot_py): egraph = EGraph() egraph.register(program) egraph.run(array_api_numba_schedule) - simplified_program = egraph.extract(program) + simplified_program = egraph.extract(program, cost_model=greedy_dag_cost_model()) assert str(simplified_program) == snapshot_py(name="expr") egraph = EGraph() egraph.register(simplified_program.compile()) diff --git a/python/tests/test_bindings.py b/python/tests/test_bindings.py index 376c4d7f..420b89b7 100644 --- a/python/tests/test_bindings.py +++ b/python/tests/test_bindings.py @@ -3,6 +3,7 @@ import os import pathlib import subprocess +from fractions import Fraction import black import pytest @@ -227,3 +228,156 @@ def test_serialized_egraph(self): serialized = egraph.serialize([]) with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: executor.submit(print, (serialized,)).result() + + +egraph = EGraph() + + +class TestValues: + def test_i64(self): + sort, value = egraph.eval_expr(Lit(DUMMY_SPAN, Int(42))) + assert sort == "i64" + assert egraph.value_to_i64(value) == 42 + + def test_bigint(self): + sort, value = egraph.eval_expr(Call(DUMMY_SPAN, "bigint", [Lit(DUMMY_SPAN, Int(100))])) + assert sort == "BigInt" + assert egraph.value_to_bigint(value) == 100 + + def test_bigrat(self): + sort, value = egraph.eval_expr( + Call( + DUMMY_SPAN, + "bigrat", + [ + Call(DUMMY_SPAN, "bigint", [Lit(DUMMY_SPAN, Int(100))]), + Call(DUMMY_SPAN, "bigint", [Lit(DUMMY_SPAN, Int(21))]), + ], + ) + ) + assert sort == "BigRat" + assert egraph.value_to_bigrat(value) == Fraction(100, 21) + + def test_f64(self): + sort, value = egraph.eval_expr(Lit(DUMMY_SPAN, Float(3.14))) + assert sort == "f64" + assert egraph.value_to_f64(value) == 3.14 + + def test_string(self): + sort, value = egraph.eval_expr(Lit(DUMMY_SPAN, String("hello"))) + assert sort == "String" + assert egraph.value_to_string(value) == "hello" + + def test_rational(self): + sort, value = egraph.eval_expr( + Call( + DUMMY_SPAN, + "rational", + [Lit(DUMMY_SPAN, Int(22)), Lit(DUMMY_SPAN, Int(7))], + ) + ) + assert sort == "Rational" + assert egraph.value_to_rational(value) == Fraction(22, 7) + + def test_bool(self): + sort, value = egraph.eval_expr(Lit(DUMMY_SPAN, Bool(True))) + assert sort == "bool" + assert egraph.value_to_bool(value) is True + + def test_py_object(self): + py_object_sort = PyObjectSort() + egraph = EGraph(py_object_sort) + expr = py_object_sort.store("my object") + sort, value = egraph.eval_expr(expr) + assert sort == "PyObject" + assert egraph.value_to_pyobject(py_object_sort, value) == "my object" + + def test_map(self): + k = Lit(DUMMY_SPAN, Int(1)) + v = Lit(DUMMY_SPAN, String("one")) + egraph.run_program(Sort(DUMMY_SPAN, "MyMap", ("Map", [Var(DUMMY_SPAN, "i64"), Var(DUMMY_SPAN, "String")]))) + sort, value = egraph.eval_expr(Call(DUMMY_SPAN, "map-insert", [Call(DUMMY_SPAN, "map-empty", []), k, v])) + assert sort == "MyMap" + m = egraph.value_to_map(value) + assert m == {egraph.eval_expr(k)[1]: egraph.eval_expr(v)[1]} + + def test_multiset(self): + egraph.run_program(Sort(DUMMY_SPAN, "MyMultiSet", ("MultiSet", [Var(DUMMY_SPAN, "i64")]))) + sort, value = egraph.eval_expr( + Call( + DUMMY_SPAN, + "multiset-of", + [ + Lit(DUMMY_SPAN, Int(1)), + Lit(DUMMY_SPAN, Int(2)), + Lit(DUMMY_SPAN, Int(1)), + ], + ) + ) + assert sort == "MyMultiSet" + ms = egraph.value_to_multiset(value) + assert sorted(egraph.value_to_i64(v) for v in ms) == [1, 1, 2] + + def test_set(self): + egraph.run_program(Sort(DUMMY_SPAN, "MySet", ("Set", [Var(DUMMY_SPAN, "i64")]))) + sort, value = egraph.eval_expr( + Call( + DUMMY_SPAN, + "set-of", + [ + Lit(DUMMY_SPAN, Int(1)), + Lit(DUMMY_SPAN, Int(2)), + ], + ) + ) + assert sort == "MySet" + s = egraph.value_to_set(value) + assert isinstance(s, set) + assert {egraph.value_to_i64(v) for v in s} == {1, 2} + + def test_vec(self): + egraph.run_program(Sort(DUMMY_SPAN, "MyVec", ("Vec", [Var(DUMMY_SPAN, "i64")]))) + sort, value = egraph.eval_expr( + Call( + DUMMY_SPAN, + "vec-of", + [ + Lit(DUMMY_SPAN, Int(1)), + Lit(DUMMY_SPAN, Int(2)), + Lit(DUMMY_SPAN, Int(3)), + ], + ) + ) + assert sort == "MyVec" + v = egraph.value_to_vec(value) + assert isinstance(v, list) + assert [egraph.value_to_i64(vi) for vi in v] == [1, 2, 3] + + def test_fn(self): + egraph.run_program( + Sort(DUMMY_SPAN, "MyFn", ("UnstableFn", [Call(DUMMY_SPAN, "i64", []), Var(DUMMY_SPAN, "i64")])) + ) + sort, value = egraph.eval_expr( + Call( + DUMMY_SPAN, + "unstable-fn", + [ + Lit(DUMMY_SPAN, String("+")), + Lit(DUMMY_SPAN, Int(1)), + ], + ) + ) + assert sort == "MyFn" + f, args = egraph.value_to_function(value) + assert f == "+" + assert len(args) == 1 + assert egraph.value_to_i64(args[0]) == 1 + + +def test_lookup_function(): + egraph = EGraph() + egraph.run_program(*egraph.parse_program("(function hi (i64) i64 :no-merge)\n(set (hi 1) 2)")) + assert ( + egraph.lookup_function("hi", [egraph.eval_expr(Lit(DUMMY_SPAN, Int(1)))[1]]) + == egraph.eval_expr(Lit(DUMMY_SPAN, Int(2)))[1] + ) diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index c6e8731b..240bb254 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -7,6 +7,7 @@ from fractions import Fraction from functools import partial from typing import ClassVar, TypeAlias, TypeVar +from unittest.mock import MagicMock import pytest @@ -344,7 +345,7 @@ def incr(x: Math) -> None: ... assert str(x) == "_Math_1 = Math(10)\nincr(_Math_1)\n_Math_1" assert str(x + Math(10)) == "_Math_1 = Math(10)\nincr(_Math_1)\n_Math_1 + Math(10)" - i, j = vars_("i j", Math) + i, _j = vars_("i j", Math) incr_i = copy(i) incr(incr_i) egraph.register(rewrite(incr_i).to(i + Math(1)), x) @@ -1184,3 +1185,126 @@ def test_custom_scheduler_invalid_until(self): # Multiple until facts should error via high-level run with pytest.raises(ValueError, match="Can only have one until fact with custom scheduler"): egraph.run(run(r, rel(i64(0)), rel(i64(1)), scheduler=bo)) + + +@function +def ff(x: i64Like, y: i64Like) -> E: ... + + +@function +def gg() -> E: ... + + +class TestCustomExtract: + @pytest.mark.parametrize( + "expr", + [ + pytest.param(i64(10), id="i64"), + pytest.param(f64(10.0), id="f64"), + pytest.param(String("hi"), id="String"), + pytest.param(Bool(True), id="Bool"), + pytest.param(Rational(1, 2), id="Rational"), + pytest.param(BigInt(10), id="BigInt"), + pytest.param(BigRat(1, 2), id="BigRat"), + pytest.param(PyObject("hi"), id="PyObject"), + pytest.param(Vec(i64(1), i64(2)), id="Vec"), + pytest.param(Set(i64(1), i64(2)), id="Set"), + pytest.param(Map[i64, String].empty().insert(i64(1), String("hi")), id="Map"), + pytest.param(MultiSet(i64(1), i64(1)), id="MultiSet"), + pytest.param(Unit(), id="Unit"), + pytest.param(UnstableFn[E, i64, i64](ff), id="fn"), + pytest.param(UnstableFn[E, i64](ff, i64(1)), id="fn partial"), + ], + ) + def test_to_from_value(self, expr): + egraph = EGraph() + expr = egraph.extract(expr) + assert expr == self._to_from_value(egraph, expr) + + def _to_from_value(self, egraph, expr): + typed_expr = expr.__egg_typed_expr__ + value = egraph._state.typed_expr_to_value(typed_expr) + res_val = egraph._state.value_to_expr(typed_expr.tp, value) + return expr.__with_expr__(TypedExprDecl(typed_expr.tp, res_val)) + + def test_compare_values(self): + egraph = EGraph() + egraph.register(E(), gg()) + e_value = self._to_from_value(egraph, E()) + gg_value = self._to_from_value(egraph, gg()) + assert e_value != gg_value + assert hash(e_value) != hash(gg_value) + assert str(e_value) != str(gg_value) + + def test_no_changes(self): + egraph = EGraph() + assert egraph.extract(E(), include_cost=True) == egraph.extract( + E(), include_cost=True, cost_model=default_cost_model + ) + + def test_calls_methods(self): + @function + def my_f(xs: Vec[i64]) -> E: ... + + # cost = 2 + x = i64(10) + # cost = 3 + 2 = 5 + xs = Vec[i64](x) + # cost = 100 + res = E() + # cost = 1 + 5 = 6 + called = my_f(xs) + egraph = EGraph() + egraph.register(union(called).with_(res)) + + def my_cost_model(egraph: EGraph, expr: BaseExpr, children_costs: list[int]) -> int: + if get_callable_fn(expr) == E: + return 100 + match expr: + case i64(): + return 2 + case Vec(): + return 3 + sum(children_costs) + return default_cost_model(egraph, expr, children_costs) + + my_cost_model = MagicMock(side_effect=my_cost_model) + assert egraph.extract(called, include_cost=True, cost_model=my_cost_model) == (called, 6) + + my_cost_model.assert_any_call(egraph, res, []) + my_cost_model.assert_any_call(egraph, xs, [2]) + my_cost_model.assert_any_call(egraph, x, []) + my_cost_model.assert_any_call(egraph, called, [5]) + + @pytest.mark.xfail(reason="Errors dont bubble, just panic") + def test_errors_bubble(self): + def my_cost_model(egraph: EGraph, expr: BaseExpr, children_costs: list[int]) -> int: + msg = "bad" + raise ValueError(msg) + + egraph = EGraph() + + with pytest.raises(ValueError, match="bad"): + egraph.extract(i64(10), cost_model=my_cost_model) + + def test_dag_cost_model(self): + egraph = EGraph() + expr = ff(1, 2) + res, cost = egraph.extract(expr, include_cost=True, cost_model=greedy_dag_cost_model()) + assert cost.total == 3 + assert expr == res + + expr = ff(1, 1) + res, cost = egraph.extract(expr, include_cost=True, cost_model=greedy_dag_cost_model()) + assert cost.total == 2 + assert expr == res + + @function + def bin(l: E, r: E) -> E: ... + + x = constant("x", E) + y = constant("y", E) + expr = bin(x, bin(x, y)) + egraph.register(expr) + res, cost = egraph.extract(expr, include_cost=True, cost_model=greedy_dag_cost_model()) + assert cost.total == 4 + assert expr == res diff --git a/src/conversions.rs b/src/conversions.rs index c2f1eed7..56449ad6 100644 --- a/src/conversions.rs +++ b/src/conversions.rs @@ -111,7 +111,7 @@ convert_enums!( v -> egglog::Term::Var((&v.name).into()), egglog::Term::Var(v) => TermVar { name: v.to_string() }; TermApp[trait=Hash](name: String, args: Vec) - a -> egglog::Term::App(a.name.clone().into(), a.args.to_vec()), + a -> egglog::Term::App(a.name.clone(), a.args.to_vec()), egglog::Term::App(s, a) => TermApp { name: s.to_string(), args: a.to_vec() @@ -335,22 +335,22 @@ convert_enums!( egglog::CommandOutput::PrintAllFunctionsSize(sizes) => PrintAllFunctionsSize {sizes: sizes.clone()}; ExtractBest(termdag: TermDag, cost: DefaultCost, term: Term) b -> egglog::CommandOutput::ExtractBest( - (&b.termdag).into(), + b.termdag.0.clone(), b.cost, (&b.term).into() ), egglog::CommandOutput::ExtractBest(termdag, cost, term) => ExtractBest { - termdag: termdag.into(), + termdag: TermDag(termdag.clone()), cost: *cost, term: term.into() }; ExtractVariants(termdag: TermDag, terms: Vec) v -> egglog::CommandOutput::ExtractVariants( - (&v.termdag).into(), + v.termdag.0.clone(), v.terms.iter().map(|v| v.into()).collect() ), egglog::CommandOutput::ExtractVariants(termdag, terms) => ExtractVariants { - termdag: termdag.into(), + termdag: TermDag(termdag.clone()), terms: terms.iter().map(|v| v.into()).collect() }; OverallStatistics(report: RunReport) @@ -362,13 +362,13 @@ convert_enums!( PrintFunctionOutput(function: Function, termdag: TermDag, terms: Vec<(Term, Term)>, mode: PrintFunctionMode) v -> egglog::CommandOutput::PrintFunction( v.function.0.clone(), - (&v.termdag).into(), + v.termdag.0.clone(), v.terms.iter().map(|(l, r)| (l.into(), r.into())).collect(), v.mode.clone().into() ), egglog::CommandOutput::PrintFunction(function, termdag, terms, mode) => PrintFunctionOutput { function: Function(function.clone()), - termdag: termdag.into(), + termdag: TermDag(termdag.clone()), terms: terms.iter().map(|(l, r)| (l.into(), r.into())).collect(), mode: mode.into() }; @@ -461,11 +461,11 @@ convert_struct!( ) r -> egglog::RunReport { updated: r.updated, - search_and_apply_time_per_rule: r.search_and_apply_time_per_rule.iter().map(|(k, v)| (k.clone().into(), v.clone().0)).collect(), - num_matches_per_rule: r.num_matches_per_rule.iter().map(|(k, v)| (k.clone().into(), *v)).collect(), - search_and_apply_time_per_ruleset: r.search_and_apply_time_per_ruleset.iter().map(|(k, v)| (k.clone().into(), v.clone().0)).collect(), - merge_time_per_ruleset: r.merge_time_per_ruleset.iter().map(|(k, v)| (k.clone().into(), v.clone().0)).collect(), - rebuild_time_per_ruleset: r.rebuild_time_per_ruleset.iter().map(|(k, v)| (k.clone().into(), v.clone().0)).collect(), + search_and_apply_time_per_rule: r.search_and_apply_time_per_rule.iter().map(|(k, v)| (k.clone(), v.clone().0)).collect(), + num_matches_per_rule: r.num_matches_per_rule.iter().map(|(k, v)| (k.clone(), *v)).collect(), + search_and_apply_time_per_ruleset: r.search_and_apply_time_per_ruleset.iter().map(|(k, v)| (k.clone(), v.clone().0)).collect(), + merge_time_per_ruleset: r.merge_time_per_ruleset.iter().map(|(k, v)| (k.clone(), v.clone().0)).collect(), + rebuild_time_per_ruleset: r.rebuild_time_per_ruleset.iter().map(|(k, v)| (k.clone(), v.clone().0)).collect(), }, r -> RunReport { updated: r.updated, diff --git a/src/egraph.rs b/src/egraph.rs index 377ff39e..a7950af8 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -2,13 +2,16 @@ use crate::conversions::*; use crate::error::{EggResult, WrappedError}; -use crate::py_object_sort::PyObjectSort; +use crate::py_object_sort::{PyObjectIdent, PyObjectSort}; use crate::serialize::SerializedEGraph; use egglog::prelude::add_base_sort; use egglog::{SerializeConfig, span}; use log::info; +use num_bigint::BigInt; +use num_rational::{BigRational, Rational64}; use pyo3::prelude::*; +use std::collections::{BTreeMap, BTreeSet}; use std::path::PathBuf; /// EGraph() @@ -76,10 +79,10 @@ impl EGraph { WrappedError::Egglog(e, "\nWhen running commands:\n".to_string() + &cmds_str) }) }); - if res.is_ok() { - if let Some(cmds) = &mut self.cmds { - cmds.push_str(&cmds_str); - } + if res.is_ok() + && let Some(cmds) = &mut self.cmds + { + cmds.push_str(&cmds_str); } res.map(|xs| xs.iter().map(|o| o.into()).collect()) } @@ -121,4 +124,130 @@ impl EGraph { } }) } + + fn lookup_function(&self, name: &str, key: Vec) -> Option { + self.egraph + .lookup_function( + name, + key.into_iter().map(|v| v.0).collect::>().as_slice(), + ) + .map(Value) + } + + fn eval_expr(&mut self, expr: Expr) -> EggResult<(String, Value)> { + let expr: egglog::ast::Expr = expr.into(); + self.egraph + .eval_expr(&expr) + .map(|(s, v)| (s.name().to_string(), Value(v))) + .map_err(|e| WrappedError::Egglog(e, format!("\nWhen evaluating expr: {expr}"))) + } + + fn value_to_i64(&self, v: Value) -> i64 { + self.egraph.value_to_base(v.0) + } + + fn value_to_bigint(&self, v: Value) -> BigInt { + let bi: egglog::sort::Z = self.egraph.value_to_base(v.0); + bi.0 + } + + fn value_to_bigrat(&self, v: Value) -> BigRational { + let bi: egglog::sort::Q = self.egraph.value_to_base(v.0); + bi.0 + } + + fn value_to_f64(&self, v: Value) -> f64 { + let f: egglog::sort::F = self.egraph.value_to_base(v.0); + f.0.into_inner() + } + + fn value_to_string(&self, v: Value) -> String { + let s: egglog::sort::S = self.egraph.value_to_base(v.0); + s.0 + } + + fn value_to_bool(&self, v: Value) -> bool { + self.egraph.value_to_base(v.0) + } + fn value_to_rational(&self, v: Value) -> Rational64 { + let r: egglog_experimental::R = self.egraph.value_to_base(v.0); + r.0 + } + + fn value_to_pyobject( + &self, + py: Python<'_>, + py_object_sort: PyObjectSort, + v: Value, + ) -> Py { + let ident = self.egraph.value_to_base::(v.0); + py_object_sort.load(py, ident).unbind() + } + + fn value_to_map(&self, v: Value) -> BTreeMap { + let mc = self + .egraph + .value_to_container::(v.0) + .unwrap(); + mc.data + .iter() + .map(|(k, v)| (Value(*k), Value(*v))) + .collect() + } + + fn value_to_multiset(&self, v: Value) -> Vec { + let mc = self + .egraph + .value_to_container::(v.0) + .unwrap(); + mc.data.iter().map(|k| Value(*k)).collect() + } + + fn value_to_set(&self, v: Value) -> BTreeSet { + let sc = self + .egraph + .value_to_container::(v.0) + .unwrap(); + sc.data.iter().map(|k| Value(*k)).collect() + } + + fn value_to_vec(&self, v: Value) -> Vec { + let vc = self + .egraph + .value_to_container::(v.0) + .unwrap(); + vc.data.iter().map(|x| Value(*x)).collect() + } + + fn value_to_function(&self, v: Value) -> (String, Vec) { + let fc = self + .egraph + .value_to_container::(v.0) + .unwrap(); + ( + fc.2.clone(), + fc.1.iter().map(|(_, v)| Value(*v)).collect::>(), + ) + } + + // fn dynamic_cost_model_enode_cost( + // &self, + // func: String, + // args: Vec, + // ) -> EggResult { + // let func = self.egraph.get_function(&func).ok_or_else(|| { + // WrappedError::Py(PyRuntimeError::new_err(format!("No such function: {func}"))) + // })?; + // let vals: Vec = args.into_iter().map(|v| v.0).collect(); + // let row = FunctionRow { + // vals: &vals, + // subsumed: false, + // }; + // Ok(egglog_experimental::DynamicCostModel {}.enode_cost(&self.egraph, &func, &row)) + // } } + +/// Wrapper around Egglog Value. Represents either a primitive base value or a reference to an e-class. +#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Debug, Clone)] +#[pyclass(eq, frozen, hash, str = "{0:?}")] +pub struct Value(pub egglog::Value); diff --git a/src/extract.rs b/src/extract.rs new file mode 100644 index 00000000..a9c28598 --- /dev/null +++ b/src/extract.rs @@ -0,0 +1,260 @@ +use std::cmp::Ordering; + +use pyo3::{exceptions::PyValueError, prelude::*}; + +use crate::{conversions::Term, egraph::EGraph, egraph::Value, termdag::TermDag}; + +#[derive(Debug)] +// We have to store the result, since the cost model does not return errors +struct Cost(Py); + +impl Ord for Cost { + fn cmp(&self, other: &Self) -> Ordering { + Python::attach(|py| self.0.bind(py).compare(other.0.bind(py)).unwrap()) + } +} + +impl PartialOrd for Cost { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for Cost { + fn eq(&self, other: &Self) -> bool { + Python::attach(|py| self.0.bind(py).eq(other.0.bind(py))).unwrap() + } +} + +impl Eq for Cost {} + +impl Clone for Cost { + fn clone(&self) -> Self { + Python::attach(|py| Cost(self.0.clone_ref(py))) + } +} + +impl egglog::extract::Cost for Cost { + fn identity() -> Self { + panic!("Should never be called from Rust directly"); + } + + fn unit() -> Self { + panic!("Should never be called from Rust directly"); + } + + fn combine(self, _other: &Self) -> Self { + panic!("Should never be called from Rust directly"); + } +} + +/// Cost model defined by Python functions. +#[derive(Debug)] +#[pyclass( + frozen, + str = "CostModel({fold:?}, {enode_cost:?}, {container_cost:?}, {base_value_cost:?}" +)] +pub struct CostModel { + /// Function mapping from a term's head and its children's costs to the term's total cost. + /// (head: str, head_cost: COST, children_costs: list[COST]) -> COST + fold: Py, + /// Function mapping from an expression node to its cost. + /// (func_name: str, args: list[Value]) -> COST + enode_cost: Py, + /// Function mapping from a container value to its cost given the costs of its elements. + /// (sort_name: str, value: Value, element_costs: list[COST]) -> COST + container_cost: Py, + /// Function mapping from a base value to its cost. + /// (sort_name: str, value: Value) -> COST + base_value_cost: Py, +} + +#[pymethods] +impl CostModel { + #[new] + fn new( + fold: Py, + enode_cost: Py, + container_cost: Py, + base_value_cost: Py, + ) -> Self { + CostModel { + fold, + enode_cost, + container_cost, + base_value_cost, + } + } +} + +impl Clone for CostModel { + fn clone(&self) -> Self { + Python::attach(|py| CostModel { + fold: self.fold.clone_ref(py), + enode_cost: self.enode_cost.clone_ref(py), + container_cost: self.container_cost.clone_ref(py), + base_value_cost: self.base_value_cost.clone_ref(py), + }) + } +} + +impl egglog::extract::CostModel for CostModel { + fn fold(&self, head: &str, children_cost: &[Cost], head_cost: Cost) -> Cost { + Cost(Python::attach(|py| { + let head_cost = head_cost.0.clone_ref(py); + let children_cost = children_cost + .into_iter() + .cloned() + .map(|c| c.0.clone_ref(py)) + .collect::>(); + self.fold + .call1(py, (head, head_cost, children_cost)) + .unwrap() + })) + } + + fn enode_cost( + &self, + egraph: &egglog::EGraph, + func: &egglog::Function, + row: &egglog::FunctionRow<'_>, + ) -> Cost { + Python::attach(|py| { + let mut values = row.vals.iter().map(|v| Value(*v)).collect::>(); + // Remove last element which is the output + // this is not needed because the only thing we can do with the output is look up an analysis + // which we can also do with the original function + values.pop().unwrap(); + Cost(self.enode_cost.call1(py, (func.name(), values)).unwrap()) + }) + } + + fn container_cost( + &self, + _egraph: &egglog::EGraph, + sort: &egglog::ArcSort, + value: egglog::Value, + element_costs: &[Cost], + ) -> Cost { + Cost(Python::attach(|py| { + let element_costs = element_costs + .into_iter() + .cloned() + .map(|c| c.0.clone_ref(py)) + .collect::>(); + self.container_cost + .call1(py, (sort.name(), Value(value), element_costs)) + .unwrap() + })) + } + + // https://github.com/PyO3/pyo3/issues/1190 + fn base_value_cost( + &self, + _egraph: &egglog::EGraph, + sort: &egglog::ArcSort, + value: egglog::Value, + ) -> Cost { + Python::attach(|py| { + Cost( + self.base_value_cost + .call1(py, (sort.name(), Value(value))) + .unwrap(), + ) + }) + } +} + +// TODO: Don't progress just return an error if there was an exception? + +#[pyclass(unsendable)] +pub struct Extractor(egglog::extract::Extractor); + +#[pymethods] +impl Extractor { + /// Create a new extractor from the given egraph and cost model. + /// + /// Bulk of the computation happens at initialization time. + /// The later extractions only reuses saved results. + /// This means a new extractor must be created if the egraph changes. + /// Holding a reference to the egraph would enforce this but prevents the extractor being reused. + /// + /// For convenience, if the rootsorts is `None`, it defaults to extract all extractable rootsorts. + #[new] + fn new( + py: Python<'_>, + rootsorts: Option>, + egraph: &EGraph, + cost_model: CostModel, + ) -> PyResult { + let egraph = &egraph.egraph; + // Transforms sorts to arcsorts, returning an error if any are unknown + let rootsorts = rootsorts + .map(|rs| { + rs.into_iter() + .map(|s| egraph.get_sort_by_name(&s).cloned()) + .collect::>>() + .ok_or(PyValueError::new_err("Unknown sort in rootsorts")) + }) + .map_or(Ok(None), |r| r.map(Some))?; + let extractor = + egglog::extract::Extractor::compute_costs_from_rootsorts(rootsorts, egraph, cost_model); + if let Some(err) = PyErr::take(py) { + return Err(err); + }; + Ok(Extractor(extractor)) + } + + /// Extract the best term of a value from a given sort. + /// + /// This function expects the sort to be already computed, + /// which can be one of the rootsorts, or reachable from rootsorts, or primitives, or containers of computed sorts. + fn extract_best( + &self, + py: Python<'_>, + egraph: &EGraph, + termdag: &mut TermDag, + value: Value, + sort: String, + ) -> PyResult<(Py, Term)> { + let sort = egraph + .egraph + .get_sort_by_name(&sort) + .ok_or(PyValueError::new_err("Unknown sort"))?; + let (cost, term) = self + .0 + .extract_best_with_sort(&egraph.egraph, &mut termdag.0, value.0, sort.clone()) + .ok_or(PyValueError::new_err("Unextractable root".to_string()))?; + Ok((cost.0.clone_ref(py), term.into())) + } + + /// Extract variants of an e-class. + /// + /// The variants are selected by first picking `nvariants` e-nodes with the lowest cost from the e-class + /// and then extracting a term from each e-node. + fn extract_variants( + &self, + py: Python<'_>, + egraph: &EGraph, + termdag: &mut TermDag, + value: Value, + nvariants: usize, + sort: String, + ) -> PyResult, Term)>> { + let sort = egraph + .egraph + .get_sort_by_name(&sort) + .ok_or(PyValueError::new_err("Unknown sort"))?; + let variants = self.0.extract_variants_with_sort( + &egraph.egraph, + &mut termdag.0, + value.0, + nvariants, + sort.clone(), + ); + Ok(variants + .into_iter() + .map(|(cost, term)| (cost.0.clone_ref(py), term.into())) + .collect()) + } +} diff --git a/src/lib.rs b/src/lib.rs index b9c2bf69..2a406718 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ mod conversions; mod egraph; mod error; +mod extract; mod py_object_sort; mod serialize; mod termdag; @@ -26,10 +27,13 @@ fn bindings(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; crate::conversions::add_structs_to_module(m)?; crate::conversions::add_enums_to_module(m)?; diff --git a/src/py_object_sort.rs b/src/py_object_sort.rs index d9ff0a75..5206cd3b 100644 --- a/src/py_object_sort.rs +++ b/src/py_object_sort.rs @@ -30,7 +30,7 @@ use pyo3::{ types::{PyCode, PyCodeMethods as _, PyDict}, }; -type PyObjectIdent = usize; +pub type PyObjectIdent = usize; #[derive(Clone)] #[pyclass] @@ -158,7 +158,7 @@ impl PyObjectSort { // Integrate with Python garbage collector // https://pyo3.rs/main/class/protocols#garbage-collector-integration - fn __traverse__<'py>(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { visit.call(self.hashable_to_index.as_ref())?; self.objects .lock() diff --git a/src/termdag.rs b/src/termdag.rs index 6e333a7b..c4027a1f 100644 --- a/src/termdag.rs +++ b/src/termdag.rs @@ -2,60 +2,56 @@ use crate::conversions::{Expr, Literal, Span, Term}; use egglog::TermId; use pyo3::prelude::*; -#[pyclass()] -#[derive(Clone, PartialEq, Eq)] -pub struct TermDag { - pub termdag: egglog::TermDag, -} +#[pyclass(eq, str = "{0:?}")] +#[derive(PartialEq, Eq, Clone)] +pub struct TermDag(pub egglog::TermDag); #[pymethods] impl TermDag { /// Create a new, empty TermDag. #[new] fn new() -> Self { - Self { - termdag: egglog::TermDag::default(), - } + Self(egglog::TermDag::default()) } /// Returns the number of nodes in this DAG. pub fn size(&self) -> usize { - self.termdag.size() + self.0.size() } /// Convert the given term to its id. /// /// Panics if the term does not already exist in this [TermDag]. pub fn lookup(&self, node: Term) -> TermId { - self.termdag.lookup(&node.into()).into() + self.0.lookup(&node.into()) } /// Convert the given id to the corresponding term. /// /// Panics if the id is not valid. pub fn get(&self, id: TermId) -> Term { - self.termdag.get(id).into() + self.0.get(id).into() } /// Make and return a App with the given head symbol and children, /// and insert into the DAG if it is not already present. /// /// Panics if any of the children are not already in the DAG. pub fn app(&mut self, sym: String, children: Vec) -> Term { - self.termdag - .app(sym.into(), children.into_iter().map(|c| c.into()).collect()) + self.0 + .app(sym, children.into_iter().map(|c| c.into()).collect()) .into() } /// Make and return a [`Term::Lit`] with the given literal, and insert into /// the DAG if it is not already present. pub fn lit(&mut self, lit: Literal) -> Term { - self.termdag.lit(lit.into()).into() + self.0.lit(lit.into()).into() } /// Make and return a [`Term::Var`] with the given symbol, and insert into /// the DAG if it is not already present. pub fn var(&mut self, sym: String) -> Term { - self.termdag.var(sym.into()).into() + self.0.var(sym).into() } /// Recursively converts the given expression to a term. @@ -64,44 +60,20 @@ impl TermDag { /// TermDags are hashconsed, the resulting term is guaranteed to maximally /// share subterms. pub fn expr_to_term(&mut self, expr: Expr) -> Term { - self.termdag.expr_to_term(&expr.into()).into() + self.0.expr_to_term(&expr.into()).into() } /// Recursively converts the given term to an expression. /// /// Panics if the term contains subterms that are not in the DAG. pub fn term_to_expr(&self, term: Term, span: Span) -> Expr { - self.termdag.term_to_expr(&term.into(), span.into()).into() + self.0.term_to_expr(&term.into(), span.into()).into() } /// Converts the given term to a string. /// /// Panics if the term or any of its subterms are not in the DAG. pub fn to_string(&self, term: Term) -> String { - self.termdag.to_string(&term.into()) - } -} - -impl From<&egglog::TermDag> for TermDag { - fn from(termdag: &egglog::TermDag) -> Self { - Self { - termdag: termdag.clone(), - } - } -} -impl From<&TermDag> for egglog::TermDag { - fn from(termdag: &TermDag) -> Self { - termdag.termdag.clone() - } -} - -impl From for TermDag { - fn from(termdag: egglog::TermDag) -> Self { - Self { termdag } - } -} -impl From for egglog::TermDag { - fn from(termdag: TermDag) -> Self { - termdag.termdag + self.0.to_string(&term.into()) } }