Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, std::optional<unsigned> Opcode_i16,
case MVT::f32:
return Opcode_i32;
case MVT::v2f32:
case MVT::v2i32:
case MVT::i64:
case MVT::f64:
return Opcode_i64;
Expand Down
190 changes: 114 additions & 76 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,21 +226,20 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
switch (VectorVT.SimpleTy) {
default:
return std::nullopt;

case MVT::v4i64:
case MVT::v4f64:
case MVT::v8i32:
// This is a "native" vector type iff the address space is global
// and the target supports 256-bit loads/stores
// This is a "native" vector type iff the address space is global and the
// target supports 256-bit loads/stores
if (!CanLowerTo256Bit)
return std::nullopt;
LLVM_FALLTHROUGH;
case MVT::v2i8:
case MVT::v2i32:
case MVT::v2i64:
case MVT::v2f64:
case MVT::v4i32:
// This is a "native" vector type
return std::pair(NumElts, EltVT);

case MVT::v16f16: // <8 x f16x2>
case MVT::v16bf16: // <8 x bf16x2>
case MVT::v16i16: // <8 x i16x2>
Expand All @@ -264,12 +263,18 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
case MVT::v16i8: // <4 x i8x4>
PackRegSize = 32;
break;

case MVT::v8f32: // <4 x f32x2>
case MVT::v8i32: // <4 x i32x2>
// This is a "native" vector type iff the address space is global and the
// target supports 256-bit loads/stores
if (!CanLowerTo256Bit)
return std::nullopt;
LLVM_FALLTHROUGH;
case MVT::v2f32: // <1 x f32x2>
case MVT::v4f32: // <2 x f32x2>
case MVT::v2i32: // <1 x i32x2>
case MVT::v4i32: // <2 x i32x2>
if (!STI.hasF32x2Instructions())
return std::pair(NumElts, EltVT);
PackRegSize = 64;
Expand Down Expand Up @@ -590,8 +595,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
addRegisterClass(MVT::bf16, &NVPTX::B16RegClass);
addRegisterClass(MVT::v2bf16, &NVPTX::B32RegClass);

if (STI.hasF32x2Instructions())
if (STI.hasF32x2Instructions()) {
addRegisterClass(MVT::v2f32, &NVPTX::B64RegClass);
addRegisterClass(MVT::v2i32, &NVPTX::B64RegClass);
}

// Conversion to/from FP16/FP16x2 is always legal.
setOperationAction(ISD::BUILD_VECTOR, MVT::v2f16, Custom);
Expand Down Expand Up @@ -628,12 +635,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);

// No support for these operations with v2f32.
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2f32, Expand);
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2f32, Expand);
// No support for these operations with v2f32/v2i32
setOperationAction(ISD::INSERT_VECTOR_ELT, {MVT::v2f32, MVT::v2i32}, Expand);
setOperationAction(ISD::VECTOR_SHUFFLE, {MVT::v2f32, MVT::v2i32}, Expand);
// Need custom lowering in case the index is dynamic.
if (STI.hasF32x2Instructions())
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
setOperationAction(ISD::EXTRACT_VECTOR_ELT, {MVT::v2f32, MVT::v2i32},
Custom);

// Custom conversions to/from v2i8.
setOperationAction(ISD::BITCAST, MVT::v2i8, Custom);
Expand Down Expand Up @@ -661,14 +669,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// Operations not directly supported by NVPTX.
for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
MVT::v2f32, MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16,
MVT::v4i8, MVT::i32, MVT::i64}) {
MVT::v4i8, MVT::i32, MVT::v2i32, MVT::i64}) {
setOperationAction(ISD::SELECT_CC, VT, Expand);
setOperationAction(ISD::BR_CC, VT, Expand);
}

// Not directly supported. TLI would attempt to expand operations like
// FMINIMUM(v2f32) using invalid SETCC and VSELECT nodes.
setOperationAction(ISD::VSELECT, MVT::v2f32, Expand);
// We don't want ops like FMINIMUM or UMAX to be lowered to SETCC+VSELECT.
setOperationAction(ISD::VSELECT, {MVT::v2f32, MVT::v2i32}, Expand);

