Skip to content

Commit abbf0ba

Browse files
committed
Review feedback
1 parent c44f8a2 commit abbf0ba

File tree

2 files changed

+15
-24
lines changed

2 files changed

+15
-24
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3230,10 +3230,9 @@ static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
32303230
assert(Mask.getValueType().getVectorNumElements() ==
32313231
ValVT.getVectorNumElements() &&
32323232
"Mask size must be the same as the vector size");
3233-
for (unsigned I : llvm::seq(ValVT.getVectorNumElements())) {
3234-
assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
3235-
"Mask elements must be constants");
3236-
if (Mask->getConstantOperandVal(I) == 0) {
3233+
for (auto [I, Op] : enumerate(Mask->ops())) {
3234+
// Mask elements must be constants.
3235+
if (Op.getNode()->getAsZExtVal() == 0) {
32373236
// Append a sentinel register 0 to the Ops vector to represent a masked
32383237
// off element, this will be handled in tablegen
32393238
Ops.push_back(DAG.getRegister(MCRegister::NoRegister,
@@ -3501,7 +3500,7 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
35013500
MachinePointerInfo(SV));
35023501
}
35033502

3504-
static std::tuple<MemSDNode *, uint32_t>
3503+
static std::pair<MemSDNode *, uint32_t>
35053504
convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG) {
35063505
SDValue Chain = N->getOperand(0);
35073506
SDValue BasePtr = N->getOperand(1);
@@ -3526,16 +3525,14 @@ convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG) {
35263525
uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
35273526
uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
35283527

3529-
for (unsigned I :
3530-
llvm::reverse(llvm::seq<unsigned>(0, ResVT.getVectorNumElements()))) {
3531-
assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
3532-
"Mask elements must be constants");
3533-
// We technically only want to do this shift for every iteration *but* the
3534-
// first, but in the first iteration NewMask is 0, so this shift is a
3535-
// no-op.
3528+
for (SDValue Op : llvm::reverse(Mask->ops())) {
3529+
// We technically only want to do this shift for every
3530+
// iteration *but* the first, but in the first iteration NewMask is 0, so
3531+
// this shift is a no-op.
35363532
UsedBytesMask <<= ElementSizeInBytes;
35373533

3538-
if (Mask->getConstantOperandVal(I) != 0)
3534+
// Mask elements must be constants.
3535+
if (Op->getAsZExtVal() != 0)
35393536
UsedBytesMask |= ElementMask;
35403537
}
35413538

@@ -3581,11 +3578,8 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
35813578

35823579
// If we have a masked load, convert it to a normal load now
35833580
std::optional<uint32_t> UsedBytesMask = std::nullopt;
3584-
if (LD->getOpcode() == ISD::MLOAD) {
3585-
auto Result = convertMLOADToLoadWithUsedBytesMask(LD, DAG);
3586-
LD = std::get<0>(Result);
3587-
UsedBytesMask = std::get<1>(Result);
3588-
}
3581+
if (LD->getOpcode() == ISD::MLOAD)
3582+
std::tie(LD, UsedBytesMask) = convertMLOADToLoadWithUsedBytesMask(LD, DAG);
35893583

35903584
// Since LoadV2 is a target node, we cannot rely on DAG type legalization.
35913585
// Therefore, we must ensure the type is legal. For i1 and i8, we set the

llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -610,12 +610,9 @@ bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment,
610610
if (!VTy)
611611
return false;
612612

613-
auto *ScalarTy = VTy->getScalarType();
614-
if ((ScalarTy->getScalarSizeInBits() == 32 && VTy->getNumElements() == 8) ||
615-
(ScalarTy->getScalarSizeInBits() == 64 && VTy->getNumElements() == 4))
616-
return true;
617-
618-
return false;
613+
auto *ElemTy = VTy->getScalarType();
614+
return (ElemTy->getScalarSizeInBits() == 32 && VTy->getNumElements() == 8) ||
615+
(ElemTy->getScalarSizeInBits() == 64 && VTy->getNumElements() == 4);
619616
}
620617

621618
bool NVPTXTTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,

0 commit comments

Comments
 (0)