Skip to content

Commit 1e2b257

Browse files
authored
fix CSE (#414)
pointed out by @kaihsin we should actually use a key object with both hash and eq implemented.
1 parent 69607af commit 1e2b257

File tree

5 files changed

+66
-40
lines changed

5 files changed

+66
-40
lines changed

src/kirin/ir/attrs/abc.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ class Attribute(ABC, Printable, metaclass=AttributeMeta):
5050
@abstractmethod
5151
def __hash__(self) -> int: ...
5252

53+
@abstractmethod
54+
def __eq__(self, value: object) -> bool: ...
55+
5356
@classmethod
5457
def has_trait(cls, trait_type: type[Trait["Attribute"]]) -> bool:
5558
"""Check if the Statement has a specific trait.

src/kirin/ir/attrs/py.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@ def __init__(self, data: T, pytype: TypeAttribute | None = None):
3434
else:
3535
self.type = pytype
3636

37-
def __hash__(self):
38-
# Fix hash(-1) == hash(-2) collision
39-
# assume maximum is 8 bytes == 64 bits
40-
if isinstance(self.data, int):
41-
return hash(self.data.to_bytes(signed=True, byteorder="big", length=8))
42-
elif isinstance(self.data, float):
43-
return hash(self.data.hex())
44-
return hash(self.data) + hash(self.type)
37+
def __hash__(self) -> int:
38+
return hash((self.type, self.data))
39+
40+
def __eq__(self, value: object) -> bool:
41+
if not isinstance(value, PyAttr):
42+
return False
43+
44+
return self.type == value.type and self.data == value.data
4545

4646
def print_impl(self, printer: Printer) -> None:
4747
printer.plain_print(repr(self.data))

src/kirin/ir/attrs/types.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,14 +351,20 @@ def print_impl(self, printer: Printer) -> None:
351351

352352

353353
@typing.final
354-
@dataclass(eq=False)
354+
@dataclass
355355
class Vararg(Attribute):
356356
name = "Vararg"
357357
typ: TypeAttribute
358358

359359
def __hash__(self) -> int:
360360
return hash((Vararg, self.typ))
361361

362+
def __eq__(self, value: object) -> bool:
363+
if not isinstance(value, Vararg):
364+
return False
365+
366+
return self.typ == value.typ
367+
362368
def print_impl(self, printer: Printer) -> None:
363369
printer.plain_print("*")
364370
printer.print(self.typ)

src/kirin/rewrite/cse.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,74 @@
1+
from __future__ import annotations
2+
13
from dataclasses import dataclass
24

3-
from kirin.ir import Pure, Block, Statement
5+
from kirin import ir
46
from kirin.rewrite.abc import RewriteRule, RewriteResult
57

68

9+
@dataclass
10+
class Info:
11+
"""An object to hold the comparison information of a statement."""
12+
13+
head: type[ir.Statement]
14+
args: tuple[ir.SSAValue, ...]
15+
attributes: tuple[ir.Attribute, ...]
16+
successors: tuple[ir.Block, ...]
17+
regions: tuple[ir.Region, ...]
18+
19+
def __hash__(self) -> int:
20+
return hash(
21+
(id(self.head),)
22+
+ tuple(id(ssa) for ssa in self.args)
23+
+ tuple(hash(attr) for attr in self.attributes)
24+
+ tuple(id(succ) for succ in self.successors)
25+
+ tuple(id(region) for region in self.regions)
26+
)
27+
28+
def __eq__(self, other: object) -> bool:
29+
if not isinstance(other, Info):
30+
return False
31+
32+
return (
33+
self.head == other.head
34+
and self.args == other.args
35+
and self.attributes == other.attributes
36+
and self.successors == other.successors
37+
and self.regions == other.regions
38+
)
39+
40+
741
@dataclass
842
class CommonSubexpressionElimination(RewriteRule):
943

10-
def rewrite_Block(self, node: Block) -> RewriteResult:
11-
seen: dict[int, Statement] = {}
44+
def rewrite_Block(self, node: ir.Block) -> RewriteResult:
45+
seen: dict[Info, ir.Statement] = {}
1246

1347
for stmt in node.stmts:
14-
if not stmt.has_trait(Pure):
48+
if not stmt.has_trait(ir.Pure):
1549
continue
1650

1751
if stmt.regions:
1852
continue
1953

20-
hash_value = hash(
21-
(type(stmt),)
22-
+ tuple(stmt.args)
23-
+ tuple(stmt.attributes.values())
24-
+ tuple(stmt.successors)
25-
+ tuple(stmt.regions)
54+
info = Info(
55+
head=type(stmt),
56+
args=tuple(stmt.args),
57+
attributes=tuple(stmt.attributes.values()),
58+
successors=tuple(stmt.successors),
59+
regions=tuple(stmt.regions),
2660
)
27-
if hash_value in seen:
28-
old_stmt = seen[hash_value]
61+
if info in seen:
62+
old_stmt = seen[info]
2963
for result, old_result in zip(stmt._results, old_stmt.results):
3064
result.replace_by(old_result)
3165
stmt.delete()
3266
return RewriteResult(has_done_something=True)
3367
else:
34-
seen[hash_value] = stmt
68+
seen[info] = stmt
3569
return RewriteResult()
3670

37-
def rewrite_Statement(self, node: Statement) -> RewriteResult:
71+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
3872
if not node.regions:
3973
return RewriteResult()
4074

test/ir/test_hash.py

Lines changed: 0 additions & 17 deletions
This file was deleted.

0 commit comments

Comments
 (0)