Skip to content

Commit 1eccb26

Browse files
committed
[MLIR][OpenMP] Refactor omp.target_allocmem to allow reuse, NFC
This patch moves tablegen definitions that could be used for all kinds of heap allocations out of `omp.target_allocmem` and into a new `OpenMP_HeapAllocClause` that can be reused. Descriptions are updated to follow the format of most other operations and the custom verifier for `omp.target_allocmem` is removed as it only made a redundant check on its result type.
1 parent d31b265 commit 1eccb26

File tree

5 files changed

+176
-148
lines changed

5 files changed

+176
-148
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#define OPENMP_CLAUSES
2121

2222
include "mlir/Dialect/OpenMP/OpenMPOpBase.td"
23+
include "mlir/Interfaces/SideEffectInterfaces.td"
2324
include "mlir/IR/SymbolInterfaces.td"
2425

2526
//===----------------------------------------------------------------------===//
@@ -547,6 +548,58 @@ class OpenMP_HasDeviceAddrClauseSkip<
547548

548549
def OpenMP_HasDeviceAddrClause : OpenMP_HasDeviceAddrClauseSkip<>;
549550

551+
//===----------------------------------------------------------------------===//
552+
// Not in the spec: Clause-like structure to hold heap allocation information.
553+
//===----------------------------------------------------------------------===//
554+
555+
class OpenMP_HeapAllocClauseSkip<
556+
bit traits = false, bit arguments = false, bit assemblyFormat = false,
557+
bit description = false, bit extraClassDeclaration = false
558+
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
559+
extraClassDeclaration> {
560+
let traits = [
561+
MemoryEffects<[MemAlloc<DefaultResource>]>
562+
];
563+
564+
let arguments = (ins
565+
TypeAttr:$in_type,
566+
OptionalAttr<StrAttr>:$uniq_name,
567+
OptionalAttr<StrAttr>:$bindc_name,
568+
Variadic<IntLikeType>:$typeparams,
569+
Variadic<IntLikeType>:$shape
570+
);
571+
572+
// The custom parser doesn't parse `uniq_name` and `bindc_name`. This is
573+
// handled by the attr-dict, which must be present in the operation's
574+
// `assemblyFormat`.
575+
let reqAssemblyFormat = [{
576+
custom<HeapAllocClause>($in_type, $typeparams, type($typeparams), $shape,
577+
type($shape))
578+
}];
579+
580+
let extraClassDeclaration = [{
581+
mlir::Type getAllocatedType() { return getInTypeAttr().getValue(); }
582+
}];
583+
584+
let description = [{
585+
The `in_type` is the type of the object for which memory is being allocated.
586+
For arrays, this can be a static or dynamic array type.
587+
588+
The optional `uniq_name` is a unique name for the allocated memory.
589+
590+
The optional `bindc_name` is a name used for C interoperability.
591+
592+
The `typeparams` are runtime type parameters for polymorphic or
593+
parameterized types. These are typically integer values that define aspects
594+
of a type not fixed at compile time.
595+
596+
The `shape` holds runtime shape operands for dynamic arrays. Each operand is
597+
an integer value representing the extent of a specific dimension.
598+
}];
599+
}
600+
601+
def OpenMP_HeapAllocClause : OpenMP_HeapAllocClauseSkip<>;
602+
550603
//===----------------------------------------------------------------------===//
551604
// V5.2: [5.4.7] `inclusive` clause
552605
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 33 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2128,59 +2128,45 @@ def AllocateDirOp : OpenMP_Op<"allocate_dir", [AttrSizedOperandSegments], clause
21282128
// TargetAllocMemOp
21292129
//===----------------------------------------------------------------------===//
21302130

