Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 71 additions & 44 deletions llvm/docs/LangRef.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20363,6 +20363,77 @@ Arguments:
""""""""""
The argument to this intrinsic must be a vector of floating-point values.

Vector Partial Reduction Intrinsics
-----------------------------------

Partial reductions of vectors can be expressed using the intrinsics described in
this section. Each one reduces the concatenation of the two vector arguments
down to the number of elements of the result vector type.

Other than the reduction operator (e.g. add, fadd), the way in which the
concatenated arguments is reduced is entirely unspecified. By their nature these
intrinsics are not expected to be useful in isolation but can instead be used to
implement the first phase of an overall reduction operation.

The typical use case is loop vectorization where reductions are split into an
in-loop phase, where maintaining an unordered vector result is important for
performance, and an out-of-loop phase is required to calculate the final scalar
result.

By avoiding the introduction of new ordering constraints, these intrinsics
enhance the ability to leverage a target's accumulation instructions.

'``llvm.vector.partial.reduce.add.*``' Intrinsic
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""
This is an overloaded intrinsic.

::

declare <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v4i32.v8i32(<4 x i32> %a, <8 x i32> %b)
declare <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v4i32.v16i32(<4 x i32> %a, <16 x i32> %b)
declare <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv4i32.nxv8i32(<vscale x 4 x i32> %a, <vscale x 8 x i32> %b)
declare <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv4i32.nxv16i32(<vscale x 4 x i32> %a, <vscale x 16 x i32> %b)

Arguments:
""""""""""

The first argument is an integer vector with the same type as the result.

The second argument is a vector with a length that is a known integer multiple
of the result's type, while maintaining the same element type.

'``llvm.vector.partial.reduce.fadd.*``' Intrinsic
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""
This is an overloaded intrinsic.

::

declare <4 x f32> @llvm.vector.partial.reduce.fadd.v4f32.v8f32(<4 x f32> %a, <8 x f32> %b)
declare <vscale x 4 x f32> @llvm.vector.partial.reduce.fadd.nxv4f32.nxv8f32(<vscale x 4 x f32> %a, <vscale x 8 x f32> %b)

Arguments:
""""""""""

The first argument is a floating-point vector with the same type as the result.

The second argument is a vector with a length that is a known integer multiple
of the result's type, while maintaining the same element type.

Semantics:
""""""""""

As the way in which the arguments to this floating-point intrinsic are reduced
is unspecified, this intrinsic will assume floating-point reassociation and
contraction can be leveraged to implement the reduction, which may result in
variations to the results due to reordering or by lowering to different
instructions (including combining multiple instructions into a single one).

'``llvm.vector.insert``' Intrinsic
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -20736,50 +20807,6 @@ Note that it has the following implications:
- If ``%cnt`` is non-zero, the return value is non-zero as well.
- If ``%cnt`` is less than or equal to ``%max_lanes``, the return value is equal to ``%cnt``.

'``llvm.vector.partial.reduce.add.*``' Intrinsic
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""
This is an overloaded intrinsic.

::

declare <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v4i32.v8i32(<4 x i32> %a, <8 x i32> %b)
declare <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v4i32.v16i32(<4 x i32> %a, <16 x i32> %b)
declare <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv4i32.nxv8i32(<vscale x 4 x i32> %a, <vscale x 8 x i32> %b)
declare <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv4i32.nxv16i32(<vscale x 4 x i32> %a, <vscale x 16 x i32> %b)

Overview:
"""""""""

The '``llvm.vector.partial.reduce.add.*``' intrinsics reduce the
concatenation of the two vector arguments down to the number of elements of the
result vector type.

Arguments:
""""""""""

The first argument is an integer vector with the same type as the result.

The second argument is a vector with a length that is a known integer multiple
of the result's type, while maintaining the same element type.

