Skip to content

Commit 87f2261

Browse files
committed
update
1 parent 38b566f commit 87f2261

File tree

3 files changed

+33
-5
lines changed

3 files changed

+33
-5
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2246,6 +2246,7 @@ def GroupprivateOp : OpenMP_Op<"groupprivate",
22462246

22472247
let arguments = (ins OpenMP_PointerLikeType:$sym_addr);
22482248
let results = (outs OpenMP_PointerLikeType:$gp_addr);
2249+
let hasVerifier = 1;
22492250
let assemblyFormat = [{
22502251
$sym_addr `:` type($sym_addr) `->` type($gp_addr) attr-dict
22512252
}];

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4433,6 +4433,14 @@ LogicalResult WorkdistributeOp::verify() {
44334433
return success();
44344434
}
44354435

4436+
//===----------------------------------------------------------------------===//
4437+
// GroupprivateOp
4438+
//===----------------------------------------------------------------------===//
4439+
4440+
LogicalResult GroupprivateOp::verify() {
4441+
return success();
4442+
}
4443+
44364444
#define GET_ATTRDEF_CLASSES
44374445
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
44384446

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6117,11 +6117,30 @@ convertOmpGroupprivate(Operation &opInst, llvm::IRBuilderBase &builder,
61176117
addressOfOp.getGlobal(moduleTranslation.symbolTable());
61186118
llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
61196119

6120-
if (!ompBuilder->Config.isTargetDevice()) {
6121-
llvm_unreachable("NYI");
6122-
} else {
6123-
moduleTranslation.mapValue(opInst.getResult(0), globalValue);
6124-
}
6120+
// Get the size of the variable
6121+
llvm::Type *varType = globalValue->getValueType();
6122+
llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
6123+
llvm::DataLayout DL = llvmModule->getDataLayout();
6124+
uint64_t typeSize = DL.getTypeAllocSize(varType);
6125+
// Call omp_alloc_shared to allocate memory for groupprivate variable.
6126+
// Need PR: https://github.com/llvm/llvm-project/pull/150923.
6127+
//groupPrivatePtr = ompBuilder->createOMPAllocShared(ompLoc, varType);
6128+
llvm::FunctionCallee allocSharedFn = ompBuilder->getOrCreateRuntimeFunction(
6129+
*llvmModule,
6130+
llvm::omp::OMPRTL___kmpc_alloc_shared);
6131+
uint32_t SrcLocStrSize;
6132+
llvm::Constant *Loc =
6133+
ompBuilder->getOrCreateDefaultSrcLocStr(SrcLocStrSize);
6134+
// Call runtime to allocate shared memory for this group
6135+
llvm::Value *args[] = {
6136+
ompBuilder->getOrCreateIdent(Loc, SrcLocStrSize),
6137+
builder.getInt64(typeSize)
6138+
};
6139+
llvm::Value *groupPrivatePtr = builder.CreateCall(allocSharedFn, args);
6140+
groupPrivatePtr = builder.CreateBitCast(groupPrivatePtr,
6141+
globalValue->getType());
6142+
6143+
moduleTranslation.mapValue(opInst.getResult(0), groupPrivatePtr);
61256144

61266145
return success();
61276146
}

0 commit comments

Comments
 (0)