Skip to content

Commit a6b5d96

Browse files
committed
[MLIR] [OpenMP] Modify definition of ALLOCATOR clause to support
allocator type defined in user program.
1 parent d57aa48 commit a6b5d96

File tree

4 files changed

+89
-40
lines changed

4 files changed

+89
-40
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,11 @@ class OpenMP_AllocatorClauseSkip<
120120
extraClassDeclaration> {
121121

122122
let arguments = (ins
123-
OptionalAttr<AllocatorHandleAttr>:$allocator
123+
DefaultValuedOptionalAttr<I64Attr, "0">:$allocator
124124
);
125125

126126
let optAssemblyFormat = [{
127-
`allocator` `(` custom<ClauseAttr>($allocator) `)`
127+
`allocator` `(` custom<AllocatorHandle>($allocator) `)`
128128
}];
129129

130130
let description = [{

mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -263,34 +263,4 @@ def VariableCaptureKindAttr : OpenMP_EnumAttr<VariableCaptureKind,
263263
let assemblyFormat = "`(` $value `)`";
264264
}
265265

266-
267-
//===----------------------------------------------------------------------===//
268-
// allocator_handle enum.
269-
//===----------------------------------------------------------------------===//
270-
271-
def OpenMP_AllocatorHandleNullAllocator : I32EnumAttrCase<"omp_null_allocator", 0>;
272-
def OpenMP_AllocatorHandleDefaultMemAlloc : I32EnumAttrCase<"omp_default_mem_alloc", 1>;
273-
def OpenMP_AllocatorHandleLargeCapMemAlloc : I32EnumAttrCase<"omp_large_cap_mem_alloc", 2>;
274-
def OpenMP_AllocatorHandleConstMemAlloc : I32EnumAttrCase<"omp_const_mem_alloc", 3>;
275-
def OpenMP_AllocatorHandleHighBwMemAlloc : I32EnumAttrCase<"omp_high_bw_mem_alloc", 4>;
276-
def OpenMP_AllocatorHandleLowLatMemAlloc : I32EnumAttrCase<"omp_low_lat_mem_alloc", 5>;
277-
def OpenMP_AllocatorHandleCgroupMemAlloc : I32EnumAttrCase<"omp_cgroup_mem_alloc", 6>;
278-
def OpenMP_AllocatorHandlePteamMemAlloc : I32EnumAttrCase<"omp_pteam_mem_alloc", 7>;
279-
def OpenMP_AllocatorHandlethreadMemAlloc : I32EnumAttrCase<"omp_thread_mem_alloc", 8>;
280-
281-
def AllocatorHandle : OpenMP_I32EnumAttr<
282-
"AllocatorHandle",
283-
"OpenMP allocator_handle", [
284-
OpenMP_AllocatorHandleNullAllocator,
285-
OpenMP_AllocatorHandleDefaultMemAlloc,
286-
OpenMP_AllocatorHandleLargeCapMemAlloc,
287-
OpenMP_AllocatorHandleConstMemAlloc,
288-
OpenMP_AllocatorHandleHighBwMemAlloc,
289-
OpenMP_AllocatorHandleLowLatMemAlloc,
290-
OpenMP_AllocatorHandleCgroupMemAlloc,
291-
OpenMP_AllocatorHandlePteamMemAlloc,
292-
OpenMP_AllocatorHandlethreadMemAlloc
293-
]>;
294-
295-
def AllocatorHandleAttr : OpenMP_EnumAttr<AllocatorHandle, "allocator_handle">;
296266
#endif // OPENMP_ENUMS

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,6 +1250,93 @@ verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
12501250
return success();
12511251
}
12521252

1253+
//===----------------------------------------------------------------------===//
1254+
// Parser, printer and verifier for Allocator (Section 8.4 in OpenMP 6.0)
1255+
//===----------------------------------------------------------------------===//
1256+
1257+
/// Parses a allocator clause. The value of allocator handle is an integer
1258+
/// which is a combination of different allocator handles from
1259+
/// `omp_allocator_handle_t`.
1260+
///
1261+
/// allocator-clause = `allocator` `(` allocator-value `)`
1262+
static ParseResult parseAllocatorHandle(OpAsmParser &parser,
1263+
IntegerAttr &allocatorHandleAttr) {
1264+
StringRef allocatorKeyword;
1265+
int64_t allocator = 0;
1266+
if (succeeded(parser.parseOptionalKeyword("none"))) {
1267+
allocatorHandleAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1268+
return success();
1269+
}
1270+
auto parseKeyword = [&]() -> ParseResult {
1271+
if (failed(parser.parseKeyword(&allocatorKeyword)))
1272+
return failure();
1273+
if (allocatorKeyword == "omp_null_allocator")
1274+
allocator = 0;
1275+
else if (allocatorKeyword == "omp_default_mem_alloc")
1276+
allocator = 1;
1277+
else if (allocatorKeyword == "omp_large_cap_mem_alloc")
1278+
allocator = 2;
1279+
else if (allocatorKeyword == "omp_const_mem_alloc")
1280+
allocator = 3;
1281+
else if (allocatorKeyword == "omp_high_bw_mem_alloc")
1282+
allocator = 4;
1283+
else if (allocatorKeyword == "omp_low_lat_mem_alloc")
1284+
allocator = 5;
1285+
else if (allocatorKeyword == "omp_cgroup_mem_alloc")
1286+
allocator = 6;
1287+
else if (allocatorKeyword == "omp_pteam_mem_alloc")
1288+
allocator = 7;
1289+
else if (allocatorKeyword == "omp_thread_mem_alloc")
1290+
allocator = 8;
1291+
else
1292+
return parser.emitError(parser.getCurrentLocation())
1293+
<< allocatorKeyword << " is not a valid allocator";
1294+
return success();
1295+
};
1296+
if (parser.parseCommaSeparatedList(parseKeyword))
1297+
return failure();
1298+
allocatorHandleAttr =
1299+
IntegerAttr::get(parser.getBuilder().getI64Type(), allocator);
1300+
return success();
1301+
}
1302+
1303+
/// Prints a allocator clause
1304+
static void printAllocatorHandle(OpAsmPrinter &p, Operation *op,
1305+
IntegerAttr allocatorHandleAttr) {
1306+
int64_t allocator = allocatorHandleAttr.getInt();
1307+
StringRef allocatorHandle;
1308+
switch (allocator) {
1309+
case 0:
1310+
allocatorHandle = "omp_null_allocator";
1311+
break;
1312+
case 1:
1313+
allocatorHandle = "omp_default_mem_alloc";
1314+
break;
1315+
case 2:
1316+
allocatorHandle = "omp_large_cap_mem_alloc";
1317+
break;
1318+
case 3:
1319+
allocatorHandle = "omp_const_mem_alloc";
1320+
break;
1321+
case 4:
1322+
allocatorHandle = "omp_high_bw_mem_alloc";
1323+
break;
1324+
case 5:
1325+
allocatorHandle = "omp_low_lat_mem_alloc";
1326+
break;
1327+
case 6:
1328+
allocatorHandle = "omp_cgroup_mem_alloc";
1329+
break;
1330+
case 7:
1331+
allocatorHandle = "omp_pteam_mem_alloc";
1332+
break;
1333+
case 8:
1334+
allocatorHandle = "omp_thread_mem_alloc";
1335+
break;
1336+
}
1337+
p << allocatorHandle;
1338+
}
1339+
12531340
//===----------------------------------------------------------------------===//
12541341
// Parser, printer and verifier for Copyprivate
12551342
//===----------------------------------------------------------------------===//

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3033,14 +3033,6 @@ func.func @invalid_allocate_align_2(%arg0 : memref<i32>) -> () {
30333033
return
30343034
}
30353035

3036-
// -----
3037-
func.func @invalid_allocate_allocator(%arg0 : memref<i32>) -> () {
3038-
// expected-error @below {{invalid clause value}}
3039-
omp.allocate_dir (%arg0 : memref<i32>) allocator(omp_small_cap_mem_alloc)
3040-
3041-
return
3042-
}
3043-
30443036
// -----
30453037
func.func @invalid_workdistribute_empty_region() -> () {
30463038
omp.teams {

0 commit comments

Comments
 (0)