// Some SIGN_EXTEND_INREG can be done using cvt instruction.
// For others we will expand to a SHL/SRA pair.
Expand Down Expand Up @@ -815,7 +822,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction({ISD::SDIV, ISD::UDIV, ISD::SRA, ISD::SRL, ISD::MULHS,
ISD::MULHU, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::SETCC},
MVT::v2i16, Expand);
{MVT::v2i16, MVT::v2i32}, Expand);

// v2i32 is not supported for any arithmetic operations
setOperationAction({ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX,
ISD::CTPOP, ISD::CTLZ, ISD::ADD, ISD::SUB, ISD::MUL,
ISD::SHL, ISD::SRA, ISD::SRL, ISD::OR, ISD::AND, ISD::XOR,
ISD::SREM, ISD::UREM},
MVT::v2i32, Expand);

setOperationAction(ISD::ADDC, MVT::i32, Legal);
setOperationAction(ISD::ADDE, MVT::i32, Legal);
Expand All @@ -829,7 +843,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
}

setOperationAction(ISD::CTTZ, MVT::i16, Expand);
setOperationAction(ISD::CTTZ, MVT::v2i16, Expand);
setOperationAction(ISD::CTTZ, {MVT::v2i16, MVT::v2i32}, Expand);
setOperationAction(ISD::CTTZ, MVT::i32, Expand);
setOperationAction(ISD::CTTZ, MVT::i64, Expand);

Expand Down Expand Up @@ -1071,7 +1085,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// Custom lowering for tcgen05.st vector operands
setOperationAction(ISD::INTRINSIC_VOID,
{MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
MVT::v32i32, MVT::v64i32, MVT::v128i32},
MVT::v32i32, MVT::v64i32, MVT::v128i32, MVT::Other},
Custom);

// Enable custom lowering for the following:
Expand Down Expand Up @@ -2604,7 +2618,7 @@ static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) {
return V;
}

static SDValue LowerTcgen05St(SDValue Op, SelectionDAG &DAG) {
static SDValue lowerTcgen05St(SDValue Op, SelectionDAG &DAG) {
SDNode *N = Op.getNode();
SDLoc DL(N);
SmallVector<SDValue, 32> Ops;
Expand Down Expand Up @@ -2719,7 +2733,52 @@ static SDValue LowerTcgen05MMADisableOutputLane(SDValue Op, SelectionDAG &DAG) {
return Tcgen05MMANode;
}

static SDValue LowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
// Lower vector return type of tcgen05.ld intrinsics
static std::optional<std::pair<SDValue, SDValue>>
lowerTcgen05Ld(SDNode *N, SelectionDAG &DAG, bool HasOffset = false) {
SDLoc DL(N);
EVT ResVT = N->getValueType(0);
if (!ResVT.isVector())
return {}; // already legalized.

const unsigned NumElts = ResVT.getVectorNumElements();

// Create the return type of the instructions
SmallVector<EVT, 5> ListVTs;
for (unsigned i = 0; i < NumElts; ++i)
ListVTs.push_back(MVT::i32);

ListVTs.push_back(N->getValueType(1)); // Chain

SDVTList ResVTs = DAG.getVTList(ListVTs);

SmallVector<SDValue, 8> Ops{N->getOperand(0), N->getOperand(1),
N->getOperand(2)};

if (HasOffset) {
Ops.push_back(N->getOperand(3)); // offset
Ops.push_back(N->getOperand(4)); // Pack flag
} else
Ops.push_back(N->getOperand(3)); // Pack flag

MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
SDValue NewNode =
DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, ResVTs, Ops,
MemSD->getMemoryVT(), MemSD->getMemOperand());

// split the vector result
SmallVector<SDValue, 4> ScalarRes;
for (unsigned i = 0; i < NumElts; ++i) {
SDValue Res = NewNode.getValue(i);
ScalarRes.push_back(Res);
}

SDValue Chain = NewNode.getValue(NumElts);
SDValue BuildVector = DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, ScalarRes);
return {{BuildVector, Chain}};
}

static SDValue lowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
SDNode *N = Op.getNode();
SDValue Intrin = N->getOperand(1);

