|
| 1 | +from abc import abstractmethod |
1 | 2 | from dataclasses import dataclass |
2 | 3 |
|
3 | 4 | from kirin import ir |
@@ -40,33 +41,47 @@ def print_impl(self, printer: Printer) -> None: |
40 | 41 |
|
41 | 42 |
|
42 | 43 | @dataclass |
43 | | -class WrapSquinAnalysis(RewriteRule): |
| 44 | +class WrapAnalysis(RewriteRule): |
44 | 45 |
|
| 46 | + @abstractmethod |
| 47 | + def wrap(self, value: ir.SSAValue) -> bool: |
| 48 | + pass |
| 49 | + |
| 50 | + def rewrite_Block(self, node: ir.Block) -> RewriteResult: |
| 51 | + has_done_something = any(self.wrap(arg) for arg in node.args) |
| 52 | + return RewriteResult(has_done_something=has_done_something) |
| 53 | + |
| 54 | + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: |
| 55 | + has_done_something = any(self.wrap(result) for result in node.results) |
| 56 | + return RewriteResult(has_done_something=has_done_something) |
| 57 | + |
| 58 | + |
| 59 | +@dataclass |
| 60 | +class WrapAddressAnalysis(WrapAnalysis): |
45 | 61 | address_analysis: dict[ir.SSAValue, Address] |
46 | | - op_site_analysis: dict[ir.SSAValue, Sites] |
47 | 62 |
|
48 | 63 | def wrap(self, value: ir.SSAValue) -> bool: |
49 | 64 | address_analysis_result = self.address_analysis[value] |
50 | | - op_site_analysis_result = self.op_site_analysis[value] |
51 | 65 |
|
52 | | - if value.hints.get("address") and value.hints.get("sites"): |
| 66 | + if value.hints.get("address") is not None: |
53 | 67 | return False |
54 | | - else: |
55 | | - value.hints["address"] = AddressAttribute(address_analysis_result) |
56 | | - value.hints["sites"] = SitesAttribute(op_site_analysis_result) |
| 68 | + |
| 69 | + value.hints["address"] = AddressAttribute(address_analysis_result) |
57 | 70 |
|
58 | 71 | return True |
59 | 72 |
|
60 | | - def rewrite_Block(self, node: ir.Block) -> RewriteResult: |
61 | | - has_done_something = False |
62 | | - for arg in node.args: |
63 | | - if self.wrap(arg): |
64 | | - has_done_something = True |
65 | | - return RewriteResult(has_done_something=has_done_something) |
66 | 73 |
|
67 | | - def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: |
68 | | - has_done_something = False |
69 | | - for result in node.results: |
70 | | - if self.wrap(result): |
71 | | - has_done_something = True |
72 | | - return RewriteResult(has_done_something=has_done_something) |
| 74 | +@dataclass |
| 75 | +class WrapOpSiteAnalysis(WrapAnalysis): |
| 76 | + |
| 77 | + op_site_analysis: dict[ir.SSAValue, Sites] |
| 78 | + |
| 79 | + def wrap(self, value: ir.SSAValue) -> bool: |
| 80 | + op_site_analysis_result = self.op_site_analysis[value] |
| 81 | + |
| 82 | + if value.hints.get("sites") is not None: |
| 83 | + return False |
| 84 | + |
| 85 | + value.hints["sites"] = SitesAttribute(op_site_analysis_result) |
| 86 | + |
| 87 | + return True |
0 commit comments