diff --git a/src/kirin/dialects/py/indexing.py b/src/kirin/dialects/py/indexing.py index 564f25915..a1d88243b 100644 --- a/src/kirin/dialects/py/indexing.py +++ b/src/kirin/dialects/py/indexing.py @@ -98,17 +98,7 @@ class Concrete(interp.MethodTable): @interp.impl(GetItem) def getindex(self, interp, frame: interp.Frame, stmt: GetItem): - from kirin.dialects.py.slice import SliceAttribute - - index = frame.get(stmt.index) - - # need to handle special case of slice attribute - if isinstance(index, SliceAttribute): - index_value = index.unwrap() - else: - index_value = index - - return (frame.get(stmt.obj)[index_value],) + return (frame.get(stmt.obj)[frame.get(stmt.index)],) @dialect.register(key="typeinfer") @@ -218,16 +208,7 @@ def getitem( return (const.Unknown(),) if isinstance(obj, const.Value): - from kirin.dialects.py.slice import SliceAttribute - - # need to handle special case of slice attribute - if isinstance(index.data, SliceAttribute): - index_value = index.data.unwrap() - else: - index_value = index.data - - return (const.Value(obj.data[index_value]),) - + return (const.Value(obj.data[index.data]),) elif isinstance(obj, const.PartialTuple): obj = obj.data if isinstance(index.data, int) and 0 <= index.data < len(obj): diff --git a/src/kirin/dialects/py/slice.py b/src/kirin/dialects/py/slice.py index b45f27277..e7a2adde1 100644 --- a/src/kirin/dialects/py/slice.py +++ b/src/kirin/dialects/py/slice.py @@ -13,7 +13,6 @@ from kirin import ir, types, interp, lowering from kirin.decl import info, statement -from kirin.print.printer import Printer from kirin.dialects.py.constant import Constant dialect = ir.Dialect("py.slice") @@ -63,43 +62,6 @@ def __init__( ) -@dataclass -class SliceAttribute(ir.Data[slice]): - - start: int | None - stop: int | None - step: int | None - - def __post_init__(self) -> None: - if self.start is None and self.step is None: - self.type = types.Slice[types.Literal(self.stop)] - else: - self.type = types.Slice3[ - types.Literal(self.start), - types.Literal(self.stop), - types.Literal(self.step), - ] - - def unwrap(self): - return slice(self.start, self.stop, self.step) - - def __hash__(self): - return hash((type(self), self.start, self.stop, self.step)) - - def print_impl(self, printer: Printer) -> None: - return printer.plain_print(f"slice({self.start}, {self.stop}, {self.step})") - - def is_structurally_equal( - self, other: ir.Attribute, context: dict | None = None - ) -> bool: - return ( - isinstance(other, SliceAttribute) - and self.start == other.start - and self.stop == other.stop - and self.step == other.step - ) - - @dialect.register class Concrete(interp.MethodTable): @@ -107,11 +69,11 @@ class Concrete(interp.MethodTable): def _slice(self, interp, frame: interp.Frame, stmt: Slice): start, stop, step = frame.get_values(stmt.args) if start is None and step is None: - return (SliceAttribute(None, stop, None),) + return (slice(stop),) elif step is None: - return (SliceAttribute(start, stop, None),) + return (slice(start, stop),) else: - return (SliceAttribute(start, stop, step),) + return (slice(start, stop, step),) @dialect.register diff --git a/src/kirin/types.py b/src/kirin/types.py index e24249ca8..2f57e173d 100644 --- a/src/kirin/types.py +++ b/src/kirin/types.py @@ -28,7 +28,6 @@ NoneType = PyClass(type(None)) List = Generic(list, TypeVar("T")) Slice = Generic(slice, TypeVar("T")) -Slice3 = Generic(slice, TypeVar("T1"), TypeVar("T2"), TypeVar("T3")) Tuple = Generic(tuple, Vararg(TypeVar("T"))) Dict = Generic(dict, TypeVar("K"), TypeVar("V")) Set = Generic(set, TypeVar("T")) diff --git a/test/dialects/pystmts/test_slice.py b/test/dialects/pystmts/test_slice.py index 260ccf995..1d02f3fa8 100644 --- a/test/dialects/pystmts/test_slice.py +++ b/test/dialects/pystmts/test_slice.py @@ -1,7 +1,6 @@ from kirin import types from kirin.prelude import basic_no_opt -from kirin.dialects import py, ilist -from kirin.dialects.py.slice import SliceAttribute +from kirin.dialects import py @basic_no_opt @@ -45,32 +44,3 @@ def test_wrong_slice(): stmt: py.slice.Slice = wrong_slice.code.body.blocks[0].stmts.at(7) assert stmt.result.type.is_subseteq(types.Bottom) - - -def test_slice_attr(): - - @basic_no_opt - def test(): - - return (slice(0, 20), slice(30), slice(1, 40, 5)) - - result = test() - assert result == ( - SliceAttribute(0, 20, None), - SliceAttribute(None, 30, None), - SliceAttribute(1, 40, 5), - ) - - -def test_slice_attr_hash(): - assert hash(SliceAttribute(0, 20, None)) == hash((SliceAttribute, 0, 20, None)) - - -def test_slice_get_index(): - @basic_no_opt - def test(): - x = slice(0, 20, None) - y = range(40) - return y[x] - - assert test() == ilist.IList(range(0, 20, 1))