Skip to content

Commit 0015c50

Browse files
committed
[WIP][SDAG] Add partial_reduce_sumla node
We have recently added the partial_reduce_smla and partial_reduce_umla nodes to represent Acc += ext(b) * ext(b) where the two extends have to have the same source type, and have the same extend kind. For riscv64 w/zvqdotq, we have the vqdot and vqdotu instructions which correspond to the existing nodes, but we also have vqdotsu which represents the case where the two extends are sign and zero respective (i.e. not the same type of extend). This patch adds a partial_reduce_sumla node which has sign extension for A, and zero extension for B. The addition is somewhat mechanical, except that it exposes an implementaion challenge because AArch64 doesn't have an analogous instruction (that I've found). The current legalization table assumes that all of the partial_reduce*mla variants have the same handling for a given type pair. Questions to the AArch64 folks: * Does aarch64 have a good implementation for this that I missed? * If not, are you okay with my somewhat hacky custom legalization approach (in this patch)? It does look like there are some small regressions here, but I haven't dug into why. * If not, any suggestions on how to structure splitting the legalization table? I could add the opcode to the table key; that's probably the easiest.
1 parent 9b4de7d commit 0015c50

File tree

12 files changed

+400
-130
lines changed

12 files changed

+400
-130
lines changed

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,8 +1484,9 @@ enum NodeType {
14841484
VECREDUCE_UMIN,
14851485

14861486
// PARTIAL_REDUCE_[U|S]MLA(Accumulator, Input1, Input2)
1487-
// The partial reduction nodes sign or zero extend Input1 and Input2 to the
1488-
// element type of Accumulator before multiplying their results.
1487+
// The partial reduction nodes sign or zero extend Input1 and Input2
1488+
// (with the extension kind noted below) to the element type of
1489+
// Accumulator before multiplying their results.
14891490
// This result is concatenated to the Accumulator, and this is then reduced,
14901491
// using addition, to the result type.
14911492
// The output is only expected to either be given to another partial reduction
@@ -1497,8 +1498,9 @@ enum NodeType {
14971498
// multiple of the number of elements in the Accumulator / output type.
14981499
// Input1 and Input2 must have an element type which is the same as or smaller
14991500
// than the element type of the Accumulator and output.
1500-
PARTIAL_REDUCE_SMLA,
1501-
PARTIAL_REDUCE_UMLA,
1501+
PARTIAL_REDUCE_SMLA, // sext, sext
1502+
PARTIAL_REDUCE_UMLA, // zext, zext
1503+
PARTIAL_REDUCE_SUMLA, // sext, zext
15021504

15031505
// The `llvm.experimental.stackmap` intrinsic.
15041506
// Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1991,6 +1991,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
19911991
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
19921992
case ISD::PARTIAL_REDUCE_SMLA:
19931993
case ISD::PARTIAL_REDUCE_UMLA:
1994+
case ISD::PARTIAL_REDUCE_SUMLA:
19941995
return visitPARTIAL_REDUCE_MLA(N);
19951996
case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
19961997
case ISD::LIFETIME_END: return visitLIFETIME_END(N);
@@ -12675,19 +12676,19 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1267512676
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
1267612677
return SDValue();
1267712678

12678-
bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
12679-
unsigned NewOpcode =
12680-
ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12681-
1268212679
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
1268312680
// -> partial_reduce_*mla(acc, x, C)
1268412681
if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
12682+
// TODO: Make use of partial_reduce_sumla here
1268512683
APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
1268612684
unsigned LHSBits = LHS.getValueType().getScalarSizeInBits();
1268712685
if ((LHSOpcode != ISD::ZERO_EXTEND || CTrunc.zext(LHSBits) != C) &&
1268812686
(LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
1268912687
return SDValue();
1269012688

12689+
unsigned NewOpcode = LHSOpcode == ISD::SIGN_EXTEND
12690+
? ISD::PARTIAL_REDUCE_SMLA
12691+
: ISD::PARTIAL_REDUCE_UMLA;
1269112692
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
1269212693
DAG.getConstant(CTrunc, DL, LHSExtOpVT));
1269312694
}
@@ -12697,26 +12698,46 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1269712698
return SDValue();
1269812699

1269912700
SDValue RHSExtOp = RHS->getOperand(0);
12700-
if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
12701+
if (LHSExtOpVT != RHSExtOp.getValueType())
1270112702
return SDValue();
1270212703

12703-
// For a 2-stage extend the signedness of both of the extends must be the
12704-
// same. This is so the node can be folded into only a signed or unsigned
12705-
// node.
12706-
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
12704+
unsigned NewOpc = ISD::PARTIAL_REDUCE_SMLA;
12705+
// For a 2-stage extend the signedness of both of the extends must match
12706+
// If the mul has the same type, there is no outer extend, and thus we
12707+
// can simply use the inner extends to pick the result node.
1270712708
EVT AccElemVT = Acc.getValueType().getVectorElementType();
12708-
if (ExtIsSigned != NodeIsSigned &&
12709-
Op1.getValueType().getVectorElementType() != AccElemVT)
12710-
return SDValue();
12711-
12712-
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
12713-
RHSExtOp);
12709+
if (Op1.getValueType().getVectorElementType() != AccElemVT) {
12710+
// TODO: Split this into canonicalization rules
12711+
if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::SIGN_EXTEND &&
12712+
(N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA ||
12713+
N->getOpcode() == ISD::PARTIAL_REDUCE_SUMLA))
12714+
NewOpc = ISD::PARTIAL_REDUCE_SMLA;
12715+
else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::ZERO_EXTEND &&
12716+
N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA)
12717+
NewOpc = ISD::PARTIAL_REDUCE_UMLA;
12718+
else
12719+
return SDValue();
12720+
} else {
12721+
// TODO: Add canonicalization rule
12722+
if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::SIGN_EXTEND)
12723+
NewOpc = ISD::PARTIAL_REDUCE_SMLA;
12724+
else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
12725+
NewOpc = ISD::PARTIAL_REDUCE_UMLA;
12726+
else if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
12727+
NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
12728+
else
12729+
// TODO: Handle the swapped sumla case here
12730+
return SDValue();
12731+
}
12732+
return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
1271412733
}
1271512734

