Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
get_type_hints,
overload,
)
from warnings import warn

import graphviz
from typing_extensions import Never, ParamSpec, Self, Unpack, assert_never
Expand Down Expand Up @@ -1024,6 +1025,12 @@ def _serialize(
max_calls_per_function=max_calls_per_function,
include_temporary_functions=include_temporary_functions,
)
if serialized.discarded_functions:
msg = ", ".join(set(self._state.possible_egglog_functions(serialized.discarded_functions)))
warn(f"Ommitted: {msg}", stacklevel=3)
if serialized.truncated_functions:
msg = ", ".join(set(self._state.possible_egglog_functions(serialized.truncated_functions)))
warn(f"Truncated: {msg}", stacklevel=3)
if split_primitive_outputs or split_functions:
additional_ops = set(map(self._callable_to_egg, split_functions))
serialized.split_classes(self._egraph, additional_ops)
Expand Down
8 changes: 8 additions & 0 deletions python/egglog/egraph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,14 @@ def op_mapping(self) -> dict[str, str]:
if len(v) == 1
}

def possible_egglog_functions(self, names: list[str]) -> Iterable[str]:
"""
Given a list of egglog functions, returns all the possible Python function strings
"""
for name in names:
for c in self.egg_fn_to_callable_refs[name]:
yield pretty_callable_ref(self.__egg_decls__, c)

def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl, transform_let: bool = True) -> bindings._Expr:
# transform all expressions with multiple parents into a let binding, so that less expressions
# are sent to egglog. Only for performance reasons.
Expand Down
22 changes: 22 additions & 0 deletions python/tests/test_high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,28 @@ def __hash__(self) -> int:
assert hash(A()) == 42


def test_serialize_warning_max_functions():
class A(Expr):
def __init__(self) -> None: ...

egraph = EGraph()
egraph.register(A())
with pytest.warns(UserWarning, match="A"):
egraph._serialize(max_functions=0)


def test_serialize_warning_max_calls():
class A(Expr): ...

@function
def f(x: StringLike) -> A: ...

egraph = EGraph()
egraph.register(f("a"), f("b"))
with pytest.warns(UserWarning, match="f"):
egraph._serialize(max_calls_per_function=1)


EXAMPLE_FILES = list((pathlib.Path(__file__).parent / "../egglog/examples").glob("*.py"))


Expand Down
Loading