@@ -4563,6 +4563,80 @@ static SDValue signExtendBitcastSrcVector(SelectionDAG &DAG, EVT SExtVT,
4563
4563
llvm_unreachable (" Unexpected node type for vXi1 sign extension" );
4564
4564
}
4565
4565
4566
+ static SDValue
4567
+ performSETCC_BITCASTCombine (SDNode *N, SelectionDAG &DAG,
4568
+ TargetLowering::DAGCombinerInfo &DCI,
4569
+ const LoongArchSubtarget &Subtarget) {
4570
+ SDLoc DL (N);
4571
+ EVT VT = N->getValueType (0 );
4572
+ SDValue Src = N->getOperand (0 );
4573
+ EVT SrcVT = Src.getValueType ();
4574
+
4575
+ if (Src.getOpcode () != ISD::SETCC || !Src.hasOneUse ())
4576
+ return SDValue ();
4577
+
4578
+ bool UseLASX;
4579
+ unsigned Opc = ISD::DELETED_NODE;
4580
+ EVT CmpVT = Src.getOperand (0 ).getValueType ();
4581
+ EVT EltVT = CmpVT.getVectorElementType ();
4582
+
4583
+ if (Subtarget.hasExtLSX () && CmpVT.getSizeInBits () == 128 )
4584
+ UseLASX = false ;
4585
+ else if (Subtarget.has32S () && Subtarget.hasExtLASX () &&
4586
+ CmpVT.getSizeInBits () == 256 )
4587
+ UseLASX = true ;
4588
+ else
4589
+ return SDValue ();
4590
+
4591
+ SDValue SrcN1 = Src.getOperand (1 );
4592
+ switch (cast<CondCodeSDNode>(Src.getOperand (2 ))->get ()) {
4593
+ default :
4594
+ break ;
4595
+ case ISD::SETEQ:
4596
+ // x == 0 => not (vmsknez.b x)
4597
+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4598
+ Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
4599
+ break ;
4600
+ case ISD::SETGT:
4601
+ // x > -1 => vmskgez.b x
4602
+ if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) && EltVT == MVT::i8 )
4603
+ Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4604
+ break ;
4605
+ case ISD::SETGE:
4606
+ // x >= 0 => vmskgez.b x
4607
+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4608
+ Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4609
+ break ;
4610
+ case ISD::SETLT:
4611
+ // x < 0 => vmskltz.{b,h,w,d} x
4612
+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) &&
4613
+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4614
+ EltVT == MVT::i64 ))
4615
+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4616
+ break ;
4617
+ case ISD::SETLE:
4618
+ // x <= -1 => vmskltz.{b,h,w,d} x
4619
+ if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) &&
4620
+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4621
+ EltVT == MVT::i64 ))
4622
+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4623
+ break ;
4624
+ case ISD::SETNE:
4625
+ // x != 0 => vmsknez.b x
4626
+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4627
+ Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
4628
+ break ;
4629
+ }
4630
+
4631
+ if (Opc == ISD::DELETED_NODE)
4632
+ return SDValue ();
4633
+
4634
+ SDValue V = DAG.getNode (Opc, DL, MVT::i64 , Src.getOperand (0 ));
4635
+ EVT T = EVT::getIntegerVT (*DAG.getContext (), SrcVT.getVectorNumElements ());
4636
+ V = DAG.getZExtOrTrunc (V, DL, T);
4637
+ return DAG.getBitcast (VT, V);
4638
+ }
4639
+
4566
4640
static SDValue performBITCASTCombine (SDNode *N, SelectionDAG &DAG,
4567
4641
TargetLowering::DAGCombinerInfo &DCI,
4568
4642
const LoongArchSubtarget &Subtarget) {
@@ -4577,110 +4651,63 @@ static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
4577
4651
if (!SrcVT.isSimple () || SrcVT.getScalarType () != MVT::i1)
4578
4652
return SDValue ();
4579
4653
4580
- unsigned Opc = ISD::DELETED_NODE;
4581
4654
// Combine SETCC and BITCAST into [X]VMSK{LT,GE,NE} when possible
4655
+ SDValue Res = performSETCC_BITCASTCombine (N, DAG, DCI, Subtarget);
4656
+ if (Res)
4657
+ return Res;
4658
+
4659
+ // Generate vXi1 using [X]VMSKLTZ
4660
+ MVT SExtVT;
4661
+ unsigned Opc;
4662
+ bool UseLASX = false ;
4663
+ bool PropagateSExt = false ;
4664
+
4582
4665
if (Src.getOpcode () == ISD::SETCC && Src.hasOneUse ()) {
4583
- bool UseLASX;
4584
4666
EVT CmpVT = Src.getOperand (0 ).getValueType ();
4585
- EVT EltVT = CmpVT.getVectorElementType ();
4586
-
4587
- if (Subtarget.hasExtLSX () && CmpVT.getSizeInBits () <= 128 )
4588
- UseLASX = false ;
4589
- else if (Subtarget.has32S () && Subtarget.hasExtLASX () &&
4590
- CmpVT.getSizeInBits () <= 256 )
4591
- UseLASX = true ;
4592
- else
4667
+ if (CmpVT.getSizeInBits () > 256 )
4593
4668
return SDValue ();
4594
-
4595
- SDValue SrcN1 = Src.getOperand (1 );
4596
- switch (cast<CondCodeSDNode>(Src.getOperand (2 ))->get ()) {
4597
- default :
4598
- break ;
4599
- case ISD::SETEQ:
4600
- // x == 0 => not (vmsknez.b x)
4601
- if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4602
- Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
4603
- break ;
4604
- case ISD::SETGT:
4605
- // x > -1 => vmskgez.b x
4606
- if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) && EltVT == MVT::i8 )
4607
- Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4608
- break ;
4609
- case ISD::SETGE:
4610
- // x >= 0 => vmskgez.b x
4611
- if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4612
- Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4613
- break ;
4614
- case ISD::SETLT:
4615
- // x < 0 => vmskltz.{b,h,w,d} x
4616
- if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) &&
4617
- (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4618
- EltVT == MVT::i64 ))
4619
- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4620
- break ;
4621
- case ISD::SETLE:
4622
- // x <= -1 => vmskltz.{b,h,w,d} x
4623
- if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) &&
4624
- (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4625
- EltVT == MVT::i64 ))
4626
- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4627
- break ;
4628
- case ISD::SETNE:
4629
- // x != 0 => vmsknez.b x
4630
- if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4631
- Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
4632
- break ;
4633
- }
4634
4669
}
4635
4670
4636
- // Generate vXi1 using [X]VMSKLTZ
4637
- if (Opc == ISD::DELETED_NODE) {
4638
- MVT SExtVT;
4639
- bool UseLASX = false ;
4640
- bool PropagateSExt = false ;
4641
- switch (SrcVT.getSimpleVT ().SimpleTy ) {
4642
- default :
4643
- return SDValue ();
4644
- case MVT::v2i1:
4645
- SExtVT = MVT::v2i64;
4646
- break ;
4647
- case MVT::v4i1:
4648
- SExtVT = MVT::v4i32;
4649
- if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4650
- SExtVT = MVT::v4i64;
4651
- UseLASX = true ;
4652
- PropagateSExt = true ;
4653
- }
4654
- break ;
4655
- case MVT::v8i1:
4656
- SExtVT = MVT::v8i16;
4657
- if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4658
- SExtVT = MVT::v8i32;
4659
- UseLASX = true ;
4660
- PropagateSExt = true ;
4661
- }
4662
- break ;
4663
- case MVT::v16i1:
4664
- SExtVT = MVT::v16i8;
4665
- if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4666
- SExtVT = MVT::v16i16;
4667
- UseLASX = true ;
4668
- PropagateSExt = true ;
4669
- }
4670
- break ;
4671
- case MVT::v32i1:
4672
- SExtVT = MVT::v32i8;
4671
+ switch (SrcVT.getSimpleVT ().SimpleTy ) {
4672
+ default :
4673
+ return SDValue ();
4674
+ case MVT::v2i1:
4675
+ SExtVT = MVT::v2i64;
4676
+ break ;
4677
+ case MVT::v4i1:
4678
+ SExtVT = MVT::v4i32;
4679
+ if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4680
+ SExtVT = MVT::v4i64;
4673
4681
UseLASX = true ;
4674
- break ;
4675
- };
4676
- if (UseLASX && !Subtarget.has32S () && !Subtarget.hasExtLASX ())
4677
- return SDValue ();
4678
- Src = PropagateSExt ? signExtendBitcastSrcVector (DAG, SExtVT, Src, DL)
4679
- : DAG.getNode (ISD::SIGN_EXTEND, DL, SExtVT, Src);
4680
- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4681
- } else {
4682
- Src = Src.getOperand (0 );
4683
- }
4682
+ PropagateSExt = true ;
4683
+ }
4684
+ break ;
4685
+ case MVT::v8i1:
4686
+ SExtVT = MVT::v8i16;
4687
+ if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4688
+ SExtVT = MVT::v8i32;
4689
+ UseLASX = true ;
4690
+ PropagateSExt = true ;
4691
+ }
4692
+ break ;
4693
+ case MVT::v16i1:
4694
+ SExtVT = MVT::v16i8;
4695
+ if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4696
+ SExtVT = MVT::v16i16;
4697
+ UseLASX = true ;
4698
+ PropagateSExt = true ;
4699
+ }
4700
+ break ;
4701
+ case MVT::v32i1:
4702
+ SExtVT = MVT::v32i8;
4703
+ UseLASX = true ;
4704
+ break ;
4705
+ };
4706
+ if (UseLASX && !(Subtarget.has32S () && Subtarget.hasExtLASX ()))
4707
+ return SDValue ();
4708
+ Src = PropagateSExt ? signExtendBitcastSrcVector (DAG, SExtVT, Src, DL)
4709
+ : DAG.getNode (ISD::SIGN_EXTEND, DL, SExtVT, Src);
4710
+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4684
4711
4685
4712
SDValue V = DAG.getNode (Opc, DL, MVT::i64 , Src);
4686
4713
EVT T = EVT::getIntegerVT (*DAG.getContext (), SrcVT.getVectorNumElements ());
0 commit comments