Skip to content

Commit 885b2bf

Browse files
committed
[PyRTG] Support sets
1 parent 36078b3 commit 885b2bf

File tree

5 files changed

+144
-1
lines changed

5 files changed

+144
-1
lines changed

frontends/PyRTG/src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ declare_mlir_python_sources(PyRTGSources
1717
pyrtg/core.py
1818
pyrtg/labels.py
1919
pyrtg/rtg.py
20+
pyrtg/sets.py
2021
pyrtg/support.py
2122
pyrtg/tests.py
2223
rtgtool/rtgtool.py

frontends/PyRTG/src/pyrtg/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
from .tests import test
88
from .labels import Label
99
from .rtg import rtg
10+
from .sets import Set

frontends/PyRTG/src/pyrtg/sets.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from __future__ import annotations
6+
7+
from .circt import ir, support
8+
from .rtg import rtg
9+
from .core import Value
10+
11+
12+
class Set(Value):
13+
"""
14+
Represents a statically typed set for any kind of values that allows picking
15+
elements at random.
16+
"""
17+
18+
def __init__(self, value: ir.Value) -> Set:
19+
"""
20+
Intended for library internal usage only.
21+
"""
22+
23+
self._value = value
24+
25+
def create_empty(elementType: ir.Type) -> Set:
26+
"""
27+
Create an empty set that can hold elements of the provided type.
28+
"""
29+
30+
return rtg.SetCreateOp(rtg.SetType.get(elementType), [])
31+
32+
def create(*elements: Value) -> Set:
33+
"""
34+
Create a set containing the provided values. At least one element must be
35+
provided.
36+
"""
37+
38+
assert len(
39+
elements) > 0, "use 'create_empty' to create sets with no elements"
40+
assert all([e.get_type() == elements[0].get_type() for e in elements
41+
]), "all elements must have the same type"
42+
return rtg.SetCreateOp(rtg.SetType.get(elements[0].get_type()), elements)
43+
44+
def __add__(self, other: Value) -> Set:
45+
"""
46+
If another set is provided their types must match and a new Set will be
47+
returned containing all elements of both sets (set union). If a value that
48+
is not a Set is provided, it must match the element type of this Set. A new
49+
Set will be returned containing all elements of this Set plus the provided
50+
value.
51+
"""
52+
53+
if isinstance(other, Set):
54+
assert self.get_type() == other.get_type(
55+
), "sets must be of the same type"
56+
return rtg.SetUnionOp([self._value, other._value])
57+
58+
assert support.type_to_pytype(
59+
self.get_type()).element_type == other.get_type(
60+
), "type of the provided value must match element type of the set"
61+
return self + Set.create(other)
62+
63+
def __sub__(self, other: Value) -> Set:
64+
"""
65+
If another set is provided their types must match and a new Set will be
66+
returned containing all elements in this Set that are not contained in the
67+
'other' Set. If a value that is not a Set is provided, it must match the
68+
element type of this Set. A new Set will be returned containing all
69+
elements of this Set except the provided value in that case.
70+
"""
71+
72+
if isinstance(other, Set):
73+
assert self.get_type() == other.get_type(
74+
), "sets must be of the same type"
75+
return rtg.SetDifferenceOp(self._value, other._value)
76+
77+
assert support.type_to_pytype(
78+
self.get_type()).element_type == other.get_type(
79+
), "type of the provided value must match element type of the set"
80+
return self - Set.create(other)
81+
82+
def get_random(self) -> Value:
83+
"""
84+
Returns an element from the set picked uniformly at random. If the set is
85+
empty, calling this method is undefined behavior.
86+
"""
87+
88+
return rtg.SetSelectRandomOp(self._value)
89+
90+
def get_random_and_exclude(self) -> Value:
91+
"""
92+
Returns an element from the set picked uniformly at random and removes it
93+
from the set. If the set is empty, calling this method is undefined
94+
behavior.
95+
"""
96+
97+
r = self.get_random()
98+
self = self - r
99+
return r
100+
101+
def _get_ssa_value(self) -> ir.Value:
102+
return self._value
103+
104+
def get_type(self) -> ir.Type:
105+
return self._value.type

frontends/PyRTG/src/pyrtg/support.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ def _FromCirctValue(value: ir.Value) -> Value:
1212
if isinstance(type, rtg.LabelType):
1313
from .labels import Label
1414
return Label(value)
15+
if isinstance(type, rtg.SetType):
16+
from .sets import Set
17+
return Set(value)
1518
assert False, "Unsupported value"
1619

1720

@@ -35,6 +38,8 @@ def specialize_create(cls):
3538
def create(*args, **kwargs):
3639
# If any of the arguments are 'pyrtg.Value', we need to convert them.
3740
def to_circt(arg):
41+
if isinstance(arg, Value):
42+
return arg._get_ssa_value()
3843
if isinstance(arg, (list, tuple)):
3944
return [to_circt(a) for a in arg]
4045
return arg

frontends/PyRTG/test/basic.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# RUN: %rtgtool% %s --seed=0 --output-format=elaborated | FileCheck %s --check-prefix=ELABORATED
33
# RUN: %rtgtool% %s --seed=0 -o %t --output-format=asm && FileCheck %s --input-file=%t --check-prefix=ASM
44

5-
from pyrtg import test, Label, rtg
5+
from pyrtg import test, rtg, Label, Set
66

77
# MLIR-LABEL: rtg.test @test0
88
# MLIR-NEXT: }
@@ -26,6 +26,18 @@ def test0():
2626
# MLIR-NEXT: rtg.label global [[L0]]
2727
# MLIR-NEXT: rtg.label external [[L1]]
2828
# MLIR-NEXT: rtg.label local [[L2]]
29+
30+
# MLIR-NEXT: [[SET0:%.+]] = rtg.set_create [[L0]], [[L1]] : !rtg.label
31+
# MLIR-NEXT: [[SET1:%.+]] = rtg.set_create [[L2]] : !rtg.label
32+
# MLIR-NEXT: [[EMPTY_SET:%.+]] = rtg.set_create : !rtg.label
33+
# MLIR-NEXT: [[SET2_1:%.+]] = rtg.set_union [[SET0]], [[SET1]] : !rtg.set<!rtg.label>
34+
# MLIR-NEXT: [[SET2:%.+]] = rtg.set_union [[SET2_1]], [[EMPTY_SET]] : !rtg.set<!rtg.label>
35+
# MLIR-NEXT: [[RL0:%.+]] = rtg.set_select_random [[SET2]] : !rtg.set<!rtg.label>
36+
# MLIR-NEXT: rtg.label local [[RL0]]
37+
# MLIR-NEXT: [[SET2_MINUS_SET0:%.+]] = rtg.set_difference [[SET2]], [[SET0]] : !rtg.set<!rtg.label>
38+
# MLIR-NEXT: [[RL1:%.+]] = rtg.set_select_random [[SET2_MINUS_SET0]] : !rtg.set<!rtg.label>
39+
# MLIR-NEXT: rtg.label local [[RL1]]
40+
2941
# MLIR-NEXT: }
3042

3143
# ELABORATED-LABEL: rtg.test @test_labels
@@ -35,6 +47,10 @@ def test0():
3547
# ELABORATED-NEXT: rtg.label external [[L1]]
3648
# ELABORATED-NEXT: [[L2:%.+]] = rtg.label_decl "l1_1"
3749
# ELABORATED-NEXT: rtg.label local [[L2]]
50+
51+
# ELABORATED-NEXT: rtg.label local [[L0]]
52+
# ELABORATED-NEXT: rtg.label local [[L2]]
53+
3854
# ELABORATED-NEXT: }
3955

4056
# ASM-LABEL: Begin of test_labels
@@ -43,6 +59,10 @@ def test0():
4359
# ASM-NEXT: l0:
4460
# ASM-NEXT: .extern l1_0
4561
# ASM-NEXT: l1_1:
62+
63+
# ASM-NEXT: l0:
64+
# ASM-NEXT: l1_1:
65+
4666
# ASM-EMPTY:
4767
# ASM: End of test_labels
4868

@@ -55,3 +75,14 @@ def test_labels():
5575
l0.place(rtg.LabelVisibility.GLOBAL)
5676
l1.place(rtg.LabelVisibility.EXTERNAL)
5777
l2.place()
78+
79+
set0 = Set.create(l0, l1)
80+
set1 = Set.create(l2)
81+
empty_set = Set.create_empty(rtg.LabelType.get())
82+
set2 = set0 + set1 + empty_set
83+
rl0 = set2.get_random()
84+
rl0.place()
85+
86+
set2 -= set0
87+
rl1 = set2.get_random_and_exclude()
88+
rl1.place()

0 commit comments

Comments
 (0)