Skip to content

Commit 0c087bd

Browse files
[mlir][acc] Extend PointerLikeType to provide alloc, dealloc, copy (#162328)
A variable in an acc data clause operation must have a type that implements either PointerLikeType or a MappableType interface. These interfaces provide the contract that allows acc dialect and its transform passes to interact with a source dialect. One of these requirements is ability to generate code that creates memory for a private copy and ability to initialize that copy from another variable. Thus, update the PointerLikeType API to provide the means to create allocation, deallocation, and copy. This will be used as a way to fill in privatization and firstprivatization recipes. This new API was implemented for memref along with testing to exercise it via the implementation of a testing pass.
1 parent 750e64d commit 0c087bd

File tree

11 files changed

+617
-0
lines changed

11 files changed

+617
-0
lines changed

mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,86 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> {
7070
return ::mlir::acc::VariableTypeCategory::uncategorized;
7171
}]
7272
>,
73+
InterfaceMethod<
74+
/*description=*/[{
75+
Generates allocation operations for the pointer-like type. It will create
76+
an allocate that produces memory space for an instance of the current type.
77+
78+
The `varName` parameter is optional and can be used to provide a name
79+
for the allocated variable. If the current type is represented
80+
in a way that it does not capture the pointee type, `varType` must be
81+
passed in to provide the necessary type information.
82+
83+
The `originalVar` parameter is optional but enables support for dynamic
84+
types (e.g., dynamic memrefs). When provided, implementations can extract
85+
runtime dimension information from the original variable to create
86+
allocations with matching dynamic sizes.
87+
88+
Returns a Value representing the result of the allocation. If no value
89+
is returned, it means the allocation was not successfully generated.
90+
}],
91+
/*retTy=*/"::mlir::Value",
92+
/*methodName=*/"genAllocate",
93+
/*args=*/(ins "::mlir::OpBuilder &":$builder,
94+
"::mlir::Location":$loc,
95+
"::llvm::StringRef":$varName,
96+
"::mlir::Type":$varType,
97+
"::mlir::Value":$originalVar),
98+
/*methodBody=*/"",
99+
/*defaultImplementation=*/[{
100+
return {};
101+
}]
102+
>,
103+
InterfaceMethod<
104+
/*description=*/[{
105+
Generates deallocation operations for the pointer-like type. It deallocates
106+
the instance provided.
107+
108+
The `varPtr` parameter is required and must represent an instance that was
109+
previously allocated. If the current type is represented in a way that it
110+
does not capture the pointee type, `varType` must be passed in to provide
111+
the necessary type information. Nothing is generated in case the allocate
112+
is `alloca`-like.
113+
114+
Returns true if deallocation was successfully generated or successfully
115+
deemed as not needed to be generated, false otherwise.
116+
}],
117+
/*retTy=*/"bool",
118+
/*methodName=*/"genFree",
119+
/*args=*/(ins "::mlir::OpBuilder &":$builder,
120+
"::mlir::Location":$loc,
121+
"::mlir::TypedValue<::mlir::acc::PointerLikeType>":$varPtr,
122+
"::mlir::Type":$varType),
123+
/*methodBody=*/"",
124+
/*defaultImplementation=*/[{
125+
return false;
126+
}]
127+
>,
128+
InterfaceMethod<
129+
/*description=*/[{
130+
Generates copy operations for the pointer-like type. It copies the memory
131+
from the source to the destination. Typically used to initialize one
132+
variable of this type from another.
133+
134+
The `destination` and `source` parameters represent the target and source
135+
instances respectively. If the current type is represented in a way that it
136+
does not capture the pointee type, `varType` must be passed in to provide
137+
the necessary type information.
138+
139+
Returns true if copy was successfully generated, false otherwise.
140+
}],
141+
/*retTy=*/"bool",
142+
/*methodName=*/"genCopy",
143+
/*args=*/(ins "::mlir::OpBuilder &":$builder,
144+
"::mlir::Location":$loc,
145+
"::mlir::TypedValue<::mlir::acc::PointerLikeType>":$destination,
146+
"::mlir::TypedValue<::mlir::acc::PointerLikeType>":$source,
147+
"::mlir::Type":$varType),
148+
/*methodBody=*/"",
149+
/*defaultImplementation=*/[{
150+
return false;
151+
}]
152+
>,
73153
];
74154
}
75155

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
// =============================================================================
88

