Skip to content

Commit 6ec5661

Browse files
committed
Add masked load lowering support
1 parent 8e7140a commit 6ec5661

15 files changed

+584
-37
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2449,6 +2449,7 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
24492449
SDValue PassThru = MLD->getPassThru();
24502450
Align Alignment = MLD->getBaseAlign();
24512451
ISD::LoadExtType ExtType = MLD->getExtensionType();
2452+
MachineMemOperand::Flags MMOFlags = MLD->getMemOperand()->getFlags();
24522453

24532454
// Split Mask operand
24542455
SDValue MaskLo, MaskHi;
@@ -2474,9 +2475,8 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
24742475
std::tie(PassThruLo, PassThruHi) = DAG.SplitVector(PassThru, dl);
24752476

24762477
MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
2477-
MLD->getPointerInfo(), MachineMemOperand::MOLoad,
2478-
LocationSize::beforeOrAfterPointer(), Alignment, MLD->getAAInfo(),
2479-
MLD->getRanges());
2478+
MLD->getPointerInfo(), MMOFlags, LocationSize::beforeOrAfterPointer(),
2479+
Alignment, MLD->getAAInfo(), MLD->getRanges());
24802480

24812481
Lo = DAG.getMaskedLoad(LoVT, dl, Ch, Ptr, Offset, MaskLo, PassThruLo, LoMemVT,
24822482
MMO, MLD->getAddressingMode(), ExtType,
@@ -2499,8 +2499,8 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
24992499
LoMemVT.getStoreSize().getFixedValue());
25002500

25012501
MMO = DAG.getMachineFunction().getMachineMemOperand(
2502-
MPI, MachineMemOperand::MOLoad, LocationSize::beforeOrAfterPointer(),
2503-
Alignment, MLD->getAAInfo(), MLD->getRanges());
2502+
MPI, MMOFlags, LocationSize::beforeOrAfterPointer(), Alignment,
2503+
MLD->getAAInfo(), MLD->getRanges());
25042504

25052505
Hi = DAG.getMaskedLoad(HiVT, dl, Ch, Ptr, Offset, MaskHi, PassThruHi,
25062506
HiMemVT, MMO, MLD->getAddressingMode(), ExtType,

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5010,6 +5010,8 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) {
50105010
auto MMOFlags = MachineMemOperand::MOLoad;
50115011
if (I.hasMetadata(LLVMContext::MD_nontemporal))
50125012
MMOFlags |= MachineMemOperand::MONonTemporal;
5013+
if (I.hasMetadata(LLVMContext::MD_invariant_load))
5014+
MMOFlags |= MachineMemOperand::MOInvariant;
50135015

50145016
MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
50155017
MachinePointerInfo(PtrOperand), MMOFlags,

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,16 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
395395
}
396396
}
397397

398+
void NVPTXInstPrinter::printUsedBytesMaskPragma(const MCInst *MI, int OpNum,
399+
raw_ostream &O) {
400+
auto &Op = MI->getOperand(OpNum);
401+
assert(Op.isImm() && "Invalid operand");
402+
uint32_t Imm = (uint32_t)Op.getImm();
403+
if (Imm != UINT32_MAX) {
404+
O << ".pragma \"used_bytes_mask " << Imm << "\";\n\t";
405+
}
406+
}
407+
398408
void NVPTXInstPrinter::printRegisterOrSinkSymbol(const MCInst *MI, int OpNum,
399409
raw_ostream &O,
400410
const char *Modifier) {

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class NVPTXInstPrinter : public MCInstPrinter {
4646
StringRef Modifier = {});
4747
void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O,
4848
StringRef Modifier = {});
49+
void printUsedBytesMaskPragma(const MCInst *MI, int OpNum, raw_ostream &O);
4950
void printRegisterOrSinkSymbol(const MCInst *MI, int OpNum, raw_ostream &O,
5051
const char *Modifier = nullptr);
5152
void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O);

llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI,
9696
const MachineOperand *ParamSymbol = Mov.uses().begin();
9797
assert(ParamSymbol->isSymbol());
9898

