@@ -196,7 +196,8 @@ static bool IsPTXVectorType(MVT VT) {
196196// - unsigned int NumElts - The number of elements in the final vector
197197// - EVT EltVT - The type of the elements in the final vector
198198static std::optional<std::pair<unsigned int , MVT>>
199- getVectorLoweringShape (EVT VectorEVT, bool CanLowerTo256Bit) {
199+ getVectorLoweringShape (EVT VectorEVT, const NVPTXSubtarget &STI,
200+ unsigned AddressSpace) {
200201 if (!VectorEVT.isSimple ())
201202 return std::nullopt ;
202203 const MVT VectorVT = VectorEVT.getSimpleVT ();
@@ -213,6 +214,8 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
213214 // The size of the PTX virtual register that holds a packed type.
214215 unsigned PackRegSize;
215216
217+ bool CanLowerTo256Bit = STI.has256BitVectorLoadStore (AddressSpace);
218+
216219 // We only handle "native" vector sizes for now, e.g. <4 x double> is not
217220 // legal. We can (and should) split that into 2 stores of <2 x double> here
218221 // but I'm leaving that as a TODO for now.
@@ -263,6 +266,8 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
263266 LLVM_FALLTHROUGH;
264267 case MVT::v2f32: // <1 x f32x2>
265268 case MVT::v4f32: // <2 x f32x2>
269+ if (!STI.hasF32x2Instructions ())
270+ return std::pair (NumElts, EltVT);
266271 PackRegSize = 64 ;
267272 break ;
268273 }
@@ -278,97 +283,44 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
278283}
279284
280285// / ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
281- // / EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors
282- // / into their primitive components.
286+ // / legal-ish MVTs that compose it. Unlike ComputeValueVTs, this will legalize
287+ // / the types as required by the calling convention (with special handling for
288+ // / i8s).
283289// / NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
284290// / same number of types as the Ins/Outs arrays in LowerFormalArguments,
285291// / LowerCall, and LowerReturn.
286292static void ComputePTXValueVTs (const TargetLowering &TLI, const DataLayout &DL,
293+ LLVMContext &Ctx, CallingConv::ID CallConv,
287294 Type *Ty, SmallVectorImpl<EVT> &ValueVTs,
288- SmallVectorImpl<uint64_t > * Offsets = nullptr ,
295+ SmallVectorImpl<uint64_t > & Offsets,
289296 uint64_t StartingOffset = 0 ) {
290297 SmallVector<EVT, 16 > TempVTs;
291298 SmallVector<uint64_t , 16 > TempOffsets;
292-
293- // Special case for i128 - decompose to (i64, i64)
294- if (Ty->isIntegerTy (128 ) || Ty->isFP128Ty ()) {
295- ValueVTs.append ({MVT::i64 , MVT::i64 });
296-
297- if (Offsets)
298- Offsets->append ({StartingOffset + 0 , StartingOffset + 8 });
299-
300- return ;
301- }
302-
303- // Given a struct type, recursively traverse the elements with custom ComputePTXValueVTs.
304- if (StructType *STy = dyn_cast<StructType>(Ty)) {
305- auto const *SL = DL.getStructLayout (STy);
306- auto ElementNum = 0 ;
307- for (auto *EI : STy->elements ()) {
308- ComputePTXValueVTs (TLI, DL, EI, ValueVTs, Offsets,
309- StartingOffset + SL->getElementOffset (ElementNum));
310- ++ElementNum;
311- }
312- return ;
313- }
314-
315- // Given an array type, recursively traverse the elements with custom ComputePTXValueVTs.
316- if (ArrayType *ATy = dyn_cast<ArrayType>(Ty)) {
317- Type *EltTy = ATy->getElementType ();
318- uint64_t EltSize = DL.getTypeAllocSize (EltTy);
319- for (int I : llvm::seq<int >(ATy->getNumElements ()))
320- ComputePTXValueVTs (TLI, DL, EltTy, ValueVTs, Offsets, StartingOffset + I * EltSize);
321- return ;
322- }
323-
324- // Will split structs and arrays into member types, but will not split vector
325- // types. We do that manually below.
326299 ComputeValueVTs (TLI, DL, Ty, TempVTs, &TempOffsets, StartingOffset);
327300
328- for (auto [VT, Off] : zip (TempVTs, TempOffsets)) {
329- // Split vectors into individual elements that fit into registers.
330- if (VT.isVector ()) {
331- unsigned NumElts = VT.getVectorNumElements ();
332- EVT EltVT = VT.getVectorElementType ();
333- // Below we must maintain power-of-2 sized vectors because
334- // TargetLoweringBase::getVectorTypeBreakdown() which is invoked in
335- // ComputePTXValueVTs() cannot currently break down non-power-of-2 sized
336- // vectors.
337-
338- // If the element type belongs to one of the supported packed vector types
339- // then we can pack multiples of this element into a single register.
340- if (VT == MVT::v2i8) {
341- // We can pack 2 i8s into a single 16-bit register. We only do this for
342- // loads and stores, which is why we have a separate case for it.
343- EltVT = MVT::v2i8;
344- NumElts = 1 ;
345- } else if (VT == MVT::v3i8) {
346- // We can also pack 3 i8s into 32-bit register, leaving the 4th
347- // element undefined.
348- EltVT = MVT::v4i8;
349- NumElts = 1 ;
350- } else if (NumElts > 1 && isPowerOf2_32 (NumElts)) {
351- // Handle default packed types.
352- for (MVT PackedVT : NVPTX::packed_types ()) {
353- const auto NumEltsPerReg = PackedVT.getVectorNumElements ();
354- if (NumElts % NumEltsPerReg == 0 &&
355- EltVT == PackedVT.getVectorElementType ()) {
356- EltVT = PackedVT;
357- NumElts /= NumEltsPerReg;
358- break ;
359- }
360- }
361- }
301+ for (const auto [VT, Off] : zip (TempVTs, TempOffsets)) {
302+ MVT RegisterVT = TLI.getRegisterTypeForCallingConv (Ctx, CallConv, VT);
303+ unsigned NumRegs = TLI.getNumRegistersForCallingConv (Ctx, CallConv, VT);
304+
305+ // Since we actually can load/store b8, we need to ensure that we'll use
306+ // the original sized type for any i8s or i8 vectors.
307+ if (VT.getScalarType () == MVT::i8 ) {
308+ if (RegisterVT == MVT::i16 )
309+ RegisterVT = MVT::i8 ;
310+ else if (RegisterVT == MVT::v2i16)
311+ RegisterVT = MVT::v2i8;
312+ else
313+ assert (RegisterVT == MVT::v4i8 &&
314+ " Expected v4i8, v2i16, or i16 for i8 RegisterVT" );
315+ }
362316
363- for (unsigned J : seq (NumElts)) {
364- ValueVTs.push_back (EltVT);
365- if (Offsets)
366- Offsets->push_back (Off + J * EltVT.getStoreSize ());
367- }
368- } else {
369- ValueVTs.push_back (VT);
370- if (Offsets)
371- Offsets->push_back (Off);
317+ // TODO: This is horribly incorrect for cases where the vector elements are
318+ // not a multiple of bytes (ex i1) and legal or i8. However, this problem
319+ // has existed for as long as NVPTX has and no one has complained, so we'll
320+ // leave it for now.
321+ for (unsigned I : seq (NumRegs)) {
322+ ValueVTs.push_back (RegisterVT);
323+ Offsets.push_back (Off + I * RegisterVT.getStoreSize ());
372324 }
373325 }
374326}
@@ -631,7 +583,9 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
631583 addRegisterClass (MVT::v2f16, &NVPTX::B32RegClass);
632584 addRegisterClass (MVT::bf16 , &NVPTX::B16RegClass);
633585 addRegisterClass (MVT::v2bf16, &NVPTX::B32RegClass);
634- addRegisterClass (MVT::v2f32, &NVPTX::B64RegClass);
586+
587+ if (STI.hasF32x2Instructions ())
588+ addRegisterClass (MVT::v2f32, &NVPTX::B64RegClass);
635589
636590 // Conversion to/from FP16/FP16x2 is always legal.
637591 setOperationAction (ISD::BUILD_VECTOR, MVT::v2f16, Custom);
@@ -672,7 +626,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
672626 setOperationAction (ISD::INSERT_VECTOR_ELT, MVT::v2f32, Expand);
673627 setOperationAction (ISD::VECTOR_SHUFFLE, MVT::v2f32, Expand);
674628 // Need custom lowering in case the index is dynamic.
675- setOperationAction (ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
629+ if (STI.hasF32x2Instructions ())
630+ setOperationAction (ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
676631
677632 // Custom conversions to/from v2i8.
678633 setOperationAction (ISD::BITCAST, MVT::v2i8, Custom);
@@ -1606,7 +1561,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
16061561 } else {
16071562 SmallVector<EVT, 16 > VTs;
16081563 SmallVector<uint64_t , 16 > Offsets;
1609- ComputePTXValueVTs (*this , DL, Arg.Ty , VTs, &Offsets, VAOffset);
1564+ ComputePTXValueVTs (*this , DL, Ctx, CLI.CallConv , Arg.Ty , VTs, Offsets,
1565+ VAOffset);
16101566 assert (VTs.size () == Offsets.size () && " Size mismatch" );
16111567 assert (VTs.size () == ArgOuts.size () && " Size mismatch" );
16121568
@@ -1756,7 +1712,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
17561712 if (!Ins.empty ()) {
17571713 SmallVector<EVT, 16 > VTs;
17581714 SmallVector<uint64_t , 16 > Offsets;
1759- ComputePTXValueVTs (*this , DL, RetTy, VTs, & Offsets);
1715+ ComputePTXValueVTs (*this , DL, Ctx, CLI. CallConv , RetTy, VTs, Offsets);
17601716 assert (VTs.size () == Ins.size () && " Bad value decomposition" );
17611717
17621718 const Align RetAlign = getArgumentAlignment (CB, RetTy, 0 , DL);
@@ -3217,8 +3173,8 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
32173173 if (ValVT != MemVT)
32183174 return SDValue ();
32193175
3220- const auto NumEltsAndEltVT = getVectorLoweringShape (
3221- ValVT, STI. has256BitVectorLoadStore ( N->getAddressSpace () ));
3176+ const auto NumEltsAndEltVT =
3177+ getVectorLoweringShape ( ValVT, STI, N->getAddressSpace ());
32223178 if (!NumEltsAndEltVT)
32233179 return SDValue ();
32243180 const auto [NumElts, EltVT] = NumEltsAndEltVT.value ();
@@ -3386,6 +3342,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
33863342 const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &dl,
33873343 SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
33883344 const DataLayout &DL = DAG.getDataLayout ();
3345+ LLVMContext &Ctx = *DAG.getContext ();
33893346 auto PtrVT = getPointerTy (DAG.getDataLayout ());
33903347
33913348 const Function &F = DAG.getMachineFunction ().getFunction ();
@@ -3457,7 +3414,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34573414 } else {
34583415 SmallVector<EVT, 16 > VTs;
34593416 SmallVector<uint64_t , 16 > Offsets;
3460- ComputePTXValueVTs (*this , DL, Ty, VTs, & Offsets, 0 );
3417+ ComputePTXValueVTs (*this , DL, Ctx, CallConv, Ty, VTs, Offsets);
34613418 assert (VTs.size () == ArgIns.size () && " Size mismatch" );
34623419 assert (VTs.size () == Offsets.size () && " Size mismatch" );
34633420
@@ -3469,7 +3426,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34693426 for (const unsigned NumElts : VI) {
34703427 // i1 is loaded/stored as i8
34713428 const EVT LoadVT = VTs[I] == MVT::i1 ? MVT::i8 : VTs[I];
3472- const EVT VecVT = getVectorizedVT (LoadVT, NumElts, *DAG. getContext () );
3429+ const EVT VecVT = getVectorizedVT (LoadVT, NumElts, Ctx );
34733430
34743431 SDValue VecAddr = DAG.getObjectPtrOffset (
34753432 dl, ArgSymbol, TypeSize::getFixed (Offsets[I]));
@@ -3514,6 +3471,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
35143471 }
35153472
35163473 const DataLayout &DL = DAG.getDataLayout ();
3474+ LLVMContext &Ctx = *DAG.getContext ();
35173475
35183476 const SDValue RetSymbol = DAG.getExternalSymbol (" func_retval0" , MVT::i32 );
35193477 const auto RetAlign = getFunctionParamOptimizedAlign (&F, RetTy, DL);
@@ -3526,7 +3484,7 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
35263484
35273485 SmallVector<EVT, 16 > VTs;
35283486 SmallVector<uint64_t , 16 > Offsets;
3529- ComputePTXValueVTs (*this , DL, RetTy, VTs, & Offsets);
3487+ ComputePTXValueVTs (*this , DL, Ctx, CallConv, RetTy, VTs, Offsets);
35303488 assert (VTs.size () == OutVals.size () && " Bad return value decomposition" );
35313489
35323490 const auto GetRetVal = [&](unsigned I) -> SDValue {
@@ -5985,8 +5943,8 @@ static void replaceLoadVector(SDNode *N, SelectionDAG &DAG,
59855943 if (ResVT != MemVT)
59865944 return ;
59875945
5988- const auto NumEltsAndEltVT = getVectorLoweringShape (
5989- ResVT, STI. has256BitVectorLoadStore ( LD->getAddressSpace () ));
5946+ const auto NumEltsAndEltVT =
5947+ getVectorLoweringShape ( ResVT, STI, LD->getAddressSpace ());
59905948 if (!NumEltsAndEltVT)
59915949 return ;
59925950 const auto [NumElts, EltVT] = NumEltsAndEltVT.value ();
0 commit comments