Skip to content

Commit 1a2ee10

Browse files
Merge pull request #338 from egraphs-good/extract-report-warning
Add `all_function_sizes` and `function_size` EGraph methods
2 parents b8c53e0 + b5d7e7f commit 1a2ee10

File tree

6 files changed

+152
-14
lines changed

6 files changed

+152
-14
lines changed

docs/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ _This project uses semantic versioning_
44

55
## UNRELEASED
66

7+
- Add `all_function_sizes` and `function_size` EGraph methods [#338](https://github.com/egraphs-good/egglog-python/pull/338)
78
- Fix execution of docs [#337](https://github.com/egraphs-good/egglog-python/pull/337)
89
- Emit warnings when functions omitted when visualizing [#336](https://github.com/egraphs-good/egglog-python/pull/336)
910
- Bump Egglog version [#335](https://github.com/egraphs-good/egglog-python/pull/335)

docs/reference/egglog-translation.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,21 @@ with egraph:
507507
egraph.check_fail(eq(Math(0)).to(Math(1)))
508508
```
509509

510+
## Function Sizes
511+
512+
The `(print-size <function name>?)` command is translated into either `egraph.function_size(fn)` to get the number of
513+
rows of one function or `egraph.all_function_sizes()` to get a list of all the function sizes:
514+
515+
```{code-cell} python
516+
# (function-size Math)
517+
egraph.function_size(Math)
518+
```
519+
520+
```{code-cell} python
521+
# (function-size)
522+
egraph.all_function_sizes()
523+
```
524+
510525
## Include
511526

512527
The `(include <path>)` command is used to add modularity, by allowing you to pull in the source from another egglog file into the current file.

python/egglog/egraph.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ class GraphvizKwargs(TypedDict, total=False):
805805
max_calls_per_function: int | None
806806
n_inline_leaves: int
807807
split_primitive_outputs: bool
808-
split_functions: list[object]
808+
split_functions: list[ExprCallable]
809809
include_temporary_functions: bool
810810

811811

@@ -854,7 +854,7 @@ def input(self, fn: Callable[..., String], path: str) -> None:
854854
"""
855855
self._egraph.run_program(bindings.Input(span(1), self._callable_to_egg(fn), path))
856856

857-
def _callable_to_egg(self, fn: object) -> str:
857+
def _callable_to_egg(self, fn: ExprCallable) -> str:
858858
ref, decls = resolve_callable(fn)
859859
self._add_decls(decls)
860860
return self._state.callable_ref_to_egg(ref)[0]
@@ -1179,6 +1179,37 @@ def _command_to_egg(self, cmd: Command) -> bindings._Command | None:
11791179
assert_never(cmd)
11801180
return self._state.command_to_egg(cmd_decl, ruleset_name)
11811181

1182+
def function_size(self, fn: ExprCallable) -> int:
1183+
"""
1184+
Returns the number of rows in a certain function
1185+
"""
1186+
egg_name = self._callable_to_egg(fn)
1187+
(output,) = self._egraph.run_program(bindings.PrintSize(span(1), egg_name))
1188+
assert isinstance(output, bindings.PrintFunctionSize)
1189+
return output.size
1190+
1191+
def all_function_sizes(self) -> list[tuple[ExprCallable, int]]:
1192+
"""
1193+
Returns a list of all functions and their sizes.
1194+
"""
1195+
(output,) = self._egraph.run_program(bindings.PrintSize(span(1), None))
1196+
assert isinstance(output, bindings.PrintAllFunctionsSize)
1197+
return [
1198+
(
1199+
cast(
1200+
"ExprCallable",
1201+
create_callable(self._state.__egg_decls__, next(iter(refs))),
1202+
),
1203+
size,
1204+
)
1205+
for (name, size) in output.sizes
1206+
if (refs := self._state.egg_fn_to_callable_refs[name])
1207+
]
1208+
1209+
1210+
# Either a constant or a function.
1211+
ExprCallable: TypeAlias = Callable[..., BaseExpr] | BaseExpr
1212+
11821213

11831214
@dataclass(frozen=True)
11841215
class _WrappedMethod:

python/egglog/egraph_state.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def fact_to_egg(self, fact: FactDecl) -> bindings._Fact:
225225
case _:
226226
assert_never(fact)
227227

228-
def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]:
228+
def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]: # noqa: C901, PLR0912
229229
"""
230230
Returns the egg function name for a callable reference, registering it if it is not already registered.
231231
@@ -245,9 +245,12 @@ def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]:
245245
case ConstantDecl(tp, _):
246246
# Use constructor decleration instead of constant b/c constants cannot be extracted
247247
# https://github.com/egraphs-good/egglog/issues/334
248-
self.egraph.run_program(
249-
bindings.Constructor(span(), egg_name, bindings.Schema([], self.type_ref_to_egg(tp)), None, False)
250-
)
248+
is_function = self.__egg_decls__._classes[tp.name].builtin
249+
schema = bindings.Schema([], self.type_ref_to_egg(tp))
250+
if is_function:
251+
self.egraph.run_program(bindings.FunctionCommand(span(), egg_name, schema, None))
252+
else:
253+
self.egraph.run_program(bindings.Constructor(span(), egg_name, schema, None, False))
251254
case FunctionDecl(signature, builtin, _, merge):
252255
if isinstance(signature, FunctionSignature):
253256
reverse_args = signature.reverse_args

python/egglog/runtime.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from itertools import zip_longest
2121
from typing import TYPE_CHECKING, Any, TypeVar, Union, cast, get_args, get_origin
2222

23+
from typing_extensions import assert_never
24+
2325
from .declarations import *
2426
from .pretty import *
2527
from .thunk import Thunk
@@ -36,6 +38,7 @@
3638
"RuntimeClass",
3739
"RuntimeExpr",
3840
"RuntimeFunction",
41+
"create_callable",
3942
"define_expr_method",
4043
"resolve_callable",
4144
"resolve_type_annotation",
@@ -340,7 +343,7 @@ def __repr__(self) -> str:
340343

341344
# Make hashable so can go in Union
342345
def __hash__(self) -> int:
343-
return hash((id(self.__egg_decls_thunk__), self.__egg_tp__))
346+
return hash(self.__egg_tp__)
344347

345348
def __eq__(self, other: object) -> bool:
346349
"""
@@ -478,6 +481,9 @@ def __str__(self) -> str:
478481
bound_tp_params = args
479482
return pretty_callable_ref(self.__egg_decls__, self.__egg_ref__, first_arg, bound_tp_params)
480483

484+
def __repr__(self) -> str:
485+
return str(self)
486+
481487

482488
def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args: bool) -> Signature:
483489
"""
@@ -670,11 +676,12 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
670676
"""
671677
Resolves a runtime callable into a ref
672678
"""
679+
# TODO: Make runtime class work with __match_args__
680+
if isinstance(callable, RuntimeClass):
681+
return InitRef(callable.__egg_tp__.name), callable.__egg_decls__
673682
match callable:
674683
case RuntimeFunction(decls, ref, _):
675684
return ref(), decls()
676-
case RuntimeClass(thunk, tp):
677-
return InitRef(tp.name), thunk()
678685
case RuntimeExpr(decl_thunk, expr_thunk):
679686
if not isinstance((expr := expr_thunk().expr), CallDecl) or not isinstance(
680687
expr.callable, ConstantRef | ClassVariableRef
@@ -683,3 +690,23 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
683690
return expr.callable, decl_thunk()
684691
case _:
685692
raise NotImplementedError(f"Cannot turn {callable} of type {type(callable)} into a callable ref")
693+
694+
695+
def create_callable(decls: Declarations, ref: CallableRef) -> RuntimeClass | RuntimeFunction | RuntimeExpr:
696+
"""
697+
Creates a callable object from a callable ref. This might not actually be callable, if the ref is a constant
698+
or classvar then it is a value
699+
"""
700+
match ref:
701+
case InitRef(name):
702+
return RuntimeClass(Thunk.value(decls), TypeRefWithVars(name))
703+
case FunctionRef() | MethodRef() | ClassMethodRef() | PropertyRef() | UnnamedFunctionRef():
704+
bound = JustTypeRef(ref.class_name) if isinstance(ref, ClassMethodRef) else None
705+
return RuntimeFunction(Thunk.value(decls), Thunk.value(ref), bound)
706+
case ConstantRef(name):
707+
tp = decls._constants[name].type_ref
708+
case ClassVariableRef(cls_name, var_name):
709+
tp = decls._classes[cls_name].class_variables[var_name].type_ref
710+
case _:
711+
assert_never(ref)
712+
return RuntimeExpr.__from_values__(decls, TypedExprDecl(tp, CallDecl(ref)))

python/tests/test_high_level.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,13 @@ class A(Expr):
139139
def test_class_vars():
140140
egraph = EGraph()
141141

142-
class A(Expr):
143-
ONE: ClassVar[A]
142+
class B(Expr):
143+
ONE: ClassVar[B]
144144

145-
two = constant("two", A)
145+
two = constant("two", B)
146146

147-
egraph.register(union(A.ONE).with_(two))
148-
egraph.check(eq(A.ONE).to(two))
147+
egraph.register(union(B.ONE).with_(two))
148+
egraph.check(eq(B.ONE).to(two))
149149

150150

151151
def test_extract_constant_twice():
@@ -987,3 +987,64 @@ def f(x: StringLike) -> A: ...
987987
@pytest.mark.parametrize("name", [f.stem for f in EXAMPLE_FILES if f.stem != "__init__"])
988988
def test_example(name):
989989
importlib.import_module(f"egglog.examples.{name}")
990+
991+
992+
@function
993+
def f() -> i64: ...
994+
995+
996+
class E(Expr):
997+
X: ClassVar[i64]
998+
999+
def __init__(self) -> None: ...
1000+
def m(self) -> i64: ...
1001+
1002+
@property
1003+
def p(self) -> i64: ...
1004+
1005+
@classmethod
1006+
def cm(cls) -> i64: ...
1007+
1008+
1009+
egraph = EGraph()
1010+
1011+
C = constant("C", i64)
1012+
1013+
zero = i64(0)
1014+
egraph.register(
1015+
set_(f()).to(zero),
1016+
set_(E().m()).to(zero),
1017+
set_(E.X).to(zero),
1018+
set_(E().p).to(zero),
1019+
set_(C).to(zero),
1020+
set_(E.cm()).to(zero),
1021+
)
1022+
1023+
1024+
@pytest.mark.parametrize(
1025+
"c",
1026+
[
1027+
pytest.param(E, id="init"),
1028+
pytest.param(f, id="function"),
1029+
pytest.param(E.m, id="method"),
1030+
pytest.param(E.X, id="class var"),
1031+
pytest.param(E.p, id="property"),
1032+
pytest.param(C, id="constant"),
1033+
pytest.param(E.cm, id="class method"),
1034+
],
1035+
)
1036+
def test_function_size(c):
1037+
assert egraph.function_size(c) == 1
1038+
1039+
1040+
def test_all_function_size():
1041+
res = egraph.all_function_sizes()
1042+
assert set(res) == {
1043+
(E, 1),
1044+
(f, 1),
1045+
(E.m, 1),
1046+
(E.X, 1),
1047+
(E.p, 1),
1048+
(C, 1),
1049+
(E.cm, 1),
1050+
}

0 commit comments

Comments
 (0)