@@ -4302,65 +4302,66 @@ static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
43024302 if (!SrcVT.isSimple () || SrcVT.getScalarType () != MVT::i1)
43034303 return SDValue ();
43044304
4305- if (Src.getOpcode () != ISD::SETCC || !Src.hasOneUse ())
4306- return SDValue ();
4307-
4308- bool UseLASX;
4309- EVT CmpVT = Src.getOperand (0 ).getValueType ();
4310- EVT EltVT = CmpVT.getVectorElementType ();
4311- if (Subtarget.hasExtLSX () && CmpVT.getSizeInBits () <= 128 )
4312- UseLASX = false ;
4313- else if (Subtarget.has32S () && Subtarget.hasExtLASX () &&
4314- CmpVT.getSizeInBits () <= 256 )
4315- UseLASX = true ;
4316- else
4317- return SDValue ();
4305+ unsigned Opc = ISD::DELETED_NODE;
4306+ // Combine SETCC and BITCAST into [X]VMSK{LT,GE,NE} when possible
4307+ if (Src.getOpcode () == ISD::SETCC && Src.hasOneUse ()) {
4308+ bool UseLASX;
4309+ EVT CmpVT = Src.getOperand (0 ).getValueType ();
4310+ EVT EltVT = CmpVT.getVectorElementType ();
4311+
4312+ if (Subtarget.hasExtLSX () && CmpVT.getSizeInBits () <= 128 )
4313+ UseLASX = false ;
4314+ else if (Subtarget.has32S () && Subtarget.hasExtLASX () &&
4315+ CmpVT.getSizeInBits () <= 256 )
4316+ UseLASX = true ;
4317+ else
4318+ return SDValue ();
43184319
4319- unsigned ISD = ISD::DELETED_NODE ;
4320- SDValue SrcN1 = Src.getOperand (1 );
4321- switch (cast<CondCodeSDNode>(Src. getOperand ( 2 ))-> get ()) {
4322- default :
4323- return SDValue ();
4324- case ISD::SETEQ:
4325- // x == 0 => not (vmsknez.b x )
4326- if ( ISD::isBuildVectorAllZeros (SrcN1. getNode ()) && EltVT == MVT:: i8 )
4327- ISD = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ ;
4328- break ;
4329- case ISD::SETGT:
4330- // x > -1 => vmskgez.b x
4331- if ( ISD::isBuildVectorAllOnes (SrcN1. getNode ()) && EltVT == MVT:: i8 )
4332- ISD = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ ;
4333- break ;
4334- case ISD::SETGE:
4335- // x >= 0 => vmskgez.b x
4336- if ( ISD::isBuildVectorAllZeros (SrcN1. getNode ()) && EltVT == MVT:: i8 )
4337- ISD = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ ;
4338- break ;
4339- case ISD::SETLT:
4340- // x < 0 => vmskltz.{b,h,w,d} x
4341- if ( ISD::isBuildVectorAllZeros (SrcN1. getNode ()) &&
4342- (EltVT == MVT:: i8 || EltVT == MVT::i16 || EltVT == MVT:: i32 ||
4343- EltVT == MVT:: i64 ))
4344- ISD = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ ;
4345- break ;
4346- case ISD::SETLE:
4347- // x <= -1 => vmskltz.{b,h,w,d} x
4348- if ( ISD::isBuildVectorAllOnes (SrcN1. getNode ()) &&
4349- (EltVT == MVT:: i8 || EltVT == MVT::i16 || EltVT == MVT:: i32 ||
4350- EltVT == MVT:: i64 ))
4351- ISD = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ ;
4352- break ;
4353- case ISD::SETNE:
4354- // x != 0 => vmsknez.b x
4355- if ( ISD::isBuildVectorAllZeros (SrcN1. getNode ()) && EltVT == MVT:: i8 )
4356- ISD = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ ;
4357- break ;
4320+ SDValue SrcN1 = Src. getOperand ( 1 ) ;
4321+ switch (cast<CondCodeSDNode>( Src.getOperand (2 ))-> get ()) {
4322+ default :
4323+ break ;
4324+ case ISD::SETEQ:
4325+ // x == 0 => not (vmsknez.b x)
4326+ if ( ISD::isBuildVectorAllZeros (SrcN1. getNode ()) && EltVT == MVT:: i8 )
4327+ Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
4328+ break ;
4329+ case ISD::SETGT:
4330+ // x > -1 => vmskgez.b x
4331+ if ( ISD::isBuildVectorAllOnes (SrcN1. getNode ()) && EltVT == MVT:: i8 )
4332+ Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4333+ break ;
4334+ case ISD::SETGE:
4335+ // x >= 0 => vmskgez.b x
4336+ if ( ISD::isBuildVectorAllZeros (SrcN1. getNode ()) && EltVT == MVT:: i8 )
4337+ Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4338+ break ;
4339+ case ISD::SETLT:
4340+ // x < 0 => vmskltz.{b,h,w,d} x
4341+ if ( ISD::isBuildVectorAllZeros (SrcN1. getNode ()) &&
4342+ (EltVT == MVT:: i8 || EltVT == MVT:: i16 || EltVT == MVT:: i32 ||
4343+ EltVT == MVT::i64 ))
4344+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4345+ break ;
4346+ case ISD::SETLE:
4347+ // x <= -1 => vmskltz.{b,h,w,d} x
4348+ if ( ISD::isBuildVectorAllOnes (SrcN1. getNode ()) &&
4349+ (EltVT == MVT:: i8 || EltVT == MVT:: i16 || EltVT == MVT:: i32 ||
4350+ EltVT == MVT::i64 ))
4351+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4352+ break ;
4353+ case ISD::SETNE:
4354+ // x != 0 => vmsknez.b x
4355+ if ( ISD::isBuildVectorAllZeros (SrcN1. getNode ()) && EltVT == MVT:: i8 )
4356+ Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
4357+ break ;
4358+ }
43584359 }
43594360
4360- if (ISD == ISD::DELETED_NODE)
4361+ if (Opc == ISD::DELETED_NODE)
43614362 return SDValue ();
43624363
4363- SDValue V = DAG.getNode (ISD , DL, MVT::i64 , Src.getOperand (0 ));
4364+ SDValue V = DAG.getNode (Opc , DL, MVT::i64 , Src.getOperand (0 ));
43644365 EVT T = EVT::getIntegerVT (*DAG.getContext (), SrcVT.getVectorNumElements ());
43654366 V = DAG.getZExtOrTrunc (V, DL, T);
43664367 return DAG.getBitcast (VT, V);
0 commit comments