Skip to content

Commit 67126c6

Browse files
tmp working
1 parent f3b96c6 commit 67126c6

File tree

5 files changed

+208
-125
lines changed

5 files changed

+208
-125
lines changed

python/egglog/builtins.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,24 @@
2222

2323

2424
__all__ = [
25-
"i64",
26-
"i64Like",
27-
"f64",
28-
"f64Like",
2925
"Bool",
3026
"BoolLike",
31-
"String",
32-
"StringLike",
3327
"Map",
28+
"PyObject",
3429
"Rational",
3530
"Set",
31+
"String",
32+
"StringLike",
33+
"UnstableFn",
3634
"Vec",
35+
"f64",
36+
"f64Like",
37+
"i64",
38+
"i64Like",
3739
"join",
38-
"PyObject",
3940
"py_eval",
40-
"py_exec",
4141
"py_eval_fn",
42-
"UnstableFn",
42+
"py_exec",
4343
]
4444

4545

@@ -210,6 +210,9 @@ def __truediv__(self, other: f64Like) -> f64: ...
210210
@method(egg_fn="%")
211211
def __mod__(self, other: f64Like) -> f64: ...
212212

213+
@method(egg_fn="^")
214+
def __pow__(self, other: f64Like) -> f64: ...
215+
213216
def __radd__(self, other: f64Like) -> f64: ...
214217

215218
def __rsub__(self, other: f64Like) -> f64: ...

python/egglog/exp/array_api.py

Lines changed: 134 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int):
189189
yield rewrite(Int(i) + Int(j)).to(Int(i + j))
190190
yield rewrite(Int(i) - Int(j)).to(Int(i - j))
191191
yield rewrite(Int(i) * Int(j)).to(Int(i * j))
192-
yield rewrite(Int(i) / Int(j)).to(Int(i / j))
192+
yield rewrite(Int(i) // Int(j)).to(Int(i / j))
193193
yield rewrite(Int(i) % Int(j)).to(Int(i % j))
194194
yield rewrite(Int(i) & Int(j)).to(Int(i & j))
195195
yield rewrite(Int(i) | Int(j)).to(Int(i | j))
@@ -219,15 +219,17 @@ def abs(self) -> Float: ...
219219
def rational(cls, r: Rational) -> Float: ...
220220

221221
@classmethod
222-
def from_int(cls, i: Int) -> Float: ...
222+
def from_int(cls, i: IntLike) -> Float: ...
223223

224-
def __truediv__(self, other: Float) -> Float: ...
224+
def __truediv__(self, other: FloatLike) -> Float: ...
225225

226-
def __mul__(self, other: Float) -> Float: ...
226+
def __mul__(self, other: FloatLike) -> Float: ...
227227

228-
def __add__(self, other: Float) -> Float: ...
228+
def __add__(self, other: FloatLike) -> Float: ...
229229

230-
def __sub__(self, other: Float) -> Float: ...
230+
def __sub__(self, other: FloatLike) -> Float: ...
231+
232+
def __pow__(self, other: FloatLike) -> Float: ...
231233

232234

233235
converter(float, Float, lambda x: Float(x))
@@ -252,6 +254,7 @@ def _float(f: f64, f2: f64, i: i64, r: Rational, r1: Rational):
252254
rewrite(Float.rational(r) + Float.rational(r1)).to(Float.rational(r + r1)),
253255
rewrite(Float.rational(r) - Float.rational(r1)).to(Float.rational(r - r1)),
254256
rewrite(Float.rational(r) * Float.rational(r1)).to(Float.rational(r * r1)),
257+
rewrite(Float(f) ** Float(f2)).to(Float(f**f2)),
255258
]
256259

257260

@@ -271,6 +274,7 @@ def var(cls, name: StringLike) -> TupleInt: ...
271274

272275
EMPTY: ClassVar[TupleInt]
273276

277+
@method(cost=100)
274278
def __init__(self, length: IntLike, idx_fn: Callable[[Int], Int]) -> None: ...
275279

276280
@classmethod
@@ -325,13 +329,57 @@ def if_(cls, b: Boolean, i: TupleInt, j: TupleInt) -> TupleInt: ...
325329
def to_py(self) -> tuple[int, ...]:
326330
return tuple(int(i) for i in self)
327331

332+
@method(subsume=True)
333+
def drop(self, n: Int) -> TupleInt:
334+
return TupleInt(self.length() - n, lambda i: self[i + n])
335+
336+
@method(subsume=True)
337+
def product(self) -> Int:
338+
return self.fold(Int(1), lambda acc, i: acc * i)
339+
340+
def map_tuple_int(self, f: Callable[[Int], TupleInt]) -> TupleTupleInt: ...
341+
342+
def append(self, i: Int) -> TupleInt: ...
343+
328344

329345
# TODO: Upcast args for Vec[Int] constructor
330346
converter(tuple, TupleInt, lambda x: TupleInt.from_vec(Vec(*(convert(i, Int) for i in x))))
331347

332348
TupleIntLike: TypeAlias = TupleInt | tuple[IntLike, ...]
333349

334350

351+
@array_api_ruleset.register
352+
def _tuple_int_create_from_vec(
353+
x: NDArray, idx_fn: Callable[[Int], Int], i: i64, xs: Vec[Int], ti: TupleInt, ti2: TupleInt, v: Int
354+
):
355+
"""
356+
Turn a tuple into constructor with a known length into a from_vec constructor
357+
"""
358+
# # create from_vec from zero length tuple
359+
# yield rule(eq(ti).to(TupleInt(0, idx_fn))).then(union(ti).with_(TupleInt.from_vec(Vec[Int]())))
360+
361+
# yield rewrite(x.index(TupleInt(0, idx_fn))).to(x.index(TupleInt.from_vec(Vec[Int]())))
362+
# yield rewrite(x.index(TupleInt(1, idx_fn))).to(x.index(TupleInt.from_vec(Vec(idx_fn(Int(0))))))
363+
# yield rewrite(x.index(TupleInt(2, idx_fn))).to(x.index(TupleInt.from_vec(Vec(idx_fn(Int(0)), idx_fn(Int(1))))))
364+
# yield rewrite(x.index(TupleInt(3, idx_fn))).to(
365+
# x.index(TupleInt.from_vec(Vec(idx_fn(Int(0)), idx_fn(Int(1)), idx_fn(Int(2)))))
366+
# )
367+
yield rewrite(x.index(TupleInt(4, idx_fn))).to(
368+
x.index(
369+
TupleInt.from_vec(Vec(idx_fn(Int(0)), idx_fn(Int(1)), idx_fn(Int(2)), idx_fn(Int(3))))
370+
# TupleInt.EMPTY.append(idx_fn(Int(0))).append(idx_fn(Int(1))).append(idx_fn(Int(2))).append(idx_fn(Int(3)))
371+
)
372+
)
373+
374+
375+
# # Also create it when appending onto a tuple that already has a vec
376+
# # yield rule(eq(ti).to(ti2.append(v)), eq(ti2).to(TupleInt.from_vec(xs))).then(
377+
# # union(ti).with_(TupleInt.from_vec(xs.append(Vec(v))))
378+
# # )
379+
# # Split up known length tuple vecs into append calls so they will be transformed into from_vec
380+
# yield rewrite(TupleInt(i, idx_fn)).to(TupleInt(i - 1, idx_fn).append(idx_fn(Int(i - 1))), i > 0)
381+
382+
335383
@array_api_ruleset.register
336384
def _tuple_int(
337385
i: Int,
@@ -340,7 +388,7 @@ def _tuple_int(
340388
f: Callable[[Int, Int], Int],
341389
bool_f: Callable[[Boolean, Int], Boolean],
342390
idx_fn: Callable[[Int], Int],
343-
map_fn: Callable[[Int], Int],
391+
map_tuple_int_fn: Callable[[Int], TupleInt],
344392
filter_f: Callable[[Int], Boolean],
345393
vs: Vec[Int],
346394
b: Boolean,
@@ -351,7 +399,7 @@ def _tuple_int(
351399
rewrite(TupleInt(i, idx_fn).length()).to(i),
352400
rewrite(TupleInt(i, idx_fn)[i2]).to(idx_fn(i2)),
353401
# index_vec_int
354-
rewrite(index_vec_int(vs, Int(k))).to(vs[k], vs.length() > k),
402+
rule(eq(i).to(index_vec_int(vs, Int(k))), k < vs.length(), k >= 0).then(union(i).with_(vs[k])),
355403
# fold
356404
rewrite(TupleInt(0, idx_fn).fold(i, f)).to(i),
357405
rewrite(TupleInt(Int(k), idx_fn).fold(i, f)).to(
@@ -379,6 +427,10 @@ def _tuple_int(
379427
# if_
380428
rewrite(TupleInt.if_(TRUE, ti, ti2)).to(ti),
381429
rewrite(TupleInt.if_(FALSE, ti, ti2)).to(ti2),
430+
# map_tuple_int
431+
rewrite(TupleInt(i, idx_fn).map_tuple_int(map_tuple_int_fn)).to(
432+
TupleTupleInt(i, lambda i: map_tuple_int_fn(idx_fn(i)))
433+
),
382434
]
383435

384436

@@ -418,6 +470,55 @@ def __len__(self) -> int:
418470
def __iter__(self) -> Iterator[TupleInt]:
419471
return iter(self[i] for i in range(len(self)))
420472

473+
def drop(self, n: Int) -> TupleTupleInt:
474+
return TupleTupleInt(self.length() - n, lambda i: self[i + n])
475+
476+
def map_int(self, f: Callable[[TupleInt], Int]) -> TupleInt: ...
477+
478+
def reduce_value(self, f: Callable[[Value, TupleInt], Value], init: ValueLike) -> Value: ...
479+
480+
def product(self) -> TupleTupleInt:
481+
"""
482+
Cartesian product of inputs
483+
484+
https://docs.python.org/3/library/itertools.html#itertools.product
485+
486+
https://github.com/saulshanabrook/saulshanabrook/discussions/39
487+
"""
488+
return TupleTupleInt(
489+
self.map_int(lambda x: x.length()).product(),
490+
lambda i: TupleInt(
491+
self.length(),
492+
lambda j: self[j][i // self.drop(j).map_int(lambda x: x.length()).product() % self[j].length()],
493+
),
494+
)
495+
496+
497+
@array_api_ruleset.register
498+
def _tuple_tuple_int(
499+
length: Int,
500+
fn: Callable[[TupleInt], Int],
501+
idx_fn: Callable[[Int], TupleInt],
502+
f: Callable[[Value, TupleInt], Value],
503+
i: Value,
504+
k: i64,
505+
idx: Int,
506+
):
507+
yield rewrite(TupleTupleInt(length, idx_fn).length()).to(length)
508+
509+
yield rewrite(TupleTupleInt(length, idx_fn)[idx]).to(idx_fn(idx))
510+
511+
yield rewrite(TupleTupleInt(length, idx_fn).map_int(fn), subsume=True).to(TupleInt(length, lambda i: fn(idx_fn(i))))
512+
513+
yield rewrite(TupleTupleInt(0, idx_fn).reduce_value(f, i)).to(i)
514+
yield rewrite(TupleTupleInt(Int(k), idx_fn).reduce_value(f, i), subsume=True).to(
515+
f(
516+
TupleTupleInt(k - 1, lambda i: idx_fn(i + 1)).reduce_value(f, i),
517+
idx_fn(Int(0)),
518+
),
519+
ne(k).to(i64(0)),
520+
)
521+
421522

422523
@function
423524
def bottom_indexing(i: Int) -> Int: ...
@@ -627,19 +728,23 @@ class Device(Expr): ...
627728

628729
class Value(Expr):
629730
@classmethod
630-
def int(cls, i: Int) -> Value: ...
731+
def int(cls, i: IntLike) -> Value: ...
631732

632733
@classmethod
633-
def float(cls, f: Float) -> Value: ...
734+
def float(cls, f: FloatLike) -> Value: ...
634735

635736
@classmethod
636-
def bool(cls, b: Boolean) -> Value: ...
737+
def bool(cls, b: BooleanLike) -> Value: ...
637738

638739
def isfinite(self) -> Boolean: ...
639740

640-
def __lt__(self, other: Value) -> Value: ...
741+
def __lt__(self, other: ValueLike) -> Value: ...
742+
743+
def __truediv__(self, other: ValueLike) -> Value: ...
641744

642-
def __truediv__(self, other: Value) -> Value: ...
745+
def __mul__(self, other: ValueLike) -> Value: ...
746+
747+
def __add__(self, other: ValueLike) -> Value: ...
643748

644749
def astype(self, dtype: DType) -> Value: ...
645750

@@ -665,17 +770,21 @@ def to_truthy_value(self) -> Value:
665770
https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.any.html
666771
"""
667772

773+
def conj(self) -> Value: ...
774+
def real(self) -> Value: ...
775+
def sqrt(self) -> Value: ...
776+
777+
778+
ValueLike: TypeAlias = Value | IntLike | FloatLike | BooleanLike
668779

669780
converter(Int, Value, Value.int)
670781
converter(Float, Value, Value.float)
671782
converter(Boolean, Value, Value.bool)
672783
converter(Value, Int, lambda x: x.to_int, 10)
673784

674-
ValueLike: TypeAlias = Value | IntLike | FloatLike | BooleanLike
675-
676785

677786
@array_api_ruleset.register
678-
def _value(i: Int, f: Float, b: Boolean):
787+
def _value(i: Int, f: Float, b: Boolean, v: Value):
679788
# Default dtypes
680789
# https://data-apis.org/array-api/latest/API_specification/data_types.html?highlight=dtype#default-data-types
681790
yield rewrite(Value.int(i).dtype).to(DType.int64)
@@ -688,6 +797,15 @@ def _value(i: Int, f: Float, b: Boolean):
688797
yield rewrite(Value.bool(b).to_truthy_value).to(Value.bool(b))
689798
# TODO: Add more rules for to_bool_value
690799

800+
yield rewrite(Value.float(f).conj()).to(Value.float(f))
801+
yield rewrite(Value.float(f).real()).to(Value.float(f))
802+
yield rewrite(Value.int(i).real()).to(Value.int(i))
803+
yield rewrite(Value.int(i).conj()).to(Value.int(i))
804+
805+
yield rewrite(Value.float(f).sqrt()).to(Value.float(f ** (0.5)))
806+
807+
yield rewrite(Value.float(Float.rational(Rational(0, 1))) + v).to(v)
808+
691809

692810
class TupleValue(Expr):
693811
EMPTY: ClassVar[TupleValue]

0 commit comments

Comments
 (0)