Semantics:
""""""""""

Other than the reduction operator (e.g. add) the way in which the concatenated
arguments is reduced is entirely unspecified. By their nature these intrinsics
are not expected to be useful in isolation but instead implement the first phase
of an overall reduction operation.

The typical use case is loop vectorization where reductions are split into an
in-loop phase, where maintaining an unordered vector result is important for
performance, and an out-of-loop phase to calculate the final scalar result.

By avoiding the introduction of new ordering constraints, these intrinsics
enhance the ability to leverage a target's accumulation instructions.

'``llvm.experimental.vector.histogram.*``' Intrinsic
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
7 changes: 6 additions & 1 deletion llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,12 @@ class TargetTransformInfoImplBase;
/// for IR-level transformations.
class TargetTransformInfo {
public:
enum PartialReductionExtendKind { PR_None, PR_SignExtend, PR_ZeroExtend };
enum PartialReductionExtendKind {
PR_None,
PR_SignExtend,
PR_ZeroExtend,
PR_FPExtend
};

/// Get the kind of extension that an instruction represents.
LLVM_ABI static PartialReductionExtendKind
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/CodeGen/ISDOpcodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1516,6 +1516,7 @@ enum NodeType {
PARTIAL_REDUCE_SMLA, // sext, sext
PARTIAL_REDUCE_UMLA, // zext, zext
PARTIAL_REDUCE_SUMLA, // sext, zext
PARTIAL_REDUCE_FMLA, // fpext, fpext

// The `llvm.experimental.stackmap` intrinsic.
// Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1938,6 +1938,10 @@ LLVM_ABI bool isNullOrNullSplat(SDValue V, bool AllowUndefs = false);
/// be zero.
LLVM_ABI bool isOneOrOneSplat(SDValue V, bool AllowUndefs = false);

/// Return true if the value is a constant floating-point value, or a splatted
/// vector of a constant floating-point value, of 1.0 (with no undefs).
LLVM_ABI bool isOneOrOneSplatFP(SDValue V, bool AllowUndefs = false);

/// Return true if the value is a constant -1 integer or a splatted vector of a
/// constant -1 integer (with no undefs).
/// Does not permit build vector implicit truncation.
Expand Down
4 changes: 2 additions & 2 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1664,7 +1664,7 @@ class LLVM_ABI TargetLoweringBase {
LegalizeAction getPartialReduceMLAAction(unsigned Opc, EVT AccVT,
EVT InputVT) const {
assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
Opc == ISD::PARTIAL_REDUCE_SUMLA);
Opc == ISD::PARTIAL_REDUCE_SUMLA || Opc == ISD::PARTIAL_REDUCE_FMLA);
PartialReduceActionTypes Key = {Opc, AccVT.getSimpleVT().SimpleTy,
InputVT.getSimpleVT().SimpleTy};
auto It = PartialReduceMLAActions.find(Key);
Expand Down Expand Up @@ -2766,7 +2766,7 @@ class LLVM_ABI TargetLoweringBase {
void setPartialReduceMLAAction(unsigned Opc, MVT AccVT, MVT InputVT,
LegalizeAction Action) {
assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
Opc == ISD::PARTIAL_REDUCE_SUMLA);
Opc == ISD::PARTIAL_REDUCE_SUMLA || Opc == ISD::PARTIAL_REDUCE_FMLA);
assert(AccVT.isValid() && InputVT.isValid() &&
"setPartialReduceMLAAction types aren't valid");
PartialReduceActionTypes Key = {Opc, AccVT.SimpleTy, InputVT.SimpleTy};
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/IR/Intrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -2799,6 +2799,10 @@ def int_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
[IntrNoMem,
IntrSpeculatable]>;

def int_vector_partial_reduce_fadd : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
[llvm_anyfloat_ty, llvm_anyfloat_ty],
[IntrNoMem]>;

//===----------------- Pointer Authentication Intrinsics ------------------===//
//

Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/Target/TargetSelectionDAG.td
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,8 @@ def partial_reduce_smla : SDNode<"ISD::PARTIAL_REDUCE_SMLA",
SDTPartialReduceMLA>;
def partial_reduce_sumla : SDNode<"ISD::PARTIAL_REDUCE_SUMLA",
SDTPartialReduceMLA>;
def partial_reduce_fmla : SDNode<"ISD::PARTIAL_REDUCE_FMLA",
SDTPartialReduceMLA>;

def fadd : SDNode<"ISD::FADD" , SDTFPBinOp, [SDNPCommutative]>;
def fsub : SDNode<"ISD::FSUB" , SDTFPBinOp>;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,8 @@ TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
return PR_SignExtend;
if (isa<ZExtInst>(I))
return PR_ZeroExtend;
if (isa<FPExtInst>(I))
return PR_FPExtend;
return PR_None;
}

Expand Down
45 changes: 31 additions & 14 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2042,6 +2042,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
case ISD::PARTIAL_REDUCE_FMLA:
return visitPARTIAL_REDUCE_MLA(N);
case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
case ISD::LIFETIME_END: return visitLIFETIME_END(N);
Expand Down Expand Up @@ -12988,6 +12989,9 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
//
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
//
// partial_reduce_fmla(acc, fmul(fpext(a), fpext(b)), splat(1.0))
// -> partial_reduce_fmla(acc, a, b)
SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDLoc DL(N);
auto *Context = DAG.getContext();
Expand All @@ -12996,7 +13000,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue Op2 = N->getOperand(2);

unsigned Opc = Op1->getOpcode();
if (Opc != ISD::MUL && Opc != ISD::SHL)
if (Opc != ISD::MUL && Opc != ISD::FMUL && Opc != ISD::SHL)
return SDValue();

SDValue LHS = Op1->getOperand(0);
Expand All @@ -13015,20 +13019,24 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
Opc = ISD::MUL;
}

APInt C;
if (Opc != ISD::MUL || !ISD::isConstantSplatVector(Op2.getNode(), C) ||
!C.isOne())
if (!(Opc == ISD::MUL && llvm::isOneOrOneSplat(Op2)) &&
!(Opc == ISD::FMUL && llvm::isOneOrOneSplatFP(Op2)))
return SDValue();

auto IsIntOrFPExtOpcode = [](unsigned int Opcode) {
return (ISD::isExtOpcode(Opcode) || Opcode == ISD::FP_EXTEND);
};

unsigned LHSOpcode = LHS->getOpcode();
if (!ISD::isExtOpcode(LHSOpcode))
if (!IsIntOrFPExtOpcode(LHSOpcode))
return SDValue();

SDValue LHSExtOp = LHS->getOperand(0);
EVT LHSExtOpVT = LHSExtOp.getValueType();

// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
APInt C;
if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
// TODO: Make use of partial_reduce_sumla here
APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
Expand All @@ -13053,7 +13061,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
}

unsigned RHSOpcode = RHS->getOpcode();
if (!ISD::isExtOpcode(RHSOpcode))
if (!IsIntOrFPExtOpcode(RHSOpcode))
return SDValue();

SDValue RHSExtOp = RHS->getOperand(0);
Expand All @@ -13070,6 +13078,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
std::swap(LHSExtOp, RHSExtOp);
} else if (LHSOpcode == ISD::FP_EXTEND && RHSOpcode == ISD::FP_EXTEND) {
NewOpc = ISD::PARTIAL_REDUCE_FMLA;
} else
return SDValue();
// For a 2-stage extend the signedness of both of the extends must match
Expand Down Expand Up @@ -13097,30 +13107,33 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
// partial.reduce.sumla(acc, sext(op), splat(1))
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
// partial.reduce.fmla(acc, fpext(op), splat(1.0))
// -> partial.reduce.fmla(acc, op, splat(1.0))
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
SDLoc DL(N);
SDValue Acc = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);

APInt ConstantOne;
if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
!ConstantOne.isOne())
if (!llvm::isOneOrOneSplat(Op2) && !llvm::isOneOrOneSplatFP(Op2))
return SDValue();

unsigned Op1Opcode = Op1.getOpcode();
if (!ISD::isExtOpcode(Op1Opcode))
if (!ISD::isExtOpcode(Op1Opcode) && Op1Opcode != ISD::FP_EXTEND)
return SDValue();

bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
bool Op1IsSigned =
Op1Opcode == ISD::SIGN_EXTEND || Op1Opcode == ISD::FP_EXTEND;
bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
EVT AccElemVT = Acc.getValueType().getVectorElementType();
if (Op1IsSigned != NodeIsSigned &&
Op1.getValueType().getVectorElementType() != AccElemVT)
return SDValue();

unsigned NewOpcode =
Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
unsigned NewOpcode = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
? ISD::PARTIAL_REDUCE_FMLA
: Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA
: ISD::PARTIAL_REDUCE_UMLA;

SDValue UnextOp1 = Op1.getOperand(0);
EVT UnextOp1VT = UnextOp1.getValueType();
Expand All @@ -13130,8 +13143,12 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
return SDValue();

SDValue Constant = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
? DAG.getConstantFP(1, DL, UnextOp1VT)
: DAG.getConstant(1, DL, UnextOp1VT);

return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
DAG.getConstant(1, DL, UnextOp1VT));
Constant);
}

SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
case ISD::PARTIAL_REDUCE_FMLA:
Action =
TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0),
Node->getOperand(1).getValueType());
Expand Down Expand Up @@ -1243,6 +1244,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
case ISD::PARTIAL_REDUCE_FMLA:
Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
return;
case ISD::VECREDUCE_SEQ_FADD:
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1459,6 +1459,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
case ISD::PARTIAL_REDUCE_FMLA:
SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
break;
case ISD::GET_ACTIVE_LANE_MASK:
Expand Down Expand Up @@ -3674,6 +3675,7 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
case ISD::PARTIAL_REDUCE_FMLA:
Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
break;
}
Expand Down
8 changes: 7 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8404,7 +8404,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
}
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA: {
case ISD::PARTIAL_REDUCE_SUMLA:
case ISD::PARTIAL_REDUCE_FMLA: {
[[maybe_unused]] EVT AccVT = N1.getValueType();
[[maybe_unused]] EVT Input1VT = N2.getValueType();
[[maybe_unused]] EVT Input2VT = N3.getValueType();
Expand Down Expand Up @@ -13054,6 +13055,11 @@ bool llvm::isOneOrOneSplat(SDValue N, bool AllowUndefs) {
return C && C->isOne();
}

bool llvm::isOneOrOneSplatFP(SDValue N, bool AllowUndefs) {
ConstantFPSDNode *C = isConstOrConstSplatFP(N, AllowUndefs);
return C && C->isExactlyValue(1.0);
}

bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) {
N = peekThroughBitcasts(N);
unsigned BitWidth = N.getScalarValueSizeInBits();
Expand Down
Loading
Loading