Skip to content

Commit a135a13

Browse files
paulwalker-armLukacma
authored andcommitted
[LLVM][InstCombine] Preserve vector types when shrinking FP constants. (llvm#163598)
While my objective is to make the shrinkfp path safe for ConstantFP based splats I discovered the following issues also affect ConstantVector based splats: 1. PreferBFloat is not set for bfloat vectors. 2. getMinimumFPType() returns a scalar type for vector constants where getSplatValue() is successful.
1 parent a72fe96 commit a135a13

File tree

2 files changed

+69
-21
lines changed

2 files changed

+69
-21
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,33 +1643,46 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
16431643

16441644
/// Return a Constant* for the specified floating-point constant if it fits
16451645
/// in the specified FP type without changing its value.
1646-
static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) {
1646+
static bool fitsInFPType(APFloat F, const fltSemantics &Sem) {
16471647
bool losesInfo;
1648-
APFloat F = CFP->getValueAPF();
16491648
(void)F.convert(Sem, APFloat::rmNearestTiesToEven, &losesInfo);
16501649
return !losesInfo;
16511650
}
16521651

1653-
static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) {
1654-
if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext()))
1655-
return nullptr; // No constant folding of this.
1652+
static Type *shrinkFPConstant(LLVMContext &Ctx, const APFloat &F,
1653+
bool PreferBFloat) {
16561654
// See if the value can be truncated to bfloat and then reextended.
1657-
if (PreferBFloat && fitsInFPType(CFP, APFloat::BFloat()))
1658-
return Type::getBFloatTy(CFP->getContext());
1655+
if (PreferBFloat && fitsInFPType(F, APFloat::BFloat()))
1656+
return Type::getBFloatTy(Ctx);
16591657
// See if the value can be truncated to half and then reextended.
1660-
if (!PreferBFloat && fitsInFPType(CFP, APFloat::IEEEhalf()))
1661-
return Type::getHalfTy(CFP->getContext());
1658+
if (!PreferBFloat && fitsInFPType(F, APFloat::IEEEhalf()))
1659+
return Type::getHalfTy(Ctx);
16621660
// See if the value can be truncated to float and then reextended.
1663-
if (fitsInFPType(CFP, APFloat::IEEEsingle()))
1664-
return Type::getFloatTy(CFP->getContext());
1665-
if (CFP->getType()->isDoubleTy())
1666-
return nullptr; // Won't shrink.
1667-
if (fitsInFPType(CFP, APFloat::IEEEdouble()))
1668-
return Type::getDoubleTy(CFP->getContext());
1661+
if (fitsInFPType(F, APFloat::IEEEsingle()))
1662+
return Type::getFloatTy(Ctx);
1663+
if (&F.getSemantics() == &APFloat::IEEEdouble())
1664+
return nullptr; // Won't shrink.
1665+
// See if the value can be truncated to double and then reextended.
1666+
if (fitsInFPType(F, APFloat::IEEEdouble()))
1667+
return Type::getDoubleTy(Ctx);
16691668
// Don't try to shrink to various long double types.
16701669
return nullptr;
16711670
}
16721671

