Skip to content

Commit 1b6abfc

Browse files
first go
1 parent b9f201c commit 1b6abfc

File tree

3 files changed

+263
-18
lines changed

3 files changed

+263
-18
lines changed

Cargo.lock

Lines changed: 9 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ crate-type = ["cdylib"]
1313
pyo3 = { version = "0.27", features = ["extension-module", "num-bigint", "num-rational"] }
1414
num-bigint = "*"
1515
num-rational = "*"
16-
egglog = { git = "https://github.com/egraphs-good/egglog.git", branch = "main", default-features = false }
17-
egglog-bridge = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
18-
egglog-core-relations = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
16+
egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug", default-features = false }
17+
egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" }
18+
egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" }
1919
egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental", branch = "main", default-features = false }
20-
egglog-ast = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
21-
egglog-reports = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
20+
egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" }
21+
egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" }
2222
egraph-serialize = { version = "0.3", features = ["serde", "graphviz"] }
2323
serde_json = "1"
2424
pyo3-log = "*"
@@ -31,10 +31,11 @@ base64 = "0.22.1"
3131

3232
# Use patched version of egglog in experimental
3333
[patch.'https://github.com/egraphs-good/egglog']
34-
# egglog = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
35-
# egglog-bridge = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
36-
# egglog-core-relations = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
37-
# egglog-ast = { git = "https://github.com/egraphs-good/egglog.git", branch = "main" }
34+
egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" }
35+
egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" }
36+
egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" }
37+
egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" }
38+
egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" }
3839

3940
# enable debug symbols for easier profiling
4041
[profile.release]

