Skip to content

Commit e81e641

Browse files
Merge pull request #257 from egraphs-good/fix-loopnest
Adds working loopnest example
2 parents 3829d69 + 61b0024 commit e81e641

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1857
-866
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,8 @@ repos:
66
- id: ruff
77
args: [--fix, --exit-non-zero-on-fix]
88
- id: ruff-format
9+
- repo: https://github.com/astral-sh/uv-pre-commit
10+
# uv version.
11+
rev: 0.5.31
12+
hooks:
13+
- id: uv-lock

docs/changelog.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@ _This project uses semantic versioning_
99
- Add better error message when using @function in class (thanks @shinawy)
1010
- Add error method if `@method` decorator is in wrong place
1111
- Subsumes lambda functions after replacing
12-
- Add working loopnest test
12+
- Add working loopnest test and rewrite array api suport to be more general
13+
- Improve tracebacks on failing conversions.
14+
- Use `add_note` for exception to add more context, instead of raising a new exception, to make it easier to debug.
15+
- Add conversions from generic types to be supported at runtime and typing level (so can go from `(1, 2, 3)` to `TupleInt`)
16+
- Open files with webbrowser instead of internal graphviz util for better support
1317

1418
## 8.0.1 (2024-10-24)
1519

docs/how-to-guides.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ file_format: mystnb
99
You can provide your program in a special DSL language. You can parse this with {meth}`egglog.bindings.parse_program` and then run the result with You can parse this with {meth}`egglog.bindings.EGraph.run_program`::
1010

