Skip to content

Commit 7549dcf

Browse files
[MLIR][OpenMP] Add support for critical construct
This patch adds the critical construct to the OpenMP dialect. The implementation models the definition in 2.17.1 of the OpenMP 5 standard. A name and hint can be specified. The name is a global entity or has external linkage, it is modelled as a FlatSymbolRefAttr. Hint is modelled as an integer enum attribute. Also lowering to LLVM IR using the OpenMP IRBuilder. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D107135
1 parent c046931 commit 7549dcf

File tree

6 files changed

+136
-0
lines changed

6 files changed

+136
-0
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,43 @@ def MasterOp : OpenMP_Op<"master"> {
334334
let assemblyFormat = "$region attr-dict";
335335
}
336336

337+
//===----------------------------------------------------------------------===//
338+
// 2.17.1 critical Construct
339+
//===----------------------------------------------------------------------===//
340+
// TODO: Autogenerate this from OMP.td in llvm/include/Frontend
341+
def omp_sync_hint_none: I32EnumAttrCase<"none", 0>;
342+
def omp_sync_hint_uncontended: I32EnumAttrCase<"uncontended", 1>;
343+
def omp_sync_hint_contended: I32EnumAttrCase<"contended", 2>;
344+
def omp_sync_hint_nonspeculative: I32EnumAttrCase<"nonspeculative", 3>;
345+
def omp_sync_hint_speculative: I32EnumAttrCase<"speculative", 4>;
346+
347+
def SyncHintKind: I32EnumAttr<"SyncHintKind", "OpenMP Sync Hint Kind",
348+
[omp_sync_hint_none, omp_sync_hint_uncontended, omp_sync_hint_contended,
349+
omp_sync_hint_nonspeculative, omp_sync_hint_speculative]> {
350+
let cppNamespace = "::mlir::omp";
351+
let stringToSymbolFnName = "ConvertToEnum";
352+
let symbolToStringFnName = "ConvertToString";
353+
}
354+
355+
def CriticalOp : OpenMP_Op<"critical"> {
356+
let summary = "critical construct";
357+
let description = [{
358+
The critical construct imposes a restriction on the associated structured
359+
block (region) to be executed by only a single thread at a time.
360+
}];
361+
362+
let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$name,
363+
OptionalAttr<SyncHintKind>:$hint);
364+
365+
let regions = (region AnyRegion:$region);
366+
367+
let assemblyFormat = [{
368+
(`(` $name^ `)`)? (`hint` `(` $hint^ `)`)? $region attr-dict
369+
}];
370+
371+
let verifier = "return ::verifyCriticalOp(*this);";
372+
}
373+
337374
//===----------------------------------------------------------------------===//
338375
// 2.17.2 barrier Construct
339376
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,5 +1002,13 @@ static LogicalResult verifyWsLoopOp(WsLoopOp op) {
10021002
return success();
10031003
}
10041004

1005+
static LogicalResult verifyCriticalOp(CriticalOp op) {
1006+
if (!op.name().hasValue() && op.hint().hasValue() &&
1007+
(op.hint().getValue() != SyncHintKind::none))
1008+
return op.emitOpError() << "must specify a name unless the effect is as if "
1009+
"hint(none) is specified";
1010+
return success();
1011+
}
1012+
10051013
#define GET_OP_CLASSES
10061014
#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,45 @@ convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
204204
return success();
205205
}
206206

207+
/// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
208+
static LogicalResult
209+
convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
210+
LLVM::ModuleTranslation &moduleTranslation) {
211+
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
212+
auto criticalOp = cast<omp::CriticalOp>(opInst);
213+
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
214+
// relying on captured variables.
215+
LogicalResult bodyGenStatus = success();
216+
217+
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
218+
llvm::BasicBlock &continuationBlock) {
219+
// CriticalOp has only one region associated with it.
220+
auto &region = cast<omp::CriticalOp>(opInst).getRegion();
221+
convertOmpOpRegions(region, "omp.critical.region", *codeGenIP.getBlock(),
222+
continuationBlock, builder, moduleTranslation,
223+
bodyGenStatus);
224+
};
225+
226+
// TODO: Perform finalization actions for variables. This has to be
227+
// called for variables which have destructors/finalizers.
228+
auto finiCB = [&](InsertPointTy codeGenIP) {};
229+
230+
llvm::OpenMPIRBuilder::LocationDescription ompLoc(
231+
builder.saveIP(), builder.getCurrentDebugLocation());
232+
llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
233+
llvm::Constant *hint = nullptr;
234+
if (criticalOp.hint().hasValue()) {
235+
hint =
236+
llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
237+
static_cast<int>(criticalOp.hint().getValue()));
238+
} else {
239+
hint = llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0);
240+
}
241+
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createCritical(
242+
ompLoc, bodyGenCB, finiCB, criticalOp.name().getValueOr(""), hint));
243+
return success();
244+
}
245+
207246
/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
208247
static LogicalResult
209248
convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -419,6 +458,9 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
419458
.Case([&](omp::MasterOp) {
420459
return convertOmpMaster(*op, builder, moduleTranslation);
421460
})
461+
.Case([&](omp::CriticalOp) {
462+
return convertOmpCritical(*op, builder, moduleTranslation);
463+
})
422464
.Case([&](omp::WsLoopOp) {
423465
return convertOmpWsLoop(*op, builder, moduleTranslation);
424466
})

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,13 @@ func @foo(%lb : index, %ub : index, %step : index, %mem : memref<1xf32>) {
293293
}
294294
return
295295
}
296+
297+
// -----
298+
299+
func @omp_critical() -> () {
300+
// expected-error @below {{must specify a name unless the effect is as if hint(none) is specified}}
301+
omp.critical hint(nonspeculative) {
302+
omp.terminator
303+
}
304+
return
305+
}

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,15 @@ func @reduction2(%lb : index, %ub : index, %step : index) {
366366
}
367367
return
368368
}
369+
370+
// CHECK-LABEL: omp_critical
371+
func @omp_critical() -> () {
372+
omp.critical {
373+
omp.terminator
374+
}
375+
376+
omp.critical(@mutex) hint(nonspeculative) {
377+
omp.terminator
378+
}
379+
return
380+
}

mlir/test/Target/LLVMIR/openmp-llvm.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,3 +538,30 @@ llvm.func @collapse_wsloop(
538538
}
539539
llvm.return
540540
}
541+
542+
// -----
543+
544+
// CHECK-LABEL: @omp_critical
545+
llvm.func @omp_critical(%x : !llvm.ptr<i32>, %xval : i32) -> () {
546+
// CHECK: call void @__kmpc_critical_with_hint({{.*}}critical_user_.var{{.*}}, i32 0)
547+
// CHECK: br label %omp.critical.region
548+
// CHECK: omp.critical.region
549+
omp.critical {
550+
// CHECK: store
551+
llvm.store %xval, %x : !llvm.ptr<i32>
552+
omp.terminator
553+
}
554+
// CHECK: call void @__kmpc_end_critical({{.*}}critical_user_.var{{.*}})
555+
556+
// CHECK: call void @__kmpc_critical_with_hint({{.*}}critical_user_mutex.var{{.*}}, i32 2)
557+
// CHECK: br label %omp.critical.region
558+
// CHECK: omp.critical.region
559+
omp.critical(@mutex) hint(contended) {
560+
// CHECK: store
561+
llvm.store %xval, %x : !llvm.ptr<i32>
562+
omp.terminator
563+
}
564+
// CHECK: call void @__kmpc_end_critical({{.*}}critical_user_mutex.var{{.*}})
565+
566+
llvm.return
567+
}

0 commit comments

Comments
 (0)