2131-
def TargetAllocMemOp : OpenMP_Op<"target_allocmem",
2132-
[MemoryEffects<[MemAlloc<DefaultResource>]>, AttrSizedOperandSegments]> {
2131+
def TargetAllocMemOp : OpenMP_Op<"target_allocmem", traits = [
2132+
AttrSizedOperandSegments
2133+
], clauses = [
2134+
OpenMP_HeapAllocClause
2135+
]> {
21332136
let summary = "allocate storage on an openmp device for an object of a given type";
21342137

21352138
let description = [{
2136-
Allocates memory on the specified OpenMP device for an object of the given type.
2137-
Returns an integer value representing the device pointer to the allocated memory.
2138-
The memory is uninitialized after allocation. Operations must be paired with
2139-
`omp.target_freemem` to avoid memory leaks.
2140-
2141-
* `$device`: The integer ID of the OpenMP device where the memory will be allocated.
2142-
* `$in_type`: The type of the object for which memory is being allocated.
2143-
For arrays, this can be a static or dynamic array type.
2144-
* `$uniq_name`: An optional unique name for the allocated memory.
2145-
* `$bindc_name`: An optional name used for C interoperability.
2146-
* `$typeparams`: Runtime type parameters for polymorphic or parameterized types.
2147-
These are typically integer values that define aspects of a type not fixed at compile time.
2148-
* `$shape`: Runtime shape operands for dynamic arrays.
2149-
Each operand is an integer value representing the extent of a specific dimension.
2150-
2151-
```mlir
2152-
// Allocate a static 3x3 integer vector on device 0
2153-
%device_0 = arith.constant 0 : i32
2154-
%ptr_static = omp.target_allocmem %device_0 : i32, vector<3x3xi32>
2155-
// ... use %ptr_static ...
2156-
omp.target_freemem %device_0, %ptr_static : i32, i64
2157-
2158-
// Allocate a dynamic 2D Fortran array (fir.array) on device 1
2159-
%device_1 = arith.constant 1 : i32
2160-
%rows = arith.constant 10 : index
2161-
%cols = arith.constant 20 : index
2162-
%ptr_dynamic = omp.target_allocmem %device_1 : i32, !fir.array<?x?xf32>, %rows, %cols : index, index
2163-
// ... use %ptr_dynamic ...
2164-
omp.target_freemem %device_1, %ptr_dynamic : i32, i64
2165-
```
2166-
}];
2139+
Allocates memory on the specified OpenMP device for an object of the given
2140+
type. Returns an integer value representing the device pointer to the
2141+
allocated memory. The memory is uninitialized after allocation. Operations
2142+
must be paired with `omp.target_freemem` to avoid memory leaks.
21672143

2168-
let arguments = (ins
2169-
Arg<AnyInteger>:$device,
2170-
TypeAttr:$in_type,
2171-
OptionalAttr<StrAttr>:$uniq_name,
2172-
OptionalAttr<StrAttr>:$bindc_name,
2173-
Variadic<IntLikeType>:$typeparams,
2174-
Variadic<IntLikeType>:$shape
2175-
);
2176-
let results = (outs I64);
2144+
```mlir
2145+
// Allocate a static 3x3 integer vector on device 0
2146+
%device_0 = arith.constant 0 : i32
2147+
%ptr_static = omp.target_allocmem %device_0 : i32, vector<3x3xi32>
2148+
// ... use %ptr_static ...
2149+
omp.target_freemem %device_0, %ptr_static : i32, i64
2150+
2151+
// Allocate a dynamic 2D Fortran array (fir.array) on device 1
2152+
%device_1 = arith.constant 1 : i32
2153+
%rows = arith.constant 10 : index
2154+
%cols = arith.constant 20 : index
2155+
%ptr_dynamic = omp.target_allocmem %device_1 : i32, !fir.array<?x?xf32>, %rows, %cols : index, index
2156+
// ... use %ptr_dynamic ...
2157+
omp.target_freemem %device_1, %ptr_dynamic : i32, i64
2158+
```
21772159

2178-
let hasCustomAssemblyFormat = 1;
2179-
let hasVerifier = 1;
2160+
The `device` is an integer ID of the OpenMP device where the memory will be
2161+
allocated.
2162+
}] # clausesDescription;
21802163

2181-
let extraClassDeclaration = [{
2182-
mlir::Type getAllocatedType();
2183-
}];
2164+
let arguments = !con((ins Arg<AnyInteger>:$device), clausesArgs);
2165+
let results = (outs I64);
2166+
2167+
// Override inherited assembly format to include `device`.
2168+
let assemblyFormat = " $device `:` type($device) `,` "
2169+
# clausesReqAssemblyFormat # " attr-dict";
21842170
}
21852171

21862172
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 52 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,58 @@ static void printNumTasksClause(OpAsmPrinter &p, Operation *op,
797797
p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
798798
}
799799

