Skip to content

Commit dbf79f8

Browse files
Emit warnings when functions omitted when visualizing
Adds support for upstream egglog change to emit when serializing and functions or calls were omitted based on maxmimums
1 parent 21b0777 commit dbf79f8

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

python/egglog/egraph.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
get_type_hints,
2424
overload,
2525
)
26+
from warnings import warn
2627

2728
import graphviz
2829
from typing_extensions import Never, ParamSpec, Self, Unpack, assert_never
@@ -1024,6 +1025,12 @@ def _serialize(
10241025
max_calls_per_function=max_calls_per_function,
10251026
include_temporary_functions=include_temporary_functions,
10261027
)
1028+
if serialized.discarded_functions:
1029+
msg = ", ".join(set(self._state.possible_egglog_functions(serialized.discarded_functions)))
1030+
warn(f"Ommitted: {msg}", stacklevel=3)
1031+
if serialized.truncated_functions:
1032+
msg = ", ".join(set(self._state.possible_egglog_functions(serialized.truncated_functions)))
1033+
warn(f"Truncated: {msg}", stacklevel=3)
10271034
if split_primitive_outputs or split_functions:
10281035
additional_ops = set(map(self._callable_to_egg, split_functions))
10291036
serialized.split_classes(self._egraph, additional_ops)

python/egglog/egraph_state.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,14 @@ def op_mapping(self) -> dict[str, str]:
354354
if len(v) == 1
355355
}
356356

357+
def possible_egglog_functions(self, names: list[str]) -> Iterable[str]:
358+
"""
359+
Given a list of egglog functions, returns all the possible Python function strings
360+
"""
361+
for name in names:
362+
for c in self.egg_fn_to_callable_refs[name]:
363+
yield pretty_callable_ref(self.__egg_decls__, c)
364+
357365
def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl, transform_let: bool = True) -> bindings._Expr:
358366
# transform all expressions with multiple parents into a let binding, so that less expressions
359367
# are sent to egglog. Only for performance reasons.

python/tests/test_high_level.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,28 @@ def __hash__(self) -> int:
958958
assert hash(A()) == 42
959959

960960

961+
def test_serialize_warning_max_functions():
962+
class A(Expr):
963+
def __init__(self) -> None: ...
964+
965+
egraph = EGraph()
966+
egraph.register(A())
967+
with pytest.warns(UserWarning, match="A"):
968+
egraph._serialize(max_functions=0)
969+
970+
971+
def test_serialize_warning_max_calls():
972+
class A(Expr): ...
973+
974+
@function
975+
def f(x: StringLike) -> A: ...
976+
977+
egraph = EGraph()
978+
egraph.register(f("a"), f("b"))
979+
with pytest.warns(UserWarning, match="f"):
980+
egraph._serialize(max_calls_per_function=1)
981+
982+
961983
EXAMPLE_FILES = list((pathlib.Path(__file__).parent / "../egglog/examples").glob("*.py"))
962984

963985

0 commit comments

Comments
 (0)