Skip to content

Commit f8d552c

Browse files
Merge pull request #336 from egraphs-good/warning-viz
Emit warnings when functions omitted when visualizing
2 parents 0fc9674 + f85b35d commit f8d552c

File tree

4 files changed

+38
-0
lines changed

4 files changed

+38
-0
lines changed

docs/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ _This project uses semantic versioning_
44

55
## UNRELEASED
66

7+
- Emit warnings when functions omitted when visualizing [#336](https://github.com/egraphs-good/egglog-python/pull/336)
78
## 11.1.0 (2025-08-21)
89

910
- Allow changing number of threads with env variable [#330](https://github.com/egraphs-good/egglog-python/pull/330)

python/egglog/egraph.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
get_type_hints,
2424
overload,
2525
)
26+
from warnings import warn
2627

2728
import graphviz
2829
from typing_extensions import Never, ParamSpec, Self, Unpack, assert_never
@@ -1024,6 +1025,12 @@ def _serialize(
10241025
max_calls_per_function=max_calls_per_function,
10251026
include_temporary_functions=include_temporary_functions,
10261027
)
1028+
if serialized.discarded_functions:
1029+
msg = ", ".join(set(self._state.possible_egglog_functions(serialized.discarded_functions)))
1030+
warn(f"Omitted: {msg}", stacklevel=3)
1031+
if serialized.truncated_functions:
1032+
msg = ", ".join(set(self._state.possible_egglog_functions(serialized.truncated_functions)))
1033+
warn(f"Truncated: {msg}", stacklevel=3)
10271034
if split_primitive_outputs or split_functions:
10281035
additional_ops = set(map(self._callable_to_egg, split_functions))
10291036
serialized.split_classes(self._egraph, additional_ops)

python/egglog/egraph_state.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,14 @@ def op_mapping(self) -> dict[str, str]:
354354
if len(v) == 1
355355
}
356356

357+
def possible_egglog_functions(self, names: list[str]) -> Iterable[str]:
358+
"""
359+
Given a list of egglog functions, returns all the possible Python function strings
360+
"""
361+
for name in names:
362+
for c in self.egg_fn_to_callable_refs[name]:
363+
yield pretty_callable_ref(self.__egg_decls__, c)
364+
357365
def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl, transform_let: bool = True) -> bindings._Expr:
358366
# transform all expressions with multiple parents into a let binding, so that less expressions
359367
# are sent to egglog. Only for performance reasons.

python/tests/test_high_level.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,28 @@ def __hash__(self) -> int:
958958
assert hash(A()) == 42
959959

960960

961+
def test_serialize_warning_max_functions():
962+
class A(Expr):
963+
def __init__(self) -> None: ...
964+
965+
egraph = EGraph()
966+
egraph.register(A())
967+
with pytest.warns(UserWarning, match="A"):
968+
egraph._serialize(max_functions=0)
969+
970+
971+
def test_serialize_warning_max_calls():
972+
class A(Expr): ...
973+
974+
@function
975+
def f(x: StringLike) -> A: ...
976+
977+
egraph = EGraph()
978+
egraph.register(f("a"), f("b"))
979+
with pytest.warns(UserWarning, match="f"):
980+
egraph._serialize(max_calls_per_function=1)
981+
982+
961983
EXAMPLE_FILES = list((pathlib.Path(__file__).parent / "../egglog/examples").glob("*.py"))
962984

963985

0 commit comments

Comments
 (0)