Skip to content

Commit 05fae11

Browse files
Automerge: [mlir][python] Add bindings for OpenACC dialect (#163620)
Adds initial support for Python bindings to the OpenACC dialect. * The bindings do not provide any niceties yet, just the barebones exposure of the dialect to Python. Construction of OpenACC ops is therefore verbose and somewhat inconvenient, as evidenced by the test. * The test only constructs one module, but I attempted to use enough operations to be meaningful. It does not test all the ops exposed, but does contain a realistic example of a memcpy idiom.
2 parents fdf68ac + e5825c4 commit 05fae11

File tree

4 files changed

+200
-0
lines changed

4 files changed

+200
-0
lines changed

mlir/python/CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,16 @@ declare_mlir_dialect_python_bindings(
134134
dialects/func.py
135135
DIALECT_NAME func)
136136

137+
declare_mlir_dialect_python_bindings(
138+
ADD_TO_PARENT MLIRPythonSources.Dialects
139+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
140+
TD_FILE dialects/OpenACCOps.td
141+
SOURCES
142+
dialects/openacc.py
143+
DIALECT_NAME acc
144+
DEPENDS acc_common_td
145+
)
146+
137147
declare_mlir_dialect_python_bindings(
138148
ADD_TO_PARENT MLIRPythonSources.Dialects
139149
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//===-- OpenACCOps.td - Entry point for OpenACCOps bind ------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef PYTHON_BINDINGS_OPENACC_OPS
10+
#define PYTHON_BINDINGS_OPENACC_OPS
11+
12+
include "mlir/Dialect/OpenACC/OpenACCOps.td"
13+
14+
#endif
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from ._acc_ops_gen import *
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# RUN: python %s | FileCheck %s
2+
from unittest import result
3+
from mlir.ir import (
4+
Context,
5+
FunctionType,
6+
Location,
7+
Module,
8+
InsertionPoint,
9+
IntegerType,
10+
IndexType,
11+
MemRefType,
12+
F32Type,
13+
Block,
14+
ArrayAttr,
15+
Attribute,
16+
UnitAttr,
17+
StringAttr,
18+
DenseI32ArrayAttr,
19+
ShapedType,
20+
)
21+
from mlir.dialects import openacc, func, arith, memref
22+
from mlir.extras import types
23+
24+
25+
def run(f):
26+
print("\n// TEST:", f.__name__)
27+
with Context(), Location.unknown():
28+
f()
29+
return f
30+
31+
32+
@run
33+
def testParallelMemcpy():
34+
module = Module.create()
35+
36+
dynamic = ShapedType.get_dynamic_size()
37+
memref_f32_1d_any = MemRefType.get([dynamic], types.f32())
38+
39+
with InsertionPoint(module.body):
40+
function_type = FunctionType.get(
41+
[memref_f32_1d_any, memref_f32_1d_any, types.i64()], []
42+
)
43+
f = func.FuncOp(
44+
type=function_type,
45+
name="memcpy_idiom",
46+
)
47+
f.attributes["sym_visibility"] = StringAttr.get("public")
48+
49+
with InsertionPoint(f.add_entry_block()):
50+
c1024 = arith.ConstantOp(types.i32(), 1024)
51+
c128 = arith.ConstantOp(types.i32(), 128)
52+
53+
arg0, arg1, arg2 = f.arguments
54+
55+
copied = openacc.copyin(
56+
acc_var=arg0.type,
57+
var=arg0,
58+
var_type=types.f32(),
59+
bounds=[],
60+
async_operands=[],
61+
implicit=False,
62+
structured=True,
63+
)
64+
created = openacc.create_(
65+
acc_var=arg1.type,
66+
var=arg1,
67+
var_type=types.f32(),
68+
bounds=[],
69+
async_operands=[],
70+
implicit=False,
71+
structured=True,
72+
)
73+
74+
parallel_op = openacc.ParallelOp(
75+
asyncOperands=[],
76+
waitOperands=[],
77+
numGangs=[c1024],
78+
numWorkers=[],
79+
vectorLength=[c128],
80+
reductionOperands=[],
81+
privateOperands=[],
82+
firstprivateOperands=[],
83+
dataClauseOperands=[],
84+
)
85+
86+
# Set required device_type and segment attributes to satisfy verifier
87+
acc_device_none = ArrayAttr.get([Attribute.parse("#acc.device_type<none>")])
88+
parallel_op.numGangsDeviceType = acc_device_none
89+
parallel_op.numGangsSegments = DenseI32ArrayAttr.get([1])
90+
parallel_op.vectorLengthDeviceType = acc_device_none
91+
92+
parallel_block = Block.create_at_start(parent=parallel_op.region, arg_types=[])
93+
94+
with InsertionPoint(parallel_block):
95+
c0 = arith.ConstantOp(types.i64(), 0)
96+
c1 = arith.ConstantOp(types.i64(), 1)
97+
98+
loop_op = openacc.LoopOp(
99+
results_=[],
100+
lowerbound=[c0],
101+
upperbound=[f.arguments[2]],
102+
step=[c1],
103+
gangOperands=[],
104+
workerNumOperands=[],
105+
vectorOperands=[],
106+
tileOperands=[],
107+
cacheOperands=[],
108+
privateOperands=[],
109+
reductionOperands=[],
110+
firstprivateOperands=[],
111+
)
112+
113+
# Set loop attributes: gang and independent on device_type<none>
114+
acc_device_none = ArrayAttr.get([Attribute.parse("#acc.device_type<none>")])
115+
loop_op.gang = acc_device_none
116+
loop_op.independent = acc_device_none
117+
118+
loop_block = Block.create_at_start(
119+
parent=loop_op.region, arg_types=[types.i64()]
120+
)
121+
122+
with InsertionPoint(loop_block):
123+
idx = arith.index_cast(out=IndexType.get(), in_=loop_block.arguments[0])
124+
val = memref.load(memref=copied, indices=[idx])
125+
memref.store(value=val, memref=created, indices=[idx])
126+
openacc.YieldOp([])
127+
128+
openacc.YieldOp([])
129+
130+
deleted = openacc.delete(
131+
acc_var=copied,
132+
bounds=[],
133+
async_operands=[],
134+
implicit=False,
135+
structured=True,
136+
)
137+
copied = openacc.copyout(
138+
acc_var=created,
139+
var=arg1,
140+
var_type=types.f32(),
141+
bounds=[],
142+
async_operands=[],
143+
implicit=False,
144+
structured=True,
145+
)
146+
func.ReturnOp([])
147+
148+
print(module)
149+
150+
# CHECK: TEST: testParallelMemcpy
151+
# CHECK-LABEL: func.func public @memcpy_idiom(
152+
# CHECK-SAME: %[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: i64) {
153+
# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1024 : i32
154+
# CHECK: %[[CONSTANT_1:.*]] = arith.constant 128 : i32
155+
# CHECK: %[[COPYIN_0:.*]] = acc.copyin varPtr(%[[ARG0]] : memref<?xf32>) -> memref<?xf32>
156+
# CHECK: %[[CREATE_0:.*]] = acc.create varPtr(%[[ARG1]] : memref<?xf32>) -> memref<?xf32>
157+
# CHECK: acc.parallel num_gangs({%[[CONSTANT_0]] : i32}) vector_length(%[[CONSTANT_1]] : i32) {
158+
# CHECK: %[[CONSTANT_2:.*]] = arith.constant 0 : i64
159+
# CHECK: %[[CONSTANT_3:.*]] = arith.constant 1 : i64
160+
# CHECK: acc.loop gang control(%[[VAL_0:.*]] : i64) = (%[[CONSTANT_2]] : i64) to (%[[ARG2]] : i64) step (%[[CONSTANT_3]] : i64) {
161+
# CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[VAL_0]] : i64 to index
162+
# CHECK: %[[LOAD_0:.*]] = memref.load %[[COPYIN_0]]{{\[}}%[[INDEX_CAST_0]]] : memref<?xf32>
163+
# CHECK: memref.store %[[LOAD_0]], %[[CREATE_0]]{{\[}}%[[INDEX_CAST_0]]] : memref<?xf32>
164+
# CHECK: acc.yield
165+
# CHECK: } attributes {independent = [#acc.device_type<none>]}
166+
# CHECK: acc.yield
167+
# CHECK: }
168+
# CHECK: acc.delete accPtr(%[[COPYIN_0]] : memref<?xf32>)
169+
# CHECK: acc.copyout accPtr(%[[CREATE_0]] : memref<?xf32>) to varPtr(%[[ARG1]] : memref<?xf32>)
170+
# CHECK: return
171+
# CHECK: }

0 commit comments

Comments
 (0)