diff --git a/docs/changelog.md b/docs/changelog.md index 3da9ffa2..44417898 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,7 @@ _This project uses semantic versioning_ ## UNRELEASED +- Add `egraph.function_values(fn)` to export all function values like `print-function` [#340](https://github.com/egraphs-good/egglog-python/pull/340) - Add `egraph.stats()` method to print overall stats [#339](https://github.com/egraphs-good/egglog-python/pull/339) - Add `all_function_sizes` and `function_size` EGraph methods [#338](https://github.com/egraphs-good/egglog-python/pull/338) - Fix execution of docs [#337](https://github.com/egraphs-good/egglog-python/pull/337) diff --git a/docs/reference/egglog-translation.md b/docs/reference/egglog-translation.md index 7bf13607..fcbb034f 100644 --- a/docs/reference/egglog-translation.md +++ b/docs/reference/egglog-translation.md @@ -531,6 +531,15 @@ The `(print-stats)` command is translated into `egraph.stats()` to get overall s egraph.stats() ``` +## Function Values + +The `print-function` command is translated into `egraph.function_values(fn, [length]?)` to get the values of a specific function. Note that the function provided must either return a primitive or be created with a merge function. + +```{code-cell} python +# (print-function fib 3) +egraph.function_values(fib, length=3) +``` + ## Include The `(include )` command is used to add modularity, by allowing you to pull in the source from another egglog file into the current file. diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index d73ab7d1..ac2a64b2 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -852,12 +852,12 @@ def input(self, fn: Callable[..., String], path: str) -> None: """ Loads a CSV file and sets it as *input, output of the function. """ - self._egraph.run_program(bindings.Input(span(1), self._callable_to_egg(fn), path)) + self._egraph.run_program(bindings.Input(span(1), self._callable_to_egg(fn)[1], path)) - def _callable_to_egg(self, fn: ExprCallable) -> str: + def _callable_to_egg(self, fn: ExprCallable) -> tuple[CallableRef, str]: ref, decls = resolve_callable(fn) self._add_decls(decls) - return self._state.callable_ref_to_egg(ref)[0] + return ref, self._state.callable_ref_to_egg(ref)[0] # TODO: Change let to be action... def let(self, name: str, expr: BASE_EXPR) -> BASE_EXPR: @@ -961,15 +961,15 @@ def extract(self, expr: BASE_EXPR, include_cost: bool = False) -> BASE_EXPR | tu runtime_expr = to_runtime_expr(expr) extract_report = self._run_extract(runtime_expr, 0) assert isinstance(extract_report, bindings.ExtractBest) - (new_typed_expr,) = self._state.exprs_from_egg( - extract_report.termdag, [extract_report.term], runtime_expr.__egg_typed_expr__.tp - ) - - res = cast("BASE_EXPR", RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr)) + res = self._from_termdag(extract_report.termdag, extract_report.term, runtime_expr.__egg_typed_expr__.tp) if include_cost: return res, extract_report.cost return res + def _from_termdag(self, termdag: bindings.TermDag, term: bindings._Term, tp: JustTypeRef) -> Any: + (new_typed_expr,) = self._state.exprs_from_egg(termdag, [term], tp) + return RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr) + def extract_multiple(self, expr: BASE_EXPR, n: int) -> list[BASE_EXPR]: """ Extract multiple expressions from the egraph. @@ -1040,7 +1040,7 @@ def _serialize( msg = ", ".join(set(self._state.possible_egglog_functions(serialized.truncated_functions))) warn(f"Truncated: {msg}", stacklevel=3) if split_primitive_outputs or split_functions: - additional_ops = set(map(self._callable_to_egg, split_functions)) + additional_ops = {self._callable_to_egg(f)[1] for f in split_functions} serialized.split_classes(self._egraph, additional_ops) serialized.map_ops(self._state.op_mapping()) @@ -1191,7 +1191,7 @@ def function_size(self, fn: ExprCallable) -> int: """ Returns the number of rows in a certain function """ - egg_name = self._callable_to_egg(fn) + egg_name = self._callable_to_egg(fn)[1] (output,) = self._egraph.run_program(bindings.PrintSize(span(1), egg_name)) assert isinstance(output, bindings.PrintFunctionSize) return output.size @@ -1214,6 +1214,27 @@ def all_function_sizes(self) -> list[tuple[ExprCallable, int]]: if (refs := self._state.egg_fn_to_callable_refs[name]) ] + def function_values( + self, fn: Callable[..., BASE_EXPR] | BASE_EXPR, length: int | None = None + ) -> dict[BASE_EXPR, BASE_EXPR]: + """ + Given a callable that is a "function", meaning it returns a primitive or has a merge set, + returns a mapping of the function applied with its arguments to its values + + If length is specified, only the first `length` values will be returned. + """ + ref, egg_name = self._callable_to_egg(fn) + cmd = bindings.PrintFunction(span(1), egg_name, length, None, bindings.DefaultPrintFunctionMode()) + (output,) = self._egraph.run_program(cmd) + assert isinstance(output, bindings.PrintFunctionOutput) + signature = self.__egg_decls__.get_callable_decl(ref).signature + assert isinstance(signature, FunctionSignature) + tp = signature.semantic_return_type.to_just() + return { + self._from_termdag(output.termdag, call, tp): self._from_termdag(output.termdag, res, tp) + for (call, res) in output.terms + } + # Either a constant or a function. ExprCallable: TypeAlias = Callable[..., BaseExpr] | BaseExpr diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index bdca5ddc..2771feaa 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -1052,3 +1052,14 @@ def test_all_function_size(): def test_overall_run_report(): assert EGraph().stats() + + +def test_function_values(): + egraph = EGraph() + + @function + def f(x: i64Like) -> i64: ... + + egraph.register(set_(f(i64(1))).to(i64(2))) + values = egraph.function_values(f) + assert values == {f(i64(1)): i64(2)}