Skip to content

Commit 451c207

Browse files
ArnavM3434kcloudy0717
authored andcommitted
[X86] Make VBMI2 funnel shifts use VSHLD/VSHRD for const splats (llvm#169401)
Make ISD::FSHL/FSHR legal on VBMI2 vector targets and convert to VSHLD/VSHRD in a combine closes llvm#166949
1 parent a9808eb commit 451c207

File tree

1 file changed

+48
-14
lines changed

1 file changed

+48
-14
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2073,8 +2073,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
20732073

20742074
if (Subtarget.hasVBMI2()) {
20752075
for (auto VT : {MVT::v32i16, MVT::v16i32, MVT::v8i64}) {
2076-
setOperationAction(ISD::FSHL, VT, Custom);
2077-
setOperationAction(ISD::FSHR, VT, Custom);
2076+
setOperationAction(ISD::FSHL, VT, Legal);
2077+
setOperationAction(ISD::FSHR, VT, Legal);
20782078
}
20792079

20802080
setOperationAction(ISD::ROTL, MVT::v32i16, Custom);
@@ -2089,8 +2089,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
20892089
if (!Subtarget.useSoftFloat() && Subtarget.hasVBMI2()) {
20902090
for (auto VT : {MVT::v8i16, MVT::v4i32, MVT::v2i64, MVT::v16i16, MVT::v8i32,
20912091
MVT::v4i64}) {
2092-
setOperationAction(ISD::FSHL, VT, Custom);
2093-
setOperationAction(ISD::FSHR, VT, Custom);
2092+
setOperationAction(ISD::FSHL, VT, Subtarget.hasVLX() ? Legal : Custom);
2093+
setOperationAction(ISD::FSHR, VT, Subtarget.hasVLX() ? Legal : Custom);
20942094
}
20952095
}
20962096

@@ -2709,6 +2709,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
27092709
ISD::STRICT_FP_EXTEND,
27102710
ISD::FP_ROUND,
27112711
ISD::STRICT_FP_ROUND,
2712+
ISD::FSHL,
2713+
ISD::FSHR,
27122714
ISD::INTRINSIC_VOID,
27132715
ISD::INTRINSIC_WO_CHAIN,
27142716
ISD::INTRINSIC_W_CHAIN});
@@ -31322,19 +31324,15 @@ static SDValue LowerFunnelShift(SDValue Op, const X86Subtarget &Subtarget,
3132231324
bool IsCstSplat = X86::isConstantSplat(Amt, APIntShiftAmt);
3132331325
unsigned NumElts = VT.getVectorNumElements();
3132431326

31325-
if (Subtarget.hasVBMI2() && EltSizeInBits > 8) {
31326-
31327-
if (IsCstSplat) {
31328-
if (IsFSHR)
31329-
std::swap(Op0, Op1);
31330-
uint64_t ShiftAmt = APIntShiftAmt.urem(EltSizeInBits);
31331-
SDValue Imm = DAG.getTargetConstant(ShiftAmt, DL, MVT::i8);
31332-
return getAVX512Node(IsFSHR ? X86ISD::VSHRD : X86ISD::VSHLD, DL, VT,
31333-
{Op0, Op1, Imm}, DAG, Subtarget);
31334-
}
31327+
// For non-VLX VBMI2 targets, widen 128/256-bit to 512-bit so
31328+
// the rest of the lowering/isel can select the VBMI2 forms.
31329+
// Only Custom types (v8i16, v4i32, v2i64, v16i16, v8i32, v4i64) can
31330+
// reach LowerFunnelShift with VBMI2 but no VLX, so no type check needed.
31331+
if (Subtarget.hasVBMI2() && !Subtarget.hasVLX() && EltSizeInBits > 8) {
3133531332
return getAVX512Node(IsFSHR ? ISD::FSHR : ISD::FSHL, DL, VT,
3133631333
{Op0, Op1, Amt}, DAG, Subtarget);
3133731334
}
31335+
3133831336
assert((VT == MVT::v16i8 || VT == MVT::v32i8 || VT == MVT::v64i8 ||
3133931337
VT == MVT::v8i16 || VT == MVT::v16i16 || VT == MVT::v32i16 ||
3134031338
VT == MVT::v4i32 || VT == MVT::v8i32 || VT == MVT::v16i32) &&
@@ -57637,6 +57635,40 @@ static SDValue combineFP_TO_xINT_SAT(SDNode *N, SelectionDAG &DAG,
5763757635
return SDValue();
5763857636
}
5763957637

57638+
// Combiner: turn uniform-constant splat funnel shifts into VSHLD/VSHRD
57639+
static SDValue combineFunnelShift(SDNode *N, SelectionDAG &DAG,
57640+
TargetLowering::DAGCombinerInfo &DCI,
57641+
const X86Subtarget &Subtarget) {
57642+
SDLoc DL(N);
57643+
SDValue Op0 = N->getOperand(0);
57644+
SDValue Op1 = N->getOperand(1);
57645+
SDValue Amt = N->getOperand(2);
57646+
EVT VT = Op0.getValueType();
57647+
57648+
if (!VT.isVector())
57649+
return SDValue();
57650+
57651+
// Only combine if the operation is legal for this type.
57652+
// This ensures we don't try to convert types that need to be
57653+
// widened/promoted.
57654+
if (!DAG.getTargetLoweringInfo().isOperationLegal(N->getOpcode(), VT))
57655+
return SDValue();
57656+
57657+
unsigned EltSize = VT.getScalarSizeInBits();
57658+
APInt ShiftVal;
57659+
if (!X86::isConstantSplat(Amt, ShiftVal))
57660+
return SDValue();
57661+
57662+
uint64_t ModAmt = ShiftVal.urem(EltSize);
57663+
SDValue Imm = DAG.getTargetConstant(ModAmt, DL, MVT::i8);
57664+
bool IsFSHR = N->getOpcode() == ISD::FSHR;
57665+
57666+
if (IsFSHR)
57667+
std::swap(Op0, Op1);
57668+
unsigned Opcode = IsFSHR ? X86ISD::VSHRD : X86ISD::VSHLD;
57669+
return DAG.getNode(Opcode, DL, VT, {Op0, Op1, Imm});
57670+
}
57671+
5764057672
static bool needCarryOrOverflowFlag(SDValue Flags) {
5764157673
assert(Flags.getValueType() == MVT::i32 && "Unexpected VT!");
5764257674

@@ -61279,6 +61311,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
6127961311
case ISD::INTRINSIC_VOID: return combineINTRINSIC_VOID(N, DAG, DCI);
6128061312
case ISD::FP_TO_SINT_SAT:
6128161313
case ISD::FP_TO_UINT_SAT: return combineFP_TO_xINT_SAT(N, DAG, Subtarget);
61314+
case ISD::FSHL:
61315+
case ISD::FSHR: return combineFunnelShift(N, DAG, DCI, Subtarget);
6128261316
// clang-format on
6128361317
}
6128461318

0 commit comments

Comments
 (0)