Skip to content
Merged
28 changes: 21 additions & 7 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -899,14 +899,20 @@ class TargetTransformInfo {

bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;

/// Identifies if the vector form of the intrinsic has a scalar operand.
bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
unsigned ScalarOpdIdx) const;

/// Identifies if the vector form of the intrinsic is overloaded on the type
/// of the operand at index \p OpdIdx, or on the return type if \p OpdIdx is
/// -1.
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx) const;
bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int OpdIdx) const;

/// Identifies if the vector form of the intrinsic that returns a struct is
/// overloaded at the struct element index \p RetIdx.
bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
int RetIdx) const;

/// Estimate the overhead of scalarizing an instruction. Insert and Extract
/// are set if the demanded result elements need to be inserted and/or
Expand Down Expand Up @@ -2002,8 +2008,11 @@ class TargetTransformInfo::Concept {
virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
unsigned ScalarOpdIdx) = 0;
virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx) = 0;
virtual bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int OpdIdx) = 0;
virtual bool
isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
int RetIdx) = 0;
virtual InstructionCost
getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
bool Insert, bool Extract, TargetCostKind CostKind,
Expand Down Expand Up @@ -2580,9 +2589,14 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
}

bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx) override {
return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int OpdIdx) override {
return Impl.isTargetIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
}

bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
int RetIdx) override {
return Impl.isTargetIntrinsicWithStructReturnOverloadAtField(ID, RetIdx);
}

InstructionCost getScalarizationOverhead(VectorType *Ty,
Expand Down
11 changes: 8 additions & 3 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,14 @@ class TargetTransformInfoImplBase {
return false;
}

bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx) const {
return ScalarOpdIdx == -1;
bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int OpdIdx) const {
return OpdIdx == -1;
}

bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
int RetIdx) const {
return RetIdx == 0;
}

InstructionCost getScalarizationOverhead(VectorType *Ty,
Expand Down
26 changes: 21 additions & 5 deletions llvm/include/llvm/Analysis/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,25 @@ inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
/// This method returns true if the intrinsic's argument types are all scalars
/// for the scalar form of the intrinsic and all vectors (or scalars handled by
/// isVectorIntrinsicWithScalarOpAtArg) for the vector form of the intrinsic.
///
/// Note: isTriviallyVectorizable implies isTriviallyScalarizable.
bool isTriviallyVectorizable(Intrinsic::ID ID);

/// Identify if the intrinsic is trivially scalarizable.
/// This method returns true following the same predicates of
/// isTriviallyVectorizable.

/// Note: There are intrinsics where implementing vectorization for the
/// intrinsic is redundant, but we want to implement scalarization of the
/// vector. To prevent the requirement that an intrinsic also implements
/// vectorization we provide this seperate function.
bool isTriviallyScalarizable(Intrinsic::ID ID, const TargetTransformInfo *TTI);

/// Identifies if the vector form of the intrinsic has a scalar operand.
bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
unsigned ScalarOpdIdx);
/// \p TTI is used to consider target specific intrinsics, if no target specific
/// intrinsics will be considered then it is appropriate to pass in nullptr.
bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, unsigned ScalarOpdIdx,
const TargetTransformInfo *TTI);

/// Identifies if the vector form of the intrinsic is overloaded on the type of
/// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
Expand All @@ -158,9 +172,11 @@ bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx,
const TargetTransformInfo *TTI);

/// Identifies if the vector form of the intrinsic that returns a struct is
/// overloaded at the struct element index \p RetIdx.
bool isVectorIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
int RetIdx);
/// overloaded at the struct element index \p RetIdx. /// \p TTI is used to
/// consider target specific intrinsics, if no target specific intrinsics
/// will be considered then it is appropriate to pass in nullptr.
bool isVectorIntrinsicWithStructReturnOverloadAtField(
Intrinsic::ID ID, int RetIdx, const TargetTransformInfo *TTI);