99
#include "mlir/Dialect/OpenACC/OpenACC.h"
10+
#include "mlir/Dialect/Arith/IR/Arith.h"
1011
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1112
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1213
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -44,6 +45,7 @@ struct MemRefPointerLikeModel
4445
Type getElementType(Type pointer) const {
4546
return cast<MemRefType>(pointer).getElementType();
4647
}
48+
4749
mlir::acc::VariableTypeCategory
4850
getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
4951
Type varType) const {
@@ -70,6 +72,115 @@ struct MemRefPointerLikeModel
7072
assert(memrefTy.getRank() > 0 && "rank expected to be positive");
7173
return mlir::acc::VariableTypeCategory::array;
7274
}
75+
76+
mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
77+
StringRef varName, Type varType,
78+
Value originalVar) const {
79+
auto memrefTy = cast<MemRefType>(pointer);
80+
81+
// Check if this is a static memref (all dimensions are known) - if yes
82+
// then we can generate an alloca operation.
83+
if (memrefTy.hasStaticShape())
84+
return memref::AllocaOp::create(builder, loc, memrefTy).getResult();
85+
86+
// For dynamic memrefs, extract sizes from the original variable if
87+
// provided. Otherwise they cannot be handled.
88+
if (originalVar && originalVar.getType() == memrefTy &&
89+
memrefTy.hasRank()) {
90+
SmallVector<Value> dynamicSizes;
91+
for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
92+
if (memrefTy.isDynamicDim(i)) {
93+
// Extract the size of dimension i from the original variable
94+
auto indexValue = arith::ConstantIndexOp::create(builder, loc, i);
95+
auto dimSize =
96+
memref::DimOp::create(builder, loc, originalVar, indexValue);
97+
dynamicSizes.push_back(dimSize);
98+
}
99+
// Note: We only add dynamic sizes to the dynamicSizes array
100+
// Static dimensions are handled automatically by AllocOp
101+
}
102+
return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes)
103+
.getResult();
104+
}
105+
106+
// TODO: Unranked not yet supported.
107+
return {};
108+
}
109+
110+
bool genFree(Type pointer, OpBuilder &builder, Location loc,
111+
TypedValue<PointerLikeType> varPtr, Type varType) const {
112+
if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varPtr)) {
113+
// Walk through casts to find the original allocation
114+
Value currentValue = memrefValue;
115+
Operation *originalAlloc = nullptr;
116+
117+
// Follow the chain of operations to find the original allocation
118+
// even if a casted result is provided.
119+
while (currentValue) {
120+
if (auto *definingOp = currentValue.getDefiningOp()) {
121+
// Check if this is an allocation operation
122+
if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
123+
originalAlloc = definingOp;
124+
break;
125+
}
126+
127+
// Check if this is a cast operation we can look through
128+
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
129+
currentValue = castOp.getSource();
130+
continue;
131+
}
132+
133+
// Check for other cast-like operations
134+
if (auto reinterpretCastOp =
135+
dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
136+
currentValue = reinterpretCastOp.getSource();
137+
continue;
138+
}
139+
140+
// If we can't look through this operation, stop
141+
break;
142+
}
143+
// This is a block argument or similar - can't trace further.
144+
break;
145+
}
146+
147+
if (originalAlloc) {
148+
if (isa<memref::AllocaOp>(originalAlloc)) {
149+
// This is an alloca - no dealloc needed, but return true (success)
150+
return true;
151+
}
152+
if (isa<memref::AllocOp>(originalAlloc)) {
153+
// This is an alloc - generate dealloc
154+
memref::DeallocOp::create(builder, loc, memrefValue);
155+
return true;
156+
}
157+
}
158+
}
159+
160+
return false;
161+
}
162+
163+
bool genCopy(Type pointer, OpBuilder &builder, Location loc,
164+
TypedValue<PointerLikeType> destination,
165+
TypedValue<PointerLikeType> source, Type varType) const {
166+
// Generate a copy operation between two memrefs
167+
auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
168+
auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
169+
170+
// As per memref documentation, source and destination must have same
171+
// element type and shape in order to be compatible. We do not want to fail
172+
// with an IR verification error - thus check that before generating the
173+
// copy operation.
174+
if (destMemref && srcMemref &&
175+
destMemref.getType().getElementType() ==
176+
srcMemref.getType().getElementType() &&
177+
destMemref.getType().getShape() == srcMemref.getType().getShape()) {
178+
memref::CopyOp::create(builder, loc, srcMemref, destMemref);
179+
return true;
180+
}
181+
182+
return false;
183+
}
73184
};
74185

