Skip to content

Commit 651081f

Browse files
merged is_equal with is_structurally_equal. issue 514 (#515)
Addressing #514 Merged all usage of `is_equal` with `is_structurally_equal`. --------- Co-authored-by: kaihsin <[email protected]>
1 parent 1bc9f87 commit 651081f

File tree

17 files changed

+81
-111
lines changed

17 files changed

+81
-111
lines changed

src/kirin/analysis/const/lattice.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,6 @@ class Value(Result):
8484
def is_subseteq_Value(self, other: "Value") -> bool:
8585
return self.data == other.data
8686

87-
def is_equal(self, other: Result) -> bool:
88-
if isinstance(other, Value):
89-
return self.data == other.data
90-
return False
91-
9287
def __hash__(self) -> int:
9388
# NOTE: we use id here because the data
9489
# may not be hashable. This is fine because
@@ -100,9 +95,7 @@ def is_structurally_equal(
10095
) -> bool:
10196
if not isinstance(other, Value):
10297
return False
103-
if isinstance(self.data, ir.Attribute) and isinstance(other.data, ir.Attribute):
104-
return self.data.is_structurally_equal(other.data, context=context)
105-
return self.data.is_structurally_equal(other.data, context=context)
98+
return self.data == other.data
10699

107100

108101
@dataclass
@@ -159,13 +152,6 @@ def meet(self, other: Result) -> Result:
159152
)
160153
return self.bottom()
161154

162-
def is_equal(self, other: Result) -> bool:
163-
if isinstance(other, PartialTuple):
164-
return all(x.is_equal(y) for x, y in zip(self.data, other.data))
165-
elif isinstance(other, Value) and isinstance(other.data, tuple):
166-
return all(x.is_equal(Value(y)) for x, y in zip(self.data, other.data))
167-
return False
168-
169155
def is_subseteq_PartialTuple(self, other: "PartialTuple") -> bool:
170156
return all(x.is_subseteq(y) for x, y in zip(self.data, other.data))
171157

@@ -180,14 +166,17 @@ def __hash__(self) -> int:
180166
def is_structurally_equal(
181167
self, other: ir.Attribute, context: dict | None = None
182168
) -> bool:
183-
if not isinstance(other, PartialTuple):
184-
return False
185-
if len(self.data) != len(other.data):
186-
return False
187-
return all(
188-
x.is_structurally_equal(y, context=context)
189-
for x, y in zip(self.data, other.data)
190-
)
169+
if isinstance(other, PartialTuple):
170+
return all(
171+
x.is_structurally_equal(y, context=context)
172+
for x, y in zip(self.data, other.data)
173+
)
174+
elif isinstance(other, Value) and isinstance(other.data, tuple):
175+
return all(
176+
x.is_structurally_equal(y, context=context)
177+
for x, y in zip(self.data, other.data)
178+
)
179+
return False
191180

192181

193182
@final

