Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,17 +311,6 @@ void NVPTXInstPrinter::printLdStCode(const MCInst *MI, int OpNum,
default:
llvm_unreachable("Unknown register type");
}
} else if (Modifier == "vec") {
switch (Imm) {
case NVPTX::PTXLdStInstCode::V2:
O << ".v2";
return;
case NVPTX::PTXLdStInstCode::V4:
O << ".v4";
return;
}
// TODO: evaluate whether cases not covered by this switch are bugs
return;
}
llvm_unreachable(formatv("Unknown Modifier: {}", Modifier).str().c_str());
}
Expand Down
12 changes: 1 addition & 11 deletions llvm/lib/Target/NVPTX/NVPTX.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,17 +190,7 @@ enum AddressSpace : AddressSpaceUnderlyingType {
};

namespace PTXLdStInstCode {
enum FromType {
Unsigned = 0,
Signed,
Float,
Untyped
};
enum VecType {
Scalar = 1,
V2 = 2,
V4 = 4
};
enum FromType { Unsigned = 0, Signed, Float, Untyped };
} // namespace PTXLdStInstCode

/// PTXCvtMode - Conversion code enumeration
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI,
const MachineOperand *ParamSymbol = Mov.uses().begin();
assert(ParamSymbol->isSymbol());

constexpr unsigned LDInstBasePtrOpIdx = 6;
constexpr unsigned LDInstBasePtrOpIdx = 5;
constexpr unsigned LDInstAddrSpaceOpIdx = 2;
for (auto *LI : LoadInsts) {
(LI->uses().begin() + LDInstBasePtrOpIdx)
Expand Down
126 changes: 66 additions & 60 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
return;
case NVPTXISD::LoadV2:
case NVPTXISD::LoadV4:
case NVPTXISD::LoadV8:
if (tryLoadVector(N))
return;
break;
Expand All @@ -139,6 +140,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
break;
case NVPTXISD::StoreV2:
case NVPTXISD::StoreV4:
case NVPTXISD::StoreV8:
if (tryStoreVector(N))
return;
break;
Expand Down Expand Up @@ -1012,11 +1014,11 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {

// Helper function template to reduce amount of boilerplate code for
// opcode selection.
static std::optional<unsigned>
pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
unsigned Opcode_i16, unsigned Opcode_i32,
std::optional<unsigned> Opcode_i64, unsigned Opcode_f32,
std::optional<unsigned> Opcode_f64) {
static std::optional<unsigned> pickOpcodeForVT(
MVT::SimpleValueType VT, std::optional<unsigned> Opcode_i8,
std::optional<unsigned> Opcode_i16, std::optional<unsigned> Opcode_i32,
std::optional<unsigned> Opcode_i64, std::optional<unsigned> Opcode_f32,
std::optional<unsigned> Opcode_f64) {
switch (VT) {
case MVT::i1:
case MVT::i8:
Expand Down Expand Up @@ -1091,7 +1093,6 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
SDValue Ops[] = {getI32Imm(Ordering, DL),
getI32Imm(Scope, DL),
getI32Imm(CodeAddrSpace, DL),
getI32Imm(NVPTX::PTXLdStInstCode::Scalar, DL),
getI32Imm(FromType, DL),
getI32Imm(FromTypeWidth, DL),
Base,
Expand Down Expand Up @@ -1128,6 +1129,22 @@ static bool isSubVectorPackedInI32(EVT EltVT) {
return Isv2x16VT(EltVT) || EltVT == MVT::v4i8;
}

static unsigned getLoadStoreVectorNumElts(SDNode *N) {
switch (N->getOpcode()) {
case NVPTXISD::LoadV2:
case NVPTXISD::StoreV2:
return 2;
case NVPTXISD::LoadV4:
case NVPTXISD::StoreV4:
return 4;
case NVPTXISD::LoadV8:
case NVPTXISD::StoreV8:
return 8;
default:
llvm_unreachable("Unexpected opcode");
}
}

bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
MemSDNode *MemSD = cast<MemSDNode>(N);
const EVT MemEVT = MemSD->getMemoryVT();
Expand Down Expand Up @@ -1159,35 +1176,21 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
? NVPTX::PTXLdStInstCode::Signed
: NVPTX::PTXLdStInstCode::Untyped;

unsigned VecType;
unsigned FromTypeWidth;
switch (N->getOpcode()) {
case NVPTXISD::LoadV2:
FromTypeWidth = TotalWidth / 2;
VecType = NVPTX::PTXLdStInstCode::V2;
break;
case NVPTXISD::LoadV4:
FromTypeWidth = TotalWidth / 4;
VecType = NVPTX::PTXLdStInstCode::V4;
break;
default:
return false;
}
unsigned FromTypeWidth = TotalWidth / getLoadStoreVectorNumElts(N);

if (isSubVectorPackedInI32(EltVT)) {
assert(ExtensionType == ISD::NON_EXTLOAD);
EltVT = MVT::i32;
}

assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
FromTypeWidth <= 128 && TotalWidth <= 128 && "Invalid width for load");
FromTypeWidth <= 128 && TotalWidth <= 256 && "Invalid width for load");

SDValue Offset, Base;
SelectADDR(N->getOperand(1), Base, Offset);
SDValue Ops[] = {getI32Imm(Ordering, DL),
getI32Imm(Scope, DL),
getI32Imm(CodeAddrSpace, DL),
getI32Imm(VecType, DL),
getI32Imm(FromType, DL),
getI32Imm(FromTypeWidth, DL),
Base,
Expand All @@ -1205,9 +1208,16 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
NVPTX::LDV_f32_v2, NVPTX::LDV_f64_v2);
break;
case NVPTXISD::LoadV4:
Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4,
NVPTX::LDV_i16_v4, NVPTX::LDV_i32_v4, std::nullopt,
NVPTX::LDV_f32_v4, std::nullopt);
Opcode =
pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4,
NVPTX::LDV_i16_v4, NVPTX::LDV_i32_v4, NVPTX::LDV_i64_v4,
NVPTX::LDV_f32_v4, NVPTX::LDV_f64_v4);
break;
case NVPTXISD::LoadV8:
Opcode =
pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, {/* no v8i8 */},
{/* no v8i16 */}, NVPTX::LDV_i32_v8, {/* no v8i64 */},
NVPTX::LDV_f32_v8, {/* no v8f64 */});
break;
}
if (!Opcode)
Expand Down Expand Up @@ -1303,13 +1313,20 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
Opcode = pickOpcodeForVT(
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
NVPTX::INT_PTX_LDG_G_v4i64_ELE, NVPTX::INT_PTX_LDG_G_v4f32_ELE,
NVPTX::INT_PTX_LDG_G_v4f64_ELE);
break;
case NVPTXISD::LDUV4:
Opcode = pickOpcodeForVT(
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE,
NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
{/* no v4i64 */}, NVPTX::INT_PTX_LDU_G_v4f32_ELE, {/* no v4f64 */});
break;
case NVPTXISD::LoadV8:
Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, {/* no v8i8 */},
{/* no v8i16 */}, NVPTX::INT_PTX_LDG_G_v8i32_ELE,
{/* no v8i64 */}, NVPTX::INT_PTX_LDG_G_v8f32_ELE,
{/* no v8f64 */});
break;
}
if (!Opcode)
Expand Down Expand Up @@ -1395,7 +1412,6 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
getI32Imm(Ordering, DL),
getI32Imm(Scope, DL),
getI32Imm(CodeAddrSpace, DL),
getI32Imm(NVPTX::PTXLdStInstCode::Scalar, DL),
getI32Imm(NVPTX::PTXLdStInstCode::Untyped, DL),
getI32Imm(ToTypeWidth, DL),
Base,
Expand Down Expand Up @@ -1443,41 +1459,24 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
// - for integer type, always use 'u'
const unsigned TotalWidth = StoreVT.getSimpleVT().getSizeInBits();