99-
constexpr unsigned LDInstBasePtrOpIdx = 5;
99+
constexpr unsigned LDInstBasePtrOpIdx = 6;
100100
constexpr unsigned LDInstAddrSpaceOpIdx = 2;
101101
for (auto *LI : LoadInsts) {
102102
(LI->uses().begin() + LDInstBasePtrOpIdx)

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
105105
switch (N->getOpcode()) {
106106
case ISD::LOAD:
107107
case ISD::ATOMIC_LOAD:
108+
case NVPTXISD::MLoadV1:
108109
if (tryLoad(N))
109110
return;
110111
break;
@@ -1132,6 +1133,19 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
11321133
? NVPTX::PTXLdStInstCode::Signed
11331134
: NVPTX::PTXLdStInstCode::Untyped;
11341135

1136+
uint32_t UsedBytesMask;
1137+
switch (N->getOpcode()) {
1138+
case ISD::LOAD:
1139+
case ISD::ATOMIC_LOAD:
1140+
UsedBytesMask = UINT32_MAX;
1141+
break;
1142+
case NVPTXISD::MLoadV1:
1143+
UsedBytesMask = N->getConstantOperandVal(N->getNumOperands() - 2);
1144+
break;
1145+
default:
1146+
llvm_unreachable("Unexpected opcode");
1147+
}
1148+
11351149
assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
11361150
FromTypeWidth <= 128 && "Invalid width for load");
11371151

@@ -1142,6 +1156,7 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
11421156
getI32Imm(CodeAddrSpace, DL),
11431157
getI32Imm(FromType, DL),
11441158
getI32Imm(FromTypeWidth, DL),
1159+
getI32Imm(UsedBytesMask, DL),
11451160
Base,
11461161
Offset,
11471162
Chain};
@@ -1204,6 +1219,8 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
12041219
: NVPTX::PTXLdStInstCode::Untyped;
12051220

12061221
const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD);
1222+
const uint32_t UsedBytesMask =
1223+
N->getConstantOperandVal(N->getNumOperands() - 2);
12071224

12081225
assert(!(EltVT.isVector() && ExtensionType != ISD::NON_EXTLOAD));
12091226

@@ -1213,6 +1230,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
12131230
getI32Imm(CodeAddrSpace, DL),
12141231
getI32Imm(FromType, DL),
12151232
getI32Imm(FromTypeWidth, DL),
1233+
getI32Imm(UsedBytesMask, DL),
12161234
Base,
12171235
Offset,
12181236
Chain};
@@ -1250,10 +1268,13 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
12501268
SDLoc DL(LD);
12511269

12521270
unsigned ExtensionType;
1271+
uint32_t UsedBytesMask;
12531272
if (const auto *Load = dyn_cast<LoadSDNode>(LD)) {
12541273
ExtensionType = Load->getExtensionType();
1274+
UsedBytesMask = UINT32_MAX;
12551275
} else {
12561276
ExtensionType = LD->getConstantOperandVal(LD->getNumOperands() - 1);
1277+
UsedBytesMask = LD->getConstantOperandVal(LD->getNumOperands() - 2);
12571278
}
12581279
const unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
12591280
? NVPTX::PTXLdStInstCode::Signed
@@ -1265,8 +1286,12 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
12651286
ExtensionType != ISD::NON_EXTLOAD));
12661287

12671288
const auto [Base, Offset] = selectADDR(LD->getOperand(1), CurDAG);
1268-
SDValue Ops[] = {getI32Imm(FromType, DL), getI32Imm(FromTypeWidth, DL), Base,
1269-
Offset, LD->getChain()};
1289+
SDValue Ops[] = {getI32Imm(FromType, DL),
1290+
getI32Imm(FromTypeWidth, DL),
1291+
getI32Imm(UsedBytesMask, DL),
1292+
Base,
1293+
Offset,
1294+
LD->getChain()};
12701295

12711296
const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;
12721297
std::optional<unsigned> Opcode;
@@ -1277,6 +1302,10 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
12771302
Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_i16,
12781303
NVPTX::LD_GLOBAL_NC_i32, NVPTX::LD_GLOBAL_NC_i64);
12791304
break;
1305+
case NVPTXISD::MLoadV1:
1306+
Opcode = pickOpcodeForVT(TargetVT, std::nullopt, NVPTX::LD_GLOBAL_NC_i32,
1307+
NVPTX::LD_GLOBAL_NC_i64);
1308+
break;
12801309
case NVPTXISD::LoadV2:
12811310
Opcode =
12821311
pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_v2i16,

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 109 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
769769
setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::f128}, Custom);
770770
for (MVT VT : MVT::fixedlen_vector_valuetypes())
771771
if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
772-
setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE}, VT, Custom);
772+
setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE, ISD::MLOAD}, VT,
773+
Custom);
773774

