Skip to content

Commit 196ca9b

Browse files
committed
[NVPTX] Lower LLVM masked vector stores to PTX using the new sink symbol syntax
1 parent d2738c0 commit 196ca9b

20 files changed

+530
-18
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -810,9 +810,13 @@ class TargetTransformInfo {
810810
LLVM_ABI AddressingModeKind
811811
getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const;
812812

813-
/// Return true if the target supports masked store.
813+
/// Return true if the target supports masked store. A value of false for
814+
/// IsMaskConstant indicates that the mask could either be variable or
815+
/// constant. This is for targets that only support masked store with a
816+
/// constant mask.
814817
LLVM_ABI bool isLegalMaskedStore(Type *DataType, Align Alignment,
815-
unsigned AddressSpace) const;
818+
unsigned AddressSpace,
819+
bool IsMaskConstant = false) const;
816820
/// Return true if the target supports masked load.
817821
LLVM_ABI bool isLegalMaskedLoad(Type *DataType, Align Alignment,
818822
unsigned AddressSpace) const;

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ class TargetTransformInfoImplBase {
309309
}
310310

311311
virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
312-
unsigned AddressSpace) const {
312+
unsigned AddressSpace, bool IsMaskConstant) const {
313313
return false;
314314
}
315315

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,8 +467,8 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
467467
}
468468

469469
bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
470-
unsigned AddressSpace) const {
471-
return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace);
470+
unsigned AddressSpace, bool IsMaskConstant) const {
471+
return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace, IsMaskConstant);
472472
}
473473

474474
bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
321321
}
322322

323323
bool isLegalMaskedStore(Type *DataType, Align Alignment,
324-
unsigned /*AddressSpace*/) const override {
324+
unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const override {
325325
return isLegalMaskedLoadStore(DataType, Alignment);
326326
}
327327

llvm/lib/Target/ARM/ARMTargetTransformInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
190190
unsigned AddressSpace) const override;
191191

192192
bool isLegalMaskedStore(Type *DataTy, Align Alignment,
193-
unsigned AddressSpace) const override {
193+
unsigned AddressSpace, bool /*IsMaskConstant*/) const override {
194194
return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
195195
}
196196

llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
341341
}
342342

343343
bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
344-
unsigned /*AddressSpace*/) const {
344+
unsigned /*AddressSpace*/, bool /*IsMaskConstant*/) const {
345345
// This function is called from scalarize-masked-mem-intrin, which runs
346346
// in pre-isel. Use ST directly instead of calling isHVXVectorType.
347347
return HexagonMaskedVMem && ST.isTypeForHVX(DataType);

llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
166166
}
167167

168168
bool isLegalMaskedStore(Type *DataType, Align Alignment,
169-
unsigned AddressSpace) const override;
169+
unsigned AddressSpace, bool IsMaskConstant) const override;
170170
bool isLegalMaskedLoad(Type *DataType, Align Alignment,
171171
unsigned AddressSpace) const override;
172172

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,16 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
392392
}
393393
}
394394

395+
void NVPTXInstPrinter::printRegisterOrSinkSymbol(const MCInst *MI, int OpNum,
396+
raw_ostream &O,
397+
const char *Modifier) {
398+
const MCOperand &Op = MI->getOperand(OpNum);
399+
if (Op.isReg() && Op.getReg() == MCRegister::NoRegister)
400+
O << "_";
401+
else
402+
printOperand(MI, OpNum, O);
403+
}
404+
395405
void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
396406
raw_ostream &O) {
397407
int64_t Imm = MI->getOperand(OpNum).getImm();

llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
4646
StringRef Modifier = {});
4747
void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O,
4848
StringRef Modifier = {});
49+
void printRegisterOrSinkSymbol(const MCInst *MI, int OpNum, raw_ostream &O,
50+
const char *Modifier = nullptr);
4951
void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O);
5052
void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O);
5153
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
753753
setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::f128}, Custom);
754754
for (MVT VT : MVT::fixedlen_vector_valuetypes())
755755
if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
756-
setOperationAction({ISD::STORE, ISD::LOAD}, VT, Custom);
756+
setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE}, VT, Custom);
757757

758758
// Custom legalization for LDU intrinsics.
759759
// TODO: The logic to lower these is not very robust and we should rewrite it.
@@ -2869,6 +2869,87 @@ static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) {
28692869
return Or;
28702870
}
28712871

