Skip to content

Commit 7aa3cad

Browse files
committed
[NVPTX] Enable lowering of atomics on local memory
LLVM does not have valid assembly backends for atomicrmw on local memory. However, as this memory is thread local, we should be able to lower this to the relevant load/store. Differential Revision: https://reviews.llvm.org/D98650
1 parent 2509f9f commit 7aa3cad

File tree

7 files changed

+123
-2
lines changed

7 files changed

+123
-2
lines changed

llvm/include/llvm/Transforms/Scalar/LowerAtomic.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ class LowerAtomicPass : public PassInfoMixin<LowerAtomicPass> {
2424
PreservedAnalyses run(Function &F, FunctionAnalysisManager &);
2525
static bool isRequired() { return true; }
2626
};
27+
28+
class AtomicRMWInst;
29+
/// Convert the given RMWI into primitive load and stores,
30+
/// assuming that doing so is legal. Return true if the lowering
31+
/// succeeds.
32+
bool lowerAtomicRMWInst(AtomicRMWInst *RMWI);
2733
}
2834

2935
#endif // LLVM_TRANSFORMS_SCALAR_LOWERATOMIC_H

llvm/lib/Target/NVPTX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ add_public_tablegen_target(NVPTXCommonTableGen)
1212

1313
set(NVPTXCodeGen_sources
1414
NVPTXAllocaHoisting.cpp
15+
NVPTXAtomicLower.cpp
1516
NVPTXAsmPrinter.cpp
1617
NVPTXAssignValidGlobalNames.cpp
1718
NVPTXFrameLowering.cpp
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
//===-- NVPTXAtomicLower.cpp - Lower atomics of local memory ----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Lower atomics of local memory to simple load/stores
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "NVPTXAtomicLower.h"
14+
#include "llvm/CodeGen/StackProtector.h"
15+
#include "llvm/IR/Constants.h"
16+
#include "llvm/IR/Function.h"
17+
#include "llvm/IR/IRBuilder.h"
18+
#include "llvm/IR/InstIterator.h"
19+
#include "llvm/IR/Instructions.h"
20+
#include "llvm/Transforms/Scalar/LowerAtomic.h"
21+
22+
#include "MCTargetDesc/NVPTXBaseInfo.h"
23+
using namespace llvm;
24+
25+
namespace {
26+
// Hoisting the alloca instructions in the non-entry blocks to the entry
27+
// block.
28+
class NVPTXAtomicLower : public FunctionPass {
29+
public:
30+
static char ID; // Pass ID
31+
NVPTXAtomicLower() : FunctionPass(ID) {}
32+
33+
void getAnalysisUsage(AnalysisUsage &AU) const override {
34+
AU.setPreservesCFG();
35+
}
36+
37+
StringRef getPassName() const override {
38+
return "NVPTX lower atomics of local memory";
39+
}
40+
41+
bool runOnFunction(Function &F) override;
42+
};
43+
} // namespace
44+
45+
bool NVPTXAtomicLower::runOnFunction(Function &F) {
46+
SmallVector<AtomicRMWInst *> LocalMemoryAtomics;
47+
for (Instruction &I : instructions(F))
48+
if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(&I))
49+
if (RMWI->getPointerAddressSpace() == ADDRESS_SPACE_LOCAL)
50+
LocalMemoryAtomics.push_back(RMWI);
51+
52+
bool Changed = false;
53+
for (AtomicRMWInst *RMWI : LocalMemoryAtomics)
54+
Changed |= lowerAtomicRMWInst(RMWI);
55+
return Changed;
56+
}
57+
58+
char NVPTXAtomicLower::ID = 0;
59+
60+
namespace llvm {
61+
void initializeNVPTXAtomicLowerPass(PassRegistry &);
62+
}
63+
64+
INITIALIZE_PASS(NVPTXAtomicLower, "nvptx-atomic-lower",
65+
"Lower atomics of local memory to simple load/stores", false,
66+
false)
67+
68+
FunctionPass *llvm::createNVPTXAtomicLowerPass() {
69+
return new NVPTXAtomicLower();
70+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===-- NVPTXAtomicLower.h - Lower atomics of local memory ------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Lower atomics of local memory to simple load/stores
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef LLVM_LIB_TARGET_NVPTX_NVPTXATOMICLOWER_H
14+
#define LLVM_LIB_TARGET_NVPTX_NVPTXATOMICLOWER_H
15+
16+
namespace llvm {
17+
class FunctionPass;
18+
19+
extern FunctionPass *createNVPTXAtomicLowerPass();
20+
} // end namespace llvm
21+
22+
#endif

llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "NVPTXTargetMachine.h"
1414
#include "NVPTX.h"
1515
#include "NVPTXAllocaHoisting.h"
16+
#include "NVPTXAtomicLower.h"
1617
#include "NVPTXLowerAggrCopies.h"
1718
#include "NVPTXTargetObjectFile.h"
1819
#include "NVPTXTargetTransformInfo.h"
@@ -65,6 +66,7 @@ void initializeNVVMIntrRangePass(PassRegistry&);
6566
void initializeNVVMReflectPass(PassRegistry&);
6667
void initializeGenericToNVVMPass(PassRegistry&);
6768
void initializeNVPTXAllocaHoistingPass(PassRegistry &);
69+
void initializeNVPTXAtomicLowerPass(PassRegistry &);
6870
void initializeNVPTXAssignValidGlobalNamesPass(PassRegistry&);
6971
void initializeNVPTXLowerAggrCopiesPass(PassRegistry &);
7072
void initializeNVPTXLowerArgsPass(PassRegistry &);
@@ -86,6 +88,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() {
8688
initializeGenericToNVVMPass(PR);
8789
initializeNVPTXAllocaHoistingPass(PR);
8890
initializeNVPTXAssignValidGlobalNamesPass(PR);
91+
initializeNVPTXAtomicLowerPass(PR);
8992
initializeNVPTXLowerArgsPass(PR);
9093
initializeNVPTXLowerAllocaPass(PR);
9194
initializeNVPTXLowerAggrCopiesPass(PR);
@@ -252,6 +255,7 @@ void NVPTXPassConfig::addAddressSpaceInferencePasses() {
252255
addPass(createSROAPass());
253256
addPass(createNVPTXLowerAllocaPass());
254257
addPass(createInferAddressSpacesPass());
258+
addPass(createNVPTXAtomicLowerPass());
255259
}
256260

257261
void NVPTXPassConfig::addStraightLineScalarOptimizationPasses() {

llvm/lib/Transforms/Scalar/LowerAtomic.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ static bool LowerAtomicCmpXchgInst(AtomicCmpXchgInst *CXI) {
4040
return true;
4141
}
4242

43-
static bool LowerAtomicRMWInst(AtomicRMWInst *RMWI) {
43+
bool llvm::lowerAtomicRMWInst(AtomicRMWInst *RMWI) {
4444
IRBuilder<> Builder(RMWI);
4545
Value *Ptr = RMWI->getPointerOperand();
4646
Value *Val = RMWI->getValOperand();
@@ -123,7 +123,7 @@ static bool runOnBasicBlock(BasicBlock &BB) {
123123
else if (AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(&Inst))
124124
Changed |= LowerAtomicCmpXchgInst(CXI);
125125
else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(&Inst))
126-
Changed |= LowerAtomicRMWInst(RMWI);
126+
Changed |= lowerAtomicRMWInst(RMWI);
127127
else if (LoadInst *LI = dyn_cast<LoadInst>(&Inst)) {
128128
if (LI->isAtomic())
129129
LowerLoadInst(LI);
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
; RUN: opt < %s -S -nvptx-atomic-lower | FileCheck %s
2+
3+
; This test ensures that there is a legal way for ptx to lower atomics
4+
; on local memory. Here, we demonstrate this by lowering them to simple
5+
; load and stores.
6+
7+
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
8+
target triple = "nvptx64-unknown-unknown"
9+
10+
define double @kernel(double addrspace(5)* %ptr, double %val) {
11+
%res = atomicrmw fadd double addrspace(5)* %ptr, double %val monotonic, align 8
12+
ret double %res
13+
; CHECK: %1 = load double, double addrspace(5)* %ptr, align 8
14+
; CHECK-NEXT: %2 = fadd double %1, %val
15+
; CHECK-NEXT: store double %2, double addrspace(5)* %ptr, align 8
16+
; CHECK-NEXT: ret double %1
17+
}
18+

0 commit comments

Comments
 (0)