Skip to content

Commit 99e6a33

Browse files
committed
[MLIR][OpenMP] Lowering nontemporal clause to LLVM IR for SIMD directive
1 parent caaac84 commit 99e6a33

File tree

4 files changed

+61
-3
lines changed

4 files changed

+61
-3
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1187,7 +1187,8 @@ class OpenMPIRBuilder {
11871187
void applySimd(CanonicalLoopInfo *Loop,
11881188
MapVector<Value *, Value *> AlignedVars, Value *IfCond,
11891189
omp::OrderKind Order, ConstantInt *Simdlen,
1190-
ConstantInt *Safelen);
1190+
ConstantInt *Safelen,
1191+
SmallVector<Value *> NontempralVars = {});
11911192

11921193
/// Generator for '#omp flush'
11931194
///

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5183,10 +5183,31 @@ OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
51835183
return 0;
51845184
}
51855185

5186+
/// Attach nontemporal metadata to the load/store instructions of nontemporal
5187+
/// variables of \p Block
5188+
static void addNonTemporalMetadata(BasicBlock *Block, MDNode *Nontemporal,
5189+
SmallVector<Value *> NontemporalVars) {
5190+
for (Instruction &I : *Block) {
5191+
llvm::Value *mem_ptr = nullptr;
5192+
if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(&I))
5193+
mem_ptr = li->getPointerOperand();
5194+
else if (llvm::StoreInst *si = dyn_cast<llvm::StoreInst>(&I))
5195+
mem_ptr = si->getPointerOperand();
5196+
if (mem_ptr) {
5197+
if (llvm::GetElementPtrInst *gep =
5198+
dyn_cast<llvm::GetElementPtrInst>(mem_ptr))
5199+
mem_ptr = gep->getPointerOperand();
5200+
if (is_contained(NontemporalVars, mem_ptr))
5201+
I.setMetadata(LLVMContext::MD_nontemporal, Nontemporal);
5202+
}
5203+
}
5204+
}
5205+
51865206
void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
51875207
MapVector<Value *, Value *> AlignedVars,
51885208
Value *IfCond, OrderKind Order,
5189-
ConstantInt *Simdlen, ConstantInt *Safelen) {
5209+
ConstantInt *Simdlen, ConstantInt *Safelen,
5210+
SmallVector<Value *> NontemporalVars) {
51905211
LLVMContext &Ctx = Builder.getContext();
51915212

51925213
Function *F = CanonicalLoop->getFunction();
@@ -5283,6 +5304,12 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
52835304
}
52845305

52855306
addLoopMetadata(CanonicalLoop, LoopMDList);
5307+
// Set nontemporal metadata to load and stores of nontemporal values
5308+
if (NontemporalVars.size()) {
5309+
MDNode *NontemporalNode = MDNode::getDistinct(Ctx, {});
5310+
for (BasicBlock *BB : Reachable)
5311+
addNonTemporalMetadata(BB, NontemporalNode, NontemporalVars);
5312+
}
52865313
}
52875314

52885315
/// Create the TargetMachine object to query the backend for optimization

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1867,11 +1867,19 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
18671867

18681868
llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
18691869
llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
1870+
1871+
llvm::SmallVector<llvm::Value *> nontemporalVars;
1872+
mlir::OperandRange nontemporals = simdOp.getNontemporalVars();
1873+
for (mlir::Value nontemporal : nontemporals) {
1874+
llvm::Value *nt = moduleTranslation.lookupValue(nontemporal);
1875+
nontemporalVars.push_back(nt);
1876+
}
1877+
18701878
ompBuilder->applySimd(loopInfo, alignedVars,
18711879
simdOp.getIfExpr()
18721880
? moduleTranslation.lookupValue(simdOp.getIfExpr())
18731881
: nullptr,
1874-
order, simdlen, safelen);
1882+
order, simdlen, safelen, nontemporalVars);
18751883

18761884
builder.restoreIP(afterIP);
18771885
return success();

mlir/test/Target/LLVMIR/openmp-llvm.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,28 @@ llvm.func @simd_order() {
872872
// CHECK-NEXT: llvm.loop.vectorize.width{{.*}}i64 2
873873
// -----
874874

875+
// CHECK-LABEL: @simd_nontemporal
876+
llvm.func @simd_nontemporal() {
877+
%0 = llvm.mlir.constant(10 : i64) : i64
878+
%1 = llvm.mlir.constant(1 : i64) : i64
879+
%2 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr
880+
%3 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr
881+
//CHECK: %[[A_ADDR:.*]] = alloca i64, i64 1, align 8
882+
//CHECK: %[[B_ADDR:.*]] = alloca i64, i64 1, align 8
883+
//CHECK: %[[B:.*]] = load i64, ptr %[[B_ADDR]], align 4, !nontemporal !1, !llvm.access.group !2
884+
//CHECK: store i64 %[[B]], ptr %[[A_ADDR]], align 4, !nontemporal !1, !llvm.access.group !2
885+
omp.simd nontemporal(%2, %3 : !llvm.ptr, !llvm.ptr) {
886+
omp.loop_nest (%arg0) : i64 = (%1) to (%0) inclusive step (%1) {
887+
%4 = llvm.load %3 : !llvm.ptr -> i64
888+
llvm.store %4, %2 : i64, !llvm.ptr
889+
omp.yield
890+
}
891+
omp.terminator
892+
}
893+
llvm.return
894+
}
895+
// -----
896+
875897
llvm.func @body(i64)
876898

877899
llvm.func @test_omp_wsloop_ordered(%lb : i64, %ub : i64, %step : i64) -> () {

0 commit comments

Comments
 (0)