Expand Down Expand Up @@ -2765,7 +2824,7 @@ static SDValue LowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
return LowerTcgen05St(Op, DAG);
return lowerTcgen05St(Op, DAG);
case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
Expand Down Expand Up @@ -2867,6 +2926,28 @@ static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) {
SDValue Selector = (Op->op_end() - 1)->get();
return getPRMT(A, B, Selector, DL, DAG, Mode);
}

static SDValue lowerIntrinsicWChain(SDValue Op, SelectionDAG &DAG) {
switch (Op->getConstantOperandVal(1)) {
default:
return Op;

// These tcgen05 intrinsics return a v2i32, which is legal, so we have to
// lower them through LowerOperation() instead of ReplaceNodeResults().
case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
if (auto Res = lowerTcgen05Ld(Op.getNode(), DAG))
return DAG.getMergeValues({Res->first, Res->second}, SDLoc(Op));
return SDValue();

case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
if (auto Res = lowerTcgen05Ld(Op.getNode(), DAG, /*HasOffset=*/true))
return DAG.getMergeValues({Res->first, Res->second}, SDLoc(Op));
return SDValue();
}
}

static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
switch (Op->getConstantOperandVal(0)) {
default:
Expand Down Expand Up @@ -3029,11 +3110,11 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::ADDRSPACECAST:
return LowerADDRSPACECAST(Op, DAG);
case ISD::INTRINSIC_W_CHAIN:
return Op;
return lowerIntrinsicWChain(Op, DAG);
case ISD::INTRINSIC_WO_CHAIN:
return lowerIntrinsicWOChain(Op, DAG);
case ISD::INTRINSIC_VOID:
return LowerIntrinsicVoid(Op, DAG);
return lowerIntrinsicVoid(Op, DAG);
case ISD::BUILD_VECTOR:
return LowerBUILD_VECTOR(Op, DAG);
case ISD::BITCAST:
Expand Down Expand Up @@ -5920,7 +6001,7 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
IsPTXVectorType(VectorVT.getSimpleVT()))
return SDValue(); // Native vector loads already combine nicely w/
// extract_vector_elt.
// Don't mess with singletons or packed types (v2f32, v2*16, v4i8 and v8i8),
// Don't mess with singletons or packed types (v2*32, v2*16, v4i8 and v8i8),
// we already handle them OK.
if (VectorVT.getVectorNumElements() == 1 ||
NVPTX::isPackedVectorTy(VectorVT) || VectorVT == MVT::v8i8)
Expand Down Expand Up @@ -6300,53 +6381,6 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1}));
}

// Lower vector return type of tcgen05.ld intrinsics
static void ReplaceTcgen05Ld(SDNode *N, SelectionDAG &DAG,
SmallVectorImpl<SDValue> &Results,
bool hasOffset = false) {
SDLoc DL(N);
EVT ResVT = N->getValueType(0);
if (!ResVT.isVector())
return; // already legalized.

const unsigned NumElts = ResVT.getVectorNumElements();

// Create the return type of the instructions
SmallVector<EVT, 5> ListVTs;
for (unsigned i = 0; i < NumElts; ++i)
ListVTs.push_back(MVT::i32);

ListVTs.push_back(N->getValueType(1)); // Chain

SDVTList ResVTs = DAG.getVTList(ListVTs);

SmallVector<SDValue, 8> Ops{N->getOperand(0), N->getOperand(1),
N->getOperand(2)};

if (hasOffset) {
Ops.push_back(N->getOperand(3)); // offset
Ops.push_back(N->getOperand(4)); // Pack flag
} else
Ops.push_back(N->getOperand(3)); // Pack flag

MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
SDValue NewNode =
DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, ResVTs, Ops,
MemSD->getMemoryVT(), MemSD->getMemOperand());

// split the vector result
SmallVector<SDValue, 4> ScalarRes;
for (unsigned i = 0; i < NumElts; ++i) {
SDValue Res = NewNode.getValue(i);
ScalarRes.push_back(Res);
}

