Skip to content

Commit 7848e74

Browse files
slight fixes in array api filter
1 parent 0dbb348 commit 7848e74

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

python/egglog/exp/array_api.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def single(cls, i: Int) -> TupleInt:
278278
return TupleInt(Int(1), lambda _: i)
279279

280280
@classmethod
281-
def range(cls, stop: Int) -> TupleInt:
281+
def range(cls, stop: IntLike) -> TupleInt:
282282
return TupleInt(stop, lambda i: i)
283283

284284
@classmethod
@@ -346,7 +346,6 @@ def _tuple_int(
346346
ti: TupleInt,
347347
ti2: TupleInt,
348348
):
349-
remaining = TupleInt(k - 1, lambda i: idx_fn(i + 1)).filter(filter_f)
350349
return [
351350
rewrite(TupleInt(i, idx_fn).length()).to(i),
352351
rewrite(TupleInt(i, idx_fn)[i2]).to(idx_fn(i2)),
@@ -367,7 +366,11 @@ def _tuple_int(
367366
# filter TODO: could be written as fold w/ generic types
368367
rewrite(TupleInt(0, idx_fn).filter(filter_f)).to(TupleInt(0, idx_fn)),
369368
rewrite(TupleInt(Int(k), idx_fn).filter(filter_f)).to(
370-
TupleInt.if_(filter_f(value := idx_fn(Int(k))), TupleInt.single(value) + remaining, remaining),
369+
TupleInt.if_(
370+
filter_f(value := idx_fn(Int(k - 1))),
371+
(remaining := TupleInt(k - 1, idx_fn).filter(filter_f)) + TupleInt.single(value),
372+
remaining,
373+
),
371374
ne(k).to(i64(0)),
372375
),
373376
# Empty

0 commit comments

Comments
 (0)