Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/reference/egglog-translation.md
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,21 @@ with egraph:
egraph.check_fail(eq(Math(0)).to(Math(1)))
```

## Function Sizes

The `(print-size <functon name>?)` command is translated into either `egraph.function_size(fn)` to get the number of
rows of one function or `egraph.all_function_sizes()` to get a list of all the function sizes:

```{code-cell} python
# (function-size Math)
egraph.function_size(Math)
```

```{code-cell} python
# (function-size)
egraph.all_function_sizes()
```

## Include

The `(include <path>)` command is used to add modularity, by allowing you to pull in the source from another egglog file into the current file.
Expand Down
35 changes: 33 additions & 2 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ class GraphvizKwargs(TypedDict, total=False):
max_calls_per_function: int | None
n_inline_leaves: int
split_primitive_outputs: bool
split_functions: list[object]
split_functions: list[ExprCallable]
include_temporary_functions: bool


Expand Down Expand Up @@ -854,7 +854,7 @@ def input(self, fn: Callable[..., String], path: str) -> None:
"""
self._egraph.run_program(bindings.Input(span(1), self._callable_to_egg(fn), path))

def _callable_to_egg(self, fn: object) -> str:
def _callable_to_egg(self, fn: ExprCallable) -> str:
ref, decls = resolve_callable(fn)
self._add_decls(decls)
return self._state.callable_ref_to_egg(ref)[0]
Expand Down Expand Up @@ -1179,6 +1179,37 @@ def _command_to_egg(self, cmd: Command) -> bindings._Command | None:
assert_never(cmd)
return self._state.command_to_egg(cmd_decl, ruleset_name)

def function_size(self, fn: ExprCallable) -> int:
"""
Returns the number of rows in a certain function
"""
egg_name = self._callable_to_egg(fn)
(output,) = self._egraph.run_program(bindings.PrintSize(span(1), egg_name))
assert isinstance(output, bindings.PrintFunctionSize)
return output.size

def all_function_sizes(self) -> list[tuple[ExprCallable, int]]:
"""
Returns a list of all functions and their sizes.
"""
(output,) = self._egraph.run_program(bindings.PrintSize(span(1), None))
assert isinstance(output, bindings.PrintAllFunctionsSize)
return [
(
cast(
"ExprCallable",
create_callable(self._state.__egg_decls__, next(iter(refs))),
),
size,
)
for (name, size) in output.sizes
if (refs := self._state.egg_fn_to_callable_refs[name])
]


# Either a constant or a function.
ExprCallable: TypeAlias = Callable[..., BaseExpr] | BaseExpr


@dataclass(frozen=True)
class _WrappedMethod:
Expand Down
11 changes: 7 additions & 4 deletions python/egglog/egraph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def fact_to_egg(self, fact: FactDecl) -> bindings._Fact:
case _:
assert_never(fact)

def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]:
def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]: # noqa: C901, PLR0912
"""
Returns the egg function name for a callable reference, registering it if it is not already registered.

Expand All @@ -245,9 +245,12 @@ def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]:
case ConstantDecl(tp, _):
# Use constructor decleration instead of constant b/c constants cannot be extracted
# https://github.com/egraphs-good/egglog/issues/334
self.egraph.run_program(
bindings.Constructor(span(), egg_name, bindings.Schema([], self.type_ref_to_egg(tp)), None, False)
)
is_function = self.__egg_decls__._classes[tp.name].builtin
schema = bindings.Schema([], self.type_ref_to_egg(tp))
if is_function:
self.egraph.run_program(bindings.FunctionCommand(span(), egg_name, schema, None))
else:
self.egraph.run_program(bindings.Constructor(span(), egg_name, schema, None, False))
case FunctionDecl(signature, builtin, _, merge):
if isinstance(signature, FunctionSignature):
reverse_args = signature.reverse_args
Expand Down
29 changes: 26 additions & 3 deletions python/egglog/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"RuntimeClass",
"RuntimeExpr",
"RuntimeFunction",
"create_callable",
"define_expr_method",
"resolve_callable",
"resolve_type_annotation",
Expand Down Expand Up @@ -340,7 +341,7 @@ def __repr__(self) -> str:

# Make hashable so can go in Union
def __hash__(self) -> int:
return hash((id(self.__egg_decls_thunk__), self.__egg_tp__))
return hash(self.__egg_tp__)

def __eq__(self, other: object) -> bool:
"""
Expand Down Expand Up @@ -478,6 +479,9 @@ def __str__(self) -> str:
bound_tp_params = args
return pretty_callable_ref(self.__egg_decls__, self.__egg_ref__, first_arg, bound_tp_params)

def __repr__(self) -> str:
return str(self)


def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args: bool) -> Signature:
"""
Expand Down Expand Up @@ -670,11 +674,12 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
"""
Resolves a runtime callable into a ref
"""
# TODO: Make runtime class work with __match_args__
if isinstance(callable, RuntimeClass):
return InitRef(callable.__egg_tp__.name), callable.__egg_decls__
match callable:
case RuntimeFunction(decls, ref, _):
return ref(), decls()
case RuntimeClass(thunk, tp):
return InitRef(tp.name), thunk()
case RuntimeExpr(decl_thunk, expr_thunk):
if not isinstance((expr := expr_thunk().expr), CallDecl) or not isinstance(
expr.callable, ConstantRef | ClassVariableRef
Expand All @@ -683,3 +688,21 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
return expr.callable, decl_thunk()
case _:
raise NotImplementedError(f"Cannot turn {callable} of type {type(callable)} into a callable ref")


def create_callable(decls: Declarations, ref: CallableRef) -> RuntimeFunction | RuntimeExpr:
"""
Creates a callable object from a callable ref. This might not actually be callable, if the ref is a constant
or classvar then it is a value
"""
match ref:
case InitRef(name):
return RuntimeClass(Thunk.value(decls), TypeRefWithVars(name))
case FunctionRef() | MethodRef() | ClassMethodRef() | PropertyRef() | UnnamedFunctionRef():
bound = JustTypeRef(ref.class_name) if isinstance(ref, ClassMethodRef) else None
return RuntimeFunction(Thunk.value(decls), Thunk.value(ref), bound)
case ConstantRef(name):
tp = decls._constants[name].type_ref
case ClassVariableRef(cls_name, var_name):
tp = decls._classes[cls_name].class_variables[var_name].type_ref
return RuntimeExpr.__from_values__(decls, TypedExprDecl(tp, CallDecl(ref)))
61 changes: 61 additions & 0 deletions python/tests/test_high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,3 +987,64 @@ def f(x: StringLike) -> A: ...
@pytest.mark.parametrize("name", [f.stem for f in EXAMPLE_FILES if f.stem != "__init__"])
def test_example(name):
importlib.import_module(f"egglog.examples.{name}")


@function
def f() -> i64: ...


class E(Expr):
X: ClassVar[i64]

def __init__(self) -> None: ...
def m(self) -> i64: ...

@property
def p(self) -> i64: ...

@classmethod
def cm(cls) -> i64: ...


egraph = EGraph()

C = constant("C", i64)

zero = i64(0)
egraph.register(
set_(f()).to(zero),
set_(E().m()).to(zero),
set_(E.X).to(zero),
set_(E().p).to(zero),
set_(C).to(zero),
set_(E.cm()).to(zero),
)


@pytest.mark.parametrize(
"c",
[
pytest.param(E, id="init"),
pytest.param(f, id="function"),
pytest.param(E.m, id="method"),
pytest.param(E.X, id="class var"),
pytest.param(E.p, id="property"),
pytest.param(C, id="constant"),
pytest.param(E.cm, id="class method"),
],
)
def test_function_size(c):
assert egraph.function_size(c) == 1


def test_all_function_size():
res = egraph.all_function_sizes()
assert set(res) == {
(E, 1),
(f, 1),
(E.m, 1),
(E.X, 1),
(E.p, 1),
(C, 1),
(E.cm, 1),
}
Loading