SmallVector<SDValue, 12> Ops;
SDValue N2;
unsigned VecType;
unsigned ToTypeWidth;
unsigned NumElts = getLoadStoreVectorNumElts(N);

switch (N->getOpcode()) {
case NVPTXISD::StoreV2:
VecType = NVPTX::PTXLdStInstCode::V2;
Ops.append({N->getOperand(1), N->getOperand(2)});
N2 = N->getOperand(3);
ToTypeWidth = TotalWidth / 2;
break;
case NVPTXISD::StoreV4:
VecType = NVPTX::PTXLdStInstCode::V4;
Ops.append({N->getOperand(1), N->getOperand(2), N->getOperand(3),
N->getOperand(4)});
N2 = N->getOperand(5);
ToTypeWidth = TotalWidth / 4;
break;
default:
return false;
}
SmallVector<SDValue, 16> Ops(N->ops().slice(1, NumElts));
SDValue N2 = N->getOperand(NumElts + 1);
unsigned ToTypeWidth = TotalWidth / NumElts;

if (isSubVectorPackedInI32(EltVT)) {
EltVT = MVT::i32;
}

assert(isPowerOf2_32(ToTypeWidth) && ToTypeWidth >= 8 && ToTypeWidth <= 128 &&
TotalWidth <= 128 && "Invalid width for store");
TotalWidth <= 256 && "Invalid width for store");

