diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp index 32c7d2bfea6c2..c5b9a0a2ef057 100644 --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -1139,24 +1139,62 @@ void X86DAGToDAGISel::PreprocessISelDAG() { break; } case ISD::VSELECT: { - // Replace VSELECT with non-mask conditions with with BLENDV/VPTERNLOG. - EVT EleVT = N->getOperand(0).getValueType().getVectorElementType(); - if (EleVT == MVT::i1) - break; - - assert(Subtarget->hasSSE41() && "Expected SSE4.1 support!"); - assert(N->getValueType(0).getVectorElementType() != MVT::i16 && - "We can't replace VSELECT with BLENDV in vXi16!"); + SDValue Cond = N->getOperand(0); + SDValue LHS = N->getOperand(1); + SDValue RHS = N->getOperand(2); + EVT CondVT = Cond.getValueType(); + EVT CondSVT = CondVT.getVectorElementType(); + EVT VT = N->getValueType(0); + SDLoc DL(N); SDValue R; - if (Subtarget->hasVLX() && CurDAG->ComputeNumSignBits(N->getOperand(0)) == - EleVT.getSizeInBits()) { - R = CurDAG->getNode(X86ISD::VPTERNLOG, SDLoc(N), N->getValueType(0), - N->getOperand(0), N->getOperand(1), N->getOperand(2), - CurDAG->getTargetConstant(0xCA, SDLoc(N), MVT::i8)); + + if (CondSVT == MVT::i1) { + assert(Subtarget->hasAVX512() && "Expected AVX512 support!"); + if (!Cond->hasOneUse() || !ISD::isBuildVectorAllZeros(LHS.getNode()) || + ISD::isBuildVectorAllZeros(RHS.getNode())) + break; + // If this is an avx512 target we can improve the use of zero masking by + // swapping the operands and inverting the condition. + // vselect cond, zero, op = vselect not(cond), op, zero + auto InverseCondition = [this](SDValue Cond, const SDLoc &DL) { + EVT CondVT = Cond.getValueType(); + if (Cond.getOpcode() == ISD::SETCC && + !ISD::isBuildVectorAllZeros(Cond.getOperand(0).getNode())) { + ISD::CondCode CC = cast(Cond.getOperand(2))->get(); + CC = ISD::getSetCCInverse(CC, Cond.getOperand(0).getValueType()); + return CurDAG->getSetCC(DL, CondVT, Cond.getOperand(0), + Cond.getOperand(1), CC); + } + if (Cond.getOpcode() == X86ISD::CMPM || + Cond.getOpcode() == X86ISD::FSETCCM) { + unsigned CC = Cond.getConstantOperandVal(2); + return CurDAG->getNode( + Cond.getOpcode(), DL, CondVT, Cond.getOperand(0), + Cond.getOperand(1), + CurDAG->getTargetConstant(CC ^ 4, DL, MVT::i8)); + } + return CurDAG->getNOT(DL, Cond, CondVT); + }; + if (Cond.getOpcode() == ISD::INSERT_SUBVECTOR && + Cond.getOperand(0).isUndef()) + R = CurDAG->getNode( + ISD::INSERT_SUBVECTOR, DL, CondVT, Cond.getOperand(0), + InverseCondition(Cond.getOperand(1), DL), Cond.getOperand(2)); + else + R = InverseCondition(Cond, DL); + R = CurDAG->getSelect(DL, VT, R, RHS, LHS); } else { - R = CurDAG->getNode(X86ISD::BLENDV, SDLoc(N), N->getValueType(0), - N->getOperand(0), N->getOperand(1), - N->getOperand(2)); + // Replace VSELECT with non-mask conditions with BLENDV/VPTERNLOG. + assert(Subtarget->hasSSE41() && "Expected SSE4.1 support!"); + assert(VT.getVectorElementType() != MVT::i16 && + "We can't replace VSELECT with BLENDV in vXi16!"); + if (Subtarget->hasVLX() && + CurDAG->ComputeNumSignBits(Cond) == CondSVT.getSizeInBits()) { + R = CurDAG->getNode(X86ISD::VPTERNLOG, DL, VT, Cond, LHS, RHS, + CurDAG->getTargetConstant(0xCA, DL, MVT::i8)); + } else { + R = CurDAG->getNode(X86ISD::BLENDV, DL, VT, Cond, LHS, RHS); + } } --I; CurDAG->ReplaceAllUsesWith(N, R.getNode()); diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 307a237e2955c..46d7be3b4a3f7 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -48039,19 +48039,6 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, } } - // Check if the first operand is all zeros and Cond type is vXi1. - // If this an avx512 target we can improve the use of zero masking by - // swapping the operands and inverting the condition. - if (N->getOpcode() == ISD::VSELECT && Cond.hasOneUse() && - Subtarget.hasAVX512() && CondVT.getVectorElementType() == MVT::i1 && - ISD::isBuildVectorAllZeros(LHS.getNode()) && - !ISD::isBuildVectorAllZeros(RHS.getNode())) { - // Invert the cond to not(cond) : xor(op,allones)=not(op) - SDValue CondNew = DAG.getNOT(DL, Cond, CondVT); - // Vselect cond, op1, op2 = Vselect not(cond), op2, op1 - return DAG.getSelect(DL, VT, CondNew, RHS, LHS); - } - // Attempt to convert a (vXi1 bitcast(iX Cond)) selection mask before it might // get split by legalization. if (N->getOpcode() == ISD::VSELECT && Cond.getOpcode() == ISD::BITCAST && @@ -48115,33 +48102,35 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, return V; // select(~Cond, X, Y) -> select(Cond, Y, X) - if (CondVT.getScalarType() != MVT::i1) { + // Limit vXi1 cases to AVX512 canonicalization of zero mask to the RHS. + if (CondVT.getScalarType() != MVT::i1 || + (ISD::isBuildVectorAllZeros(LHS.getNode()) && + !ISD::isBuildVectorAllZeros(RHS.getNode()))) if (SDValue CondNot = IsNOT(Cond, DAG)) return DAG.getNode(N->getOpcode(), DL, VT, DAG.getBitcast(CondVT, CondNot), RHS, LHS); - // select(pcmpeq(and(X,Pow2),0),A,B) -> select(pcmpeq(and(X,Pow2),Pow2),B,A) - if (Cond.getOpcode() == X86ISD::PCMPEQ && - Cond.getOperand(0).getOpcode() == ISD::AND && - ISD::isBuildVectorAllZeros(Cond.getOperand(1).getNode()) && - isConstantPowerOf2(Cond.getOperand(0).getOperand(1), - Cond.getScalarValueSizeInBits(), - /*AllowUndefs=*/true) && - Cond.hasOneUse()) { - Cond = DAG.getNode(X86ISD::PCMPEQ, DL, CondVT, Cond.getOperand(0), - Cond.getOperand(0).getOperand(1)); - return DAG.getNode(N->getOpcode(), DL, VT, Cond, RHS, LHS); - } - - // pcmpgt(X, -1) -> pcmpgt(0, X) to help select/blendv just use the - // signbit. - if (Cond.getOpcode() == X86ISD::PCMPGT && - ISD::isBuildVectorAllOnes(Cond.getOperand(1).getNode()) && - Cond.hasOneUse()) { - Cond = DAG.getNode(X86ISD::PCMPGT, DL, CondVT, - DAG.getConstant(0, DL, CondVT), Cond.getOperand(0)); - return DAG.getNode(N->getOpcode(), DL, VT, Cond, RHS, LHS); - } + // select(pcmpeq(and(X,Pow2),0),A,B) -> select(pcmpeq(and(X,Pow2),Pow2),B,A) + if (Cond.getOpcode() == X86ISD::PCMPEQ && + Cond.getOperand(0).getOpcode() == ISD::AND && + ISD::isBuildVectorAllZeros(Cond.getOperand(1).getNode()) && + isConstantPowerOf2(Cond.getOperand(0).getOperand(1), + Cond.getScalarValueSizeInBits(), + /*AllowUndefs=*/true) && + Cond.hasOneUse()) { + Cond = DAG.getNode(X86ISD::PCMPEQ, DL, CondVT, Cond.getOperand(0), + Cond.getOperand(0).getOperand(1)); + return DAG.getNode(N->getOpcode(), DL, VT, Cond, RHS, LHS); + } + + // pcmpgt(X, -1) -> pcmpgt(0, X) to help select/blendv just use the + // signbit. + if (Cond.getOpcode() == X86ISD::PCMPGT && + ISD::isBuildVectorAllOnes(Cond.getOperand(1).getNode()) && + Cond.hasOneUse()) { + Cond = DAG.getNode(X86ISD::PCMPGT, DL, CondVT, + DAG.getConstant(0, DL, CondVT), Cond.getOperand(0)); + return DAG.getNode(N->getOpcode(), DL, VT, Cond, RHS, LHS); } // Try to optimize vXi1 selects if both operands are either all constants or diff --git a/llvm/test/CodeGen/X86/psubus.ll b/llvm/test/CodeGen/X86/psubus.ll index e10b360b35b56..cc3aee4feba2d 100644 --- a/llvm/test/CodeGen/X86/psubus.ll +++ b/llvm/test/CodeGen/X86/psubus.ll @@ -981,9 +981,9 @@ define <16 x i8> @test14(<16 x i8> %x, <16 x i32> %y) nounwind { ; AVX512-LABEL: test14: ; AVX512: # %bb.0: # %vector.ph ; AVX512-NEXT: vpmovzxbd {{.*#+}} zmm2 = xmm0[0],zero,zero,zero,xmm0[1],zero,zero,zero,xmm0[2],zero,zero,zero,xmm0[3],zero,zero,zero,xmm0[4],zero,zero,zero,xmm0[5],zero,zero,zero,xmm0[6],zero,zero,zero,xmm0[7],zero,zero,zero,xmm0[8],zero,zero,zero,xmm0[9],zero,zero,zero,xmm0[10],zero,zero,zero,xmm0[11],zero,zero,zero,xmm0[12],zero,zero,zero,xmm0[13],zero,zero,zero,xmm0[14],zero,zero,zero,xmm0[15],zero,zero,zero +; AVX512-NEXT: vpmovdb %zmm1, %xmm3 ; AVX512-NEXT: vpcmpnltud %zmm2, %zmm1, %k1 -; AVX512-NEXT: vpmovdb %zmm1, %xmm1 -; AVX512-NEXT: vpsubb %xmm0, %xmm1, %xmm0 {%k1} {z} +; AVX512-NEXT: vpsubb %xmm0, %xmm3, %xmm0 {%k1} {z} ; AVX512-NEXT: vzeroupper ; AVX512-NEXT: retq vector.ph: diff --git a/llvm/test/CodeGen/X86/var-permute-256.ll b/llvm/test/CodeGen/X86/var-permute-256.ll index 283c6a303a581..ab2ffdfd0ff2c 100644 --- a/llvm/test/CodeGen/X86/var-permute-256.ll +++ b/llvm/test/CodeGen/X86/var-permute-256.ll @@ -148,11 +148,11 @@ define <4 x i64> @var_shuffle_zero_v4i64(<4 x i64> %v, <4 x i64> %indices) nounw ; AVX512-NEXT: # kill: def $ymm1 killed $ymm1 def $zmm1 ; AVX512-NEXT: # kill: def $ymm0 killed $ymm0 def $zmm0 ; AVX512-NEXT: vpbroadcastq {{.*#+}} ymm2 = [3,3,3,3] +; AVX512-NEXT: vpcmpnleuq %zmm2, %zmm1, %k1 +; AVX512-NEXT: vpcmpeqd %ymm3, %ymm3, %ymm3 +; AVX512-NEXT: vpblendmq %zmm3, %zmm1, %zmm3 {%k1} ; AVX512-NEXT: vpcmpleuq %zmm2, %zmm1, %k1 -; AVX512-NEXT: vpcmpnleuq %zmm2, %zmm1, %k2 -; AVX512-NEXT: vpcmpeqd %ymm2, %ymm2, %ymm2 -; AVX512-NEXT: vmovdqa64 %zmm2, %zmm1 {%k2} -; AVX512-NEXT: vpermq %zmm0, %zmm1, %zmm0 {%k1} {z} +; AVX512-NEXT: vpermq %zmm0, %zmm3, %zmm0 {%k1} {z} ; AVX512-NEXT: # kill: def $ymm0 killed $ymm0 killed $zmm0 ; AVX512-NEXT: retq ; @@ -1192,11 +1192,11 @@ define <4 x double> @var_shuffle_zero_v4f64(<4 x double> %v, <4 x i64> %indices) ; AVX512-NEXT: # kill: def $ymm1 killed $ymm1 def $zmm1 ; AVX512-NEXT: # kill: def $ymm0 killed $ymm0 def $zmm0 ; AVX512-NEXT: vpbroadcastq {{.*#+}} ymm2 = [3,3,3,3] +; AVX512-NEXT: vpcmpnleuq %zmm2, %zmm1, %k1 +; AVX512-NEXT: vpcmpeqd %ymm3, %ymm3, %ymm3 +; AVX512-NEXT: vpblendmq %zmm3, %zmm1, %zmm3 {%k1} ; AVX512-NEXT: vpcmpleuq %zmm2, %zmm1, %k1 -; AVX512-NEXT: vpcmpnleuq %zmm2, %zmm1, %k2 -; AVX512-NEXT: vpcmpeqd %ymm2, %ymm2, %ymm2 -; AVX512-NEXT: vmovdqa64 %zmm2, %zmm1 {%k2} -; AVX512-NEXT: vpermpd %zmm0, %zmm1, %zmm0 {%k1} {z} +; AVX512-NEXT: vpermpd %zmm0, %zmm3, %zmm0 {%k1} {z} ; AVX512-NEXT: # kill: def $ymm0 killed $ymm0 killed $zmm0 ; AVX512-NEXT: retq ;