diff --git a/src/kirin/dialects/ilist/rewrite/const.py b/src/kirin/dialects/ilist/rewrite/const.py index dedf58181..81eb3cf95 100644 --- a/src/kirin/dialects/ilist/rewrite/const.py +++ b/src/kirin/dialects/ilist/rewrite/const.py @@ -39,11 +39,24 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: return RewriteResult(has_done_something=has_done_something) def rewrite_Constant(self, node: Constant) -> RewriteResult: - if isinstance(node.value, ir.PyAttr) and isinstance(node.value.data, list): - stmt = Constant(value=IList(data=node.value.data)) + if not isinstance(node.value, ir.PyAttr): + return RewriteResult() + + if isinstance(data := node.value.data, list): + stmt = Constant(value=IList(data=data)) + node.replace_by(stmt) + self._rewrite_IList_type(stmt.result, data) + return RewriteResult(has_done_something=True) + elif isinstance(data, range): + new_constant = IList(data=data, elem=types.Int) + stmt = Constant(value=new_constant) + # specializing the type computation since we know that a + # range will always be integer typed. + stmt.result.hints["const"] = const.Value(new_constant) + stmt.result.type = IListType[types.Int, types.Literal(len(data))] node.replace_by(stmt) - self._rewrite_IList_type(stmt.result, node.value.data) return RewriteResult(has_done_something=True) + return RewriteResult() def _rewrite_IList_type(self, result: ir.SSAValue, data): diff --git a/test/dialects/test_ilist2list.py b/test/dialects/test_ilist2list.py index 14bad06ee..4dcea6842 100644 --- a/test/dialects/test_ilist2list.py +++ b/test/dialects/test_ilist2list.py @@ -28,3 +28,18 @@ def ilist2_list(): x = ilist2_list() assert isinstance(x, ilist.IList) + + +def test_range_rewrite(): + + r = range(10) + + @basic_desugar + def ilist_range(): + return r + + ilist_range.print() + + x = ilist_range() + + assert isinstance(x, ilist.IList)