2872+
static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) {
2873+
SDNode *N = Op.getNode();
2874+
2875+
SDValue Chain = N->getOperand(0);
2876+
SDValue Val = N->getOperand(1);
2877+
SDValue BasePtr = N->getOperand(2);
2878+
SDValue Offset = N->getOperand(3);
2879+
SDValue Mask = N->getOperand(4);
2880+
2881+
SDLoc DL(N);
2882+
EVT ValVT = Val.getValueType();
2883+
MemSDNode *MemSD = cast<MemSDNode>(N);
2884+
assert(ValVT.isVector() && "Masked vector store must have vector type");
2885+
assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) &&
2886+
"Unexpected alignment for masked store");
2887+
2888+
unsigned Opcode = 0;
2889+
switch (ValVT.getSimpleVT().SimpleTy) {
2890+
default:
2891+
llvm_unreachable("Unexpected masked vector store type");
2892+
case MVT::v4i64:
2893+
case MVT::v4f64: {
2894+
Opcode = NVPTXISD::StoreV4;
2895+
break;
2896+
}
2897+
case MVT::v8i32:
2898+
case MVT::v8f32: {
2899+
Opcode = NVPTXISD::StoreV8;
2900+
break;
2901+
}
2902+
}
2903+
2904+
SmallVector<SDValue, 8> Ops;
2905+
2906+
// Construct the new SDNode. First operand is the chain.
2907+
Ops.push_back(Chain);
2908+
2909+
// The next N operands are the values to store. Encode the mask into the
2910+
// values using the sentinel register 0 to represent a masked-off element.
2911+
assert(Mask.getValueType().isVector() &&
2912+
Mask.getValueType().getVectorElementType() == MVT::i1 &&
2913+
"Mask must be a vector of i1");
2914+
assert(Mask.getOpcode() == ISD::BUILD_VECTOR &&
2915+
"Mask expected to be a BUILD_VECTOR");
2916+
assert(Mask.getValueType().getVectorNumElements() ==
2917+
ValVT.getVectorNumElements() &&
2918+
"Mask size must be the same as the vector size");
2919+
for (unsigned I : llvm::seq(ValVT.getVectorNumElements())) {
2920+
assert(isa<ConstantSDNode>(Mask.getOperand(I)) &&
2921+
"Mask elements must be constants");
2922+
if (Mask->getConstantOperandVal(I) == 0) {
2923+
// Append a sentinel register 0 to the Ops vector to represent a masked
2924+
// off element, this will be handled in tablegen
2925+
Ops.push_back(DAG.getRegister(MCRegister::NoRegister,
2926+
ValVT.getVectorElementType()));
2927+
} else {
2928+
// Extract the element from the vector to store
2929+
SDValue ExtVal =
2930+
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ValVT.getVectorElementType(),
2931+
Val, DAG.getIntPtrConstant(I, DL));
2932+
Ops.push_back(ExtVal);
2933+
}
2934+
}
2935+
2936+
// Next, the pointer operand.
2937+
Ops.push_back(BasePtr);
2938+
2939+
// Finally, the offset operand. We expect this to always be undef, and it will
2940+
// be ignored in lowering, but to mirror the handling of the other vector
2941+
// store instructions we include it in the new SDNode.
2942+
assert(Offset.getOpcode() == ISD::UNDEF &&
2943+
"Offset operand expected to be undef");
2944+
Ops.push_back(Offset);
2945+
2946+
SDValue NewSt =
2947+
DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
2948+
MemSD->getMemoryVT(), MemSD->getMemOperand());
2949+
2950+
return NewSt;
2951+
}
2952+
28722953
SDValue
28732954
NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28742955
switch (Op.getOpcode()) {
@@ -2905,6 +2986,12 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
29052986
return LowerVECREDUCE(Op, DAG);
29062987
case ISD::STORE:
29072988
return LowerSTORE(Op, DAG);
2989+
case ISD::MSTORE: {
2990+
assert(STI.has256BitVectorLoadStore(
2991+
cast<MemSDNode>(Op.getNode())->getAddressSpace()) &&
2992+
"Masked store vector not supported on subtarget.");
2993+
return lowerMSTORE(Op, DAG);
2994+
}
29082995
case ISD::LOAD:
29092996
return LowerLOAD(Op, DAG);
29102997
case ISD::SHL_PARTS:

0 commit comments

Comments
 (0)