Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3938,42 +3938,50 @@ SDValue DAGTypeLegalizer::SplitVecOp_EXTRACT_SUBVECTOR(SDNode *N) {

GetSplitVector(N->getOperand(0), Lo, Hi);

uint64_t LoEltsMin = Lo.getValueType().getVectorMinNumElements();
uint64_t IdxVal = Idx->getAsZExtVal();
ElementCount LoElts = Lo.getValueType().getVectorElementCount();
ElementCount IdxVal =
ElementCount::get(Idx->getAsZExtVal(), SubVT.isScalableVector());
uint64_t IdxValMin = IdxVal.getKnownMinValue();

unsigned NumResultElts = SubVT.getVectorMinNumElements();
EVT SrcVT = N->getOperand(0).getValueType();
ElementCount NumResultElts = SubVT.getVectorElementCount();

if (IdxVal < LoEltsMin) {
// If the extracted elements are all in the low half, do a simple extract.
if (IdxVal + NumResultElts <= LoEltsMin)
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SubVT, Lo, Idx);
// If the extracted elements are all in the low half, do a simple extract.
if (ElementCount::isKnownLE(IdxVal + NumResultElts, LoElts))
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SubVT, Lo, Idx);

unsigned LoEltsMin = LoElts.getKnownMinValue();
if (IdxValMin < LoEltsMin &&
!(SubVT.isScalableVector() || SrcVT.isScalableVector())) {
// Extracted subvector crosses vector split, so we need to blend the two
// halves.
// TODO: May be able to emit partial extract_subvector.
SmallVector<SDValue, 8> Elts;
Elts.reserve(NumResultElts);
Elts.reserve(NumResultElts.getFixedValue());

DAG.ExtractVectorElements(Lo, Elts, /*Start=*/IdxVal,
/*Count=*/LoEltsMin - IdxVal);
// This is not valid for scalable vectors. If SubVT is scalable, this is the
// same as unrolling a scalable dimension (invalid). If ScrVT is scalable,
// `Lo[LoEltsMin]` may not be the last element of `Lo`.
DAG.ExtractVectorElements(Lo, Elts, /*Start=*/IdxValMin,
/*Count=*/LoEltsMin - IdxValMin);
DAG.ExtractVectorElements(Hi, Elts, /*Start=*/0,
/*Count=*/SubVT.getVectorNumElements() -
Elts.size());
return DAG.getBuildVector(SubVT, dl, Elts);
}

EVT SrcVT = N->getOperand(0).getValueType();
if (SubVT.isScalableVector() == SrcVT.isScalableVector()) {
uint64_t ExtractIdx = IdxVal - LoEltsMin;
if (ExtractIdx % NumResultElts == 0)
uint64_t ExtractIdx = IdxValMin - LoEltsMin;
unsigned NumResultEltsMin = NumResultElts.getKnownMinValue();
if (ExtractIdx % NumResultEltsMin == 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ElementCount has a nicer way to check for divisibility in it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I've also added an assertion that the case below this only runs for fixed-length vectors (as it's also invalid for scalable vectors, but I don't think it's reachable in that case).

return DAG.getExtractSubvector(dl, SubVT, Hi, ExtractIdx);

// We cannot create an extract_subvector that isn't a multiple of the result
// size, which may go out of bounds for the last elements. Shuffle the
// desired elements down to 0 and do a simple 0 extract.
EVT HiVT = Hi.getValueType();
SmallVector<int, 8> Mask(HiVT.getVectorNumElements(), -1);
for (int I = 0; I != static_cast<int>(NumResultElts); ++I)
for (int I = 0; I != static_cast<int>(NumResultEltsMin); ++I)
Mask[I] = ExtractIdx + I;

SDValue Shuffle =
Expand Down
25 changes: 25 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-fixed-vector-extract-256-bits.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -aarch64-sve-vector-bits-min=256 -aarch64-sve-vector-bits-max=256 < %s -o - | FileCheck %s

; Note: This test case is reduced from: https://github.com/llvm/llvm-project/pull/166748#issuecomment-3600498185

define i32 @test_extract_v8i32_from_nxv8i32(<vscale x 8 x i32> %vec) {
; CHECK-LABEL: test_extract_v8i32_from_nxv8i32:
; CHECK: // %bb.0:
; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-2
; CHECK-NEXT: .cfi_escape 0x0f, 0x08, 0x8f, 0x10, 0x92, 0x2e, 0x00, 0x40, 0x1e, 0x22 // sp + 16 + 16 * VG
; CHECK-NEXT: .cfi_offset w29, -16
; CHECK-NEXT: str z0, [sp]
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: ldr z0, [sp]
; CHECK-NEXT: str z1, [sp, #1, mul vl]
; CHECK-NEXT: uaddv d0, p0, z0.s
; CHECK-NEXT: fmov w0, s0
; CHECK-NEXT: addvl sp, sp, #2
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
%1 = tail call <8 x i32> @llvm.vector.extract.v8i32.nxv8i32(<vscale x 8 x i32> %vec, i64 0)
%2 = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %1)
ret i32 %2
}
Loading