SDValue Offset, Base;
SelectADDR(N2, Base, Offset);

Ops.append({getI32Imm(Ordering, DL), getI32Imm(Scope, DL),
getI32Imm(CodeAddrSpace, DL), getI32Imm(VecType, DL),
getI32Imm(CodeAddrSpace, DL),
getI32Imm(NVPTX::PTXLdStInstCode::Untyped, DL),
getI32Imm(ToTypeWidth, DL), Base, Offset, Chain});

Expand All @@ -1492,9 +1491,16 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
NVPTX::STV_f32_v2, NVPTX::STV_f64_v2);
break;
case NVPTXISD::StoreV4:
Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4,
NVPTX::STV_i16_v4, NVPTX::STV_i32_v4, std::nullopt,
NVPTX::STV_f32_v4, std::nullopt);
Opcode =
pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4,
NVPTX::STV_i16_v4, NVPTX::STV_i32_v4, NVPTX::STV_i64_v4,
NVPTX::STV_f32_v4, NVPTX::STV_f64_v4);
break;
case NVPTXISD::StoreV8:
Opcode =
pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, {/* no v8i8 */},
{/* no v8i16 */}, NVPTX::STV_i32_v8, {/* no v8i64 */},
NVPTX::STV_f32_v8, {/* no v8f64 */});
break;
}

Expand Down Expand Up @@ -1554,10 +1560,10 @@ bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
NVPTX::LoadParamMemV2F64);
break;
case 4:
Opcode =
pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy, NVPTX::LoadParamMemV4I8,
NVPTX::LoadParamMemV4I16, NVPTX::LoadParamMemV4I32,
std::nullopt, NVPTX::LoadParamMemV4F32, std::nullopt);
Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy,
NVPTX::LoadParamMemV4I8, NVPTX::LoadParamMemV4I16,
NVPTX::LoadParamMemV4I32, {/* no v4i64 */},
NVPTX::LoadParamMemV4F32, {/* no v4f64 */});
break;
}
if (!Opcode)
Expand Down Expand Up @@ -1648,8 +1654,8 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
case 4:
Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
NVPTX::StoreRetvalV4I8, NVPTX::StoreRetvalV4I16,
NVPTX::StoreRetvalV4I32, std::nullopt,
NVPTX::StoreRetvalV4F32, std::nullopt);
NVPTX::StoreRetvalV4I32, {/* no v4i64 */},
NVPTX::StoreRetvalV4F32, {/* no v4f64 */});
break;
}
if (!Opcode)
Expand Down
Loading