Skip to content

Conversation

@RKSimon
Copy link
Collaborator

@RKSimon RKSimon commented Dec 17, 2024

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Dec 17, 2024

@llvm/pr-subscribers-backend-x86

Author: Simon Pilgrim (RKSimon)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/120241.diff

1 Files Affected:

  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+5-8)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index c50db59464724c..4ed0a886c9d298 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -30000,6 +30000,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
   SDLoc dl(Op);
   SDValue R = Op.getOperand(0);
   SDValue Amt = Op.getOperand(1);
+  unsigned NumElts = VT.getVectorNumElements();
   unsigned EltSizeInBits = VT.getScalarSizeInBits();
   bool ConstantAmt = ISD::isBuildVectorOfConstantSDNodes(Amt.getNode());
 
@@ -30069,7 +30070,6 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
   if (ConstantAmt && (VT == MVT::v8i16 || VT == MVT::v4i32 ||
                       (VT == MVT::v16i16 && Subtarget.hasInt256()))) {
     SDValue Amt1, Amt2;
-    unsigned NumElts = VT.getVectorNumElements();
     SmallVector<int, 8> ShuffleMask;
     for (unsigned i = 0; i != NumElts; ++i) {
       SDValue A = Amt->getOperand(i);
@@ -30116,7 +30116,6 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
        VT == MVT::v8i16 || VT == MVT::v16i16 || VT == MVT::v32i16) &&
       !Subtarget.hasXOP()) {
     MVT NarrowScalarVT = VT.getScalarType();
-    int NumElts = VT.getVectorNumElements();
     // We can do this extra fast if each pair of narrow elements is shifted by
     // the same amount by doing this SWAR style: use a shift to move the valid
     // bits to the right position, mask out any bits which crossed from one
@@ -30377,7 +30376,6 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
   if ((VT == MVT::v16i8 && Subtarget.hasSSSE3()) ||
       (VT == MVT::v32i8 && Subtarget.hasInt256()) ||
       (VT == MVT::v64i8 && Subtarget.hasBWI())) {
-    unsigned NumElts = VT.getVectorNumElements();
     unsigned NumLanes = VT.getSizeInBits() / 128u;
     unsigned NumEltsPerLane = NumElts / NumLanes;
     SmallVector<APInt, 16> LUT;
@@ -30417,7 +30415,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
     assert((!Subtarget.hasBWI() || VT == MVT::v32i8 || VT == MVT::v16i8) &&
            "Unexpected vector type");
     MVT EvtSVT = Subtarget.hasBWI() ? MVT::i16 : MVT::i32;
-    MVT ExtVT = MVT::getVectorVT(EvtSVT, VT.getVectorNumElements());
+    MVT ExtVT = MVT::getVectorVT(EvtSVT, NumElts);
     unsigned ExtOpc = Opc == ISD::SRA ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
     R = DAG.getNode(ExtOpc, dl, ExtVT, R);
     Amt = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtVT, Amt);
@@ -30431,7 +30429,6 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
       (VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget.hasInt256()) ||
        (VT == MVT::v64i8 && Subtarget.hasBWI())) &&
       !Subtarget.hasXOP()) {
-    int NumElts = VT.getVectorNumElements();
     MVT VT16 = MVT::getVectorVT(MVT::i16, NumElts / 2);
     SDValue Cst8 = DAG.getTargetConstant(8, dl, MVT::i8);
 
@@ -30477,14 +30474,14 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
   if (VT == MVT::v16i8 ||
       (VT == MVT::v32i8 && Subtarget.hasInt256() && !Subtarget.hasXOP()) ||
       (VT == MVT::v64i8 && Subtarget.hasBWI())) {
-    MVT ExtVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements() / 2);
+    MVT ExtVT = MVT::getVectorVT(MVT::i16, NumElts / 2);
 
     auto SignBitSelect = [&](MVT SelVT, SDValue Sel, SDValue V0, SDValue V1) {
       if (VT.is512BitVector()) {
         // On AVX512BW targets we make use of the fact that VSELECT lowers
         // to a masked blend which selects bytes based just on the sign bit
         // extracted to a mask.
-        MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements());
+        MVT MaskVT = MVT::getVectorVT(MVT::i1, NumElts);
         V0 = DAG.getBitcast(VT, V0);
         V1 = DAG.getBitcast(VT, V1);
         Sel = DAG.getBitcast(VT, Sel);
@@ -30611,7 +30608,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
       // On SSE41 targets we can use PBLENDVB which selects bytes based just on
       // the sign bit.
       if (UseSSE41) {
-        MVT ExtVT = MVT::getVectorVT(MVT::i8, VT.getVectorNumElements() * 2);
+        MVT ExtVT = MVT::getVectorVT(MVT::i8, NumElts * 2);
         V0 = DAG.getBitcast(ExtVT, V0);
         V1 = DAG.getBitcast(ExtVT, V1);
         Sel = DAG.getBitcast(ExtVT, Sel);

@RKSimon RKSimon merged commit 7ab8dd7 into llvm:main Dec 17, 2024
7 of 10 checks passed
@RKSimon RKSimon deleted the x86-lower-shift-num-elts branch December 17, 2024 16:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants