Skip to content

Commit d33c9cb

Browse files
committed
[NVPTX] Disable v2f32 registers when no operations supported, or via cl::opt
1 parent fc62990 commit d33c9cb

20 files changed

+2566
-1720
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 52 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ static bool IsPTXVectorType(MVT VT) {
196196
// - unsigned int NumElts - The number of elements in the final vector
197197
// - EVT EltVT - The type of the elements in the final vector
198198
static std::optional<std::pair<unsigned int, MVT>>
199-
getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
199+
getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
200+
unsigned AddressSpace) {
200201
if (!VectorEVT.isSimple())
201202
return std::nullopt;
202203
const MVT VectorVT = VectorEVT.getSimpleVT();
@@ -213,6 +214,8 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
213214
// The size of the PTX virtual register that holds a packed type.
214215
unsigned PackRegSize;
215216

217+
bool CanLowerTo256Bit = STI.has256BitVectorLoadStore(AddressSpace);
218+
216219
// We only handle "native" vector sizes for now, e.g. <4 x double> is not
217220
// legal. We can (and should) split that into 2 stores of <2 x double> here
218221
// but I'm leaving that as a TODO for now.
@@ -263,6 +266,8 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
263266
LLVM_FALLTHROUGH;
264267
case MVT::v2f32: // <1 x f32x2>
265268
case MVT::v4f32: // <2 x f32x2>
269+
if (!STI.hasF32x2Instructions())
270+
return std::pair(NumElts, EltVT);
266271
PackRegSize = 64;
267272
break;
268273
}
@@ -278,97 +283,46 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
278283
}
279284

280285
/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
281-
/// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors
282-
/// into their primitive components.
286+
/// legal-ish MVTs that compose it. Unlike ComputeValueVTs, this will legalize
287+
/// the types as required by the calling convention (with special handling for
288+
/// i8s).
283289
/// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
284290
/// same number of types as the Ins/Outs arrays in LowerFormalArguments,
285291
/// LowerCall, and LowerReturn.
286292
static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
293+
LLVMContext &Ctx, CallingConv::ID CallConv,
287294
Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
288-
SmallVectorImpl<uint64_t> *Offsets = nullptr,
295+
SmallVectorImpl<uint64_t> *Offsets,
289296
uint64_t StartingOffset = 0) {
297+
assert(Offsets && "Offsets must be non-null");
298+
290299
SmallVector<EVT, 16> TempVTs;
291300
SmallVector<uint64_t, 16> TempOffsets;
292-
293-
// Special case for i128 - decompose to (i64, i64)
294-
if (Ty->isIntegerTy(128) || Ty->isFP128Ty()) {
295-
ValueVTs.append({MVT::i64, MVT::i64});
296-
297-
if (Offsets)
298-
Offsets->append({StartingOffset + 0, StartingOffset + 8});
299-
300-
return;
301-
}
302-
303-
// Given a struct type, recursively traverse the elements with custom ComputePTXValueVTs.
304-
if (StructType *STy = dyn_cast<StructType>(Ty)) {
305-
auto const *SL = DL.getStructLayout(STy);
306-
auto ElementNum = 0;
307-
for(auto *EI : STy->elements()) {
308-
ComputePTXValueVTs(TLI, DL, EI, ValueVTs, Offsets,
309-
StartingOffset + SL->getElementOffset(ElementNum));
310-
++ElementNum;
311-
}
312-
return;
313-
}
314-
315-
// Given an array type, recursively traverse the elements with custom ComputePTXValueVTs.
316-
if (ArrayType *ATy = dyn_cast<ArrayType>(Ty)) {
317-
Type *EltTy = ATy->getElementType();
318-
uint64_t EltSize = DL.getTypeAllocSize(EltTy);
319-
for (int I : llvm::seq<int>(ATy->getNumElements()))
320-
ComputePTXValueVTs(TLI, DL, EltTy, ValueVTs, Offsets, StartingOffset + I * EltSize);
321-
return;
322-
}
323-
324-
// Will split structs and arrays into member types, but will not split vector
325-
// types. We do that manually below.
326301
ComputeValueVTs(TLI, DL, Ty, TempVTs, &TempOffsets, StartingOffset);
327302

328-
for (auto [VT, Off] : zip(TempVTs, TempOffsets)) {
329-
// Split vectors into individual elements that fit into registers.
330-
if (VT.isVector()) {
331-
unsigned NumElts = VT.getVectorNumElements();
332-
EVT EltVT = VT.getVectorElementType();
333-
// Below we must maintain power-of-2 sized vectors because
334-
// TargetLoweringBase::getVectorTypeBreakdown() which is invoked in
335-
// ComputePTXValueVTs() cannot currently break down non-power-of-2 sized
336-
// vectors.
337-
338-
// If the element type belongs to one of the supported packed vector types
339-
// then we can pack multiples of this element into a single register.
340-
if (VT == MVT::v2i8) {
341-
// We can pack 2 i8s into a single 16-bit register. We only do this for
342-
// loads and stores, which is why we have a separate case for it.
343-
EltVT = MVT::v2i8;
344-
NumElts = 1;
345-
} else if (VT == MVT::v3i8) {
346-
// We can also pack 3 i8s into 32-bit register, leaving the 4th
347-
// element undefined.
348-
EltVT = MVT::v4i8;
349-
NumElts = 1;
350-
} else if (NumElts > 1 && isPowerOf2_32(NumElts)) {
351-
// Handle default packed types.
352-
for (MVT PackedVT : NVPTX::packed_types()) {
353-
const auto NumEltsPerReg = PackedVT.getVectorNumElements();
354-
if (NumElts % NumEltsPerReg == 0 &&
355-
EltVT == PackedVT.getVectorElementType()) {
356-
EltVT = PackedVT;
357-
NumElts /= NumEltsPerReg;
358-
break;
359-
}
360-
}
361-
}
303+
for (const auto [VT, Off] : zip(TempVTs, TempOffsets)) {
304+
MVT RegisterVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, VT);
305+
unsigned NumRegs = TLI.getNumRegistersForCallingConv(Ctx, CallConv, VT);
306+
307+
// Since we actually can load/store b8, we need to ensure that we'll use
308+
// the original sized type for any i8s or i8 vectors.
309+
if (VT.getScalarType() == MVT::i8) {
310+
if (RegisterVT == MVT::i16)
311+
RegisterVT = MVT::i8;
312+
else if (RegisterVT == MVT::v2i16)
313+
RegisterVT = MVT::v2i8;
314+
else
315+
assert(RegisterVT == MVT::v4i8 &&
316+
"Expected v4i8, v2i16, or i16 for i8 RegisterVT");
317+
}
362318

363-
for (unsigned J : seq(NumElts)) {
364-
ValueVTs.push_back(EltVT);
365-
if (Offsets)
366-
Offsets->push_back(Off + J * EltVT.getStoreSize());
367-
}
368-
} else {
369-
ValueVTs.push_back(VT);
370-
if (Offsets)
371-
Offsets->push_back(Off);
319+
// TODO: This is horribly incorrect for cases where the vector elements are
320+
// not a multiple of bytes (ex i1) and legal or i8. However, this problem
321+
// has existed for as long as NVPTX has and no one has complained, so we'll
322+
// leave it for now.
323+
for (unsigned I : seq(NumRegs)) {
324+
ValueVTs.push_back(RegisterVT);
325+
Offsets->push_back(Off + I * RegisterVT.getStoreSize());
372326
}
373327
}
374328
}
@@ -631,7 +585,9 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
631585
addRegisterClass(MVT::v2f16, &NVPTX::B32RegClass);
632586
addRegisterClass(MVT::bf16, &NVPTX::B16RegClass);
633587
addRegisterClass(MVT::v2bf16, &NVPTX::B32RegClass);
634-
addRegisterClass(MVT::v2f32, &NVPTX::B64RegClass);
588+
589+
if (STI.hasF32x2Instructions())
590+
addRegisterClass(MVT::v2f32, &NVPTX::B64RegClass);
635591

636592
// Conversion to/from FP16/FP16x2 is always legal.
637593
setOperationAction(ISD::BUILD_VECTOR, MVT::v2f16, Custom);
@@ -672,7 +628,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
672628
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2f32, Expand);
673629
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2f32, Expand);
674630
// Need custom lowering in case the index is dynamic.
675-
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
631+
if (STI.hasF32x2Instructions())
632+
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
676633

