Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,10 +1241,10 @@ def register(

def _register_commands(self, cmds: list[Command]) -> None:
self._add_decls(*cmds)
egg_cmds = list(map(self._command_to_egg, cmds))
egg_cmds = [egg_cmd for cmd in cmds if (egg_cmd := self._command_to_egg(cmd)) is not None]
self._egraph.run_program(*egg_cmds)

def _command_to_egg(self, cmd: Command) -> bindings._Command:
def _command_to_egg(self, cmd: Command) -> bindings._Command | None:
ruleset_name = ""
cmd_decl: CommandDecl
match cmd:
Expand Down
32 changes: 23 additions & 9 deletions python/egglog/egraph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, overload
from typing import TYPE_CHECKING, Literal, overload

from typing_extensions import assert_never

Expand Down Expand Up @@ -116,7 +116,8 @@ def ruleset_to_egg(self, name: str) -> None:
if rule in added_rules:
continue
cmd = self.command_to_egg(rule, name)
self.egraph.run_program(cmd)
if cmd is not None:
self.egraph.run_program(cmd)
added_rules.add(rule)
case CombinedRulesetDecl(rulesets):
if name in self.rulesets:
Expand All @@ -126,10 +127,13 @@ def ruleset_to_egg(self, name: str) -> None:
self.ruleset_to_egg(ruleset)
self.egraph.run_program(bindings.UnstableCombinedRuleset(name, list(rulesets)))

def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command:
def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command | None:
match cmd:
case ActionCommandDecl(action):
return bindings.ActionCommand(self.action_to_egg(action, expr_to_let=True))
action_egg = self.action_to_egg(action, expr_to_let=True)
if not action_egg:
return None
return bindings.ActionCommand(action_egg)
case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions):
self.type_ref_to_egg(tp)
rewrite = bindings.Rewrite(
Expand Down Expand Up @@ -166,7 +170,13 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command:
case _:
assert_never(cmd)

def action_to_egg(self, action: ActionDecl, expr_to_let: bool = False) -> bindings._Action:
@overload
def action_to_egg(self, action: ActionDecl) -> bindings._Action: ...

@overload
def action_to_egg(self, action: ActionDecl, expr_to_let: Literal[True] = ...) -> bindings._Action | None: ...

def action_to_egg(self, action: ActionDecl, expr_to_let: bool = False) -> bindings._Action | None: # noqa: C901, PLR0911, PLR0912
match action:
case LetDecl(name, typed_expr):
var_decl = VarDecl(name, True)
Expand All @@ -179,7 +189,11 @@ def action_to_egg(self, action: ActionDecl, expr_to_let: bool = False) -> bindin
return bindings.Set(span(), call_.name, call_.args, self._expr_to_egg(rhs))
case ExprActionDecl(typed_expr):
if expr_to_let:
typed_expr = self._transform_let(typed_expr)
maybe_typed_expr = self._transform_let(typed_expr)
if maybe_typed_expr:
typed_expr = maybe_typed_expr
else:
return None
return bindings.Expr_(span(), self.typed_expr_to_egg(typed_expr))
case ChangeDecl(tp, call, change):
self.type_ref_to_egg(tp)
Expand Down Expand Up @@ -351,13 +365,13 @@ def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl, transform_let: bool
self.type_ref_to_egg(typed_expr_decl.tp)
return self._expr_to_egg(typed_expr_decl.expr)

def _transform_let(self, typed_expr: TypedExprDecl) -> TypedExprDecl:
def _transform_let(self, typed_expr: TypedExprDecl) -> TypedExprDecl | None:
"""
Rewrites this expression as a let binding if it's not already a let binding.
"""
var_decl = VarDecl(f"__expr_{hash(typed_expr)}", True)
if var_decl in self.expr_to_egg_cache:
return TypedExprDecl(typed_expr.tp, var_decl)
return None
var_egg = self._expr_to_egg(var_decl)
cmd = bindings.ActionCommand(bindings.Let(span(), var_egg.name, self.typed_expr_to_egg(typed_expr)))
try:
Expand All @@ -367,7 +381,7 @@ def _transform_let(self, typed_expr: TypedExprDecl) -> TypedExprDecl:
return typed_expr
self.expr_to_egg_cache[typed_expr.expr] = var_egg
self.expr_to_egg_cache[var_decl] = var_egg
return TypedExprDecl(typed_expr.tp, var_decl)
return None

@overload
def _expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
Expand Down
34 changes: 17 additions & 17 deletions python/egglog/exp/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ def check_index(length: IntLike, idx: IntLike) -> Int:
"""
Returns the index if 0 <= idx < length, otherwise returns Int.NEVER
"""
length = cast(Int, length)
idx = cast(Int, idx)
length = cast("Int", length)
idx = cast("Int", idx)
return Int.if_(((idx >= 0) & (idx < length)), idx, Int.NEVER)


Expand Down Expand Up @@ -336,7 +336,7 @@ def abs(self) -> Float: ...

@method(cost=2)
@classmethod
def rational(cls, r: Rational) -> Float: ...
def rational(cls, r: BigRat) -> Float: ...

@classmethod
def from_int(cls, i: IntLike) -> Float: ...
Expand All @@ -362,15 +362,15 @@ def __eq__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]