1271612735
// partial.reduce.umla(acc, zext(op), splat(1))
1271712736
// -> partial.reduce.umla(acc, op, splat(trunc(1)))
1271812737
// partial.reduce.smla(acc, sext(op), splat(1))
1271912738
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
12739+
// partial.reduce.sumla(acc, sext(op), splat(1))
12740+
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
1272012741
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
1272112742
SDLoc DL(N);
1272212743
SDValue Acc = N->getOperand(0);
@@ -12738,7 +12759,7 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
1273812759
return SDValue();
1273912760

1274012761
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
12741-
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
12762+
bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
1274212763
EVT AccElemVT = Acc.getValueType().getVectorElementType();
1274312764
if (Op1IsSigned != NodeIsSigned &&
1274412765
Op1.getValueType().getVectorElementType() != AccElemVT)

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
166166

167167
case ISD::PARTIAL_REDUCE_UMLA:
168168
case ISD::PARTIAL_REDUCE_SMLA:
169+
case ISD::PARTIAL_REDUCE_SUMLA:
169170
Res = PromoteIntRes_PARTIAL_REDUCE_MLA(N);
170171
break;
171172

@@ -2090,6 +2091,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
20902091
break;
20912092
case ISD::PARTIAL_REDUCE_UMLA:
20922093
case ISD::PARTIAL_REDUCE_SMLA:
2094+
case ISD::PARTIAL_REDUCE_SUMLA:
20932095
Res = PromoteIntOp_PARTIAL_REDUCE_MLA(N);
20942096
break;
20952097
}
@@ -2876,12 +2878,21 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N,
28762878

