@@ -2073,8 +2073,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
20732073
20742074 if (Subtarget.hasVBMI2()) {
20752075 for (auto VT : {MVT::v32i16, MVT::v16i32, MVT::v8i64}) {
2076- setOperationAction(ISD::FSHL, VT, Custom );
2077- setOperationAction(ISD::FSHR, VT, Custom );
2076+ setOperationAction(ISD::FSHL, VT, Legal );
2077+ setOperationAction(ISD::FSHR, VT, Legal );
20782078 }
20792079
20802080 setOperationAction(ISD::ROTL, MVT::v32i16, Custom);
@@ -2089,8 +2089,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
20892089 if (!Subtarget.useSoftFloat() && Subtarget.hasVBMI2()) {
20902090 for (auto VT : {MVT::v8i16, MVT::v4i32, MVT::v2i64, MVT::v16i16, MVT::v8i32,
20912091 MVT::v4i64}) {
2092- setOperationAction(ISD::FSHL, VT, Custom);
2093- setOperationAction(ISD::FSHR, VT, Custom);
2092+ setOperationAction(ISD::FSHL, VT, Subtarget.hasVLX() ? Legal : Custom);
2093+ setOperationAction(ISD::FSHR, VT, Subtarget.hasVLX() ? Legal : Custom);
20942094 }
20952095 }
20962096
@@ -2709,6 +2709,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
27092709 ISD::STRICT_FP_EXTEND,
27102710 ISD::FP_ROUND,
27112711 ISD::STRICT_FP_ROUND,
2712+ ISD::FSHL,
2713+ ISD::FSHR,
27122714 ISD::INTRINSIC_VOID,
27132715 ISD::INTRINSIC_WO_CHAIN,
27142716 ISD::INTRINSIC_W_CHAIN});
@@ -31322,19 +31324,15 @@ static SDValue LowerFunnelShift(SDValue Op, const X86Subtarget &Subtarget,
3132231324 bool IsCstSplat = X86::isConstantSplat(Amt, APIntShiftAmt);
3132331325 unsigned NumElts = VT.getVectorNumElements();
3132431326
31325- if (Subtarget.hasVBMI2() && EltSizeInBits > 8) {
31326-
31327- if (IsCstSplat) {
31328- if (IsFSHR)
31329- std::swap(Op0, Op1);
31330- uint64_t ShiftAmt = APIntShiftAmt.urem(EltSizeInBits);
31331- SDValue Imm = DAG.getTargetConstant(ShiftAmt, DL, MVT::i8);
31332- return getAVX512Node(IsFSHR ? X86ISD::VSHRD : X86ISD::VSHLD, DL, VT,
31333- {Op0, Op1, Imm}, DAG, Subtarget);
31334- }
31327+ // For non-VLX VBMI2 targets, widen 128/256-bit to 512-bit so
31328+ // the rest of the lowering/isel can select the VBMI2 forms.
31329+ // Only Custom types (v8i16, v4i32, v2i64, v16i16, v8i32, v4i64) can
31330+ // reach LowerFunnelShift with VBMI2 but no VLX, so no type check needed.
31331+ if (Subtarget.hasVBMI2() && !Subtarget.hasVLX() && EltSizeInBits > 8) {
3133531332 return getAVX512Node(IsFSHR ? ISD::FSHR : ISD::FSHL, DL, VT,
3133631333 {Op0, Op1, Amt}, DAG, Subtarget);
3133731334 }
31335+
3133831336 assert((VT == MVT::v16i8 || VT == MVT::v32i8 || VT == MVT::v64i8 ||
3133931337 VT == MVT::v8i16 || VT == MVT::v16i16 || VT == MVT::v32i16 ||
3134031338 VT == MVT::v4i32 || VT == MVT::v8i32 || VT == MVT::v16i32) &&
@@ -57637,6 +57635,40 @@ static SDValue combineFP_TO_xINT_SAT(SDNode *N, SelectionDAG &DAG,
5763757635 return SDValue();
5763857636}
5763957637
57638+ // Combiner: turn uniform-constant splat funnel shifts into VSHLD/VSHRD
57639+ static SDValue combineFunnelShift(SDNode *N, SelectionDAG &DAG,
57640+ TargetLowering::DAGCombinerInfo &DCI,
57641+ const X86Subtarget &Subtarget) {
57642+ SDLoc DL(N);
57643+ SDValue Op0 = N->getOperand(0);
57644+ SDValue Op1 = N->getOperand(1);
57645+ SDValue Amt = N->getOperand(2);
57646+ EVT VT = Op0.getValueType();
57647+
57648+ if (!VT.isVector())
57649+ return SDValue();
57650+
57651+ // Only combine if the operation is legal for this type.
57652+ // This ensures we don't try to convert types that need to be
57653+ // widened/promoted.
57654+ if (!DAG.getTargetLoweringInfo().isOperationLegal(N->getOpcode(), VT))
57655+ return SDValue();
57656+
57657+ unsigned EltSize = VT.getScalarSizeInBits();
57658+ APInt ShiftVal;
57659+ if (!X86::isConstantSplat(Amt, ShiftVal))
57660+ return SDValue();
57661+
57662+ uint64_t ModAmt = ShiftVal.urem(EltSize);
57663+ SDValue Imm = DAG.getTargetConstant(ModAmt, DL, MVT::i8);
57664+ bool IsFSHR = N->getOpcode() == ISD::FSHR;
57665+
57666+ if (IsFSHR)
57667+ std::swap(Op0, Op1);
57668+ unsigned Opcode = IsFSHR ? X86ISD::VSHRD : X86ISD::VSHLD;
57669+ return DAG.getNode(Opcode, DL, VT, {Op0, Op1, Imm});
57670+ }
57671+
5764057672static bool needCarryOrOverflowFlag(SDValue Flags) {
5764157673 assert(Flags.getValueType() == MVT::i32 && "Unexpected VT!");
5764257674
@@ -61279,6 +61311,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
6127961311 case ISD::INTRINSIC_VOID: return combineINTRINSIC_VOID(N, DAG, DCI);
6128061312 case ISD::FP_TO_SINT_SAT:
6128161313 case ISD::FP_TO_UINT_SAT: return combineFP_TO_xINT_SAT(N, DAG, Subtarget);
61314+ case ISD::FSHL:
61315+ case ISD::FSHR: return combineFunnelShift(N, DAG, DCI, Subtarget);
6128261316 // clang-format on
6128361317 }
6128461318
0 commit comments