Skip to content

Commit 23a94d5

Browse files
Add all_function_sizes and function_size EGraph methods
1 parent 6b46a65 commit 23a94d5

File tree

5 files changed

+142
-9
lines changed

5 files changed

+142
-9
lines changed

docs/reference/egglog-translation.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,21 @@ with egraph:
507507
egraph.check_fail(eq(Math(0)).to(Math(1)))
508508
```
509509

510+
## Function Sizes
511+
512+
The `(print-size <functon name>?)` command is translated into either `egraph.function_size(fn)` to get the number of
513+
rows of one function or `egraph.all_function_sizes()` to get a list of all the function sizes:
514+
515+
```{code-cell} python
516+
# (function-size Math)
517+
egraph.function_size(Math)
518+
```
519+
520+
```{code-cell} python
521+
# (function-size)
522+
egraph.all_function_sizes()
523+
```
524+
510525
## Include
511526

512527
The `(include <path>)` command is used to add modularity, by allowing you to pull in the source from another egglog file into the current file.

python/egglog/egraph.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ class GraphvizKwargs(TypedDict, total=False):
805805
max_calls_per_function: int | None
806806
n_inline_leaves: int
807807
split_primitive_outputs: bool
808-
split_functions: list[object]
808+
split_functions: list[ExprCallable]
809809
include_temporary_functions: bool
810810

811811

@@ -854,7 +854,7 @@ def input(self, fn: Callable[..., String], path: str) -> None:
854854
"""
855855
self._egraph.run_program(bindings.Input(span(1), self._callable_to_egg(fn), path))
856856

857-
def _callable_to_egg(self, fn: object) -> str:
857+
def _callable_to_egg(self, fn: ExprCallable) -> str:
858858
ref, decls = resolve_callable(fn)
859859
self._add_decls(decls)
860860
return self._state.callable_ref_to_egg(ref)[0]
@@ -1179,6 +1179,37 @@ def _command_to_egg(self, cmd: Command) -> bindings._Command | None:
11791179
assert_never(cmd)
11801180
return self._state.command_to_egg(cmd_decl, ruleset_name)
11811181

1182+
def function_size(self, fn: ExprCallable) -> int:
1183+
"""
1184+
Returns the number of rows in a certain function
1185+
"""
1186+
egg_name = self._callable_to_egg(fn)
1187+
(output,) = self._egraph.run_program(bindings.PrintSize(span(1), egg_name))
1188+
assert isinstance(output, bindings.PrintFunctionSize)
1189+
return output.size
1190+
1191+
def all_function_sizes(self) -> list[tuple[ExprCallable, int]]:
1192+
"""
1193+
Returns a list of all functions and their sizes.
1194+
"""
1195+
(output,) = self._egraph.run_program(bindings.PrintSize(span(1), None))
1196+
assert isinstance(output, bindings.PrintAllFunctionsSize)
1197+
return [
1198+
(
1199+
cast(
1200+
"ExprCallable",
1201+
create_callable(self._state.__egg_decls__, next(iter(refs))),
1202+
),
1203+
size,
1204+
)
1205+
for (name, size) in output.sizes
1206+
if (refs := self._state.egg_fn_to_callable_refs[name])
1207+
]
1208+
1209+
1210+
# Either a constant or a function.
1211+
ExprCallable: TypeAlias = Callable[..., BaseExpr] | BaseExpr
1212+
11821213

11831214
@dataclass(frozen=True)
11841215
class _WrappedMethod:

python/egglog/egraph_state.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def fact_to_egg(self, fact: FactDecl) -> bindings._Fact:
225225
case _:
226226
assert_never(fact)
227227

228-
def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]:
228+
def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]: # noqa: C901, PLR0912
229229
"""
230230
Returns the egg function name for a callable reference, registering it if it is not already registered.
231231
@@ -245,9 +245,12 @@ def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]:
245245
case ConstantDecl(tp, _):
246246
# Use constructor decleration instead of constant b/c constants cannot be extracted
247247
# https://github.com/egraphs-good/egglog/issues/334
248-
self.egraph.run_program(
249-
bindings.Constructor(span(), egg_name, bindings.Schema([], self.type_ref_to_egg(tp)), None, False)
250-
)
248+
is_function = self.__egg_decls__._classes[tp.name].builtin
249+
schema = bindings.Schema([], self.type_ref_to_egg(tp))
250+
if is_function:
251+
self.egraph.run_program(bindings.FunctionCommand(span(), egg_name, schema, None))
252+
else:
253+
self.egraph.run_program(bindings.Constructor(span(), egg_name, schema, None, False))
251254
case FunctionDecl(signature, builtin, _, merge):
252255
if isinstance(signature, FunctionSignature):
253256
reverse_args = signature.reverse_args

python/egglog/runtime.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
"RuntimeClass",
3737
"RuntimeExpr",
3838
"RuntimeFunction",
39+
"create_callable",
3940
"define_expr_method",
4041
"resolve_callable",
4142
"resolve_type_annotation",
@@ -340,7 +341,7 @@ def __repr__(self) -> str:
340341

341342
# Make hashable so can go in Union
342343
def __hash__(self) -> int:
343-
return hash((id(self.__egg_decls_thunk__), self.__egg_tp__))
344+
return hash(self.__egg_tp__)
344345

345346
def __eq__(self, other: object) -> bool:
346347
"""
@@ -478,6 +479,9 @@ def __str__(self) -> str:
478479
bound_tp_params = args
479480
return pretty_callable_ref(self.__egg_decls__, self.__egg_ref__, first_arg, bound_tp_params)
480481

482+
def __repr__(self) -> str:
483+
return str(self)
484+
481485

482486
def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args: bool) -> Signature:
483487
"""
@@ -670,11 +674,12 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
670674
"""
671675
Resolves a runtime callable into a ref
672676
"""
677+
# TODO: Make runtime class work with __match_args__
678+
if isinstance(callable, RuntimeClass):
679+
return InitRef(callable.__egg_tp__.name), callable.__egg_decls__
673680
match callable:
674681
case RuntimeFunction(decls, ref, _):
675682
return ref(), decls()
676-
case RuntimeClass(thunk, tp):
677-
return InitRef(tp.name), thunk()
678683
case RuntimeExpr(decl_thunk, expr_thunk):
679684
if not isinstance((expr := expr_thunk().expr), CallDecl) or not isinstance(
680685
expr.callable, ConstantRef | ClassVariableRef
@@ -683,3 +688,21 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
683688
return expr.callable, decl_thunk()
684689
case _:
685690
raise NotImplementedError(f"Cannot turn {callable} of type {type(callable)} into a callable ref")
691+
692+
693+
def create_callable(decls: Declarations, ref: CallableRef) -> RuntimeFunction | RuntimeExpr:
694+
"""
695+
Creates a callable object from a callable ref. This might not actually be callable, if the ref is a constant
696+
or classvar then it is a value
697+
"""
698+
match ref:
699+
case InitRef(name):
700+
return RuntimeClass(Thunk.value(decls), TypeRefWithVars(name))
701+
case FunctionRef() | MethodRef() | ClassMethodRef() | PropertyRef() | UnnamedFunctionRef():
702+
bound = JustTypeRef(ref.class_name) if isinstance(ref, ClassMethodRef) else None
703+
return RuntimeFunction(Thunk.value(decls), Thunk.value(ref), bound)
704+
case ConstantRef(name):
705+
tp = decls._constants[name].type_ref
706+
case ClassVariableRef(cls_name, var_name):
707+
tp = decls._classes[cls_name].class_variables[var_name].type_ref
708+
return RuntimeExpr.__from_values__(decls, TypedExprDecl(tp, CallDecl(ref)))

python/tests/test_high_level.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,3 +987,64 @@ def f(x: StringLike) -> A: ...
987987
@pytest.mark.parametrize("name", [f.stem for f in EXAMPLE_FILES if f.stem != "__init__"])
988988
def test_example(name):
989989
importlib.import_module(f"egglog.examples.{name}")
990+
991+
992+
@function
993+
def f() -> i64: ...
994+
995+
996+
class E(Expr):
997+
X: ClassVar[i64]
998+
999+
def __init__(self) -> None: ...
1000+
def m(self) -> i64: ...
1001+
1002+
@property
1003+
def p(self) -> i64: ...
1004+
1005+
@classmethod
1006+
def cm(cls) -> i64: ...
1007+
1008+
1009+
egraph = EGraph()
1010+
1011+
C = constant("C", i64)
1012+
1013+
zero = i64(0)
1014+
egraph.register(
1015+
set_(f()).to(zero),
1016+
set_(E().m()).to(zero),
1017+
set_(E.X).to(zero),
1018+
set_(E().p).to(zero),
1019+
set_(C).to(zero),
1020+
set_(E.cm()).to(zero),
1021+
)
1022+
1023+
1024+
@pytest.mark.parametrize(
1025+
"c",
1026+
[
1027+
pytest.param(E, id="init"),
1028+
pytest.param(f, id="function"),
1029+
pytest.param(E.m, id="method"),
1030+
pytest.param(E.X, id="class var"),
1031+
pytest.param(E.p, id="property"),
1032+
pytest.param(C, id="constant"),
1033+
pytest.param(E.cm, id="class method"),
1034+
],
1035+
)
1036+
def test_function_size(c):
1037+
assert egraph.function_size(c) == 1
1038+
1039+
1040+
def test_all_function_size():
1041+
res = egraph.all_function_sizes()
1042+
assert set(res) == {
1043+
(E, 1),
1044+
(f, 1),
1045+
(E.m, 1),
1046+
(E.X, 1),
1047+
(E.p, 1),
1048+
(C, 1),
1049+
(E.cm, 1),
1050+
}

0 commit comments

Comments
 (0)