docs/explanation/Untitled-1.ipynb

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "1b715c58",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"from __future__ import annotations\n",
11+
"from collections.abc import Callable\n",
12+
"from egglog import *\n",
13+
"\n",
14+
"array_ruleset = ruleset(name=\"array_ruleset\")\n",
15+
"\n",
16+
"\n",
17+
"class Boolean(Expr):\n",
18+
" def __init__(self, val: BoolLike) -> None: ...\n",
19+
" def if_bool(self, then: Int, else_: Int) -> Int: ...\n",
20+
"\n",
21+
"\n",
22+
"class Int(Expr):\n",
23+
" @classmethod\n",
24+
" def var(cls, name: StringLike) -> Int: ...\n",
25+
"\n",
26+
" def __init__(self, val: i64Like) -> None: ...\n",
27+
" def __eq__(self, other: Int) -> Boolean: ... # type: ignore[override]\n",
28+
" def __lt__(self, other: Int) -> Boolean: ...\n",
29+
" def __add__(self, other: Int) -> Int: ...\n",
30+
" def __sub__(self, other: Int) -> Int: ...\n",
31+
"\n",
32+
"\n",
33+
"@array_ruleset.register\n",
34+
"def _int(i: i64, j: i64, x: Int, y: Int):\n",
35+
" yield rewrite(Int(i) + Int(j)).to(Int(i + j))\n",
36+
" yield rewrite(Int(i) - Int(j)).to(Int(i - j))\n",
37+
" yield rewrite(Int(i) == Int(i)).to(Boolean(True))\n",
38+
" yield rewrite(Int(i) == Int(j)).to(Boolean(False), i != j)\n",
39+
" yield rewrite(Int(i) < Int(j)).to(Boolean(True), i < j)\n",
40+
" yield rewrite(Int(i) < Int(j)).to(Boolean(False), i >= j)\n",
41+
" yield rewrite(Boolean(True).if_bool(x, y)).to(x)\n",
42+
" yield rewrite(Boolean(False).if_bool(x, y)).to(y)\n",
43+
"\n",
44+
"\n",
45+
"@function\n",
46+
"def vec_index(vec: Vec[Int], index: Int) -> Int: ...\n",
47+
"\n",
48+
"\n",
49+
"@array_ruleset.register\n",
50+
"def _vec_index(i: i64, xs: Vec[Int]):\n",
51+
" yield rewrite(vec_index(xs, Int(i))).to(xs[i])\n",
52+
"\n",
53+
"\n",
54+
"class TupleInt(Expr, ruleset=array_ruleset):\n",
55+
" def __init__(self, length: Int, getitem_fn: Callable[[Int], Int]) -> None: ...\n",
56+
" def __getitem__(self, index: Int) -> Int: ...\n",
57+
"\n",
58+
" @property\n",
59+
" def length(self) -> Int: ...\n",
60+
"\n",
61+
" @classmethod\n",
62+
" def from_vec(cls, xs: Vec[Int]) -> TupleInt:\n",
63+
" return TupleInt(\n",
64+
" Int(xs.length()),\n",
65+
" lambda i: vec_index(xs, i),\n",
66+
" )\n",
67+
"\n",
68+
"\n",
69+
"@array_ruleset.register\n",
70+
"def _tuple_int(l: Int, fn: Callable[[Int], Int], i: Int):\n",
71+
" ti = TupleInt(l, fn)\n",
72+
" yield rewrite(ti.length).to(l)\n",
73+
" yield rewrite(ti[i]).to(fn(i))\n",
74+
"\n",
75+
"\n",
76+
"class NDArray(Expr, ruleset=array_ruleset):\n",
77+
" def __init__(self, shape: TupleInt, idx_fn: Callable[[TupleInt], Int]) -> None: ...\n",
78+
"\n",
79+
" @classmethod\n",
80+
" def from_vec(cls, values: Vec[Int]) -> NDArray:\n",
81+
" return NDArray(\n",
82+
" TupleInt(Int(1), lambda i: Int(values.length())),\n",
83+
" lambda idx: vec_index(values, idx[Int(0)]),\n",
84+
" )\n",
85+
"\n",
86+
" def with_shape(self, shape: TupleInt) -> NDArray:\n",
87+
" return NDArray(shape, self.__getitem__)\n",
88+
"\n",
89+
" @classmethod\n",
90+
" def var(cls, name: StringLike) -> NDArray: ...\n",
91+
"\n",
92+
" @property\n",
93+
" def shape(self) -> TupleInt: ...\n",
94+
"\n",
95+
" def __getitem__(self, index: TupleInt) -> Int: ...\n",
96+
"\n",
97+
"\n",
98+
"@array_ruleset.register\n",
99+
"def _ndarray(shape: TupleInt, fn: Callable[[TupleInt], Int], idx: TupleInt):\n",
100+
" nda = NDArray(shape, fn)\n",
101+
" yield rewrite(nda.shape).to(shape)\n",
102+
" yield rewrite(nda[idx]).to(fn(idx))\n",
103+
"\n",
104+
"\n",
105+
"@function(subsume=True, ruleset=array_ruleset)\n",
106+
"def cat(l: NDArray, r: NDArray) -> NDArray:\n",
107+
" \"\"\"\n",
108+
" Returns the concatenation of two arrays, they should have the same shape and the first dimension is added.\n",
109+
" \"\"\"\n",
110+
" return NDArray(\n",
111+
" TupleInt(\n",
112+
" l.shape.length,\n",
113+
" lambda i: (i == Int(0)).if_bool(l.shape[Int(0)] + r.shape[Int(0)], l.shape[i]),\n",
114+
" ),\n",
115+
" lambda idx: (idx[Int(0)] < l.shape[Int(0)]).if_bool(\n",
116+
" l[idx], r[TupleInt(r.shape.length, lambda i: (i == Int(0)).if_bool(idx[Int(0)] - l.shape[Int(0)], idx[i]))]\n",
117+
" ),\n",
118+
" )\n",
119+
"\n",
120+
"\n",
121+
"@function(subsume=True, ruleset=array_ruleset)\n",
122+
"def drop(x: Int, arr: NDArray) -> NDArray:\n",
123+
" \"\"\"\n",
124+
" Drops the first `x` elements off the front of the array `arr`.\n",
125+
" \"\"\"\n",
126+
" return NDArray(\n",
127+
" TupleInt(\n",
128+
" arr.shape.length,\n",
129+
" lambda i: (i == Int(0)).if_bool(arr.shape[Int(0)] - x, arr.shape[i]),\n",
130+
" ),\n",
131+
" lambda idx: arr[\n",
132+
" TupleInt(\n",
133+
" arr.shape.length,\n",
134+
" # Add x to the first index, so it skips the first x elements\n",
135+
" lambda i: (i == Int(0)).if_bool(idx[Int(0)] + x, idx[i]),\n",
136+
" )\n",
137+
" ],\n",
138+
" )\n",
139+
"\n",
140+
"\n",
141+
"@function(subsume=True, ruleset=array_ruleset)\n",
142+
"def take(x: Int, arr: NDArray) -> NDArray:\n",
143+
" \"\"\"\n",
144+
" Takes the first `x` elements off the front of the array `arr`.\n",
145+
" \"\"\"\n",
146+
" return NDArray(\n",
147+
" TupleInt(\n",
148+
" arr.shape.length,\n",
149+
" lambda i: (i == Int(0)).if_bool(x, arr.shape[i]),\n",
150+
" ),\n",
151+
" lambda idx: arr[idx],\n",
152+
" )"
153+
]
154+
},
155+
{
156+
"cell_type": "code",
157+
"execution_count": null,
158+
"id": "1ada95b6",
159+
"metadata": {},
160+
"outputs": [
161+
{
162+
"name": "stdout",
163+
"output_type": "stream",
164+
"text": [
165+
"Amts.shape.length()=Int(3)\n",
166+
"Amts.shape[0]=Int(2)\n",
167+
"Amts.shape[1]=Int(3)\n",
168+
"Amts.shape[2]=Int(4)\n",
169+
"Amts[i, j, k]=\n",
170+
"_TupleInt_1 = TupleInt(\n",
171+
" Int(3),\n",
172+
" lambda i: (i == Int(0)).if_bool(\n",
173+
" TupleInt.from_vec(Vec[Int](Int.var(\"i\"), Int.var(\"j\"), Int.var(\"k\")))[Int(0)] + Int(2), TupleInt.from_vec(Vec[Int](Int.var(\"i\"), Int.var(\"j\"), Int.var(\"k\")))[i]\n",
174+
" ),\n",
175+
")\n",
176+
"((Int.var(\"i\") + Int(2)) < Int(2)).if_bool(\n",
177+
" NDArray.var(\"RAMY\")[_TupleInt_1],\n",
178+
" NDArray.var(\"AMY\")[\n",
179+
" TupleInt(\n",
180+
" Int(3),\n",
181+
" lambda i: (i == Int(0)).if_bool(\n",
182+
" _TupleInt_1[Int(0)] - NDArray.var(\"RAMY\").with_shape(TupleInt.from_vec(Vec[Int](Int(2), Int(3), Int(4)))).shape[Int(0)], _TupleInt_1[i]\n",
183+
" ),\n",
184+
" )\n",
185+
" ],\n",
186+
")\n"
187+
]
188+
}
189+
],
190+
"source": [
191+
"shape = TupleInt.from_vec(Vec(Int(2), Int(3), Int(4)))\n",
192+
"RAMY = NDArray.var(\"RAMY\").with_shape(shape)\n",
193+
"AMY = NDArray.var(\"AMY\").with_shape(shape)\n",
194+
"\n",
195+
"\n",
196+
"egraph = EGraph()\n",
197+
"\n",
198+
"Amts = egraph.let(\"Amts\", take(Int(2), drop(Int(2), cat(RAMY, AMY))))\n",
199+
"\n",
200+
"ndim = egraph.let(\"ndim\", Amts.shape.length)\n",
201+
"shape_1 = egraph.let(\"shape_1\", Amts.shape[Int(0)])\n",
202+
"shape_2 = egraph.let(\"shape_2\", Amts.shape[Int(1)])\n",
203+
"shape_3 = egraph.let(\"shape_3\", Amts.shape[Int(2)])\n",
204+
"idxed = egraph.let(\"idxed\", Amts[TupleInt.from_vec(Vec(Int.var(\"i\"), Int.var(\"j\"), Int.var(\"k\")))])\n",
205+
"\n",
206+
"egraph.run(array_ruleset.saturate())\n",
207+
"print(f\"Amts.shape.length()={egraph.extract(ndim)}\")\n",
208+
"print(f\"Amts.shape[0]={egraph.extract(shape_1)}\")\n",
209+
"print(f\"Amts.shape[1]={egraph.extract(shape_2)}\")\n",
210+
"print(f\"Amts.shape[2]={egraph.extract(shape_3)}\")\n",
211+
"print(f\"Amts[i, j, k]=\\n{egraph.extract(idxed)}\")"
212+
]
213+
},
214+
{
215+
"cell_type": "code",
216+
"execution_count": null,
217+
"id": "e3dfbd1f",
218+
"metadata": {},
219+
"outputs": [],
220+
"source": []
221+
}
222+
],
223+
"metadata": {
224+
"kernelspec": {
225+
"display_name": "egglog",
226+
"language": "python",
227+
"name": "python3"
228+
},
229+
"language_info": {
230+
"codemirror_mode": {
231+
"name": "ipython",
232+
"version": 3
233+
},
234+
"file_extension": ".py",
235+
"mimetype": "text/x-python",
236+
"name": "python",
237+
"nbconvert_exporter": "python",
238+
"pygments_lexer": "ipython3",
239+
"version": "3.13.3"
240+
}
241+
},
242+
"nbformat": 4,
243+
"nbformat_minor": 5
244+
}

0 commit comments

Comments
 (0)