1111
```{code-cell}
12-
from egglog.bindings import EGraph
12+
from egglog.bindings import EGraph, parse_program
1313
1414
egraph = EGraph()
1515
commands = parse_program("(check (= (+ 1 2) 3))")

pyproject.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ build-backend = "maturin"
66
name = "egglog"
77
description = "e-graphs in Python built around the the egglog rust library"
88
readme = "README.md"
9-
dynamic = ["version"]
9+
dynamic = ["version"]
1010
license = { text = "MIT" }
1111
requires-python = ">=3.10"
1212
classifiers = [
@@ -34,8 +34,9 @@ dependencies = ["typing-extensions", "black", "graphviz", "anywidget"]
3434
array = [
3535
"scikit-learn",
3636
"array_api_compat",
37-
"numba==0.59.1",
38-
"llvmlite==0.42.0",
37+
"numba>=0.59.1",
38+
"llvmlite>=0.42.0",
39+
"numpy>2",
3940
]
4041
dev = [
4142
"ruff",
@@ -216,7 +217,7 @@ preview = true
216217

217218
[tool.ruff.lint.per-file-ignores]
218219
# Don't require annotations for tests
219-
"python/tests/**" = ["ANN001", "ANN201"]
220+
"python/tests/**" = ["ANN001", "ANN201", "INP001"]
220221

221222
[tool.mypy]
222223
ignore_missing_imports = true
@@ -239,7 +240,6 @@ python_files = ["test_*.py", "test.py"]
239240
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
240241
norecursedirs = ["__snapshots__"]
241242
filterwarnings = [
242-
"error",
243243
"ignore::numba.core.errors.NumbaPerformanceWarning",
244244
"ignore::pytest_benchmark.logger.PytestBenchmarkWarning",
245245
# https://github.com/manzt/anywidget/blob/d38bb3f5f9cfc7e49e2ff1aa1ba994d66327cb02/pyproject.toml#L120

python/egglog/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from . import config, ipython_magic # noqa: F401
66
from .builtins import * # noqa: UP029
7-
from .conversion import convert, converter # noqa: F401
7+
from .conversion import ConvertError, convert, converter, get_type_args # noqa: F401
88
from .egraph import *
99

1010
del ipython_magic

python/egglog/builtins.py

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55

66
from __future__ import annotations
77

8-
from functools import partial
8+
from functools import partial, reduce
99
from types import FunctionType
1010
from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union, cast, overload
1111

1212
from typing_extensions import TypeVarTuple, Unpack
1313

14-
from .conversion import converter, get_type_args
14+
from .conversion import convert, converter, get_type_args
1515
from .egraph import Expr, Unit, function, get_current_ruleset, method
1616
from .functionalize import functionalize
1717
from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction
@@ -22,24 +22,27 @@
2222

2323

2424
__all__ = [
25-
"i64",
26-
"i64Like",
27-
"f64",
28-
"f64Like",
2925
"Bool",
3026
"BoolLike",
31-
"String",
32-
"StringLike",
3327
"Map",
28+
"MapLike",
29+
"PyObject",
3430
"Rational",
3531
"Set",
32+
"SetLike",
33+
"String",
34+
"StringLike",
35+
"UnstableFn",
3636
"Vec",
37+
"VecLike",
38+
"f64",
39+
"f64Like",
40+
"i64",
41+
"i64Like",
3742
"join",
38-
"PyObject",
3943
"py_eval",
40-
"py_exec",
4144
"py_eval_fn",
42-
"UnstableFn",
45+
"py_exec",
4346
]
4447

4548

@@ -210,6 +213,9 @@ def __truediv__(self, other: f64Like) -> f64: ...
210213
@method(egg_fn="%")
211214
def __mod__(self, other: f64Like) -> f64: ...
212215

216+
@method(egg_fn="^")
217+
def __pow__(self, other: f64Like) -> f64: ...
218+
213219
def __radd__(self, other: f64Like) -> f64: ...
214220

215221
def __rsub__(self, other: f64Like) -> f64: ...
@@ -282,6 +288,22 @@ def remove(self, key: T) -> Map[T, V]: ...
282288
def rebuild(self) -> Map[T, V]: ...
283289

284290

291+
TO = TypeVar("TO")
292+
VO = TypeVar("VO")
293+
294+
converter(
295+
dict,
296+
Map,
297+
lambda t: reduce(
298+
(lambda acc, kv: acc.insert(convert(kv[0], get_type_args()[0]), convert(kv[1], get_type_args()[1]))),
299+
t.items(),
300+
Map[get_type_args()].empty(), # type: ignore[misc]
301+
),
302+
)
303+
304+
MapLike: TypeAlias = Map[T, V] | dict[TO, VO]
305+
306+
285307
class Set(Expr, Generic[T], builtin=True):
286308
@method(egg_fn="set-of")
287309
def __init__(self, *args: T) -> None: ...
@@ -315,6 +337,17 @@ def __and__(self, other: Set[T]) -> Set[T]: ...
315337
def rebuild(self) -> Set[T]: ...
316338

317339

340+
converter(
341+
set,
342+
Set,
343+
lambda t: Set[get_type_args()[0]]( # type: ignore[misc,operator]
344+
*(convert(x, get_type_args()[0]) for x in t)
345+
),
346+
)
347+
348+
SetLike: TypeAlias = Set[T] | set[TO]
349+
350+
318351
class Rational(Expr, builtin=True):
319352
@method(egg_fn="rational")
320353
def __init__(self, num: i64Like, den: i64Like) -> None: ...
@@ -415,6 +448,18 @@ def remove(self, index: i64Like) -> Vec[T]: ...
415448
def set(self, index: i64Like, value: T) -> Vec[T]: ...
416449

417450

451+
for sequence_type in (list, tuple):
452+
converter(
453+
sequence_type,
454+
Vec,
455+
lambda t: Vec[get_type_args()[0]]( # type: ignore[misc,operator]
456+
*(convert(x, get_type_args()[0]) for x in t)
457+
),
458+
)
459+
460+
VecLike: TypeAlias = Vec[T] | tuple[TO, ...] | list[TO]
461+
462+
418463
class PyObject(Expr, builtin=True):
419464
def __init__(self, value: object) -> None: ...
420465

0 commit comments

Comments
 (0)