Skip to content

Commit 17852de

Browse files
authored
[NVPTX] Lower LLVM masked vector loads and stores to PTX (#159387)
This backend support will allow the LoadStoreVectorizer, in certain cases, to fill in gaps when creating load/store vectors and generate LLVM masked load/stores (https://llvm.org/docs/LangRef.html#llvm-masked-store-intrinsics). To accomplish this, changes are separated into two parts. This first part has the backend lowering and TTI changes, and a follow up PR will have the LSV generate these intrinsics: #159388. In this backend change, Masked Loads get lowered to PTX with `#pragma "used_bytes_mask" [mask];` (https://docs.nvidia.com/cuda/parallel-thread-execution/#pragma-strings-used-bytes-mask). And Masked Stores get lowered to PTX using the new sink symbol syntax (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st). # TTI Changes TTI changes are needed because NVPTX only supports masked loads/stores with _constant_ masks. `ScalarizeMaskedMemIntrin.cpp` is adjusted to check that the mask is constant and pass that result into the TTI check. Behavior shouldn't change for non-NVPTX targets, which do not care whether the mask is variable or constant when determining legality, but all TTI files that implement these API need to be updated. # Masked store lowering implementation details If the masked stores make it to the NVPTX backend without being scalarized, they are handled by the following: * `NVPTXISelLowering.cpp` - Sets up a custom operation action and handles it in lowerMSTORE. Similar handling to normal store vectors, except we read the mask and place a sentinel register `$noreg` in each position where the mask reads as false. For example, ``` t10: v8i1 = BUILD_VECTOR Constant:i1<-1>, Constant:i1<0>, Constant:i1<0>, Constant:i1<-1>, Constant:i1<-1>, Constant:i1<0>, Constant:i1<0>, Constant:i1<-1> t11: ch = masked_store<(store unknown-size into %ir.lsr.iv28, align 32, addrspace 1)> t5:1, t5, t7, undef:i64, t10 -> STV_i32_v8 killed %13:int32regs, $noreg, $noreg, killed %16:int32regs, killed %17:int32regs, $noreg, $noreg, killed %20:int32regs, 0, 0, 1, 8, 0, 32, %4:int64regs, 0, debug-location !18 :: (store unknown-size into %ir.lsr.iv28, align 32, addrspace 1); ``` * `NVPTXInstInfo.td` - changes the definition of store vectors to allow for a mix of sink symbols and registers. * `NVPXInstPrinter.h/.cpp` - Handles the `$noreg` case by printing "_". # Masked load lowering implementation details Masked loads are routed to normal PTX loads, with one difference: a `#pragma "used_bytes_mask"` is emitted before the load instruction (https://docs.nvidia.com/cuda/parallel-thread-execution/#pragma-strings-used-bytes-mask). To accomplish this, a new operand is added to every NVPTXISD Load type representing this mask. * `NVPTXISelLowering.h/.cpp` - Masked loads are converted into normal NVPTXISD loads with a mask operand in two ways. 1) In type legalization through replaceLoadVector, which is the normal path, and 2) through LowerMLOAD, to handle the legal vector types (v2f16/v2bf16/v2i16/v4i8/v2f32) that will not be type legalized. Both share the same convertMLOADToLoadWithUsedBytesMask helper. Both default this operand to UINT32_MAX, representing all bytes on. For the latter, we need a new `NVPTXISD::MLoadV1` type to represent that edge case because we cannot put the used bytes mask operand on a generic LoadSDNode. * `NVPTXISelDAGToDAG.cpp` - Extract used bytes mask from loads, add them to created machine instructions. * `NVPTXInstPrinter.h/.cpp` - Print the pragma when the used bytes mask isn't all ones. * `NVPTXForwardParams.cpp`, `NVPTXReplaceImageHandles.cpp` - Update manual indexing of load operands to account for new operand. * `NVPTXInsrtInfo.td`, `NVPTXIntrinsics.td` - Add the used bytes mask to the MI definitions. * `NVPTXTagInvariantLoads.cpp` - Ensure that masked loads also get tagged as invariant. Some generic changes that are needed: * `LegalizeVectorTypes.cpp` - Ensure flags are preserved when splitting masked loads. * `SelectionDAGBuilder.cpp` - Preserve `MD_invariant_load` on masked load SDNode creation
1 parent 1a03673 commit 17852de

36 files changed

+1195
-106
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -842,12 +842,20 @@ class TargetTransformInfo {
842842
LLVM_ABI AddressingModeKind
843843
getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const;
844844

845+
/// Some targets only support masked load/store with a constant mask.
846+
enum MaskKind {
847+
VariableOrConstantMask,
848+
ConstantMask,
849+
};
850+
845851
/// Return true if the target supports masked store.
846-
LLVM_ABI bool isLegalMaskedStore(Type *DataType, Align Alignment,
847-
unsigned AddressSpace) const;
852+
LLVM_ABI bool
853+
isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddressSpace,
854+
MaskKind MaskKind = VariableOrConstantMask) const;
848855
/// Return true if the target supports masked load.
849-
LLVM_ABI bool isLegalMaskedLoad(Type *DataType, Align Alignment,
850-
unsigned AddressSpace) const;
856+
LLVM_ABI bool
857+
isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace,
858+
MaskKind MaskKind = VariableOrConstantMask) const;
851859

852860
/// Return true if the target supports nontemporal store.
853861
LLVM_ABI bool isLegalNTStore(Type *DataType, Align Alignment) const;

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,14 @@ class TargetTransformInfoImplBase {
309309
}
310310

311311
virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
312-
unsigned AddressSpace) const {
312+
unsigned AddressSpace,
313+
TTI::MaskKind MaskKind) const {
313314
return false;
314315
}
315316

