diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index f946c3856948b..2c7ecf8d60b77 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -542,27 +542,39 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { } } } else if (auto *SVI = dyn_cast(I)) { - // If this is extracting an element from a shufflevector, figure out where - // it came from and extract from the appropriate input element instead. - // Restrict the following transformation to fixed-length vector. - if (isa(SVI->getType()) && isa(Index)) { - int SrcIdx = - SVI->getMaskValue(cast(Index)->getZExtValue()); - Value *Src; - unsigned LHSWidth = cast(SVI->getOperand(0)->getType()) - ->getNumElements(); - - if (SrcIdx < 0) - return replaceInstUsesWith(EI, PoisonValue::get(EI.getType())); - if (SrcIdx < (int)LHSWidth) - Src = SVI->getOperand(0); - else { - SrcIdx -= LHSWidth; - Src = SVI->getOperand(1); + int SplatIndex = getSplatIndex(SVI->getShuffleMask()); + // We know the all-0 splat must be reading from the first operand, even + // in the case of scalable vectors (vscale is always > 0). + if (SplatIndex == 0) + return ExtractElementInst::Create(SVI->getOperand(0), + Builder.getInt64(0)); + + if (isa(SVI->getType())) { + std::optional SrcIdx; + // getSplatIndex returns -1 to mean not-found. + if (SplatIndex != -1) + SrcIdx = SplatIndex; + else if (ConstantInt *CI = dyn_cast(Index)) + SrcIdx = SVI->getMaskValue(CI->getZExtValue()); + + if (SrcIdx) { + Value *Src; + unsigned LHSWidth = + cast(SVI->getOperand(0)->getType()) + ->getNumElements(); + + if (*SrcIdx < 0) + return replaceInstUsesWith(EI, PoisonValue::get(EI.getType())); + if (*SrcIdx < (int)LHSWidth) + Src = SVI->getOperand(0); + else { + *SrcIdx -= LHSWidth; + Src = SVI->getOperand(1); + } + Type *Int64Ty = Type::getInt64Ty(EI.getContext()); + return ExtractElementInst::Create( + Src, ConstantInt::get(Int64Ty, *SrcIdx, false)); } - Type *Int64Ty = Type::getInt64Ty(EI.getContext()); - return ExtractElementInst::Create( - Src, ConstantInt::get(Int64Ty, SrcIdx, false)); } } else if (auto *CI = dyn_cast(I)) { // Canonicalize extractelement(cast) -> cast(extractelement). diff --git a/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll b/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll new file mode 100644 index 0000000000000..1ec3dc3e4b40a --- /dev/null +++ b/llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll @@ -0,0 +1,52 @@ +; RUN: opt -passes=instcombine -S < %s | FileCheck %s + +define float @extract_from_zero_init_shuffle(<2 x float> %vec, i64 %idx) { +; CHECK-LABEL: @extract_from_zero_init_shuffle( +; CHECK-NEXT: %extract = extractelement <2 x float> %vec, i64 0 +; CHECK-NEXT: ret float %extract +; + %shuffle = shufflevector <2 x float> %vec, <2 x float> poison, <4 x i32> zeroinitializer + %extract = extractelement <4 x float> %shuffle, i64 %idx + ret float %extract +} + + +define float @extract_from_general_splat(<2 x float> %vec, i64 %idx) { +; CHECK-LABEL: @extract_from_general_splat( +; CHECK-NEXT: %extract = extractelement <2 x float> %vec, i64 1 +; CHECK-NEXT: ret float %extract +; + %shuffle = shufflevector <2 x float> %vec, <2 x float> poison, <4 x i32> + %extract = extractelement <4 x float> %shuffle, i64 %idx + ret float %extract +} + +define float @extract_from_general_scalable_splat( %vec, i64 %idx) { +; CHECK-LABEL: @extract_from_general_scalable_splat( +; CHECK-NEXT: %extract = extractelement %vec, i64 0 +; CHECK-NEXT: ret float %extract +; + %shuffle = shufflevector %vec, poison, zeroinitializer + %extract = extractelement %shuffle, i64 %idx + ret float %extract +} + +define float @extract_from_splat_with_poison_0(<2 x float> %vec, i64 %idx) { +; CHECK-LABEL: @extract_from_splat_with_poison_0( +; CHECK-NEXT: %extract = extractelement <2 x float> %vec, i64 1 +; CHECK-NEXT: ret float %extract +; + %shuffle = shufflevector <2 x float> %vec, <2 x float> poison, <4 x i32> + %extract = extractelement <4 x float> %shuffle, i64 %idx + ret float %extract +} + +define float @extract_from_splat_with_poison_1(<2 x float> %vec, i64 %idx) { +; CHECK-LABEL: @extract_from_splat_with_poison_1( +; CHECK-NEXT: %extract = extractelement <2 x float> %vec, i64 1 +; CHECK-NEXT: ret float %extract +; + %shuffle = shufflevector <2 x float> %vec, <2 x float> poison, <4 x i32> + %extract = extractelement <4 x float> %shuffle, i64 %idx + ret float %extract +} diff --git a/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll b/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll index 9aa050e8cd500..cc8ecd9aefb1c 100644 --- a/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll +++ b/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll @@ -61,8 +61,7 @@ define float @test6(<4 x float> %X) { define float @testvscale6( %X) { ; CHECK-LABEL: @testvscale6( -; CHECK-NEXT: [[T2:%.*]] = shufflevector [[X:%.*]], poison, zeroinitializer -; CHECK-NEXT: [[R:%.*]] = extractelement [[T2]], i64 0 +; CHECK-NEXT: [[R:%.*]] = extractelement [[X:%.*]], i64 0 ; CHECK-NEXT: ret float [[R]] ; %X1 = bitcast %X to diff --git a/llvm/test/Transforms/InstCombine/vec_shuffle.ll b/llvm/test/Transforms/InstCombine/vec_shuffle.ll index 83919e743d384..f4ee0e7f2eb95 100644 --- a/llvm/test/Transforms/InstCombine/vec_shuffle.ll +++ b/llvm/test/Transforms/InstCombine/vec_shuffle.ll @@ -67,8 +67,7 @@ define float @test6(<4 x float> %X) { define float @testvscale6( %X) { ; CHECK-LABEL: @testvscale6( -; CHECK-NEXT: [[T2:%.*]] = shufflevector [[X:%.*]], poison, zeroinitializer -; CHECK-NEXT: [[R:%.*]] = extractelement [[T2]], i64 0 +; CHECK-NEXT: [[R:%.*]] = extractelement [[X:%.*]], i64 0 ; CHECK-NEXT: ret float [[R]] ; %X1 = bitcast %X to diff --git a/llvm/test/Transforms/InstCombine/vscale_extractelement-inseltpoison.ll b/llvm/test/Transforms/InstCombine/vscale_extractelement-inseltpoison.ll index 2655c20354607..36ed39a3d3242 100644 --- a/llvm/test/Transforms/InstCombine/vscale_extractelement-inseltpoison.ll +++ b/llvm/test/Transforms/InstCombine/vscale_extractelement-inseltpoison.ll @@ -89,12 +89,12 @@ define i8 @extractelement_bitcast_insert_extra_use_bitcast( %a ret i8 %r } +; while it may be that the extract is out-of-bounds, any valid index +; is going to yield %v (because the mask is all-zeros). + define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) { ; CHECK-LABEL: @extractelement_shuffle_maybe_out_of_range( -; CHECK-NEXT: [[IN:%.*]] = insertelement poison, i32 [[V:%.*]], i64 0 -; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector [[IN]], poison, zeroinitializer -; CHECK-NEXT: [[R:%.*]] = extractelement [[SPLAT]], i64 4 -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: ret i32 [[V:%.*]] ; %in = insertelement poison, i32 %v, i32 0 %splat = shufflevector %in, poison, zeroinitializer @@ -104,10 +104,7 @@ define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) { define i32 @extractelement_shuffle_invalid_index(i32 %v) { ; CHECK-LABEL: @extractelement_shuffle_invalid_index( -; CHECK-NEXT: [[IN:%.*]] = insertelement poison, i32 [[V:%.*]], i64 0 -; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector [[IN]], poison, zeroinitializer -; CHECK-NEXT: [[R:%.*]] = extractelement [[SPLAT]], i64 4294967295 -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: ret i32 [[V:%.*]] ; %in = insertelement poison, i32 %v, i32 0 %splat = shufflevector %in, poison, zeroinitializer diff --git a/llvm/test/Transforms/InstCombine/vscale_extractelement.ll b/llvm/test/Transforms/InstCombine/vscale_extractelement.ll index 07090e9099ae1..9ac8a92abb689 100644 --- a/llvm/test/Transforms/InstCombine/vscale_extractelement.ll +++ b/llvm/test/Transforms/InstCombine/vscale_extractelement.ll @@ -53,12 +53,12 @@ define i8 @extractelement_bitcast_useless_insert( %a, i32 %x) ret i8 %r } +; while in these tests it may be that the extract is out-of-bounds, +; any valid index is going to yield %v (because the mask is all-zeros). + define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) { ; CHECK-LABEL: @extractelement_shuffle_maybe_out_of_range( -; CHECK-NEXT: [[IN:%.*]] = insertelement undef, i32 [[V:%.*]], i64 0 -; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector [[IN]], poison, zeroinitializer -; CHECK-NEXT: [[R:%.*]] = extractelement [[SPLAT]], i64 4 -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: ret i32 [[V:%.*]] ; %in = insertelement undef, i32 %v, i32 0 %splat = shufflevector %in, undef, zeroinitializer @@ -68,10 +68,7 @@ define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) { define i32 @extractelement_shuffle_invalid_index(i32 %v) { ; CHECK-LABEL: @extractelement_shuffle_invalid_index( -; CHECK-NEXT: [[IN:%.*]] = insertelement undef, i32 [[V:%.*]], i64 0 -; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector [[IN]], poison, zeroinitializer -; CHECK-NEXT: [[R:%.*]] = extractelement [[SPLAT]], i64 4294967295 -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: ret i32 [[V:%.*]] ; %in = insertelement undef, i32 %v, i32 0 %splat = shufflevector %in, undef, zeroinitializer