@array_api_ruleset.register
def _float(fl: Float, f: f64, f2: f64, i: i64, r: Rational, r1: Rational):
def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat):
return [
rule(eq(fl).to(Float(f))).then(set_(fl.to_f64).to(f)),
rewrite(Float(f).abs()).to(Float(f), f >= 0.0),
rewrite(Float(f).abs()).to(Float(-f), f < 0.0),
# Convert from float to rationl, if its a whole number i.e. can be converted to int
rewrite(Float(f)).to(Float.rational(Rational(f.to_i64(), 1)), eq(f64.from_i64(f.to_i64())).to(f)),
rewrite(Float(f)).to(Float.rational(BigRat(f.to_i64(), 1)), eq(f64.from_i64(f.to_i64())).to(f)),
# always convert from int to rational
rewrite(Float.from_int(Int(i))).to(Float.rational(Rational(i, 1))),
rewrite(Float.from_int(Int(i))).to(Float.rational(BigRat(i, 1))),
rewrite(Float(f) + Float(f2)).to(Float(f + f2)),
rewrite(Float(f) - Float(f2)).to(Float(f - f2)),
rewrite(Float(f) * Float(f2)).to(Float(f * f2)),
Expand Down Expand Up @@ -417,7 +417,7 @@ def range(cls, stop: IntLike) -> TupleInt:
def from_vec(cls, vec: VecLike[Int, IntLike]) -> TupleInt: ...

def __add__(self, other: TupleIntLike) -> TupleInt:
other = cast(TupleInt, other)
other = cast("TupleInt", other)
return TupleInt(
self.length() + other.length(), lambda i: Int.if_(i < self.length(), self[i], other[i - self.length()])
)
Expand Down Expand Up @@ -475,14 +475,14 @@ def select(self, indices: TupleIntLike) -> TupleInt:
"""
Return a new tuple with the elements at the given indices
"""
indices = cast(TupleInt, indices)
indices = cast("TupleInt", indices)
return indices.map(lambda i: self[i])

def deselect(self, indices: TupleIntLike) -> TupleInt:
"""
Return a new tuple with the elements not at the given indices
"""
indices = cast(TupleInt, indices)
indices = cast("TupleInt", indices)
return TupleInt.range(self.length()).filter(lambda i: ~indices.contains(i)).map(lambda i: self[i])


Expand Down Expand Up @@ -554,7 +554,7 @@ def __init__(self, length: IntLike, idx_fn: Callable[[Int], TupleInt]) -> None:
@method(subsume=True)
@classmethod
def single(cls, i: TupleIntLike) -> TupleTupleInt:
i = cast(TupleInt, i)
i = cast("TupleInt", i)
return TupleTupleInt(1, lambda _: i)

@method(subsume=True)
Expand All @@ -564,7 +564,7 @@ def from_vec(cls, vec: Vec[TupleInt]) -> TupleTupleInt: ...
def append(self, i: TupleIntLike) -> TupleTupleInt: ...

def __add__(self, other: TupleTupleIntLike) -> TupleTupleInt:
other = cast(TupleTupleInt, other)
other = cast("TupleTupleInt", other)
return TupleTupleInt(
self.length() + other.length(),
lambda i: TupleInt.if_(i < self.length(), self[i], other[i - self.length()]),
Expand Down Expand Up @@ -840,7 +840,7 @@ def _value(i: Int, f: Float, b: Boolean, v: Value, v1: Value, i1: Int, f1: Float

yield rewrite(Value.float(f).sqrt()).to(Value.float(f ** (0.5)))

yield rewrite(Value.float(Float.rational(Rational(0, 1))) + v).to(v)
yield rewrite(Value.float(Float.rational(BigRat(0, 1))) + v).to(v)

yield rewrite(Value.if_(TRUE, v, v1)).to(v)
yield rewrite(Value.if_(FALSE, v, v1)).to(v1)
Expand All @@ -862,7 +862,7 @@ def append(self, i: ValueLike) -> TupleValue: ...
def from_vec(cls, vec: Vec[Value]) -> TupleValue: ...

def __add__(self, other: TupleValueLike) -> TupleValue:
other = cast(TupleValue, other)
other = cast("TupleValue", other)
return TupleValue(
self.length() + other.length(),
lambda i: Value.if_(i < self.length(), self[i], other[i - self.length()]),
Expand All @@ -875,13 +875,13 @@ def __getitem__(self, i: Int) -> Value: ...
def foldl_boolean(self, f: Callable[[Boolean, Value], Boolean], init: BooleanLike) -> Boolean: ...

def contains(self, value: ValueLike) -> Boolean:
value = cast(Value, value)
value = cast("Value", value)
return self.foldl_boolean(lambda acc, j: acc | (value == j), FALSE)

@method(subsume=True)
@classmethod
def from_tuple_int(cls, ti: TupleIntLike) -> TupleValue:
ti = cast(TupleInt, ti)
ti = cast("TupleInt", ti)
return TupleValue(ti.length(), lambda i: Value.int(ti[i]))


Expand Down Expand Up @@ -1259,7 +1259,7 @@ def append(self, i: NDArrayLike) -> TupleNDArray: ...
def from_vec(cls, vec: Vec[NDArray]) -> TupleNDArray: ...

def __add__(self, other: TupleNDArrayLike) -> TupleNDArray:
other = cast(TupleNDArray, other)
other = cast("TupleNDArray", other)
return TupleNDArray(
self.length() + other.length(),
lambda i: NDArray.if_(i < self.length(), self[i], other[i - self.length()]),
Expand Down Expand Up @@ -1632,7 +1632,7 @@ def ndindex(shape: TupleIntLike) -> TupleTupleInt:
"""
https://numpy.org/doc/stable/reference/generated/numpy.ndindex.html
"""
shape = cast(TupleInt, shape)
shape = cast("TupleInt", shape)
return shape.map_tuple_int(TupleInt.range).product()


Expand Down
36 changes: 26 additions & 10 deletions python/egglog/exp/array_api_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
from collections.abc import Callable
from typing import TypeVar, cast

import numpy as np

from egglog import EGraph, try_evaling
from egglog.exp.array_api import NDArray
from egglog.exp.array_api_numba import array_api_numba_schedule
from egglog.exp.array_api_program_gen import array_api_program_gen_schedule, ndarray_function_two
from egglog.exp.array_api_program_gen import EvalProgram, array_api_program_gen_schedule, ndarray_function_two_program

from .program_gen import Program

X = TypeVar("X", bound=Callable)

Expand All @@ -14,15 +18,27 @@ def jit(fn: X) -> X:
"""
Jit compiles a function
"""
sig = inspect.signature(fn)
arg1, arg2 = sig.parameters.keys()
egraph = EGraph()
egraph, res, res_optimized, program = function_to_program(fn, save_egglog_string=False)
fn_program = EvalProgram(program, {"np": np})
with egraph.set_current():
res = fn(NDArray.var(arg1), NDArray.var(arg2))
res_optimized = egraph.simplify(res, array_api_numba_schedule)

fn_program = ndarray_function_two(res_optimized, NDArray.var(arg1), NDArray.var(arg2))
fn = try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object)
fn = cast("X", try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
fn.initial_expr = res # type: ignore[attr-defined]
fn.expr = res_optimized # type: ignore[attr-defined]
return cast(X, fn)
return fn


def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph, NDArray, NDArray, Program]:
sig = inspect.signature(fn)
arg1, arg2 = sig.parameters.keys()
egraph = EGraph(save_egglog_string=save_egglog_string)
with egraph:
with egraph.set_current():
res = fn(NDArray.var(arg1), NDArray.var(arg2))
res_optimized = egraph.simplify(res, array_api_numba_schedule)

return (
egraph,
res,
res_optimized,
ndarray_function_two_program(res_optimized, NDArray.var(arg1), NDArray.var(arg2)),
)
12 changes: 6 additions & 6 deletions python/egglog/exp/array_api_program_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
array_api_program_gen_ruleset = ruleset(name="array_api_program_gen_ruleset")
array_api_program_gen_eval_ruleset = ruleset(name="array_api_program_gen_eval_ruleset")

array_api_program_gen_schedule = (
array_api_program_gen_combined_ruleset = (
array_api_program_gen_ruleset
| program_gen_ruleset
| array_api_program_gen_eval_ruleset
| eval_program_rulseset
| array_api_vec_to_cons_ruleset
).saturate()
)
array_api_program_gen_schedule = (array_api_program_gen_combined_ruleset | eval_program_rulseset).saturate()


@function
Expand Down Expand Up @@ -116,7 +116,7 @@ def float_program(x: Float) -> Program: ...


@array_api_program_gen_ruleset.register
def _float_program(f: Float, g: Float, f64_: f64, i: Int, r: Rational):
def _float_program(f: Float, g: Float, f64_: f64, i: Int, r: BigRat):
yield rewrite(float_program(Float(f64_))).to(Program(f64_.to_string()))
yield rewrite(float_program(f.abs())).to(Program("np.abs(") + float_program(f) + ")")
yield rewrite(float_program(Float.from_int(i))).to(int_program(i))
Expand All @@ -126,10 +126,10 @@ def _float_program(f: Float, g: Float, f64_: f64, i: Int, r: Rational):
yield rewrite(float_program(f / g)).to(Program("(") + float_program(f) + " / " + float_program(g) + ")")
yield rewrite(float_program(Float.rational(r))).to(
Program("float(") + Program(r.numer.to_string()) + " / " + Program(r.denom.to_string()) + ")",
ne(r.denom).to(i64(1)),
ne(r.denom).to(BigInt(1)),
)
yield rewrite(float_program(Float.rational(r))).to(
Program("float(") + Program(r.numer.to_string()) + ")", eq(r.denom).to(i64(1))
Program("float(") + Program(r.numer.to_string()) + ")", eq(r.denom).to(BigInt(1))
)


Expand Down
17 changes: 13 additions & 4 deletions python/tests/__snapshots__/test_array_api/test_jit[lda][expr].py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
),
DType.float64,
) / NDArray.scalar(Value.float(Float.rational(Rational(150, 1))))
) / NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("150"), BigInt.from_string("1")))))
_NDArray_4 = zeros(TupleInt.from_vec(Vec[Int](Int(3), Int(4))), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device))
_MultiAxisIndexKeyItem_1 = MultiAxisIndexKeyItem.slice(Slice())
_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(0)), _MultiAxisIndexKeyItem_1)))
Expand All @@ -36,15 +36,24 @@
_NDArray_9 = square(_NDArray_8 - expand_dims(sum(_NDArray_8, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_8.shape[Int(0)]))))
_NDArray_10 = sqrt(sum(_NDArray_9, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_9.shape[Int(0)])))
_NDArray_11 = copy(_NDArray_10)
_NDArray_11[IndexKey.ndarray(_NDArray_10 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float.rational(Rational(1, 1))))
_TupleNDArray_1 = svd(sqrt(asarray(NDArray.scalar(Value.float(Float.rational(Rational(1, 147)))), OptionalDType.some(DType.float64))) * (_NDArray_8 / _NDArray_11), Boolean(False))
_NDArray_11[IndexKey.ndarray(_NDArray_10 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(
Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1"))))
)
_TupleNDArray_1 = svd(
sqrt(asarray(NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("147"))))), OptionalDType.some(DType.float64)))
* (_NDArray_8 / _NDArray_11),
Boolean(False),
)
_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))
_NDArray_12 = (
_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))]
/ _NDArray_11
).T / _TupleNDArray_1[Int(1)][IndexKey.slice(_Slice_1)]
_TupleNDArray_2 = svd(
(sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_3) * NDArray.scalar(Value.float(Float.rational(Rational(1, 2))))) * (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T).T
(
sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_3) * NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("2"))))))
* (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T
).T
@ _NDArray_12,
Boolean(False),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,14 @@
OptionalInt.some(Int(0)),
)
_NDArray_6 = std(_NDArray_5, _OptionalIntOrTuple_1)
_NDArray_6[IndexKey.ndarray(std(_NDArray_5, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float.rational(Rational(1, 1))))
_TupleNDArray_1 = svd(sqrt(asarray(NDArray.scalar(Value.float(Float.rational(Rational(1, 147)))), OptionalDType.some(DType.float64))) * (_NDArray_5 / _NDArray_6), Boolean(False))
_NDArray_6[IndexKey.ndarray(std(_NDArray_5, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(
Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1"))))
)
_TupleNDArray_1 = svd(
sqrt(asarray(NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("147"))))), OptionalDType.some(DType.float64)))
* (_NDArray_5 / _NDArray_6),
Boolean(False),
)
_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))
_NDArray_7 = asarray(reshape(asarray(_NDArray_2), TupleInt.from_vec(Vec[Int](Int(-1)))))
_NDArray_8 = unique_values(concat(TupleNDArray.from_vec(Vec[NDArray](unique_values(asarray(_NDArray_7))))))
Expand All @@ -61,9 +67,12 @@
)
_NDArray_10 = copy(_NDArray_9)
_NDArray_10[IndexKey.ndarray(_NDArray_9 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0)))
_NDArray_11 = astype(unique_counts(_NDArray_2)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float.rational(Rational(150, 1))))
_NDArray_11 = astype(unique_counts(_NDArray_2)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("150"), BigInt.from_string("1")))))
_TupleNDArray_2 = svd(
(sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_11) * NDArray.scalar(Value.float(Float.rational(Rational(1, 2))))) * (_NDArray_4 - (_NDArray_11 @ _NDArray_4)).T).T
(
sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_11) * NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("2"))))))
* (_NDArray_4 - (_NDArray_11 @ _NDArray_4)).T
).T
@ (
(
_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))]
Expand Down
Loading