Skip to content

Commit ec4373b

Browse files
committed
[InstCombine] Fold bitcast (extelt (bitcast X), Idx) into bitcast+shuffle.
Fold sequences such as: ```llvm %bc = bitcast <8 x i32> %v to <2 x i128> %ext = extractelement <2 x i128> %bc, i64 0 %res = bitcast i128 %ext to <2 x i64> ``` Into: ```llvm %bc = bitcast <8 x i32> %v to <4 x i64> %res = shufflevector <4 x i64> %bc, <4 x i64> poison, <2 x i32> <i32 0, i32 1> ``` The motivation for this is a long standing regression affecting SIMDe on AArch64 introduced indirectly by the AlwaysInliner (1a2e77c). Some reproducers: * https://godbolt.org/z/53qx18s6M * https://godbolt.org/z/o5e43h5M7
1 parent 13aae75 commit ec4373b

File tree

2 files changed

+52
-6
lines changed

2 files changed

+52
-6
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,6 +2380,51 @@ static Value *optimizeIntegerToVectorInsertions(BitCastInst &CI,
23802380
return Result;
23812381
}
23822382

2383+
/// If the input is (extractelement (bitcast X), Idx) and the source and
2384+
/// destination types are vectors, we are performing a vector extract from X. We
2385+
/// can replace the extractelement+bitcast with a shufflevector, avoiding the
2386+
/// final scalar->vector bitcast. This pattern is usually handled better by the
2387+
/// backend.
2388+
///
2389+
/// Example:
2390+
/// %bc = bitcast <8 x i32> %X to <2 x i128>
2391+
/// %ext = extractelement <2 x i128> %bc1, i64 1
2392+
/// bitcast i128 %ext to <2 x i64>
2393+
/// --->
2394+
/// %bc = bitcast <8 x i32> %X to <4 x i64>
2395+
/// shufflevector <4 x i64> %bc, <4 x i64> poison, <2 x i32> <i32 2, i32 3>
2396+
static Instruction *foldBitCastExtElt(BitCastInst &BitCast,
2397+
InstCombiner::BuilderTy &Builder) {
2398+
Value *X;
2399+
uint64_t Index;
2400+
if (!match(
2401+
BitCast.getOperand(0),
2402+
m_OneUse(m_ExtractElt(m_BitCast(m_Value(X)), m_ConstantInt(Index)))))
2403+
return nullptr;
2404+
2405+
auto *SrcTy = dyn_cast<FixedVectorType>(X->getType());
2406+
auto *DstTy = dyn_cast<FixedVectorType>(BitCast.getType());
2407+
if (!SrcTy || !DstTy)
2408+
return nullptr;
2409+
2410+
// Check if the mask indices would overflow.
2411+
unsigned NumElts = DstTy->getNumElements();
2412+
if (Index > INT32_MAX || NumElts > INT32_MAX ||
2413+
(Index + 1) * NumElts - 1 > INT32_MAX)
2414+
return nullptr;
2415+
2416+
unsigned DstEltWidth = DstTy->getScalarSizeInBits();
2417+
unsigned SrcVecWidth = SrcTy->getPrimitiveSizeInBits();
2418+
assert((SrcVecWidth % DstEltWidth == 0) && "Invalid types.");
2419+
auto *NewVecTy =
2420+
FixedVectorType::get(DstTy->getElementType(), SrcVecWidth / DstEltWidth);
2421+
auto *NewBC = Builder.CreateBitCast(X, NewVecTy, "bc");
2422+
2423+
unsigned StartIdx = Index * NumElts;
2424+
auto Mask = llvm::to_vector<16>(llvm::seq<int>(StartIdx, StartIdx + NumElts));
2425+
return new ShuffleVectorInst(NewBC, Mask);
2426+
}
2427+
23832428
/// Canonicalize scalar bitcasts of extracted elements into a bitcast of the
23842429
/// vector followed by extract element. The backend tends to handle bitcasts of
23852430
/// vectors better than bitcasts of scalars because vector registers are
@@ -2866,6 +2911,9 @@ Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) {
28662911
if (Instruction *I = canonicalizeBitCastExtElt(CI, *this))
28672912
return I;
28682913

2914+
if (Instruction *I = foldBitCastExtElt(CI, Builder))
2915+
return I;
2916+
28692917
if (Instruction *I = foldBitCastBitwiseLogic(CI, Builder))
28702918
return I;
28712919

llvm/test/Transforms/InstCombine/bitcast.ll

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -484,9 +484,8 @@ define double @bitcast_extelt8(<1 x i64> %A) {
484484

485485
define <2 x i64> @bitcast_extelt9(<8 x i32> %A) {
486486
; CHECK-LABEL: @bitcast_extelt9(
487-
; CHECK-NEXT: [[BC1:%.*]] = bitcast <8 x i32> [[A:%.*]] to <2 x i128>
488-
; CHECK-NEXT: [[EXT:%.*]] = extractelement <2 x i128> [[BC1]], i64 1
489-
; CHECK-NEXT: [[BC2:%.*]] = bitcast i128 [[EXT]] to <2 x i64>
487+
; CHECK-NEXT: [[BC:%.*]] = bitcast <8 x i32> [[A:%.*]] to <4 x i64>
488+
; CHECK-NEXT: [[BC2:%.*]] = shufflevector <4 x i64> [[BC]], <4 x i64> poison, <2 x i32> <i32 2, i32 3>
490489
; CHECK-NEXT: ret <2 x i64> [[BC2]]
491490
;
492491
%bc1 = bitcast <8 x i32> %A to <2 x i128>
@@ -499,9 +498,8 @@ define <2 x i64> @bitcast_extelt9(<8 x i32> %A) {
499498

500499
define <2 x i8> @bitcast_extelt10(<8 x i32> %A) {
501500
; CHECK-LABEL: @bitcast_extelt10(
502-
; CHECK-NEXT: [[BC1:%.*]] = bitcast <8 x i32> [[A:%.*]] to <16 x i16>
503-
; CHECK-NEXT: [[EXT:%.*]] = extractelement <16 x i16> [[BC1]], i64 3
504-
; CHECK-NEXT: [[BC2:%.*]] = bitcast i16 [[EXT]] to <2 x i8>
501+
; CHECK-NEXT: [[BC:%.*]] = bitcast <8 x i32> [[A:%.*]] to <32 x i8>
502+
; CHECK-NEXT: [[BC2:%.*]] = shufflevector <32 x i8> [[BC]], <32 x i8> poison, <2 x i32> <i32 6, i32 7>
505503
; CHECK-NEXT: ret <2 x i8> [[BC2]]
506504
;
507505
%bc1 = bitcast <8 x i32> %A to <16 x i16>

0 commit comments

Comments
 (0)