@@ -866,6 +866,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
866866 setBF16OperationAction (ISD::FNEG, MVT::v2bf16, Legal, Expand);
867867 // (would be) Library functions.
868868
869+ if (STI.hasF32x2Instructions ()) {
870+ // Handle custom lowering for: v2f32 = OP v2f32, v2f32
871+ for (const auto &Op : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FMA})
872+ setOperationAction (Op, MVT::v2f32, Custom);
873+ // Handle custom lowering for: i64 = bitcast v2f32
874+ setOperationAction (ISD::BITCAST, MVT::v2f32, Custom);
875+ }
876+
869877 // These map to conversion instructions for scalar FP types.
870878 for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
871879 ISD::FROUNDEVEN, ISD::FTRUNC}) {
@@ -1066,6 +1074,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10661074 MAKE_CASE (NVPTXISD::STACKSAVE)
10671075 MAKE_CASE (NVPTXISD::SETP_F16X2)
10681076 MAKE_CASE (NVPTXISD::SETP_BF16X2)
1077+ MAKE_CASE (NVPTXISD::FADD_F32X2)
1078+ MAKE_CASE (NVPTXISD::FSUB_F32X2)
1079+ MAKE_CASE (NVPTXISD::FMUL_F32X2)
1080+ MAKE_CASE (NVPTXISD::FMA_F32X2)
10691081 MAKE_CASE (NVPTXISD::Dummy)
10701082 MAKE_CASE (NVPTXISD::MUL_WIDE_SIGNED)
10711083 MAKE_CASE (NVPTXISD::MUL_WIDE_UNSIGNED)
@@ -2099,24 +2111,58 @@ SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
20992111 // Handle bitcasting from v2i8 without hitting the default promotion
21002112 // strategy which goes through stack memory.
21012113 EVT FromVT = Op->getOperand (0 )->getValueType (0 );
2102- if (FromVT != MVT::v2i8) {
2103- return Op;
2104- }
2105-
2106- // Pack vector elements into i16 and bitcast to final type
2107- SDLoc DL (Op);
2108- SDValue Vec0 = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8 ,
2109- Op->getOperand (0 ), DAG.getIntPtrConstant (0 , DL));
2110- SDValue Vec1 = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8 ,
2111- Op->getOperand (0 ), DAG.getIntPtrConstant (1 , DL));
2112- SDValue Extend0 = DAG.getNode (ISD::ZERO_EXTEND, DL, MVT::i16 , Vec0);
2113- SDValue Extend1 = DAG.getNode (ISD::ZERO_EXTEND, DL, MVT::i16 , Vec1);
2114- SDValue Const8 = DAG.getConstant (8 , DL, MVT::i16 );
2115- SDValue AsInt = DAG.getNode (
2116- ISD::OR, DL, MVT::i16 ,
2117- {Extend0, DAG.getNode (ISD::SHL, DL, MVT::i16 , {Extend1, Const8})});
21182114 EVT ToVT = Op->getValueType (0 );
2119- return MaybeBitcast (DAG, DL, ToVT, AsInt);
2115+ SDLoc DL (Op);
2116+
2117+ if (FromVT == MVT::v2i8) {
2118+ // Pack vector elements into i16 and bitcast to final type
2119+ SDValue Vec0 = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8 ,
2120+ Op->getOperand (0 ), DAG.getIntPtrConstant (0 , DL));
2121+ SDValue Vec1 = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8 ,
2122+ Op->getOperand (0 ), DAG.getIntPtrConstant (1 , DL));
2123+ SDValue Extend0 = DAG.getNode (ISD::ZERO_EXTEND, DL, MVT::i16 , Vec0);
2124+ SDValue Extend1 = DAG.getNode (ISD::ZERO_EXTEND, DL, MVT::i16 , Vec1);
2125+ SDValue Const8 = DAG.getConstant (8 , DL, MVT::i16 );
2126+ SDValue AsInt = DAG.getNode (
2127+ ISD::OR, DL, MVT::i16 ,
2128+ {Extend0, DAG.getNode (ISD::SHL, DL, MVT::i16 , {Extend1, Const8})});
2129+ EVT ToVT = Op->getValueType (0 );
2130+ return MaybeBitcast (DAG, DL, ToVT, AsInt);
2131+ }
2132+
2133+ if (FromVT == MVT::v2f32) {
2134+ assert (ToVT == MVT::i64 );
2135+
2136+ // A bitcast to i64 from v2f32.
2137+ // See if we can legalize the operand.
2138+ const SDValue &Operand = Op->getOperand (0 );
2139+ if (Operand.getOpcode () == ISD::BUILD_VECTOR) {
2140+ const SDValue &BVOp0 = Operand.getOperand (0 );
2141+ const SDValue &BVOp1 = Operand.getOperand (1 );
2142+
2143+ auto CastToAPInt = [](SDValue Op) -> APInt {
2144+ if (Op->isUndef ())
2145+ return APInt (64 , 0 ); // undef values default to 0
2146+ return cast<ConstantFPSDNode>(Op)->getValueAPF ().bitcastToAPInt ().zext (
2147+ 64 );
2148+ };
2149+
2150+ if ((BVOp0->isUndef () || isa<ConstantFPSDNode>(BVOp0)) &&
2151+ (BVOp1->isUndef () || isa<ConstantFPSDNode>(BVOp1))) {
2152+ // cast two constants
2153+ APInt Value (64 , 0 );
2154+ Value = CastToAPInt (BVOp0) | CastToAPInt (BVOp1).shl (32 );
2155+ SDValue Const = DAG.getConstant (Value, DL, MVT::i64 );
2156+ return DAG.getBitcast (ToVT, Const);
2157+ }
2158+
2159+ // otherwise build an i64
2160+ return DAG.getNode (ISD::BUILD_PAIR, DL, MVT::i64 ,
2161+ DAG.getBitcast (MVT::i32 , BVOp0),
2162+ DAG.getBitcast (MVT::i32 , BVOp1));
2163+ }
2164+ }
2165+ return Op;
21202166}
21212167
21222168// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
@@ -3055,6 +3101,13 @@ bool NVPTXTargetLowering::splitValueIntoRegisterParts(
30553101 return false ;
30563102}
30573103
3104+ const TargetRegisterClass *
3105+ NVPTXTargetLowering::getRegClassFor (MVT VT, bool isDivergent) const {
3106+ if (VT == MVT::v2f32)
3107+ return &NVPTX::Int64RegsRegClass;
3108+ return TargetLowering::getRegClassFor (VT, isDivergent);
3109+ }
3110+
30583111// This creates target external symbol for a function parameter.
30593112// Name of the symbol is composed from its index and the function name.
30603113// Negative index corresponds to special parameter (unsized array) used for
@@ -5055,10 +5108,10 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
50555108 IsPTXVectorType (VectorVT.getSimpleVT ()))
50565109 return SDValue (); // Native vector loads already combine nicely w/
50575110 // extract_vector_elt.
5058- // Don't mess with singletons or v2*16, v4i8 and v8i8 types, we already
5111+ // Don't mess with singletons or v2*16, v4i8, v8i8, or v2f32 types, we already
50595112 // handle them OK.
50605113 if (VectorVT.getVectorNumElements () == 1 || Isv2x16VT (VectorVT) ||
5061- VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8)
5114+ VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8 || VectorVT == MVT::v2f32 )
50625115 return SDValue ();
50635116
50645117 // Don't mess with undef values as sra may be simplified to 0, not undef.
@@ -5478,6 +5531,45 @@ static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
54785531 Results.push_back (NewValue.getValue (3 ));
54795532}
54805533
5534+ static void ReplaceF32x2Op (SDNode *N, SelectionDAG &DAG,
5535+ SmallVectorImpl<SDValue> &Results,
5536+ bool UseFTZ) {
5537+ SDLoc DL (N);
5538+ EVT OldResultTy = N->getValueType (0 ); // <2 x float>
5539+ assert (OldResultTy == MVT::v2f32 && " Unexpected result type for F32x2 op!" );
5540+
5541+ SmallVector<SDValue> NewOps;
5542+
5543+ // whether we use FTZ (TODO)
5544+
5545+ // replace with NVPTX F32x2 op:
5546+ unsigned Opcode;
5547+ switch (N->getOpcode ()) {
5548+ case ISD::FADD:
5549+ Opcode = NVPTXISD::FADD_F32X2;
5550+ break ;
5551+ case ISD::FSUB:
5552+ Opcode = NVPTXISD::FSUB_F32X2;
5553+ break ;
5554+ case ISD::FMUL:
5555+ Opcode = NVPTXISD::FMUL_F32X2;
5556+ break ;
5557+ case ISD::FMA:
5558+ Opcode = NVPTXISD::FMA_F32X2;
5559+ break ;
5560+ default :
5561+ llvm_unreachable (" Unexpected opcode" );
5562+ }
5563+
5564+ // bitcast operands: <2 x float> -> i64
5565+ for (const SDValue &Op : N->ops ())
5566+ NewOps.push_back (DAG.getNode (ISD::BITCAST, DL, MVT::i64 , Op));
5567+
5568+ // cast i64 result of new op back to <2 x float>
5569+ SDValue NewValue = DAG.getNode (Opcode, DL, MVT::i64 , NewOps);
5570+ Results.push_back (DAG.getBitcast (OldResultTy, NewValue));
5571+ }
5572+
54815573void NVPTXTargetLowering::ReplaceNodeResults (
54825574 SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
54835575 switch (N->getOpcode ()) {
@@ -5495,6 +5587,12 @@ void NVPTXTargetLowering::ReplaceNodeResults(
54955587 case ISD::CopyFromReg:
54965588 ReplaceCopyFromReg_128 (N, DAG, Results);
54975589 return ;
5590+ case ISD::FADD:
5591+ case ISD::FSUB:
5592+ case ISD::FMUL:
5593+ case ISD::FMA:
5594+ ReplaceF32x2Op (N, DAG, Results, useF32FTZ (DAG.getMachineFunction ()));
5595+ return ;
54985596 }
54995597}
55005598
0 commit comments