Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7338,16 +7338,23 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
Op.getValueType().getVectorElementCount() == NumElts;
};

auto IsBuildVectorSplatVectorOrUndef = [](const SDValue &Op) {
// UNDEF: folds to undef
// BUILD_VECTOR: may have constant elements
// SPLAT_VECTOR: could be a splat of a constant
// INSERT_SUBVECTOR: could be inserting a constant splat into an undef vector
// - This pattern occurs when a fixed-length vector splat is inserted into
// a scalable vector
auto VectorOpMayConstantFold = [](const SDValue &Op) {
return Op.isUndef() || Op.getOpcode() == ISD::CONDCODE ||
Op.getOpcode() == ISD::BUILD_VECTOR ||
Op.getOpcode() == ISD::SPLAT_VECTOR;
Op.getOpcode() == ISD::SPLAT_VECTOR ||
Op.getOpcode() == ISD::INSERT_SUBVECTOR;
};

// All operands must be vector types with the same number of elements as
// the result type and must be either UNDEF or a build/splat vector
// or UNDEF scalars.
if (!llvm::all_of(Ops, IsBuildVectorSplatVectorOrUndef) ||
if (!llvm::all_of(Ops, VectorOpMayConstantFold) ||
!llvm::all_of(Ops, IsScalarOrSameVectorSize))
return SDValue();

Expand All @@ -7374,23 +7381,42 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
// a combination of BUILD_VECTOR and SPLAT_VECTOR.
unsigned NumVectorElts = NumElts.isScalable() ? 1 : NumElts.getFixedValue();

// Preprocess insert_subvector to avoid repeatedly matching the splat.
SmallVector<SDValue, 4> PreprocessedOps;
for (SDValue Op : Ops) {
if (Op.getOpcode() == ISD::INSERT_SUBVECTOR) {
// match: `insert_subvector undef, (splat X), N2` as `splat X`
SDValue N0 = Op.getOperand(0);
auto *BV = dyn_cast<BuildVectorSDNode>(Op.getOperand(1));
if (!N0.isUndef() || !BV || !(Op = BV->getSplatValue()))
return SDValue();
}
PreprocessedOps.push_back(Op);
}

// Constant fold each scalar lane separately.
SmallVector<SDValue, 4> ScalarResults;
for (unsigned I = 0; I != NumVectorElts; I++) {
SmallVector<SDValue, 4> ScalarOps;
for (SDValue Op : Ops) {
for (SDValue Op : PreprocessedOps) {
EVT InSVT = Op.getValueType().getScalarType();
if (Op.getOpcode() != ISD::BUILD_VECTOR &&
Op.getOpcode() != ISD::SPLAT_VECTOR) {
Op.getOpcode() != ISD::SPLAT_VECTOR &&
Op.getOpcode() != ISD::INSERT_SUBVECTOR) {
if (Op.isUndef())
ScalarOps.push_back(getUNDEF(InSVT));
else
ScalarOps.push_back(Op);
continue;
}

// insert_subvector has been preprocessed, so if it was of the form
// `insert_subvector undef, (splat X), N2`, it has been replaced with the
// splat value (X).
SDValue ScalarOp =
Op.getOperand(Op.getOpcode() == ISD::SPLAT_VECTOR ? 0 : I);
Op.getOpcode() == ISD::INSERT_SUBVECTOR
? Op
: Op.getOperand(Op.getOpcode() == ISD::SPLAT_VECTOR ? 0 : I);
EVT ScalarVT = ScalarOp.getValueType();

// Build vector (integer) scalar operands may need implicit
Expand Down
46 changes: 46 additions & 0 deletions llvm/test/CodeGen/AArch64/fixed-subvector-insert-into-scalable.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s

define <vscale x 4 x i32> @insert_div() {
; CHECK-LABEL: insert_div:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: mov z0.s, #3 // =0x3
; CHECK-NEXT: ret
entry:
%0 = tail call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> poison, <4 x i32> splat (i32 9), i64 0)
%div = udiv <vscale x 4 x i32> %0, splat (i32 3)
ret <vscale x 4 x i32> %div
}

define <vscale x 4 x i32> @insert_mul() {
; CHECK-LABEL: insert_mul:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: mov z0.s, #7 // =0x7
; CHECK-NEXT: ret
entry:
%0 = tail call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> poison, <4 x i32> splat (i32 1), i64 0)
%mul = mul <vscale x 4 x i32> %0, splat (i32 7)
ret <vscale x 4 x i32> %mul
}

define <vscale x 4 x i32> @insert_add() {
; CHECK-LABEL: insert_add:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: mov z0.s, #16 // =0x10
; CHECK-NEXT: ret
entry:
%0 = tail call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> poison, <4 x i32> splat (i32 5), i64 0)
%add = add <vscale x 4 x i32> %0, splat (i32 11)
ret <vscale x 4 x i32> %add
}

define <vscale x 4 x i32> @insert_sub() {
; CHECK-LABEL: insert_sub:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: movi v0.2d, #0000000000000000
; CHECK-NEXT: ret
entry:
%0 = tail call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32> poison, <4 x i32> splat (i32 11), i64 0)
%sub = add <vscale x 4 x i32> %0, splat (i32 -11)
ret <vscale x 4 x i32> %sub
}
Loading