Skip to content

Commit b95d4f3

Browse files
committed
[NFC][Scalarizer][TargetTransformInfo] Add isVectorIntrinsicWithOverloadTypeAtArg
This changes allows target intrinsic to specify overloaded types. This change will let us add scalarization for `asdouble`:
1 parent eaa7b38 commit b95d4f3

File tree

7 files changed

+49
-3
lines changed

7 files changed

+49
-3
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,10 @@ class TargetTransformInfo {
896896
bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
897897
unsigned ScalarOpdIdx) const;
898898

899+
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
900+
unsigned ScalarOpdIdx,
901+
bool Default) const;
902+
899903
/// Estimate the overhead of scalarizing an instruction. Insert and Extract
900904
/// are set if the demanded result elements need to be inserted and/or
901905
/// extracted from vectors.
@@ -1969,6 +1973,9 @@ class TargetTransformInfo::Concept {
19691973
virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
19701974
virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
19711975
unsigned ScalarOpdIdx) = 0;
1976+
virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
1977+
unsigned ScalarOpdIdx,
1978+
bool Default) = 0;
19721979
virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
19731980
const APInt &DemandedElts,
19741981
bool Insert, bool Extract,
@@ -2530,6 +2537,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
25302537
return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
25312538
}
25322539

2540+
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
2541+
unsigned ScalarOpdIdx,
2542+
bool Default) override {
2543+
return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx,
2544+
Default);
2545+
}
2546+
25332547
InstructionCost getScalarizationOverhead(VectorType *Ty,
25342548
const APInt &DemandedElts,
25352549
bool Insert, bool Extract,

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,12 @@ class TargetTransformInfoImplBase {
392392
return false;
393393
}
394394

395+
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
396+
unsigned ScalarOpdIdx,
397+
bool Default) const {
398+
return Default;
399+
}
400+
395401
InstructionCost getScalarizationOverhead(VectorType *Ty,
396402
const APInt &DemandedElts,
397403
bool Insert, bool Extract,

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
798798
return false;
799799
}
800800

801+
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
802+
unsigned ScalarOpdIdx,
803+
bool Default) const {
804+
return Default;
805+
}
806+
801807
/// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
802808
InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert,
803809
bool Extract,

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,12 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
612612
return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
613613
}
614614

615+
bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
616+
Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) const {
617+
return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx,
618+
Default);
619+
}
620+
615621
InstructionCost TargetTransformInfo::getScalarizationOverhead(
616622
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
617623
TTI::TargetCostKind CostKind) const {

llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
2525
}
2626
}
2727

28+
bool DirectXTTIImpl::isVectorIntrinsicWithOverloadTypeAtArg(
29+
Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) {
30+
switch (ID) {
31+
default:
32+
return Default;
33+
}
34+
}
35+
2836
bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
2937
Intrinsic::ID ID) const {
3038
switch (ID) {

llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ class DirectXTTIImpl : public BasicTTIImplBase<DirectXTTIImpl> {
3737
bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
3838
bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
3939
unsigned ScalarOpdIdx);
40+
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
41+
unsigned ScalarOpdIdx,
42+
bool Default);
4043
};
4144
} // namespace llvm
4245

llvm/lib/Transforms/Scalar/Scalarizer.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,8 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
727727

728728
SmallVector<llvm::Type *, 3> Tys;
729729
// Add return type if intrinsic is overloaded on it.
730-
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
730+
if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
731+
ID, -1, isVectorIntrinsicWithOverloadTypeAtArg(ID, -1)))
731732
Tys.push_back(VS->SplitTy);
732733

733734
if (AreAllVectorsOfMatchingSize) {
@@ -767,13 +768,15 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
767768
}
768769

769770
Scattered[I] = scatter(&CI, OpI, *OpVS);
770-
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) {
771+
if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
772+
ID, I, isVectorIntrinsicWithOverloadTypeAtArg(ID, I))) {
771773
OverloadIdx[I] = Tys.size();
772774
Tys.push_back(OpVS->SplitTy);
773775
}
774776
} else {
775777
ScalarOperands[I] = OpI;
776-
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
778+
if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
779+
ID, I, isVectorIntrinsicWithOverloadTypeAtArg(ID, I)))
777780
Tys.push_back(OpI->getType());
778781
}
779782
}

0 commit comments

Comments
 (0)