src/kirin/dialects/ilist/runtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def is_structurally_equal(
101101
return (
102102
isinstance(other, IList)
103103
and self.data == other.data
104-
and self.elem.is_equal(other.elem)
104+
and self.elem.is_structurally_equal(other.elem, context=context)
105105
)
106106

107107
def serialize(self, serializer: "Serializer") -> "SerializationUnit":

src/kirin/ir/attrs/abc.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def __hash__(self) -> int: ...
5555
@abstractmethod
5656
def __eq__(self, value: object) -> bool: ...
5757

58-
@abstractmethod
5958
def is_structurally_equal(self, other: Self, context: dict | None = None) -> bool:
6059
return self == other
6160

src/kirin/ir/attrs/types.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,12 @@ def __or__(self, other: "TypeAttribute"):
8484
return self.join(other)
8585

8686
def __eq__(self, value: object) -> bool:
87-
return isinstance(value, TypeAttribute) and self.is_equal(value)
87+
return isinstance(value, TypeAttribute) and self.is_structurally_equal(value)
88+
89+
@abstractmethod
90+
def is_structurally_equal(
91+
self, other: "Attribute", context: dict | None = None
92+
) -> bool: ...
8893

8994
@abstractmethod
9095
def __hash__(self) -> int: ...
@@ -286,13 +291,6 @@ def __eq__(self, other: object) -> bool:
286291
return False
287292
return self.data == other.data and self.type == other.type
288293

289-
def is_equal(self, other: TypeAttribute) -> bool:
290-
return (
291-
isinstance(other, Literal)
292-
and self.type.is_equal(other.type)
293-
and self.data == other.data
294-
)
295-
296294
def is_subseteq_TypeVar(self, other: "TypeVar") -> bool:
297295
return self.is_subseteq(other.bound)
298296

@@ -317,7 +315,7 @@ def is_structurally_equal(
317315
return (
318316
isinstance(other, Literal)
319317
and self.data == other.data
320-
and self.type.is_equal(other.type)
318+
and self.type.is_structurally_equal(other.type, context=context)
321319
)
322320

323321
def serialize(self, serializer: "Serializer") -> "SerializationUnit":
@@ -355,9 +353,6 @@ def __init__(
355353
types = types.union({typ})
356354
self.types = types
357355

358-
def is_equal(self, other: TypeAttribute) -> bool:
359-
return isinstance(other, Union) and self.types == other.types
360-
361356
def is_subseteq_fallback(self, other: TypeAttribute) -> bool:
362357
return all(t.is_subseteq(other) for t in self.types)
363358

@@ -416,13 +411,6 @@ def __init__(self, name: str, bound: TypeAttribute | None = None):
416411
self.varname = name
417412
self.bound = bound or AnyType()
418413

419-
def is_equal(self, other: TypeAttribute) -> bool:
420-
return (
421-
isinstance(other, TypeVar)
422-
and self.varname == other.varname
423-
and self.bound.is_equal(other.bound)
424-
)
425-
426414
def is_subseteq_TypeVar(self, other: "TypeVar") -> bool:
427415
return self.bound.is_subseteq(other.bound)
428416

@@ -447,7 +435,7 @@ def is_structurally_equal(
447435
return (
448436
isinstance(other, TypeVar)
449437
and self.varname == other.varname
450-
and self.bound.is_equal(other.bound)
438+
and self.bound.is_structurally_equal(other.bound, context=context)
451439
)
452440

453441
def serialize(self, serializer: "Serializer") -> "SerializationUnit":
@@ -482,7 +470,7 @@ def print_impl(self, printer: Printer) -> None:
482470
def is_structurally_equal(
483471
self, other: Attribute, context: dict | None = None
484472
) -> bool:
485-
return isinstance(other, Vararg) and self.typ.is_equal(other.typ)
473+
return isinstance(other, Vararg) and self.typ.is_structurally_equal(other.typ)
486474

487475
def serialize(self, serializer: "Serializer") -> "SerializationUnit":
488476
return serializer.serialize_vararg(self)
@@ -626,12 +614,17 @@ def is_structurally_equal(
626614
return False
627615
if len(self.vars) != len(other.vars):
628616
return False
629-
if any(not v.is_equal(o) for v, o in zip(self.vars, other.vars)):
617+
if any(
618+
not v.is_structurally_equal(o, context=context)
619+
for v, o in zip(self.vars, other.vars)
620+
):
630621
return False
631622
if self.vararg is None and other.vararg is None:
632623
return True
633624
if self.vararg is not None and other.vararg is not None:
634-
return self.vararg.typ.is_equal(other.vararg.typ)
625+
return self.vararg.typ.is_structurally_equal(
626+
other.vararg.typ, context=context
627+
)
635628
return False
636629

637630
def serialize(self, serializer: "Serializer") -> "SerializationUnit":

src/kirin/ir/nodes/base.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -58,25 +58,6 @@ def get_root(self) -> IRNode:
5858
return self
5959
return parent.get_root()
6060

61-
def is_equal(self, other: IRNode, context: dict = {}) -> bool:
62-
"""Check if the current node is equal to the other node.
63-
64-
Args:
65-
other: The other node to compare.
66-
context: The context to store the visited nodes. Defaults to {}.
67-
68-
Returns:
69-
True if the nodes are equal, False otherwise.
70-
71-
!!! note
72-
This method is not the same as the `==` operator. It checks for
73-
structural equality rather than identity. To change the behavior
74-
of structural equality, override the `is_structurally_equal` method.
75-
"""
76-
if not isinstance(other, type(self)):
77-
return False
78-
return self.is_structurally_equal(other, context)
79-
8061
def attach(self, parent: ParentType) -> None:
8162
"""Attach the current node to the parent node."""
8263
assert isinstance(parent, IRNode), f"Expected IRNode, got {type(parent)}"
@@ -116,7 +97,7 @@ def is_structurally_equal(
11697
11798
!!! note
11899
This method is for tweaking the behavior of structural equality.
119-
To check if two nodes are structurally equal, use the `is_equal` method.
100+
To check if two nodes are structurally equal, use the `is_structurally_equal` method.
120101
121102
Args:
122103
other: The other node to compare.

src/kirin/lattice/abc.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def is_subseteq(self, other: LatticeType) -> bool:
4949
"""Subseteq operation."""
5050
...
5151

52-
def is_equal(self, other: LatticeType) -> bool:
52+
def is_structurally_equal(
53+
self, other: LatticeType, context: dict | None = None
54+
) -> bool:
5355
"""Check if two lattices are equal."""
5456
if self is other:
5557
return True
@@ -61,7 +63,7 @@ def is_subset(self, other: LatticeType) -> bool:
6163

6264
def __eq__(self, value: object) -> bool:
6365
raise NotImplementedError(
64-
"Equality is not implemented for lattices, use is_equal instead"
66+
"Equality is not implemented for lattices, use is_structurally_equal instead"
6567
)
6668

6769
def __hash__(self) -> int:

src/kirin/rewrite/wrap_const.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def wrap(self, value: ir.SSAValue) -> bool:
2727
const_hint = value.hints.get("const")
2828
if const_hint and isinstance(const_hint, const.Result):
2929
const_result = result.meet(const_hint)
30-
if const_result.is_equal(const_hint):
30+
if const_result.is_structurally_equal(const_hint):
3131
return False
3232
else:
3333
const_result = result

test/analysis/dataflow/constprop/test_constprop.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -72,40 +72,40 @@ def test_join(self):
7272
== const.Unknown()
7373
)
7474

75-
def test_is_equal(self):
76-
assert const.Unknown().is_equal(const.Unknown())
77-
assert not const.Unknown().is_equal(const.Bottom())
78-
assert not const.Unknown().is_equal(const.Value(1))
79-
assert const.Bottom().is_equal(const.Bottom())
80-
assert not const.Bottom().is_equal(const.Value(1))
81-
assert const.Value(1).is_equal(const.Value(1))
82-
assert not const.Value(1).is_equal(const.Value(2))
83-
assert const.PartialTuple((const.Value(1), const.Bottom())).is_equal(
84-
const.PartialTuple((const.Value(1), const.Bottom()))
85-
)
86-
assert not const.PartialTuple((const.Value(1), const.Bottom())).is_equal(
87-
const.PartialTuple((const.Value(1), const.Value(2)))
88-
)
75+
def test_is_structurally_equal(self):
76+
assert const.Unknown().is_structurally_equal(const.Unknown())
77+
assert not const.Unknown().is_structurally_equal(const.Bottom())
78+
assert not const.Unknown().is_structurally_equal(const.Value(1))
79+
assert const.Bottom().is_structurally_equal(const.Bottom())
80+
assert not const.Bottom().is_structurally_equal(const.Value(1))
81+
assert const.Value(1).is_structurally_equal(const.Value(1))
82+
assert not const.Value(1).is_structurally_equal(const.Value(2))
83+
assert const.PartialTuple(
84+
(const.Value(1), const.Bottom())
85+
).is_structurally_equal(const.PartialTuple((const.Value(1), const.Bottom())))
86+
assert not const.PartialTuple(
87+
(const.Value(1), const.Bottom())
88+
).is_structurally_equal(const.PartialTuple((const.Value(1), const.Value(2))))
8989

9090
def test_partial_tuple(self):
9191
pt1 = const.PartialTuple((const.Value(1), const.Bottom()))
9292
pt2 = const.PartialTuple((const.Value(1), const.Bottom()))
93-
assert pt1.is_equal(pt2)
93+
assert pt1.is_structurally_equal(pt2)
9494
assert pt1.is_subseteq(pt2)
9595
assert pt1.join(pt2) == pt1
9696
assert pt1.meet(pt2) == pt1
9797
pt2 = const.PartialTuple((const.Value(1), const.Value(2)))
98-
assert not pt1.is_equal(pt2)
98+
assert not pt1.is_structurally_equal(pt2)
9999
assert pt1.is_subseteq(pt2)
100100
assert pt1.join(pt2) == const.PartialTuple((const.Value(1), const.Value(2)))
101101
assert pt1.meet(pt2) == const.PartialTuple((const.Value(1), const.Bottom()))
102102
pt2 = const.PartialTuple((const.Value(1), const.Bottom()))
103-
assert pt1.is_equal(pt2)
103+
assert pt1.is_structurally_equal(pt2)
104104
assert pt1.is_subseteq(pt2)
105105
assert pt1.join(pt2) == pt1
106106
assert pt1.meet(pt2) == pt1
107107
pt2 = const.PartialTuple((const.Value(1), const.Unknown()))
108-
assert not pt1.is_equal(pt2)
108+
assert not pt1.is_structurally_equal(pt2)
109109
assert pt1.is_subseteq(pt2)
110110
assert pt1.join(pt2) == pt2
111111
assert pt1.meet(pt2) == pt1

test/analysis/dataflow/typeinfer/test_selfref_closure.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,6 @@ def self_ref_source(i_layer):
2222
infer = TypeInference(structural_no_opt)
2323
frame, ret = infer.run(should_work, types.Int)
2424
should_work.print(analysis=frame.entries)
25-
assert ret.is_equal(types.MethodType[types.Tuple[types.Any], types.NoneType])
25+
assert ret.is_structurally_equal(
26+
types.MethodType[types.Tuple[types.Any], types.NoneType]
27+
)

test/dialects/py/test_assign.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_ann_assign():
1818

1919
typeinfer = TypeInference(basic_no_opt)
2020
_, ret = typeinfer.run(main, types.Int)
21-
assert ret.is_equal(types.Int)
21+
assert ret.is_structurally_equal(types.Int)
2222
_, ret = typeinfer.run(main, types.Float)
2323
assert ret is ret.bottom()
2424

@@ -41,9 +41,13 @@ def list_assign():
4141

4242
stmt = list_assign.callable_region.blocks[0].stmts.at(3)
4343
assert isinstance(stmt, ilist.New)
44-
assert stmt.elem_type.is_equal(types.Float)
45-
assert stmt.result.type.is_equal(ilist.IListType[types.Float, types.Literal(3)])
44+
assert stmt.elem_type.is_structurally_equal(types.Float)
45+
assert stmt.result.type.is_structurally_equal(
46+
ilist.IListType[types.Float, types.Literal(3)]
47+
)
4648

4749
stmt = list_assign.callable_region.blocks[0].stmts.at(4)
4850
assert isinstance(stmt, py.assign.TypeAssert)
49-
assert stmt.expected.is_equal(ilist.IListType[types.Float, types.Literal(3)])
51+
assert stmt.expected.is_structurally_equal(
52+
ilist.IListType[types.Float, types.Literal(3)]
53+
)

0 commit comments

Comments
 (0)