@@ -4423,6 +4423,62 @@ static SDValue performSRLCombine(SDNode *N, SelectionDAG &DAG,
44234423 return SDValue ();
44244424}
44254425
4426+ // Helper to peek through bitops/trunc/setcc to determine size of source vector.
4427+ // Allows BITCASTCombine to determine what size vector generated a <X x i1>.
4428+ static bool checkBitcastSrcVectorSize (SDValue Src, unsigned Size,
4429+ unsigned Depth) {
4430+ // Limit recursion.
4431+ if (Depth >= SelectionDAG::MaxRecursionDepth)
4432+ return false ;
4433+ switch (Src.getOpcode ()) {
4434+ case ISD::SETCC:
4435+ case ISD::TRUNCATE:
4436+ return Src.getOperand (0 ).getValueSizeInBits () == Size;
4437+ case ISD::FREEZE:
4438+ return checkBitcastSrcVectorSize (Src.getOperand (0 ), Size, Depth + 1 );
4439+ case ISD::AND:
4440+ case ISD::XOR:
4441+ case ISD::OR:
4442+ return checkBitcastSrcVectorSize (Src.getOperand (0 ), Size, Depth + 1 ) &&
4443+ checkBitcastSrcVectorSize (Src.getOperand (1 ), Size, Depth + 1 );
4444+ case ISD::SELECT:
4445+ case ISD::VSELECT:
4446+ return Src.getOperand (0 ).getScalarValueSizeInBits () == 1 &&
4447+ checkBitcastSrcVectorSize (Src.getOperand (1 ), Size, Depth + 1 ) &&
4448+ checkBitcastSrcVectorSize (Src.getOperand (2 ), Size, Depth + 1 );
4449+ case ISD::BUILD_VECTOR:
4450+ return ISD::isBuildVectorAllZeros (Src.getNode ()) ||
4451+ ISD::isBuildVectorAllOnes (Src.getNode ());
4452+ }
4453+ return false ;
4454+ }
4455+
4456+ // Helper to push sign extension of vXi1 SETCC result through bitops.
4457+ static SDValue signExtendBitcastSrcVector (SelectionDAG &DAG, EVT SExtVT,
4458+ SDValue Src, const SDLoc &DL) {
4459+ switch (Src.getOpcode ()) {
4460+ case ISD::SETCC:
4461+ case ISD::FREEZE:
4462+ case ISD::TRUNCATE:
4463+ case ISD::BUILD_VECTOR:
4464+ return DAG.getNode (ISD::SIGN_EXTEND, DL, SExtVT, Src);
4465+ case ISD::AND:
4466+ case ISD::XOR:
4467+ case ISD::OR:
4468+ return DAG.getNode (
4469+ Src.getOpcode (), DL, SExtVT,
4470+ signExtendBitcastSrcVector (DAG, SExtVT, Src.getOperand (0 ), DL),
4471+ signExtendBitcastSrcVector (DAG, SExtVT, Src.getOperand (1 ), DL));
4472+ case ISD::SELECT:
4473+ case ISD::VSELECT:
4474+ return DAG.getSelect (
4475+ DL, SExtVT, Src.getOperand (0 ),
4476+ signExtendBitcastSrcVector (DAG, SExtVT, Src.getOperand (1 ), DL),
4477+ signExtendBitcastSrcVector (DAG, SExtVT, Src.getOperand (2 ), DL));
4478+ }
4479+ llvm_unreachable (" Unexpected node type for vXi1 sign extension" );
4480+ }
4481+
44264482static SDValue performBITCASTCombine (SDNode *N, SelectionDAG &DAG,
44274483 TargetLowering::DAGCombinerInfo &DCI,
44284484 const LoongArchSubtarget &Subtarget) {
@@ -4493,10 +4549,56 @@ static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
44934549 }
44944550 }
44954551
4496- if (Opc == ISD::DELETED_NODE)
4497- return SDValue ();
4552+ // Generate vXi1 using [X]VMSKLTZ
4553+ if (Opc == ISD::DELETED_NODE) {
4554+ MVT SExtVT;
4555+ bool UseLASX = false ;
4556+ bool PropagateSExt = false ;
4557+ switch (SrcVT.getSimpleVT ().SimpleTy ) {
4558+ default :
4559+ return SDValue ();
4560+ case MVT::v2i1:
4561+ SExtVT = MVT::v2i64;
4562+ break ;
4563+ case MVT::v4i1:
4564+ SExtVT = MVT::v4i32;
4565+ if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4566+ SExtVT = MVT::v4i64;
4567+ UseLASX = true ;
4568+ PropagateSExt = true ;
4569+ }
4570+ break ;
4571+ case MVT::v8i1:
4572+ SExtVT = MVT::v8i16;
4573+ if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4574+ SExtVT = MVT::v8i32;
4575+ UseLASX = true ;
4576+ PropagateSExt = true ;
4577+ }
4578+ break ;
4579+ case MVT::v16i1:
4580+ SExtVT = MVT::v16i8;
4581+ if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4582+ SExtVT = MVT::v16i16;
4583+ UseLASX = true ;
4584+ PropagateSExt = true ;
4585+ }
4586+ break ;
4587+ case MVT::v32i1:
4588+ SExtVT = MVT::v32i8;
4589+ UseLASX = true ;
4590+ break ;
4591+ };
4592+ if (UseLASX && !Subtarget.has32S () && !Subtarget.hasExtLASX ())
4593+ return SDValue ();
4594+ Src = PropagateSExt ? signExtendBitcastSrcVector (DAG, SExtVT, Src, DL)
4595+ : DAG.getNode (ISD::SIGN_EXTEND, DL, SExtVT, Src);
4596+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4597+ } else {
4598+ Src = Src.getOperand (0 );
4599+ }
44984600
4499- SDValue V = DAG.getNode (Opc, DL, MVT::i64 , Src. getOperand ( 0 ) );
4601+ SDValue V = DAG.getNode (Opc, DL, MVT::i64 , Src);
45004602 EVT T = EVT::getIntegerVT (*DAG.getContext (), SrcVT.getVectorNumElements ());
45014603 V = DAG.getZExtOrTrunc (V, DL, T);
45024604 return DAG.getBitcast (VT, V);
0 commit comments