Skip to content

Commit dc244cb

Browse files
committed
Review feedback
1 parent 8b0b5ec commit dc244cb

File tree

2 files changed

+15
-24
lines changed

2 files changed

+15
-24
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2918,10 +2918,9 @@ static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
29182918
assert(Mask.getValueType().getVectorNumElements() ==
29192919
ValVT.getVectorNumElements() &&
29202920
"Mask size must be the same as the vector size");
2921-
for (unsigned I : llvm::seq(ValVT.getVectorNumElements())) {
2922-
assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
2923-
"Mask elements must be constants");
2924-
if (Mask->getConstantOperandVal(I) == 0) {
2921+
for (auto [I, Op] : enumerate(Mask->ops())) {
2922+
// Mask elements must be constants.
2923+
if (Op.getNode()->getAsZExtVal() == 0) {
29252924
// Append a sentinel register 0 to the Ops vector to represent a masked
29262925
// off element, this will be handled in tablegen
29272926
Ops.push_back(DAG.getRegister(MCRegister::NoRegister,
@@ -3189,7 +3188,7 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
31893188
MachinePointerInfo(SV));
31903189
}
31913190

3192-
static std::tuple<MemSDNode *, uint32_t>
3191+
static std::pair<MemSDNode *, uint32_t>
31933192
convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG) {
31943193
SDValue Chain = N->getOperand(0);
31953194
SDValue BasePtr = N->getOperand(1);
@@ -3214,16 +3213,14 @@ convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG) {
32143213
uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
32153214
uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
32163215

3217-
for (unsigned I :
3218-
llvm::reverse(llvm::seq<unsigned>(0, ResVT.getVectorNumElements()))) {
3219-
assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
3220-
"Mask elements must be constants");
3221-
// We technically only want to do this shift for every iteration *but* the
3222-
// first, but in the first iteration NewMask is 0, so this shift is a
3223-
// no-op.
3216+
for (SDValue Op : llvm::reverse(Mask->ops())) {
3217+
// We technically only want to do this shift for every
3218+
// iteration *but* the first, but in the first iteration NewMask is 0, so
3219+
// this shift is a no-op.
32243220
UsedBytesMask <<= ElementSizeInBytes;
32253221

3226-
if (Mask->getConstantOperandVal(I) != 0)
3222+
// Mask elements must be constants.
3223+
if (Op->getAsZExtVal() != 0)
32273224
UsedBytesMask |= ElementMask;
32283225
}
32293226

@@ -3269,11 +3266,8 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
32693266

32703267
// If we have a masked load, convert it to a normal load now
32713268
std::optional<uint32_t> UsedBytesMask = std::nullopt;
3272-
if (LD->getOpcode() == ISD::MLOAD) {
3273-
auto Result = convertMLOADToLoadWithUsedBytesMask(LD, DAG);
3274-
LD = std::get<0>(Result);
3275-
UsedBytesMask = std::get<1>(Result);
3276-
}
3269+
if (LD->getOpcode() == ISD::MLOAD)
3270+
std::tie(LD, UsedBytesMask) = convertMLOADToLoadWithUsedBytesMask(LD, DAG);
32773271

32783272
// Since LoadV2 is a target node, we cannot rely on DAG type legalization.
32793273
// Therefore, we must ensure the type is legal. For i1 and i8, we set the

llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -615,12 +615,9 @@ bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
615615
if (!VTy)
616616
return false;
617617

618-
auto *ScalarTy = VTy->getScalarType();
619-
if ((ScalarTy->getScalarSizeInBits() == 32 && VTy->getNumElements() == 8) ||
620-
(ScalarTy->getScalarSizeInBits() == 64 && VTy->getNumElements() == 4))
621-
return true;
622-
623-
return false;
618+
auto *ElemTy = VTy->getScalarType();
619+
return (ElemTy->getScalarSizeInBits() == 32 && VTy->getNumElements() == 8) ||
620+
(ElemTy->getScalarSizeInBits() == 64 && VTy->getNumElements() == 4);
624621
}
625622

626623
bool NVPTXTTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,

0 commit comments

Comments
 (0)