Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7008,10 +7008,34 @@ SDValue DAGTypeLegalizer::WidenVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
EVT VT = N->getValueType(0);
EVT WideAccVT = TLI.getTypeToTransformTo(*DAG.getContext(),
N->getOperand(0).getValueType());
SDValue Zero = DAG.getConstant(0, DL, WideAccVT);
ElementCount WideAccEC = WideAccVT.getVectorElementCount();

// Widen mul-operands if needed, otherwise we'll end up with a
// node that isn't legal because the accumulator vector will not
// be a known multiple of the input vector.
SDValue MulOp1 = N->getOperand(1);
SDValue MulOp2 = N->getOperand(2);
SDValue Acc = DAG.getInsertSubvector(DL, Zero, N->getOperand(0), 0);
EVT MulOpVT = MulOp1.getValueType();
ElementCount MulOpEC = MulOpVT.getVectorElementCount();
if (getTypeAction(MulOpVT) == TargetLowering::TypeWidenVector) {
EVT WideMulVT = GetWidenedVector(MulOp1).getValueType();
assert(WideMulVT.getVectorElementCount().isKnownMultipleOf(WideAccEC) &&
"Widening to a vector with less elements than accumulator?");
SDValue Zero = DAG.getConstant(0, DL, WideMulVT);
MulOp1 = DAG.getInsertSubvector(DL, Zero, MulOp1, 0);
MulOp2 = DAG.getInsertSubvector(DL, Zero, MulOp2, 0);
} else if (!MulOpEC.isKnownMultipleOf(WideAccEC)) {
assert(getTypeAction(MulOpVT) != TargetLowering::TypeLegal &&
"Expected Mul operands to need legalisation");
EVT WideMulVT = EVT::getVectorVT(*DAG.getContext(),
MulOpVT.getVectorElementType(), WideAccEC);
SDValue Zero = DAG.getConstant(0, DL, WideMulVT);
MulOp1 = DAG.getInsertSubvector(DL, Zero, MulOp1, 0);
MulOp2 = DAG.getInsertSubvector(DL, Zero, MulOp2, 0);
}
Comment on lines +7020 to +7035
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the only difference between these two blocks is their assertions and how they assign `WideMulVT``, so they could share the zero vector construction and subvector insertion as:

bool NeedsWidening = getTypeAction(MulOpVT) == TargetLowering::TypeWidenVector;
bool NarrowMultipleOfWide = MulOpEC.isKnownMultipleOf(WideAccEC);

if (NeedsWidening || !NarrowMultipleOfWide) {
    EVT WideMulVT;
    if (NeedsWidening) {
        assert(...)
        ...
    } else {
        assert(...)
        ...
    }
    SDValue Zero = ...
    MulOp1 = ...
    MulOp2 = ...
}


SDValue Acc = DAG.getInsertSubvector(DL, DAG.getConstant(0, DL, WideAccVT),
N->getOperand(0), 0);
SDValue WidenedRes =
DAG.getNode(N->getOpcode(), DL, WideAccVT, Acc, MulOp1, MulOp2);
while (ElementCount::isKnownLT(
Expand Down Expand Up @@ -8069,6 +8093,9 @@ SDValue DAGTypeLegalizer::WidenVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
EVT WideOpVT = TLI.getTypeToTransformTo(*DAG.getContext(),
N->getOperand(1).getValueType());
SDValue Acc = N->getOperand(0);
assert(WideOpVT.getVectorElementCount().isKnownMultipleOf(
Acc.getValueType().getVectorElementCount()) &&
"Expected AccVT to have been legalised");
SDValue WidenedOp1 = DAG.getInsertSubvector(
DL, DAG.getConstant(0, DL, WideOpVT), N->getOperand(1), 0);
SDValue WidenedOp2 = DAG.getInsertSubvector(
Expand Down
190 changes: 151 additions & 39 deletions llvm/test/CodeGen/AArch64/partial-reduce-widen.ll
Original file line number Diff line number Diff line change
@@ -1,39 +1,29 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
; RUN: llc < %s | FileCheck %s
; RUN: llc -mattr=+neon,+sve2p1,+dotprod < %s | FileCheck %s

target triple = "aarch64"

define void @partial_reduce_widen_v1i32_acc_v16i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
; CHECK-LABEL: partial_reduce_widen_v1i32_acc_v16i32_vec:
define void @partial_reduce_widen_v1i32_acc_legal_v4i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
; CHECK-LABEL: partial_reduce_widen_v1i32_acc_legal_v4i32_vec:
; CHECK: // %bb.0:
; CHECK-NEXT: ldp q1, q0, [x2]
; CHECK-NEXT: ldr s2, [x0]
; CHECK-NEXT: ldp q5, q6, [x2, #32]
; CHECK-NEXT: ext v3.16b, v1.16b, v1.16b, #8
; CHECK-NEXT: ext v4.16b, v0.16b, v0.16b, #8
; CHECK-NEXT: add v1.2s, v2.2s, v1.2s
; CHECK-NEXT: ext v2.16b, v5.16b, v5.16b, #8
; CHECK-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NEXT: add v1.2s, v4.2s, v3.2s
; CHECK-NEXT: ext v3.16b, v6.16b, v6.16b, #8
; CHECK-NEXT: add v0.2s, v0.2s, v5.2s
; CHECK-NEXT: add v1.2s, v2.2s, v1.2s
; CHECK-NEXT: add v0.2s, v0.2s, v6.2s
; CHECK-NEXT: add v1.2s, v3.2s, v1.2s
; CHECK-NEXT: ldr q0, [x2]
; CHECK-NEXT: ldr s1, [x0]
; CHECK-NEXT: ext v2.16b, v0.16b, v0.16b, #8
; CHECK-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NEXT: add v0.2s, v2.2s, v0.2s
; CHECK-NEXT: dup v1.2s, v0.s[1]
; CHECK-NEXT: add v0.2s, v0.2s, v1.2s
; CHECK-NEXT: str s0, [x1]
; CHECK-NEXT: ret
%acc = load <1 x i32>, ptr %accptr
%vec = load <16 x i32>, ptr %vecptr
%partial.reduce = call <1 x i32> @llvm.vector.partial.reduce.add(<1 x i32> %acc, <16 x i32> %vec)
%vec = load <4 x i32>, ptr %vecptr
%partial.reduce = call <1 x i32> @llvm.vector.partial.reduce.add(<1 x i32> %acc, <4 x i32> %vec)
store <1 x i32> %partial.reduce, ptr %resptr
ret void
}

define void @partial_reduce_widen_v3i32_acc_v12i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
; CHECK-LABEL: partial_reduce_widen_v3i32_acc_v12i32_vec:
define void @partial_reduce_widen_v3i32_acc_widen_v12i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
; CHECK-LABEL: partial_reduce_widen_v3i32_acc_widen_v12i32_vec:
; CHECK: // %bb.0:
; CHECK-NEXT: sub sp, sp, #128
; CHECK-NEXT: .cfi_def_cfa_offset 128
Expand All @@ -58,41 +48,163 @@ define void @partial_reduce_widen_v3i32_acc_v12i32_vec(ptr %accptr, ptr %resptr,
ret void
}

define void @partial_reduce_widen_v4i32_acc_v20i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
; CHECK-LABEL: partial_reduce_widen_v4i32_acc_v20i32_vec:
define void @partial_reduce_widen_v1i32_acc_widen_v12i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
; CHECK-LABEL: partial_reduce_widen_v1i32_acc_widen_v12i32_vec:
; CHECK: // %bb.0:
; CHECK-NEXT: sub sp, sp, #272
; CHECK-NEXT: str x29, [sp, #256] // 8-byte Folded Spill
; CHECK-NEXT: .cfi_def_cfa_offset 272
; CHECK-NEXT: .cfi_offset w29, -16
; CHECK-NEXT: sub sp, sp, #128
; CHECK-NEXT: .cfi_def_cfa_offset 128
; CHECK-NEXT: ldp q1, q0, [x2]
; CHECK-NEXT: ldr s2, [x0]
; CHECK-NEXT: ldp q5, q6, [x2, #32]
; CHECK-NEXT: ldr x29, [sp, #256] // 8-byte Folded Reload
; CHECK-NEXT: ldr q5, [x2, #32]
; CHECK-NEXT: ext v3.16b, v1.16b, v1.16b, #8
; CHECK-NEXT: ext v4.16b, v0.16b, v0.16b, #8
; CHECK-NEXT: add v1.2s, v2.2s, v1.2s
; CHECK-NEXT: ext v2.16b, v5.16b, v5.16b, #8
; CHECK-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NEXT: add v1.2s, v4.2s, v3.2s
; CHECK-NEXT: ext v3.16b, v6.16b, v6.16b, #8
; CHECK-NEXT: ldr q4, [x2, #64]
; CHECK-NEXT: add v0.2s, v0.2s, v5.2s
; CHECK-NEXT: add v1.2s, v2.2s, v1.2s
; CHECK-NEXT: ext v2.16b, v4.16b, v4.16b, #8
; CHECK-NEXT: add v0.2s, v0.2s, v6.2s
; CHECK-NEXT: add v1.2s, v3.2s, v1.2s
; CHECK-NEXT: add v0.2s, v0.2s, v4.2s
; CHECK-NEXT: add v1.2s, v2.2s, v1.2s
; CHECK-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NEXT: dup v1.2s, v0.s[1]
; CHECK-NEXT: add v0.2s, v0.2s, v1.2s
; CHECK-NEXT: str s0, [x1]
; CHECK-NEXT: add sp, sp, #272
; CHECK-NEXT: add sp, sp, #128
; CHECK-NEXT: ret
%acc = load <1 x i32>, ptr %accptr
%vec = load <20 x i32>, ptr %vecptr
%partial.reduce = call <1 x i32> @llvm.vector.partial.reduce.add(<1 x i32> %acc, <20 x i32> %vec)
%vec = load <12 x i32>, ptr %vecptr
%partial.reduce = call <1 x i32> @llvm.vector.partial.reduce.add(<1 x i32> %acc, <12 x i32> %vec)
store <1 x i32> %partial.reduce, ptr %resptr
ret void
}

define void @partial_reduce_widen_v4i32_acc_widen_v12i8_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
; CHECK-LABEL: partial_reduce_widen_v4i32_acc_widen_v12i8_vec:
; CHECK: // %bb.0:
; CHECK-NEXT: sub sp, sp, #128
; CHECK-NEXT: .cfi_def_cfa_offset 128
; CHECK-NEXT: ldr q0, [x2]
; CHECK-NEXT: ldr q2, [x0]
; CHECK-NEXT: umull v1.8h, v0.8b, v0.8b
; CHECK-NEXT: umull2 v0.8h, v0.16b, v0.16b
; CHECK-NEXT: uaddw v2.4s, v2.4s, v1.4h
; CHECK-NEXT: uaddw2 v1.4s, v2.4s, v1.8h
; CHECK-NEXT: uaddw v0.4s, v1.4s, v0.4h
; CHECK-NEXT: str q0, [x1]
; CHECK-NEXT: add sp, sp, #128
; CHECK-NEXT: ret
%acc = load <4 x i32>, ptr %accptr
%vec = load <12 x i8>, ptr %vecptr
%vec.zext = zext <12 x i8> %vec to <12 x i32>
%vec.mul = mul <12 x i32> %vec.zext, %vec.zext
%partial.reduce = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <12 x i32> %vec.mul)
store <4 x i32> %partial.reduce, ptr %resptr
ret void
}

define void @partial_reduce_widen_v1i8_acc_promote_v4i8_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
; CHECK-LABEL: partial_reduce_widen_v1i8_acc_promote_v4i8_vec:
; CHECK: // %bb.0:
; CHECK-NEXT: ldr s1, [x2]
; CHECK-NEXT: movi v0.2d, #0000000000000000
; CHECK-NEXT: adrp x8, .LCPI4_0
; CHECK-NEXT: ldr b2, [x0]
; CHECK-NEXT: zip1 v1.8b, v1.8b, v1.8b
; CHECK-NEXT: uzp1 v0.8b, v1.8b, v0.8b
; CHECK-NEXT: ldr d1, [x8, :lo12:.LCPI4_0]
; CHECK-NEXT: mla v2.8b, v0.8b, v1.8b
; CHECK-NEXT: zip2 v0.8b, v2.8b, v0.8b
; CHECK-NEXT: zip1 v1.8b, v2.8b, v0.8b
; CHECK-NEXT: add v0.4h, v1.4h, v0.4h
; CHECK-NEXT: zip2 v1.4h, v0.4h, v0.4h
; CHECK-NEXT: uaddw v0.4s, v1.4s, v0.4h
; CHECK-NEXT: mov w8, v0.s[1]
; CHECK-NEXT: fmov s1, w8
; CHECK-NEXT: add v0.8b, v0.8b, v1.8b
; CHECK-NEXT: str b0, [x1]
; CHECK-NEXT: ret
%acc = load <1 x i8>, ptr %accptr
%vec = load <4 x i8>, ptr %vecptr
%res = call <1 x i8> @llvm.vector.partial.reduce.add(<1 x i8> %acc, <4 x i8> %vec)
store <1 x i8> %res, ptr %resptr
ret void
}

define void @partial_reduce_widen_v3i32_acc_widen_v12i8_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
; CHECK-LABEL: partial_reduce_widen_v3i32_acc_widen_v12i8_vec:
; CHECK: // %bb.0:
; CHECK-NEXT: movi v0.2d, #0000000000000000
; CHECK-NEXT: ldr q1, [x0]
; CHECK-NEXT: ldr q2, [x2]
; CHECK-NEXT: adrp x8, .LCPI5_0
; CHECK-NEXT: mov v1.s[3], wzr
; CHECK-NEXT: mov v2.s[3], v0.s[3]
; CHECK-NEXT: ldr q0, [x8, :lo12:.LCPI5_0]
; CHECK-NEXT: udot z1.s, z2.b, z0.b
; CHECK-NEXT: ext v0.16b, v1.16b, v1.16b, #8
; CHECK-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NEXT: mov s1, v0.s[2]
; CHECK-NEXT: str d0, [x1]
; CHECK-NEXT: str s1, [x1, #8]
; CHECK-NEXT: ret
%acc = load <3 x i32>, ptr %accptr
%vec = load <12 x i8>, ptr %vecptr
%res = call <3 x i32> @llvm.vector.partial.reduce.add(<3 x i32> %acc, <12 x i8> %vec)
store <3 x i32> %res, ptr %resptr
ret void
}

define void @partial_reduce_widen_v1i32_acc_promote_v4i8_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
; CHECK-LABEL: partial_reduce_widen_v1i32_acc_promote_v4i8_vec:
; CHECK: // %bb.0:
; CHECK-NEXT: ldr s0, [x2]
; CHECK-NEXT: ldr s2, [x0]
; CHECK-NEXT: ushll v0.8h, v0.8b, #0
; CHECK-NEXT: ushll v1.4s, v0.4h, #0
; CHECK-NEXT: uaddw v0.4s, v2.4s, v0.4h
; CHECK-NEXT: ext v1.16b, v1.16b, v1.16b, #8
; CHECK-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NEXT: dup v1.2s, v0.s[1]
; CHECK-NEXT: add v0.2s, v0.2s, v1.2s
; CHECK-NEXT: str s0, [x1]
; CHECK-NEXT: ret
%acc = load <1 x i32>, ptr %accptr
%vec = load <4 x i8>, ptr %vecptr
%res = call <1 x i32> @llvm.vector.partial.reduce.add(<1 x i32> %acc, <4 x i8> %vec)
store <1 x i32> %res, ptr %resptr
ret void
}

define void @partial_reduce_widen_v9i32_acc_widen_v18i8_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
; CHECK-LABEL: partial_reduce_widen_v9i32_acc_widen_v18i8_vec:
; CHECK: // %bb.0:
; CHECK-NEXT: sub sp, sp, #128
; CHECK-NEXT: .cfi_def_cfa_offset 128
; CHECK-NEXT: movi v0.2d, #0000000000000000
; CHECK-NEXT: ldr w8, [x0, #32]
; CHECK-NEXT: mov w9, #257 // =0x101
; CHECK-NEXT: ldp q1, q2, [x0]
; CHECK-NEXT: ldr q3, [x2]
; CHECK-NEXT: str xzr, [sp, #40]
; CHECK-NEXT: movi v4.16b, #1
; CHECK-NEXT: stp w8, wzr, [sp, #32]
; CHECK-NEXT: str q0, [sp, #80]
; CHECK-NEXT: ldr q5, [sp, #32]
; CHECK-NEXT: str q0, [sp, #112]
; CHECK-NEXT: strh w9, [sp, #80]
; CHECK-NEXT: udot z1.s, z3.b, z4.b
; CHECK-NEXT: ldr q0, [x2, #16]
; CHECK-NEXT: str h0, [sp, #112]
; CHECK-NEXT: ldr q0, [sp, #80]
; CHECK-NEXT: ldr q6, [sp, #112]
; CHECK-NEXT: udot z5.s, z6.b, z0.b
; CHECK-NEXT: add v0.4s, v1.4s, v5.4s
; CHECK-NEXT: stp q0, q2, [x1]
; CHECK-NEXT: add sp, sp, #128
; CHECK-NEXT: ret
%acc = load <9 x i32>, ptr %accptr
%vec = load <18 x i8>, ptr %vecptr
%res = call <9 x i32> @llvm.vector.partial.reduce.add(<9 x i32> %acc, <18 x i8> %vec)
store <9 x i32> %res, ptr %resptr
ret void
}