diff --git a/Cargo.lock b/Cargo.lock index b29dac40..d4ca4592 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -317,7 +317,7 @@ checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" [[package]] name = "egglog" version = "1.0.0" -source = "git+https://github.com/egraphs-good/egglog.git?branch=main#ef90b97de1f5e7778186439b8fb0549179f82a45" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" dependencies = [ "csv", "dyn-clone", @@ -344,7 +344,7 @@ dependencies = [ [[package]] name = "egglog-add-primitive" version = "1.0.0" -source = "git+https://github.com/egraphs-good/egglog.git?branch=main#ef90b97de1f5e7778186439b8fb0549179f82a45" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" dependencies = [ "quote", "syn 2.0.107", @@ -353,7 +353,7 @@ dependencies = [ [[package]] name = "egglog-ast" version = "1.0.0" -source = "git+https://github.com/egraphs-good/egglog.git?branch=main#ef90b97de1f5e7778186439b8fb0549179f82a45" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" dependencies = [ "ordered-float", ] @@ -361,7 +361,7 @@ dependencies = [ [[package]] name = "egglog-bridge" version = "1.0.0" -source = "git+https://github.com/egraphs-good/egglog.git?branch=main#ef90b97de1f5e7778186439b8fb0549179f82a45" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" dependencies = [ "anyhow", "dyn-clone", @@ -385,7 +385,7 @@ dependencies = [ [[package]] name = "egglog-concurrency" version = "1.0.0" -source = "git+https://github.com/egraphs-good/egglog.git?branch=main#ef90b97de1f5e7778186439b8fb0549179f82a45" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" dependencies = [ "arc-swap", "rayon", @@ -394,7 +394,7 @@ dependencies = [ [[package]] name = "egglog-core-relations" version = "1.0.0" -source = "git+https://github.com/egraphs-good/egglog.git?branch=main#ef90b97de1f5e7778186439b8fb0549179f82a45" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" dependencies = [ "anyhow", "bumpalo", @@ -437,7 +437,7 @@ dependencies = [ [[package]] name = "egglog-numeric-id" version = "1.0.0" -source = "git+https://github.com/egraphs-good/egglog.git?branch=main#ef90b97de1f5e7778186439b8fb0549179f82a45" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" dependencies = [ "rayon", ] @@ -445,7 +445,7 @@ dependencies = [ [[package]] name = "egglog-reports" version = "1.0.0" -source = "git+https://github.com/egraphs-good/egglog.git?branch=main#ef90b97de1f5e7778186439b8fb0549179f82a45" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" dependencies = [ "clap", "hashbrown 0.16.0", @@ -459,7 +459,7 @@ dependencies = [ [[package]] name = "egglog-union-find" version = "1.0.0" -source = "git+https://github.com/egraphs-good/egglog.git?branch=main#ef90b97de1f5e7778186439b8fb0549179f82a45" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" dependencies = [ "crossbeam", "egglog-concurrency", diff --git a/Cargo.toml b/Cargo.toml index 263c73ac..9e9e58ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,12 +13,12 @@ crate-type = ["cdylib"] pyo3 = { version = "0.27", features = ["extension-module", "num-bigint", "num-rational"] } num-bigint = "*" num-rational = "*" -egglog = { git = "https://github.com/egraphs-good/egglog.git", branch = "main", default-features = false } -egglog-bridge = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" } -egglog-core-relations = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" } +egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug", default-features = false } +egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } +egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental", branch = "main", default-features = false } -egglog-ast = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" } -egglog-reports = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" } +egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } +egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } egraph-serialize = { version = "0.3", features = ["serde", "graphviz"] } serde_json = "1" pyo3-log = "*" @@ -31,10 +31,11 @@ base64 = "0.22.1" # Use patched version of egglog in experimental [patch.'https://github.com/egraphs-good/egglog'] -# egglog = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" } -# egglog-bridge = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" } -# egglog-core-relations = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" } -# egglog-ast = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" } +egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } +egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } +egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } +egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } +egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } # enable debug symbols for easier profiling [profile.release] diff --git a/python/egglog/exp/MoA.ipynb b/python/egglog/exp/MoA.ipynb new file mode 100644 index 00000000..c4988e10 --- /dev/null +++ b/python/egglog/exp/MoA.ipynb @@ -0,0 +1,617 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "922a695b", + "metadata": {}, + "source": [ + "# Mathematics of Arrays in Egglog\n", + "\n", + "\n", + "This notebook shows how if you define array operations as higher order functions, we can compose them and end up with a simpler algebra that just uses boolean and integers and functions.\n", + "\n", + "We take as our input this MoA program, defined in [the PSI compiler](https://saulshanabrook.github.io/psi-compiler/src/):\n", + "\n", + "\n", + "```\n", + "main ()\n", + "\n", + "{\n", + " array Amts^3 <2 3 4>;\n", + " array Ams^3 <2 3 4>;\n", + " const array RAMY^3 <2 3 4>=<1 2 3 4 5 6 7 8 9 10 11 12 1 2 3 4 5 6 7 8 9 10 \n", + "\t\t\t\t11 12>;\n", + " const array AMY^3 <2 3 4>=<9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9>;\n", + " Amts=<2> take (<2> drop (RAMY cat AMY));\n", + "}\n", + "```\n", + "\n", + "This result `Amts` is equivalent to `AMY`, since we are concatenating `RAMY` and `AMY` along the first axis, dropping the first 2 elements (which removes all of `RAMY`), and then taking the next 2 elements (which is all of `AMY`).\n", + "\n", + "Compiling it produces this C program which copies AMY into Amts:\n", + "\n", + "```c\n", + "#include \n", + "#include \"moalib.e\"\n", + "\n", + "main()\n", + "\n", + "{\n", + " double *offset0;\n", + " int i0;\n", + " int i1;\n", + " int i2;\n", + " double *shift;\n", + " double _RAMY[]={1.000000, 2.000000, 3.000000, 4.000000, 5.000000,\n", + " 6.000000, 7.000000, 8.000000, 9.000000, 10.000000,\n", + " 11.000000, 12.000000, 1.000000, 2.000000, 3.000000,\n", + " 4.000000, 5.000000, 6.000000, 7.000000, 8.000000,\n", + " 9.000000, 10.000000, 11.000000, 12.000000};\n", + " double _AMY[]={9.000000, 9.000000, 9.000000, 9.000000, 9.000000,\n", + " 9.000000, 9.000000, 9.000000, 9.000000, 9.000000,\n", + " 9.000000, 9.000000, 9.000000, 9.000000, 9.000000,\n", + " 9.000000, 9.000000, 9.000000, 9.000000, 9.000000,\n", + " 9.000000, 9.000000, 9.000000, 9.000000};\n", + " double _Y[]={8.000000, 8.000000, 8.000000, 8.000000, 8.000000,\n", + " 8.000000, 8.000000, 8.000000, 8.000000, 8.000000,\n", + " 8.000000, 8.000000, 8.000000, 8.000000, 8.000000,\n", + " 8.000000, 8.000000, 8.000000, 8.000000, 8.000000,\n", + " 8.000000, 8.000000, 8.000000, 8.000000};\n", + " double _V[]={1.000000, 1.000000};\n", + " double _Amts[2*3*4];\n", + "\n", + "/*******\n", + "Amts=<2.000000> take (<2.000000> drop (RAMY cat AMY))\n", + "********/\n", + "\n", + " shift=_Amts+0*12+0*4+0;\n", + " offset0=_AMY+0*12+0*4+0;\n", + " for (i0=0; i0<2; i0++) {\n", + " for (i1=0; i1<3; i1++) {\n", + " for (i2=0; i2<4; i2++) {\n", + " *(shift)= *(offset0);\n", + " offset0+=1;\n", + " shift+=1;\n", + " }\n", + " }\n", + " }\n", + "```\n", + "\n", + "What we want to show here is not the full compilation into C and into loops, but just the fact that by defining each array operation as a higher order function, we can compose them and end up with a simpler algebra that just uses boolean and integers and functions. This could then be compiled into loops. The hypothesis here is that we don't *lose* any information by erasing the `take`, `drop`, and `cat` operations and replacing them with their definitions in terms of functions.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "1b715c58", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
take(Int(2), drop(Int(2), cat(NDArray.from_memory(TupleInt.from_vec(Vec(Int(2), Int(3), Int(4))), RAMY), NDArray.from_memory(TupleInt.from_vec(Vec(Int(2), Int(3), Int(4))), AMY))))\n",
+       "
\n" + ], + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{n}{take}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{,} \\PY{n}{drop}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{,} \\PY{n}{cat}\\PY{p}{(}\\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{from\\PYZus{}memory}\\PY{p}{(}\\PY{n}{TupleInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{3}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{4}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{RAMY}\\PY{p}{)}\\PY{p}{,} \\PY{n}{NDArray}\\PY{o}{.}\\PY{n}{from\\PYZus{}memory}\\PY{p}{(}\\PY{n}{TupleInt}\\PY{o}{.}\\PY{n}{from\\PYZus{}vec}\\PY{p}{(}\\PY{n}{Vec}\\PY{p}{(}\\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{3}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Int}\\PY{p}{(}\\PY{l+m+mi}{4}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{AMY}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", + "\\end{Verbatim}\n" + ], + "text/plain": [ + "take(Int(2), drop(Int(2), cat(NDArray.from_memory(TupleInt.from_vec(Vec(Int(2), Int(3), Int(4))), RAMY), NDArray.from_memory(TupleInt.from_vec(Vec(Int(2), Int(3), Int(4))), AMY))))" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from __future__ import annotations\n", + "\n", + "from collections.abc import Callable\n", + "\n", + "from egglog import *\n", + "\n", + "array_ruleset = ruleset(name=\"array_ruleset\")\n", + "\n", + "\n", + "class Boolean(Expr):\n", + " def __init__(self, val: BoolLike) -> None: ...\n", + " def if_bool(self, then: Int, else_: Int) -> Int: ...\n", + "\n", + "\n", + "class Int(Expr):\n", + " def __init__(self, val: i64Like) -> None: ...\n", + " def __eq__(self, other: Int) -> Boolean: ... # type: ignore[override]\n", + " def __lt__(self, other: Int) -> Boolean: ...\n", + " def __add__(self, other: Int) -> Int: ...\n", + " def __sub__(self, other: Int) -> Int: ...\n", + " def __mul__(self, other: Int) -> Int: ...\n", + "\n", + "\n", + "@array_ruleset.register\n", + "def _int(i: i64, j: i64, x: Int, y: Int):\n", + " yield rewrite(Int(i) + Int(j)).to(Int(i + j))\n", + " yield rewrite(Int(i) - Int(j)).to(Int(i - j))\n", + " yield rewrite(Int(i) * Int(j)).to(Int(i * j))\n", + " yield rewrite(Int(i) == Int(i)).to(Boolean(True))\n", + " yield rewrite(Int(i) == Int(j)).to(Boolean(False), i != j)\n", + " yield rewrite(Int(i) < Int(j)).to(Boolean(True), i < j)\n", + " yield rewrite(Int(i) < Int(j)).to(Boolean(False), i >= j)\n", + " yield rewrite(Boolean(True).if_bool(x, y)).to(x)\n", + " yield rewrite(Boolean(False).if_bool(x, y)).to(y)\n", + "\n", + "\n", + "@function\n", + "def vec_index(vec: Vec[Int], index: Int) -> Int: ...\n", + "\n", + "\n", + "@array_ruleset.register\n", + "def _vec_index(i: i64, xs: Vec[Int]):\n", + " yield rewrite(vec_index(xs, Int(i))).to(xs[i])\n", + "\n", + "\n", + "class TupleInt(Expr, ruleset=array_ruleset):\n", + " def __init__(self, length: Int, getitem_fn: Callable[[Int], Int]) -> None: ...\n", + " def __getitem__(self, index: Int) -> Int: ...\n", + "\n", + " @property\n", + " def length(self) -> Int: ...\n", + "\n", + " @classmethod\n", + " def from_vec(cls, xs: Vec[Int]) -> TupleInt:\n", + " return TupleInt(\n", + " Int(xs.length()),\n", + " lambda i: vec_index(xs, i),\n", + " )\n", + "\n", + "\n", + "@array_ruleset.register\n", + "def _tuple_int(l: Int, fn: Callable[[Int], Int], i: Int):\n", + " ti = TupleInt(l, fn)\n", + " yield rewrite(ti.length).to(l)\n", + " yield rewrite(ti[i]).to(fn(i))\n", + "\n", + "\n", + "class NDArray(Expr, ruleset=array_ruleset):\n", + " def __init__(self, shape: TupleInt, idx_fn: Callable[[TupleInt], Int]) -> None: ...\n", + "\n", + " @classmethod\n", + " def from_memory(cls, shape: TupleInt, values: TupleInt) -> NDArray:\n", + " # Only work on ndim = 3 for now\n", + " return NDArray(\n", + " shape,\n", + " lambda idx: values[\n", + " idx[Int(0)] * (shape[Int(1)] * shape[Int(2)]) + idx[Int(1)] * shape[Int(2)] + idx[Int(2)]\n", + " ],\n", + " )\n", + "\n", + " @property\n", + " def shape(self) -> TupleInt: ...\n", + "\n", + " def __getitem__(self, index: TupleInt) -> Int: ...\n", + "\n", + "\n", + "@array_ruleset.register\n", + "def _ndarray(shape: TupleInt, fn: Callable[[TupleInt], Int], idx: TupleInt):\n", + " nda = NDArray(shape, fn)\n", + " yield rewrite(nda.shape).to(shape)\n", + " yield rewrite(nda[idx]).to(fn(idx))\n", + "\n", + "\n", + "@function(subsume=True, ruleset=array_ruleset)\n", + "def cat(l: NDArray, r: NDArray) -> NDArray:\n", + " \"\"\"\n", + " Returns the concatenation of two arrays, they should have the same shape and the first dimension is added.\n", + " \"\"\"\n", + " return NDArray(\n", + " TupleInt(\n", + " l.shape.length,\n", + " lambda i: (i == Int(0)).if_bool(l.shape[Int(0)] + r.shape[Int(0)], l.shape[i]),\n", + " ),\n", + " lambda idx: (idx[Int(0)] < l.shape[Int(0)]).if_bool(\n", + " l[idx], r[TupleInt(r.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - l.shape[Int(0)], idx[i]))]\n", + " ),\n", + " )\n", + "\n", + "\n", + "@function(subsume=True, ruleset=array_ruleset)\n", + "def drop(x: Int, arr: NDArray) -> NDArray:\n", + " \"\"\"\n", + " Drops the first `x` elements off the front of the array `arr`.\n", + " \"\"\"\n", + " return NDArray(\n", + " TupleInt(\n", + " arr.shape.length,\n", + " lambda i: (i == Int(0)).if_bool(arr.shape[Int(0)] - x, arr.shape[i]),\n", + " ),\n", + " lambda idx: arr[\n", + " TupleInt(\n", + " arr.shape.length,\n", + " # Add x to the first index, so it skips the first x elements\n", + " lambda i: (i == Int(0)).if_bool(idx[Int(0)] + x, idx[i]),\n", + " )\n", + " ],\n", + " )\n", + "\n", + "\n", + "@function(subsume=True, ruleset=array_ruleset)\n", + "def take(x: Int, arr: NDArray) -> NDArray:\n", + " \"\"\"\n", + " Takes the first `x` elements off the front of the array `arr`.\n", + " \"\"\"\n", + " return NDArray(\n", + " TupleInt(\n", + " arr.shape.length,\n", + " lambda i: (i == Int(0)).if_bool(x, arr.shape[i]),\n", + " ),\n", + " lambda idx: arr[idx],\n", + " )\n", + "\n", + "\n", + "shape = TupleInt.from_vec(Vec(Int(2), Int(3), Int(4)))\n", + "RAMY = NDArray.from_memory(shape, constant(\"RAMY\", TupleInt))\n", + "AMY = NDArray.from_memory(shape, constant(\"AMY\", TupleInt))\n", + "Amts = take(Int(2), drop(Int(2), cat(RAMY, AMY)))\n", + "Amts" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "1ada95b6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Amts.shape.length()=Int(3)\n", + "Amts.shape[0]=Int(2)\n", + "Amts.shape[1]=Int(3)\n", + "Amts.shape[2]=Int(4)\n", + "Amts[i, j, k]=((i + Int(2)) < Int(2)).if_bool(RAMY[(((i + Int(2)) * Int(12)) + (j * Int(4))) + k], AMY[((((i + Int(2)) - Int(2)) * Int(12)) + (j * Int(4))) + k])\n", + "AMY[i, j, k]=AMY[((i * Int(12)) + (j * Int(4))) + k]\n" + ] + } + ], + "source": [ + "egraph = EGraph()\n", + "ndim = egraph.let(\"ndim\", Amts.shape.length)\n", + "shape_1 = egraph.let(\"shape_1\", Amts.shape[Int(0)])\n", + "shape_2 = egraph.let(\"shape_2\", Amts.shape[Int(1)])\n", + "shape_3 = egraph.let(\"shape_3\", Amts.shape[Int(2)])\n", + "idxs = TupleInt.from_vec(Vec(constant(\"i\", Int), constant(\"j\", Int), constant(\"k\", Int)))\n", + "idxed = egraph.let(\"idxed\", Amts[idxs])\n", + "amy_idxed = egraph.let(\"amy_idxed\", AMY[idxs])\n", + "\n", + "egraph.run(array_ruleset.saturate())\n", + "print(f\"Amts.shape.length()={egraph.extract(ndim)}\")\n", + "print(f\"Amts.shape[0]={egraph.extract(shape_1)}\")\n", + "print(f\"Amts.shape[1]={egraph.extract(shape_2)}\")\n", + "print(f\"Amts.shape[2]={egraph.extract(shape_3)}\")\n", + "print(f\"Amts[i, j, k]={egraph.extract(idxed)}\")\n", + "print(f\"AMY[i, j, k]={egraph.extract(amy_idxed)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "e3dfbd1f", + "metadata": {}, + "source": [ + "We can see that Amts is equal to AMY, since they have the shape and indexing them produces the same result.\n", + "\n", + "With some basic range analysis we could make them simplify to the same expression in the e-graph as well." + ] + }, + { + "cell_type": "markdown", + "id": "5a232786", + "metadata": {}, + "source": [ + "If we want, we can also see all the intermediate steps to get to the indexed result." + ] + }, + { + "cell_type": "markdown", + "id": "326942be", + "metadata": {}, + "source": [] + }, + { + "cell_type": "markdown", + "id": "a56c640a", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "c7b757ff", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "take(\n", + " Int(2),\n", + " drop(\n", + " Int(2), cat(NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY), NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY))\n", + " ),\n", + ")[TupleInt.from_vec(Vec[Int](i, j, k))] \n", + "\n", + "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n", + "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n", + "_NDArray_3 = NDArray(\n", + " TupleInt(_NDArray_1.shape.length, lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n", + " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n", + " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n", + " ),\n", + ")\n", + "_NDArray_4 = NDArray(\n", + " TupleInt(_NDArray_3.shape.length, lambda i: (i == Int(0)).if_bool(_NDArray_3.shape[Int(0)] - Int(2), _NDArray_3.shape[i])),\n", + " lambda idx: _NDArray_3[TupleInt(_NDArray_3.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] + Int(2), idx[i]))],\n", + ")\n", + "NDArray(TupleInt(_NDArray_4.shape.length, lambda i: (i == Int(0)).if_bool(Int(2), _NDArray_4.shape[i])), lambda idx: _NDArray_4[idx])[TupleInt.from_vec(Vec[Int](i, j, k))] \n", + "\n", + "_TupleInt_1 = TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4)))\n", + "_TupleInt_2 = TupleInt(\n", + " _TupleInt_1.length,\n", + " lambda i: (i == Int(0)).if_bool(\n", + " NDArray.from_memory(_TupleInt_1, RAMY).shape[Int(0)] + NDArray.from_memory(_TupleInt_1, AMY).shape[Int(0)], NDArray.from_memory(_TupleInt_1, RAMY).shape[i]\n", + " ),\n", + ")\n", + "_NDArray_1 = NDArray(\n", + " _TupleInt_2,\n", + " lambda idx: (idx[Int(0)] < NDArray.from_memory(_TupleInt_1, RAMY).shape[Int(0)]).if_bool(\n", + " NDArray.from_memory(_TupleInt_1, RAMY)[idx],\n", + " NDArray.from_memory(_TupleInt_1, AMY)[\n", + " TupleInt(\n", + " NDArray.from_memory(_TupleInt_1, AMY).shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - NDArray.from_memory(_TupleInt_1, RAMY).shape[Int(0)], idx[i])\n", + " )\n", + " ],\n", + " ),\n", + ")\n", + "(lambda arr, idx: arr[idx])(\n", + " NDArray(\n", + " TupleInt(_TupleInt_2.length, lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] - Int(2), _NDArray_1.shape[i])),\n", + " lambda idx: _NDArray_1[TupleInt(_NDArray_1.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] + Int(2), idx[i]))],\n", + " ),\n", + " TupleInt.from_vec(Vec[Int](i, j, k)),\n", + ") \n", + "\n", + "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n", + "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n", + "_NDArray_3 = NDArray(\n", + " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n", + " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n", + " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n", + " ),\n", + ")\n", + "NDArray(\n", + " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_3.shape[Int(0)] - Int(2), _NDArray_3.shape[i])),\n", + " lambda idx: _NDArray_3[TupleInt(_NDArray_3.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] + Int(2), idx[i]))],\n", + ")[TupleInt.from_vec(Vec[Int](i, j, k))] \n", + "\n", + "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n", + "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n", + "(lambda arr, x, idx: arr[TupleInt(arr.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] + x, idx[i]))])(\n", + " NDArray(\n", + " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n", + " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n", + " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n", + " ),\n", + " ),\n", + " Int(2),\n", + " TupleInt.from_vec(Vec[Int](i, j, k)),\n", + ") \n", + "\n", + "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n", + "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n", + "NDArray(\n", + " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n", + " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n", + " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n", + " ),\n", + ")[TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(TupleInt.from_vec(Vec[Int](i, j, k))[Int(0)] + Int(2), TupleInt.from_vec(Vec[Int](i, j, k))[i]))] \n", + "\n", + "(lambda l, r, idx: (idx[Int(0)] < l.shape[Int(0)]).if_bool(l[idx], r[TupleInt(r.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - l.shape[Int(0)], idx[i]))]))(\n", + " NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY),\n", + " NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY),\n", + " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(TupleInt.from_vec(Vec[Int](i, j, k))[Int(0)] + Int(2), TupleInt.from_vec(Vec[Int](i, j, k))[i])),\n", + ") \n", + "\n", + "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n", + "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n", + "NDArray(\n", + " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n", + " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n", + " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n", + " ),\n", + ")[TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(TupleInt.from_vec(Vec[Int](i, j, k))[Int(0)] + Int(2), TupleInt.from_vec(Vec[Int](i, j, k))[i]))] \n", + "\n", + "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n", + "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n", + "NDArray(\n", + " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n", + " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n", + " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n", + " ),\n", + ")[TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(TupleInt.from_vec(Vec[Int](i, j, k))[Int(0)] + Int(2), TupleInt.from_vec(Vec[Int](i, j, k))[i]))] \n", + "\n", + "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n", + "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n", + "NDArray(\n", + " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n", + " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n", + " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n", + " ),\n", + ")[TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(TupleInt.from_vec(Vec[Int](i, j, k))[Int(0)] + Int(2), TupleInt.from_vec(Vec[Int](i, j, k))[i]))] \n", + "\n", + "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n", + "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n", + "NDArray(\n", + " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n", + " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n", + " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n", + " ),\n", + ")[TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(TupleInt.from_vec(Vec[Int](i, j, k))[Int(0)] + Int(2), TupleInt.from_vec(Vec[Int](i, j, k))[i]))] \n", + "\n", + "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n", + "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n", + "NDArray(\n", + " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n", + " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n", + " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n", + " ),\n", + ")[TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(TupleInt.from_vec(Vec[Int](i, j, k))[Int(0)] + Int(2), TupleInt.from_vec(Vec[Int](i, j, k))[i]))] \n", + "\n", + "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n", + "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n", + "NDArray(\n", + " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n", + " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n", + " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n", + " ),\n", + ")[TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(TupleInt.from_vec(Vec[Int](i, j, k))[Int(0)] + Int(2), TupleInt.from_vec(Vec[Int](i, j, k))[i]))] \n", + "\n", + "_NDArray_1 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), RAMY)\n", + "_NDArray_2 = NDArray.from_memory(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4))), AMY)\n", + "NDArray(\n", + " TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(_NDArray_1.shape[Int(0)] + _NDArray_2.shape[Int(0)], _NDArray_1.shape[i])),\n", + " lambda idx: (idx[Int(0)] < _NDArray_1.shape[Int(0)]).if_bool(\n", + " _NDArray_1[idx], _NDArray_2[TupleInt(_NDArray_2.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - _NDArray_1.shape[Int(0)], idx[i]))]\n", + " ),\n", + ")[TupleInt(Int(3), lambda i: (i == Int(0)).if_bool(TupleInt.from_vec(Vec[Int](i, j, k))[Int(0)] + Int(2), TupleInt.from_vec(Vec[Int](i, j, k))[i]))] \n", + "\n", + "((i + Int(2)) < Int(2)).if_bool(RAMY[(((i + Int(2)) * Int(12)) + (j * Int(4))) + k], AMY[((((i + Int(2)) - Int(2)) * Int(12)) + (j * Int(4))) + k]) \n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9da93b4d1d6241819757834a6da521dd", + "version_major": 2, + "version_minor": 1 + }, + "text/plain": [ + "VisualizerWidget(egraphs=['{\"nodes\":{\"primitive-i64-2\":{\"op\":\"2\",\"children\":[],\"eclass\":\"i64-2\",\"cost\":1.0,\"su…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "egraph = EGraph()\n", + "idxed = egraph.let(\"idxed\", Amts[idxs])\n", + "egraph.saturate(array_ruleset, expr=idxed)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2642b054", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "egglog", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}