Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue WidenVecRes_Unary(SDNode *N);
SDValue WidenVecRes_InregOp(SDNode *N);
SDValue WidenVecRes_UnaryOpWithTwoResults(SDNode *N, unsigned ResNo);
SDValue WidenVecRes_PARTIAL_REDUCE_MLA(SDNode *N);
void ReplaceOtherWidenResults(SDNode *N, SDNode *WidenNode,
unsigned WidenResNo);

Expand Down Expand Up @@ -1152,6 +1153,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue WidenVecOp_VP_REDUCE(SDNode *N);
SDValue WidenVecOp_ExpOp(SDNode *N);
SDValue WidenVecOp_VP_CttzElements(SDNode *N);
SDValue WidenVecOp_PARTIAL_REDUCE_MLA(SDNode *N);

/// Helper function to generate a set of operations to perform
/// a vector operation for a wider type.
Expand Down
51 changes: 51 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5136,6 +5136,10 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
if (!unrollExpandedOp())
Res = WidenVecRes_UnaryOpWithTwoResults(N, ResNo);
break;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
Res = WidenVecRes_PARTIAL_REDUCE_MLA(N);
break;
}
}

Expand Down Expand Up @@ -6995,6 +6999,34 @@ SDValue DAGTypeLegalizer::WidenVecRes_STRICT_FSETCC(SDNode *N) {
return DAG.getBuildVector(WidenVT, dl, Scalars);
}

// Widening the result of a partial reductions is implemented by
// accumulating into a wider (zero-padded) vector, then incrementally
// reducing that (extract half vector and add) until it fits
// the original type.
SDValue DAGTypeLegalizer::WidenVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
SDLoc DL(N);
EVT VT = N->getValueType(0);
EVT WideAccVT = TLI.getTypeToTransformTo(*DAG.getContext(),
N->getOperand(0).getValueType());
SDValue Zero = DAG.getConstant(0, DL, WideAccVT);
SDValue MulOp1 = N->getOperand(1);
SDValue MulOp2 = N->getOperand(2);
SDValue Acc = DAG.getInsertSubvector(DL, Zero, N->getOperand(0), 0);
SDValue WidenedRes =
DAG.getNode(N->getOpcode(), DL, WideAccVT, Acc, MulOp1, MulOp2);
while (ElementCount::isKnownLT(
VT.getVectorElementCount(),
WidenedRes.getValueType().getVectorElementCount())) {
EVT HalfVT =
WidenedRes.getValueType().getHalfNumVectorElementsVT(*DAG.getContext());
SDValue Lo = DAG.getExtractSubvector(DL, HalfVT, WidenedRes, 0);
SDValue Hi = DAG.getExtractSubvector(DL, HalfVT, WidenedRes,
HalfVT.getVectorMinNumElements());
WidenedRes = DAG.getNode(ISD::ADD, DL, HalfVT, Lo, Hi);
}
return DAG.getInsertSubvector(DL, DAG.getPOISON(WideAccVT), WidenedRes, 0);
}

//===----------------------------------------------------------------------===//
// Widen Vector Operand
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -7127,6 +7159,10 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::VP_REDUCE_FMINIMUM:
Res = WidenVecOp_VP_REDUCE(N);
break;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
Res = WidenVecOp_PARTIAL_REDUCE_MLA(N);
break;
case ISD::VP_CTTZ_ELTS:
case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
Res = WidenVecOp_VP_CttzElements(N);
Expand Down Expand Up @@ -8026,6 +8062,21 @@ SDValue DAGTypeLegalizer::WidenVecOp_VP_CttzElements(SDNode *N) {
{Source, Mask, N->getOperand(2)}, N->getFlags());
}

SDValue DAGTypeLegalizer::WidenVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
// Widening of multiplicant operands only. The result and accumulator
// should already be legal types.
SDLoc DL(N);
EVT WideOpVT = TLI.getTypeToTransformTo(*DAG.getContext(),
N->getOperand(1).getValueType());
SDValue Acc = N->getOperand(0);
SDValue WidenedOp1 = DAG.getInsertSubvector(
DL, DAG.getConstant(0, DL, WideOpVT), N->getOperand(1), 0);
SDValue WidenedOp2 = DAG.getInsertSubvector(
DL, DAG.getConstant(0, DL, WideOpVT), N->getOperand(2), 0);
return DAG.getNode(N->getOpcode(), DL, Acc.getValueType(), Acc, WidenedOp1,
WidenedOp2);
}

//===----------------------------------------------------------------------===//
// Vector Widening Utilities
//===----------------------------------------------------------------------===//
Expand Down
98 changes: 98 additions & 0 deletions llvm/test/CodeGen/AArch64/partial-reduce-widen.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
; RUN: llc < %s | FileCheck %s

target triple = "aarch64"

