Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -810,12 +810,20 @@ class TargetTransformInfo {
LLVM_ABI AddressingModeKind
getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const;

/// Some targets only support masked load/store with a constant mask.
enum MaskKind {
VariableOrConstantMask,
ConstantMask,
};

/// Return true if the target supports masked store.
LLVM_ABI bool isLegalMaskedStore(Type *DataType, Align Alignment,
unsigned AddressSpace) const;
LLVM_ABI bool
isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddressSpace,
MaskKind MaskKind = VariableOrConstantMask) const;
/// Return true if the target supports masked load.
LLVM_ABI bool isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned AddressSpace) const;
LLVM_ABI bool
isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace,
MaskKind MaskKind = VariableOrConstantMask) const;

/// Return true if the target supports nontemporal store.
LLVM_ABI bool isLegalNTStore(Type *DataType, Align Alignment) const;
Expand Down
6 changes: 4 additions & 2 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,12 +309,14 @@ class TargetTransformInfoImplBase {
}

virtual bool isLegalMaskedStore(Type *DataType, Align Alignment,
unsigned AddressSpace) const {
unsigned AddressSpace,
TTI::MaskKind MaskKind) const {
return false;
}

virtual bool isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned AddressSpace) const {
unsigned AddressSpace,
TTI::MaskKind MaskKind) const {
return false;
}

Expand Down
12 changes: 8 additions & 4 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,13 +467,17 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
}

bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment,
unsigned AddressSpace) const {
return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace);
unsigned AddressSpace,
TTI::MaskKind MaskKind) const {
return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace,
MaskKind);
}

bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned AddressSpace) const {
return TTIImpl->isLegalMaskedLoad(DataType, Alignment, AddressSpace);
unsigned AddressSpace,
TTI::MaskKind MaskKind) const {
return TTIImpl->isLegalMaskedLoad(DataType, Alignment, AddressSpace,
MaskKind);
}

bool TargetTransformInfo::isLegalNTStore(Type *DataType,
Expand Down
10 changes: 5 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2449,6 +2449,7 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
SDValue PassThru = MLD->getPassThru();
Align Alignment = MLD->getBaseAlign();
ISD::LoadExtType ExtType = MLD->getExtensionType();
MachineMemOperand::Flags MMOFlags = MLD->getMemOperand()->getFlags();

// Split Mask operand
SDValue MaskLo, MaskHi;
Expand All @@ -2474,9 +2475,8 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
std::tie(PassThruLo, PassThruHi) = DAG.SplitVector(PassThru, dl);

MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
MLD->getPointerInfo(), MachineMemOperand::MOLoad,
LocationSize::beforeOrAfterPointer(), Alignment, MLD->getAAInfo(),
MLD->getRanges());
MLD->getPointerInfo(), MMOFlags, LocationSize::beforeOrAfterPointer(),
Alignment, MLD->getAAInfo(), MLD->getRanges());

Lo = DAG.getMaskedLoad(LoVT, dl, Ch, Ptr, Offset, MaskLo, PassThruLo, LoMemVT,
MMO, MLD->getAddressingMode(), ExtType,
Expand All @@ -2499,8 +2499,8 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD,
LoMemVT.getStoreSize().getFixedValue());

MMO = DAG.getMachineFunction().getMachineMemOperand(
MPI, MachineMemOperand::MOLoad, LocationSize::beforeOrAfterPointer(),
Alignment, MLD->getAAInfo(), MLD->getRanges());
MPI, MMOFlags, LocationSize::beforeOrAfterPointer(), Alignment,
MLD->getAAInfo(), MLD->getRanges());

Hi = DAG.getMaskedLoad(HiVT, dl, Ch, Ptr, Offset, MaskHi, PassThruHi,
HiMemVT, MMO, MLD->getAddressingMode(), ExtType,
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5044,6 +5044,8 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) {
auto MMOFlags = MachineMemOperand::MOLoad;
if (I.hasMetadata(LLVMContext::MD_nontemporal))
MMOFlags |= MachineMemOperand::MONonTemporal;
if (I.hasMetadata(LLVMContext::MD_invariant_load))
MMOFlags |= MachineMemOperand::MOInvariant;

MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
MachinePointerInfo(PtrOperand), MMOFlags,
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,14 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
}

