Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 50 additions & 92 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ static bool IsPTXVectorType(MVT VT) {
// - unsigned int NumElts - The number of elements in the final vector
// - EVT EltVT - The type of the elements in the final vector
static std::optional<std::pair<unsigned int, MVT>>
getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
unsigned AddressSpace) {
if (!VectorEVT.isSimple())
return std::nullopt;
const MVT VectorVT = VectorEVT.getSimpleVT();
Expand All @@ -213,6 +214,8 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
// The size of the PTX virtual register that holds a packed type.
unsigned PackRegSize;

bool CanLowerTo256Bit = STI.has256BitVectorLoadStore(AddressSpace);

// We only handle "native" vector sizes for now, e.g. <4 x double> is not
// legal. We can (and should) split that into 2 stores of <2 x double> here
// but I'm leaving that as a TODO for now.
Expand Down Expand Up @@ -263,6 +266,8 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
LLVM_FALLTHROUGH;
case MVT::v2f32: // <1 x f32x2>
case MVT::v4f32: // <2 x f32x2>
if (!STI.hasF32x2Instructions())
return std::pair(NumElts, EltVT);
PackRegSize = 64;
break;
}
Expand All @@ -278,97 +283,44 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
}

/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
/// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors
/// into their primitive components.
/// legal-ish MVTs that compose it. Unlike ComputeValueVTs, this will legalize
/// the types as required by the calling convention (with special handling for
/// i8s).
/// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
/// same number of types as the Ins/Outs arrays in LowerFormalArguments,
/// LowerCall, and LowerReturn.
static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
LLVMContext &Ctx, CallingConv::ID CallConv,
Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
SmallVectorImpl<uint64_t> *Offsets = nullptr,
SmallVectorImpl<uint64_t> &Offsets,
uint64_t StartingOffset = 0) {
SmallVector<EVT, 16> TempVTs;
SmallVector<uint64_t, 16> TempOffsets;

// Special case for i128 - decompose to (i64, i64)
if (Ty->isIntegerTy(128) || Ty->isFP128Ty()) {
ValueVTs.append({MVT::i64, MVT::i64});

if (Offsets)
Offsets->append({StartingOffset + 0, StartingOffset + 8});

return;
}

// Given a struct type, recursively traverse the elements with custom ComputePTXValueVTs.
if (StructType *STy = dyn_cast<StructType>(Ty)) {
auto const *SL = DL.getStructLayout(STy);
auto ElementNum = 0;
for(auto *EI : STy->elements()) {
ComputePTXValueVTs(TLI, DL, EI, ValueVTs, Offsets,
StartingOffset + SL->getElementOffset(ElementNum));
++ElementNum;
}
return;
}

// Given an array type, recursively traverse the elements with custom ComputePTXValueVTs.
if (ArrayType *ATy = dyn_cast<ArrayType>(Ty)) {
Type *EltTy = ATy->getElementType();
uint64_t EltSize = DL.getTypeAllocSize(EltTy);
for (int I : llvm::seq<int>(ATy->getNumElements()))
ComputePTXValueVTs(TLI, DL, EltTy, ValueVTs, Offsets, StartingOffset + I * EltSize);
return;
}

// Will split structs and arrays into member types, but will not split vector
// types. We do that manually below.
ComputeValueVTs(TLI, DL, Ty, TempVTs, &TempOffsets, StartingOffset);

for (auto [VT, Off] : zip(TempVTs, TempOffsets)) {
// Split vectors into individual elements that fit into registers.
if (VT.isVector()) {
unsigned NumElts = VT.getVectorNumElements();
EVT EltVT = VT.getVectorElementType();
// Below we must maintain power-of-2 sized vectors because
// TargetLoweringBase::getVectorTypeBreakdown() which is invoked in
// ComputePTXValueVTs() cannot currently break down non-power-of-2 sized
// vectors.

// If the element type belongs to one of the supported packed vector types
// then we can pack multiples of this element into a single register.
if (VT == MVT::v2i8) {
// We can pack 2 i8s into a single 16-bit register. We only do this for
// loads and stores, which is why we have a separate case for it.
EltVT = MVT::v2i8;
NumElts = 1;
} else if (VT == MVT::v3i8) {
// We can also pack 3 i8s into 32-bit register, leaving the 4th
// element undefined.
EltVT = MVT::v4i8;
NumElts = 1;
} else if (NumElts > 1 && isPowerOf2_32(NumElts)) {
// Handle default packed types.
for (MVT PackedVT : NVPTX::packed_types()) {
const auto NumEltsPerReg = PackedVT.getVectorNumElements();
if (NumElts % NumEltsPerReg == 0 &&
EltVT == PackedVT.getVectorElementType()) {
EltVT = PackedVT;
NumElts /= NumEltsPerReg;
break;
}
}
}
for (const auto [VT, Off] : zip(TempVTs, TempOffsets)) {
MVT RegisterVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, VT);
unsigned NumRegs = TLI.getNumRegistersForCallingConv(Ctx, CallConv, VT);

// Since we actually can load/store b8, we need to ensure that we'll use
// the original sized type for any i8s or i8 vectors.
if (VT.getScalarType() == MVT::i8) {
if (RegisterVT == MVT::i16)
RegisterVT = MVT::i8;
else if (RegisterVT == MVT::v2i16)
RegisterVT = MVT::v2i8;
else
assert(RegisterVT == MVT::v4i8 &&
"Expected v4i8, v2i16, or i16 for i8 RegisterVT");
}

for (unsigned J : seq(NumElts)) {
ValueVTs.push_back(EltVT);
if (Offsets)
Offsets->push_back(Off + J * EltVT.getStoreSize());
}
} else {
ValueVTs.push_back(VT);
if (Offsets)
Offsets->push_back(Off);
// TODO: This is horribly incorrect for cases where the vector elements are
// not a multiple of bytes (ex i1) and legal or i8. However, this problem
// has existed for as long as NVPTX has and no one has complained, so we'll
// leave it for now.
for (unsigned I : seq(NumRegs)) {
ValueVTs.push_back(RegisterVT);
Offsets.push_back(Off + I * RegisterVT.getStoreSize());
}
}
}
Expand Down Expand Up @@ -631,7 +583,9 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
addRegisterClass(MVT::v2f16, &NVPTX::B32RegClass);
addRegisterClass(MVT::bf16, &NVPTX::B16RegClass);
addRegisterClass(MVT::v2bf16, &NVPTX::B32RegClass);
addRegisterClass(MVT::v2f32, &NVPTX::B64RegClass);

if (STI.hasF32x2Instructions())
addRegisterClass(MVT::v2f32, &NVPTX::B64RegClass);

// Conversion to/from FP16/FP16x2 is always legal.
setOperationAction(ISD::BUILD_VECTOR, MVT::v2f16, Custom);
Expand Down Expand Up @@ -672,7 +626,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2f32, Expand);
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2f32, Expand);
// Need custom lowering in case the index is dynamic.
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
if (STI.hasF32x2Instructions())
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);

