@@ -226,21 +226,20 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
226226 switch (VectorVT.SimpleTy ) {
227227 default :
228228 return std::nullopt ;
229+
229230 case MVT::v4i64:
230231 case MVT::v4f64:
231- case MVT::v8i32:
232- // This is a "native" vector type iff the address space is global
233- // and the target supports 256-bit loads/stores
232+ // This is a "native" vector type iff the address space is global and the
233+ // target supports 256-bit loads/stores
234234 if (!CanLowerTo256Bit)
235235 return std::nullopt ;
236236 LLVM_FALLTHROUGH;
237237 case MVT::v2i8:
238- case MVT::v2i32:
239238 case MVT::v2i64:
240239 case MVT::v2f64:
241- case MVT::v4i32:
242240 // This is a "native" vector type
243241 return std::pair (NumElts, EltVT);
242+
244243 case MVT::v16f16: // <8 x f16x2>
245244 case MVT::v16bf16: // <8 x bf16x2>
246245 case MVT::v16i16: // <8 x i16x2>
@@ -264,12 +263,18 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
264263 case MVT::v16i8: // <4 x i8x4>
265264 PackRegSize = 32 ;
266265 break ;
266+
267267 case MVT::v8f32: // <4 x f32x2>
268+ case MVT::v8i32: // <4 x i32x2>
269+ // This is a "native" vector type iff the address space is global and the
270+ // target supports 256-bit loads/stores
268271 if (!CanLowerTo256Bit)
269272 return std::nullopt ;
270273 LLVM_FALLTHROUGH;
271274 case MVT::v2f32: // <1 x f32x2>
272275 case MVT::v4f32: // <2 x f32x2>
276+ case MVT::v2i32: // <1 x i32x2>
277+ case MVT::v4i32: // <2 x i32x2>
273278 if (!STI.hasF32x2Instructions ())
274279 return std::pair (NumElts, EltVT);
275280 PackRegSize = 64 ;
@@ -590,8 +595,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
590595 addRegisterClass (MVT::bf16 , &NVPTX::B16RegClass);
591596 addRegisterClass (MVT::v2bf16, &NVPTX::B32RegClass);
592597
593- if (STI.hasF32x2Instructions ())
598+ if (STI.hasF32x2Instructions ()) {
594599 addRegisterClass (MVT::v2f32, &NVPTX::B64RegClass);
600+ addRegisterClass (MVT::v2i32, &NVPTX::B64RegClass);
601+ }
595602
596603 // Conversion to/from FP16/FP16x2 is always legal.
597604 setOperationAction (ISD::BUILD_VECTOR, MVT::v2f16, Custom);
@@ -628,12 +635,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
628635 setOperationAction (ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
629636 setOperationAction (ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);
630637
631- // No support for these operations with v2f32.
632- setOperationAction (ISD::INSERT_VECTOR_ELT, MVT::v2f32, Expand);
633- setOperationAction (ISD::VECTOR_SHUFFLE, MVT::v2f32, Expand);
638+ // No support for these operations with v2f32/v2i32
639+ setOperationAction (ISD::INSERT_VECTOR_ELT, { MVT::v2f32, MVT::v2i32} , Expand);
640+ setOperationAction (ISD::VECTOR_SHUFFLE, { MVT::v2f32, MVT::v2i32} , Expand);
634641 // Need custom lowering in case the index is dynamic.
635642 if (STI.hasF32x2Instructions ())
636- setOperationAction (ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
643+ setOperationAction (ISD::EXTRACT_VECTOR_ELT, {MVT::v2f32, MVT::v2i32},
644+ Custom);
637645
638646 // Custom conversions to/from v2i8.
639647 setOperationAction (ISD::BITCAST, MVT::v2i8, Custom);
@@ -661,14 +669,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
661669 // Operations not directly supported by NVPTX.
662670 for (MVT VT : {MVT::bf16 , MVT::f16 , MVT::v2bf16, MVT::v2f16, MVT::f32 ,
663671 MVT::v2f32, MVT::f64 , MVT::i1, MVT::i8 , MVT::i16 , MVT::v2i16,
664- MVT::v4i8, MVT::i32 , MVT::i64 }) {
672+ MVT::v4i8, MVT::i32 , MVT::v2i32, MVT:: i64 }) {
665673 setOperationAction (ISD::SELECT_CC, VT, Expand);
666674 setOperationAction (ISD::BR_CC, VT, Expand);
667675 }
668676
669- // Not directly supported. TLI would attempt to expand operations like
670- // FMINIMUM(v2f32) using invalid SETCC and VSELECT nodes.
671- setOperationAction (ISD::VSELECT, MVT::v2f32, Expand);
677+ // We don't want ops like FMINIMUM or UMAX to be lowered to SETCC+VSELECT.
678+ setOperationAction (ISD::VSELECT, {MVT::v2f32, MVT::v2i32}, Expand);
672679
673680 // Some SIGN_EXTEND_INREG can be done using cvt instruction.
674681 // For others we will expand to a SHL/SRA pair.
@@ -815,7 +822,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
815822 setOperationAction ({ISD::SDIV, ISD::UDIV, ISD::SRA, ISD::SRL, ISD::MULHS,
816823 ISD::MULHU, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
817824 ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::SETCC},
818- MVT::v2i16, Expand);
825+ {MVT::v2i16, MVT::v2i32}, Expand);
826+
827+ // v2i32 is not supported for any arithmetic operations
828+ setOperationAction ({ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX,
829+ ISD::CTPOP, ISD::CTLZ, ISD::ADD, ISD::SUB, ISD::MUL,
830+ ISD::SHL, ISD::SRA, ISD::SRL, ISD::OR, ISD::AND, ISD::XOR,
831+ ISD::SREM, ISD::UREM},
832+ MVT::v2i32, Expand);
819833
820834 setOperationAction (ISD::ADDC, MVT::i32 , Legal);
821835 setOperationAction (ISD::ADDE, MVT::i32 , Legal);
@@ -829,7 +843,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
829843 }
830844
831845 setOperationAction (ISD::CTTZ, MVT::i16 , Expand);
832- setOperationAction (ISD::CTTZ, MVT::v2i16, Expand);
846+ setOperationAction (ISD::CTTZ, { MVT::v2i16, MVT::v2i32} , Expand);
833847 setOperationAction (ISD::CTTZ, MVT::i32 , Expand);
834848 setOperationAction (ISD::CTTZ, MVT::i64 , Expand);
835849
@@ -1071,7 +1085,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
10711085 // Custom lowering for tcgen05.st vector operands
10721086 setOperationAction (ISD::INTRINSIC_VOID,
10731087 {MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
1074- MVT::v32i32, MVT::v64i32, MVT::v128i32},
1088+ MVT::v32i32, MVT::v64i32, MVT::v128i32, MVT::Other },
10751089 Custom);
10761090
10771091 // Enable custom lowering for the following:
@@ -2604,7 +2618,7 @@ static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) {
26042618 return V;
26052619}
26062620
2607- static SDValue LowerTcgen05St (SDValue Op, SelectionDAG &DAG) {
2621+ static SDValue lowerTcgen05St (SDValue Op, SelectionDAG &DAG) {
26082622 SDNode *N = Op.getNode ();
26092623 SDLoc DL (N);
26102624 SmallVector<SDValue, 32 > Ops;
@@ -2719,7 +2733,52 @@ static SDValue LowerTcgen05MMADisableOutputLane(SDValue Op, SelectionDAG &DAG) {
27192733 return Tcgen05MMANode;
27202734}
27212735
2722- static SDValue LowerIntrinsicVoid (SDValue Op, SelectionDAG &DAG) {
2736+ // Lower vector return type of tcgen05.ld intrinsics
2737+ static std::optional<std::pair<SDValue, SDValue>>
2738+ lowerTcgen05Ld (SDNode *N, SelectionDAG &DAG, bool HasOffset = false ) {
2739+ SDLoc DL (N);
2740+ EVT ResVT = N->getValueType (0 );
2741+ if (!ResVT.isVector ())
2742+ return {}; // already legalized.
2743+
2744+ const unsigned NumElts = ResVT.getVectorNumElements ();
2745+
2746+ // Create the return type of the instructions
2747+ SmallVector<EVT, 5 > ListVTs;
2748+ for (unsigned i = 0 ; i < NumElts; ++i)
2749+ ListVTs.push_back (MVT::i32 );
2750+
2751+ ListVTs.push_back (N->getValueType (1 )); // Chain
2752+
2753+ SDVTList ResVTs = DAG.getVTList (ListVTs);
2754+
2755+ SmallVector<SDValue, 8 > Ops{N->getOperand (0 ), N->getOperand (1 ),
2756+ N->getOperand (2 )};
2757+
2758+ if (HasOffset) {
2759+ Ops.push_back (N->getOperand (3 )); // offset
2760+ Ops.push_back (N->getOperand (4 )); // Pack flag
2761+ } else
2762+ Ops.push_back (N->getOperand (3 )); // Pack flag
2763+
2764+ MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
2765+ SDValue NewNode =
2766+ DAG.getMemIntrinsicNode (ISD::INTRINSIC_W_CHAIN, DL, ResVTs, Ops,
2767+ MemSD->getMemoryVT (), MemSD->getMemOperand ());
2768+
2769+ // split the vector result
2770+ SmallVector<SDValue, 4 > ScalarRes;
2771+ for (unsigned i = 0 ; i < NumElts; ++i) {
2772+ SDValue Res = NewNode.getValue (i);
2773+ ScalarRes.push_back (Res);
2774+ }
2775+
2776+ SDValue Chain = NewNode.getValue (NumElts);
2777+ SDValue BuildVector = DAG.getNode (ISD::BUILD_VECTOR, DL, ResVT, ScalarRes);
2778+ return {{BuildVector, Chain}};
2779+ }
2780+
2781+ static SDValue lowerIntrinsicVoid (SDValue Op, SelectionDAG &DAG) {
27232782 SDNode *N = Op.getNode ();
27242783 SDValue Intrin = N->getOperand (1 );
27252784
@@ -2765,7 +2824,7 @@ static SDValue LowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
27652824 case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
27662825 case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
27672826 case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
2768- return LowerTcgen05St (Op, DAG);
2827+ return lowerTcgen05St (Op, DAG);
27692828 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
27702829 case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
27712830 case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
@@ -2867,6 +2926,28 @@ static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) {
28672926 SDValue Selector = (Op->op_end () - 1 )->get ();
28682927 return getPRMT (A, B, Selector, DL, DAG, Mode);
28692928}
2929+
2930+ static SDValue lowerIntrinsicWChain (SDValue Op, SelectionDAG &DAG) {
2931+ switch (Op->getConstantOperandVal (1 )) {
2932+ default :
2933+ return Op;
2934+
2935+ // These tcgen05 intrinsics return a v2i32, which is legal, so we have to
2936+ // lower them through LowerOperation() instead of ReplaceNodeResults().
2937+ case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
2938+ case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
2939+ case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
2940+ if (auto Res = lowerTcgen05Ld (Op.getNode (), DAG))
2941+ return DAG.getMergeValues ({Res->first , Res->second }, SDLoc (Op));
2942+ return SDValue ();
2943+
2944+ case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
2945+ if (auto Res = lowerTcgen05Ld (Op.getNode (), DAG, /* HasOffset=*/ true ))
2946+ return DAG.getMergeValues ({Res->first , Res->second }, SDLoc (Op));
2947+ return SDValue ();
2948+ }
2949+ }
2950+
28702951static SDValue lowerIntrinsicWOChain (SDValue Op, SelectionDAG &DAG) {
28712952 switch (Op->getConstantOperandVal (0 )) {
28722953 default :
@@ -3029,11 +3110,11 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
30293110 case ISD::ADDRSPACECAST:
30303111 return LowerADDRSPACECAST (Op, DAG);
30313112 case ISD::INTRINSIC_W_CHAIN:
3032- return Op ;
3113+ return lowerIntrinsicWChain (Op, DAG) ;
30333114 case ISD::INTRINSIC_WO_CHAIN:
30343115 return lowerIntrinsicWOChain (Op, DAG);
30353116 case ISD::INTRINSIC_VOID:
3036- return LowerIntrinsicVoid (Op, DAG);
3117+ return lowerIntrinsicVoid (Op, DAG);
30373118 case ISD::BUILD_VECTOR:
30383119 return LowerBUILD_VECTOR (Op, DAG);
30393120 case ISD::BITCAST:
@@ -5920,7 +6001,7 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
59206001 IsPTXVectorType (VectorVT.getSimpleVT ()))
59216002 return SDValue (); // Native vector loads already combine nicely w/
59226003 // extract_vector_elt.
5923- // Don't mess with singletons or packed types (v2f32 , v2*16, v4i8 and v8i8),
6004+ // Don't mess with singletons or packed types (v2*32 , v2*16, v4i8 and v8i8),
59246005 // we already handle them OK.
59256006 if (VectorVT.getVectorNumElements () == 1 ||
59266007 NVPTX::isPackedVectorTy (VectorVT) || VectorVT == MVT::v8i8)
@@ -6300,53 +6381,6 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
63006381 DAG.getNode (ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1}));
63016382}
63026383
6303- // Lower vector return type of tcgen05.ld intrinsics
6304- static void ReplaceTcgen05Ld (SDNode *N, SelectionDAG &DAG,
6305- SmallVectorImpl<SDValue> &Results,
6306- bool hasOffset = false ) {
6307- SDLoc DL (N);
6308- EVT ResVT = N->getValueType (0 );
6309- if (!ResVT.isVector ())
6310- return ; // already legalized.
6311-
6312- const unsigned NumElts = ResVT.getVectorNumElements ();
6313-
6314- // Create the return type of the instructions
6315- SmallVector<EVT, 5 > ListVTs;
6316- for (unsigned i = 0 ; i < NumElts; ++i)
6317- ListVTs.push_back (MVT::i32 );
6318-
6319- ListVTs.push_back (N->getValueType (1 )); // Chain
6320-
6321- SDVTList ResVTs = DAG.getVTList (ListVTs);
6322-
6323- SmallVector<SDValue, 8 > Ops{N->getOperand (0 ), N->getOperand (1 ),
6324- N->getOperand (2 )};
6325-
6326- if (hasOffset) {
6327- Ops.push_back (N->getOperand (3 )); // offset
6328- Ops.push_back (N->getOperand (4 )); // Pack flag
6329- } else
6330- Ops.push_back (N->getOperand (3 )); // Pack flag
6331-
6332- MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
6333- SDValue NewNode =
6334- DAG.getMemIntrinsicNode (ISD::INTRINSIC_W_CHAIN, DL, ResVTs, Ops,
6335- MemSD->getMemoryVT (), MemSD->getMemOperand ());
6336-
6337- // split the vector result
6338- SmallVector<SDValue, 4 > ScalarRes;
6339- for (unsigned i = 0 ; i < NumElts; ++i) {
6340- SDValue Res = NewNode.getValue (i);
6341- ScalarRes.push_back (Res);
6342- }
6343-
6344- SDValue Chain = NewNode.getValue (NumElts);
6345- SDValue BuildVector = DAG.getNode (ISD::BUILD_VECTOR, DL, ResVT, ScalarRes);
6346- Results.push_back (BuildVector); // Build Vector
6347- Results.push_back (Chain); // Chain
6348- }
6349-
63506384static void ReplaceINTRINSIC_W_CHAIN (SDNode *N, SelectionDAG &DAG,
63516385 SmallVectorImpl<SDValue> &Results) {
63526386 SDValue Chain = N->getOperand (0 );
@@ -6455,21 +6489,18 @@ static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
64556489 return ;
64566490 }
64576491
6458- case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
64596492 case Intrinsic::nvvm_tcgen05_ld_16x64b_x4:
64606493 case Intrinsic::nvvm_tcgen05_ld_16x64b_x8:
64616494 case Intrinsic::nvvm_tcgen05_ld_16x64b_x16:
64626495 case Intrinsic::nvvm_tcgen05_ld_16x64b_x32:
64636496 case Intrinsic::nvvm_tcgen05_ld_16x64b_x64:
64646497 case Intrinsic::nvvm_tcgen05_ld_16x64b_x128:
6465- case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
64666498 case Intrinsic::nvvm_tcgen05_ld_32x32b_x4:
64676499 case Intrinsic::nvvm_tcgen05_ld_32x32b_x8:
64686500 case Intrinsic::nvvm_tcgen05_ld_32x32b_x16:
64696501 case Intrinsic::nvvm_tcgen05_ld_32x32b_x32:
64706502 case Intrinsic::nvvm_tcgen05_ld_32x32b_x64:
64716503 case Intrinsic::nvvm_tcgen05_ld_32x32b_x128:
6472- case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
64736504 case Intrinsic::nvvm_tcgen05_ld_16x128b_x2:
64746505 case Intrinsic::nvvm_tcgen05_ld_16x128b_x4:
64756506 case Intrinsic::nvvm_tcgen05_ld_16x128b_x8:
@@ -6482,16 +6513,23 @@ static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
64826513 case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
64836514 case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
64846515 case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
6485- return ReplaceTcgen05Ld (N, DAG, Results);
6516+ if (auto Res = lowerTcgen05Ld (N, DAG)) {
6517+ Results.push_back (Res->first );
6518+ Results.push_back (Res->second );
6519+ }
6520+ return ;
64866521
6487- case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
64886522 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4:
64896523 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x8:
64906524 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x16:
64916525 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32:
64926526 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64:
64936527 case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128:
6494- return ReplaceTcgen05Ld (N, DAG, Results, /* Offset */ true );
6528+ if (auto Res = lowerTcgen05Ld (N, DAG, /* HasOffset=*/ true )) {
6529+ Results.push_back (Res->first );
6530+ Results.push_back (Res->second );
6531+ }
6532+ return ;
64956533 }
64966534}
64976535
0 commit comments