define void @partial_reduce_widen_v1i32_acc_v16i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
; CHECK-LABEL: partial_reduce_widen_v1i32_acc_v16i32_vec:
; CHECK: // %bb.0:
; CHECK-NEXT: ldp q1, q0, [x2]
; CHECK-NEXT: ldr s2, [x0]
; CHECK-NEXT: ldp q5, q6, [x2, #32]
; CHECK-NEXT: ext v3.16b, v1.16b, v1.16b, #8
; CHECK-NEXT: ext v4.16b, v0.16b, v0.16b, #8
; CHECK-NEXT: add v1.2s, v2.2s, v1.2s
; CHECK-NEXT: ext v2.16b, v5.16b, v5.16b, #8
; CHECK-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NEXT: add v1.2s, v4.2s, v3.2s
; CHECK-NEXT: ext v3.16b, v6.16b, v6.16b, #8
; CHECK-NEXT: add v0.2s, v0.2s, v5.2s
; CHECK-NEXT: add v1.2s, v2.2s, v1.2s
; CHECK-NEXT: add v0.2s, v0.2s, v6.2s
; CHECK-NEXT: add v1.2s, v3.2s, v1.2s
; CHECK-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NEXT: dup v1.2s, v0.s[1]
; CHECK-NEXT: add v0.2s, v0.2s, v1.2s
; CHECK-NEXT: str s0, [x1]
; CHECK-NEXT: ret
%acc = load <1 x i32>, ptr %accptr
%vec = load <16 x i32>, ptr %vecptr
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could these tests just take the vectors as parameters and return the vector instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I didn't do that was so that I wouldn't have to pass/return illegal types to the function (the ABI only describes how legal types are passed)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense 👍

%partial.reduce = call <1 x i32> @llvm.vector.partial.reduce.add(<1 x i32> %acc, <16 x i32> %vec)
store <1 x i32> %partial.reduce, ptr %resptr
ret void
}

define void @partial_reduce_widen_v3i32_acc_v12i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
; CHECK-LABEL: partial_reduce_widen_v3i32_acc_v12i32_vec:
; CHECK: // %bb.0:
; CHECK-NEXT: sub sp, sp, #128
; CHECK-NEXT: .cfi_def_cfa_offset 128
; CHECK-NEXT: ldp q1, q0, [x2]
; CHECK-NEXT: ldr q2, [x0]
; CHECK-NEXT: mov v2.s[3], wzr
; CHECK-NEXT: add v0.4s, v1.4s, v0.4s
; CHECK-NEXT: ldr q1, [x2, #32]
; CHECK-NEXT: add v0.4s, v0.4s, v1.4s
; CHECK-NEXT: add v0.4s, v2.4s, v0.4s
; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #8
; CHECK-NEXT: add v0.2s, v0.2s, v1.2s
; CHECK-NEXT: mov s1, v0.s[2]
; CHECK-NEXT: str d0, [x1]
; CHECK-NEXT: str s1, [x1, #8]
; CHECK-NEXT: add sp, sp, #128
; CHECK-NEXT: ret
%acc = load <3 x i32>, ptr %accptr
%vec = load <12 x i32>, ptr %vecptr
%partial.reduce = call <3 x i32> @llvm.vector.partial.reduce.add(<3 x i32> %acc, <12 x i32> %vec)
store <3 x i32> %partial.reduce, ptr %resptr
ret void
}

define void @partial_reduce_widen_v4i32_acc_v20i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
; CHECK-LABEL: partial_reduce_widen_v4i32_acc_v20i32_vec:
; CHECK: // %bb.0:
; CHECK-NEXT: sub sp, sp, #272
; CHECK-NEXT: str x29, [sp, #256] // 8-byte Folded Spill
; CHECK-NEXT: .cfi_def_cfa_offset 272
; CHECK-NEXT: .cfi_offset w29, -16
; CHECK-NEXT: ldp q1, q0, [x2]
; CHECK-NEXT: ldr s2, [x0]
; CHECK-NEXT: ldp q5, q6, [x2, #32]
; CHECK-NEXT: ldr x29, [sp, #256] // 8-byte Folded Reload
; CHECK-NEXT: ext v3.16b, v1.16b, v1.16b, #8
; CHECK-NEXT: ext v4.16b, v0.16b, v0.16b, #8
; CHECK-NEXT: add v1.2s, v2.2s, v1.2s
; CHECK-NEXT: ext v2.16b, v5.16b, v5.16b, #8
; CHECK-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NEXT: add v1.2s, v4.2s, v3.2s
; CHECK-NEXT: ext v3.16b, v6.16b, v6.16b, #8
; CHECK-NEXT: ldr q4, [x2, #64]
; CHECK-NEXT: add v0.2s, v0.2s, v5.2s
; CHECK-NEXT: add v1.2s, v2.2s, v1.2s
; CHECK-NEXT: ext v2.16b, v4.16b, v4.16b, #8
; CHECK-NEXT: add v0.2s, v0.2s, v6.2s
; CHECK-NEXT: add v1.2s, v3.2s, v1.2s
; CHECK-NEXT: add v0.2s, v0.2s, v4.2s
; CHECK-NEXT: add v1.2s, v2.2s, v1.2s
; CHECK-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NEXT: dup v1.2s, v0.s[1]
; CHECK-NEXT: add v0.2s, v0.2s, v1.2s
; CHECK-NEXT: str s0, [x1]
; CHECK-NEXT: add sp, sp, #272
; CHECK-NEXT: ret
%acc = load <1 x i32>, ptr %accptr
%vec = load <20 x i32>, ptr %vecptr
%partial.reduce = call <1 x i32> @llvm.vector.partial.reduce.add(<1 x i32> %acc, <20 x i32> %vec)
store <1 x i32> %partial.reduce, ptr %resptr
ret void
}
Loading