Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ set(NVPTXCodeGen_sources
NVPTXAssignValidGlobalNames.cpp
NVPTXAtomicLower.cpp
NVPTXCtorDtorLowering.cpp
NVPTXFoldFMA.cpp
NVPTXForwardParams.cpp
NVPTXFrameLowering.cpp
NVPTXGenericToNVVM.cpp
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTX.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ FunctionPass *createNVPTXLowerAllocaPass();
FunctionPass *createNVPTXLowerUnreachablePass(bool TrapUnreachable,
bool NoTrapAfterNoreturn);
FunctionPass *createNVPTXTagInvariantLoadsPass();
FunctionPass *createNVPTXFoldFMAPass();
MachineFunctionPass *createNVPTXPeephole();
MachineFunctionPass *createNVPTXProxyRegErasurePass();
MachineFunctionPass *createNVPTXForwardParamsPass();
Expand All @@ -76,12 +77,17 @@ void initializeNVPTXAAWrapperPassPass(PassRegistry &);
void initializeNVPTXExternalAAWrapperPass(PassRegistry &);
void initializeNVPTXPeepholePass(PassRegistry &);
void initializeNVPTXTagInvariantLoadLegacyPassPass(PassRegistry &);
void initializeNVPTXFoldFMAPass(PassRegistry &);
void initializeNVPTXPrologEpilogPassPass(PassRegistry &);

struct NVVMIntrRangePass : PassInfoMixin<NVVMIntrRangePass> {
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
};

struct NVPTXFoldFMAPass : PassInfoMixin<NVPTXFoldFMAPass> {
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
};

struct NVVMReflectPass : PassInfoMixin<NVVMReflectPass> {
NVVMReflectPass() : SmVersion(0) {}
NVVMReflectPass(unsigned SmVersion) : SmVersion(SmVersion) {}
Expand Down
150 changes: 150 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXFoldFMA.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
//===------ NVPTXFoldFMA.cpp - Fold FMA --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements FMA folding for float/double type for NVPTX. It folds
// following patterns:
// 1. fadd(fmul(a, b), c) => fma(a, b, c)
// 2. fadd(c, fmul(a, b)) => fma(a, b, c)
// 3. fadd(fmul(a, b), fmul(c, d)) => fma(a, b, fmul(c, d))
// 4. fsub(fmul(a, b), c) => fma(a, b, fneg(c))
// 5. fsub(a, fmul(b, c)) => fma(fneg(b), c, a)
// 6. fsub(fmul(a, b), fmul(c, d)) => fma(a, b, fneg(fmul(c, d)))
//===----------------------------------------------------------------------===//

#include "NVPTXUtilities.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"

#define DEBUG_TYPE "nvptx-fold-fma"

using namespace llvm;

static bool tryFoldBinaryFMul(BinaryOperator *BI, Value *MulOperand,
Value *OtherOperand, bool IsFirstOperand,
bool IsFSub) {
auto *FMul = dyn_cast<BinaryOperator>(MulOperand);
if (!FMul || FMul->getOpcode() != Instruction::FMul || !FMul->hasOneUse() ||
!FMul->hasAllowContract())
return false;

LLVM_DEBUG({
const char *OpName = IsFSub ? "FSub" : "FAdd";
dbgs() << "Found " << OpName << " with FMul (single use) as "
<< (IsFirstOperand ? "first" : "second") << " operand: " << *BI
<< "\n";
});

Value *MulOp0 = FMul->getOperand(0);
Value *MulOp1 = FMul->getOperand(1);
IRBuilder<> Builder(BI);
Value *FMA = nullptr;

if (!IsFSub) {
// fadd(fmul(a, b), c) => fma(a, b, c)
// fadd(c, fmul(a, b)) => fma(a, b, c)
FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
{MulOp0, MulOp1, OtherOperand});
} else {
if (IsFirstOperand) {
// fsub(fmul(a, b), c) => fma(a, b, fneg(c))
Value *NegOtherOp =
Builder.CreateFNegFMF(OtherOperand, BI->getFastMathFlags());
FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
{MulOp0, MulOp1, NegOtherOp});
} else {
// fsub(a, fmul(b, c)) => fma(fneg(b), c, a)
Value *NegMulOp0 =
Builder.CreateFNegFMF(MulOp0, FMul->getFastMathFlags());
FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()},
{NegMulOp0, MulOp1, OtherOperand});
}
}

// Combine fast-math flags from the original instructions
auto *FMAInst = cast<Instruction>(FMA);
FastMathFlags BinaryFMF = BI->getFastMathFlags();
FastMathFlags FMulFMF = FMul->getFastMathFlags();
FastMathFlags NewFMF = FastMathFlags::intersectRewrite(BinaryFMF, FMulFMF) |
FastMathFlags::unionValue(BinaryFMF, FMulFMF);
FMAInst->setFastMathFlags(NewFMF);

