Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 2 additions & 21 deletions src/kirin/dialects/py/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,7 @@ class Concrete(interp.MethodTable):

@interp.impl(GetItem)
def getindex(self, interp, frame: interp.Frame, stmt: GetItem):
from kirin.dialects.py.slice import SliceAttribute

index = frame.get(stmt.index)

# need to handle special case of slice attribute
if isinstance(index, SliceAttribute):
index_value = index.unwrap()
else:
index_value = index

return (frame.get(stmt.obj)[index_value],)
return (frame.get(stmt.obj)[frame.get(stmt.index)],)


@dialect.register(key="typeinfer")
Expand Down Expand Up @@ -218,16 +208,7 @@ def getitem(
return (const.Unknown(),)

if isinstance(obj, const.Value):
from kirin.dialects.py.slice import SliceAttribute

# need to handle special case of slice attribute
if isinstance(index.data, SliceAttribute):
index_value = index.data.unwrap()
else:
index_value = index.data

return (const.Value(obj.data[index_value]),)

return (const.Value(obj.data[index.data]),)
elif isinstance(obj, const.PartialTuple):
obj = obj.data
if isinstance(index.data, int) and 0 <= index.data < len(obj):
Expand Down
44 changes: 3 additions & 41 deletions src/kirin/dialects/py/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from kirin import ir, types, interp, lowering
from kirin.decl import info, statement
from kirin.print.printer import Printer
from kirin.dialects.py.constant import Constant

dialect = ir.Dialect("py.slice")
Expand Down Expand Up @@ -63,55 +62,18 @@ def __init__(
)


@dataclass
class SliceAttribute(ir.Data[slice]):

start: int | None
stop: int | None
step: int | None

def __post_init__(self) -> None:
if self.start is None and self.step is None:
self.type = types.Slice[types.Literal(self.stop)]
else:
self.type = types.Slice3[
types.Literal(self.start),
types.Literal(self.stop),
types.Literal(self.step),
]

def unwrap(self):
return slice(self.start, self.stop, self.step)

def __hash__(self):
return hash((type(self), self.start, self.stop, self.step))

def print_impl(self, printer: Printer) -> None:
return printer.plain_print(f"slice({self.start}, {self.stop}, {self.step})")

def is_structurally_equal(
self, other: ir.Attribute, context: dict | None = None
) -> bool:
return (
isinstance(other, SliceAttribute)
and self.start == other.start
and self.stop == other.stop
and self.step == other.step
)


@dialect.register
class Concrete(interp.MethodTable):

@interp.impl(Slice)
def _slice(self, interp, frame: interp.Frame, stmt: Slice):
start, stop, step = frame.get_values(stmt.args)
if start is None and step is None:
return (SliceAttribute(None, stop, None),)
return (slice(stop),)
elif step is None:
return (SliceAttribute(start, stop, None),)
return (slice(start, stop),)
else:
return (SliceAttribute(start, stop, step),)
return (slice(start, stop, step),)


@dialect.register
Expand Down
1 change: 0 additions & 1 deletion src/kirin/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
NoneType = PyClass(type(None))
List = Generic(list, TypeVar("T"))
Slice = Generic(slice, TypeVar("T"))
Slice3 = Generic(slice, TypeVar("T1"), TypeVar("T2"), TypeVar("T3"))
Tuple = Generic(tuple, Vararg(TypeVar("T")))
Dict = Generic(dict, TypeVar("K"), TypeVar("V"))
Set = Generic(set, TypeVar("T"))
Expand Down
32 changes: 1 addition & 31 deletions test/dialects/pystmts/test_slice.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from kirin import types
from kirin.prelude import basic_no_opt
from kirin.dialects import py, ilist
from kirin.dialects.py.slice import SliceAttribute
from kirin.dialects import py


@basic_no_opt
Expand Down Expand Up @@ -45,32 +44,3 @@ def test_wrong_slice():

stmt: py.slice.Slice = wrong_slice.code.body.blocks[0].stmts.at(7)
assert stmt.result.type.is_subseteq(types.Bottom)


def test_slice_attr():

@basic_no_opt
def test():

return (slice(0, 20), slice(30), slice(1, 40, 5))

result = test()
assert result == (
SliceAttribute(0, 20, None),
SliceAttribute(None, 30, None),
SliceAttribute(1, 40, 5),
)


def test_slice_attr_hash():
assert hash(SliceAttribute(0, 20, None)) == hash((SliceAttribute, 0, 20, None))


def test_slice_get_index():
@basic_no_opt
def test():
x = slice(0, 20, None)
y = range(40)
return y[x]

assert test() == ilist.IList(range(0, 20, 1))