SDValue Chain = NewNode.getValue(NumElts);
SDValue BuildVector = DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, ScalarRes);
Results.push_back(BuildVector); // Build Vector
Results.push_back(Chain); // Chain
}

static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
SmallVectorImpl<SDValue> &Results) {
SDValue Chain = N->getOperand(0);
Expand Down Expand Up @@ -6455,21 +6489,18 @@ static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
return;
}

case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
case Intrinsic::nvvm_tcgen05_ld_16x64b_x4:
case Intrinsic::nvvm_tcgen05_ld_16x64b_x8:
case Intrinsic::nvvm_tcgen05_ld_16x64b_x16:
case Intrinsic::nvvm_tcgen05_ld_16x64b_x32:
case Intrinsic::nvvm_tcgen05_ld_16x64b_x64:
case Intrinsic::nvvm_tcgen05_ld_16x64b_x128:
case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
case Intrinsic::nvvm_tcgen05_ld_32x32b_x4:
case Intrinsic::nvvm_tcgen05_ld_32x32b_x8:
case Intrinsic::nvvm_tcgen05_ld_32x32b_x16:
case Intrinsic::nvvm_tcgen05_ld_32x32b_x32:
case Intrinsic::nvvm_tcgen05_ld_32x32b_x64:
case Intrinsic::nvvm_tcgen05_ld_32x32b_x128:
case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
case Intrinsic::nvvm_tcgen05_ld_16x128b_x2:
case Intrinsic::nvvm_tcgen05_ld_16x128b_x4:
case Intrinsic::nvvm_tcgen05_ld_16x128b_x8:
Expand All @@ -6482,16 +6513,23 @@ static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
return ReplaceTcgen05Ld(N, DAG, Results);
if (auto Res = lowerTcgen05Ld(N, DAG)) {
Results.push_back(Res->first);
Results.push_back(Res->second);
}
return;

case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4:
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x8:
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x16:
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32:
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64:
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128:
return ReplaceTcgen05Ld(N, DAG, Results, /* Offset */ true);
if (auto Res = lowerTcgen05Ld(N, DAG, /*HasOffset=*/true)) {
Results.push_back(Res->first);
Results.push_back(Res->second);
}
return;
}
}

Expand Down
8 changes: 5 additions & 3 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -756,8 +756,10 @@ def : Pat<(vt (select i1:$p, vt:$a, vt:$b)),
(SELP_b32rr $a, $b, $p)>;
}

def : Pat<(v2f32 (select i1:$p, v2f32:$a, v2f32:$b)),
foreach vt = [v2f32, v2i32] in {
def : Pat<(vt (select i1:$p, vt:$a, vt:$b)),
(SELP_b64rr $a, $b, $p)>;
}

//-----------------------------------
// Test Instructions
Expand Down Expand Up @@ -2101,8 +2103,8 @@ foreach vt = [v2f16, v2bf16, v2i16] in {
(V2I16toI32 $a, $b)>;
}

// Same thing for the 64-bit type v2f32.
foreach vt = [v2f32] in {
// Handle extracting one element from the pair (64-bit types)
foreach vt = [v2f32, v2i32] in {
def : Pat<(extractelt vt:$src, 0), (I64toI32L_Sink $src)>, Requires<[hasPTX<71>]>;
def : Pat<(extractelt vt:$src, 1), (I64toI32H_Sink $src)>, Requires<[hasPTX<71>]>;

Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def B16 : NVPTXRegClass<[i16, f16, bf16], 16, (add (sequence "RS%u", 0, 4))>;
def B32 : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8, f32], 32,
(add (sequence "R%u", 0, 4),
VRFrame32, VRFrameLocal32)>;
def B64 : NVPTXRegClass<[i64, v2f32, f64], 64, (add (sequence "RL%u", 0, 4),
def B64 : NVPTXRegClass<[i64, v2i32, v2f32, f64], 64,
(add (sequence "RL%u", 0, 4),
VRFrame64, VRFrameLocal64)>;
// 128-bit regs are not defined as general regs in NVPTX. They are used for inlineASM only.
def B128 : NVPTXRegClass<[i128], 128, (add (sequence "RQ%u", 0, 4))>;
Expand Down
Loading