@@ -4560,6 +4560,80 @@ static SDValue signExtendBitcastSrcVector(SelectionDAG &DAG, EVT SExtVT,
45604560 llvm_unreachable (" Unexpected node type for vXi1 sign extension" );
45614561}
45624562
4563+ static SDValue
4564+ performSETCC_BITCASTCombine (SDNode *N, SelectionDAG &DAG,
4565+ TargetLowering::DAGCombinerInfo &DCI,
4566+ const LoongArchSubtarget &Subtarget) {
4567+ SDLoc DL (N);
4568+ EVT VT = N->getValueType (0 );
4569+ SDValue Src = N->getOperand (0 );
4570+ EVT SrcVT = Src.getValueType ();
4571+
4572+ if (Src.getOpcode () != ISD::SETCC || !Src.hasOneUse ())
4573+ return SDValue ();
4574+
4575+ bool UseLASX;
4576+ unsigned Opc = ISD::DELETED_NODE;
4577+ EVT CmpVT = Src.getOperand (0 ).getValueType ();
4578+ EVT EltVT = CmpVT.getVectorElementType ();
4579+
4580+ if (Subtarget.hasExtLSX () && CmpVT.getSizeInBits () == 128 )
4581+ UseLASX = false ;
4582+ else if (Subtarget.has32S () && Subtarget.hasExtLASX () &&
4583+ CmpVT.getSizeInBits () == 256 )
4584+ UseLASX = true ;
4585+ else
4586+ return SDValue ();
4587+
4588+ SDValue SrcN1 = Src.getOperand (1 );
4589+ switch (cast<CondCodeSDNode>(Src.getOperand (2 ))->get ()) {
4590+ default :
4591+ break ;
4592+ case ISD::SETEQ:
4593+ // x == 0 => not (vmsknez.b x)
4594+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4595+ Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
4596+ break ;
4597+ case ISD::SETGT:
4598+ // x > -1 => vmskgez.b x
4599+ if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) && EltVT == MVT::i8 )
4600+ Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4601+ break ;
4602+ case ISD::SETGE:
4603+ // x >= 0 => vmskgez.b x
4604+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4605+ Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4606+ break ;
4607+ case ISD::SETLT:
4608+ // x < 0 => vmskltz.{b,h,w,d} x
4609+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) &&
4610+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4611+ EltVT == MVT::i64 ))
4612+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4613+ break ;
4614+ case ISD::SETLE:
4615+ // x <= -1 => vmskltz.{b,h,w,d} x
4616+ if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) &&
4617+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4618+ EltVT == MVT::i64 ))
4619+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4620+ break ;
4621+ case ISD::SETNE:
4622+ // x != 0 => vmsknez.b x
4623+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4624+ Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
4625+ break ;
4626+ }
4627+
4628+ if (Opc == ISD::DELETED_NODE)
4629+ return SDValue ();
4630+
4631+ SDValue V = DAG.getNode (Opc, DL, MVT::i64 , Src.getOperand (0 ));
4632+ EVT T = EVT::getIntegerVT (*DAG.getContext (), SrcVT.getVectorNumElements ());
4633+ V = DAG.getZExtOrTrunc (V, DL, T);
4634+ return DAG.getBitcast (VT, V);
4635+ }
4636+
45634637static SDValue performBITCASTCombine (SDNode *N, SelectionDAG &DAG,
45644638 TargetLowering::DAGCombinerInfo &DCI,
45654639 const LoongArchSubtarget &Subtarget) {
@@ -4574,110 +4648,63 @@ static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
45744648 if (!SrcVT.isSimple () || SrcVT.getScalarType () != MVT::i1)
45754649 return SDValue ();
45764650
4577- unsigned Opc = ISD::DELETED_NODE;
45784651 // Combine SETCC and BITCAST into [X]VMSK{LT,GE,NE} when possible
4652+ SDValue Res = performSETCC_BITCASTCombine (N, DAG, DCI, Subtarget);
4653+ if (Res)
4654+ return Res;
4655+
4656+ // Generate vXi1 using [X]VMSKLTZ
4657+ MVT SExtVT;
4658+ unsigned Opc;
4659+ bool UseLASX = false ;
4660+ bool PropagateSExt = false ;
4661+
45794662 if (Src.getOpcode () == ISD::SETCC && Src.hasOneUse ()) {
4580- bool UseLASX;
45814663 EVT CmpVT = Src.getOperand (0 ).getValueType ();
4582- EVT EltVT = CmpVT.getVectorElementType ();
4583-
4584- if (Subtarget.hasExtLSX () && CmpVT.getSizeInBits () <= 128 )
4585- UseLASX = false ;
4586- else if (Subtarget.has32S () && Subtarget.hasExtLASX () &&
4587- CmpVT.getSizeInBits () <= 256 )
4588- UseLASX = true ;
4589- else
4664+ if (CmpVT.getSizeInBits () > 256 )
45904665 return SDValue ();
4591-
4592- SDValue SrcN1 = Src.getOperand (1 );
4593- switch (cast<CondCodeSDNode>(Src.getOperand (2 ))->get ()) {
4594- default :
4595- break ;
4596- case ISD::SETEQ:
4597- // x == 0 => not (vmsknez.b x)
4598- if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4599- Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
4600- break ;
4601- case ISD::SETGT:
4602- // x > -1 => vmskgez.b x
4603- if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) && EltVT == MVT::i8 )
4604- Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4605- break ;
4606- case ISD::SETGE:
4607- // x >= 0 => vmskgez.b x
4608- if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4609- Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4610- break ;
4611- case ISD::SETLT:
4612- // x < 0 => vmskltz.{b,h,w,d} x
4613- if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) &&
4614- (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4615- EltVT == MVT::i64 ))
4616- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4617- break ;
4618- case ISD::SETLE:
4619- // x <= -1 => vmskltz.{b,h,w,d} x
4620- if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) &&
4621- (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4622- EltVT == MVT::i64 ))
4623- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4624- break ;
4625- case ISD::SETNE:
4626- // x != 0 => vmsknez.b x
4627- if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4628- Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
4629- break ;
4630- }
46314666 }
46324667
4633- // Generate vXi1 using [X]VMSKLTZ
4634- if (Opc == ISD::DELETED_NODE) {
4635- MVT SExtVT;
4636- bool UseLASX = false ;
4637- bool PropagateSExt = false ;
4638- switch (SrcVT.getSimpleVT ().SimpleTy ) {
4639- default :
4640- return SDValue ();
4641- case MVT::v2i1:
4642- SExtVT = MVT::v2i64;
4643- break ;
4644- case MVT::v4i1:
4645- SExtVT = MVT::v4i32;
4646- if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4647- SExtVT = MVT::v4i64;
4648- UseLASX = true ;
4649- PropagateSExt = true ;
4650- }
4651- break ;
4652- case MVT::v8i1:
4653- SExtVT = MVT::v8i16;
4654- if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4655- SExtVT = MVT::v8i32;
4656- UseLASX = true ;
4657- PropagateSExt = true ;
4658- }
4659- break ;
4660- case MVT::v16i1:
4661- SExtVT = MVT::v16i8;
4662- if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4663- SExtVT = MVT::v16i16;
4664- UseLASX = true ;
4665- PropagateSExt = true ;
4666- }
4667- break ;
4668- case MVT::v32i1:
4669- SExtVT = MVT::v32i8;
4668+ switch (SrcVT.getSimpleVT ().SimpleTy ) {
4669+ default :
4670+ return SDValue ();
4671+ case MVT::v2i1:
4672+ SExtVT = MVT::v2i64;
4673+ break ;
4674+ case MVT::v4i1:
4675+ SExtVT = MVT::v4i32;
4676+ if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4677+ SExtVT = MVT::v4i64;
46704678 UseLASX = true ;
4671- break ;
4672- };
4673- if (UseLASX && !Subtarget.has32S () && !Subtarget.hasExtLASX ())
4674- return SDValue ();
4675- Src = PropagateSExt ? signExtendBitcastSrcVector (DAG, SExtVT, Src, DL)
4676- : DAG.getNode (ISD::SIGN_EXTEND, DL, SExtVT, Src);
4677- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4678- } else {
4679- Src = Src.getOperand (0 );
4680- }
4679+ PropagateSExt = true ;
4680+ }
4681+ break ;
4682+ case MVT::v8i1:
4683+ SExtVT = MVT::v8i16;
4684+ if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4685+ SExtVT = MVT::v8i32;
4686+ UseLASX = true ;
4687+ PropagateSExt = true ;
4688+ }
4689+ break ;
4690+ case MVT::v16i1:
4691+ SExtVT = MVT::v16i8;
4692+ if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4693+ SExtVT = MVT::v16i16;
4694+ UseLASX = true ;
4695+ PropagateSExt = true ;
4696+ }
4697+ break ;
4698+ case MVT::v32i1:
4699+ SExtVT = MVT::v32i8;
4700+ UseLASX = true ;
4701+ break ;
4702+ };
4703+ if (UseLASX && !(Subtarget.has32S () && Subtarget.hasExtLASX ()))
4704+ return SDValue ();
4705+ Src = PropagateSExt ? signExtendBitcastSrcVector (DAG, SExtVT, Src, DL)
4706+ : DAG.getNode (ISD::SIGN_EXTEND, DL, SExtVT, Src);
4707+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
46814708
46824709 SDValue V = DAG.getNode (Opc, DL, MVT::i64 , Src);
46834710 EVT T = EVT::getIntegerVT (*DAG.getContext (), SrcVT.getVectorNumElements ());
0 commit comments