@@ -391,8 +391,10 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
391391
392392 // Set DAG combine for 'LSX' feature.
393393
394- if (Subtarget.hasExtLSX ())
394+ if (Subtarget.hasExtLSX ()) {
395395 setTargetDAGCombine (ISD::INTRINSIC_WO_CHAIN);
396+ setTargetDAGCombine (ISD::BITCAST);
397+ }
396398
397399 // Compute derived properties from the register classes.
398400 computeRegisterProperties (Subtarget.getRegisterInfo ());
@@ -4329,6 +4331,85 @@ static SDValue performSRLCombine(SDNode *N, SelectionDAG &DAG,
43294331 return SDValue ();
43304332}
43314333
4334+ static SDValue performBITCASTCombine (SDNode *N, SelectionDAG &DAG,
4335+ TargetLowering::DAGCombinerInfo &DCI,
4336+ const LoongArchSubtarget &Subtarget) {
4337+ SDLoc DL (N);
4338+ EVT VT = N->getValueType (0 );
4339+ SDValue Src = N->getOperand (0 );
4340+ EVT SrcVT = Src.getValueType ();
4341+
4342+ if (!DCI.isBeforeLegalizeOps ())
4343+ return SDValue ();
4344+
4345+ if (!SrcVT.isSimple () || SrcVT.getScalarType () != MVT::i1)
4346+ return SDValue ();
4347+
4348+ unsigned Opc = ISD::DELETED_NODE;
4349+ // Combine SETCC and BITCAST into [X]VMSK{LT,GE,NE} when possible
4350+ if (Src.getOpcode () == ISD::SETCC && Src.hasOneUse ()) {
4351+ bool UseLASX;
4352+ EVT CmpVT = Src.getOperand (0 ).getValueType ();
4353+ EVT EltVT = CmpVT.getVectorElementType ();
4354+
4355+ if (Subtarget.hasExtLSX () && CmpVT.getSizeInBits () <= 128 )
4356+ UseLASX = false ;
4357+ else if (Subtarget.has32S () && Subtarget.hasExtLASX () &&
4358+ CmpVT.getSizeInBits () <= 256 )
4359+ UseLASX = true ;
4360+ else
4361+ return SDValue ();
4362+
4363+ SDValue SrcN1 = Src.getOperand (1 );
4364+ switch (cast<CondCodeSDNode>(Src.getOperand (2 ))->get ()) {
4365+ default :
4366+ break ;
4367+ case ISD::SETEQ:
4368+ // x == 0 => not (vmsknez.b x)
4369+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4370+ Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
4371+ break ;
4372+ case ISD::SETGT:
4373+ // x > -1 => vmskgez.b x
4374+ if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) && EltVT == MVT::i8 )
4375+ Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4376+ break ;
4377+ case ISD::SETGE:
4378+ // x >= 0 => vmskgez.b x
4379+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4380+ Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4381+ break ;
4382+ case ISD::SETLT:
4383+ // x < 0 => vmskltz.{b,h,w,d} x
4384+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) &&
4385+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4386+ EltVT == MVT::i64 ))
4387+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4388+ break ;
4389+ case ISD::SETLE:
4390+ // x <= -1 => vmskltz.{b,h,w,d} x
4391+ if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) &&
4392+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4393+ EltVT == MVT::i64 ))
4394+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4395+ break ;
4396+ case ISD::SETNE:
4397+ // x != 0 => vmsknez.b x
4398+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4399+ Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
4400+ break ;
4401+ }
4402+ }
4403+
4404+ if (Opc == ISD::DELETED_NODE)
4405+ return SDValue ();
4406+
4407+ SDValue V = DAG.getNode (Opc, DL, MVT::i64 , Src.getOperand (0 ));
4408+ EVT T = EVT::getIntegerVT (*DAG.getContext (), SrcVT.getVectorNumElements ());
4409+ V = DAG.getZExtOrTrunc (V, DL, T);
4410+ return DAG.getBitcast (VT, V);
4411+ }
4412+
43324413static SDValue performORCombine (SDNode *N, SelectionDAG &DAG,
43334414 TargetLowering::DAGCombinerInfo &DCI,
43344415 const LoongArchSubtarget &Subtarget) {
@@ -5373,6 +5454,8 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
53735454 return performSETCCCombine (N, DAG, DCI, Subtarget);
53745455 case ISD::SRL:
53755456 return performSRLCombine (N, DAG, DCI, Subtarget);
5457+ case ISD::BITCAST:
5458+ return performBITCASTCombine (N, DAG, DCI, Subtarget);
53765459 case LoongArchISD::BITREV_W:
53775460 return performBITREV_WCombine (N, DAG, DCI, Subtarget);
53785461 case ISD::INTRINSIC_WO_CHAIN:
@@ -5663,6 +5746,120 @@ static MachineBasicBlock *emitPseudoCTPOP(MachineInstr &MI,
56635746 return BB;
56645747}
56655748
5749+ static MachineBasicBlock *
5750+ emitPseudoVMSKCOND (MachineInstr &MI, MachineBasicBlock *BB,
5751+ const LoongArchSubtarget &Subtarget) {
5752+ const TargetInstrInfo *TII = Subtarget.getInstrInfo ();
5753+ const TargetRegisterClass *RC = &LoongArch::LSX128RegClass;
5754+ const LoongArchRegisterInfo *TRI = Subtarget.getRegisterInfo ();
5755+ MachineRegisterInfo &MRI = BB->getParent ()->getRegInfo ();
5756+ Register Dst = MI.getOperand (0 ).getReg ();
5757+ Register Src = MI.getOperand (1 ).getReg ();
5758+ DebugLoc DL = MI.getDebugLoc ();
5759+ unsigned EleBits = 8 ;
5760+ unsigned NotOpc = 0 ;
5761+ unsigned MskOpc;
5762+
5763+ switch (MI.getOpcode ()) {
5764+ default :
5765+ llvm_unreachable (" Unexpected opcode" );
5766+ case LoongArch::PseudoVMSKLTZ_B:
5767+ MskOpc = LoongArch::VMSKLTZ_B;
5768+ break ;
5769+ case LoongArch::PseudoVMSKLTZ_H:
5770+ MskOpc = LoongArch::VMSKLTZ_H;
5771+ EleBits = 16 ;
5772+ break ;
5773+ case LoongArch::PseudoVMSKLTZ_W:
5774+ MskOpc = LoongArch::VMSKLTZ_W;
5775+ EleBits = 32 ;
5776+ break ;
5777+ case LoongArch::PseudoVMSKLTZ_D:
5778+ MskOpc = LoongArch::VMSKLTZ_D;
5779+ EleBits = 64 ;
5780+ break ;
5781+ case LoongArch::PseudoVMSKGEZ_B:
5782+ MskOpc = LoongArch::VMSKGEZ_B;
5783+ break ;
5784+ case LoongArch::PseudoVMSKEQZ_B:
5785+ MskOpc = LoongArch::VMSKNZ_B;
5786+ NotOpc = LoongArch::VNOR_V;
5787+ break ;
5788+ case LoongArch::PseudoVMSKNEZ_B:
5789+ MskOpc = LoongArch::VMSKNZ_B;
5790+ break ;
5791+ case LoongArch::PseudoXVMSKLTZ_B:
5792+ MskOpc = LoongArch::XVMSKLTZ_B;
5793+ RC = &LoongArch::LASX256RegClass;
5794+ break ;
5795+ case LoongArch::PseudoXVMSKLTZ_H:
5796+ MskOpc = LoongArch::XVMSKLTZ_H;
5797+ RC = &LoongArch::LASX256RegClass;
5798+ EleBits = 16 ;
5799+ break ;
5800+ case LoongArch::PseudoXVMSKLTZ_W:
5801+ MskOpc = LoongArch::XVMSKLTZ_W;
5802+ RC = &LoongArch::LASX256RegClass;
5803+ EleBits = 32 ;
5804+ break ;
5805+ case LoongArch::PseudoXVMSKLTZ_D:
5806+ MskOpc = LoongArch::XVMSKLTZ_D;
5807+ RC = &LoongArch::LASX256RegClass;
5808+ EleBits = 64 ;
5809+ break ;
5810+ case LoongArch::PseudoXVMSKGEZ_B:
5811+ MskOpc = LoongArch::XVMSKGEZ_B;
5812+ RC = &LoongArch::LASX256RegClass;
5813+ break ;
5814+ case LoongArch::PseudoXVMSKEQZ_B:
5815+ MskOpc = LoongArch::XVMSKNZ_B;
5816+ NotOpc = LoongArch::XVNOR_V;
5817+ RC = &LoongArch::LASX256RegClass;
5818+ break ;
5819+ case LoongArch::PseudoXVMSKNEZ_B:
5820+ MskOpc = LoongArch::XVMSKNZ_B;
5821+ RC = &LoongArch::LASX256RegClass;
5822+ break ;
5823+ }
5824+
5825+ Register Msk = MRI.createVirtualRegister (RC);
5826+ if (NotOpc) {
5827+ Register Tmp = MRI.createVirtualRegister (RC);
5828+ BuildMI (*BB, MI, DL, TII->get (MskOpc), Tmp).addReg (Src);
5829+ BuildMI (*BB, MI, DL, TII->get (NotOpc), Msk)
5830+ .addReg (Tmp, RegState::Kill)
5831+ .addReg (Tmp, RegState::Kill);
5832+ } else {
5833+ BuildMI (*BB, MI, DL, TII->get (MskOpc), Msk).addReg (Src);
5834+ }
5835+
5836+ if (TRI->getRegSizeInBits (*RC) > 128 ) {
5837+ Register Lo = MRI.createVirtualRegister (&LoongArch::GPRRegClass);
5838+ Register Hi = MRI.createVirtualRegister (&LoongArch::GPRRegClass);
5839+ BuildMI (*BB, MI, DL, TII->get (LoongArch::XVPICKVE2GR_WU), Lo)
5840+ .addReg (Msk)
5841+ .addImm (0 );
5842+ BuildMI (*BB, MI, DL, TII->get (LoongArch::XVPICKVE2GR_WU), Hi)
5843+ .addReg (Msk, RegState::Kill)
5844+ .addImm (4 );
5845+ BuildMI (*BB, MI, DL,
5846+ TII->get (Subtarget.is64Bit () ? LoongArch::BSTRINS_D
5847+ : LoongArch::BSTRINS_W),
5848+ Dst)
5849+ .addReg (Lo, RegState::Kill)
5850+ .addReg (Hi, RegState::Kill)
5851+ .addImm (256 / EleBits - 1 )
5852+ .addImm (128 / EleBits);
5853+ } else {
5854+ BuildMI (*BB, MI, DL, TII->get (LoongArch::VPICKVE2GR_HU), Dst)
5855+ .addReg (Msk, RegState::Kill)
5856+ .addImm (0 );
5857+ }
5858+
5859+ MI.eraseFromParent ();
5860+ return BB;
5861+ }
5862+
56665863static bool isSelectPseudo (MachineInstr &MI) {
56675864 switch (MI.getOpcode ()) {
56685865 default :
@@ -5869,6 +6066,21 @@ MachineBasicBlock *LoongArchTargetLowering::EmitInstrWithCustomInserter(
58696066 return emitPseudoXVINSGR2VR (MI, BB, Subtarget);
58706067 case LoongArch::PseudoCTPOP:
58716068 return emitPseudoCTPOP (MI, BB, Subtarget);
6069+ case LoongArch::PseudoVMSKLTZ_B:
6070+ case LoongArch::PseudoVMSKLTZ_H:
6071+ case LoongArch::PseudoVMSKLTZ_W:
6072+ case LoongArch::PseudoVMSKLTZ_D:
6073+ case LoongArch::PseudoVMSKGEZ_B:
6074+ case LoongArch::PseudoVMSKEQZ_B:
6075+ case LoongArch::PseudoVMSKNEZ_B:
6076+ case LoongArch::PseudoXVMSKLTZ_B:
6077+ case LoongArch::PseudoXVMSKLTZ_H:
6078+ case LoongArch::PseudoXVMSKLTZ_W:
6079+ case LoongArch::PseudoXVMSKLTZ_D:
6080+ case LoongArch::PseudoXVMSKGEZ_B:
6081+ case LoongArch::PseudoXVMSKEQZ_B:
6082+ case LoongArch::PseudoXVMSKNEZ_B:
6083+ return emitPseudoVMSKCOND (MI, BB, Subtarget);
58726084 case TargetOpcode::STATEPOINT:
58736085 // STATEPOINT is a pseudo instruction which has no implicit defs/uses
58746086 // while bl call instruction (where statepoint will be lowered at the
@@ -5990,6 +6202,14 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
59906202 NODE_NAME_CASE (VBSLL)
59916203 NODE_NAME_CASE (VBSRL)
59926204 NODE_NAME_CASE (VLDREPL)
6205+ NODE_NAME_CASE (VMSKLTZ)
6206+ NODE_NAME_CASE (VMSKGEZ)
6207+ NODE_NAME_CASE (VMSKEQZ)
6208+ NODE_NAME_CASE (VMSKNEZ)
6209+ NODE_NAME_CASE (XVMSKLTZ)
6210+ NODE_NAME_CASE (XVMSKGEZ)
6211+ NODE_NAME_CASE (XVMSKEQZ)
6212+ NODE_NAME_CASE (XVMSKNEZ)
59936213 }
59946214#undef NODE_NAME_CASE
59956215 return nullptr ;
0 commit comments