316317
virtual bool isLegalMaskedLoad(Type *DataType, Align Alignment,
317-
unsigned AddressSpace) const {
318+
unsigned AddressSpace,
319+
TTI::MaskKind MaskKind) const {
318320
return false;
319321
}
320322

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -468,13 +468,17 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
468468
}
469469

470470
bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
471-
unsigned AddressSpace) const {
472-
return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace);
471+
unsigned AddressSpace,
472+
TTI::MaskKind MaskKind) const {
473+
return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace,
474+
MaskKind);
473475
}
474476

475477
bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
476-
unsigned AddressSpace) const {
477-
return TTIImpl->isLegalMaskedLoad(DataType, Alignment, AddressSpace);
478+
unsigned AddressSpace,
479+
TTI::MaskKind MaskKind) const {
480+
return TTIImpl->isLegalMaskedLoad(DataType, Alignment, AddressSpace,
481+
MaskKind);
478482
}
479483

480484
bool TargetTransformInfo::isLegalNTStore(Type *DataType,

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2465,6 +2465,7 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
24652465
SDValue PassThru = MLD->getPassThru();
24662466
Align Alignment = MLD->getBaseAlign();
24672467
ISD::LoadExtType ExtType = MLD->getExtensionType();
2468+
MachineMemOperand::Flags MMOFlags = MLD->getMemOperand()->getFlags();
24682469

24692470
// Split Mask operand
24702471
SDValue MaskLo, MaskHi;
@@ -2490,9 +2491,8 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
24902491
std::tie(PassThruLo, PassThruHi) = DAG.SplitVector(PassThru, dl);
24912492

24922493
MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
2493-
MLD->getPointerInfo(), MachineMemOperand::MOLoad,
2494-
LocationSize::beforeOrAfterPointer(), Alignment, MLD->getAAInfo(),
2495-
MLD->getRanges());
2494+
MLD->getPointerInfo(), MMOFlags, LocationSize::beforeOrAfterPointer(),
2495+
Alignment, MLD->getAAInfo(), MLD->getRanges());
24962496

24972497
Lo = DAG.getMaskedLoad(LoVT, dl, Ch, Ptr, Offset, MaskLo, PassThruLo, LoMemVT,
24982498
MMO, MLD->getAddressingMode(), ExtType,
@@ -2515,8 +2515,8 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
25152515
LoMemVT.getStoreSize().getFixedValue());
25162516

25172517
MMO = DAG.getMachineFunction().getMachineMemOperand(
2518-
MPI, MachineMemOperand::MOLoad, LocationSize::beforeOrAfterPointer(),
2519-
Alignment, MLD->getAAInfo(), MLD->getRanges());
2518+
MPI, MMOFlags, LocationSize::beforeOrAfterPointer(), Alignment,
2519+
MLD->getAAInfo(), MLD->getRanges());
25202520