bool isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned /*AddressSpace*/) const override {
unsigned /*AddressSpace*/,
TTI::MaskKind /*MaskKind*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}

bool isLegalMaskedStore(Type *DataType, Align Alignment,
unsigned /*AddressSpace*/) const override {
unsigned /*AddressSpace*/,
TTI::MaskKind /*MaskKind*/) const override {
return isLegalMaskedLoadStore(DataType, Alignment);
}

Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,8 @@ bool ARMTTIImpl::isProfitableLSRChainElement(Instruction *I) const {
}

bool ARMTTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment,
unsigned /*AddressSpace*/) const {
unsigned /*AddressSpace*/,
TTI::MaskKind /*MaskKind*/) const {
if (!EnableMaskedLoadStores || !ST->hasMVEIntegerOps())
return false;

Expand Down
10 changes: 5 additions & 5 deletions llvm/lib/Target/ARM/ARMTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,12 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {

bool isProfitableLSRChainElement(Instruction *I) const override;

bool isLegalMaskedLoad(Type *DataTy, Align Alignment,
unsigned AddressSpace) const override;
bool isLegalMaskedLoad(Type *DataTy, Align Alignment, unsigned AddressSpace,
TTI::MaskKind MaskKind) const override;

bool isLegalMaskedStore(Type *DataTy, Align Alignment,
unsigned AddressSpace) const override {
return isLegalMaskedLoad(DataTy, Alignment, AddressSpace);
bool isLegalMaskedStore(Type *DataTy, Align Alignment, unsigned AddressSpace,
TTI::MaskKind MaskKind) const override {
return isLegalMaskedLoad(DataTy, Alignment, AddressSpace, MaskKind);
}

bool forceScalarizeMaskedGather(VectorType *VTy,
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,14 +341,16 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
}

bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/,
unsigned /*AddressSpace*/) const {
unsigned /*AddressSpace*/,
TTI::MaskKind /*MaskKind*/) const {
// This function is called from scalarize-masked-mem-intrin, which runs
// in pre-isel. Use ST directly instead of calling isHVXVectorType.
return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
}

bool HexagonTTIImpl::isLegalMaskedLoad(Type *DataType, Align /*Alignment*/,
unsigned /*AddressSpace*/) const {
unsigned /*AddressSpace*/,
TTI::MaskKind /*MaskKind*/) const {
// This function is called from scalarize-masked-mem-intrin, which runs
// in pre-isel. Use ST directly instead of calling isHVXVectorType.
return HexagonMaskedVMem && ST.isTypeForHVX(DataType);
Expand Down
7 changes: 4 additions & 3 deletions llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,10 @@ class HexagonTTIImpl final : public BasicTTIImplBase<HexagonTTIImpl> {
}

bool isLegalMaskedStore(Type *DataType, Align Alignment,
unsigned AddressSpace) const override;
bool isLegalMaskedLoad(Type *DataType, Align Alignment,
unsigned AddressSpace) const override;
unsigned AddressSpace,
TTI::MaskKind MaskKind) const override;
bool isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace,
TTI::MaskKind MaskKind) const override;

/// @}

Expand Down
20 changes: 20 additions & 0 deletions llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,26 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
}
}

void NVPTXInstPrinter::printUsedBytesMaskPragma(const MCInst *MI, int OpNum,
raw_ostream &O) {
auto &Op = MI->getOperand(OpNum);
assert(Op.isImm() && "Invalid operand");
uint32_t Imm = (uint32_t)Op.getImm();
if (Imm != UINT32_MAX) {
O << ".pragma \"used_bytes_mask " << Imm << "\";\n\t";
}
}

void NVPTXInstPrinter::printRegisterOrSinkSymbol(const MCInst *MI, int OpNum,
raw_ostream &O,
const char *Modifier) {
const MCOperand &Op = MI->getOperand(OpNum);
if (Op.isReg() && Op.getReg() == MCRegister::NoRegister)
O << "_";
else
printOperand(MI, OpNum, O);
}