LLVM_DEBUG({
const char *OpName = IsFSub ? "FSub" : "FAdd";
dbgs() << "Replacing " << OpName << " with FMA: " << *FMA << "\n";
});
BI->replaceAllUsesWith(FMA);
BI->eraseFromParent();
FMul->eraseFromParent();
return true;
}

static bool foldFMA(Function &F) {
bool Changed = false;
SmallVector<BinaryOperator *, 16> FAddFSubInsts;

// Collect all float/double FAdd/FSub instructions with allow-contract
for (auto &I : instructions(F)) {
if (auto *BI = dyn_cast<BinaryOperator>(&I)) {
// Only FAdd and FSub are supported.
if (BI->getOpcode() != Instruction::FAdd &&
BI->getOpcode() != Instruction::FSub)
continue;

// At minimum, the instruction should have allow-contract.
if (!BI->hasAllowContract())
continue;

// Only float and double are supported.
if (!BI->getType()->isFloatTy() && !BI->getType()->isDoubleTy())
continue;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not half and bfloat?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No specific reasons—first, I wanted to support the float and double. This pass could be extended in the future to handle half and bfloat types as well.


FAddFSubInsts.push_back(BI);
}
}

for (auto *BI : FAddFSubInsts) {
Value *Op0 = BI->getOperand(0);
Value *Op1 = BI->getOperand(1);
bool IsFSub = BI->getOpcode() == Instruction::FSub;

if (tryFoldBinaryFMul(BI, Op0, Op1, true /*IsFirstOperand*/, IsFSub) ||
tryFoldBinaryFMul(BI, Op1, Op0, false /*IsFirstOperand*/, IsFSub))
Changed = true;
}

return Changed;
}

namespace {

struct NVPTXFoldFMA : public FunctionPass {
static char ID;
NVPTXFoldFMA() : FunctionPass(ID) {}
bool runOnFunction(Function &F) override;
};

} // namespace

char NVPTXFoldFMA::ID = 0;
INITIALIZE_PASS(NVPTXFoldFMA, "nvptx-fold-fma", "NVPTX Fold FMA", false, false)

bool NVPTXFoldFMA::runOnFunction(Function &F) { return foldFMA(F); }

FunctionPass *llvm::createNVPTXFoldFMAPass() { return new NVPTXFoldFMA(); }

PreservedAnalyses NVPTXFoldFMAPass::run(Function &F,
FunctionAnalysisManager &) {
if (!foldFMA(F))
return PreservedAnalyses::all();

PreservedAnalyses PA;
PA.preserveSet<CFGAnalyses>();
return PA;
}
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ FUNCTION_PASS("nvvm-intr-range", NVVMIntrRangePass())
FUNCTION_PASS("nvptx-copy-byval-args", NVPTXCopyByValArgsPass())
FUNCTION_PASS("nvptx-lower-args", NVPTXLowerArgsPass(*this))
FUNCTION_PASS("nvptx-tag-invariant-loads", NVPTXTagInvariantLoadsPass())
FUNCTION_PASS("nvptx-fold-fma", NVPTXFoldFMAPass())
#undef FUNCTION_PASS
9 changes: 9 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ static cl::opt<bool>
cl::desc("Disable load/store vectorizer"),
cl::init(false), cl::Hidden);

// FoldFMA is a new pass; this option will lets us turn it off in case we
// encounter some issues.
static cl::opt<bool> DisableFoldFMA("disable-nvptx-fold-fma",
cl::desc("Disable NVPTX Fold FMA"),
cl::init(false), cl::Hidden);

// TODO: Remove this flag when we are confident with no regressions.
static cl::opt<bool> DisableRequireStructuredCFG(
"disable-nvptx-require-structured-cfg",
Expand Down Expand Up @@ -115,6 +121,7 @@ extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() {
initializeNVPTXExternalAAWrapperPass(PR);
initializeNVPTXPeepholePass(PR);
initializeNVPTXTagInvariantLoadLegacyPassPass(PR);
initializeNVPTXFoldFMAPass(PR);
initializeNVPTXPrologEpilogPassPass(PR);
}

Expand Down Expand Up @@ -397,6 +404,8 @@ void NVPTXPassConfig::addIRPasses() {
addPass(createLoadStoreVectorizerPass());
addPass(createSROAPass());
addPass(createNVPTXTagInvariantLoadsPass());
if (!DisableFoldFMA)
addPass(createNVPTXFoldFMAPass());
}

if (ST.hasPTXASUnreachableBug()) {
Expand Down
Loading