@@ -2462,6 +2462,25 @@ convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
24622462 llvm_unreachable (" Unknown ClauseOrderKind kind" );
24632463}
24642464
2465+ static void
2466+ appendNontemporalVars (llvm::BasicBlock *Block,
2467+ SmallVectorImpl<llvm::Value *> &NontemporalVars) {
2468+ for (llvm::Instruction &I : *Block) {
2469+ if (const llvm::CallInst *CI = dyn_cast<llvm::CallInst>(&I)) {
2470+ if (CI->getIntrinsicID () == llvm::Intrinsic::memcpy) {
2471+ llvm::Value *DestPtr = CI->getArgOperand (0 );
2472+ llvm::Value *SrcPtr = CI->getArgOperand (1 );
2473+ for (const llvm::Value *Var : NontemporalVars) {
2474+ if (Var == SrcPtr) {
2475+ NontemporalVars.push_back (DestPtr);
2476+ break ;
2477+ }
2478+ }
2479+ }
2480+ }
2481+ }
2482+ }
2483+
24652484// / Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
24662485static LogicalResult
24672486convertOmpSimd (Operation &opInst, llvm::IRBuilderBase &builder,
@@ -2523,13 +2542,71 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
25232542 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
25242543 llvm::omp::OrderKind order = convertOrderKind (simdOp.getOrder ());
25252544
2526- llvm::SmallVector<llvm::Value *> nontemporalVars ;
2545+ llvm::SmallVector<llvm::Value *> nontemporalOrigVars ;
25272546 mlir::OperandRange nontemporals = simdOp.getNontemporalVars ();
25282547 for (mlir::Value nontemporal : nontemporals) {
25292548 llvm::Value *nt = moduleTranslation.lookupValue (nontemporal);
2530- nontemporalVars .push_back (nt);
2549+ nontemporalOrigVars .push_back (nt);
25312550 }
25322551
2552+ /* * Call back function to attach nontemporal metadata to the load/store
2553+ * instructions of nontemporal variables of Block.
2554+ * Nontemporal variables may be a scalar, fixed size or allocatable
2555+ * or pointer array
2556+ *
2557+ * Example scenarios for nontemporal variables:
2558+ * Case 1: Scalar variable
2559+ * If the nontemporal variable is a scalar, it is allocated on stack.Load
2560+ * and store instructions directly access the alloca pointer of the scalar
2561+ * variable for fetching information about scalar variable or writing
2562+ * into the scalar variable. Mark those load and store instructions as
2563+ * non-temporal.
2564+ *
2565+ * Case 2: Fixed Size array
2566+ * If the nontemporal variable is a fixed-size array, it is allocated
2567+ * as a contiguous block of memory. It uses one GEP instruction, to compute
2568+ * the address of each individual array elements and perform load or store
2569+ * operation on it. Mark those load and store instructions as non-temporal.
2570+ *
2571+ * Case 3: Allocatable array
2572+ * For an allocatable array, which might involve runtime type descriptor,
2573+ * needs to navigate through descriptors using two or more GEP and load
2574+ * instructions to compute the address of each individual element in an array.
2575+ * Mark those load or store which access the individual array elements as
2576+ * non-temporal.
2577+ */
2578+ auto addNonTemporalMetadataCB = [&](llvm::BasicBlock *Block,
2579+ llvm::MDNode *Nontemporal) {
2580+ SmallVector<llvm::Value *> NontemporalVars{nontemporalOrigVars};
2581+ appendNontemporalVars (Block, NontemporalVars);
2582+ for (llvm::Instruction &I : *Block) {
2583+ llvm::Value *mem_ptr = nullptr ;
2584+ bool MetadataFlag = true ;
2585+ if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(&I)) {
2586+ if (!(li->getType ()->isPointerTy ()))
2587+ mem_ptr = li->getPointerOperand ();
2588+ } else if (llvm::StoreInst *si = dyn_cast<llvm::StoreInst>(&I))
2589+ mem_ptr = si->getPointerOperand ();
2590+ if (mem_ptr) {
2591+ while (mem_ptr && !(isa<llvm::AllocaInst>(mem_ptr))) {
2592+ if (llvm::GetElementPtrInst *gep =
2593+ dyn_cast<llvm::GetElementPtrInst>(mem_ptr)) {
2594+ llvm::Type *sourceType = gep->getSourceElementType ();
2595+ if (sourceType->isStructTy () && gep->getNumIndices () >= 2 &&
2596+ !(gep->hasAllZeroIndices ())) {
2597+ MetadataFlag = false ;
2598+ break ;
2599+ }
2600+ mem_ptr = gep->getPointerOperand ();
2601+ } else if (llvm::LoadInst *li = dyn_cast<llvm::LoadInst>(mem_ptr))
2602+ mem_ptr = li->getPointerOperand ();
2603+ }
2604+ if (MetadataFlag && is_contained (NontemporalVars, mem_ptr))
2605+ I.setMetadata (llvm::LLVMContext::MD_nontemporal, Nontemporal);
2606+ }
2607+ }
2608+ };
2609+
25332610 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock ();
25342611 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments ();
25352612 mlir::OperandRange operands = simdOp.getAlignedVars ();
@@ -2557,11 +2634,11 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
25572634
25582635 builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
25592636 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo (moduleTranslation);
2560- ompBuilder->applySimd (loopInfo, alignedVars,
2561- simdOp. getIfExpr ()
2562- ? moduleTranslation.lookupValue (simdOp.getIfExpr ())
2563- : nullptr ,
2564- order, simdlen, safelen, nontemporalVars );
2637+ ompBuilder->applySimd (
2638+ loopInfo, alignedVars,
2639+ simdOp. getIfExpr () ? moduleTranslation.lookupValue (simdOp.getIfExpr ())
2640+ : nullptr ,
2641+ order, simdlen, safelen, addNonTemporalMetadataCB, nontemporalOrigVars );
25652642
25662643 return cleanupPrivateVars (builder, moduleTranslation, simdOp.getLoc (),
25672644 llvmPrivateVars, privateDecls);
0 commit comments