Skip to content

Commit dacb401

Browse files
authored
feat: add discard to DialectGroup (#192)
1 parent 926f2e6 commit dacb401

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

src/kirin/ir/group.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,12 @@ def identity(code: Method):
6464
def __iter__(self):
6565
return iter(self.data)
6666

67+
def __repr__(self) -> str:
68+
names = ", ".join(each.name for each in self.data)
69+
return f"DialectGroup([{names}])"
70+
6771
@staticmethod
68-
def map_module(dialect):
72+
def map_module(dialect: Union["Dialect", ModuleType]) -> "Dialect":
6973
"""map the module to the dialect if it is a module.
7074
It assumes that the module has a `dialect` attribute
7175
that is an instance of [`Dialect`][kirin.ir.Dialect].
@@ -99,6 +103,26 @@ def union(self, dialect: Iterable[Union["Dialect", ModuleType]]) -> "DialectGrou
99103
run_pass=self.run_pass_gen, # pass the run_pass_gen function
100104
)
101105

106+
def discard(self, dialect: Union["Dialect", ModuleType]) -> "DialectGroup":
107+
"""discard a dialect from the group.
108+
109+
!!! note
110+
This does not raise an error if the dialect is not in the group.
111+
112+
Args:
113+
dialect (Union[Dialect, ModuleType]): the dialect to discard
114+
115+
Returns:
116+
DialectGroup: the new dialect group with the discarded dialect.
117+
"""
118+
dialect_ = self.map_module(dialect)
119+
return DialectGroup(
120+
dialects=frozenset(
121+
each for each in self.data if each.name != dialect_.name
122+
),
123+
run_pass=self.run_pass_gen, # pass the run_pass_gen function
124+
)
125+
102126
@property
103127
def registry(self) -> "Registry":
104128
"""return the registry for the dialect group. This

test/ir/test_group.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,35 @@ def test_union():
1313
target_b = group_a.union(group_c)
1414
assert target_a.data == group_a.data
1515
assert target_b.data == group_d.data
16+
17+
target_a_repr = repr(target_a)
18+
assert "DialectGroup(" in target_a_repr
19+
assert base.dialect.name in target_a_repr
20+
assert cf.dialect.name in target_a_repr
21+
22+
target_b_repr = repr(target_b)
23+
assert "DialectGroup(" in target_b_repr
24+
assert base.dialect.name in target_b_repr
25+
assert cf.dialect.name in target_b_repr
26+
assert func.dialect.name in target_b_repr
27+
28+
29+
def test_discard():
30+
group_a = DialectGroup([base, cf])
31+
group_c = DialectGroup([base, func])
32+
group_d = DialectGroup([base, func, cf])
33+
34+
target_a = group_d.discard(cf)
35+
target_b = group_d.discard(func)
36+
assert target_a.data == group_c.data
37+
assert target_b.data == group_a.data
38+
39+
target_a_repr = repr(target_a)
40+
assert "DialectGroup(" in target_a_repr
41+
assert base.dialect.name in target_a_repr
42+
assert func.dialect.name in target_a_repr
43+
44+
target_b_repr = repr(target_b)
45+
assert "DialectGroup(" in target_b_repr
46+
assert base.dialect.name in target_b_repr
47+
assert cf.dialect.name in target_b_repr

0 commit comments

Comments
 (0)