Skip to content

Commit 09db089

Browse files
authored
remove return_type make return_type return signature output (#152)
1 parent d7ad7ef commit 09db089

File tree

4 files changed

+18
-36
lines changed

4 files changed

+18
-36
lines changed

src/kirin/ir/method.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
from typing import TYPE_CHECKING, Generic, TypeVar, Callable, ParamSpec
33
from dataclasses import field, dataclass
44

5-
from kirin.ir.traits import CallableStmtInterface
5+
from kirin.ir.traits import HasSignature, CallableStmtInterface
66
from kirin.exceptions import InterpreterError, VerificationError
77
from kirin.ir.nodes.stmt import Statement
88
from kirin.print.printer import Printer
9-
from kirin.ir.attrs.types import TypeAttribute
109
from kirin.print.printable import Printable
1110

1211
if TYPE_CHECKING:
@@ -27,7 +26,6 @@ class Method(Printable, Generic[Param, RetType]):
2726
# values contained if closure
2827
fields: tuple = field(default_factory=tuple) # own
2928
file: str = ""
30-
return_type: TypeAttribute | None = None
3129
inferred: bool = False
3230
"""if typeinfer has been run on this method
3331
"""
@@ -58,6 +56,13 @@ def callable_region(self):
5856
raise ValueError("Method body must implement CallableStmtInterface")
5957
return trait.get_callable_region(self.code)
6058

59+
@property
60+
def return_type(self):
61+
trait = self.code.get_trait(HasSignature)
62+
if trait is None:
63+
raise ValueError("Method body must implement HasSignature")
64+
return trait.get_signature(self.code).output
65+
6166
def __repr__(self) -> str:
6267
return f'Method("{self.sym_name}")'
6368

src/kirin/passes/typeinfer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from dataclasses import dataclass
22

3-
from kirin.ir import Method
3+
from kirin.ir import Method, HasSignature
44
from kirin.rewrite import Walk
55
from kirin.passes.abc import Pass
66
from kirin.rewrite.abc import RewriteResult
7+
from kirin.dialects.func import Signature
78
from kirin.analysis.typeinfer import TypeInference
89
from kirin.rewrite.apply_type import ApplyType
910

@@ -16,8 +17,10 @@ def __post_init__(self):
1617

1718
def unsafe_run(self, mt: Method) -> RewriteResult:
1819
return_type = self.infer.eval(mt, mt.arg_types).expect()
19-
mt.return_type = return_type
20+
if trait := mt.code.get_trait(HasSignature):
21+
trait.set_signature(mt.code, Signature(mt.arg_types, return_type))
22+
23+
result = Walk(ApplyType(self.infer.results)).rewrite(mt.code)
2024
mt.inferred = True
21-
result = Walk(ApplyType(return_type, self.infer.results)).rewrite(mt.code)
2225
self.infer.results.clear()
2326
return result

src/kirin/rewrite/apply_type.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,27 @@
11
from dataclasses import dataclass
22

3-
from kirin.ir import (
4-
Block,
5-
SSAValue,
6-
Statement,
7-
HasSignature,
8-
CallableStmtInterface,
9-
types,
10-
)
3+
from kirin.ir import Block, SSAValue, Statement, types
114
from kirin.rewrite.abc import RewriteRule, RewriteResult
12-
from kirin.dialects.func import Signature
135

146

157
@dataclass
168
class ApplyType(RewriteRule):
17-
ret_type: types.TypeAttribute
189
results: dict[SSAValue, types.TypeAttribute]
1910

20-
def get_type(self, value: SSAValue) -> types.TypeAttribute:
21-
return self.results[value]
22-
# if isinstance(typ, types.PyConst):
23-
# return typ.typ
24-
# return typ
25-
2611
def rewrite_Block(self, node: Block) -> RewriteResult:
2712
has_done_something = False
2813
for arg in node.args:
2914
if arg in self.results:
30-
arg.type = self.get_type(arg)
15+
arg.type = self.results[arg]
3116
has_done_something = True
3217

3318
return RewriteResult(has_done_something=has_done_something)
3419

3520
def rewrite_Statement(self, node: Statement) -> RewriteResult:
36-
if (fn_trait := node.get_trait(HasSignature)) and (
37-
call_trait := node.get_trait(CallableStmtInterface)
38-
):
39-
signature = fn_trait.get_signature(node)
40-
body = call_trait.get_callable_region(node)
41-
inputs = [
42-
self.get_type(arg) if arg in self.results else signature.inputs[idx]
43-
for idx, arg in enumerate(body.blocks[0].args[1:])
44-
]
45-
fn_trait.set_signature(node, Signature(tuple(inputs), self.ret_type))
46-
4721
has_done_something = False
4822
for result in node._results:
4923
if result in self.results:
50-
result.type = self.get_type(result)
24+
result.type = self.results[result]
5125
has_done_something = True
5226

5327
return RewriteResult(has_done_something=has_done_something)

test/analysis/dataflow/typeinfer/test_inter_method.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@ def test_inter_method_infer():
2727
# NOTE: inference of moo should not update foo
2828
assert foo.arg_types[0] == types.Int
2929
assert foo.inferred is False
30-
assert foo.return_type is None
30+
assert foo.return_type is types.Any

0 commit comments

Comments
 (0)