Skip to content

Commit 04da652

Browse files
Change binary conversions to look at all possible conversion types
1 parent 80d1c07 commit 04da652

File tree

8 files changed

+84
-38
lines changed

8 files changed

+84
-38
lines changed

python/egglog/conversion.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -154,30 +154,42 @@ def min_binary_conversion(
154154
both of them can be converted. However, if both are egglog types, then only one of them can be converted.
155155
"""
156156
decls = retrieve_conversion_decls()
157-
# tuple of (cost, convert_self)
158-
best_method: tuple[int, Callable[[Any], RuntimeExpr]] | None = None
159-
# Start by checking if we have a LHS that matches exactly and a RHS which can be converted
160-
if (
161-
isinstance(lhs, JustTypeRef)
162-
and (desired_other_type := decls.check_binary_method_with_self_type(method_name, lhs))
163-
and (converter := CONVERSIONS.get((rhs, desired_other_type)))
164-
):
165-
best_method = (converter[0], lambda x: x)
166-
167-
# Next see if it's possible to convert the LHS and keep the RHS as is
168-
if isinstance(rhs, JustTypeRef):
169-
decls = retrieve_conversion_decls()
170-
for desired_self_type in decls.check_binary_method_with_other_type(method_name, rhs):
157+
# tuple of (cost, convert lhs, convert rhs)
158+
best_method: tuple[int, Callable[[Any], RuntimeExpr], Callable[[Any], RuntimeExpr]] | None = None
159+
160+
possible_lhs = _all_conversions_from(lhs) if isinstance(lhs, type) else [(0, lhs, identity)]
161+
possible_rhs = _all_conversions_from(rhs) if isinstance(rhs, type) else [(0, rhs, identity)]
162+
for lhs_cost, lhs_converted_type, lhs_convert in possible_lhs:
163+
# Start by checking if we have a LHS that matches exactly and a RHS which can be converted
164+
if (desired_other_type := decls.check_binary_method_with_self_type(method_name, lhs_converted_type)) and (
165+
converter := CONVERSIONS.get((rhs, desired_other_type))
166+
):
167+
cost = lhs_cost + converter[0]
168+
if best_method is None or best_method[0] > cost:
169+
best_method = (cost, lhs_convert, converter[1])
170+
171+
for rhs_cost, rhs_converted_type, rhs_convert in possible_rhs:
172+
# Next see if it's possible to convert the LHS and keep the RHS as is
173+
for desired_self_type in decls.check_binary_method_with_other_type(method_name, rhs_converted_type):
171174
if converter := CONVERSIONS.get((lhs, desired_self_type)):
172-
cost, convert_self = converter
175+
cost = rhs_cost + converter[0]
173176
if best_method is None or best_method[0] > cost:
174-
best_method = (cost, convert_self)
177+
best_method = (cost, converter[1], rhs_convert)
175178
if best_method is None:
176179
return None
177-
return best_method[1], best_method[1]
180+
return best_method[1], best_method[2]
178181

179182

180-
def identity(x: object) -> object:
183+
def _all_conversions_from(tp: JustTypeRef | type) -> list[tuple[int, JustTypeRef, Callable[[Any], RuntimeExpr]]]:
184+
"""
185+
Get all conversions from a type to other types.
186+
187+
Returns a list of tuples of (cost, target type, conversion function).
188+
"""
189+
return [(cost, target, fn) for (source, target), (cost, fn) in CONVERSIONS.items() if source == tp]
190+
191+
192+
def identity(x: Any) -> Any:
181193
return x
182194

183195

python/egglog/exp/array_api.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,14 @@ def __add__(self, other: FloatLike) -> Float: ...
368368
def __sub__(self, other: FloatLike) -> Float: ...
369369

370370
def __pow__(self, other: FloatLike) -> Float: ...
371+
def __round__(self, ndigits: OptionalIntLike = None) -> Float: ...
371372

372373
def __eq__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
374+
def __ne__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
375+
def __lt__(self, other: FloatLike) -> Boolean: ...
376+
def __le__(self, other: FloatLike) -> Boolean: ...
377+
def __gt__(self, other: FloatLike) -> Boolean: ...
378+
def __ge__(self, other: FloatLike) -> Boolean: ...
373379

374380

375381
converter(float, Float, lambda x: Float(x))
@@ -380,9 +386,10 @@ def __eq__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
380386

381387

382388
@array_api_ruleset.register
383-
def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat):
389+
def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat, i_: Int):
384390
return [
385391
rule(eq(fl).to(Float(f))).then(set_(fl.to_f64).to(f)),
392+
rewrite(Float.from_int(Int(i))).to(Float(f64.from_i64(i))),
386393
rewrite(Float(f).abs()).to(Float(f), f >= 0.0),
387394
rewrite(Float(f).abs()).to(Float(-f), f < 0.0),
388395
# Convert from float to rationl, if its a whole number i.e. can be converted to int
@@ -397,11 +404,22 @@ def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat):
397404
rewrite(Float.rational(r) - Float.rational(r1)).to(Float.rational(r - r1)),
398405
rewrite(Float.rational(r) * Float.rational(r1)).to(Float.rational(r * r1)),
399406
rewrite(Float(f) ** Float(f2)).to(Float(f**f2)),
400-
# ==
407+
# comparisons
401408
rewrite(Float(f) == Float(f)).to(TRUE),
402409
rewrite(Float(f) == Float(f2)).to(FALSE, ne(f).to(f2)),
410+
rewrite(Float(f) != Float(f2)).to(TRUE, f != f2),
411+
rewrite(Float(f) != Float(f)).to(FALSE),
412+
rewrite(Float(f) >= Float(f2)).to(TRUE, f >= f2),
413+
rewrite(Float(f) >= Float(f2)).to(FALSE, f < f2),
414+
rewrite(Float(f) <= Float(f2)).to(TRUE, f <= f2),
415+
rewrite(Float(f) <= Float(f2)).to(FALSE, f > f2),
416+
rewrite(Float(f) > Float(f2)).to(TRUE, f > f2),
417+
rewrite(Float(f) > Float(f2)).to(FALSE, f <= f2),
418+
rewrite(Float(f) < Float(f2)).to(TRUE, f < f2),
403419
rewrite(Float.rational(r) == Float.rational(r)).to(TRUE),
404420
rewrite(Float.rational(r) == Float.rational(r1)).to(FALSE, ne(r).to(r1)),
421+
# round
422+
rewrite(Float.rational(r).__round__()).to(Float.rational(r.round())),
405423
]
406424

407425

python/egglog/runtime.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@
6565
"__and__",
6666
"__xor__",
6767
"__or__",
68+
"__lt__",
69+
"__le__",
70+
"__gt__",
71+
"__ge__",
6872
}
6973

7074

@@ -85,10 +89,6 @@
8589
"__pos__",
8690
"__neg__",
8791
"__invert__",
88-
"__lt__",
89-
"__le__",
90-
"__gt__",
91-
"__ge__",
9292
"__round__",
9393
}
9494

@@ -656,7 +656,7 @@ def _numeric_binary_method(self: object, other: object, name: str = name, r_meth
656656

657657
if not best_method:
658658
raise RuntimeError(f"Cannot resolve {name} for {self} and {other}, no conversion found")
659-
self = best_method[0](self)
659+
self, other = best_method[0](self), best_method[1](other)
660660

661661
method_ref = MethodRef(self.__egg_class_name__, name)
662662
fn = RuntimeFunction(Thunk.value(self.__egg_decls__), Thunk.value(method_ref), self)

python/tests/__snapshots__/test_array_api/test_jit[lda][code].py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __fn(X, y):
2525
_8[2, :,] = _14
2626
_15 = _7 @ _8
2727
_16 = X - _15
28-
_17 = np.sqrt(np.asarray(np.array((float(1) / 147)), np.dtype(np.float64)))
28+
_17 = np.sqrt(np.asarray(np.array(float(1 / 147)), np.dtype(np.float64)))
2929
_18 = X[_0] - _8[0, :,]
3030
_19 = X[_2] - _8[1, :,]
3131
_20 = X[_4] - _8[2, :,]
@@ -49,7 +49,7 @@ def __fn(X, y):
4949
_37 = _33[2][:_36, :,] / _29
5050
_38 = _37.T / _33[1][:_36]
5151
_39 = np.array(150) * _7
52-
_40 = _39 * np.array((float(1) / 2))
52+
_40 = _39 * np.array(float(1 / 2))
5353
_41 = np.sqrt(_40)
5454
_42 = _8 - _15
5555
_43 = _41 * _42.T

python/tests/__snapshots__/test_array_api/test_jit[lda][expr].py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,31 @@
3030
_IndexKey_3 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(2)), _MultiAxisIndexKeyItem_1)))
3131
_NDArray_7 = _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(2))))]
3232
_NDArray_4[_IndexKey_3] = sum(_NDArray_7, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_7.shape[Int(0)]))
33-
_Value_1 = Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1"))))
3433
_NDArray_8 = concat(
3534
TupleNDArray.from_vec(Vec[NDArray](_NDArray_5 - _NDArray_4[_IndexKey_1], _NDArray_6 - _NDArray_4[_IndexKey_2], _NDArray_7 - _NDArray_4[_IndexKey_3])), OptionalInt.some(Int(0))
3635
)
3736
_NDArray_9 = square(_NDArray_8 - expand_dims(sum(_NDArray_8, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_8.shape[Int(0)]))))
3837
_NDArray_10 = sqrt(sum(_NDArray_9, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_9.shape[Int(0)])))
3938
_NDArray_11 = copy(_NDArray_10)
40-
_NDArray_11[IndexKey.ndarray(_NDArray_10 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(_Value_1)
41-
_TupleNDArray_1 = svd(sqrt(asarray(NDArray.scalar(Value.int(_Value_1.to_int / Int(147))), OptionalDType.some(DType.float64))) * (_NDArray_8 / _NDArray_11), Boolean(False))
39+
_NDArray_11[IndexKey.ndarray(_NDArray_10 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(
40+
Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1"))))
41+
)
42+
_TupleNDArray_1 = svd(
43+
sqrt(asarray(NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("147"))))), OptionalDType.some(DType.float64)))
44+
* (_NDArray_8 / _NDArray_11),
45+
Boolean(False),
46+
)
4247
_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))
4348
_NDArray_12 = (
4449
_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))]
4550
/ _NDArray_11
4651
).T / _TupleNDArray_1[Int(1)][IndexKey.slice(_Slice_1)]
4752
_TupleNDArray_2 = svd(
48-
(sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_3) * NDArray.scalar(Value.int(_Value_1.to_int / Int(2)))) * (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T).T @ _NDArray_12,
53+
(
54+
sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_3) * NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("2"))))))
55+
* (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T
56+
).T
57+
@ _NDArray_12,
4958
Boolean(False),
5059
)
5160
(

python/tests/__snapshots__/test_array_api/test_jit[lda][initial_expr].py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
_IndexKey_5 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(2)), _MultiAxisIndexKeyItem_1)))
2424
_IndexKey_6 = IndexKey.ndarray(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(2))))
2525
_NDArray_3[_IndexKey_5] = mean(asarray(_NDArray_1)[_IndexKey_6], _OptionalIntOrTuple_1)
26-
_Value_1 = Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1"))))
2726
_NDArray_4 = zeros(TupleInt.from_vec(Vec[Int](Int(3), Int(4))), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device))
2827
_IndexKey_7 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(0)), _MultiAxisIndexKeyItem_1)))
2928
_NDArray_4[_IndexKey_7] = mean(_NDArray_1[_IndexKey_2], _OptionalIntOrTuple_1)
@@ -42,8 +41,14 @@
4241
OptionalInt.some(Int(0)),
4342
)
4443
_NDArray_6 = std(_NDArray_5, _OptionalIntOrTuple_1)
45-
_NDArray_6[IndexKey.ndarray(std(_NDArray_5, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(_Value_1)
46-
_TupleNDArray_1 = svd(sqrt(asarray(NDArray.scalar(Value.int(_Value_1.to_int / Int(147))), OptionalDType.some(DType.float64))) * (_NDArray_5 / _NDArray_6), Boolean(False))
44+
_NDArray_6[IndexKey.ndarray(std(_NDArray_5, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(
45+
Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1"))))
46+
)
47+
_TupleNDArray_1 = svd(
48+
sqrt(asarray(NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("147"))))), OptionalDType.some(DType.float64)))
49+
* (_NDArray_5 / _NDArray_6),
50+
Boolean(False),
51+
)
4752
_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int))
4853
_NDArray_7 = asarray(reshape(asarray(_NDArray_2), TupleInt.from_vec(Vec[Int](Int(-1)))))
4954
_NDArray_8 = unique_values(concat(TupleNDArray.from_vec(Vec[NDArray](unique_values(asarray(_NDArray_7))))))
@@ -64,7 +69,10 @@
6469
_NDArray_10[IndexKey.ndarray(_NDArray_9 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0)))
6570
_NDArray_11 = astype(unique_counts(_NDArray_2)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("150"), BigInt.from_string("1")))))
6671
_TupleNDArray_2 = svd(
67-
(sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_11) * NDArray.scalar(Value.int(_Value_1.to_int / Int(2)))) * (_NDArray_4 - (_NDArray_11 @ _NDArray_4)).T).T
72+
(
73+
sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_11) * NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("2"))))))
74+
* (_NDArray_4 - (_NDArray_11 @ _NDArray_4)).T
75+
).T
6876
@ (
6977
(
7078
_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))]

python/tests/test_array_api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828

2929

3030
def test_upcast_order():
31-
# verify that the following works by upcasting properly to floats
32-
assert Int(2) > round(0.5 * Int(2))
31+
assert Int(2) > round(0.5 * Int(2)) # type: ignore[operator]
3332

3433

3534
@function(ruleset=array_api_ruleset)

python/tests/test_high_level.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -947,7 +947,7 @@ def test_type_param_sub():
947947
assert str(V[Unit]) == str(Vec[Unit] | int) # type: ignore[misc]
948948

949949

950-
def test_override_hash(self):
950+
def test_override_hash():
951951
class A(Expr):
952952
def __init__(self) -> None: ...
953953

0 commit comments

Comments
 (0)