diff --git a/docs/changelog.md b/docs/changelog.md index f75c339f..e73c670b 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,7 @@ _This project uses semantic versioning_ ## UNRELEASED +- Add `all_function_sizes` and `function_size` EGraph methods [#338](https://github.com/egraphs-good/egglog-python/pull/338) - Fix execution of docs [#337](https://github.com/egraphs-good/egglog-python/pull/337) - Emit warnings when functions omitted when visualizing [#336](https://github.com/egraphs-good/egglog-python/pull/336) - Bump Egglog version [#335](https://github.com/egraphs-good/egglog-python/pull/335) diff --git a/docs/reference/egglog-translation.md b/docs/reference/egglog-translation.md index 9ff67f60..3c5f81e7 100644 --- a/docs/reference/egglog-translation.md +++ b/docs/reference/egglog-translation.md @@ -507,6 +507,21 @@ with egraph: egraph.check_fail(eq(Math(0)).to(Math(1))) ``` +## Function Sizes + +The `(print-size ?)` command is translated into either `egraph.function_size(fn)` to get the number of +rows of one function or `egraph.all_function_sizes()` to get a list of all the function sizes: + +```{code-cell} python +# (function-size Math) +egraph.function_size(Math) +``` + +```{code-cell} python +# (function-size) +egraph.all_function_sizes() +``` + ## Include The `(include )` command is used to add modularity, by allowing you to pull in the source from another egglog file into the current file. diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index cc07d7e7..778458e0 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -805,7 +805,7 @@ class GraphvizKwargs(TypedDict, total=False): max_calls_per_function: int | None n_inline_leaves: int split_primitive_outputs: bool - split_functions: list[object] + split_functions: list[ExprCallable] include_temporary_functions: bool @@ -854,7 +854,7 @@ def input(self, fn: Callable[..., String], path: str) -> None: """ self._egraph.run_program(bindings.Input(span(1), self._callable_to_egg(fn), path)) - def _callable_to_egg(self, fn: object) -> str: + def _callable_to_egg(self, fn: ExprCallable) -> str: ref, decls = resolve_callable(fn) self._add_decls(decls) return self._state.callable_ref_to_egg(ref)[0] @@ -1179,6 +1179,37 @@ def _command_to_egg(self, cmd: Command) -> bindings._Command | None: assert_never(cmd) return self._state.command_to_egg(cmd_decl, ruleset_name) + def function_size(self, fn: ExprCallable) -> int: + """ + Returns the number of rows in a certain function + """ + egg_name = self._callable_to_egg(fn) + (output,) = self._egraph.run_program(bindings.PrintSize(span(1), egg_name)) + assert isinstance(output, bindings.PrintFunctionSize) + return output.size + + def all_function_sizes(self) -> list[tuple[ExprCallable, int]]: + """ + Returns a list of all functions and their sizes. + """ + (output,) = self._egraph.run_program(bindings.PrintSize(span(1), None)) + assert isinstance(output, bindings.PrintAllFunctionsSize) + return [ + ( + cast( + "ExprCallable", + create_callable(self._state.__egg_decls__, next(iter(refs))), + ), + size, + ) + for (name, size) in output.sizes + if (refs := self._state.egg_fn_to_callable_refs[name]) + ] + + +# Either a constant or a function. +ExprCallable: TypeAlias = Callable[..., BaseExpr] | BaseExpr + @dataclass(frozen=True) class _WrappedMethod: diff --git a/python/egglog/egraph_state.py b/python/egglog/egraph_state.py index aee48900..a4acfc0b 100644 --- a/python/egglog/egraph_state.py +++ b/python/egglog/egraph_state.py @@ -225,7 +225,7 @@ def fact_to_egg(self, fact: FactDecl) -> bindings._Fact: case _: assert_never(fact) - def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]: + def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]: # noqa: C901, PLR0912 """ Returns the egg function name for a callable reference, registering it if it is not already registered. @@ -245,9 +245,12 @@ def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]: case ConstantDecl(tp, _): # Use constructor decleration instead of constant b/c constants cannot be extracted # https://github.com/egraphs-good/egglog/issues/334 - self.egraph.run_program( - bindings.Constructor(span(), egg_name, bindings.Schema([], self.type_ref_to_egg(tp)), None, False) - ) + is_function = self.__egg_decls__._classes[tp.name].builtin + schema = bindings.Schema([], self.type_ref_to_egg(tp)) + if is_function: + self.egraph.run_program(bindings.FunctionCommand(span(), egg_name, schema, None)) + else: + self.egraph.run_program(bindings.Constructor(span(), egg_name, schema, None, False)) case FunctionDecl(signature, builtin, _, merge): if isinstance(signature, FunctionSignature): reverse_args = signature.reverse_args diff --git a/python/egglog/runtime.py b/python/egglog/runtime.py index 14698d0f..f2316ad5 100644 --- a/python/egglog/runtime.py +++ b/python/egglog/runtime.py @@ -20,6 +20,8 @@ from itertools import zip_longest from typing import TYPE_CHECKING, Any, TypeVar, Union, cast, get_args, get_origin +from typing_extensions import assert_never + from .declarations import * from .pretty import * from .thunk import Thunk @@ -36,6 +38,7 @@ "RuntimeClass", "RuntimeExpr", "RuntimeFunction", + "create_callable", "define_expr_method", "resolve_callable", "resolve_type_annotation", @@ -340,7 +343,7 @@ def __repr__(self) -> str: # Make hashable so can go in Union def __hash__(self) -> int: - return hash((id(self.__egg_decls_thunk__), self.__egg_tp__)) + return hash(self.__egg_tp__) def __eq__(self, other: object) -> bool: """ @@ -478,6 +481,9 @@ def __str__(self) -> str: bound_tp_params = args return pretty_callable_ref(self.__egg_decls__, self.__egg_ref__, first_arg, bound_tp_params) + def __repr__(self) -> str: + return str(self) + def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args: bool) -> Signature: """ @@ -670,11 +676,12 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]: """ Resolves a runtime callable into a ref """ + # TODO: Make runtime class work with __match_args__ + if isinstance(callable, RuntimeClass): + return InitRef(callable.__egg_tp__.name), callable.__egg_decls__ match callable: case RuntimeFunction(decls, ref, _): return ref(), decls() - case RuntimeClass(thunk, tp): - return InitRef(tp.name), thunk() case RuntimeExpr(decl_thunk, expr_thunk): if not isinstance((expr := expr_thunk().expr), CallDecl) or not isinstance( expr.callable, ConstantRef | ClassVariableRef @@ -683,3 +690,23 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]: return expr.callable, decl_thunk() case _: raise NotImplementedError(f"Cannot turn {callable} of type {type(callable)} into a callable ref") + + +def create_callable(decls: Declarations, ref: CallableRef) -> RuntimeClass | RuntimeFunction | RuntimeExpr: + """ + Creates a callable object from a callable ref. This might not actually be callable, if the ref is a constant + or classvar then it is a value + """ + match ref: + case InitRef(name): + return RuntimeClass(Thunk.value(decls), TypeRefWithVars(name)) + case FunctionRef() | MethodRef() | ClassMethodRef() | PropertyRef() | UnnamedFunctionRef(): + bound = JustTypeRef(ref.class_name) if isinstance(ref, ClassMethodRef) else None + return RuntimeFunction(Thunk.value(decls), Thunk.value(ref), bound) + case ConstantRef(name): + tp = decls._constants[name].type_ref + case ClassVariableRef(cls_name, var_name): + tp = decls._classes[cls_name].class_variables[var_name].type_ref + case _: + assert_never(ref) + return RuntimeExpr.__from_values__(decls, TypedExprDecl(tp, CallDecl(ref))) diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 9bb543c3..d883fb6f 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -139,13 +139,13 @@ class A(Expr): def test_class_vars(): egraph = EGraph() - class A(Expr): - ONE: ClassVar[A] + class B(Expr): + ONE: ClassVar[B] - two = constant("two", A) + two = constant("two", B) - egraph.register(union(A.ONE).with_(two)) - egraph.check(eq(A.ONE).to(two)) + egraph.register(union(B.ONE).with_(two)) + egraph.check(eq(B.ONE).to(two)) def test_extract_constant_twice(): @@ -987,3 +987,64 @@ def f(x: StringLike) -> A: ... @pytest.mark.parametrize("name", [f.stem for f in EXAMPLE_FILES if f.stem != "__init__"]) def test_example(name): importlib.import_module(f"egglog.examples.{name}") + + +@function +def f() -> i64: ... + + +class E(Expr): + X: ClassVar[i64] + + def __init__(self) -> None: ... + def m(self) -> i64: ... + + @property + def p(self) -> i64: ... + + @classmethod + def cm(cls) -> i64: ... + + +egraph = EGraph() + +C = constant("C", i64) + +zero = i64(0) +egraph.register( + set_(f()).to(zero), + set_(E().m()).to(zero), + set_(E.X).to(zero), + set_(E().p).to(zero), + set_(C).to(zero), + set_(E.cm()).to(zero), +) + + +@pytest.mark.parametrize( + "c", + [ + pytest.param(E, id="init"), + pytest.param(f, id="function"), + pytest.param(E.m, id="method"), + pytest.param(E.X, id="class var"), + pytest.param(E.p, id="property"), + pytest.param(C, id="constant"), + pytest.param(E.cm, id="class method"), + ], +) +def test_function_size(c): + assert egraph.function_size(c) == 1 + + +def test_all_function_size(): + res = egraph.all_function_sizes() + assert set(res) == { + (E, 1), + (f, 1), + (E.m, 1), + (E.X, 1), + (E.p, 1), + (C, 1), + (E.cm, 1), + }