Skip to content

Commit 477d339

Browse files
committed
[AMDGPU] expand-fp: Unify scalarization
Extend the existing "scalarize" function which is used for the fp-integer conversion instruction expansion to BinaryOperator instructions and reuse it for the frem expansion. Furthermore, extract a function to dispatch instructions to the scalar and vector queues and hoist a check for scalable vectors to the top of the instruction visiting loop.
1 parent 5149e51 commit 477d339

File tree

5 files changed

+2028
-1097
lines changed

5 files changed

+2028
-1097
lines changed

llvm/lib/CodeGen/ExpandFp.cpp

Lines changed: 42 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,9 @@ Value *FRemExpander::buildFRem(Value *X, Value *Y,
356356
static bool expandFRem(BinaryOperator &I, std::optional<SimplifyQuery> &SQ) {
357357
LLVM_DEBUG(dbgs() << "Expanding instruction: " << I << '\n');
358358

359-
Type *ReturnTy = I.getType();
360-
assert(FRemExpander::canExpandType(ReturnTy->getScalarType()));
359+
Type *Ty = I.getType();
360+
assert(Ty->isFloatingPointTy() && "Instruction should have been scalarized");
361+
assert(FRemExpander::canExpandType(Ty));
361362

362363
FastMathFlags FMF = I.getFastMathFlags();
363364
// TODO Make use of those flags for optimization?
@@ -368,32 +369,10 @@ static bool expandFRem(BinaryOperator &I, std::optional<SimplifyQuery> &SQ) {
368369
B.setFastMathFlags(FMF);
369370
B.SetCurrentDebugLocation(I.getDebugLoc());
370371

371-
Type *ElemTy = ReturnTy->getScalarType();
372-
const FRemExpander Expander = FRemExpander::create(B, ElemTy);
373-
374-
Value *Ret;
375-
if (ReturnTy->isFloatingPointTy())
376-
Ret = FMF.approxFunc()
377-
? Expander.buildApproxFRem(I.getOperand(0), I.getOperand(1))
378-
: Expander.buildFRem(I.getOperand(0), I.getOperand(1), SQ);
379-
else {
380-
auto *VecTy = cast<FixedVectorType>(ReturnTy);
381-
382-
// This could use SplitBlockAndInsertForEachLane but the interface
383-
// is a bit awkward for a constant number of elements and it will
384-
// boil down to the same code.
385-
// TODO Expand the FRem instruction only once and reuse the code.
386-
Value *Nums = I.getOperand(0);
387-
Value *Denums = I.getOperand(1);
388-
Ret = PoisonValue::get(I.getType());
389-
for (int I = 0, E = VecTy->getNumElements(); I != E; ++I) {
390-
Value *Num = B.CreateExtractElement(Nums, I);
391-
Value *Denum = B.CreateExtractElement(Denums, I);
392-
Value *Rem = FMF.approxFunc() ? Expander.buildApproxFRem(Num, Denum)
393-
: Expander.buildFRem(Num, Denum, SQ);
394-
Ret = B.CreateInsertElement(Ret, Rem, I);
395-
}
396-
}
372+
const FRemExpander Expander = FRemExpander::create(B, Ty);
373+
Value *Ret = FMF.approxFunc()
374+
? Expander.buildApproxFRem(I.getOperand(0), I.getOperand(1))
375+
: Expander.buildFRem(I.getOperand(0), I.getOperand(1), SQ);
397376

398377
I.replaceAllUsesWith(Ret);
399378
Ret->takeName(&I);
@@ -948,11 +927,17 @@ static void scalarize(Instruction *I, SmallVectorImpl<Instruction *> &Replace) {
948927
Value *Result = PoisonValue::get(VTy);
949928
for (unsigned Idx = 0; Idx < NumElements; ++Idx) {
950929
Value *Ext = Builder.CreateExtractElement(I->getOperand(0), Idx);
951-
Value *Cast = Builder.CreateCast(cast<CastInst>(I)->getOpcode(), Ext,
952-
I->getType()->getScalarType());
953-
Result = Builder.CreateInsertElement(Result, Cast, Idx);
954-
if (isa<Instruction>(Cast))
955-
Replace.push_back(cast<Instruction>(Cast));
930+
Value *Op;
931+
if (isa<BinaryOperator>(I))
932+
Op = Builder.CreateBinOp(
933+
cast<BinaryOperator>(I)->getOpcode(), Ext,
934+
Builder.CreateExtractElement(I->getOperand(1), Idx));
935+
else
936+
Op = Builder.CreateCast(cast<CastInst>(I)->getOpcode(), Ext,
937+
I->getType()->getScalarType());
938+
Result = Builder.CreateInsertElement(Result, Op, Idx);
939+
if (isa<Instruction>(Op))
940+
Replace.push_back(cast<Instruction>(Op));
956941
}
957942
I->replaceAllUsesWith(Result);
958943
I->dropAllReferences();
@@ -989,6 +974,16 @@ static bool targetSupportsFrem(const TargetLowering &TLI, Type *Ty) {
989974
return TLI.getLibcallName(fremToLibcall(Ty->getScalarType()));
990975
}
991976

977+
static void enqueueInstruction(Instruction &I,
978+
SmallVector<Instruction *, 4> &Replace,
979+
SmallVector<Instruction *, 4> &ReplaceVector) {
980+
981+
if (I.getOperand(0)->getType()->isVectorTy())
982+
ReplaceVector.push_back(&I);
983+
else
984+
Replace.push_back(&I);
985+
}
986+
992987
static bool runImpl(Function &F, const TargetLowering &TLI,
993988
AssumptionCache *AC) {
994989
SmallVector<Instruction *, 4> Replace;
@@ -1004,55 +999,37 @@ static bool runImpl(Function &F, const TargetLowering &TLI,
1004999
return false;
10051000

10061001
for (auto &I : instructions(F)) {
1007-
switch (I.getOpcode()) {
1008-
case Instruction::FRem: {
1009-
Type *Ty = I.getType();
1010-
// TODO: This pass doesn't handle scalable vectors.
1011-
if (Ty->isScalableTy())
1012-
continue;
1013-
1014-
if (targetSupportsFrem(TLI, Ty) ||
1015-
!FRemExpander::canExpandType(Ty->getScalarType()))
1016-
continue;
1017-
1018-
Replace.push_back(&I);
1019-
Modified = true;
1002+
Type *Ty = I.getType();
1003+
// TODO: This pass doesn't handle scalable vectors.
1004+
if (Ty->isScalableTy())
1005+
continue;
10201006

1007+
switch (I.getOpcode()) {
1008+
case Instruction::FRem:
1009+
if (!targetSupportsFrem(TLI, Ty) &&
1010+
FRemExpander::canExpandType(Ty->getScalarType())) {
1011+
enqueueInstruction(I, Replace, ReplaceVector);
1012+
Modified = true;
1013+
}
10211014
break;
1022-
}
10231015
case Instruction::FPToUI:
10241016
case Instruction::FPToSI: {
1025-
// TODO: This pass doesn't handle scalable vectors.
1026-
if (I.getOperand(0)->getType()->isScalableTy())
1027-
continue;
1028-
1029-
auto *IntTy = cast<IntegerType>(I.getType()->getScalarType());
1017+
auto *IntTy = cast<IntegerType>(Ty->getScalarType());
10301018
if (IntTy->getIntegerBitWidth() <= MaxLegalFpConvertBitWidth)
10311019
continue;
10321020

1033-
if (I.getOperand(0)->getType()->isVectorTy())
1034-
ReplaceVector.push_back(&I);
1035-
else
1036-
Replace.push_back(&I);
1021+
enqueueInstruction(I, Replace, ReplaceVector);
10371022
Modified = true;
10381023
break;
10391024
}
10401025
case Instruction::UIToFP:
10411026
case Instruction::SIToFP: {
1042-
// TODO: This pass doesn't handle scalable vectors.
1043-
if (I.getOperand(0)->getType()->isScalableTy())
1044-
continue;
1045-
10461027
auto *IntTy =
10471028
cast<IntegerType>(I.getOperand(0)->getType()->getScalarType());
10481029
if (IntTy->getIntegerBitWidth() <= MaxLegalFpConvertBitWidth)
10491030
continue;
10501031

1051-
if (I.getOperand(0)->getType()->isVectorTy())
1052-
ReplaceVector.push_back(&I);
1053-
else
1054-
Replace.push_back(&I);
1055-
Modified = true;
1032+
enqueueInstruction(I, Replace, ReplaceVector);
10561033
break;
10571034
}
10581035
default:

0 commit comments

Comments
 (0)