Skip to content

Commit d7cc18d

Browse files
authored
Add get_present_trait method to statement class (#370)
Noticed in QuEraComputing/bloqade-circuit#163 that this is useful for linting, but also just for shortening code here and there. I also replaced code that raises a `ValueError` if the result from `get_trait` is `None` here (2nd commit).
1 parent 1fb1de2 commit d7cc18d

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

src/kirin/ir/method.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,24 +68,18 @@ def arg_types(self):
6868
@property
6969
def self_type(self):
7070
"""Return the type of the self argument of the method."""
71-
trait = self.code.get_trait(HasSignature)
72-
if trait is None:
73-
raise ValueError("Method body must implement HasSignature")
71+
trait = self.code.get_present_trait(HasSignature)
7472
signature = trait.get_signature(self.code)
7573
return Generic(Method, Generic(tuple, *signature.inputs), signature.output)
7674

7775
@property
7876
def callable_region(self):
79-
trait = self.code.get_trait(CallableStmtInterface)
80-
if trait is None:
81-
raise ValueError("Method body must implement CallableStmtInterface")
77+
trait = self.code.get_present_trait(CallableStmtInterface)
8278
return trait.get_callable_region(self.code)
8379

8480
@property
8581
def return_type(self):
86-
trait = self.code.get_trait(HasSignature)
87-
if trait is None:
88-
raise ValueError("Method body must implement HasSignature")
82+
trait = self.code.get_present_trait(HasSignature)
8983
return trait.get_signature(self.code).output
9084

9185
def __repr__(self) -> str:

src/kirin/ir/nodes/stmt.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,15 @@ def get_trait(cls, trait: type[TraitType]) -> TraitType | None:
689689
return t
690690
return None
691691

692+
@classmethod
693+
def get_present_trait(cls, trait: type[TraitType]) -> TraitType:
694+
"""Just like get_trait, but expects the trait to be there.
695+
Useful for linter checks, when you know the trait is present."""
696+
for t in cls.traits:
697+
if isinstance(t, trait):
698+
return t
699+
raise ValueError(f"Trait {trait} not present in statement {cls}")
700+
692701
def expect_one_result(self) -> ResultValue:
693702
"""Check if the statement contain only one result, and return it"""
694703
if len(self._results) != 1:

0 commit comments

Comments
 (0)