1672+
static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) {
1673+
Type *Ty = CFP->getType();
1674+
if (Ty->getScalarType()->isPPC_FP128Ty())
1675+
return nullptr; // No constant folding of this.
1676+
1677+
Type *ShrinkTy =
1678+
shrinkFPConstant(CFP->getContext(), CFP->getValueAPF(), PreferBFloat);
1679+
if (ShrinkTy)
1680+
if (auto *VecTy = dyn_cast<VectorType>(Ty))
1681+
ShrinkTy = VectorType::get(ShrinkTy, VecTy);
1682+
1683+
return ShrinkTy;
1684+
}
1685+
16731686
// Determine if this is a vector of ConstantFPs and if so, return the minimal
16741687
// type we can safely truncate all elements to.
16751688
static Type *shrinkFPConstantVector(Value *V, bool PreferBFloat) {
@@ -1720,10 +1733,10 @@ static Type *getMinimumFPType(Value *V, bool PreferBFloat) {
17201733

17211734
// Try to shrink scalable and fixed splat vectors.
17221735
if (auto *FPC = dyn_cast<Constant>(V))
1723-
if (isa<VectorType>(V->getType()))
1736+
if (auto *VTy = dyn_cast<VectorType>(V->getType()))
17241737
if (auto *Splat = dyn_cast_or_null<ConstantFP>(FPC->getSplatValue()))
17251738
if (Type *T = shrinkFPConstant(Splat, PreferBFloat))
1726-
return T;
1739+
return VectorType::get(T, VTy);
17271740

17281741
// Try to shrink a vector of FP constants. This returns nullptr on scalable
17291742
// vectors
@@ -1796,10 +1809,9 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) {
17961809
Type *Ty = FPT.getType();
17971810
auto *BO = dyn_cast<BinaryOperator>(FPT.getOperand(0));
17981811
if (BO && BO->hasOneUse()) {
1799-
Type *LHSMinType =
1800-
getMinimumFPType(BO->getOperand(0), /*PreferBFloat=*/Ty->isBFloatTy());
1801-
Type *RHSMinType =
1802-
getMinimumFPType(BO->getOperand(1), /*PreferBFloat=*/Ty->isBFloatTy());
1812+
bool PreferBFloat = Ty->getScalarType()->isBFloatTy();
1813+
Type *LHSMinType = getMinimumFPType(BO->getOperand(0), PreferBFloat);
1814+
Type *RHSMinType = getMinimumFPType(BO->getOperand(1), PreferBFloat);
18031815
unsigned OpWidth = BO->getType()->getFPMantissaWidth();
18041816
unsigned LHSWidth = LHSMinType->getFPMantissaWidth();
18051817
unsigned RHSWidth = RHSMinType->getFPMantissaWidth();

llvm/test/Transforms/InstCombine/fpextend.ll

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
22
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
3+
; RUN: opt < %s -passes=instcombine -use-constant-fp-for-fixed-length-splat -S | FileCheck %s
34

45
define float @test(float %x) nounwind {
56
; CHECK-LABEL: @test(
@@ -449,6 +450,28 @@ define bfloat @bf16_frem(bfloat %x) {
449450
ret bfloat %t3
450451
}
451452

453+
define <4 x bfloat> @v4bf16_frem_x_const(<4 x bfloat> %x) {
454+
; CHECK-LABEL: @v4bf16_frem_x_const(
455+
; CHECK-NEXT: [[TMP1:%.*]] = frem <4 x bfloat> [[X:%.*]], splat (bfloat 0xR40C9)
456+
; CHECK-NEXT: ret <4 x bfloat> [[TMP1]]
457+
;
458+
%t1 = fpext <4 x bfloat> %x to <4 x float>
459+
%t2 = frem <4 x float> %t1, splat(float 6.281250e+00)
460+
%t3 = fptrunc <4 x float> %t2 to <4 x bfloat>
461+
ret <4 x bfloat> %t3
462+
}
463+
464+
define <4 x bfloat> @v4bf16_frem_const_x(<4 x bfloat> %x) {
465+
; CHECK-LABEL: @v4bf16_frem_const_x(
466+
; CHECK-NEXT: [[TMP1:%.*]] = frem <4 x bfloat> splat (bfloat 0xR40C9), [[X:%.*]]
467+
; CHECK-NEXT: ret <4 x bfloat> [[TMP1]]
468+
;
469+
%t1 = fpext <4 x bfloat> %x to <4 x float>
470+
%t2 = frem <4 x float> splat(float 6.281250e+00), %t1
471+
%t3 = fptrunc <4 x float> %t2 to <4 x bfloat>
472+
ret <4 x bfloat> %t3
473+
}
474+
452475
define <4 x float> @v4f32_fadd(<4 x float> %a) {
453476
; CHECK-LABEL: @v4f32_fadd(
454477
; CHECK-NEXT: [[TMP1:%.*]] = fadd <4 x float> [[A:%.*]], splat (float -1.000000e+00)
@@ -459,3 +482,16 @@ define <4 x float> @v4f32_fadd(<4 x float> %a) {
459482
%5 = fptrunc <4 x double> %4 to <4 x float>
460483
ret <4 x float> %5
461484
}
485+
486+
define <4 x float> @v4f32_fadd_const_not_shrinkable(<4 x float> %a) {
487+
; CHECK-LABEL: @v4f32_fadd_const_not_shrinkable(
488+
; CHECK-NEXT: [[TMP1:%.*]] = fpext <4 x float> [[A:%.*]] to <4 x double>
489+
; CHECK-NEXT: [[TMP2:%.*]] = fadd <4 x double> [[TMP1]], splat (double -1.000000e+100)
490+
; CHECK-NEXT: [[TMP3:%.*]] = fptrunc <4 x double> [[TMP2]] to <4 x float>
491+
; CHECK-NEXT: ret <4 x float> [[TMP3]]
492+
;
493+
%2 = fpext <4 x float> %a to <4 x double>
494+
%4 = fadd <4 x double> %2, splat (double -1.000000e+100)
495+
%5 = fptrunc <4 x double> %4 to <4 x float>
496+
ret <4 x float> %5
497+
}

0 commit comments

Comments
 (0)