-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[NVPTX] Add IR pass for FMA transformation in the llc pipeline #154735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
//===------ NVPTXFoldFMA.cpp - Fold FMA --------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file implements FMA folding for float/double type for NVPTX. It folds | ||
// following patterns: | ||
// 1. fadd(fmul(a, b), c) => fma(a, b, c) | ||
// 2. fadd(c, fmul(a, b)) => fma(a, b, c) | ||
// 3. fadd(fmul(a, b), fmul(c, d)) => fma(a, b, fmul(c, d)) | ||
// 4. fsub(fmul(a, b), c) => fma(a, b, fneg(c)) | ||
// 5. fsub(a, fmul(b, c)) => fma(fneg(b), c, a) | ||
// 6. fsub(fmul(a, b), fmul(c, d)) => fma(a, b, fneg(fmul(c, d))) | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "NVPTXUtilities.h" | ||
#include "llvm/IR/IRBuilder.h" | ||
#include "llvm/IR/InstIterator.h" | ||
#include "llvm/IR/Instructions.h" | ||
#include "llvm/IR/Intrinsics.h" | ||
|
||
#define DEBUG_TYPE "nvptx-fold-fma" | ||
|
||
using namespace llvm; | ||
|
||
static bool foldFMA(Function &F) { | ||
bool Changed = false; | ||
SmallVector<BinaryOperator *, 16> FAddFSubInsts; | ||
|
||
// Collect all float/double FAdd/FSub instructions with allow-contract | ||
for (auto &I : instructions(F)) { | ||
if (auto *BI = dyn_cast<BinaryOperator>(&I)) { | ||
// Only FAdd and FSub are supported. | ||
if (BI->getOpcode() != Instruction::FAdd && | ||
BI->getOpcode() != Instruction::FSub) | ||
continue; | ||
|
||
// At minimum, the instruction should have allow-contract. | ||
if (!BI->hasAllowContract()) | ||
continue; | ||
|
||
// Only float and double are supported. | ||
if (!BI->getType()->isFloatTy() && !BI->getType()->isDoubleTy()) | ||
continue; | ||
|
||
FAddFSubInsts.push_back(BI); | ||
rajatbajpai marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
} | ||
} | ||
|
||
auto tryFoldBinaryFMul = [](BinaryOperator *BI, Value *MulOperand, | ||
rajatbajpai marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
Value *OtherOperand, bool IsFirstOperand, | ||
bool IsFSub) -> bool { | ||
auto *FMul = dyn_cast<BinaryOperator>(MulOperand); | ||
if (!FMul || FMul->getOpcode() != Instruction::FMul || !FMul->hasOneUse() || | ||
!FMul->hasAllowContract()) | ||
return false; | ||
|
||
LLVM_DEBUG({ | ||
const char *OpName = IsFSub ? "FSub" : "FAdd"; | ||
dbgs() << "Found " << OpName << " with FMul (single use) as " | ||
<< (IsFirstOperand ? "first" : "second") << " operand: " << *BI | ||
<< "\n"; | ||
}); | ||
|
||
Value *MulOp0 = FMul->getOperand(0); | ||
Value *MulOp1 = FMul->getOperand(1); | ||
IRBuilder<> Builder(BI); | ||
Value *FMA = nullptr; | ||
|
||
if (!IsFSub) { | ||
// fadd(fmul(a, b), c) => fma(a, b, c) | ||
// fadd(c, fmul(a, b)) => fma(a, b, c) | ||
FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()}, | ||
{MulOp0, MulOp1, OtherOperand}); | ||
} else { | ||
if (IsFirstOperand) { | ||
// fsub(fmul(a, b), c) => fma(a, b, fneg(c)) | ||
Value *NegOtherOp = Builder.CreateFNeg(OtherOperand); | ||
cast<Instruction>(NegOtherOp)->setFastMathFlags(BI->getFastMathFlags()); | ||
rajatbajpai marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()}, | ||
{MulOp0, MulOp1, NegOtherOp}); | ||
} else { | ||
// fsub(a, fmul(b, c)) => fma(fneg(b), c, a) | ||
Value *NegMulOp0 = Builder.CreateFNeg(MulOp0); | ||
cast<Instruction>(NegMulOp0)->setFastMathFlags( | ||
FMul->getFastMathFlags()); | ||
rajatbajpai marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
FMA = Builder.CreateIntrinsic(Intrinsic::fma, {BI->getType()}, | ||
{NegMulOp0, MulOp1, OtherOperand}); | ||
} | ||
} | ||
|
||
// Combine fast-math flags from the original instructions | ||
auto *FMAInst = cast<Instruction>(FMA); | ||
rajatbajpai marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
FastMathFlags BinaryFMF = BI->getFastMathFlags(); | ||
FastMathFlags FMulFMF = FMul->getFastMathFlags(); | ||
FastMathFlags NewFMF = FastMathFlags::intersectRewrite(BinaryFMF, FMulFMF) | | ||
FastMathFlags::unionValue(BinaryFMF, FMulFMF); | ||
|
||
FMAInst->setFastMathFlags(NewFMF); | ||
|
||
LLVM_DEBUG({ | ||
const char *OpName = IsFSub ? "FSub" : "FAdd"; | ||
dbgs() << "Replacing " << OpName << " with FMA: " << *FMA << "\n"; | ||
}); | ||
BI->replaceAllUsesWith(FMA); | ||
BI->eraseFromParent(); | ||
FMul->eraseFromParent(); | ||
return true; | ||
}; | ||
|
||
for (auto *BI : FAddFSubInsts) { | ||
Value *Op0 = BI->getOperand(0); | ||
Value *Op1 = BI->getOperand(1); | ||
bool IsFSub = BI->getOpcode() == Instruction::FSub; | ||
|
||
if (tryFoldBinaryFMul(BI, Op0, Op1, true /*IsFirstOperand*/, IsFSub) || | ||
tryFoldBinaryFMul(BI, Op1, Op0, false /*IsFirstOperand*/, IsFSub)) | ||
Changed = true; | ||
} | ||
|
||
return Changed; | ||
} | ||
|
||
namespace { | ||
|
||
struct NVPTXFoldFMA : public FunctionPass { | ||
static char ID; | ||
NVPTXFoldFMA() : FunctionPass(ID) {} | ||
bool runOnFunction(Function &F) override; | ||
}; | ||
|
||
} // namespace | ||
|
||
char NVPTXFoldFMA::ID = 0; | ||
INITIALIZE_PASS(NVPTXFoldFMA, "nvptx-fold-fma", "NVPTX Fold FMA", false, false) | ||
|
||
bool NVPTXFoldFMA::runOnFunction(Function &F) { return foldFMA(F); } | ||
|
||
FunctionPass *llvm::createNVPTXFoldFMAPass() { return new NVPTXFoldFMA(); } | ||
|
||
PreservedAnalyses NVPTXFoldFMAPass::run(Function &F, | ||
FunctionAnalysisManager &) { | ||
return foldFMA(F) ? PreservedAnalyses::none() : PreservedAnalyses::all(); | ||
rajatbajpai marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not half and bfloat?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No specific reasons—first, I wanted to support the float and double. This pass could be extended in the future to handle half and bfloat types as well.