Skip to content

Commit d1500d1

Browse files
authored
[SelectionDAG] Add SelectionDAG::getTypeSize. NFC (#169764)
Similar to how getElementCount avoids the need to reason about fixed and scalable ElementCounts separately, this patch adds getTypeSize to do the same for TypeSize. It also goes through and replaces some of the manual uses of getVScale with getTypeSize/getElementCount where possible.
1 parent b162099 commit d1500d1

File tree

7 files changed

+54
-95
lines changed

7 files changed

+54
-95
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,11 +1185,11 @@ class SelectionDAG {
11851185
SDValue getPOISON(EVT VT) { return getNode(ISD::POISON, SDLoc(), VT); }
11861186

11871187
/// Return a node that represents the runtime scaling 'MulImm * RuntimeVL'.
1188-
LLVM_ABI SDValue getVScale(const SDLoc &DL, EVT VT, APInt MulImm,
1189-
bool ConstantFold = true);
1188+
LLVM_ABI SDValue getVScale(const SDLoc &DL, EVT VT, APInt MulImm);
11901189

1191-
LLVM_ABI SDValue getElementCount(const SDLoc &DL, EVT VT, ElementCount EC,
1192-
bool ConstantFold = true);
1190+
LLVM_ABI SDValue getElementCount(const SDLoc &DL, EVT VT, ElementCount EC);
1191+
1192+
LLVM_ABI SDValue getTypeSize(const SDLoc &DL, EVT VT, TypeSize TS);
11931193

11941194
/// Return a GLOBAL_OFFSET_TABLE node. This does not have a useful SDLoc.
11951195
SDValue getGLOBAL_OFFSET_TABLE(EVT VT) {

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,10 +1702,8 @@ void DAGTypeLegalizer::SplitVecRes_LOOP_DEPENDENCE_MASK(SDNode *N, SDValue &Lo,
17021702
Lo = DAG.getNode(N->getOpcode(), DL, LoVT, PtrA, PtrB, N->getOperand(2));
17031703

17041704
unsigned EltSize = N->getConstantOperandVal(2);
1705-
unsigned Offset = EltSize * HiVT.getVectorMinNumElements();
1706-
SDValue Addend = HiVT.isScalableVT()
1707-
? DAG.getVScale(DL, MVT::i64, APInt(64, Offset))
1708-
: DAG.getConstant(Offset, DL, MVT::i64);
1705+
ElementCount Offset = HiVT.getVectorElementCount() * EltSize;
1706+
SDValue Addend = DAG.getElementCount(DL, MVT::i64, Offset);
17091707

17101708
PtrA = DAG.getNode(ISD::ADD, DL, MVT::i64, PtrA, Addend);
17111709
Hi = DAG.getNode(N->getOpcode(), DL, HiVT, PtrA, PtrB, N->getOperand(2));

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2098,32 +2098,43 @@ SDValue SelectionDAG::getCondCode(ISD::CondCode Cond) {
20982098
return SDValue(CondCodeNodes[Cond], 0);
20992099
}
21002100

2101-
SDValue SelectionDAG::getVScale(const SDLoc &DL, EVT VT, APInt MulImm,
2102-
bool ConstantFold) {
2101+
SDValue SelectionDAG::getVScale(const SDLoc &DL, EVT VT, APInt MulImm) {
21032102
assert(MulImm.getBitWidth() == VT.getSizeInBits() &&
21042103
"APInt size does not match type size!");
21052104

21062105
if (MulImm == 0)
21072106
return getConstant(0, DL, VT);
21082107

2109-
if (ConstantFold) {
2110-
const MachineFunction &MF = getMachineFunction();
2111-
const Function &F = MF.getFunction();
2112-
ConstantRange CR = getVScaleRange(&F, 64);
2113-
if (const APInt *C = CR.getSingleElement())
2114-
return getConstant(MulImm * C->getZExtValue(), DL, VT);
2115-
}
2108+
const MachineFunction &MF = getMachineFunction();
2109+
const Function &F = MF.getFunction();
2110+
ConstantRange CR = getVScaleRange(&F, 64);
2111+
if (const APInt *C = CR.getSingleElement())
2112+
return getConstant(MulImm * C->getZExtValue(), DL, VT);
21162113

21172114
return getNode(ISD::VSCALE, DL, VT, getConstant(MulImm, DL, VT));
21182115
}
21192116

2120-
SDValue SelectionDAG::getElementCount(const SDLoc &DL, EVT VT, ElementCount EC,
2121-
bool ConstantFold) {
2122-
if (EC.isScalable())
2123-
return getVScale(DL, VT,
2124-
APInt(VT.getSizeInBits(), EC.getKnownMinValue()));
2117+
/// \returns a value of type \p VT that represents the runtime value of \p
2118+
/// Quantity, i.e. scaled by vscale if it's scalable, or a fixed constant
2119+
/// otherwise. Quantity should be a FixedOrScalableQuantity, i.e. ElementCount
2120+
/// or TypeSize.
2121+
template <typename Ty>
2122+
static SDValue getFixedOrScalableQuantity(SelectionDAG &DAG, const SDLoc &DL,
2123+
EVT VT, Ty Quantity) {
2124+
if (Quantity.isScalable())
2125+
return DAG.getVScale(
2126+
DL, VT, APInt(VT.getSizeInBits(), Quantity.getKnownMinValue()));
2127+
2128+
return DAG.getConstant(Quantity.getKnownMinValue(), DL, VT);
2129+
}
2130+
2131+
SDValue SelectionDAG::getElementCount(const SDLoc &DL, EVT VT,
2132+
ElementCount EC) {
2133+
return getFixedOrScalableQuantity(*this, DL, VT, EC);
2134+
}
21252135

2126-
return getConstant(EC.getKnownMinValue(), DL, VT);
2136+
SDValue SelectionDAG::getTypeSize(const SDLoc &DL, EVT VT, TypeSize TS) {
2137+
return getFixedOrScalableQuantity(*this, DL, VT, TS);
21272138
}
21282139

21292140
SDValue SelectionDAG::getStepVector(const SDLoc &DL, EVT ResVT) {
@@ -8500,16 +8511,7 @@ static SDValue getMemsetStringVal(EVT VT, const SDLoc &dl, SelectionDAG &DAG,
85008511
SDValue SelectionDAG::getMemBasePlusOffset(SDValue Base, TypeSize Offset,
85018512
const SDLoc &DL,
85028513
const SDNodeFlags Flags) {
8503-
EVT VT = Base.getValueType();
8504-
SDValue Index;
8505-
8506-
if (Offset.isScalable())
8507-
Index = getVScale(DL, Base.getValueType(),
8508-
APInt(Base.getValueSizeInBits().getFixedValue(),
8509-
Offset.getKnownMinValue()));
8510-
else
8511-
Index = getConstant(Offset.getFixedValue(), DL, VT);
8512-
8514+
SDValue Index = getTypeSize(DL, Base.getValueType(), Offset);
85138515
return getMemBasePlusOffset(Base, Index, DL, Flags);
85148516
}
85158517

@@ -13585,11 +13587,8 @@ std::pair<SDValue, SDValue> SelectionDAG::SplitEVL(SDValue N, EVT VecVT,
1358513587
EVT VT = N.getValueType();
1358613588
assert(VecVT.getVectorElementCount().isKnownEven() &&
1358713589
"Expecting the mask to be an evenly-sized vector");
13588-
unsigned HalfMinNumElts = VecVT.getVectorMinNumElements() / 2;
13589-
SDValue HalfNumElts =
13590-
VecVT.isFixedLengthVector()
13591-
? getConstant(HalfMinNumElts, DL, VT)
13592-
: getVScale(DL, VT, APInt(VT.getScalarSizeInBits(), HalfMinNumElts));
13590+
SDValue HalfNumElts = getElementCount(
13591+
DL, VT, VecVT.getVectorElementCount().divideCoefficientBy(2));
1359313592
SDValue Lo = getNode(ISD::UMIN, DL, VT, N, HalfNumElts);
1359413593
SDValue Hi = getNode(ISD::USUBSAT, DL, VT, N, HalfNumElts);
1359513594
return std::make_pair(Lo, Hi);

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4584,17 +4584,9 @@ void SelectionDAGBuilder::visitAlloca(const AllocaInst &I) {
45844584
if (AllocSize.getValueType() != IntPtr)
45854585
AllocSize = DAG.getZExtOrTrunc(AllocSize, dl, IntPtr);
45864586

4587-
if (TySize.isScalable())
4588-
AllocSize = DAG.getNode(ISD::MUL, dl, IntPtr, AllocSize,
4589-
DAG.getVScale(dl, IntPtr,
4590-
APInt(IntPtr.getScalarSizeInBits(),
4591-
TySize.getKnownMinValue())));
4592-
else {
4593-
SDValue TySizeValue =
4594-
DAG.getConstant(TySize.getFixedValue(), dl, MVT::getIntegerVT(64));
4595-
AllocSize = DAG.getNode(ISD::MUL, dl, IntPtr, AllocSize,
4596-
DAG.getZExtOrTrunc(TySizeValue, dl, IntPtr));
4597-
}
4587+
AllocSize = DAG.getNode(
4588+
ISD::MUL, dl, IntPtr, AllocSize,
4589+
DAG.getZExtOrTrunc(DAG.getTypeSize(dl, MVT::i64, TySize), dl, IntPtr));
45984590

45994591
// Handle alignment. If the requested alignment is less than or equal to
46004592
// the stack alignment, ignore it. If the size is greater than or equal to

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10628,12 +10628,8 @@ TargetLowering::IncrementMemoryAddress(SDValue Addr, SDValue Mask,
1062810628
AddrVT);
1062910629
Increment = DAG.getZExtOrTrunc(Increment, DL, AddrVT);
1063010630
Increment = DAG.getNode(ISD::MUL, DL, AddrVT, Increment, Scale);
10631-
} else if (DataVT.isScalableVector()) {
10632-
Increment = DAG.getVScale(DL, AddrVT,
10633-
APInt(AddrVT.getFixedSizeInBits(),
10634-
DataVT.getStoreSize().getKnownMinValue()));
1063510631
} else
10636-
Increment = DAG.getConstant(DataVT.getStoreSize(), DL, AddrVT);
10632+
Increment = DAG.getTypeSize(DL, AddrVT, DataVT.getStoreSize());
1063710633

1063810634
return DAG.getNode(ISD::ADD, DL, AddrVT, Addr, Increment);
1063910635
}
@@ -11926,10 +11922,8 @@ SDValue TargetLowering::expandVectorSplice(SDNode *Node,
1192611922
// Store the lo part of CONCAT_VECTORS(V1, V2)
1192711923
SDValue StoreV1 = DAG.getStore(DAG.getEntryNode(), DL, V1, StackPtr, PtrInfo);
1192811924
// Store the hi part of CONCAT_VECTORS(V1, V2)
11929-
SDValue OffsetToV2 = DAG.getVScale(
11930-
DL, PtrVT,
11931-
APInt(PtrVT.getFixedSizeInBits(), VT.getStoreSize().getKnownMinValue()));
11932-
SDValue StackPtr2 = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, OffsetToV2);
11925+
SDValue VTBytes = DAG.getTypeSize(DL, PtrVT, VT.getStoreSize());
11926+
SDValue StackPtr2 = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, VTBytes);
1193311927
SDValue StoreV2 = DAG.getStore(StoreV1, DL, V2, StackPtr2, PtrInfo);
1193411928

1193511929
if (Imm >= 0) {
@@ -11948,13 +11942,8 @@ SDValue TargetLowering::expandVectorSplice(SDNode *Node,
1194811942
SDValue TrailingBytes =
1194911943
DAG.getConstant(TrailingElts * EltByteSize, DL, PtrVT);
1195011944

11951-
if (TrailingElts > VT.getVectorMinNumElements()) {
11952-
SDValue VLBytes =
11953-
DAG.getVScale(DL, PtrVT,
11954-
APInt(PtrVT.getFixedSizeInBits(),
11955-
VT.getStoreSize().getKnownMinValue()));
11956-
TrailingBytes = DAG.getNode(ISD::UMIN, DL, PtrVT, TrailingBytes, VLBytes);
11957-
}
11945+
if (TrailingElts > VT.getVectorMinNumElements())
11946+
TrailingBytes = DAG.getNode(ISD::UMIN, DL, PtrVT, TrailingBytes, VTBytes);
1195811947

1195911948
// Calculate the start address of the spliced result.
1196011949
StackPtr2 = DAG.getNode(ISD::SUB, DL, PtrVT, StackPtr2, TrailingBytes);

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8647,7 +8647,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
86478647
Subtarget->isWindowsArm64EC()) &&
86488648
"Indirect arguments should be scalable on most subtargets");
86498649

8650-
uint64_t PartSize = VA.getValVT().getStoreSize().getKnownMinValue();
8650+
TypeSize PartSize = VA.getValVT().getStoreSize();
86518651
unsigned NumParts = 1;
86528652
if (Ins[i].Flags.isInConsecutiveRegs()) {
86538653
while (!Ins[i + NumParts - 1].Flags.isInConsecutiveRegsLast())
@@ -8664,16 +8664,8 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
86648664
InVals.push_back(ArgValue);
86658665
NumParts--;
86668666
if (NumParts > 0) {
8667-
SDValue BytesIncrement;
8668-
if (PartLoad.isScalableVector()) {
8669-
BytesIncrement = DAG.getVScale(
8670-
DL, Ptr.getValueType(),
8671-
APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize));
8672-
} else {
8673-
BytesIncrement = DAG.getConstant(
8674-
APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize), DL,
8675-
Ptr.getValueType());
8676-
}
8667+
SDValue BytesIncrement =
8668+
DAG.getTypeSize(DL, Ptr.getValueType(), PartSize);
86778669
Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
86788670
BytesIncrement, SDNodeFlags::NoUnsignedWrap);
86798671
ExtraArgLocs++;
@@ -9880,8 +9872,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
98809872
assert((isScalable || Subtarget->isWindowsArm64EC()) &&
98819873
"Indirect arguments should be scalable on most subtargets");
98829874

9883-
uint64_t StoreSize = VA.getValVT().getStoreSize().getKnownMinValue();
9884-
uint64_t PartSize = StoreSize;
9875+
TypeSize StoreSize = VA.getValVT().getStoreSize();
9876+
TypeSize PartSize = StoreSize;
98859877
unsigned NumParts = 1;
98869878
if (Outs[i].Flags.isInConsecutiveRegs()) {
98879879
while (!Outs[i + NumParts - 1].Flags.isInConsecutiveRegsLast())
@@ -9892,7 +9884,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
98929884
Type *Ty = EVT(VA.getValVT()).getTypeForEVT(*DAG.getContext());
98939885
Align Alignment = DAG.getDataLayout().getPrefTypeAlign(Ty);
98949886
MachineFrameInfo &MFI = MF.getFrameInfo();
9895-
int FI = MFI.CreateStackObject(StoreSize, Alignment, false);
9887+
int FI =
9888+
MFI.CreateStackObject(StoreSize.getKnownMinValue(), Alignment, false);
98969889
if (isScalable) {
98979890
bool IsPred = VA.getValVT() == MVT::aarch64svcount ||
98989891
VA.getValVT().getVectorElementType() == MVT::i1;
@@ -9913,16 +9906,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
99139906

99149907
NumParts--;
99159908
if (NumParts > 0) {
9916-
SDValue BytesIncrement;
9917-
if (isScalable) {
9918-
BytesIncrement = DAG.getVScale(
9919-
DL, Ptr.getValueType(),
9920-
APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize));
9921-
} else {
9922-
BytesIncrement = DAG.getConstant(
9923-
APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize), DL,
9924-
Ptr.getValueType());
9925-
}
9909+
SDValue BytesIncrement =
9910+
DAG.getTypeSize(DL, Ptr.getValueType(), PartSize);
99269911
MPI = MachinePointerInfo(MPI.getAddrSpace());
99279912
Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
99289913
BytesIncrement, SDNodeFlags::NoUnsignedWrap);

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12783,10 +12783,7 @@ SDValue RISCVTargetLowering::lowerVECTOR_INTERLEAVE(SDValue Op,
1278312783

1278412784
SmallVector<SDValue, 8> Loads(Factor);
1278512785

12786-
SDValue Increment =
12787-
DAG.getVScale(DL, PtrVT,
12788-
APInt(PtrVT.getFixedSizeInBits(),
12789-
VecVT.getStoreSize().getKnownMinValue()));
12786+
SDValue Increment = DAG.getTypeSize(DL, PtrVT, VecVT.getStoreSize());
1279012787
for (unsigned i = 0; i != Factor; ++i) {
1279112788
if (i != 0)
1279212789
StackPtr = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, Increment);
@@ -14184,9 +14181,8 @@ RISCVTargetLowering::lowerVPReverseExperimental(SDValue Op,
1418414181

1418514182
// Slide off any elements from past EVL that were reversed into the low
1418614183
// elements.
14187-
unsigned MinElts = GatherVT.getVectorMinNumElements();
1418814184
SDValue VLMax =
14189-
DAG.getVScale(DL, XLenVT, APInt(XLenVT.getSizeInBits(), MinElts));
14185+
DAG.getElementCount(DL, XLenVT, GatherVT.getVectorElementCount());
1419014186
SDValue Diff = DAG.getNode(ISD::SUB, DL, XLenVT, VLMax, EVL);
1419114187

1419214188
Result = getVSlidedown(DAG, Subtarget, DL, GatherVT,

0 commit comments

Comments
 (0)