774775
// Custom legalization for LDU intrinsics.
775776
// TODO: The logic to lower these is not very robust and we should rewrite it.
@@ -1130,6 +1131,7 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
11301131
MAKE_CASE(NVPTXISD::LoadV2)
11311132
MAKE_CASE(NVPTXISD::LoadV4)
11321133
MAKE_CASE(NVPTXISD::LoadV8)
1134+
MAKE_CASE(NVPTXISD::MLoadV1)
11331135
MAKE_CASE(NVPTXISD::LDUV2)
11341136
MAKE_CASE(NVPTXISD::LDUV4)
11351137
MAKE_CASE(NVPTXISD::StoreV2)
@@ -3306,6 +3308,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
33063308
}
33073309
case ISD::LOAD:
33083310
return LowerLOAD(Op, DAG);
3311+
case ISD::MLOAD:
3312+
return LowerMLOAD(Op, DAG);
33093313
case ISD::SHL_PARTS:
33103314
return LowerShiftLeftParts(Op, DAG);
33113315
case ISD::SRA_PARTS:
@@ -3497,10 +3501,58 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
34973501
MachinePointerInfo(SV));
34983502
}
34993503

3504+
static std::tuple<MemSDNode *, uint32_t>
3505+
convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG) {
3506+
SDValue Chain = N->getOperand(0);
3507+
SDValue BasePtr = N->getOperand(1);
3508+
SDValue Mask = N->getOperand(3);
3509+
SDValue Passthru = N->getOperand(4);
3510+
3511+
SDLoc DL(N);
3512+
EVT ResVT = N->getValueType(0);
3513+
assert(ResVT.isVector() && "Masked vector load must have vector type");
3514+
// While we only expect poison passthru vectors as an input to the backend,
3515+
// when the legalization framework splits a poison vector in half, it creates
3516+
// two undef vectors, so we can technically expect those too.
3517+
assert((Passthru.getOpcode() == ISD::POISON ||
3518+
Passthru.getOpcode() == ISD::UNDEF) &&
3519+
"Passthru operand expected to be poison or undef");
3520+
3521+
// Extract the mask and convert it to a uint32_t representing the used bytes
3522+
// of the entire vector load
3523+
uint32_t UsedBytesMask = 0;
3524+
uint32_t ElementSizeInBits = ResVT.getVectorElementType().getSizeInBits();
3525+
assert(ElementSizeInBits % 8 == 0 && "Unexpected element size");
3526+
uint32_t ElementSizeInBytes = ElementSizeInBits / 8;
3527+
uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u;
3528+
3529+
for (unsigned I :
3530+
llvm::reverse(llvm::seq<unsigned>(0, ResVT.getVectorNumElements()))) {
3531+
assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
3532+
"Mask elements must be constants");
3533+
// We technically only want to do this shift for every iteration *but* the
3534+
// first, but in the first iteration NewMask is 0, so this shift is a
3535+
// no-op.
3536+
UsedBytesMask <<= ElementSizeInBytes;
3537+
3538+
if (Mask->getConstantOperandVal(I) != 0)
3539+
UsedBytesMask |= ElementMask;
3540+
}
3541+
3542+
assert(UsedBytesMask != 0 && UsedBytesMask != UINT32_MAX &&
3543+
"Unexpected masked load with elements masked all on or all off");
3544+
3545+
// Create a new load sd node to be handled normally by ReplaceLoadVector.
3546+
MemSDNode *NewLD = cast<MemSDNode>(
3547+
DAG.getLoad(ResVT, DL, Chain, BasePtr, N->getMemOperand()).getNode());
3548+
3549+
return {NewLD, UsedBytesMask};
3550+
}
3551+
35003552
/// replaceLoadVector - Convert vector loads into multi-output scalar loads.
35013553
static std::optional<std::pair<SDValue, SDValue>>
35023554
replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
3503-
LoadSDNode *LD = cast<LoadSDNode>(N);
3555+
MemSDNode *LD = cast<MemSDNode>(N);
35043556
const EVT ResVT = LD->getValueType(0);
35053557
const EVT MemVT = LD->getMemoryVT();
35063558

@@ -3527,6 +3579,14 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
35273579
return std::nullopt;
35283580
}
35293581

3582+
// If we have a masked load, convert it to a normal load now
3583+
std::optional<uint32_t> UsedBytesMask = std::nullopt;
3584+
if (LD->getOpcode() == ISD::MLOAD) {
3585+
auto Result = convertMLOADToLoadWithUsedBytesMask(LD, DAG);
3586+
LD = std::get<0>(Result);
3587+
UsedBytesMask = std::get<1>(Result);
3588+
}
3589+
35303590
// Since LoadV2 is a target node, we cannot rely on DAG type legalization.
35313591
// Therefore, we must ensure the type is legal. For i1 and i8, we set the
35323592
// loaded type to i16 and propagate the "real" type as the memory type.
@@ -3555,9 +3615,13 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
35553615
// Copy regular operands
35563616
SmallVector<SDValue, 8> OtherOps(LD->ops());
35573617

