Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions llvm/include/llvm/CodeGen/SelectionDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -1184,11 +1184,11 @@ class SelectionDAG {
SDValue getPOISON(EVT VT) { return getNode(ISD::POISON, SDLoc(), VT); }

/// Return a node that represents the runtime scaling 'MulImm * RuntimeVL'.
LLVM_ABI SDValue getVScale(const SDLoc &DL, EVT VT, APInt MulImm,
bool ConstantFold = true);
LLVM_ABI SDValue getVScale(const SDLoc &DL, EVT VT, APInt MulImm);

LLVM_ABI SDValue getElementCount(const SDLoc &DL, EVT VT, ElementCount EC,
bool ConstantFold = true);
LLVM_ABI SDValue getElementCount(const SDLoc &DL, EVT VT, ElementCount EC);

LLVM_ABI SDValue getTypeSize(const SDLoc &DL, EVT VT, TypeSize TS);

/// Return a GLOBAL_OFFSET_TABLE node. This does not have a useful SDLoc.
SDValue getGLOBAL_OFFSET_TABLE(EVT VT) {
Expand Down
6 changes: 2 additions & 4 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1702,10 +1702,8 @@ void DAGTypeLegalizer::SplitVecRes_LOOP_DEPENDENCE_MASK(SDNode *N, SDValue &Lo,
Lo = DAG.getNode(N->getOpcode(), DL, LoVT, PtrA, PtrB, N->getOperand(2));

unsigned EltSize = N->getConstantOperandVal(2);
unsigned Offset = EltSize * HiVT.getVectorMinNumElements();
SDValue Addend = HiVT.isScalableVT()
? DAG.getVScale(DL, MVT::i64, APInt(64, Offset))
: DAG.getConstant(Offset, DL, MVT::i64);
ElementCount Offset = HiVT.getVectorElementCount() * EltSize;
SDValue Addend = DAG.getElementCount(DL, MVT::i64, Offset);

PtrA = DAG.getNode(ISD::ADD, DL, MVT::i64, PtrA, Addend);
Hi = DAG.getNode(N->getOpcode(), DL, HiVT, PtrA, PtrB, N->getOperand(2));
Expand Down
57 changes: 27 additions & 30 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2083,32 +2083,41 @@ SDValue SelectionDAG::getCondCode(ISD::CondCode Cond) {
return SDValue(CondCodeNodes[Cond], 0);
}

SDValue SelectionDAG::getVScale(const SDLoc &DL, EVT VT, APInt MulImm,
bool ConstantFold) {
SDValue SelectionDAG::getVScale(const SDLoc &DL, EVT VT, APInt MulImm) {
assert(MulImm.getBitWidth() == VT.getSizeInBits() &&
"APInt size does not match type size!");

if (MulImm == 0)
return getConstant(0, DL, VT);

if (ConstantFold) {
const MachineFunction &MF = getMachineFunction();
const Function &F = MF.getFunction();
ConstantRange CR = getVScaleRange(&F, 64);
if (const APInt *C = CR.getSingleElement())
return getConstant(MulImm * C->getZExtValue(), DL, VT);
}
const MachineFunction &MF = getMachineFunction();
const Function &F = MF.getFunction();
ConstantRange CR = getVScaleRange(&F, 64);
if (const APInt *C = CR.getSingleElement())
return getConstant(MulImm * C->getZExtValue(), DL, VT);

return getNode(ISD::VSCALE, DL, VT, getConstant(MulImm, DL, VT));
}

SDValue SelectionDAG::getElementCount(const SDLoc &DL, EVT VT, ElementCount EC,
bool ConstantFold) {
if (EC.isScalable())
return getVScale(DL, VT,
APInt(VT.getSizeInBits(), EC.getKnownMinValue()));
/// \returns a value of type \p VT that represents the runtime value of \p X,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
/// \returns a value of type \p VT that represents the runtime value of \p X,
/// \returns a value of type \p VT that represents the runtime value of \p Quantity,

maybe also say that this can only be an ElementCount or a TypeSize?

/// i.e. scaled by vscale if it's scalable, or a fixed constant otherwise.
template <typename Ty>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add documentation for this function to describe what VT and X mean and how they are used to create the end result?

Also, why does it need to be a templated function? Is that for distinguishing EVT and MVT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment in 5c13c4a.

It's templated since it needs to work over both ElementType and TypeSize, both of which inherit from details::FixedOrScalableQuantity<LeafTy, ScalarTy>, but with different LeafTy and ScalarTys.

static SDValue getFixedOrScalableQuantity(SelectionDAG &DAG, const SDLoc &DL,
EVT VT, Ty Quantity) {
if (Quantity.isScalable())
return DAG.getVScale(
DL, VT, APInt(VT.getSizeInBits(), Quantity.getKnownMinValue()));

return DAG.getConstant(Quantity.getKnownMinValue(), DL, VT);
}

SDValue SelectionDAG::getElementCount(const SDLoc &DL, EVT VT,
ElementCount EC) {
return getFixedOrScalableQuantity(*this, DL, VT, EC);
}

return getConstant(EC.getKnownMinValue(), DL, VT);
SDValue SelectionDAG::getTypeSize(const SDLoc &DL, EVT VT, TypeSize TS) {
return getFixedOrScalableQuantity(*this, DL, VT, TS);
}

SDValue SelectionDAG::getStepVector(const SDLoc &DL, EVT ResVT) {
Expand Down Expand Up @@ -8485,16 +8494,7 @@ static SDValue getMemsetStringVal(EVT VT, const SDLoc &dl, SelectionDAG &DAG,
SDValue SelectionDAG::getMemBasePlusOffset(SDValue Base, TypeSize Offset,
const SDLoc &DL,
const SDNodeFlags Flags) {
EVT VT = Base.getValueType();
SDValue Index;

if (Offset.isScalable())
Index = getVScale(DL, Base.getValueType(),
APInt(Base.getValueSizeInBits().getFixedValue(),
Offset.getKnownMinValue()));
else
Index = getConstant(Offset.getFixedValue(), DL, VT);

SDValue Index = getTypeSize(DL, Base.getValueType(), Offset);
return getMemBasePlusOffset(Base, Index, DL, Flags);
}

Expand Down Expand Up @@ -13570,11 +13570,8 @@ std::pair<SDValue, SDValue> SelectionDAG::SplitEVL(SDValue N, EVT VecVT,
EVT VT = N.getValueType();
assert(VecVT.getVectorElementCount().isKnownEven() &&
"Expecting the mask to be an evenly-sized vector");
unsigned HalfMinNumElts = VecVT.getVectorMinNumElements() / 2;
SDValue HalfNumElts =
VecVT.isFixedLengthVector()
? getConstant(HalfMinNumElts, DL, VT)
: getVScale(DL, VT, APInt(VT.getScalarSizeInBits(), HalfMinNumElts));
SDValue HalfNumElts = getElementCount(
DL, VT, VecVT.getVectorElementCount().divideCoefficientBy(2));
SDValue Lo = getNode(ISD::UMIN, DL, VT, N, HalfNumElts);
SDValue Hi = getNode(ISD::USUBSAT, DL, VT, N, HalfNumElts);
return std::make_pair(Lo, Hi);
Expand Down
14 changes: 3 additions & 11 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4583,17 +4583,9 @@ void SelectionDAGBuilder::visitAlloca(const AllocaInst &I) {
if (AllocSize.getValueType() != IntPtr)
AllocSize = DAG.getZExtOrTrunc(AllocSize, dl, IntPtr);

if (TySize.isScalable())
AllocSize = DAG.getNode(ISD::MUL, dl, IntPtr, AllocSize,
DAG.getVScale(dl, IntPtr,
APInt(IntPtr.getScalarSizeInBits(),
TySize.getKnownMinValue())));
else {
SDValue TySizeValue =
DAG.getConstant(TySize.getFixedValue(), dl, MVT::getIntegerVT(64));
AllocSize = DAG.getNode(ISD::MUL, dl, IntPtr, AllocSize,
DAG.getZExtOrTrunc(TySizeValue, dl, IntPtr));
}
AllocSize = DAG.getNode(
ISD::MUL, dl, IntPtr, AllocSize,
DAG.getZExtOrTrunc(DAG.getTypeSize(dl, MVT::i64, TySize), dl, IntPtr));

// Handle alignment. If the requested alignment is less than or equal to
// the stack alignment, ignore it. If the size is greater than or equal to
Expand Down
21 changes: 5 additions & 16 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10625,12 +10625,8 @@ TargetLowering::IncrementMemoryAddress(SDValue Addr, SDValue Mask,
SDValue Scale = DAG.getConstant(DataVT.getScalarSizeInBits() / 8, DL,
AddrVT);
Increment = DAG.getNode(ISD::MUL, DL, AddrVT, Increment, Scale);
} else if (DataVT.isScalableVector()) {
Increment = DAG.getVScale(DL, AddrVT,
APInt(AddrVT.getFixedSizeInBits(),
DataVT.getStoreSize().getKnownMinValue()));
} else
Increment = DAG.getConstant(DataVT.getStoreSize(), DL, AddrVT);
Increment = DAG.getTypeSize(DL, AddrVT, DataVT.getStoreSize());

return DAG.getNode(ISD::ADD, DL, AddrVT, Addr, Increment);
}
Expand Down Expand Up @@ -11923,10 +11919,8 @@ SDValue TargetLowering::expandVectorSplice(SDNode *Node,
// Store the lo part of CONCAT_VECTORS(V1, V2)
SDValue StoreV1 = DAG.getStore(DAG.getEntryNode(), DL, V1, StackPtr, PtrInfo);
// Store the hi part of CONCAT_VECTORS(V1, V2)
SDValue OffsetToV2 = DAG.getVScale(
DL, PtrVT,
APInt(PtrVT.getFixedSizeInBits(), VT.getStoreSize().getKnownMinValue()));
SDValue StackPtr2 = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, OffsetToV2);
SDValue VTBytes = DAG.getTypeSize(DL, PtrVT, VT.getStoreSize());
SDValue StackPtr2 = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, VTBytes);
SDValue StoreV2 = DAG.getStore(StoreV1, DL, V2, StackPtr2, PtrInfo);

if (Imm >= 0) {
Expand All @@ -11945,13 +11939,8 @@ SDValue TargetLowering::expandVectorSplice(SDNode *Node,
SDValue TrailingBytes =
DAG.getConstant(TrailingElts * EltByteSize, DL, PtrVT);

if (TrailingElts > VT.getVectorMinNumElements()) {
SDValue VLBytes =
DAG.getVScale(DL, PtrVT,
APInt(PtrVT.getFixedSizeInBits(),
VT.getStoreSize().getKnownMinValue()));
TrailingBytes = DAG.getNode(ISD::UMIN, DL, PtrVT, TrailingBytes, VLBytes);
}
if (TrailingElts > VT.getVectorMinNumElements())
TrailingBytes = DAG.getNode(ISD::UMIN, DL, PtrVT, TrailingBytes, VTBytes);

// Calculate the start address of the spliced result.
StackPtr2 = DAG.getNode(ISD::SUB, DL, PtrVT, StackPtr2, TrailingBytes);
Expand Down
33 changes: 9 additions & 24 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8639,7 +8639,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
Subtarget->isWindowsArm64EC()) &&
"Indirect arguments should be scalable on most subtargets");

uint64_t PartSize = VA.getValVT().getStoreSize().getKnownMinValue();
TypeSize PartSize = VA.getValVT().getStoreSize();
unsigned NumParts = 1;
if (Ins[i].Flags.isInConsecutiveRegs()) {
while (!Ins[i + NumParts - 1].Flags.isInConsecutiveRegsLast())
Expand All @@ -8656,16 +8656,8 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
InVals.push_back(ArgValue);
NumParts--;
if (NumParts > 0) {
SDValue BytesIncrement;
if (PartLoad.isScalableVector()) {
BytesIncrement = DAG.getVScale(
DL, Ptr.getValueType(),
APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize));
} else {
BytesIncrement = DAG.getConstant(
APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize), DL,
Ptr.getValueType());
}
SDValue BytesIncrement =
DAG.getTypeSize(DL, Ptr.getValueType(), PartSize);
Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
BytesIncrement, SDNodeFlags::NoUnsignedWrap);
ExtraArgLocs++;
Expand Down Expand Up @@ -9868,8 +9860,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
assert((isScalable || Subtarget->isWindowsArm64EC()) &&
"Indirect arguments should be scalable on most subtargets");

uint64_t StoreSize = VA.getValVT().getStoreSize().getKnownMinValue();
uint64_t PartSize = StoreSize;
TypeSize StoreSize = VA.getValVT().getStoreSize();
TypeSize PartSize = StoreSize;
unsigned NumParts = 1;
if (Outs[i].Flags.isInConsecutiveRegs()) {
while (!Outs[i + NumParts - 1].Flags.isInConsecutiveRegsLast())
Expand All @@ -9880,7 +9872,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
Type *Ty = EVT(VA.getValVT()).getTypeForEVT(*DAG.getContext());
Align Alignment = DAG.getDataLayout().getPrefTypeAlign(Ty);
MachineFrameInfo &MFI = MF.getFrameInfo();
int FI = MFI.CreateStackObject(StoreSize, Alignment, false);
int FI =
MFI.CreateStackObject(StoreSize.getKnownMinValue(), Alignment, false);
if (isScalable) {
bool IsPred = VA.getValVT() == MVT::aarch64svcount ||
VA.getValVT().getVectorElementType() == MVT::i1;
Expand All @@ -9901,16 +9894,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,

NumParts--;
if (NumParts > 0) {
SDValue BytesIncrement;
if (isScalable) {
BytesIncrement = DAG.getVScale(
DL, Ptr.getValueType(),
APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize));
} else {
BytesIncrement = DAG.getConstant(
APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize), DL,
Ptr.getValueType());
}
SDValue BytesIncrement =
DAG.getTypeSize(DL, Ptr.getValueType(), PartSize);
MPI = MachinePointerInfo(MPI.getAddrSpace());
Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
BytesIncrement, SDNodeFlags::NoUnsignedWrap);
Expand Down
8 changes: 2 additions & 6 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12739,10 +12739,7 @@ SDValue RISCVTargetLowering::lowerVECTOR_INTERLEAVE(SDValue Op,

SmallVector<SDValue, 8> Loads(Factor);

SDValue Increment =
DAG.getVScale(DL, PtrVT,
APInt(PtrVT.getFixedSizeInBits(),
VecVT.getStoreSize().getKnownMinValue()));
SDValue Increment = DAG.getTypeSize(DL, PtrVT, VecVT.getStoreSize());
for (unsigned i = 0; i != Factor; ++i) {
if (i != 0)
StackPtr = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, Increment);
Expand Down Expand Up @@ -14140,9 +14137,8 @@ RISCVTargetLowering::lowerVPReverseExperimental(SDValue Op,

// Slide off any elements from past EVL that were reversed into the low
// elements.
unsigned MinElts = GatherVT.getVectorMinNumElements();
SDValue VLMax =
DAG.getVScale(DL, XLenVT, APInt(XLenVT.getSizeInBits(), MinElts));
DAG.getElementCount(DL, XLenVT, GatherVT.getVectorElementCount());
SDValue Diff = DAG.getNode(ISD::SUB, DL, XLenVT, VLMax, EVL);

Result = getVSlidedown(DAG, Subtarget, DL, GatherVT,
Expand Down