Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 35 additions & 20 deletions llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,27 +542,42 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
}
}
} else if (auto *SVI = dyn_cast<ShuffleVectorInst>(I)) {
// If this is extracting an element from a shufflevector, figure out where
// it came from and extract from the appropriate input element instead.
// Restrict the following transformation to fixed-length vector.
if (isa<FixedVectorType>(SVI->getType()) && isa<ConstantInt>(Index)) {
int SrcIdx =
SVI->getMaskValue(cast<ConstantInt>(Index)->getZExtValue());
Value *Src;
unsigned LHSWidth = cast<FixedVectorType>(SVI->getOperand(0)->getType())
->getNumElements();

if (SrcIdx < 0)
return replaceInstUsesWith(EI, PoisonValue::get(EI.getType()));
if (SrcIdx < (int)LHSWidth)
Src = SVI->getOperand(0);
else {
SrcIdx -= LHSWidth;
Src = SVI->getOperand(1);
int SplatIndex = getSplatIndex(SVI->getShuffleMask());
// We know such a splat must be reading from the first operand, even
// in the case of scalable vectors (vscale is always > 0).
if (SplatIndex == 0)
return ExtractElementInst::Create(SVI->getOperand(0),
Builder.getInt64(0));
// Restrict the non-zero index case to fixed-length vectors
if (isa<FixedVectorType>(SVI->getType())) {

// getSplatIndex doesn't distinguish between the all-poison splat and
// a non-splat mask. However, if Index is -1, we still want to propagate
// that poison value.
int SrcIdx = -2;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely happy with this. I've also considered either a tiny helper function like bool test(int* outParam), or carrying along bool ValidSrcIdx, or rewriting getSplatIndex to be able to meaningfully return the all-poison splat indicator, and this seemed the best option.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use std::optional<int> SrcIdx instead?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@agorenstein-nvidia did std::optional not help simplify this logic?

Copy link
Contributor Author

@agorenstein-nvidia agorenstein-nvidia Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion and ping! I was out from Thursday until yesterday [edit: inclusive]; this remains top of my "priority queue" and I'm integrating your suggestions now. Provisionally it looks like that exactly captures the intent, and clearly I hadn't considered this approach before. I'll finish putting this change in, and your test-improvement-suggestions (and run the tests/linter!) and push ASAP.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been done in the latest push, please take a look when you get the chance. It passes my local versions of git-clang-format and the lit tests, so hopefully that carries through to the CI here.

if (SplatIndex != PoisonMaskElem)
SrcIdx = SplatIndex;
else if (ConstantInt *CI = dyn_cast<ConstantInt>(Index))
SrcIdx = SVI->getMaskValue(CI->getZExtValue());

if (SrcIdx != -2) {
Value *Src;
unsigned LHSWidth =
cast<FixedVectorType>(SVI->getOperand(0)->getType())
->getNumElements();

if (SrcIdx < 0)
return replaceInstUsesWith(EI, PoisonValue::get(EI.getType()));
if (SrcIdx < (int)LHSWidth)
Src = SVI->getOperand(0);
else {
SrcIdx -= LHSWidth;
Src = SVI->getOperand(1);
}
Type *Int64Ty = Type::getInt64Ty(EI.getContext());
return ExtractElementInst::Create(
Src, ConstantInt::get(Int64Ty, SrcIdx, false));
}
Type *Int64Ty = Type::getInt64Ty(EI.getContext());
return ExtractElementInst::Create(
Src, ConstantInt::get(Int64Ty, SrcIdx, false));
}
} else if (auto *CI = dyn_cast<CastInst>(I)) {
// Canonicalize extractelement(cast) -> cast(extractelement).
Expand Down
52 changes: 52 additions & 0 deletions llvm/test/Transforms/InstCombine/vec_extract_through_broadcast.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
; RUN: opt -passes=instcombine -S < %s | FileCheck %s

define float @extract_from_zero_init_shuffle(<2 x float> %1, i64 %idx) {
; CHECK-LABEL: @extract_from_zero_init_shuffle(
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x float> [[W:%.*]], i64 0
; CHECK-NEXT: ret float [[TMP1]]
;
%3 = shufflevector <2 x float> %1, <2 x float> poison, <4 x i32> zeroinitializer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(style) avoid using numbers for variable names (%vec, %shuffle and %extract would be better)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in latest push, please take a look. Thanks.

%4 = extractelement <4 x float> %3, i64 %idx
ret float %4
}


define float @extract_from_general_splat(<2 x float> %1, i64 %idx) {
; CHECK-LABEL: @extract_from_general_splat(
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x float> [[W:%.*]], i64 1
; CHECK-NEXT: ret float [[TMP1]]
;
%3 = shufflevector <2 x float> %1, <2 x float> poison, <4 x i32> <i32 1, i32 1, i32 1, i32 1>
%4 = extractelement <4 x float> %3, i64 %idx
ret float %4
}

define float @extract_from_general_scalable_splat(<vscale x 2 x float> %1, i64 %idx) {
; CHECK-LABEL: @extract_from_general_scalable_splat(
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <vscale x 2 x float> [[W:%.*]], i64 0
; CHECK-NEXT: ret float [[TMP1]]
;
%3 = shufflevector <vscale x 2 x float> %1, <vscale x 2 x float> poison, <vscale x 4 x i32> zeroinitializer
%4 = extractelement <vscale x 4 x float> %3, i64 %idx
ret float %4
}

define float @extract_from_splat_with_poison_0(<2 x float> %1, i64 %idx) {
; CHECK-LABEL: @extract_from_splat_with_poison_0(
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x float> [[TMP1:%.*]], i64 1
; CHECK-NEXT: ret float [[TMP2]]
;
%3 = shufflevector <2 x float> %1, <2 x float> poison, <4 x i32> <i32 poison, i32 1, i32 1, i32 1>
%4 = extractelement <4 x float> %3, i64 %idx
ret float %4
}

define float @extract_from_splat_with_poison_1(<2 x float> %1, i64 %idx) {
; CHECK-LABEL: @extract_from_splat_with_poison_1(
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x float> [[TMP1:%.*]], i64 1
; CHECK-NEXT: ret float [[TMP2]]
;
%3 = shufflevector <2 x float> %1, <2 x float> poison, <4 x i32> <i32 1, i32 poison, i32 1, i32 1>
%4 = extractelement <4 x float> %3, i64 %idx
ret float %4
}
3 changes: 1 addition & 2 deletions llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ define float @test6(<4 x float> %X) {

define float @testvscale6(<vscale x 4 x float> %X) {
; CHECK-LABEL: @testvscale6(
; CHECK-NEXT: [[T2:%.*]] = shufflevector <vscale x 4 x float> [[X:%.*]], <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x float> [[T2]], i64 0
; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x float> [[X:%.*]], i64 0
; CHECK-NEXT: ret float [[R]]
;
%X1 = bitcast <vscale x 4 x float> %X to <vscale x 4 x i32>
Expand Down
3 changes: 1 addition & 2 deletions llvm/test/Transforms/InstCombine/vec_shuffle.ll
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ define float @test6(<4 x float> %X) {

define float @testvscale6(<vscale x 4 x float> %X) {
; CHECK-LABEL: @testvscale6(
; CHECK-NEXT: [[T2:%.*]] = shufflevector <vscale x 4 x float> [[X:%.*]], <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x float> [[T2]], i64 0
; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x float> [[X:%.*]], i64 0
; CHECK-NEXT: ret float [[R]]
;
%X1 = bitcast <vscale x 4 x float> %X to <vscale x 4 x i32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,7 @@ define i8 @extractelement_bitcast_insert_extra_use_bitcast(<vscale x 2 x i32> %a

define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) {
; CHECK-LABEL: @extractelement_shuffle_maybe_out_of_range(
; CHECK-NEXT: [[IN:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[V:%.*]], i64 0
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[IN]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[SPLAT]], i64 4
; CHECK-NEXT: ret i32 [[R]]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is analogous to the other comment; it may be that the extractelement ... i32 4 is out of bounds (when vscale=1, so to speak). However, we are able to know that any valid index into %splat must be %v, so we just return %v. So, similar questions as the other comment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be useful to add a comment above the test explaining the reasoning,similar to your comment here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did so! Thanks.

; CHECK-NEXT: ret i32 [[V:%.*]]
;
%in = insertelement <vscale x 4 x i32> poison, i32 %v, i32 0
%splat = shufflevector <vscale x 4 x i32> %in, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
Expand All @@ -104,10 +101,7 @@ define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) {

define i32 @extractelement_shuffle_invalid_index(i32 %v) {
; CHECK-LABEL: @extractelement_shuffle_invalid_index(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we optimize, but also in doing so miss the poison "opportunity" (in that the illegal -1 index-usage is removed). So I believe this is legal, but perhaps not optimal. However, I'm a bit confused: it is the case that earlier in our transformation we explicitly look for invalid indices in extractelements:

    // InstSimplify should handle cases where the index is invalid.
    // For fixed-length vector, it's invalid to extract out-of-range element.

Without tracing through, it looks like this may "just" be an order-of-transformations issue. I'm inclined to preserve the behavior. I've simply changed the test to reflect the new output; is there a preferred way of updating/changing these tests in this sort of situation?

; CHECK-NEXT: [[IN:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[V:%.*]], i64 0
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[IN]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[SPLAT]], i64 4294967295
; CHECK-NEXT: ret i32 [[R]]
; CHECK-NEXT: ret i32 [[V:%.*]]
;
%in = insertelement <vscale x 4 x i32> poison, i32 %v, i32 0
%splat = shufflevector <vscale x 4 x i32> %in, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
Expand Down
10 changes: 2 additions & 8 deletions llvm/test/Transforms/InstCombine/vscale_extractelement.ll
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,7 @@ define i8 @extractelement_bitcast_useless_insert(<vscale x 2 x i32> %a, i32 %x)

define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) {
; CHECK-LABEL: @extractelement_shuffle_maybe_out_of_range(
; CHECK-NEXT: [[IN:%.*]] = insertelement <vscale x 4 x i32> undef, i32 [[V:%.*]], i64 0
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[IN]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[SPLAT]], i64 4
; CHECK-NEXT: ret i32 [[R]]
; CHECK-NEXT: ret i32 [[V:%.*]]
;
%in = insertelement <vscale x 4 x i32> undef, i32 %v, i32 0
%splat = shufflevector <vscale x 4 x i32> %in, <vscale x 4 x i32> undef, <vscale x 4 x i32> zeroinitializer
Expand All @@ -68,10 +65,7 @@ define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) {

define i32 @extractelement_shuffle_invalid_index(i32 %v) {
; CHECK-LABEL: @extractelement_shuffle_invalid_index(
; CHECK-NEXT: [[IN:%.*]] = insertelement <vscale x 4 x i32> undef, i32 [[V:%.*]], i64 0
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[IN]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[SPLAT]], i64 4294967295
; CHECK-NEXT: ret i32 [[R]]
; CHECK-NEXT: ret i32 [[V:%.*]]
;
%in = insertelement <vscale x 4 x i32> undef, i32 %v, i32 0
%splat = shufflevector <vscale x 4 x i32> %in, <vscale x 4 x i32> undef, <vscale x 4 x i32> zeroinitializer
Expand Down
Loading