Skip to content

Commit 96bf07f

Browse files
Mostly working code!
1 parent f816fd1 commit 96bf07f

26 files changed

+1172
-720
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,8 @@ repos:
66
- id: ruff
77
args: [--fix, --exit-non-zero-on-fix]
88
- id: ruff-format
9+
- repo: https://github.com/astral-sh/uv-pre-commit
10+
# uv version.
11+
rev: 0.5.31
12+
hooks:
13+
- id: uv-lock

docs/changelog.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@ _This project uses semantic versioning_
99
- Add better error message when using @function in class (thanks @shinawy)
1010
- Add error method if `@method` decorator is in wrong place
1111
- Subsumes lambda functions after replacing
12-
- Add working loopnest test
12+
- Add working loopnest test and rewrite array api suport to be more general
1313
- Improve tracebacks on failing conversions.
1414
- Use `add_note` for exception to add more context, instead of raising a new exception, to make it easier to debug.
15+
- Add conversions from vecs (TODO)
16+
- Open files with webbrowser instead of internal graphviz util for better support
1517

1618
## 8.0.1 (2024-10-24)
1719

pyproject.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,16 @@ classifiers = [
2727
"Typing :: Typed",
2828
]
2929
dependencies = ["typing-extensions", "black", "graphviz", "anywidget"]
30+
dynamic = ["version"]
3031

3132
[project.optional-dependencies]
3233

3334
array = [
3435
"scikit-learn",
3536
"array_api_compat",
36-
"numba==0.59.1",
37-
"llvmlite==0.42.0",
37+
"numba>=0.59.1",
38+
"llvmlite>=0.42.0",
39+
"numpy>2",
3840
]
3941
dev = [
4042
"ruff",
@@ -215,7 +217,7 @@ preview = true
215217

216218
[tool.ruff.lint.per-file-ignores]
217219
# Don't require annotations for tests
218-
"python/tests/**" = ["ANN001", "ANN201"]
220+
"python/tests/**" = ["ANN001", "ANN201", "INP001"]
219221

220222
[tool.mypy]
221223
ignore_missing_imports = true

python/egglog/builtins.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from typing_extensions import TypeVarTuple, Unpack
1313

14-
from .conversion import converter, get_type_args
14+
from .conversion import convert, converter, get_type_args
1515
from .egraph import Expr, Unit, function, get_current_ruleset, method
1616
from .functionalize import functionalize
1717
from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction
@@ -418,6 +418,9 @@ def remove(self, index: i64Like) -> Vec[T]: ...
418418
def set(self, index: i64Like, value: T) -> Vec[T]: ...
419419

420420

421+
converter(tuple, Vec, lambda t: Vec(*(convert(x, get_type_args()[0]) for x in t)))
422+
423+
421424
class PyObject(Expr, builtin=True):
422425
def __init__(self, value: object) -> None: ...
423426

python/egglog/conversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,11 @@ def identity(x: object) -> object:
148148
TYPE_ARGS = ContextVar[tuple[RuntimeClass, ...]]("TYPE_ARGS")
149149

150150

151-
def get_type_args() -> tuple[RuntimeClass, ...]:
151+
def get_type_args() -> tuple[type, ...]:
152152
"""
153153
Get the type args for the type being converted.
154154
"""
155-
return TYPE_ARGS.get()
155+
return cast(tuple[type, ...], TYPE_ARGS.get())
156156

157157

158158
@contextmanager

python/egglog/egraph.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections.abc import Callable, Generator, Iterable
99
from contextvars import ContextVar, Token
1010
from dataclasses import InitVar, dataclass, field
11+
from functools import partial
1112
from inspect import Parameter, currentframe, signature
1213
from types import FrameType, FunctionType
1314
from typing import (
@@ -1633,7 +1634,9 @@ class UnstableCombinedRuleset(Schedule):
16331634

16341635
def __post_init__(self, rulesets: list[Ruleset | UnstableCombinedRuleset]) -> None:
16351636
self.schedule = RunDecl(self.__egg_name__, ())
1636-
self.__egg_decls_thunk__ = Thunk.fn(self._create_egg_decls, *rulesets)
1637+
# Don't use thunk so that this is re-evaluated each time its requsted, so that additions inside will
1638+
# be added after its been evaluated once.
1639+
self.__egg_decls_thunk__ = partial(self._create_egg_decls, *rulesets)
16371640

16381641
@property
16391642
def __egg_name__(self) -> str:

0 commit comments

Comments
 (0)