800+
//===----------------------------------------------------------------------===//
801+
// Parser and printer for Heap Alloc Clause
802+
//===----------------------------------------------------------------------===//
803+
804+
/// operation ::= $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
805+
static ParseResult parseHeapAllocClause(
806+
OpAsmParser &parser, TypeAttr &inTypeAttr,
807+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &typeparams,
808+
SmallVectorImpl<Type> &typeparamsTypes,
809+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &shape,
810+
SmallVectorImpl<Type> &shapeTypes) {
811+
mlir::Type inType;
812+
if (parser.parseType(inType))
813+
return mlir::failure();
814+
inTypeAttr = TypeAttr::get(inType);
815+
816+
if (!parser.parseOptionalLParen()) {
817+
// parse the LEN params of the derived type. (<params> : <types>)
818+
if (parser.parseOperandList(typeparams, OpAsmParser::Delimiter::None) ||
819+
parser.parseColonTypeList(typeparamsTypes) || parser.parseRParen())
820+
return failure();
821+
}
822+
823+
if (!parser.parseOptionalComma()) {
824+
// parse size to scale by, vector of n dimensions of type index
825+
if (parser.parseOperandList(shape, OpAsmParser::Delimiter::None))
826+
return failure();
827+
828+
// TODO: This overrides the actual types of the operands, which might cause
829+
// issues when they don't match. At the moment this is done in place of
830+
// making the corresponding operand type `Variadic<Index>` because index
831+
// types are lowered to I64 prior to LLVM IR translation.
832+
shapeTypes.append(shape.size(), IndexType::get(parser.getContext()));
833+
}
834+
835+
return success();
836+
}
837+
838+
static void printHeapAllocClause(OpAsmPrinter &p, Operation *op,
839+
TypeAttr inType, ValueRange typeparams,
840+
TypeRange typeparamsTypes, ValueRange shape,
841+
TypeRange shapeTypes) {
842+
p << inType;
843+
if (!typeparams.empty()) {
844+
p << '(' << typeparams << " : " << typeparamsTypes << ')';
845+
}
846+
for (auto sh : shape) {
847+
p << ", ";
848+
p.printOperand(sh);
849+
}
850+
}
851+
800852
//===----------------------------------------------------------------------===//
801853
// Parsers for operations including clauses that define entry block arguments.
802854
//===----------------------------------------------------------------------===//
@@ -4109,107 +4161,6 @@ LogicalResult AllocateDirOp::verify() {
41094161
return success();
41104162
}
41114163