/// Returns intrinsic ID for call.
/// For the input call instruction it finds mapping intrinsic and returns
Expand Down
11 changes: 8 additions & 3 deletions llvm/include/llvm/CodeGen/BasicTTIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -819,9 +819,14 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
return false;
}

bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx) const {
return ScalarOpdIdx == -1;
bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int OpdIdx) const {
return OpdIdx == -1;
}

bool isTargetIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
int RetIdx) const {
return RetIdx == 0;
}

/// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Analysis/ConstantFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3447,7 +3447,7 @@ static Constant *ConstantFoldFixedVectorCall(
// Gather a column of constants.
for (unsigned J = 0, JE = Operands.size(); J != JE; ++J) {
// Some intrinsics use a scalar type for certain arguments.
if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, J)) {
if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, J, /*TTI=*/nullptr)) {
Lane[J] = Operands[J];
continue;
}
Expand Down
11 changes: 8 additions & 3 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,9 +615,14 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
}

bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
Intrinsic::ID ID, int ScalarOpdIdx) const {
return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
bool TargetTransformInfo::isTargetIntrinsicWithOverloadTypeAtArg(
Intrinsic::ID ID, int OpdIdx) const {
return TTIImpl->isTargetIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
}

bool TargetTransformInfo::isTargetIntrinsicWithStructReturnOverloadAtField(
Intrinsic::ID ID, int RetIdx) const {
return TTIImpl->isTargetIntrinsicWithStructReturnOverloadAtField(ID, RetIdx);
}

InstructionCost TargetTransformInfo::getScalarizationOverhead(
Expand Down
34 changes: 30 additions & 4 deletions llvm/lib/Analysis/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,31 @@ bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) {
}
}

bool llvm::isTriviallyScalarizable(Intrinsic::ID ID,
const TargetTransformInfo *TTI) {
if (isTriviallyVectorizable(ID))
return true;

if (TTI && Intrinsic::isTargetIntrinsic(ID))
return TTI->isTargetIntrinsicTriviallyScalarizable(ID);

// TODO: Move frexp to isTriviallyVectorizable.
// https://github.com/llvm/llvm-project/issues/112408
switch (ID) {
case Intrinsic::frexp:
return true;
}
return false;
}

/// Identifies if the vector form of the intrinsic has a scalar operand.
bool llvm::isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
unsigned ScalarOpdIdx) {
unsigned ScalarOpdIdx,
const TargetTransformInfo *TTI) {

if (TTI && Intrinsic::isTargetIntrinsic(ID))
return TTI->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);

switch (ID) {
case Intrinsic::abs:
case Intrinsic::vp_abs:
Expand All @@ -142,7 +164,7 @@ bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(
assert(ID != Intrinsic::not_intrinsic && "Not an intrinsic!");

if (TTI && Intrinsic::isTargetIntrinsic(ID))
return TTI->isVectorIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);
return TTI->isTargetIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);

if (VPCastIntrinsic::isVPCast(ID))
return OpdIdx == -1 || OpdIdx == 0;
Expand All @@ -167,8 +189,12 @@ bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(
}
}

bool llvm::isVectorIntrinsicWithStructReturnOverloadAtField(Intrinsic::ID ID,
int RetIdx) {
bool llvm::isVectorIntrinsicWithStructReturnOverloadAtField(
Intrinsic::ID ID, int RetIdx, const TargetTransformInfo *TTI) {

if (TTI && Intrinsic::isTargetIntrinsic(ID))
return TTI->isTargetIntrinsicWithStructReturnOverloadAtField(ID, RetIdx);

switch (ID) {
case Intrinsic::frexp:
return RetIdx == 0 || RetIdx == 1;
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/ReplaceWithVeclib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
auto *ArgTy = Arg.value()->getType();
bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index(),
/*TTI=*/nullptr);
if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index(), /*TTI=*/nullptr)) {
ScalarArgTypes.push_back(ArgTy);
if (IsOloadTy)
OloadTys.push_back(ArgTy);
Expand Down
8 changes: 4 additions & 4 deletions llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
}
}

