Skip to content

Commit 4f2e5e4

Browse files
[X86][AVX512] Better lowering for _mm512_maskz_shuffle_epi32
For the function (https://godbolt.org/z/4rTYeMY4b) ``` #include <immintrin.h> __m512i foo(__m512i a){ __m512i r0 = _mm512_maskz_shuffle_epi32(0xaaaa, a, 0xab); return r0; } ``` The assembly generated is unnecessarily long ``` .LCPI0_1: .byte 0 .byte 18 .byte 2 .byte 18 .byte 4 .byte 22 .byte 6 .byte 22 .byte 8 .byte 26 .byte 10 .byte 26 .byte 12 .byte 30 .byte 14 .byte 30 foo(long long vector[8]): vpmovsxbd zmm2, xmmword ptr [rip + .LCPI0_1] vpxor xmm1, xmm1, xmm1 vpermt2d zmm1, zmm2, zmm0 vmovdqa64 zmm0, zmm1 ret ``` Instead we could simply generate a `vpshufd {{.*#+}} zmm0 {%k1} {z}` instruction and pass the mask and the `imm8` value to it. The selection dag generated from the IR doesn't contain the mask and the `imm8` value directly but there is a pattern that can be matched here. ``` t6: v16i32 = BUILD_VECTOR Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32 t2: v8i64,ch = CopyFromReg t0, Register:v8i64 %0 t3: v16i32 = bitcast t2 t7: v16i32 = vector_shuffle<0,18,2,18,4,22,6,22,8,26,10,26,12,30,14,30> t6, t3 t8: v8i64 = bitcast t7 ``` I've tried to match this pattern to get the value of the mask and imm8, and generate a `VSELECT` node. The resulting assembly looks like ``` movw $-21846, %ax # imm = 0xAAAA kmovw %eax, %k1 vpshufd $136, %zmm0, %zmm0 {%k1} {z} # zmm0 {%k1} {z} = zmm0[0,2,0,2,4,6,4,6,8,10,8,10,12,14,12,14] retq ```
1 parent cbe583b commit 4f2e5e4

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17172,6 +17172,58 @@ static SDValue lowerV8I64Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
1717217172
return lowerShuffleWithPERMV(DL, MVT::v8i64, Mask, V1, V2, Subtarget, DAG);
1717317173
}
1717417174

17175+
static SDValue lowerShuffleAsVSELECT(const SDLoc &DL,
17176+
ArrayRef<int> RepeatedMask, SDValue V1,
17177+
SDValue V2, SelectionDAG &DAG) {
17178+
if (V1.getOpcode() != ISD::BUILD_VECTOR &&
17179+
V2.getOpcode() != ISD::BUILD_VECTOR)
17180+
return SDValue();
17181+
SDValue BuildVector;
17182+
if (V1.getOpcode() == ISD::BUILD_VECTOR) {
17183+
BuildVector = V1;
17184+
if (V2.getOpcode() != ISD::BITCAST)
17185+
return SDValue();
17186+
} else {
17187+
BuildVector = V2;
17188+
if (V1.getOpcode() != ISD::BITCAST)
17189+
return SDValue();
17190+
}
17191+
if (!ISD::isBuildVectorAllZeros(BuildVector.getNode()))
17192+
return SDValue();
17193+
APInt DestMask(16, 0);
17194+
for (unsigned i = 0; i < 16; ++i) {
17195+
SDValue Op = BuildVector->getOperand(i);
17196+
if (Op.isUndef())
17197+
DestMask.setBit(i);
17198+
}
17199+
if (DestMask.isZero())
17200+
return SDValue();
17201+
17202+
unsigned Imm8 = 0;
17203+
for (unsigned i = 0; i < 4; ++i) {
17204+
if (V1.getOpcode() != ISD::BUILD_VECTOR) {
17205+
if (RepeatedMask[i] >= 4) {
17206+
continue;
17207+
}
17208+
} else if (RepeatedMask[i] < 4) {
17209+
continue;
17210+
}
17211+
Imm8 += (RepeatedMask[i] % 4) << (2 * i);
17212+
}
17213+
17214+
SDValue Bitcast = DAG.getNode(ISD::BITCAST, DL, MVT::v16i1,
17215+
DAG.getConstant(DestMask, DL, MVT::i16));
17216+
17217+
std::vector<SDValue> ZeroElements(16, DAG.getConstant(0, DL, MVT::i32));
17218+
SDValue Zeros = DAG.getBuildVector(MVT::v16i32, DL, ZeroElements);
17219+
17220+
return DAG.getNode(ISD::VSELECT, DL, MVT::v16i32, Bitcast,
17221+
DAG.getNode(X86ISD::PSHUFD, DL, MVT::v16i32,
17222+
V1.getOpcode() != ISD::BUILD_VECTOR ? V1 : V2,
17223+
DAG.getTargetConstant(Imm8, DL, MVT::i8)),
17224+
Zeros);
17225+
}
17226+
1717517227
/// Handle lowering of 16-lane 32-bit integer shuffles.
1717617228
static SDValue lowerV16I32Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
1717717229
const APInt &Zeroable, SDValue V1, SDValue V2,
@@ -17217,6 +17269,9 @@ static SDValue lowerV16I32Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
1721717269
// Use dedicated unpack instructions for masks that match their pattern.
1721817270
if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v16i32, V1, V2, Mask, DAG))
1721917271
return V;
17272+
17273+
if (SDValue V = lowerShuffleAsVSELECT(DL, RepeatedMask, V1, V2, DAG))
17274+
return V;
1722017275
}
1722117276

1722217277
// Try to use shift instructions.

0 commit comments

Comments
 (0)