// Custom conversions to/from v2i8.
setOperationAction(ISD::BITCAST, MVT::v2i8, Custom);
Expand Down Expand Up @@ -1606,7 +1561,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
} else {
SmallVector<EVT, 16> VTs;
SmallVector<uint64_t, 16> Offsets;
ComputePTXValueVTs(*this, DL, Arg.Ty, VTs, &Offsets, VAOffset);
ComputePTXValueVTs(*this, DL, Ctx, CLI.CallConv, Arg.Ty, VTs, Offsets,
VAOffset);
assert(VTs.size() == Offsets.size() && "Size mismatch");
assert(VTs.size() == ArgOuts.size() && "Size mismatch");

Expand Down Expand Up @@ -1756,7 +1712,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
if (!Ins.empty()) {
SmallVector<EVT, 16> VTs;
SmallVector<uint64_t, 16> Offsets;
ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
ComputePTXValueVTs(*this, DL, Ctx, CLI.CallConv, RetTy, VTs, Offsets);
assert(VTs.size() == Ins.size() && "Bad value decomposition");

const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
Expand Down Expand Up @@ -3217,8 +3173,8 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
if (ValVT != MemVT)
return SDValue();

const auto NumEltsAndEltVT = getVectorLoweringShape(
ValVT, STI.has256BitVectorLoadStore(N->getAddressSpace()));
const auto NumEltsAndEltVT =
getVectorLoweringShape(ValVT, STI, N->getAddressSpace());
if (!NumEltsAndEltVT)
return SDValue();
const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
Expand Down Expand Up @@ -3386,6 +3342,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
const DataLayout &DL = DAG.getDataLayout();
LLVMContext &Ctx = *DAG.getContext();
auto PtrVT = getPointerTy(DAG.getDataLayout());

const Function &F = DAG.getMachineFunction().getFunction();
Expand Down Expand Up @@ -3457,7 +3414,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
} else {
SmallVector<EVT, 16> VTs;
SmallVector<uint64_t, 16> Offsets;
ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0);
ComputePTXValueVTs(*this, DL, Ctx, CallConv, Ty, VTs, Offsets);
assert(VTs.size() == ArgIns.size() && "Size mismatch");
assert(VTs.size() == Offsets.size() && "Size mismatch");

Expand All @@ -3469,7 +3426,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
for (const unsigned NumElts : VI) {
// i1 is loaded/stored as i8
const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
const EVT VecVT = getVectorizedVT(LoadVT, NumElts, *DAG.getContext());
const EVT VecVT = getVectorizedVT(LoadVT, NumElts, Ctx);

SDValue VecAddr = DAG.getObjectPtrOffset(
dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
Expand Down Expand Up @@ -3514,6 +3471,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
}

const DataLayout &DL = DAG.getDataLayout();
LLVMContext &Ctx = *DAG.getContext();

const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL);
Expand All @@ -3526,7 +3484,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,

SmallVector<EVT, 16> VTs;
SmallVector<uint64_t, 16> Offsets;
ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
ComputePTXValueVTs(*this, DL, Ctx, CallConv, RetTy, VTs, Offsets);
assert(VTs.size() == OutVals.size() && "Bad return value decomposition");

const auto GetRetVal = [&](unsigned I) -> SDValue {
Expand Down Expand Up @@ -5985,8 +5943,8 @@ static void replaceLoadVector(SDNode *N, SelectionDAG &DAG,
if (ResVT != MemVT)
return;

const auto NumEltsAndEltVT = getVectorLoweringShape(
ResVT, STI.has256BitVectorLoadStore(LD->getAddressSpace()));
const auto NumEltsAndEltVT =
getVectorLoweringShape(ResVT, STI, LD->getAddressSpace());
if (!NumEltsAndEltVT)
return;
const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ static cl::opt<bool>
NoF16Math("nvptx-no-f16-math", cl::Hidden,
cl::desc("NVPTX Specific: Disable generation of f16 math ops."),
cl::init(false));

static cl::opt<bool> NoF32x2("nvptx-no-f32x2", cl::Hidden,
cl::desc("NVPTX Specific: Disable generation of "
"f32x2 instructions and registers."),
cl::init(false));

// Pin the vtable to this file.
void NVPTXSubtarget::anchor() {}

Expand Down Expand Up @@ -70,6 +76,10 @@ bool NVPTXSubtarget::allowFP16Math() const {
return hasFP16Math() && NoF16Math == false;
}

bool NVPTXSubtarget::hasF32x2Instructions() const {
return SmVersion >= 100 && PTXVersion >= 86 && !NoF32x2;
}

bool NVPTXSubtarget::hasNativeBF16Support(int Opcode) const {
if (!hasBF16Math())
return false;
Expand Down
4 changes: 1 addition & 3 deletions llvm/lib/Target/NVPTX/NVPTXSubtarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,7 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
return HasTcgen05 && PTXVersion >= 86;
}
// f32x2 instructions in Blackwell family
bool hasF32x2Instructions() const {
return SmVersion >= 100 && PTXVersion >= 86;
}
bool hasF32x2Instructions() const;

// TMA G2S copy with cta_group::1/2 support
bool hasCpAsyncBulkTensorCTAGroupSupport() const {
Expand Down
39 changes: 22 additions & 17 deletions llvm/test/CodeGen/NVPTX/aggregate-return.ll
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,20 @@ declare {float, float} @bars({float, float} %input)
define void @test_v2f32(<2 x float> %input, ptr %output) {
; CHECK-LABEL: test_v2f32(
; CHECK: {
; CHECK-NEXT: .reg .b64 %rd<4>;
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b64 %rd1, [test_v2f32_param_0];
; CHECK-NEXT: ld.param.v2.b32 {%r1, %r2}, [test_v2f32_param_0];
; CHECK-NEXT: { // callseq 0, 0
; CHECK-NEXT: .param .align 8 .b8 param0[8];
; CHECK-NEXT: .param .align 8 .b8 retval0[8];
; CHECK-NEXT: st.param.b64 [param0], %rd1;
; CHECK-NEXT: st.param.v2.b32 [param0], {%r1, %r2};
; CHECK-NEXT: call.uni (retval0), barv, (param0);
; CHECK-NEXT: ld.param.b64 %rd2, [retval0];
; CHECK-NEXT: ld.param.v2.b32 {%r3, %r4}, [retval0];
; CHECK-NEXT: } // callseq 0
; CHECK-NEXT: ld.param.b64 %rd3, [test_v2f32_param_1];
; CHECK-NEXT: st.b64 [%rd3], %rd2;
; CHECK-NEXT: ld.param.b64 %rd1, [test_v2f32_param_1];
; CHECK-NEXT: st.v2.b32 [%rd1], {%r3, %r4};
; CHECK-NEXT: ret;
%call = tail call <2 x float> @barv(<2 x float> %input)
store <2 x float> %call, ptr %output, align 8
Expand All @@ -32,24 +33,28 @@ define void @test_v2f32(<2 x float> %input, ptr %output) {
define void @test_v3f32(<3 x float> %input, ptr %output) {
; CHECK-LABEL: test_v3f32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<3>;
; CHECK-NEXT: .reg .b64 %rd<4>;
; CHECK-NEXT: .reg .b32 %r<7>;
; CHECK-NEXT: .reg .b64 %rd<6>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b64 %rd1, [test_v3f32_param_0];
; CHECK-NEXT: ld.param.b32 %r1, [test_v3f32_param_0+8];
; CHECK-NEXT: ld.param.v2.b32 {%r1, %r2}, [test_v3f32_param_0];
; CHECK-NEXT: ld.param.b32 %r3, [test_v3f32_param_0+8];
; CHECK-NEXT: { // callseq 1, 0
; CHECK-NEXT: .param .align 16 .b8 param0[16];
; CHECK-NEXT: .param .align 16 .b8 retval0[16];
; CHECK-NEXT: st.param.b32 [param0+8], %r1;
; CHECK-NEXT: st.param.b64 [param0], %rd1;
; CHECK-NEXT: st.param.b32 [param0+8], %r3;
; CHECK-NEXT: st.param.v2.b32 [param0], {%r1, %r2};
; CHECK-NEXT: call.uni (retval0), barv3, (param0);
; CHECK-NEXT: ld.param.b32 %r2, [retval0+8];
; CHECK-NEXT: ld.param.b64 %rd2, [retval0];
; CHECK-NEXT: ld.param.b32 %r4, [retval0+8];
; CHECK-NEXT: ld.param.v2.b32 {%r5, %r6}, [retval0];
; CHECK-NEXT: } // callseq 1
; CHECK-NEXT: ld.param.b64 %rd3, [test_v3f32_param_1];
; CHECK-NEXT: st.b32 [%rd3+8], %r2;
; CHECK-NEXT: st.b64 [%rd3], %rd2;
; CHECK-NEXT: cvt.u64.u32 %rd1, %r5;
; CHECK-NEXT: cvt.u64.u32 %rd2, %r6;
; CHECK-NEXT: shl.b64 %rd3, %rd2, 32;
; CHECK-NEXT: or.b64 %rd4, %rd1, %rd3;
; CHECK-NEXT: ld.param.b64 %rd5, [test_v3f32_param_1];
; CHECK-NEXT: st.b32 [%rd5+8], %r4;
; CHECK-NEXT: st.b64 [%rd5], %rd4;
; CHECK-NEXT: ret;
%call = tail call <3 x float> @barv3(<3 x float> %input)
; Make sure we don't load more values than than we need to.
Expand Down
Loading
Loading