bool DirectXTTIImpl::isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx) {
bool DirectXTTIImpl::isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int OpdIdx) {
switch (ID) {
case Intrinsic::dx_asdouble:
return ScalarOpdIdx == 0;
return OpdIdx == 0;
default:
return ScalarOpdIdx == -1;
return OpdIdx == -1;
}
}

Expand Down
3 changes: 1 addition & 2 deletions llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ class DirectXTTIImpl : public BasicTTIImplBase<DirectXTTIImpl> {
bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
unsigned ScalarOpdIdx);
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx);
bool isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx);
};
} // namespace llvm

Expand Down
24 changes: 4 additions & 20 deletions llvm/lib/Transforms/Scalar/Scalarizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,6 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {

bool visit(Function &F);

bool isTriviallyScalarizable(Intrinsic::ID ID);

// InstVisitor methods. They return true if the instruction was scalarized,
// false if nothing changed.
bool visitInstruction(Instruction &I) { return false; }
Expand Down Expand Up @@ -683,19 +681,6 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
return true;
}

bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) {
if (isTriviallyVectorizable(ID))
return true;
// TODO: Move frexp to isTriviallyVectorizable.
// https://github.com/llvm/llvm-project/issues/112408
switch (ID) {
case Intrinsic::frexp:
return true;
}
return Intrinsic::isTargetIntrinsic(ID) &&
TTI->isTargetIntrinsicTriviallyScalarizable(ID);
}

/// If a call to a vector typed intrinsic function, split into a scalar call per
/// element if possible for the intrinsic.
bool ScalarizerVisitor::splitCall(CallInst &CI) {
Expand All @@ -715,7 +700,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {

Intrinsic::ID ID = F->getIntrinsicID();

if (ID == Intrinsic::not_intrinsic || !isTriviallyScalarizable(ID))
if (ID == Intrinsic::not_intrinsic || !isTriviallyScalarizable(ID, TTI))
return false;

// unsigned NumElems = VT->getNumElements();
Expand Down Expand Up @@ -743,7 +728,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
// will only scalarize when the struct elements have the same bitness.
if (!CurrVS || CurrVS->NumPacked != VS->NumPacked)
return false;
if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, I))
if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, I, TTI))
Tys.push_back(CurrVS->SplitTy);
}
}
Expand Down Expand Up @@ -794,8 +779,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
Tys[0] = VS->RemainderTy;

for (unsigned J = 0; J != NumArgs; ++J) {
if (isVectorIntrinsicWithScalarOpAtArg(ID, J) ||
TTI->isTargetIntrinsicWithScalarOpAtArg(ID, J)) {
if (isVectorIntrinsicWithScalarOpAtArg(ID, J, TTI)) {
ScalarCallOps.push_back(ScalarOperands[J]);
} else {
ScalarCallOps.push_back(Scattered[J][I]);
Expand Down Expand Up @@ -1089,7 +1073,7 @@ bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
if (!F)
return false;
Intrinsic::ID ID = F->getIntrinsicID();
if (ID == Intrinsic::not_intrinsic || !isTriviallyScalarizable(ID))
if (ID == Intrinsic::not_intrinsic || !isTriviallyScalarizable(ID, TTI))
return false;
// Note: Fall through means Operand is a`CallInst` and it is defined in
// `isTriviallyScalarizable`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
auto *SE = PSE.getSE();
Intrinsic::ID IntrinID = getVectorIntrinsicIDForCall(CI, TLI);
for (unsigned Idx = 0; Idx < CI->arg_size(); ++Idx)
if (isVectorIntrinsicWithScalarOpAtArg(IntrinID, Idx)) {
if (isVectorIntrinsicWithScalarOpAtArg(IntrinID, Idx, TTI)) {
if (!SE->isLoopInvariant(PSE.getSCEV(CI->getOperand(Idx)),
TheLoop)) {
reportVectorizationFailure("Found unvectorizable intrinsic",
Expand Down
Loading
Loading