|
96 | 96 | "source": [ |
97 | 97 | "from __future__ import annotations\n", |
98 | 98 | "\n", |
| 99 | + "import os\n", |
| 100 | + "import numpy as np\n", |
| 101 | + "\n", |
| 102 | + "# Ensure SciPy array API support is enabled before importing sklearn/scipy\n", |
| 103 | + "os.environ.setdefault(\"SCIPY_ARRAY_API\", \"1\")\n", |
| 104 | + "\n", |
99 | 105 | "import sklearn\n", |
100 | 106 | "from sklearn.datasets import make_classification\n", |
101 | 107 | "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", |
102 | 108 | "\n", |
103 | 109 | "# Tell sklearn to treat arrays as following array API\n", |
104 | 110 | "sklearn.set_config(array_api_dispatch=True)\n", |
105 | 111 | "\n", |
106 | | - "X_np, y_np = make_classification(random_state=0, n_samples=1000000)\n", |
| 112 | + "X_np, y_np = make_classification(random_state=0, n_samples=10000)\n", |
107 | 113 | "\n", |
108 | 114 | "\n", |
109 | 115 | "# Assumption: I want to optimize calling this many times on data similar to that above\n", |
110 | 116 | "def run_lda(x, y):\n", |
111 | 117 | " lda = LinearDiscriminantAnalysis()\n", |
112 | | - " return lda.fit(x, y).transform(x)" |
| 118 | + " return lda.fit(x, y).transform(x)\n" |
113 | 119 | ] |
114 | 120 | }, |
115 | 121 | { |
|
831 | 837 | " egraph = EGraph()\n", |
832 | 838 | " egraph.register(self)\n", |
833 | 839 | " egraph.run(bool_rewrites.saturate())\n", |
834 | | - " return egraph.eval(self.bool)\n", |
| 840 | + " return egraph.extract(self.bool).value\n", |
835 | 841 | "\n", |
836 | 842 | "\n", |
837 | 843 | "x = var(\"x\", Boolean)\n", |
|
1392 | 1398 | "source": [ |
1393 | 1399 | "from egglog.exp.array_api_numba import array_api_numba_schedule\n", |
1394 | 1400 | "\n", |
1395 | | - "simplified_res = EGraph().simplify(res, array_api_numba_schedule)\n", |
| 1401 | + "with EGraph() as egraph:\n", |
| 1402 | + " egraph.register(res)\n", |
| 1403 | + " egraph.run(array_api_numba_schedule)\n", |
| 1404 | + " simplified_res = egraph.extract(res)\n", |
| 1405 | + "\n", |
1396 | 1406 | "simplified_res" |
1397 | 1407 | ] |
1398 | 1408 | }, |
|
1411 | 1421 | "source": [ |
1412 | 1422 | "Now that we have a program, what do we do with it?\n", |
1413 | 1423 | "\n", |
1414 | | - "Well we showed how we can use eager evaluation to get a result, but what if we don't want to do the computation in egglog, but instead export a program so we can execute that back in Python or in this case feed it to Python?\n", |
1415 | | - "\n", |
1416 | | - "Well in this case we have designed a `Program` object which we can use to convert a funtional egglog expression back to imperative Python code:\n" |
| 1424 | + "Previously this tutorial emitted runnable Python code using the experimental program generation APIs. Those APIs are in flux, so for now we'll skip directly emitting source and focus on the symbolic optimizations above.\n" |
1417 | 1425 | ] |
1418 | 1426 | }, |
1419 | 1427 | { |
|
1433 | 1441 | } |
1434 | 1442 | ], |
1435 | 1443 | "source": [ |
1436 | | - "from egglog.exp.array_api_program_gen import *\n", |
1437 | | - "\n", |
1438 | | - "egraph = EGraph()\n", |
1439 | | - "fn_program = egraph.let(\n", |
1440 | | - " \"fn_program\",\n", |
1441 | | - " ndarray_function_two(simplified_res, NDArray.var(\"X\"), NDArray.var(\"y\")),\n", |
1442 | | - ")\n", |
1443 | | - "egraph.run(array_api_program_gen_schedule)\n", |
1444 | | - "fn = egraph.eval(fn_program.py_object)\n", |
1445 | | - "\n", |
1446 | | - "fn" |
1447 | | - ] |
1448 | | - }, |
1449 | | - { |
1450 | | - "cell_type": "code", |
1451 | | - "execution_count": 12, |
1452 | | - "metadata": {}, |
1453 | | - "outputs": [ |
1454 | | - { |
1455 | | - "name": "stdout", |
1456 | | - "output_type": "stream", |
1457 | | - "text": [ |
1458 | | - "def __fn(X, y):\n", |
1459 | | - " assert X.dtype == np.dtype(np.float64)\n", |
1460 | | - " assert X.shape == (1000000, 20,)\n", |
1461 | | - " assert np.all(np.isfinite(X))\n", |
1462 | | - " assert y.dtype == np.dtype(np.int64)\n", |
1463 | | - " assert y.shape == (1000000,)\n", |
1464 | | - " assert set(np.unique(y)) == set((0, 1,))\n", |
1465 | | - " _0 = y == np.array(0)\n", |
1466 | | - " _1 = np.sum(_0)\n", |
1467 | | - " _2 = y == np.array(1)\n", |
1468 | | - " _3 = np.sum(_2)\n", |
1469 | | - " _4 = np.array((_1, _3,)).astype(np.dtype(np.float64))\n", |
1470 | | - " _5 = _4 / np.array(1000000.0)\n", |
1471 | | - " _6 = np.zeros((2, 20,), dtype=np.dtype(np.float64))\n", |
1472 | | - " _7 = np.sum(X[_0], axis=0)\n", |
1473 | | - " _8 = _7 / np.array(X[_0].shape[0])\n", |
1474 | | - " _6[0, :] = _8\n", |
1475 | | - " _9 = np.sum(X[_2], axis=0)\n", |
1476 | | - " _10 = _9 / np.array(X[_2].shape[0])\n", |
1477 | | - " _6[1, :] = _10\n", |
1478 | | - " _11 = _5 @ _6\n", |
1479 | | - " _12 = X - _11\n", |
1480 | | - " _13 = np.sqrt(np.array(float(1 / 999998)))\n", |
1481 | | - " _14 = X[_0] - _6[0, :]\n", |
1482 | | - " _15 = X[_2] - _6[1, :]\n", |
1483 | | - " _16 = np.concatenate((_14, _15,), axis=0)\n", |
1484 | | - " _17 = np.sum(_16, axis=0)\n", |
1485 | | - " _18 = _17 / np.array(_16.shape[0])\n", |
1486 | | - " _19 = np.expand_dims(_18, 0)\n", |
1487 | | - " _20 = _16 - _19\n", |
1488 | | - " _21 = np.square(_20)\n", |
1489 | | - " _22 = np.sum(_21, axis=0)\n", |
1490 | | - " _23 = _22 / np.array(_21.shape[0])\n", |
1491 | | - " _24 = np.sqrt(_23)\n", |
1492 | | - " _25 = _24 == np.array(0)\n", |
1493 | | - " _24[_25] = np.array(1.0)\n", |
1494 | | - " _26 = _16 / _24\n", |
1495 | | - " _27 = _13 * _26\n", |
1496 | | - " _28 = np.linalg.svd(_27, full_matrices=False)\n", |
1497 | | - " _29 = _28[1] > np.array(0.0001)\n", |
1498 | | - " _30 = _29.astype(np.dtype(np.int32))\n", |
1499 | | - " _31 = np.sum(_30)\n", |
1500 | | - " _32 = _28[2][:_31, :] / _24\n", |
1501 | | - " _33 = _32.T / _28[1][:_31]\n", |
1502 | | - " _34 = np.array(1000000) * _5\n", |
1503 | | - " _35 = _34 * np.array(1.0)\n", |
1504 | | - " _36 = np.sqrt(_35)\n", |
1505 | | - " _37 = _6 - _11\n", |
1506 | | - " _38 = _36 * _37.T\n", |
1507 | | - " _39 = _38.T @ _33\n", |
1508 | | - " _40 = np.linalg.svd(_39, full_matrices=False)\n", |
1509 | | - " _41 = np.array(0.0001) * _40[1][0]\n", |
1510 | | - " _42 = _40[1] > _41\n", |
1511 | | - " _43 = _42.astype(np.dtype(np.int32))\n", |
1512 | | - " _44 = np.sum(_43)\n", |
1513 | | - " _45 = _33 @ _40[2].T[:, :_44]\n", |
1514 | | - " _46 = _12 @ _45\n", |
1515 | | - " return _46[:, :1]\n", |
1516 | | - "\n" |
1517 | | - ] |
1518 | | - } |
1519 | | - ], |
1520 | | - "source": [ |
1521 | | - "import inspect\n", |
1522 | | - "\n", |
1523 | | - "print(inspect.getsource(fn))" |
| 1444 | + "print(\"Program generation to Python source is temporarily disabled in this tutorial example.\")\n" |
1524 | 1445 | ] |
1525 | 1446 | }, |
1526 | 1447 | { |
1527 | 1448 | "cell_type": "markdown", |
1528 | 1449 | "metadata": {}, |
1529 | 1450 | "source": [ |
1530 | | - "From there we can complete our work, by optimizing with numba and we can call with our original values:\n" |
1531 | | - ] |
1532 | | - }, |
1533 | | - { |
1534 | | - "cell_type": "code", |
1535 | | - "execution_count": 13, |
1536 | | - "metadata": {}, |
1537 | | - "outputs": [ |
1538 | | - { |
1539 | | - "name": "stderr", |
1540 | | - "output_type": "stream", |
1541 | | - "text": [ |
1542 | | - "/var/folders/xn/05ktz3056kqd9n8frgd6236h0000gn/T/egglog-9b40af4a-3b8a-4996-a78a-fd6284dbf541.py:56: NumbaPerformanceWarning: '@' is faster on contiguous arrays, called on (Array(float64, 2, 'C', False, aligned=True), Array(float64, 2, 'A', False, aligned=True))\n", |
1543 | | - " _45 = _33 @ _40[2].T[:, :_44]\n" |
1544 | | - ] |
1545 | | - }, |
1546 | | - { |
1547 | | - "data": { |
1548 | | - "text/plain": [ |
1549 | | - "array([[ 0.64233002],\n", |
1550 | | - " [ 0.63661245],\n", |
1551 | | - " [-1.603293 ],\n", |
1552 | | - " ...,\n", |
1553 | | - " [-1.1506433 ],\n", |
1554 | | - " [ 0.71687176],\n", |
1555 | | - " [-1.51119579]])" |
1556 | | - ] |
1557 | | - }, |
1558 | | - "execution_count": 13, |
1559 | | - "metadata": {}, |
1560 | | - "output_type": "execute_result" |
1561 | | - } |
1562 | | - ], |
1563 | | - "source": [ |
1564 | | - "from numba import njit\n", |
1565 | | - "\n", |
1566 | | - "njit(fn)(X_np, y_np)" |
| 1451 | + "With the direct code emission skipped, you can still use the symbolic results above or plug them into your own pipelines.\n" |
1567 | 1452 | ] |
1568 | 1453 | }, |
1569 | 1454 | { |
|
1623 | 1508 | "egraph = EGraph()\n", |
1624 | 1509 | "egraph.register(fn.compile())\n", |
1625 | 1510 | "egraph.run(program_gen_ruleset.saturate())\n", |
1626 | | - "print(egraph.eval(fn.statements))" |
| 1511 | + "print(egraph.extract(fn.statements).value)" |
1627 | 1512 | ] |
1628 | 1513 | }, |
1629 | 1514 | { |
|
0 commit comments