Skip to content

Conversation

@abhishek-kaushik22
Copy link
Contributor

For the function (https://godbolt.org/z/4rTYeMY4b)

#include <immintrin.h>

__m512i foo(__m512i a){
    __m512i r0 = _mm512_maskz_shuffle_epi32(0xaaaa, a, 0xab);
    return r0;
}

The assembly generated is unnecessarily long

.LCPI0_1:
        .byte   0
        .byte   18
        .byte   2
        .byte   18
        .byte   4
        .byte   22
        .byte   6
        .byte   22
        .byte   8
        .byte   26
        .byte   10
        .byte   26
        .byte   12
        .byte   30
        .byte   14
        .byte   30
foo(long long vector[8]):
        vpmovsxbd       zmm2, xmmword ptr [rip + .LCPI0_1]
        vpxor   xmm1, xmm1, xmm1
        vpermt2d        zmm1, zmm2, zmm0
        vmovdqa64       zmm0, zmm1
        ret

Instead we could simply generate a vpshufd {{.*#+}} zmm0 {%k1} {z} instruction and pass the mask and the imm8 value to it.

The selection dag generated from the IR doesn't contain the mask and the imm8 value directly but there is a pattern that can be matched here.

t6: v16i32 = BUILD_VECTOR Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32
          t2: v8i64,ch = CopyFromReg t0, Register:v8i64 %0
        t3: v16i32 = bitcast t2
      t7: v16i32 = vector_shuffle<0,18,2,18,4,22,6,22,8,26,10,26,12,30,14,30> t6, t3
    t8: v8i64 = bitcast t7

I've tried to match this pattern to get the value of the mask and imm8, and generate a VSELECT node. The resulting assembly looks like

        movw    $-21846, %ax                    # imm = 0xAAAA
        kmovw   %eax, %k1
        vpshufd $136, %zmm0, %zmm0 {%k1} {z}    # zmm0 {%k1} {z} = zmm0[0,2,0,2,4,6,4,6,8,10,8,10,12,14,12,14]
        retq

For the function (https://godbolt.org/z/4rTYeMY4b)
```
#include <immintrin.h>

__m512i foo(__m512i a){
    __m512i r0 = _mm512_maskz_shuffle_epi32(0xaaaa, a, 0xab);
    return r0;
}
```
The assembly generated is unnecessarily long
```
.LCPI0_1:
        .byte   0
        .byte   18
        .byte   2
        .byte   18
        .byte   4
        .byte   22
        .byte   6
        .byte   22
        .byte   8
        .byte   26
        .byte   10
        .byte   26
        .byte   12
        .byte   30
        .byte   14
        .byte   30
foo(long long vector[8]):
        vpmovsxbd       zmm2, xmmword ptr [rip + .LCPI0_1]
        vpxor   xmm1, xmm1, xmm1
        vpermt2d        zmm1, zmm2, zmm0
        vmovdqa64       zmm0, zmm1
        ret
```
Instead we could simply generate a `vpshufd {{.*#+}} zmm0 {%k1} {z}` instruction and pass the mask and the `imm8` value to it.

The selection dag generated from the IR doesn't contain the mask and the `imm8` value directly but there is a pattern that can be matched here.
```
t6: v16i32 = BUILD_VECTOR Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32
          t2: v8i64,ch = CopyFromReg t0, Register:v8i64 %0
        t3: v16i32 = bitcast t2
      t7: v16i32 = vector_shuffle<0,18,2,18,4,22,6,22,8,26,10,26,12,30,14,30> t6, t3
    t8: v8i64 = bitcast t7
```
I've tried to match this pattern to get the value of the mask and imm8, and generate a `VSELECT` node.
The resulting assembly looks like
```
       movw    $-21846, %ax                    # imm = 0xAAAA
        kmovw   %eax, %k1
        vpshufd $136, %zmm0, %zmm0 {%k1} {z}    # zmm0 {%k1} {z} = zmm0[0,2,0,2,4,6,4,6,8,10,8,10,12,14,12,14]
        retq
```
@llvmbot
Copy link
Member

llvmbot commented Dec 26, 2024

@llvm/pr-subscribers-backend-x86

Author: None (abhishek-kaushik22)

Changes

For the function (https://godbolt.org/z/4rTYeMY4b)

#include &lt;immintrin.h&gt;

__m512i foo(__m512i a){
    __m512i r0 = _mm512_maskz_shuffle_epi32(0xaaaa, a, 0xab);
    return r0;
}

The assembly generated is unnecessarily long

.LCPI0_1:
        .byte   0
        .byte   18
        .byte   2
        .byte   18
        .byte   4
        .byte   22
        .byte   6
        .byte   22
        .byte   8
        .byte   26
        .byte   10
        .byte   26
        .byte   12
        .byte   30
        .byte   14
        .byte   30
foo(long long vector[8]):
        vpmovsxbd       zmm2, xmmword ptr [rip + .LCPI0_1]
        vpxor   xmm1, xmm1, xmm1
        vpermt2d        zmm1, zmm2, zmm0
        vmovdqa64       zmm0, zmm1
        ret

Instead we could simply generate a vpshufd {{.*#+}} zmm0 {%k1} {z} instruction and pass the mask and the imm8 value to it.

The selection dag generated from the IR doesn't contain the mask and the imm8 value directly but there is a pattern that can be matched here.

t6: v16i32 = BUILD_VECTOR Constant:i32&lt;0&gt;, undef:i32, Constant:i32&lt;0&gt;, undef:i32, Constant:i32&lt;0&gt;, undef:i32, Constant:i32&lt;0&gt;, undef:i32, Constant:i32&lt;0&gt;, undef:i32, Constant:i32&lt;0&gt;, undef:i32, Constant:i32&lt;0&gt;, undef:i32, Constant:i32&lt;0&gt;, undef:i32
          t2: v8i64,ch = CopyFromReg t0, Register:v8i64 %0
        t3: v16i32 = bitcast t2
      t7: v16i32 = vector_shuffle&lt;0,18,2,18,4,22,6,22,8,26,10,26,12,30,14,30&gt; t6, t3
    t8: v8i64 = bitcast t7

I've tried to match this pattern to get the value of the mask and imm8, and generate a VSELECT node. The resulting assembly looks like

        movw    $-21846, %ax                    # imm = 0xAAAA
        kmovw   %eax, %k1
        vpshufd $136, %zmm0, %zmm0 {%k1} {z}    # zmm0 {%k1} {z} = zmm0[0,2,0,2,4,6,4,6,8,10,8,10,12,14,12,14]
        retq

Full diff: https://github.com/llvm/llvm-project/pull/121147.diff

1 Files Affected:

  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+55)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index e7f6032ee7d749..ca07b81f3fb984 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -17172,6 +17172,58 @@ static SDValue lowerV8I64Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
   return lowerShuffleWithPERMV(DL, MVT::v8i64, Mask, V1, V2, Subtarget, DAG);
 }
 
+static SDValue lowerShuffleAsVSELECT(const SDLoc &DL,
+                                     ArrayRef<int> RepeatedMask, SDValue V1,
+                                     SDValue V2, SelectionDAG &DAG) {
+  if (V1.getOpcode() != ISD::BUILD_VECTOR &&
+      V2.getOpcode() != ISD::BUILD_VECTOR)
+    return SDValue();
+  SDValue BuildVector;
+  if (V1.getOpcode() == ISD::BUILD_VECTOR) {
+    BuildVector = V1;
+    if (V2.getOpcode() != ISD::BITCAST)
+      return SDValue();
+  } else {
+    BuildVector = V2;
+    if (V1.getOpcode() != ISD::BITCAST)
+      return SDValue();
+  }
+  if (!ISD::isBuildVectorAllZeros(BuildVector.getNode()))
+    return SDValue();
+  APInt DestMask(16, 0);
+  for (unsigned i = 0; i < 16; ++i) {
+    SDValue Op = BuildVector->getOperand(i);
+    if (Op.isUndef())
+      DestMask.setBit(i);
+  }
+  if (DestMask.isZero())
+    return SDValue();
+
+  unsigned Imm8 = 0;
+  for (unsigned i = 0; i < 4; ++i) {
+    if (V1.getOpcode() != ISD::BUILD_VECTOR) {
+      if (RepeatedMask[i] >= 4) {
+        continue;
+      }
+    } else if (RepeatedMask[i] < 4) {
+      continue;
+    }
+    Imm8 += (RepeatedMask[i] % 4) << (2 * i);
+  }
+
+  SDValue Bitcast = DAG.getNode(ISD::BITCAST, DL, MVT::v16i1,
+                                DAG.getConstant(DestMask, DL, MVT::i16));
+
+  std::vector<SDValue> ZeroElements(16, DAG.getConstant(0, DL, MVT::i32));
+  SDValue Zeros = DAG.getBuildVector(MVT::v16i32, DL, ZeroElements);
+
+  return DAG.getNode(ISD::VSELECT, DL, MVT::v16i32, Bitcast,
+                     DAG.getNode(X86ISD::PSHUFD, DL, MVT::v16i32,
+                                 V1.getOpcode() != ISD::BUILD_VECTOR ? V1 : V2,
+                                 DAG.getTargetConstant(Imm8, DL, MVT::i8)),
+                     Zeros);
+}
+
 /// Handle lowering of 16-lane 32-bit integer shuffles.
 static SDValue lowerV16I32Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
                                   const APInt &Zeroable, SDValue V1, SDValue V2,
@@ -17217,6 +17269,9 @@ static SDValue lowerV16I32Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
     // Use dedicated unpack instructions for masks that match their pattern.
     if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v16i32, V1, V2, Mask, DAG))
       return V;
+
+    if (SDValue V = lowerShuffleAsVSELECT(DL, RepeatedMask, V1, V2, DAG))
+      return V;
   }
 
   // Try to use shift instructions.

@abhishek-kaushik22
Copy link
Contributor Author

@phoebewang @e-kud @RKSimon can you please review?

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a few comments, its tricky to decide when its worth using a constant predicate mask so we've not tried too hard to lower to them.

return SDValue();
}
if (!ISD::isBuildVectorAllZeros(BuildVector.getNode()))
return SDValue();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the code above can be removed by using the Zeroable mask used in shuffle lowering.

@github-actions
Copy link

github-actions bot commented Mar 10, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

; AVX512BW-NEXT: movw $-21846, %ax # imm = 0xAAAA
; AVX512BW-NEXT: kmovd %eax, %k1
; AVX512BW-NEXT: vpshufd {{.*#+}} zmm0 {%k1} {z} = zmm0[2,2,2,2,6,6,6,6,10,10,10,10,14,14,14,14]
; AVX512BW-NEXT: retq
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still not sure if this wouldn't be better off as a VPSHUFB node

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants