Skip to content

Commit c3ed3a6

Browse files
authored
[PyRTG] Wrapper around SSA values and label support (#8228)
1 parent 435dcee commit c3ed3a6

File tree

9 files changed

+213
-9
lines changed

9 files changed

+213
-9
lines changed

frontends/PyRTG/src/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ declare_mlir_python_sources(PyRTGSources
1414
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
1515
SOURCES
1616
pyrtg/__init__.py
17+
pyrtg/core.py
18+
pyrtg/labels.py
19+
pyrtg/rtg.py
20+
pyrtg/support.py
1721
pyrtg/tests.py
1822
rtgtool/rtgtool.py
1923
)

frontends/PyRTG/src/pyrtg/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5-
from . import circt
65
from . import tests
6+
from . import core
77
from .tests import test
8+
from .labels import Label
9+
from .rtg import rtg

frontends/PyRTG/src/pyrtg/core.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from .circt import ir
6+
7+
8+
class CodeGenRoot:
9+
"""
10+
This is the base class for classes that have to be visited by the RTG tool
11+
during codegen.
12+
"""
13+
14+
def codegen(self):
15+
assert False, "must be implemented by the subclass"
16+
17+
18+
class Value:
19+
"""
20+
This class wraps around MLIR SSA values to provide a more Python native
21+
experience. Instead of having a value class that stores the type, classes
22+
deriving from this class represent specific types of values. Operations on
23+
those values can then be exposed as methods that can support more convenient
24+
bridging between Python values and MLIR values (e.g., accepting a Python
25+
integer and automatically building a ConstantOp in MLIR).
26+
"""
27+
28+
def get_type(self) -> ir.Type:
29+
assert False, "must be implemented by subclass"
30+
31+
def _get_ssa_value(self) -> ir.Value:
32+
assert False, "must be implemented by subclass"
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from __future__ import annotations
6+
7+
from .circt import ir
8+
from .core import Value
9+
from .rtg import rtg
10+
11+
12+
class Label(Value):
13+
"""
14+
Represents an ISA Assembly label. It can be declared and then passed around
15+
like every value. To place a label at a specific location in a sequence call
16+
'place'. It is the user's responsibility to place a label such that if the
17+
label is used by an instruction in the fully randomized test, there exists
18+
exactly one placement of the label to not end up with ambiguity or usage of
19+
an undeclared label.
20+
"""
21+
22+
def __init__(self, value: ir.Value):
23+
self._value = value
24+
25+
def declare(string: str) -> Label:
26+
"""
27+
Declares a label with a fixed name. Labels returned by different calls to
28+
this function but with the same arguments refer to the same label.
29+
"""
30+
31+
return rtg.LabelDeclOp(string, [])
32+
33+
def declare_unique(string: str) -> Label:
34+
"""
35+
Declares a unique label. This means, all usages of the value returned by this
36+
function will refer to the same label, but no other label declarations can
37+
conflict with this label, including labels returned by other calls to this
38+
function or fixed labels declared with 'declare_label'.
39+
"""
40+
41+
return rtg.LabelUniqueDeclOp(string, [])
42+
43+
def place(
44+
self,
45+
visibility: rtg.LabelVisibility = rtg.LabelVisibility.LOCAL) -> None:
46+
"""
47+
Places a declared label in a sequence or test.
48+
"""
49+
50+
return rtg.LabelOp(rtg.LabelVisibilityAttr.get(visibility), self._value)
51+
52+
def get_type(self) -> ir.Type:
53+
return rtg.LabelType.get()
54+
55+
def _get_ssa_value(self) -> ir.Value:
56+
return self._value

frontends/PyRTG/src/pyrtg/rtg.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from .support import wrap_opviews_with_values
6+
from .circt.dialects import rtg
7+
8+
wrap_opviews_with_values(rtg, rtg.__name__)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from .circt import support, ir
6+
from .core import Value
7+
8+
9+
def _FromCirctValue(value: ir.Value) -> Value:
10+
type = support.type_to_pytype(value.type)
11+
from .rtg import rtg
12+
if isinstance(type, rtg.LabelType):
13+
from .labels import Label
14+
return Label(value)
15+
assert False, "Unsupported value"
16+
17+
18+
def wrap_opviews_with_values(dialect, module_name, excluded=[]):
19+
"""
20+
Wraps all of a dialect's OpView classes to have their create method return a
21+
Value instead of an OpView.
22+
"""
23+
24+
import sys
25+
module = sys.modules[module_name]
26+
27+
for attr in dir(dialect):
28+
cls = getattr(dialect, attr)
29+
30+
if attr not in excluded and isinstance(cls, type) and issubclass(
31+
cls, ir.OpView):
32+
33+
def specialize_create(cls):
34+
35+
def create(*args, **kwargs):
36+
# If any of the arguments are 'pyrtg.Value', we need to convert them.
37+
def to_circt(arg):
38+
if isinstance(arg, (list, tuple)):
39+
return [to_circt(a) for a in arg]
40+
return arg
41+
42+
args = [to_circt(arg) for arg in args]
43+
kwargs = {k: to_circt(v) for k, v in kwargs.items()}
44+
# Create the OpView.
45+
if hasattr(cls, "create"):
46+
created = cls.create(*args, **kwargs)
47+
else:
48+
created = cls(*args, **kwargs)
49+
if isinstance(created, support.NamedValueOpView):
50+
created = created.opview
51+
52+
# Return the wrapped values, if any.
53+
converted_results = tuple(
54+
_FromCirctValue(res) for res in created.results)
55+
return converted_results[0] if len(
56+
converted_results) == 1 else created
57+
58+
return create
59+
60+
wrapped_class = specialize_create(cls)
61+
setattr(module, attr, wrapped_class)
62+
else:
63+
setattr(module, attr, cls)

frontends/PyRTG/src/pyrtg/tests.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
import inspect
66

77
from .circt import ir
8-
from .circt.dialects import rtg
8+
from .core import CodeGenRoot
9+
from .rtg import rtg
910

1011

11-
class Test:
12+
class Test(CodeGenRoot):
1213
"""
1314
Represents an RTG Test. Stores the test function and location.
1415
"""

frontends/PyRTG/src/rtgtool/rtgtool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def frontend_codegen(args: argparse.Namespace) -> ir.Module:
128128
module = ir.Module.create()
129129
with ir.InsertionPoint(module.body):
130130
for _, obj in inspect.getmembers(file):
131-
if isinstance(obj, pyrtg.tests.Test):
131+
if isinstance(obj, pyrtg.core.CodeGenRoot):
132132
obj.codegen()
133133
return module
134134

frontends/PyRTG/test/basic.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,57 @@
11
# RUN: %rtgtool% %s --seed=0 --output-format=mlir | FileCheck %s --check-prefix=MLIR
22
# RUN: %rtgtool% %s --seed=0 --output-format=elaborated | FileCheck %s --check-prefix=ELABORATED
3-
# RUN: %rtgtool% %s --seed=0 -o %t --output-format=asm | FileCheck %s --input-file=%t --check-prefix=ASM
3+
# RUN: %rtgtool% %s --seed=0 -o %t --output-format=asm && FileCheck %s --input-file=%t --check-prefix=ASM
44

5-
from pyrtg import test
5+
from pyrtg import test, Label, rtg
66

7-
# MLIR: rtg.test @test0
7+
# MLIR-LABEL: rtg.test @test0
88
# MLIR-NEXT: }
99

10-
# ELABORATED: rtg.test @test0
10+
# ELABORATED-LABEL: rtg.test @test0
1111
# ELABORATED-NEXT: }
1212

13-
# ASM: Begin of test0
13+
# ASM-LABEL: Begin of test0
1414
# ASM: End of test0
1515

1616

1717
@test
1818
def test0():
1919
pass
20+
21+
22+
# MLIR-LABEL: rtg.test @test_labels
23+
# MLIR-NEXT: [[L0:%.+]] = rtg.label_decl "l0"
24+
# MLIR-NEXT: [[L1:%.+]] = rtg.label_unique_decl "l1"
25+
# MLIR-NEXT: [[L2:%.+]] = rtg.label_unique_decl "l1"
26+
# MLIR-NEXT: rtg.label global [[L0]]
27+
# MLIR-NEXT: rtg.label external [[L1]]
28+
# MLIR-NEXT: rtg.label local [[L2]]
29+
# MLIR-NEXT: }
30+
31+
# ELABORATED-LABEL: rtg.test @test_labels
32+
# ELABORATED-NEXT: [[L0:%.+]] = rtg.label_decl "l0"
33+
# ELABORATED-NEXT: rtg.label global [[L0]]
34+
# ELABORATED-NEXT: [[L1:%.+]] = rtg.label_decl "l1_0"
35+
# ELABORATED-NEXT: rtg.label external [[L1]]
36+
# ELABORATED-NEXT: [[L2:%.+]] = rtg.label_decl "l1_1"
37+
# ELABORATED-NEXT: rtg.label local [[L2]]
38+
# ELABORATED-NEXT: }
39+
40+
# ASM-LABEL: Begin of test_labels
41+
# ASM-EMPTY:
42+
# ASM-NEXT: .global l0
43+
# ASM-NEXT: l0:
44+
# ASM-NEXT: .extern l1_0
45+
# ASM-NEXT: l1_1:
46+
# ASM-EMPTY:
47+
# ASM: End of test_labels
48+
49+
50+
@test
51+
def test_labels():
52+
l0 = Label.declare("l0")
53+
l1 = Label.declare_unique("l1")
54+
l2 = Label.declare_unique("l1")
55+
l0.place(rtg.LabelVisibility.GLOBAL)
56+
l1.place(rtg.LabelVisibility.EXTERNAL)
57+
l2.place()

0 commit comments

Comments
 (0)