28772879
SDValue DAGTypeLegalizer::PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N) {
28782880
SmallVector<SDValue, 1> NewOps(N->ops());
2879-
if (N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA) {
2881+
switch (N->getOpcode()) {
2882+
case ISD::PARTIAL_REDUCE_SMLA:
28802883
NewOps[1] = SExtPromotedInteger(N->getOperand(1));
28812884
NewOps[2] = SExtPromotedInteger(N->getOperand(2));
2882-
} else {
2885+
break;
2886+
case ISD::PARTIAL_REDUCE_UMLA:
28832887
NewOps[1] = ZExtPromotedInteger(N->getOperand(1));
28842888
NewOps[2] = ZExtPromotedInteger(N->getOperand(2));
2889+
break;
2890+
case ISD::PARTIAL_REDUCE_SUMLA:
2891+
NewOps[1] = SExtPromotedInteger(N->getOperand(1));
2892+
NewOps[2] = ZExtPromotedInteger(N->getOperand(2));
2893+
break;
2894+
default:
2895+
llvm_unreachable("unexpected opcode");
28852896
}
28862897
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
28872898
}

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
530530
}
531531
case ISD::PARTIAL_REDUCE_UMLA:
532532
case ISD::PARTIAL_REDUCE_SMLA:
533+
case ISD::PARTIAL_REDUCE_SUMLA:
533534
Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
534535
Node->getOperand(1).getValueType());
535536
break;
@@ -1210,6 +1211,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
12101211
return;
12111212
case ISD::PARTIAL_REDUCE_UMLA:
12121213
case ISD::PARTIAL_REDUCE_SMLA:
1214+
case ISD::PARTIAL_REDUCE_SUMLA:
12131215
Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
12141216
return;
12151217
case ISD::VECREDUCE_SEQ_FADD:

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
13871387
break;
13881388
case ISD::PARTIAL_REDUCE_UMLA:
13891389
case ISD::PARTIAL_REDUCE_SMLA:
1390+
case ISD::PARTIAL_REDUCE_SUMLA:
13901391
SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
13911392
break;
13921393
}
@@ -3454,6 +3455,7 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
34543455
break;
34553456
case ISD::PARTIAL_REDUCE_UMLA:
34563457
case ISD::PARTIAL_REDUCE_SMLA:
3458+
case ISD::PARTIAL_REDUCE_SUMLA:
34573459
Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
34583460
break;
34593461
}

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7967,7 +7967,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
79677967
break;
79687968
}
79697969
case ISD::PARTIAL_REDUCE_UMLA:
7970-
case ISD::PARTIAL_REDUCE_SMLA: {
7970+
case ISD::PARTIAL_REDUCE_SMLA:
7971+
case ISD::PARTIAL_REDUCE_SUMLA: {
79717972
[[maybe_unused]] EVT AccVT = N1.getValueType();
79727973
[[maybe_unused]] EVT Input1VT = N2.getValueType();
79737974
[[maybe_unused]] EVT Input2VT = N3.getValueType();

llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
584584
return "partial_reduce_umla";
585585
case ISD::PARTIAL_REDUCE_SMLA:
586586
return "partial_reduce_smla";
587+
case ISD::PARTIAL_REDUCE_SUMLA:
588+
return "partial_reduce_sumla";
587589

588590
// Vector Predication
589591
#define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...) \

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11887,13 +11887,23 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
1188711887
EVT ExtMulOpVT =
1188811888
EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(),
1188911889
MulOpVT.getVectorElementCount());
11890-
unsigned ExtOpc = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
11891-
? ISD::SIGN_EXTEND
11892-
: ISD::ZERO_EXTEND;
11893-
1189411890
if (ExtMulOpVT != MulOpVT) {
11895-
MulLHS = DAG.getNode(ExtOpc, DL, ExtMulOpVT, MulLHS);
11896-
MulRHS = DAG.getNode(ExtOpc, DL, ExtMulOpVT, MulRHS);
11891+
switch (N->getOpcode()) {
11892+
case ISD::PARTIAL_REDUCE_SMLA:
11893+
MulLHS = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtMulOpVT, MulLHS);
11894+
MulRHS = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtMulOpVT, MulRHS);
11895+
break;
11896+
case ISD::PARTIAL_REDUCE_UMLA:
11897+
MulLHS = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtMulOpVT, MulLHS);
11898+
MulRHS = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtMulOpVT, MulRHS);
11899+
break;
11900+
case ISD::PARTIAL_REDUCE_SUMLA:
11901+
MulLHS = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtMulOpVT, MulLHS);
11902+
MulRHS = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtMulOpVT, MulRHS);
11903+
break;
11904+
default:
11905+
llvm_unreachable("unexpected opcode");
11906+
}
1189711907
}
1189811908
SDValue Input = MulLHS;
1189911909
APInt ConstantOne;

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1874,8 +1874,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18741874
if (EnablePartialReduceNodes && Subtarget->isSVEorStreamingSVEAvailable()) {
18751875
// Mark known legal pairs as 'Legal' (these will expand to UDOT or SDOT).
18761876
// Other pairs will default to 'Expand'.
1877-
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
1878-
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
1877+
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Custom);
1878+
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);
18791879

