diff --git a/docs/changelog.md b/docs/changelog.md index cbccdfce..43c902ee 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,7 @@ _This project uses semantic versioning_ ## UNRELEASED +- Emit warnings when functions omitted when visualizing [#336](https://github.com/egraphs-good/egglog-python/pull/336) ## 11.1.0 (2025-08-21) - Allow changing number of threads with env variable [#330](https://github.com/egraphs-good/egglog-python/pull/330) diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 20904852..cc07d7e7 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -23,6 +23,7 @@ get_type_hints, overload, ) +from warnings import warn import graphviz from typing_extensions import Never, ParamSpec, Self, Unpack, assert_never @@ -1024,6 +1025,12 @@ def _serialize( max_calls_per_function=max_calls_per_function, include_temporary_functions=include_temporary_functions, ) + if serialized.discarded_functions: + msg = ", ".join(set(self._state.possible_egglog_functions(serialized.discarded_functions))) + warn(f"Omitted: {msg}", stacklevel=3) + if serialized.truncated_functions: + msg = ", ".join(set(self._state.possible_egglog_functions(serialized.truncated_functions))) + warn(f"Truncated: {msg}", stacklevel=3) if split_primitive_outputs or split_functions: additional_ops = set(map(self._callable_to_egg, split_functions)) serialized.split_classes(self._egraph, additional_ops) diff --git a/python/egglog/egraph_state.py b/python/egglog/egraph_state.py index 8ad314f2..aee48900 100644 --- a/python/egglog/egraph_state.py +++ b/python/egglog/egraph_state.py @@ -354,6 +354,14 @@ def op_mapping(self) -> dict[str, str]: if len(v) == 1 } + def possible_egglog_functions(self, names: list[str]) -> Iterable[str]: + """ + Given a list of egglog functions, returns all the possible Python function strings + """ + for name in names: + for c in self.egg_fn_to_callable_refs[name]: + yield pretty_callable_ref(self.__egg_decls__, c) + def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl, transform_let: bool = True) -> bindings._Expr: # transform all expressions with multiple parents into a let binding, so that less expressions # are sent to egglog. Only for performance reasons. diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index c5bb13e2..9bb543c3 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -958,6 +958,28 @@ def __hash__(self) -> int: assert hash(A()) == 42 +def test_serialize_warning_max_functions(): + class A(Expr): + def __init__(self) -> None: ... + + egraph = EGraph() + egraph.register(A()) + with pytest.warns(UserWarning, match="A"): + egraph._serialize(max_functions=0) + + +def test_serialize_warning_max_calls(): + class A(Expr): ... + + @function + def f(x: StringLike) -> A: ... + + egraph = EGraph() + egraph.register(f("a"), f("b")) + with pytest.warns(UserWarning, match="f"): + egraph._serialize(max_calls_per_function=1) + + EXAMPLE_FILES = list((pathlib.Path(__file__).parent / "../egglog/examples").glob("*.py"))