void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
raw_ostream &O) {
int64_t Imm = MI->getOperand(OpNum).getImm();
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class NVPTXInstPrinter : public MCInstPrinter {
StringRef Modifier = {});
void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O,
StringRef Modifier = {});
void printUsedBytesMaskPragma(const MCInst *MI, int OpNum, raw_ostream &O);
void printRegisterOrSinkSymbol(const MCInst *MI, int OpNum, raw_ostream &O,
const char *Modifier = nullptr);
void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O);
void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O);
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI,
const MachineOperand *ParamSymbol = Mov.uses().begin();
assert(ParamSymbol->isSymbol());

constexpr unsigned LDInstBasePtrOpIdx = 5;
constexpr unsigned LDInstBasePtrOpIdx = 6;
constexpr unsigned LDInstAddrSpaceOpIdx = 2;
for (auto *LI : LoadInsts) {
(LI->uses().begin() + LDInstBasePtrOpIdx)
Expand Down
33 changes: 31 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
switch (N->getOpcode()) {
case ISD::LOAD:
case ISD::ATOMIC_LOAD:
case NVPTXISD::MLoadV1:
if (tryLoad(N))
return;
break;
Expand Down Expand Up @@ -1118,6 +1119,19 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
? NVPTX::PTXLdStInstCode::Signed
: NVPTX::PTXLdStInstCode::Untyped;

uint32_t UsedBytesMask;
switch (N->getOpcode()) {
case ISD::LOAD:
case ISD::ATOMIC_LOAD:
UsedBytesMask = UINT32_MAX;
break;
case NVPTXISD::MLoadV1:
UsedBytesMask = N->getConstantOperandVal(N->getNumOperands() - 2);
break;
default:
llvm_unreachable("Unexpected opcode");
}

assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
FromTypeWidth <= 128 && "Invalid width for load");

Expand All @@ -1128,6 +1142,7 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
getI32Imm(CodeAddrSpace, DL),
getI32Imm(FromType, DL),
getI32Imm(FromTypeWidth, DL),
getI32Imm(UsedBytesMask, DL),
Base,
Offset,
Chain};
Expand Down Expand Up @@ -1190,6 +1205,8 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
: NVPTX::PTXLdStInstCode::Untyped;

const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD);
const uint32_t UsedBytesMask =
N->getConstantOperandVal(N->getNumOperands() - 2);

assert(!(EltVT.isVector() && ExtensionType != ISD::NON_EXTLOAD));

Expand All @@ -1199,6 +1216,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
getI32Imm(CodeAddrSpace, DL),
getI32Imm(FromType, DL),
getI32Imm(FromTypeWidth, DL),
getI32Imm(UsedBytesMask, DL),
Base,
Offset,
Chain};
Expand Down Expand Up @@ -1236,10 +1254,13 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
SDLoc DL(LD);

unsigned ExtensionType;
uint32_t UsedBytesMask;
if (const auto *Load = dyn_cast<LoadSDNode>(LD)) {
ExtensionType = Load->getExtensionType();
UsedBytesMask = UINT32_MAX;
} else {
ExtensionType = LD->getConstantOperandVal(LD->getNumOperands() - 1);
UsedBytesMask = LD->getConstantOperandVal(LD->getNumOperands() - 2);
}
const unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
? NVPTX::PTXLdStInstCode::Signed
Expand All @@ -1251,8 +1272,12 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
ExtensionType != ISD::NON_EXTLOAD));

const auto [Base, Offset] = selectADDR(LD->getOperand(1), CurDAG);
SDValue Ops[] = {getI32Imm(FromType, DL), getI32Imm(FromTypeWidth, DL), Base,
Offset, LD->getChain()};
SDValue Ops[] = {getI32Imm(FromType, DL),
getI32Imm(FromTypeWidth, DL),
getI32Imm(UsedBytesMask, DL),
Base,
Offset,
LD->getChain()};

const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;
std::optional<unsigned> Opcode;
Expand All @@ -1263,6 +1288,10 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_i16,
NVPTX::LD_GLOBAL_NC_i32, NVPTX::LD_GLOBAL_NC_i64);
break;
case NVPTXISD::MLoadV1:
Opcode = pickOpcodeForVT(TargetVT, std::nullopt, NVPTX::LD_GLOBAL_NC_i32,
NVPTX::LD_GLOBAL_NC_i64);
break;
case NVPTXISD::LoadV2:
Opcode =
pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_v2i16,
Expand Down
Loading
Loading