Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ _This project uses semantic versioning_

- Fix pretty printing of lambda functions
- Add support for subsuming rewrite generated by default function and method definitions
- Add better error message when using @function in class (thanks @shinawy)

## 8.0.1 (2024-10-24)

Expand Down
39 changes: 22 additions & 17 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,23 +577,25 @@ def _generate_class_decls( # noqa: C901,PLR0912
decl = FunctionDecl(special_function_name, builtin=True, egg_name=egg_fn)
decls.set_function_decl(ref, decl)
continue

_, add_rewrite = _fn_decl(
decls,
egg_fn,
ref,
fn,
locals,
default,
cost,
merge,
on_merge,
mutates,
builtin,
ruleset=ruleset,
unextractable=unextractable,
subsume=subsume,
)
try:
_, add_rewrite = _fn_decl(
decls,
egg_fn,
ref,
fn,
locals,
default,
cost,
merge,
on_merge,
mutates,
builtin,
ruleset=ruleset,
unextractable=unextractable,
subsume=subsume,
)
except ValueError as e:
raise ValueError(f"Error processing {cls_name}.{method_name}: {e}") from e

if not builtin and not isinstance(ref, InitRef) and not mutates:
add_default_funcs.append(add_rewrite)
Expand Down Expand Up @@ -721,6 +723,9 @@ def _fn_decl(
"""
Sets the function decl for the function object and returns the ref as well as a thunk that sets the default callable.
"""
if isinstance(fn, RuntimeFunction):
msg = "Inside of classes, wrap methods with the `method` decorator, not `function`"
raise ValueError(msg) # noqa: TRY004
if not isinstance(fn, FunctionType):
raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")

Expand Down
12 changes: 12 additions & 0 deletions python/tests/test_high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,3 +762,15 @@ def test_inserting_map(self):

def test_creating_map(self):
EGraph().simplify(Map[String, i64].empty(), 1)


def test_helpful_error_function_class():
class E(Expr):
@function(cost=10)
def __init__(self) -> None: ...

with pytest.raises(
ValueError,
match="Error processing E.__init__: Inside of classes, wrap methods with the `method` decorator, not `function`",
):
E()