Skip to content

Commit 9bedb14

Browse files
committed
DAG: Move scalarizeExtractedVectorLoad to TargetLowering
SimplifyDemandedVectorElts should be able to use this on loads
1 parent acbd822 commit 9bedb14

File tree

3 files changed

+93
-2
lines changed

3 files changed

+93
-2
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5618,6 +5618,18 @@ class TargetLowering : public TargetLoweringBase {
56185618
// joining their results. SDValue() is returned when expansion did not happen.
56195619
SDValue expandVectorNaryOpBySplitting(SDNode *Node, SelectionDAG &DAG) const;
56205620

5621+
/// Replace an extraction of a load with a narrowed load.
5622+
///
5623+
/// \param ResultVT type of the result extraction.
5624+
/// \param InVecVT type of the input vector to with bitcasts resolved.
5625+
/// \param EltNo index of the vector element to load.
5626+
/// \param OriginalLoad vector load that to be replaced.
5627+
/// \returns \p ResultVT Load on success SDValue() on failure.
5628+
SDValue scalarizeExtractedVectorLoad(EVT ResultVT, const SDLoc &DL,
5629+
EVT InVecVT, SDValue EltNo,
5630+
LoadSDNode *OriginalLoad,
5631+
SelectionDAG &DAG) const;
5632+
56215633
private:
56225634
SDValue foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
56235635
const SDLoc &DL, DAGCombinerInfo &DCI) const;

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23246,8 +23246,13 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
2324623246
ISD::isNormalLoad(VecOp.getNode()) &&
2324723247
!Index->hasPredecessor(VecOp.getNode())) {
2324823248
auto *VecLoad = dyn_cast<LoadSDNode>(VecOp);
23249-
if (VecLoad && VecLoad->isSimple())
23250-
return scalarizeExtractedVectorLoad(N, VecVT, Index, VecLoad);
23249+
if (VecLoad && VecLoad->isSimple()) {
23250+
if (SDValue Scalarized = TLI.scalarizeExtractedVectorLoad(
23251+
ExtVT, SDLoc(N), VecVT, Index, VecLoad, DAG)) {
23252+
++OpsNarrowed;
23253+
return Scalarized;
23254+
}
23255+
}
2325123256
}
2325223257

2325323258
// Perform only after legalization to ensure build_vector / vector_shuffle

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12069,3 +12069,77 @@ SDValue TargetLowering::expandVectorNaryOpBySplitting(SDNode *Node,
1206912069
SDValue SplitOpHi = DAG.getNode(Opcode, DL, HiVT, HiOps);
1207012070
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
1207112071
}
12072+
12073+
SDValue TargetLowering::scalarizeExtractedVectorLoad(EVT ResultVT,
12074+
const SDLoc &DL,
12075+
EVT InVecVT, SDValue EltNo,
12076+
LoadSDNode *OriginalLoad,
12077+
SelectionDAG &DAG) const {
12078+
assert(OriginalLoad->isSimple());
12079+
12080+
EVT VecEltVT = InVecVT.getVectorElementType();
12081+
12082+
// If the vector element type is not a multiple of a byte then we are unable
12083+
// to correctly compute an address to load only the extracted element as a
12084+
// scalar.
12085+
if (!VecEltVT.isByteSized())
12086+
return SDValue();
12087+
12088+
ISD::LoadExtType ExtTy =
12089+
ResultVT.bitsGT(VecEltVT) ? ISD::EXTLOAD : ISD::NON_EXTLOAD;
12090+
if (!isOperationLegalOrCustom(ISD::LOAD, VecEltVT) ||
12091+
!shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
12092+
return SDValue();
12093+
12094+
Align Alignment = OriginalLoad->getAlign();
12095+
MachinePointerInfo MPI;
12096+
if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
12097+
int Elt = ConstEltNo->getZExtValue();
12098+
unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
12099+
MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
12100+
Alignment = commonAlignment(Alignment, PtrOff);
12101+
} else {
12102+
// Discard the pointer info except the address space because the memory
12103+
// operand can't represent this new access since the offset is variable.
12104+
MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
12105+
Alignment = commonAlignment(Alignment, VecEltVT.getSizeInBits() / 8);
12106+
}
12107+
12108+
unsigned IsFast = 0;
12109+
if (!allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VecEltVT,
12110+
OriginalLoad->getAddressSpace(), Alignment,
12111+
OriginalLoad->getMemOperand()->getFlags(), &IsFast) ||
12112+
!IsFast)
12113+
return SDValue();
12114+
12115+
SDValue NewPtr =
12116+
getVectorElementPointer(DAG, OriginalLoad->getBasePtr(), InVecVT, EltNo);
12117+
12118+
// We are replacing a vector load with a scalar load. The new load must have
12119+
// identical memory op ordering to the original.
12120+
SDValue Load;
12121+
if (ResultVT.bitsGT(VecEltVT)) {
12122+
// If the result type of vextract is wider than the load, then issue an
12123+
// extending load instead.
12124+
ISD::LoadExtType ExtType = isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT)
12125+
? ISD::ZEXTLOAD
12126+
: ISD::EXTLOAD;
12127+
Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(),
12128+
NewPtr, MPI, VecEltVT, Alignment,
12129+
OriginalLoad->getMemOperand()->getFlags(),
12130+
OriginalLoad->getAAInfo());
12131+
DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
12132+
} else {
12133+
// The result type is narrower or the same width as the vector element
12134+
Load = DAG.getLoad(VecEltVT, DL, OriginalLoad->getChain(), NewPtr, MPI,
12135+
Alignment, OriginalLoad->getMemOperand()->getFlags(),
12136+
OriginalLoad->getAAInfo());
12137+
DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
12138+
if (ResultVT.bitsLT(VecEltVT))
12139+
Load = DAG.getNode(ISD::TRUNCATE, DL, ResultVT, Load);
12140+
else
12141+
Load = DAG.getBitcast(ResultVT, Load);
12142+
}
12143+
12144+
return Load;
12145+
}

0 commit comments

Comments
 (0)