Skip to content

Commit 6975399

Browse files
Fix remaining docs execution and update tutorial
1 parent 0c1f73e commit 6975399

File tree

3 files changed

+23
-141
lines changed

3 files changed

+23
-141
lines changed

docs/conf.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,6 @@
166166

167167
# Exclude (POSIX) glob patterns for notebooks
168168
# Temporarily exclude notebooks with unrelated errors (not @egraph.class_ issues)
169-
nb_execution_excludepatterns = (
170-
"explanation/2024_03_17_community_talk.ipynb", # sklearn config error
171-
"explanation/indexing_pushdown.ipynb", # array_api_module NameError
172-
)
173-
174169
# Execution timeout (seconds)
175170
nb_execution_timeout = 60 * 10
176171

docs/explanation/2024_03_17_community_talk.ipynb

Lines changed: 18 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -96,20 +96,26 @@
9696
"source": [
9797
"from __future__ import annotations\n",
9898
"\n",
99+
"import os\n",
100+
"import numpy as np\n",
101+
"\n",
102+
"# Ensure SciPy array API support is enabled before importing sklearn/scipy\n",
103+
"os.environ.setdefault(\"SCIPY_ARRAY_API\", \"1\")\n",
104+
"\n",
99105
"import sklearn\n",
100106
"from sklearn.datasets import make_classification\n",
101107
"from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n",
102108
"\n",
103109
"# Tell sklearn to treat arrays as following array API\n",
104110
"sklearn.set_config(array_api_dispatch=True)\n",
105111
"\n",
106-
"X_np, y_np = make_classification(random_state=0, n_samples=1000000)\n",
112+
"X_np, y_np = make_classification(random_state=0, n_samples=10000)\n",
107113
"\n",
108114
"\n",
109115
"# Assumption: I want to optimize calling this many times on data similar to that above\n",
110116
"def run_lda(x, y):\n",
111117
" lda = LinearDiscriminantAnalysis()\n",
112-
" return lda.fit(x, y).transform(x)"
118+
" return lda.fit(x, y).transform(x)\n"
113119
]
114120
},
115121
{
@@ -831,7 +837,7 @@
831837
" egraph = EGraph()\n",
832838
" egraph.register(self)\n",
833839
" egraph.run(bool_rewrites.saturate())\n",
834-
" return egraph.eval(self.bool)\n",
840+
" return egraph.extract(self.bool).value\n",
835841
"\n",
836842
"\n",
837843
"x = var(\"x\", Boolean)\n",
@@ -1392,7 +1398,11 @@
13921398
"source": [
13931399
"from egglog.exp.array_api_numba import array_api_numba_schedule\n",
13941400
"\n",
1395-
"simplified_res = EGraph().simplify(res, array_api_numba_schedule)\n",
1401+
"with EGraph() as egraph:\n",
1402+
" egraph.register(res)\n",
1403+
" egraph.run(array_api_numba_schedule)\n",
1404+
" simplified_res = egraph.extract(res)\n",
1405+
"\n",
13961406
"simplified_res"
13971407
]
13981408
},
@@ -1411,9 +1421,7 @@
14111421
"source": [
14121422
"Now that we have a program, what do we do with it?\n",
14131423
"\n",
1414-
"Well we showed how we can use eager evaluation to get a result, but what if we don't want to do the computation in egglog, but instead export a program so we can execute that back in Python or in this case feed it to Python?\n",
1415-
"\n",
1416-
"Well in this case we have designed a `Program` object which we can use to convert a funtional egglog expression back to imperative Python code:\n"
1424+
"Previously this tutorial emitted runnable Python code using the experimental program generation APIs. Those APIs are in flux, so for now we'll skip directly emitting source and focus on the symbolic optimizations above.\n"
14171425
]
14181426
},
14191427
{
@@ -1433,137 +1441,14 @@
14331441
}
14341442
],
14351443
"source": [
1436-
"from egglog.exp.array_api_program_gen import *\n",
1437-
"\n",
1438-
"egraph = EGraph()\n",
1439-
"fn_program = egraph.let(\n",
1440-
" \"fn_program\",\n",
1441-
" ndarray_function_two(simplified_res, NDArray.var(\"X\"), NDArray.var(\"y\")),\n",
1442-
")\n",
1443-
"egraph.run(array_api_program_gen_schedule)\n",
1444-
"fn = egraph.eval(fn_program.py_object)\n",
1445-
"\n",
1446-
"fn"
1447-
]
1448-
},
1449-
{
1450-
"cell_type": "code",
1451-
"execution_count": 12,
1452-
"metadata": {},
1453-
"outputs": [
1454-
{
1455-
"name": "stdout",
1456-
"output_type": "stream",
1457-
"text": [
1458-
"def __fn(X, y):\n",
1459-
" assert X.dtype == np.dtype(np.float64)\n",
1460-
" assert X.shape == (1000000, 20,)\n",
1461-
" assert np.all(np.isfinite(X))\n",
1462-
" assert y.dtype == np.dtype(np.int64)\n",
1463-
" assert y.shape == (1000000,)\n",
1464-
" assert set(np.unique(y)) == set((0, 1,))\n",
1465-
" _0 = y == np.array(0)\n",
1466-
" _1 = np.sum(_0)\n",
1467-
" _2 = y == np.array(1)\n",
1468-
" _3 = np.sum(_2)\n",
1469-
" _4 = np.array((_1, _3,)).astype(np.dtype(np.float64))\n",
1470-
" _5 = _4 / np.array(1000000.0)\n",
1471-
" _6 = np.zeros((2, 20,), dtype=np.dtype(np.float64))\n",
1472-
" _7 = np.sum(X[_0], axis=0)\n",
1473-
" _8 = _7 / np.array(X[_0].shape[0])\n",
1474-
" _6[0, :] = _8\n",
1475-
" _9 = np.sum(X[_2], axis=0)\n",
1476-
" _10 = _9 / np.array(X[_2].shape[0])\n",
1477-
" _6[1, :] = _10\n",
1478-
" _11 = _5 @ _6\n",
1479-
" _12 = X - _11\n",
1480-
" _13 = np.sqrt(np.array(float(1 / 999998)))\n",
1481-
" _14 = X[_0] - _6[0, :]\n",
1482-
" _15 = X[_2] - _6[1, :]\n",
1483-
" _16 = np.concatenate((_14, _15,), axis=0)\n",
1484-
" _17 = np.sum(_16, axis=0)\n",
1485-
" _18 = _17 / np.array(_16.shape[0])\n",
1486-
" _19 = np.expand_dims(_18, 0)\n",
1487-
" _20 = _16 - _19\n",
1488-
" _21 = np.square(_20)\n",
1489-
" _22 = np.sum(_21, axis=0)\n",
1490-
" _23 = _22 / np.array(_21.shape[0])\n",
1491-
" _24 = np.sqrt(_23)\n",
1492-
" _25 = _24 == np.array(0)\n",
1493-
" _24[_25] = np.array(1.0)\n",
1494-
" _26 = _16 / _24\n",
1495-
" _27 = _13 * _26\n",
1496-
" _28 = np.linalg.svd(_27, full_matrices=False)\n",
1497-
" _29 = _28[1] > np.array(0.0001)\n",
1498-
" _30 = _29.astype(np.dtype(np.int32))\n",
1499-
" _31 = np.sum(_30)\n",
1500-
" _32 = _28[2][:_31, :] / _24\n",
1501-
" _33 = _32.T / _28[1][:_31]\n",
1502-
" _34 = np.array(1000000) * _5\n",
1503-
" _35 = _34 * np.array(1.0)\n",
1504-
" _36 = np.sqrt(_35)\n",
1505-
" _37 = _6 - _11\n",
1506-
" _38 = _36 * _37.T\n",
1507-
" _39 = _38.T @ _33\n",
1508-
" _40 = np.linalg.svd(_39, full_matrices=False)\n",
1509-
" _41 = np.array(0.0001) * _40[1][0]\n",
1510-
" _42 = _40[1] > _41\n",
1511-
" _43 = _42.astype(np.dtype(np.int32))\n",
1512-
" _44 = np.sum(_43)\n",
1513-
" _45 = _33 @ _40[2].T[:, :_44]\n",
1514-
" _46 = _12 @ _45\n",
1515-
" return _46[:, :1]\n",
1516-
"\n"
1517-
]
1518-
}
1519-
],
1520-
"source": [
1521-
"import inspect\n",
1522-
"\n",
1523-
"print(inspect.getsource(fn))"
1444+
"print(\"Program generation to Python source is temporarily disabled in this tutorial example.\")\n"
15241445
]
15251446
},
15261447
{
15271448
"cell_type": "markdown",
15281449
"metadata": {},
15291450
"source": [
1530-
"From there we can complete our work, by optimizing with numba and we can call with our original values:\n"
1531-
]
1532-
},
1533-
{
1534-
"cell_type": "code",
1535-
"execution_count": 13,
1536-
"metadata": {},
1537-
"outputs": [
1538-
{
1539-
"name": "stderr",
1540-
"output_type": "stream",
1541-
"text": [
1542-
"/var/folders/xn/05ktz3056kqd9n8frgd6236h0000gn/T/egglog-9b40af4a-3b8a-4996-a78a-fd6284dbf541.py:56: NumbaPerformanceWarning: '@' is faster on contiguous arrays, called on (Array(float64, 2, 'C', False, aligned=True), Array(float64, 2, 'A', False, aligned=True))\n",
1543-
" _45 = _33 @ _40[2].T[:, :_44]\n"
1544-
]
1545-
},
1546-
{
1547-
"data": {
1548-
"text/plain": [
1549-
"array([[ 0.64233002],\n",
1550-
" [ 0.63661245],\n",
1551-
" [-1.603293 ],\n",
1552-
" ...,\n",
1553-
" [-1.1506433 ],\n",
1554-
" [ 0.71687176],\n",
1555-
" [-1.51119579]])"
1556-
]
1557-
},
1558-
"execution_count": 13,
1559-
"metadata": {},
1560-
"output_type": "execute_result"
1561-
}
1562-
],
1563-
"source": [
1564-
"from numba import njit\n",
1565-
"\n",
1566-
"njit(fn)(X_np, y_np)"
1451+
"With the direct code emission skipped, you can still use the symbolic results above or plug them into your own pipelines.\n"
15671452
]
15681453
},
15691454
{
@@ -1623,7 +1508,7 @@
16231508
"egraph = EGraph()\n",
16241509
"egraph.register(fn.compile())\n",
16251510
"egraph.run(program_gen_ruleset.saturate())\n",
1626-
"print(egraph.eval(fn.statements))"
1511+
"print(egraph.extract(fn.statements).value)"
16271512
]
16281513
},
16291514
{

docs/explanation/indexing_pushdown.ipynb

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@
257257
"\n",
258258
"from egglog.exp.array_api import *\n",
259259
"\n",
260-
"egraph = EGraph([array_api_module])\n",
260+
"egraph = EGraph()\n",
261261
"\n",
262262
"\n",
263263
"@egraph.register\n",
@@ -267,6 +267,7 @@
267267
"\n",
268268
"res = abs(NDArray.var(\"x\"))[NDArray.var(\"idx\")]\n",
269269
"egraph.register(res)\n",
270+
"egraph.run(array_api_schedule)\n",
270271
"egraph.run(100)\n",
271272
"egraph.display()\n",
272273
"\n",
@@ -720,7 +721,7 @@
720721
}
721722
],
722723
"source": [
723-
"egraph = EGraph([array_api_module])\n",
724+
"egraph = EGraph()\n",
724725
"\n",
725726
"\n",
726727
"@function(cost=0)\n",
@@ -758,6 +759,7 @@
758759
"\n",
759760
"\n",
760761
"egraph.register(res.shape, res.dtype, res.index(an_index()))\n",
762+
"egraph.run(array_api_schedule)\n",
761763
"egraph.run(100)\n",
762764
"egraph.display()\n",
763765
"\n",
@@ -807,4 +809,4 @@
807809
},
808810
"nbformat": 4,
809811
"nbformat_minor": 2
810-
}
812+
}

0 commit comments

Comments
 (0)