diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h index 586c3411791f9..c4d69aa48434a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -1117,6 +1117,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer { SDValue WidenVecRes_Unary(SDNode *N); SDValue WidenVecRes_InregOp(SDNode *N); SDValue WidenVecRes_UnaryOpWithTwoResults(SDNode *N, unsigned ResNo); + SDValue WidenVecRes_PARTIAL_REDUCE_MLA(SDNode *N); void ReplaceOtherWidenResults(SDNode *N, SDNode *WidenNode, unsigned WidenResNo); @@ -1152,6 +1153,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer { SDValue WidenVecOp_VP_REDUCE(SDNode *N); SDValue WidenVecOp_ExpOp(SDNode *N); SDValue WidenVecOp_VP_CttzElements(SDNode *N); + SDValue WidenVecOp_PARTIAL_REDUCE_MLA(SDNode *N); /// Helper function to generate a set of operations to perform /// a vector operation for a wider type. diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp index 87d5453cd98cf..ef59706c17f80 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -5136,6 +5136,10 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) { if (!unrollExpandedOp()) Res = WidenVecRes_UnaryOpWithTwoResults(N, ResNo); break; + case ISD::PARTIAL_REDUCE_UMLA: + case ISD::PARTIAL_REDUCE_SMLA: + Res = WidenVecRes_PARTIAL_REDUCE_MLA(N); + break; } } @@ -6995,6 +6999,58 @@ SDValue DAGTypeLegalizer::WidenVecRes_STRICT_FSETCC(SDNode *N) { return DAG.getBuildVector(WidenVT, dl, Scalars); } +// Widening the result of a partial reductions is implemented by +// accumulating into a wider (zero-padded) vector, then incrementally +// reducing that (extract half vector and add) until it fits +// the original type. +SDValue DAGTypeLegalizer::WidenVecRes_PARTIAL_REDUCE_MLA(SDNode *N) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + EVT WideAccVT = TLI.getTypeToTransformTo(*DAG.getContext(), + N->getOperand(0).getValueType()); + ElementCount WideAccEC = WideAccVT.getVectorElementCount(); + + // Widen mul-operands if needed, otherwise we'll end up with a + // node that isn't legal because the accumulator vector will not + // be a known multiple of the input vector. + SDValue MulOp1 = N->getOperand(1); + SDValue MulOp2 = N->getOperand(2); + EVT MulOpVT = MulOp1.getValueType(); + ElementCount MulOpEC = MulOpVT.getVectorElementCount(); + if (getTypeAction(MulOpVT) == TargetLowering::TypeWidenVector) { + EVT WideMulVT = GetWidenedVector(MulOp1).getValueType(); + assert(WideMulVT.getVectorElementCount().isKnownMultipleOf(WideAccEC) && + "Widening to a vector with less elements than accumulator?"); + SDValue Zero = DAG.getConstant(0, DL, WideMulVT); + MulOp1 = DAG.getInsertSubvector(DL, Zero, MulOp1, 0); + MulOp2 = DAG.getInsertSubvector(DL, Zero, MulOp2, 0); + } else if (!MulOpEC.isKnownMultipleOf(WideAccEC)) { + assert(getTypeAction(MulOpVT) != TargetLowering::TypeLegal && + "Expected Mul operands to need legalisation"); + EVT WideMulVT = EVT::getVectorVT(*DAG.getContext(), + MulOpVT.getVectorElementType(), WideAccEC); + SDValue Zero = DAG.getConstant(0, DL, WideMulVT); + MulOp1 = DAG.getInsertSubvector(DL, Zero, MulOp1, 0); + MulOp2 = DAG.getInsertSubvector(DL, Zero, MulOp2, 0); + } + + SDValue Acc = DAG.getInsertSubvector(DL, DAG.getConstant(0, DL, WideAccVT), + N->getOperand(0), 0); + SDValue WidenedRes = + DAG.getNode(N->getOpcode(), DL, WideAccVT, Acc, MulOp1, MulOp2); + while (ElementCount::isKnownLT( + VT.getVectorElementCount(), + WidenedRes.getValueType().getVectorElementCount())) { + EVT HalfVT = + WidenedRes.getValueType().getHalfNumVectorElementsVT(*DAG.getContext()); + SDValue Lo = DAG.getExtractSubvector(DL, HalfVT, WidenedRes, 0); + SDValue Hi = DAG.getExtractSubvector(DL, HalfVT, WidenedRes, + HalfVT.getVectorMinNumElements()); + WidenedRes = DAG.getNode(ISD::ADD, DL, HalfVT, Lo, Hi); + } + return DAG.getInsertSubvector(DL, DAG.getPOISON(WideAccVT), WidenedRes, 0); +} + //===----------------------------------------------------------------------===// // Widen Vector Operand //===----------------------------------------------------------------------===// @@ -7127,6 +7183,10 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) { case ISD::VP_REDUCE_FMINIMUM: Res = WidenVecOp_VP_REDUCE(N); break; + case ISD::PARTIAL_REDUCE_UMLA: + case ISD::PARTIAL_REDUCE_SMLA: + Res = WidenVecOp_PARTIAL_REDUCE_MLA(N); + break; case ISD::VP_CTTZ_ELTS: case ISD::VP_CTTZ_ELTS_ZERO_UNDEF: Res = WidenVecOp_VP_CttzElements(N); @@ -8026,6 +8086,24 @@ SDValue DAGTypeLegalizer::WidenVecOp_VP_CttzElements(SDNode *N) { {Source, Mask, N->getOperand(2)}, N->getFlags()); } +SDValue DAGTypeLegalizer::WidenVecOp_PARTIAL_REDUCE_MLA(SDNode *N) { + // Widening of multiplicant operands only. The result and accumulator + // should already be legal types. + SDLoc DL(N); + EVT WideOpVT = TLI.getTypeToTransformTo(*DAG.getContext(), + N->getOperand(1).getValueType()); + SDValue Acc = N->getOperand(0); + assert(WideOpVT.getVectorElementCount().isKnownMultipleOf( + Acc.getValueType().getVectorElementCount()) && + "Expected AccVT to have been legalised"); + SDValue WidenedOp1 = DAG.getInsertSubvector( + DL, DAG.getConstant(0, DL, WideOpVT), N->getOperand(1), 0); + SDValue WidenedOp2 = DAG.getInsertSubvector( + DL, DAG.getConstant(0, DL, WideOpVT), N->getOperand(2), 0); + return DAG.getNode(N->getOpcode(), DL, Acc.getValueType(), Acc, WidenedOp1, + WidenedOp2); +} + //===----------------------------------------------------------------------===// // Vector Widening Utilities //===----------------------------------------------------------------------===// diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-widen.ll b/llvm/test/CodeGen/AArch64/partial-reduce-widen.ll new file mode 100644 index 0000000000000..9bcb5a6672063 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/partial-reduce-widen.ll @@ -0,0 +1,210 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc -mattr=+neon,+sve2p1,+dotprod < %s | FileCheck %s + +target triple = "aarch64" + +define void @partial_reduce_widen_v1i32_acc_legal_v4i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) { +; CHECK-LABEL: partial_reduce_widen_v1i32_acc_legal_v4i32_vec: +; CHECK: // %bb.0: +; CHECK-NEXT: ldr q0, [x2] +; CHECK-NEXT: ldr s1, [x0] +; CHECK-NEXT: ext v2.16b, v0.16b, v0.16b, #8 +; CHECK-NEXT: add v0.2s, v1.2s, v0.2s +; CHECK-NEXT: add v0.2s, v2.2s, v0.2s +; CHECK-NEXT: dup v1.2s, v0.s[1] +; CHECK-NEXT: add v0.2s, v0.2s, v1.2s +; CHECK-NEXT: str s0, [x1] +; CHECK-NEXT: ret + %acc = load <1 x i32>, ptr %accptr + %vec = load <4 x i32>, ptr %vecptr + %partial.reduce = call <1 x i32> @llvm.vector.partial.reduce.add(<1 x i32> %acc, <4 x i32> %vec) + store <1 x i32> %partial.reduce, ptr %resptr + ret void +} + +define void @partial_reduce_widen_v3i32_acc_widen_v12i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) { +; CHECK-LABEL: partial_reduce_widen_v3i32_acc_widen_v12i32_vec: +; CHECK: // %bb.0: +; CHECK-NEXT: sub sp, sp, #128 +; CHECK-NEXT: .cfi_def_cfa_offset 128 +; CHECK-NEXT: ldp q1, q0, [x2] +; CHECK-NEXT: ldr q2, [x0] +; CHECK-NEXT: mov v2.s[3], wzr +; CHECK-NEXT: add v0.4s, v1.4s, v0.4s +; CHECK-NEXT: ldr q1, [x2, #32] +; CHECK-NEXT: add v0.4s, v0.4s, v1.4s +; CHECK-NEXT: add v0.4s, v2.4s, v0.4s +; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #8 +; CHECK-NEXT: add v0.2s, v0.2s, v1.2s +; CHECK-NEXT: mov s1, v0.s[2] +; CHECK-NEXT: str d0, [x1] +; CHECK-NEXT: str s1, [x1, #8] +; CHECK-NEXT: add sp, sp, #128 +; CHECK-NEXT: ret + %acc = load <3 x i32>, ptr %accptr + %vec = load <12 x i32>, ptr %vecptr + %partial.reduce = call <3 x i32> @llvm.vector.partial.reduce.add(<3 x i32> %acc, <12 x i32> %vec) + store <3 x i32> %partial.reduce, ptr %resptr + ret void +} + +define void @partial_reduce_widen_v1i32_acc_widen_v12i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) { +; CHECK-LABEL: partial_reduce_widen_v1i32_acc_widen_v12i32_vec: +; CHECK: // %bb.0: +; CHECK-NEXT: sub sp, sp, #128 +; CHECK-NEXT: .cfi_def_cfa_offset 128 +; CHECK-NEXT: ldp q1, q0, [x2] +; CHECK-NEXT: ldr s2, [x0] +; CHECK-NEXT: ldr q5, [x2, #32] +; CHECK-NEXT: ext v3.16b, v1.16b, v1.16b, #8 +; CHECK-NEXT: ext v4.16b, v0.16b, v0.16b, #8 +; CHECK-NEXT: add v1.2s, v2.2s, v1.2s +; CHECK-NEXT: ext v2.16b, v5.16b, v5.16b, #8 +; CHECK-NEXT: add v0.2s, v1.2s, v0.2s +; CHECK-NEXT: add v1.2s, v4.2s, v3.2s +; CHECK-NEXT: add v0.2s, v0.2s, v5.2s +; CHECK-NEXT: add v1.2s, v2.2s, v1.2s +; CHECK-NEXT: add v0.2s, v1.2s, v0.2s +; CHECK-NEXT: dup v1.2s, v0.s[1] +; CHECK-NEXT: add v0.2s, v0.2s, v1.2s +; CHECK-NEXT: str s0, [x1] +; CHECK-NEXT: add sp, sp, #128 +; CHECK-NEXT: ret + %acc = load <1 x i32>, ptr %accptr + %vec = load <12 x i32>, ptr %vecptr + %partial.reduce = call <1 x i32> @llvm.vector.partial.reduce.add(<1 x i32> %acc, <12 x i32> %vec) + store <1 x i32> %partial.reduce, ptr %resptr + ret void +} + +define void @partial_reduce_widen_v4i32_acc_widen_v12i8_vec(ptr %accptr, ptr %resptr, ptr %vecptr) { +; CHECK-LABEL: partial_reduce_widen_v4i32_acc_widen_v12i8_vec: +; CHECK: // %bb.0: +; CHECK-NEXT: sub sp, sp, #128 +; CHECK-NEXT: .cfi_def_cfa_offset 128 +; CHECK-NEXT: ldr q0, [x2] +; CHECK-NEXT: ldr q2, [x0] +; CHECK-NEXT: umull v1.8h, v0.8b, v0.8b +; CHECK-NEXT: umull2 v0.8h, v0.16b, v0.16b +; CHECK-NEXT: uaddw v2.4s, v2.4s, v1.4h +; CHECK-NEXT: uaddw2 v1.4s, v2.4s, v1.8h +; CHECK-NEXT: uaddw v0.4s, v1.4s, v0.4h +; CHECK-NEXT: str q0, [x1] +; CHECK-NEXT: add sp, sp, #128 +; CHECK-NEXT: ret + %acc = load <4 x i32>, ptr %accptr + %vec = load <12 x i8>, ptr %vecptr + %vec.zext = zext <12 x i8> %vec to <12 x i32> + %vec.mul = mul <12 x i32> %vec.zext, %vec.zext + %partial.reduce = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <12 x i32> %vec.mul) + store <4 x i32> %partial.reduce, ptr %resptr + ret void +} + +define void @partial_reduce_widen_v1i8_acc_promote_v4i8_vec(ptr %accptr, ptr %resptr, ptr %vecptr) { +; CHECK-LABEL: partial_reduce_widen_v1i8_acc_promote_v4i8_vec: +; CHECK: // %bb.0: +; CHECK-NEXT: ldr s1, [x2] +; CHECK-NEXT: movi v0.2d, #0000000000000000 +; CHECK-NEXT: adrp x8, .LCPI4_0 +; CHECK-NEXT: ldr b2, [x0] +; CHECK-NEXT: zip1 v1.8b, v1.8b, v1.8b +; CHECK-NEXT: uzp1 v0.8b, v1.8b, v0.8b +; CHECK-NEXT: ldr d1, [x8, :lo12:.LCPI4_0] +; CHECK-NEXT: mla v2.8b, v0.8b, v1.8b +; CHECK-NEXT: zip2 v0.8b, v2.8b, v0.8b +; CHECK-NEXT: zip1 v1.8b, v2.8b, v0.8b +; CHECK-NEXT: add v0.4h, v1.4h, v0.4h +; CHECK-NEXT: zip2 v1.4h, v0.4h, v0.4h +; CHECK-NEXT: uaddw v0.4s, v1.4s, v0.4h +; CHECK-NEXT: mov w8, v0.s[1] +; CHECK-NEXT: fmov s1, w8 +; CHECK-NEXT: add v0.8b, v0.8b, v1.8b +; CHECK-NEXT: str b0, [x1] +; CHECK-NEXT: ret + %acc = load <1 x i8>, ptr %accptr + %vec = load <4 x i8>, ptr %vecptr + %res = call <1 x i8> @llvm.vector.partial.reduce.add(<1 x i8> %acc, <4 x i8> %vec) + store <1 x i8> %res, ptr %resptr + ret void +} + +define void @partial_reduce_widen_v3i32_acc_widen_v12i8_vec(ptr %accptr, ptr %resptr, ptr %vecptr) { +; CHECK-LABEL: partial_reduce_widen_v3i32_acc_widen_v12i8_vec: +; CHECK: // %bb.0: +; CHECK-NEXT: movi v0.2d, #0000000000000000 +; CHECK-NEXT: ldr q1, [x0] +; CHECK-NEXT: ldr q2, [x2] +; CHECK-NEXT: adrp x8, .LCPI5_0 +; CHECK-NEXT: mov v1.s[3], wzr +; CHECK-NEXT: mov v2.s[3], v0.s[3] +; CHECK-NEXT: ldr q0, [x8, :lo12:.LCPI5_0] +; CHECK-NEXT: udot z1.s, z2.b, z0.b +; CHECK-NEXT: ext v0.16b, v1.16b, v1.16b, #8 +; CHECK-NEXT: add v0.2s, v1.2s, v0.2s +; CHECK-NEXT: mov s1, v0.s[2] +; CHECK-NEXT: str d0, [x1] +; CHECK-NEXT: str s1, [x1, #8] +; CHECK-NEXT: ret + %acc = load <3 x i32>, ptr %accptr + %vec = load <12 x i8>, ptr %vecptr + %res = call <3 x i32> @llvm.vector.partial.reduce.add(<3 x i32> %acc, <12 x i8> %vec) + store <3 x i32> %res, ptr %resptr + ret void +} + +define void @partial_reduce_widen_v1i32_acc_promote_v4i8_vec(ptr %accptr, ptr %resptr, ptr %vecptr) { +; CHECK-LABEL: partial_reduce_widen_v1i32_acc_promote_v4i8_vec: +; CHECK: // %bb.0: +; CHECK-NEXT: ldr s0, [x2] +; CHECK-NEXT: ldr s2, [x0] +; CHECK-NEXT: ushll v0.8h, v0.8b, #0 +; CHECK-NEXT: ushll v1.4s, v0.4h, #0 +; CHECK-NEXT: uaddw v0.4s, v2.4s, v0.4h +; CHECK-NEXT: ext v1.16b, v1.16b, v1.16b, #8 +; CHECK-NEXT: add v0.2s, v1.2s, v0.2s +; CHECK-NEXT: dup v1.2s, v0.s[1] +; CHECK-NEXT: add v0.2s, v0.2s, v1.2s +; CHECK-NEXT: str s0, [x1] +; CHECK-NEXT: ret + %acc = load <1 x i32>, ptr %accptr + %vec = load <4 x i8>, ptr %vecptr + %res = call <1 x i32> @llvm.vector.partial.reduce.add(<1 x i32> %acc, <4 x i8> %vec) + store <1 x i32> %res, ptr %resptr + ret void +} + +define void @partial_reduce_widen_v9i32_acc_widen_v18i8_vec(ptr %accptr, ptr %resptr, ptr %vecptr) { +; CHECK-LABEL: partial_reduce_widen_v9i32_acc_widen_v18i8_vec: +; CHECK: // %bb.0: +; CHECK-NEXT: sub sp, sp, #128 +; CHECK-NEXT: .cfi_def_cfa_offset 128 +; CHECK-NEXT: movi v0.2d, #0000000000000000 +; CHECK-NEXT: ldr w8, [x0, #32] +; CHECK-NEXT: mov w9, #257 // =0x101 +; CHECK-NEXT: ldp q1, q2, [x0] +; CHECK-NEXT: ldr q3, [x2] +; CHECK-NEXT: str xzr, [sp, #40] +; CHECK-NEXT: movi v4.16b, #1 +; CHECK-NEXT: stp w8, wzr, [sp, #32] +; CHECK-NEXT: str q0, [sp, #80] +; CHECK-NEXT: ldr q5, [sp, #32] +; CHECK-NEXT: str q0, [sp, #112] +; CHECK-NEXT: strh w9, [sp, #80] +; CHECK-NEXT: udot z1.s, z3.b, z4.b +; CHECK-NEXT: ldr q0, [x2, #16] +; CHECK-NEXT: str h0, [sp, #112] +; CHECK-NEXT: ldr q0, [sp, #80] +; CHECK-NEXT: ldr q6, [sp, #112] +; CHECK-NEXT: udot z5.s, z6.b, z0.b +; CHECK-NEXT: add v0.4s, v1.4s, v5.4s +; CHECK-NEXT: stp q0, q2, [x1] +; CHECK-NEXT: add sp, sp, #128 +; CHECK-NEXT: ret + %acc = load <9 x i32>, ptr %accptr + %vec = load <18 x i8>, ptr %vecptr + %res = call <9 x i32> @llvm.vector.partial.reduce.add(<9 x i32> %acc, <18 x i8> %vec) + store <9 x i32> %res, ptr %resptr + ret void +} +