18801880
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
18811881
}
@@ -7745,6 +7745,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
77457745
return LowerVECTOR_HISTOGRAM(Op, DAG);
77467746
case ISD::PARTIAL_REDUCE_SMLA:
77477747
case ISD::PARTIAL_REDUCE_UMLA:
7748+
case ISD::PARTIAL_REDUCE_SUMLA:
77487749
return LowerPARTIAL_REDUCE_MLA(Op, DAG);
77497750
}
77507751
}
@@ -29532,13 +29533,24 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2953229533
SDValue
2953329534
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2953429535
SelectionDAG &DAG) const {
29536+
// No support for sumla forms, let generic legalization handle them
29537+
if (Op->getOpcode() == ISD::PARTIAL_REDUCE_SUMLA)
29538+
return SDValue();
29539+
2953529540
SDLoc DL(Op);
2953629541

2953729542
SDValue Acc = Op.getOperand(0);
2953829543
SDValue LHS = Op.getOperand(1);
2953929544
SDValue RHS = Op.getOperand(2);
2954029545
EVT ResultVT = Op.getValueType();
29541-
assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
29546+
EVT OpVT = LHS.getValueType();
29547+
29548+
// These two are legal...
29549+
if ((ResultVT == MVT::nxv2i64 && OpVT == MVT::nxv8i16) ||
29550+
(ResultVT == MVT::nxv4i32 && OpVT == MVT::nxv16i8))
29551+
return Op;
29552+
29553+
assert(ResultVT == MVT::nxv2i64 && OpVT == MVT::nxv16i8);
2954229554

2954329555
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
2954429556
DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8240,6 +8240,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
82408240
return lowerADJUST_TRAMPOLINE(Op, DAG);
82418241
case ISD::PARTIAL_REDUCE_UMLA:
82428242
case ISD::PARTIAL_REDUCE_SMLA:
8243+
case ISD::PARTIAL_REDUCE_SUMLA:
82438244
return lowerPARTIAL_REDUCE_MLA(Op, DAG);
82448245
}
82458246
}
@@ -8391,8 +8392,20 @@ SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
83918392
SDValue B = Op.getOperand(2);
83928393
assert(A.getSimpleValueType() == B.getSimpleValueType() &&
83938394
A.getSimpleValueType().getVectorElementType() == MVT::i8);
8394-
bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
8395-
unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
8395+
unsigned Opc;
8396+
switch (Op.getOpcode()) {
8397+
case ISD::PARTIAL_REDUCE_SMLA:
8398+
Opc = RISCVISD::VQDOT_VL;
8399+
break;
8400+
case ISD::PARTIAL_REDUCE_UMLA:
8401+
Opc = RISCVISD::VQDOTU_VL;
8402+
break;
8403+
case ISD::PARTIAL_REDUCE_SUMLA:
8404+
Opc = RISCVISD::VQDOTSU_VL;
8405+
break;
8406+
default:
8407+
llvm_unreachable("Unexpected opcode");
8408+
}
83968409
auto [Mask, VL] = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
83978410
return DAG.getNode(Opc, DL, VT, {A, B, Accum, Mask, VL});
83988411
}

0 commit comments

Comments
 (0)