75186
struct LLVMPointerPointerLikeModel
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: mlir-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(test-acc-pointer-like-interface{test-mode=alloc}))" 2>&1 | FileCheck %s
2+
3+
func.func @test_static_memref_alloc() {
4+
%0 = memref.alloca() {test.ptr} : memref<10x20xf32>
5+
// CHECK: Successfully generated alloc for operation: %[[ORIG:.*]] = memref.alloca() {test.ptr} : memref<10x20xf32>
6+
// CHECK: Generated: %{{.*}} = memref.alloca() : memref<10x20xf32>
7+
return
8+
}
9+
10+
// -----
11+
12+
func.func @test_dynamic_memref_alloc() {
13+
%c10 = arith.constant 10 : index
14+
%c20 = arith.constant 20 : index
15+
%orig = memref.alloc(%c10, %c20) {test.ptr} : memref<?x?xf32>
16+
17+
// CHECK: Successfully generated alloc for operation: %[[ORIG:.*]] = memref.alloc(%[[C10:.*]], %[[C20:.*]]) {test.ptr} : memref<?x?xf32>
18+
// CHECK: Generated: %[[C0:.*]] = arith.constant 0 : index
19+
// CHECK: Generated: %[[DIM0:.*]] = memref.dim %[[ORIG]], %[[C0]] : memref<?x?xf32>
20+
// CHECK: Generated: %[[C1:.*]] = arith.constant 1 : index
21+
// CHECK: Generated: %[[DIM1:.*]] = memref.dim %[[ORIG]], %[[C1]] : memref<?x?xf32>
22+
// CHECK: Generated: %{{.*}} = memref.alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32>
23+
return
24+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(test-acc-pointer-like-interface{test-mode=copy}))" 2>&1 | FileCheck %s
2+
3+
func.func @test_copy_static() {
4+
%src = memref.alloca() {test.src_ptr} : memref<10x20xf32>
5+
%dest = memref.alloca() {test.dest_ptr} : memref<10x20xf32>
6+
7+
// CHECK: Successfully generated copy from source: %[[SRC:.*]] = memref.alloca() {test.src_ptr} : memref<10x20xf32> to destination: %[[DEST:.*]] = memref.alloca() {test.dest_ptr} : memref<10x20xf32>
8+
// CHECK: Generated: memref.copy %[[SRC]], %[[DEST]] : memref<10x20xf32> to memref<10x20xf32>
9+
return
10+
}
11+
12+
// -----
13+
14+
func.func @test_copy_dynamic() {
15+
%c10 = arith.constant 10 : index
16+
%c20 = arith.constant 20 : index
17+
%src = memref.alloc(%c10, %c20) {test.src_ptr} : memref<?x?xf32>
18+
%dest = memref.alloc(%c10, %c20) {test.dest_ptr} : memref<?x?xf32>
19+
20+
// CHECK: Successfully generated copy from source: %[[SRC:.*]] = memref.alloc(%[[C10:.*]], %[[C20:.*]]) {test.src_ptr} : memref<?x?xf32> to destination: %[[DEST:.*]] = memref.alloc(%[[C10]], %[[C20]]) {test.dest_ptr} : memref<?x?xf32>
21+
// CHECK: Generated: memref.copy %[[SRC]], %[[DEST]] : memref<?x?xf32> to memref<?x?xf32>
22+
return
23+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: mlir-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(test-acc-pointer-like-interface{test-mode=free}))" 2>&1 | FileCheck %s
2+
3+
func.func @test_static_memref_free() {
4+
%0 = memref.alloca() {test.ptr} : memref<10x20xf32>
5+
// CHECK: Successfully generated free for operation: %[[ORIG:.*]] = memref.alloca() {test.ptr} : memref<10x20xf32>
6+
// CHECK-NOT: Generated
7+
return
8+
}
9+
10+
// -----
11+
12+
func.func @test_dynamic_memref_free() {
13+
%c10 = arith.constant 10 : index
14+
%c20 = arith.constant 20 : index
15+
%orig = memref.alloc(%c10, %c20) {test.ptr} : memref<?x?xf32>
16+
17+
// CHECK: Successfully generated free for operation: %[[ORIG:.*]] = memref.alloc(%[[C10:.*]], %[[C20:.*]]) {test.ptr} : memref<?x?xf32>
18+
// CHECK: Generated: memref.dealloc %[[ORIG]] : memref<?x?xf32>
19+
return
20+
}
21+
22+
// -----
23+
24+
func.func @test_cast_walking_free() {
25+
%0 = memref.alloca() : memref<10x20xf32>
26+
%1 = memref.cast %0 {test.ptr} : memref<10x20xf32> to memref<?x?xf32>
27+
28+
// CHECK: Successfully generated free for operation: %[[CAST:.*]] = memref.cast %[[ALLOCA:.*]] {test.ptr} : memref<10x20xf32> to memref<?x?xf32>
29+
// CHECK-NOT: Generated
30+
return
31+
}

mlir/test/lib/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ add_subdirectory(Math)
1212
add_subdirectory(MemRef)
1313
add_subdirectory(Shard)
1414
add_subdirectory(NVGPU)
15+
add_subdirectory(OpenACC)
1516
add_subdirectory(SCF)
1617
add_subdirectory(Shape)
1718
add_subdirectory(SPIRV)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
add_mlir_library(MLIROpenACCTestPasses
2+
TestOpenACC.cpp
3+
TestPointerLikeTypeInterface.cpp
4+
5+
EXCLUDE_FROM_LIBMLIR
6+
)
7+
mlir_target_link_libraries(MLIROpenACCTestPasses PUBLIC
8+
MLIRIR
9+
MLIRArithDialect
10+
MLIRFuncDialect
11+
MLIRMemRefDialect
12+
MLIROpenACCDialect
13+
MLIRPass
14+
MLIRSupport
15+
)
16+
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//===- TestOpenACC.cpp - OpenACC Test Registration ------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains unified registration for all OpenACC test passes.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
namespace mlir {
14+
namespace test {
15+
16+
// Forward declarations of individual test pass registration functions
17+
void registerTestPointerLikeTypeInterfacePass();
18+
19+
// Unified registration function for all OpenACC tests
20+
void registerTestOpenACC() { registerTestPointerLikeTypeInterfacePass(); }
21+
22+
} // namespace test
23+
} // namespace mlir

0 commit comments

Comments
 (0)