4112-
//===----------------------------------------------------------------------===//
4113-
// TargetAllocMemOp
4114-
//===----------------------------------------------------------------------===//
4115-
4116-
mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
4117-
return getInTypeAttr().getValue();
4118-
}
4119-
4120-
/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype,
4121-
/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
4122-
/// attr-dict-without-keyword
4123-
static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser,
4124-
mlir::OperationState &result) {
4125-
auto &builder = parser.getBuilder();
4126-
bool hasOperands = false;
4127-
std::int32_t typeparamsSize = 0;
4128-
4129-
// Parse device number as a new operand
4130-
mlir::OpAsmParser::UnresolvedOperand deviceOperand;
4131-
mlir::Type deviceType;
4132-
if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
4133-
return mlir::failure();
4134-
if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
4135-
return mlir::failure();
4136-
if (parser.parseComma())
4137-
return mlir::failure();
4138-
4139-
mlir::Type intype;
4140-
if (parser.parseType(intype))
4141-
return mlir::failure();
4142-
result.addAttribute("in_type", mlir::TypeAttr::get(intype));
4143-
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
4144-
llvm::SmallVector<mlir::Type> typeVec;
4145-
if (!parser.parseOptionalLParen()) {
4146-
// parse the LEN params of the derived type. (<params> : <types>)
4147-
if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) ||
4148-
parser.parseColonTypeList(typeVec) || parser.parseRParen())
4149-
return mlir::failure();
4150-
typeparamsSize = operands.size();
4151-
hasOperands = true;
4152-
}
4153-
std::int32_t shapeSize = 0;
4154-
if (!parser.parseOptionalComma()) {
4155-
// parse size to scale by, vector of n dimensions of type index
4156-
if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None))
4157-
return mlir::failure();
4158-
shapeSize = operands.size() - typeparamsSize;
4159-
auto idxTy = builder.getIndexType();
4160-
for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
4161-
typeVec.push_back(idxTy);
4162-
hasOperands = true;
4163-
}
4164-
if (hasOperands &&
4165-
parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
4166-
result.operands))
4167-
return mlir::failure();
4168-
4169-
mlir::Type restype = builder.getIntegerType(64);
4170-
if (!restype) {
4171-
parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
4172-
return mlir::failure();
4173-
}
4174-
llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize};
4175-
result.addAttribute("operandSegmentSizes",
4176-
builder.getDenseI32ArrayAttr(segmentSizes));
4177-
if (parser.parseOptionalAttrDict(result.attributes) ||
4178-
parser.addTypeToList(restype, result.types))
4179-
return mlir::failure();
4180-
return mlir::success();
4181-
}
4182-
4183-
mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser,
4184-
mlir::OperationState &result) {
4185-
return parseTargetAllocMemOp(parser, result);
4186-
}
4187-
4188-
void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) {
4189-
p << " ";
4190-
p.printOperand(getDevice());
4191-
p << " : ";
4192-
p << getDevice().getType();
4193-
p << ", ";
4194-
p << getInType();
4195-
if (!getTypeparams().empty()) {
4196-
p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
4197-
}
4198-
for (auto sh : getShape()) {
4199-
p << ", ";
4200-
p.printOperand(sh);
4201-
}
4202-
p.printOptionalAttrDict((*this)->getAttrs(),
4203-
{"in_type", "operandSegmentSizes"});
4204-
}
4205-
4206-
llvm::LogicalResult omp::TargetAllocMemOp::verify() {
4207-
mlir::Type outType = getType();
4208-
if (!mlir::dyn_cast<IntegerType>(outType))
4209-
return emitOpError("must be a integer type");
4210-
return mlir::success();
4211-
}
4212-
42134164
//===----------------------------------------------------------------------===//
42144165
// WorkdistributeOp
42154166
//===----------------------------------------------------------------------===//

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3139,3 +3139,17 @@ func.func @invalid_workdistribute() -> () {
31393139
}
31403140
return
31413141
}
3142+
3143+
// -----
3144+
func.func @target_allocmem_invalid_uniq_name(%device : i32) -> () {
3145+
// expected-error @below {{op attribute 'uniq_name' failed to satisfy constraint: string attribute}}
3146+
%0 = omp.target_allocmem %device : i32, i64 {uniq_name=2}
3147+
return
3148+
}
3149+
3150+
// -----
3151+
func.func @target_allocmem_invalid_bindc_name(%device : i32) -> () {
3152+
// expected-error @below {{op attribute 'bindc_name' failed to satisfy constraint: string attribute}}
3153+
%0 = omp.target_allocmem %device : i32, i64 {bindc_name=2}
3154+
return
3155+
}

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3321,3 +3321,27 @@ func.func @omp_workdistribute() {
33213321
}
33223322
return
33233323
}
3324+
3325+
// CHECK-LABEL: func.func @omp_target_allocmem(
3326+
// CHECK-SAME: %[[DEVICE:.*]]: i32, %[[X:.*]]: index, %[[Y:.*]]: index, %[[Z:.*]]: i32) {
3327+
func.func @omp_target_allocmem(%device: i32, %x: index, %y: index, %z: i32) {
3328+
// CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, i64
3329+
%0 = omp.target_allocmem %device : i32, i64
3330+
// CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, vector<16x16xf32> {bindc_name = "bindc", uniq_name = "uniq"}
3331+
%1 = omp.target_allocmem %device : i32, vector<16x16xf32> {uniq_name="uniq", bindc_name="bindc"}
3332+
// CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32)
3333+
%2 = omp.target_allocmem %device : i32, !llvm.ptr(%x, %y, %z : index, index, i32)
3334+
// CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, !llvm.ptr, %[[X]], %[[Y]]
3335+
%3 = omp.target_allocmem %device : i32, !llvm.ptr, %x, %y
3336+
// CHECK: %{{.*}} = omp.target_allocmem %[[DEVICE]] : i32, !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32), %[[X]], %[[Y]]
3337+
%4 = omp.target_allocmem %device : i32, !llvm.ptr(%x, %y, %z : index, index, i32), %x, %y
3338+
return
3339+
}
3340+
3341+
// CHECK-LABEL: func.func @omp_target_freemem(
3342+
// CHECK-SAME: %[[DEVICE:.*]]: i32, %[[PTR:.*]]: i64) {
3343+
func.func @omp_target_freemem(%device : i32, %ptr : i64) {
3344+
// CHECK: omp.target_freemem %[[DEVICE]], %[[PTR]] : i32, i64
3345+
omp.target_freemem %device, %ptr : i32, i64
3346+
return
3347+
}

0 commit comments

Comments
 (0)