Skip to content

Commit 42d6f8c

Browse files
authored
move range implementation to IList (#154)
now range is conditionally enabled by IList. cc: @kaihsin
1 parent 09db089 commit 42d6f8c

File tree

4 files changed

+22
-10
lines changed

4 files changed

+22
-10
lines changed

src/kirin/dialects/ilist/interp.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from kirin.interp import Err, Frame, Interpreter, MethodTable, impl
33
from kirin.dialects.py.len import Len
44
from kirin.dialects.py.binop import Add
5+
from kirin.dialects.py.range import Range
56

67
from .stmts import Map, New, Push, Scan, Foldl, Foldr, ForEach
78
from .runtime import IList
@@ -11,6 +12,10 @@
1112
@dialect.register
1213
class IListInterpreter(MethodTable):
1314

15+
@impl(Range)
16+
def _range(self, interp, frame: Frame, stmt: Range):
17+
return (IList(range(*frame.get_values(stmt.args))),)
18+
1419
@impl(New)
1520
def new(self, interp, frame: Frame, stmt: New):
1621
return (IList(list(frame.get_values(stmt.values))),)

src/kirin/dialects/ilist/runtime.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# TODO: replace with something faster
22
from typing import Any, Generic, TypeVar, overload
33
from dataclasses import dataclass
4+
from collections.abc import Sequence
45

56
T = TypeVar("T")
67
L = TypeVar("L")
@@ -10,7 +11,7 @@
1011
class IList(Generic[T, L]):
1112
"""A simple immutable list."""
1213

13-
data: list[T]
14+
data: Sequence[T]
1415

1516
def __len__(self) -> int:
1617
return len(self.data)

src/kirin/dialects/py/range.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import ast
22
from dataclasses import dataclass
33

4-
from kirin import ir, types, interp, lowering, exceptions
4+
from kirin import ir, types, lowering, exceptions
55
from kirin.decl import info, statement
66

77
dialect = ir.Dialect("py.range")
@@ -35,14 +35,6 @@ def lower_Call_range(
3535
return _lower_range(state, node)
3636

3737

38-
@dialect.register
39-
class Concrete(interp.MethodTable):
40-
41-
@interp.impl(Range)
42-
def _range(self, interp, frame: interp.Frame, stmt: Range):
43-
return (range(*frame.get_values(stmt.args)),)
44-
45-
4638
def _lower_range(state: lowering.LoweringState, node: ast.Call) -> lowering.Result:
4739
if len(node.args) == 1:
4840
start = state.visit(ast.Constant(0)).expect_one()

test/dialects/test_ilist.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,17 @@ def foldr(xs: ilist.IList[int, Literal[3]]):
145145
assert isinstance(stmt, py.Constant)
146146
assert stmt.value == 0
147147
assert isinstance(foldr.callable_region.blocks[0].stmts.at(10), func.Call)
148+
149+
150+
def test_ilist_range():
151+
@basic.add(py.range)
152+
def map():
153+
return ilist.Map(add1, range(0, 3)) # type: ignore
154+
155+
assert map() == ilist.IList([1, 2, 3])
156+
157+
@basic.add(py.range)
158+
def const_range():
159+
return range(0, 3)
160+
161+
assert const_range() == ilist.IList(range(0, 3))

0 commit comments

Comments
 (0)