diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 1a95636f37ed7..d656dcc21ae1e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -2380,6 +2380,51 @@ static Value *optimizeIntegerToVectorInsertions(BitCastInst &CI, return Result; } +/// If the input is (extractelement (bitcast X), Idx) and the source and +/// destination types are vectors, we are performing a vector extract from X. We +/// can replace the extractelement+bitcast with a shufflevector, avoiding the +/// final scalar->vector bitcast. This pattern is usually handled better by the +/// backend. +/// +/// Example: +/// %bc = bitcast <8 x i32> %X to <2 x i128> +/// %ext = extractelement <2 x i128> %bc1, i64 1 +/// bitcast i128 %ext to <2 x i64> +/// ---> +/// %bc = bitcast <8 x i32> %X to <4 x i64> +/// shufflevector <4 x i64> %bc, <4 x i64> poison, <2 x i32> +static Instruction *foldBitCastExtElt(BitCastInst &BitCast, + InstCombiner::BuilderTy &Builder) { + Value *X; + uint64_t Index; + if (!match( + BitCast.getOperand(0), + m_OneUse(m_ExtractElt(m_BitCast(m_Value(X)), m_ConstantInt(Index))))) + return nullptr; + + auto *SrcTy = dyn_cast(X->getType()); + auto *DstTy = dyn_cast(BitCast.getType()); + if (!SrcTy || !DstTy) + return nullptr; + + // Check if the mask indices would overflow. + unsigned NumElts = DstTy->getNumElements(); + if (Index > INT32_MAX || NumElts > INT32_MAX || + (Index + 1) * NumElts - 1 > INT32_MAX) + return nullptr; + + unsigned DstEltWidth = DstTy->getScalarSizeInBits(); + unsigned SrcVecWidth = SrcTy->getPrimitiveSizeInBits(); + assert((SrcVecWidth % DstEltWidth == 0) && "Invalid types."); + auto *NewVecTy = + FixedVectorType::get(DstTy->getElementType(), SrcVecWidth / DstEltWidth); + auto *NewBC = Builder.CreateBitCast(X, NewVecTy, "bc"); + + unsigned StartIdx = Index * NumElts; + auto Mask = llvm::to_vector<16>(llvm::seq(StartIdx, StartIdx + NumElts)); + return new ShuffleVectorInst(NewBC, Mask); +} + /// Canonicalize scalar bitcasts of extracted elements into a bitcast of the /// vector followed by extract element. The backend tends to handle bitcasts of /// vectors better than bitcasts of scalars because vector registers are @@ -2866,6 +2911,9 @@ Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) { if (Instruction *I = canonicalizeBitCastExtElt(CI, *this)) return I; + if (Instruction *I = foldBitCastExtElt(CI, Builder)) + return I; + if (Instruction *I = foldBitCastBitwiseLogic(CI, Builder)) return I; diff --git a/llvm/test/Transforms/InstCombine/bitcast.ll b/llvm/test/Transforms/InstCombine/bitcast.ll index 37d41de3e9991..cade44412341d 100644 --- a/llvm/test/Transforms/InstCombine/bitcast.ll +++ b/llvm/test/Transforms/InstCombine/bitcast.ll @@ -480,6 +480,34 @@ define double @bitcast_extelt8(<1 x i64> %A) { ret double %bc } +; Extract a subvector from a vector, extracted element wider than source. + +define <2 x i64> @bitcast_extelt9(<8 x i32> %A) { +; CHECK-LABEL: @bitcast_extelt9( +; CHECK-NEXT: [[BC:%.*]] = bitcast <8 x i32> [[A:%.*]] to <4 x i64> +; CHECK-NEXT: [[BC2:%.*]] = shufflevector <4 x i64> [[BC]], <4 x i64> poison, <2 x i32> +; CHECK-NEXT: ret <2 x i64> [[BC2]] +; + %bc1 = bitcast <8 x i32> %A to <2 x i128> + %ext = extractelement <2 x i128> %bc1, i64 1 + %bc2 = bitcast i128 %ext to <2 x i64> + ret <2 x i64> %bc2 +} + +; Extract a subvector from a vector, extracted element narrower than source. + +define <2 x i8> @bitcast_extelt10(<8 x i32> %A) { +; CHECK-LABEL: @bitcast_extelt10( +; CHECK-NEXT: [[BC:%.*]] = bitcast <8 x i32> [[A:%.*]] to <32 x i8> +; CHECK-NEXT: [[BC2:%.*]] = shufflevector <32 x i8> [[BC]], <32 x i8> poison, <2 x i32> +; CHECK-NEXT: ret <2 x i8> [[BC2]] +; + %bc1 = bitcast <8 x i32> %A to <16 x i16> + %ext = extractelement <16 x i16> %bc1, i64 3 + %bc2 = bitcast i16 %ext to <2 x i8> + ret <2 x i8> %bc2 +} + define <2 x i32> @test4(i32 %A, i32 %B){ ; CHECK-LABEL: @test4( ; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> poison, i32 [[A:%.*]], i64 0