From a6b5d960f8ce6217800b90597b1b572cca4dccbf Mon Sep 17 00:00:00 2001 From: Raghu Maddhipatla Date: Mon, 8 Sep 2025 02:13:39 -0500 Subject: [PATCH 1/3] [MLIR] [OpenMP] Modify definition of ALLOCATOR clause to support allocator type defined in user program. --- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 4 +- .../mlir/Dialect/OpenMP/OpenMPEnums.td | 30 ------- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 87 +++++++++++++++++++ mlir/test/Dialect/OpenMP/invalid.mlir | 8 -- 4 files changed, 89 insertions(+), 40 deletions(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 5f40abe62a0f6..675f62902e75b 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -120,11 +120,11 @@ class OpenMP_AllocatorClauseSkip< extraClassDeclaration> { let arguments = (ins - OptionalAttr:$allocator + DefaultValuedOptionalAttr:$allocator ); let optAssemblyFormat = [{ - `allocator` `(` custom($allocator) `)` + `allocator` `(` custom($allocator) `)` }]; let description = [{ diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td index c080c3fac87d4..9dbe6897a3304 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td @@ -263,34 +263,4 @@ def VariableCaptureKindAttr : OpenMP_EnumAttr; -def OpenMP_AllocatorHandleDefaultMemAlloc : I32EnumAttrCase<"omp_default_mem_alloc", 1>; -def OpenMP_AllocatorHandleLargeCapMemAlloc : I32EnumAttrCase<"omp_large_cap_mem_alloc", 2>; -def OpenMP_AllocatorHandleConstMemAlloc : I32EnumAttrCase<"omp_const_mem_alloc", 3>; -def OpenMP_AllocatorHandleHighBwMemAlloc : I32EnumAttrCase<"omp_high_bw_mem_alloc", 4>; -def OpenMP_AllocatorHandleLowLatMemAlloc : I32EnumAttrCase<"omp_low_lat_mem_alloc", 5>; -def OpenMP_AllocatorHandleCgroupMemAlloc : I32EnumAttrCase<"omp_cgroup_mem_alloc", 6>; -def OpenMP_AllocatorHandlePteamMemAlloc : I32EnumAttrCase<"omp_pteam_mem_alloc", 7>; -def OpenMP_AllocatorHandlethreadMemAlloc : I32EnumAttrCase<"omp_thread_mem_alloc", 8>; - -def AllocatorHandle : OpenMP_I32EnumAttr< - "AllocatorHandle", - "OpenMP allocator_handle", [ - OpenMP_AllocatorHandleNullAllocator, - OpenMP_AllocatorHandleDefaultMemAlloc, - OpenMP_AllocatorHandleLargeCapMemAlloc, - OpenMP_AllocatorHandleConstMemAlloc, - OpenMP_AllocatorHandleHighBwMemAlloc, - OpenMP_AllocatorHandleLowLatMemAlloc, - OpenMP_AllocatorHandleCgroupMemAlloc, - OpenMP_AllocatorHandlePteamMemAlloc, - OpenMP_AllocatorHandlethreadMemAlloc - ]>; - -def AllocatorHandleAttr : OpenMP_EnumAttr; #endif // OPENMP_ENUMS diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 3d70e28ed23ab..cee9230f1ff0f 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1250,6 +1250,93 @@ verifyReductionVarList(Operation *op, std::optional reductionSyms, return success(); } +//===----------------------------------------------------------------------===// +// Parser, printer and verifier for Allocator (Section 8.4 in OpenMP 6.0) +//===----------------------------------------------------------------------===// + +/// Parses a allocator clause. The value of allocator handle is an integer +/// which is a combination of different allocator handles from +/// `omp_allocator_handle_t`. +/// +/// allocator-clause = `allocator` `(` allocator-value `)` +static ParseResult parseAllocatorHandle(OpAsmParser &parser, + IntegerAttr &allocatorHandleAttr) { + StringRef allocatorKeyword; + int64_t allocator = 0; + if (succeeded(parser.parseOptionalKeyword("none"))) { + allocatorHandleAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); + return success(); + } + auto parseKeyword = [&]() -> ParseResult { + if (failed(parser.parseKeyword(&allocatorKeyword))) + return failure(); + if (allocatorKeyword == "omp_null_allocator") + allocator = 0; + else if (allocatorKeyword == "omp_default_mem_alloc") + allocator = 1; + else if (allocatorKeyword == "omp_large_cap_mem_alloc") + allocator = 2; + else if (allocatorKeyword == "omp_const_mem_alloc") + allocator = 3; + else if (allocatorKeyword == "omp_high_bw_mem_alloc") + allocator = 4; + else if (allocatorKeyword == "omp_low_lat_mem_alloc") + allocator = 5; + else if (allocatorKeyword == "omp_cgroup_mem_alloc") + allocator = 6; + else if (allocatorKeyword == "omp_pteam_mem_alloc") + allocator = 7; + else if (allocatorKeyword == "omp_thread_mem_alloc") + allocator = 8; + else + return parser.emitError(parser.getCurrentLocation()) + << allocatorKeyword << " is not a valid allocator"; + return success(); + }; + if (parser.parseCommaSeparatedList(parseKeyword)) + return failure(); + allocatorHandleAttr = + IntegerAttr::get(parser.getBuilder().getI64Type(), allocator); + return success(); +} + +/// Prints a allocator clause +static void printAllocatorHandle(OpAsmPrinter &p, Operation *op, + IntegerAttr allocatorHandleAttr) { + int64_t allocator = allocatorHandleAttr.getInt(); + StringRef allocatorHandle; + switch (allocator) { + case 0: + allocatorHandle = "omp_null_allocator"; + break; + case 1: + allocatorHandle = "omp_default_mem_alloc"; + break; + case 2: + allocatorHandle = "omp_large_cap_mem_alloc"; + break; + case 3: + allocatorHandle = "omp_const_mem_alloc"; + break; + case 4: + allocatorHandle = "omp_high_bw_mem_alloc"; + break; + case 5: + allocatorHandle = "omp_low_lat_mem_alloc"; + break; + case 6: + allocatorHandle = "omp_cgroup_mem_alloc"; + break; + case 7: + allocatorHandle = "omp_pteam_mem_alloc"; + break; + case 8: + allocatorHandle = "omp_thread_mem_alloc"; + break; + } + p << allocatorHandle; +} + //===----------------------------------------------------------------------===// // Parser, printer and verifier for Copyprivate //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 763f41c5420b8..af24d969064ab 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -3033,14 +3033,6 @@ func.func @invalid_allocate_align_2(%arg0 : memref) -> () { return } -// ----- -func.func @invalid_allocate_allocator(%arg0 : memref) -> () { - // expected-error @below {{invalid clause value}} - omp.allocate_dir (%arg0 : memref) allocator(omp_small_cap_mem_alloc) - - return -} - // ----- func.func @invalid_workdistribute_empty_region() -> () { omp.teams { From d455dcafef67f2457bc51eb5608d786af28fec15 Mon Sep 17 00:00:00 2001 From: Raghu Maddhipatla Date: Tue, 9 Sep 2025 01:10:20 -0500 Subject: [PATCH 2/3] Add test-case for user-defined allocator value --- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 53 ++++++++++---------- mlir/test/Dialect/OpenMP/ops.mlir | 6 ++- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index cee9230f1ff0f..7d15eab5c5232 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1267,34 +1267,32 @@ static ParseResult parseAllocatorHandle(OpAsmParser &parser, allocatorHandleAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); return success(); } - auto parseKeyword = [&]() -> ParseResult { - if (failed(parser.parseKeyword(&allocatorKeyword))) - return failure(); - if (allocatorKeyword == "omp_null_allocator") - allocator = 0; - else if (allocatorKeyword == "omp_default_mem_alloc") - allocator = 1; - else if (allocatorKeyword == "omp_large_cap_mem_alloc") - allocator = 2; - else if (allocatorKeyword == "omp_const_mem_alloc") - allocator = 3; - else if (allocatorKeyword == "omp_high_bw_mem_alloc") - allocator = 4; - else if (allocatorKeyword == "omp_low_lat_mem_alloc") - allocator = 5; - else if (allocatorKeyword == "omp_cgroup_mem_alloc") - allocator = 6; - else if (allocatorKeyword == "omp_pteam_mem_alloc") - allocator = 7; - else if (allocatorKeyword == "omp_thread_mem_alloc") - allocator = 8; - else - return parser.emitError(parser.getCurrentLocation()) - << allocatorKeyword << " is not a valid allocator"; + OptionalParseResult parsedInteger = parser.parseOptionalInteger(allocator); + if (parsedInteger.has_value()) { + allocatorHandleAttr = + IntegerAttr::get(parser.getBuilder().getI64Type(), allocator); return success(); - }; - if (parser.parseCommaSeparatedList(parseKeyword)) + } + if (failed(parser.parseKeyword(&allocatorKeyword))) return failure(); + if (allocatorKeyword == "omp_null_allocator") + allocator = 0; + else if (allocatorKeyword == "omp_default_mem_alloc") + allocator = 1; + else if (allocatorKeyword == "omp_large_cap_mem_alloc") + allocator = 2; + else if (allocatorKeyword == "omp_const_mem_alloc") + allocator = 3; + else if (allocatorKeyword == "omp_high_bw_mem_alloc") + allocator = 4; + else if (allocatorKeyword == "omp_low_lat_mem_alloc") + allocator = 5; + else if (allocatorKeyword == "omp_cgroup_mem_alloc") + allocator = 6; + else if (allocatorKeyword == "omp_pteam_mem_alloc") + allocator = 7; + else if (allocatorKeyword == "omp_thread_mem_alloc") + allocator = 8; allocatorHandleAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), allocator); return success(); @@ -1333,6 +1331,9 @@ static void printAllocatorHandle(OpAsmPrinter &p, Operation *op, case 8: allocatorHandle = "omp_thread_mem_alloc"; break; + default: + p << Twine(allocator).str(); + return; } p << allocatorHandle; } diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 60b1f61135ac2..79046a72006d7 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -3279,7 +3279,7 @@ func.func @omp_allocate_dir(%arg0 : memref, %arg1 : memref) -> () { // Test with one data var and allocator clause // CHECK: omp.allocate_dir(%[[ARG0]] : memref) allocator(omp_pteam_mem_alloc) - omp.allocate_dir (%arg0 : memref) allocator(omp_pteam_mem_alloc) + omp.allocate_dir (%arg0 : memref) allocator(7) // Test with one data var, align clause and allocator clause // CHECK: omp.allocate_dir(%[[ARG0]] : memref) align(2) allocator(omp_thread_mem_alloc) @@ -3289,6 +3289,10 @@ func.func @omp_allocate_dir(%arg0 : memref, %arg1 : memref) -> () { // CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref, memref) align(2) allocator(omp_cgroup_mem_alloc) omp.allocate_dir (%arg0, %arg1 : memref, memref) align(2) allocator(omp_cgroup_mem_alloc) + // Test with one data var and user defined allocator clause + // CHECK: omp.allocate_dir(%[[ARG0]] : memref) allocator(9) + omp.allocate_dir (%arg0 : memref) allocator(9) + return } From c89d0ff593e18ef6264316b2913a2f7cdcaa0a3e Mon Sep 17 00:00:00 2001 From: Raghu Maddhipatla Date: Wed, 17 Sep 2025 17:38:03 -0500 Subject: [PATCH 3/3] Changed allocator clause definition to use Integer type value argument instead of IntegerAttr. --- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 4 +- mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 2 +- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 88 ------------------- mlir/test/Dialect/OpenMP/ops.mlir | 29 ++++-- 4 files changed, 24 insertions(+), 99 deletions(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 675f62902e75b..1eda5e4bc1618 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -120,11 +120,11 @@ class OpenMP_AllocatorClauseSkip< extraClassDeclaration> { let arguments = (ins - DefaultValuedOptionalAttr:$allocator + Optional:$allocator ); let optAssemblyFormat = [{ - `allocator` `(` custom($allocator) `)` + `allocator` `(` $allocator `)` }]; let description = [{ diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 830b36f440098..5c77e215467e4 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -2100,7 +2100,7 @@ def MaskedOp : OpenMP_Op<"masked", clauses = [ //===----------------------------------------------------------------------===// // [Spec 5.2] 6.5 allocate Directive //===----------------------------------------------------------------------===// -def AllocateDirOp : OpenMP_Op<"allocate_dir", clauses = [ +def AllocateDirOp : OpenMP_Op<"allocate_dir", [AttrSizedOperandSegments], clauses = [ OpenMP_AlignClause, OpenMP_AllocatorClause ]> { let summary = "allocate directive"; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 7d15eab5c5232..3d70e28ed23ab 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1250,94 +1250,6 @@ verifyReductionVarList(Operation *op, std::optional reductionSyms, return success(); } -//===----------------------------------------------------------------------===// -// Parser, printer and verifier for Allocator (Section 8.4 in OpenMP 6.0) -//===----------------------------------------------------------------------===// - -/// Parses a allocator clause. The value of allocator handle is an integer -/// which is a combination of different allocator handles from -/// `omp_allocator_handle_t`. -/// -/// allocator-clause = `allocator` `(` allocator-value `)` -static ParseResult parseAllocatorHandle(OpAsmParser &parser, - IntegerAttr &allocatorHandleAttr) { - StringRef allocatorKeyword; - int64_t allocator = 0; - if (succeeded(parser.parseOptionalKeyword("none"))) { - allocatorHandleAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); - return success(); - } - OptionalParseResult parsedInteger = parser.parseOptionalInteger(allocator); - if (parsedInteger.has_value()) { - allocatorHandleAttr = - IntegerAttr::get(parser.getBuilder().getI64Type(), allocator); - return success(); - } - if (failed(parser.parseKeyword(&allocatorKeyword))) - return failure(); - if (allocatorKeyword == "omp_null_allocator") - allocator = 0; - else if (allocatorKeyword == "omp_default_mem_alloc") - allocator = 1; - else if (allocatorKeyword == "omp_large_cap_mem_alloc") - allocator = 2; - else if (allocatorKeyword == "omp_const_mem_alloc") - allocator = 3; - else if (allocatorKeyword == "omp_high_bw_mem_alloc") - allocator = 4; - else if (allocatorKeyword == "omp_low_lat_mem_alloc") - allocator = 5; - else if (allocatorKeyword == "omp_cgroup_mem_alloc") - allocator = 6; - else if (allocatorKeyword == "omp_pteam_mem_alloc") - allocator = 7; - else if (allocatorKeyword == "omp_thread_mem_alloc") - allocator = 8; - allocatorHandleAttr = - IntegerAttr::get(parser.getBuilder().getI64Type(), allocator); - return success(); -} - -/// Prints a allocator clause -static void printAllocatorHandle(OpAsmPrinter &p, Operation *op, - IntegerAttr allocatorHandleAttr) { - int64_t allocator = allocatorHandleAttr.getInt(); - StringRef allocatorHandle; - switch (allocator) { - case 0: - allocatorHandle = "omp_null_allocator"; - break; - case 1: - allocatorHandle = "omp_default_mem_alloc"; - break; - case 2: - allocatorHandle = "omp_large_cap_mem_alloc"; - break; - case 3: - allocatorHandle = "omp_const_mem_alloc"; - break; - case 4: - allocatorHandle = "omp_high_bw_mem_alloc"; - break; - case 5: - allocatorHandle = "omp_low_lat_mem_alloc"; - break; - case 6: - allocatorHandle = "omp_cgroup_mem_alloc"; - break; - case 7: - allocatorHandle = "omp_pteam_mem_alloc"; - break; - case 8: - allocatorHandle = "omp_thread_mem_alloc"; - break; - default: - p << Twine(allocator).str(); - return; - } - p << allocatorHandle; -} - //===----------------------------------------------------------------------===// // Parser, printer and verifier for Copyprivate //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 79046a72006d7..cbd863f88fd1f 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -3260,6 +3260,10 @@ func.func @omp_workshare_loop_wrapper_attrs(%idx : index) { return } +func.func @omp_init_allocator(%custom_allocator : i64) -> i64 { + return %custom_allocator : i64 +} + // CHECK-LABEL: func.func @omp_allocate_dir( // CHECK-SAME: %[[ARG0:.*]]: memref, // CHECK-SAME: %[[ARG1:.*]]: memref) { @@ -3278,20 +3282,29 @@ func.func @omp_allocate_dir(%arg0 : memref, %arg1 : memref) -> () { omp.allocate_dir (%arg0 : memref) align(2) // Test with one data var and allocator clause - // CHECK: omp.allocate_dir(%[[ARG0]] : memref) allocator(omp_pteam_mem_alloc) - omp.allocate_dir (%arg0 : memref) allocator(7) + // CHECK: %[[VAL_1:.*]] = arith.constant 1 : i64 + %omp_default_mem_alloc = arith.constant 1 : i64 + // CHECK: omp.allocate_dir(%[[ARG0]] : memref) allocator(%[[VAL_1:.*]]) + omp.allocate_dir (%arg0 : memref) allocator(%omp_default_mem_alloc) // Test with one data var, align clause and allocator clause - // CHECK: omp.allocate_dir(%[[ARG0]] : memref) align(2) allocator(omp_thread_mem_alloc) - omp.allocate_dir (%arg0 : memref) align(2) allocator(omp_thread_mem_alloc) + // CHECK: %[[VAL_2:.*]] = arith.constant 7 : i64 + %omp_pteam_mem_alloc = arith.constant 7 : i64 + // CHECK: omp.allocate_dir(%[[ARG0]] : memref) align(4) allocator(%[[VAL_2:.*]]) + omp.allocate_dir (%arg0 : memref) align(4) allocator(%omp_pteam_mem_alloc) // Test with two data vars, align clause and allocator clause - // CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref, memref) align(2) allocator(omp_cgroup_mem_alloc) - omp.allocate_dir (%arg0, %arg1 : memref, memref) align(2) allocator(omp_cgroup_mem_alloc) + // CHECK: %[[VAL_3:.*]] = arith.constant 6 : i64 + %omp_cgroup_mem_alloc = arith.constant 6 : i64 + // CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref, memref) align(8) allocator(%[[VAL_3:.*]]) + omp.allocate_dir (%arg0, %arg1 : memref, memref) align(8) allocator(%omp_cgroup_mem_alloc) // Test with one data var and user defined allocator clause - // CHECK: omp.allocate_dir(%[[ARG0]] : memref) allocator(9) - omp.allocate_dir (%arg0 : memref) allocator(9) + // CHECK: %[[VAL_4:.*]] = arith.constant 9 : i64 + %custom_allocator = arith.constant 9 : i64 + %custom_mem_alloc = func.call @omp_init_allocator(%custom_allocator) : (i64) -> (i64) + // CHECK: omp.allocate_dir(%[[ARG0]] : memref) allocator(%[[VAL_5:.*]]) + omp.allocate_dir (%arg0 : memref) allocator(%custom_mem_alloc) return }