3618+
OtherOps.push_back(
3619+
DAG.getConstant(UsedBytesMask.value_or(UINT32_MAX), DL, MVT::i32));
3620+
35583621
// The select routine does not have access to the LoadSDNode instance, so
35593622
// pass along the extension information
3560-
OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
3623+
OtherOps.push_back(
3624+
DAG.getIntPtrConstant(cast<LoadSDNode>(LD)->getExtensionType(), DL));
35613625

35623626
SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps, MemVT,
35633627
LD->getMemOperand());
@@ -3645,6 +3709,43 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
36453709
llvm_unreachable("Unexpected custom lowering for load");
36463710
}
36473711

3712+
SDValue NVPTXTargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
3713+
// v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
3714+
// masked loads of these types and have to handle them here.
3715+
// v2f32 also needs to be handled here if the subtarget has f32x2
3716+
// instructions, making it legal.
3717+
//
3718+
// Note: misaligned masked loads should never reach this point
3719+
// because the override of isLegalMaskedLoad in NVPTXTargetTransformInfo.cpp
3720+
// will validate alignment. Therefore, we do not need to special case handle
3721+
// them here.
3722+
EVT VT = Op.getValueType();
3723+
if (NVPTX::isPackedVectorTy(VT) &&
3724+
(VT != MVT::v2f32 || STI.hasF32x2Instructions())) {
3725+
auto Result =
3726+
convertMLOADToLoadWithUsedBytesMask(cast<MemSDNode>(Op.getNode()), DAG);
3727+
MemSDNode *LD = std::get<0>(Result);
3728+
uint32_t UsedBytesMask = std::get<1>(Result);
3729+
3730+
SDLoc DL(LD);
3731+
3732+
// Copy regular operands
3733+
SmallVector<SDValue, 8> OtherOps(LD->ops());
3734+
3735+
OtherOps.push_back(DAG.getConstant(UsedBytesMask, DL, MVT::i32));
3736+
3737+
// The select routine does not have access to the LoadSDNode instance, so
3738+
// pass along the extension information
3739+
OtherOps.push_back(
3740+
DAG.getIntPtrConstant(cast<LoadSDNode>(LD)->getExtensionType(), DL));
3741+
SDValue NewLD = DAG.getMemIntrinsicNode(
3742+
NVPTXISD::MLoadV1, DL, LD->getVTList(), OtherOps, LD->getMemoryVT(),
3743+
LD->getMemOperand());
3744+
return NewLD;
3745+
}
3746+
return SDValue();
3747+
}
3748+
36483749
static SDValue lowerSTOREVector(SDValue Op, SelectionDAG &DAG,
36493750
const NVPTXSubtarget &STI) {
36503751
MemSDNode *N = cast<MemSDNode>(Op.getNode());
@@ -5555,9 +5656,13 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
55555656
// ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
55565657
// here.
55575658
Opcode = NVPTXISD::LoadV2;
5659+
// append a "full" used bytes mask operand right before the extension type
5660+
// operand, signifying that all bytes are used.
5661+
Operands.push_back(DCI.DAG.getConstant(UINT32_MAX, DL, MVT::i32));
55585662
Operands.push_back(DCI.DAG.getIntPtrConstant(
55595663
cast<LoadSDNode>(LD)->getExtensionType(), DL));
55605664
break;
5665+
// TODO do we need to support MLoadV1 here?
55615666
case NVPTXISD::LoadV2:
55625667
OldNumOutputs = 2;
55635668
Opcode = NVPTXISD::LoadV4;
@@ -6793,6 +6898,7 @@ void NVPTXTargetLowering::ReplaceNodeResults(
67936898
ReplaceBITCAST(N, DAG, Results);
67946899
return;
67956900
case ISD::LOAD:
6901+
case ISD::MLOAD:
67966902
replaceLoadVector(N, DAG, Results, STI);
67976903
return;
67986904
case ISD::INTRINSIC_W_CHAIN:

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ enum NodeType : unsigned {
9999
LoadV2,
100100
LoadV4,
101101
LoadV8,
102+
MLoadV1,
102103
LDUV2, // LDU.v2
103104
LDUV4, // LDU.v4
104105
StoreV2,
@@ -349,6 +350,7 @@ class NVPTXTargetLowering : public TargetLowering {
349350
SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const;
350351

351352
SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
353+
SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) const;
352354
SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
353355
SDValue LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const;
354356

0 commit comments

Comments
 (0)