677634
// Custom conversions to/from v2i8.
678635
setOperationAction(ISD::BITCAST, MVT::v2i8, Custom);
@@ -1606,7 +1563,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
16061563
} else {
16071564
SmallVector<EVT, 16> VTs;
16081565
SmallVector<uint64_t, 16> Offsets;
1609-
ComputePTXValueVTs(*this, DL, Arg.Ty, VTs, &Offsets, VAOffset);
1566+
ComputePTXValueVTs(*this, DL, Ctx, CLI.CallConv, Arg.Ty, VTs, &Offsets,
1567+
VAOffset);
16101568
assert(VTs.size() == Offsets.size() && "Size mismatch");
16111569
assert(VTs.size() == ArgOuts.size() && "Size mismatch");
16121570

@@ -1756,7 +1714,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
17561714
if (!Ins.empty()) {
17571715
SmallVector<EVT, 16> VTs;
17581716
SmallVector<uint64_t, 16> Offsets;
1759-
ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
1717+
ComputePTXValueVTs(*this, DL, Ctx, CLI.CallConv, RetTy, VTs, &Offsets);
17601718
assert(VTs.size() == Ins.size() && "Bad value decomposition");
17611719

17621720
const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
@@ -3217,8 +3175,8 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
32173175
if (ValVT != MemVT)
32183176
return SDValue();
32193177

3220-
const auto NumEltsAndEltVT = getVectorLoweringShape(
3221-
ValVT, STI.has256BitVectorLoadStore(N->getAddressSpace()));
3178+
const auto NumEltsAndEltVT =
3179+
getVectorLoweringShape(ValVT, STI, N->getAddressSpace());
32223180
if (!NumEltsAndEltVT)
32233181
return SDValue();
32243182
const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
@@ -3386,6 +3344,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
33863344
const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
33873345
SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
33883346
const DataLayout &DL = DAG.getDataLayout();
3347+
LLVMContext &Ctx = *DAG.getContext();
33893348
auto PtrVT = getPointerTy(DAG.getDataLayout());
33903349

33913350
const Function &F = DAG.getMachineFunction().getFunction();
@@ -3457,7 +3416,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34573416
} else {
34583417
SmallVector<EVT, 16> VTs;
34593418
SmallVector<uint64_t, 16> Offsets;
3460-
ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0);
3419+
ComputePTXValueVTs(*this, DL, Ctx, CallConv, Ty, VTs, &Offsets, 0);
34613420
assert(VTs.size() == ArgIns.size() && "Size mismatch");
34623421
assert(VTs.size() == Offsets.size() && "Size mismatch");
34633422

@@ -3469,7 +3428,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34693428
for (const unsigned NumElts : VI) {
34703429
// i1 is loaded/stored as i8
34713430
const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
3472-
const EVT VecVT = getVectorizedVT(LoadVT, NumElts, *DAG.getContext());
3431+
const EVT VecVT = getVectorizedVT(LoadVT, NumElts, Ctx);
34733432

34743433
SDValue VecAddr = DAG.getObjectPtrOffset(
34753434
dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
@@ -3514,6 +3473,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
35143473
}
35153474

35163475
const DataLayout &DL = DAG.getDataLayout();
3476+
LLVMContext &Ctx = *DAG.getContext();
35173477

35183478
const SDValue RetSymbol = DAG.getExternalSymbol("func_retval0", MVT::i32);
35193479
const auto RetAlign = getFunctionParamOptimizedAlign(&F, RetTy, DL);
@@ -3526,7 +3486,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
35263486

35273487
SmallVector<EVT, 16> VTs;
35283488
SmallVector<uint64_t, 16> Offsets;
3529-
ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
3489+
ComputePTXValueVTs(*this, DL, Ctx, CallConv, RetTy, VTs, &Offsets);
35303490
assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
35313491

35323492
const auto GetRetVal = [&](unsigned I) -> SDValue {
@@ -5985,8 +5945,8 @@ static void replaceLoadVector(SDNode *N, SelectionDAG &DAG,
59855945
if (ResVT != MemVT)
59865946
return;
59875947

5988-
const auto NumEltsAndEltVT = getVectorLoweringShape(
5989-
ResVT, STI.has256BitVectorLoadStore(LD->getAddressSpace()));
5948+
const auto NumEltsAndEltVT =
5949+
getVectorLoweringShape(ResVT, STI, LD->getAddressSpace());
59905950
if (!NumEltsAndEltVT)
59915951
return;
59925952
const auto [NumElts, EltVT] = NumEltsAndEltVT.value();

llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ static cl::opt<bool>
2929
NoF16Math("nvptx-no-f16-math", cl::Hidden,
3030
cl::desc("NVPTX Specific: Disable generation of f16 math ops."),
3131
cl::init(false));
32+
33+
static cl::opt<bool> NoF32x2("nvptx-no-f32x2", cl::Hidden,
34+
cl::desc("NVPTX Specific: Disable generation of "
35+
"f32x2 instructions and registers."),
36+
cl::init(false));
37+
3238
// Pin the vtable to this file.
3339
void NVPTXSubtarget::anchor() {}
3440

@@ -70,6 +76,10 @@ bool NVPTXSubtarget::allowFP16Math() const {
7076
return hasFP16Math() && NoF16Math == false;
7177
}
7278

79+
bool NVPTXSubtarget::hasF32x2Instructions() const {
80+
return SmVersion >= 100 && PTXVersion >= 86 && !NoF32x2;
81+
}
82+
7383
bool NVPTXSubtarget::hasNativeBF16Support(int Opcode) const {
7484
if (!hasBF16Math())
7585
return false;

llvm/lib/Target/NVPTX/NVPTXSubtarget.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,7 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
117117
return HasTcgen05 && PTXVersion >= 86;
118118
}
119119
// f32x2 instructions in Blackwell family
120-
bool hasF32x2Instructions() const {
121-
return SmVersion >= 100 && PTXVersion >= 86;
122-
}
120+
bool hasF32x2Instructions() const;
123121

124122
// TMA G2S copy with cta_group::1/2 support
125123
bool hasCpAsyncBulkTensorCTAGroupSupport() const {

llvm/test/CodeGen/NVPTX/aggregate-return.ll

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,20 @@ declare {float, float} @bars({float, float} %input)
1010
define void @test_v2f32(<2 x float> %input, ptr %output) {
1111
; CHECK-LABEL: test_v2f32(
1212
; CHECK: {
13-
; CHECK-NEXT: .reg .b64 %rd<4>;
13+
; CHECK-NEXT: .reg .b32 %r<5>;
14+
; CHECK-NEXT: .reg .b64 %rd<2>;
1415
; CHECK-EMPTY:
1516
; CHECK-NEXT: // %bb.0:
16-
; CHECK-NEXT: ld.param.b64 %rd1, [test_v2f32_param_0];
17+
; CHECK-NEXT: ld.param.v2.b32 {%r1, %r2}, [test_v2f32_param_0];
1718
; CHECK-NEXT: { // callseq 0, 0
1819
; CHECK-NEXT: .param .align 8 .b8 param0[8];
1920
; CHECK-NEXT: .param .align 8 .b8 retval0[8];
20-
; CHECK-NEXT: st.param.b64 [param0], %rd1;
21+
; CHECK-NEXT: st.param.v2.b32 [param0], {%r1, %r2};
2122
; CHECK-NEXT: call.uni (retval0), barv, (param0);
22-
; CHECK-NEXT: ld.param.b64 %rd2, [retval0];
23+
; CHECK-NEXT: ld.param.v2.b32 {%r3, %r4}, [retval0];
2324
; CHECK-NEXT: } // callseq 0
24-
; CHECK-NEXT: ld.param.b64 %rd3, [test_v2f32_param_1];
25-
; CHECK-NEXT: st.b64 [%rd3], %rd2;
25+
; CHECK-NEXT: ld.param.b64 %rd1, [test_v2f32_param_1];
26+
; CHECK-NEXT: st.v2.b32 [%rd1], {%r3, %r4};
2627
; CHECK-NEXT: ret;
2728
%call = tail call <2 x float> @barv(<2 x float> %input)
2829
store <2 x float> %call, ptr %output, align 8
@@ -32,24 +33,28 @@ define void @test_v2f32(<2 x float> %input, ptr %output) {
3233
define void @test_v3f32(<3 x float> %input, ptr %output) {
3334
; CHECK-LABEL: test_v3f32(
3435
; CHECK: {
35-
; CHECK-NEXT: .reg .b32 %r<3>;
36-
; CHECK-NEXT: .reg .b64 %rd<4>;
36+
; CHECK-NEXT: .reg .b32 %r<7>;
37+
; CHECK-NEXT: .reg .b64 %rd<6>;
3738
; CHECK-EMPTY:
3839
; CHECK-NEXT: // %bb.0:
39-
; CHECK-NEXT: ld.param.b64 %rd1, [test_v3f32_param_0];
40-
; CHECK-NEXT: ld.param.b32 %r1, [test_v3f32_param_0+8];
40+
; CHECK-NEXT: ld.param.v2.b32 {%r1, %r2}, [test_v3f32_param_0];
41+
; CHECK-NEXT: ld.param.b32 %r3, [test_v3f32_param_0+8];
4142
; CHECK-NEXT: { // callseq 1, 0
4243
; CHECK-NEXT: .param .align 16 .b8 param0[16];
4344
; CHECK-NEXT: .param .align 16 .b8 retval0[16];
44-
; CHECK-NEXT: st.param.b32 [param0+8], %r1;
45-
; CHECK-NEXT: st.param.b64 [param0], %rd1;
45+
; CHECK-NEXT: st.param.b32 [param0+8], %r3;
46+
; CHECK-NEXT: st.param.v2.b32 [param0], {%r1, %r2};
4647
; CHECK-NEXT: call.uni (retval0), barv3, (param0);
47-
; CHECK-NEXT: ld.param.b32 %r2, [retval0+8];
48-
; CHECK-NEXT: ld.param.b64 %rd2, [retval0];
48+
; CHECK-NEXT: ld.param.b32 %r4, [retval0+8];
49+
; CHECK-NEXT: ld.param.v2.b32 {%r5, %r6}, [retval0];
4950
; CHECK-NEXT: } // callseq 1
50-
; CHECK-NEXT: ld.param.b64 %rd3, [test_v3f32_param_1];
51-
; CHECK-NEXT: st.b32 [%rd3+8], %r2;
52-
; CHECK-NEXT: st.b64 [%rd3], %rd2;
51+
; CHECK-NEXT: cvt.u64.u32 %rd1, %r5;
52+
; CHECK-NEXT: cvt.u64.u32 %rd2, %r6;
53+
; CHECK-NEXT: shl.b64 %rd3, %rd2, 32;
54+
; CHECK-NEXT: or.b64 %rd4, %rd1, %rd3;
55+
; CHECK-NEXT: ld.param.b64 %rd5, [test_v3f32_param_1];
56+
; CHECK-NEXT: st.b32 [%rd5+8], %r4;
57+
; CHECK-NEXT: st.b64 [%rd5], %rd4;
5358
; CHECK-NEXT: ret;
5459
%call = tail call <3 x float> @barv3(<3 x float> %input)
5560
; Make sure we don't load more values than than we need to.

0 commit comments

Comments
 (0)