25212521
Hi = DAG.getMaskedLoad(HiVT, dl, Ch, Ptr, Offset, MaskHi, PassThruHi,
25222522
HiMemVT, MMO, MLD->getAddressingMode(), ExtType,

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5063,6 +5063,8 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) {
50635063
auto MMOFlags = MachineMemOperand::MOLoad;
50645064
if (I.hasMetadata(LLVMContext::MD_nontemporal))
50655065
MMOFlags |= MachineMemOperand::MONonTemporal;
5066+
if (I.hasMetadata(LLVMContext::MD_invariant_load))
5067+
MMOFlags |= MachineMemOperand::MOInvariant;
50665068

50675069
MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
50685070
MachinePointerInfo(PtrOperand), MMOFlags,

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,12 +323,14 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
323323
}
324324

325325
bool isLegalMaskedLoad(Type *DataType, Align Alignment,
326-
unsigned /*AddressSpace*/) const override {
326+
unsigned /*AddressSpace*/,
327+
TTI::MaskKind /*MaskKind*/) const override {
327328
return isLegalMaskedLoadStore(DataType, Alignment);
328329
}
329330

330331
bool isLegalMaskedStore(Type *DataType, Align Alignment,
331-
unsigned /*AddressSpace*/) const override {
332+
unsigned /*AddressSpace*/,
333+
TTI::MaskKind /*MaskKind*/) const override {
332334
return isLegalMaskedLoadStore(DataType, Alignment);
333335
}
334336

llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1125,7 +1125,8 @@ bool ARMTTIImpl::isProfitableLSRChainElement(Instruction *I) const {
11251125
}
11261126

11271127
bool ARMTTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
1128-
unsigned /*AddressSpace*/) const {
1128+
unsigned /*AddressSpace*/,
1129+
TTI::MaskKind /*MaskKind*/) const {
11291130
if (!EnableMaskedLoadStores || !ST->hasMVEIntegerOps())
11301131
return false;
11311132

llvm/lib/Target/ARM/ARMTargetTransformInfo.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -186,12 +186,16 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
186186

187187
bool isProfitableLSRChainElement(Instruction *I) const override;
188188

189-
bool isLegalMaskedLoad(Type *DataTy, Align Alignment,
190-
unsigned AddressSpace) const override;
191-
192-
bool isLegalMaskedStore(Type *DataTy, Align Alignment,
193-
unsigned AddressSpace) const override {
194-
return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
189+
bool
190+
isLegalMaskedLoad(Type *DataTy, Align Alignment, unsigned AddressSpace,
191+
TTI::MaskKind MaskKind =
192+
TTI::MaskKind::VariableOrConstantMask) const override;
193+
194+
bool
195+
isLegalMaskedStore(Type *DataTy, Align Alignment, unsigned AddressSpace,
196+
TTI::MaskKind MaskKind =
197+
TTI::MaskKind::VariableOrConstantMask) const override {
198+
return isLegalMaskedLoad(DataTy, Alignment, AddressSpace, MaskKind);
195199
}
196200

197201
bool forceScalarizeMaskedGather(VectorType *VTy,

llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,14 +343,16 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
343343
}
344344

345345
bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
346-
unsigned /*AddressSpace*/) const {
346+
unsigned /*AddressSpace*/,
347+
TTI::MaskKind /*MaskKind*/) const {
347348
// This function is called from scalarize-masked-mem-intrin, which runs
348349
// in pre-isel. Use ST directly instead of calling isHVXVectorType.
349350
return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
350351
}
351352

352353
bool HexagonTTIImpl::isLegalMaskedLoad(Type *DataType, Align /*Alignment*/,
353-
unsigned /*AddressSpace*/) const {
354+
unsigned /*AddressSpace*/,
355+
TTI::MaskKind /*MaskKind*/) const {
354356
// This function is called from scalarize-masked-mem-intrin, which runs
355357
// in pre-isel. Use ST directly instead of calling isHVXVectorType.
356358
return HexagonMaskedVMem && ST.isTypeForHVX(DataType);

llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,10 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
165165
}
166166

167167
bool isLegalMaskedStore(Type *DataType, Align Alignment,
168-
unsigned AddressSpace) const override;
169-
bool isLegalMaskedLoad(Type *DataType, Align Alignment,
170-
unsigned AddressSpace) const override;
168+
unsigned AddressSpace,
169+
TTI::MaskKind MaskKind) const override;
170+
bool isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace,
171+
TTI::MaskKind MaskKind) const override;
171172
bool isLegalMaskedGather(Type *Ty, Align Alignment) const override;
172173
bool isLegalMaskedScatter(Type *Ty, Align Alignment) const override;
173174
bool forceScalarizeMaskedGather(VectorType *VTy,

0 commit comments

Comments
 (0)