Skip to content

Commit 62390ab

Browse files
committed
[DAGCombiner] Add DAG combine for PARTIAL_REDUCE_MLA when no mul op
Generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA to convert: PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into PARTIAL_REDUCE_UMLA(Acc, UnextOp1, TRUNC(Splat(1))) and PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into PARTIAL_REDUCE_SMLA(Acc, UnextOp1, TRUNC(Splat(1))).
1 parent 6a99d81 commit 62390ab

File tree

1 file changed

+53
-1
lines changed

1 file changed

+53
-1
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,8 @@ namespace {
618618
SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
619619
SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
620620
const TargetLowering &TLI);
621+
SDValue foldPartialReduceMLAMulOp(SDNode *N);
622+
SDValue foldPartialReduceMLANoMulOp(SDNode *N);
621623

622624
SDValue CombineExtLoad(SDNode *N);
623625
SDValue CombineZExtLogicopShiftLoad(SDNode *N);
@@ -12601,12 +12603,20 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
1260112603
return SDValue();
1260212604
}
1260312605

12606+
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
12607+
if (SDValue Res = foldPartialReduceMLAMulOp(N))
12608+
return Res;
12609+
if (SDValue Res = foldPartialReduceMLANoMulOp(N))
12610+
return Res;
12611+
return SDValue();
12612+
}
12613+
1260412614
// partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1))
1260512615
// -> partial_reduce_*mla(acc, a, b)
1260612616
//
1260712617
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
1260812618
// -> partial_reduce_*mla(acc, x, C)
12609-
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
12619+
SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1261012620
SDLoc DL(N);
1261112621
auto *Context = DAG.getContext();
1261212622
SDValue Acc = N->getOperand(0);
@@ -12672,6 +12682,48 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
1267212682
RHSExtOp);
1267312683
}
1267412684

12685+
// Makes PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into
12686+
// PARTIAL_REDUCE_UMLA(Acc, Op, TRUNC(Splat(1)))
12687+
// Makes PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into
12688+
// PARTIAL_REDUCE_SMLA(Acc, Op, TRUNC(Splat(1)))
12689+
SDValue DAGCombiner::foldPartialReduceMLANoMulOp(SDNode *N) {
12690+
SDLoc DL(N);
12691+
SDValue Acc = N->getOperand(0);
12692+
SDValue Op1 = N->getOperand(1);
12693+
SDValue Op2 = N->getOperand(2);
12694+
12695+
APInt ConstantOne;
12696+
if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
12697+
!ConstantOne.isOne())
12698+
return SDValue();
12699+
12700+
unsigned Op1Opcode = Op1.getOpcode();
12701+
if (!ISD::isExtOpcode(Op1Opcode))
12702+
return SDValue();
12703+
12704+
SDValue UnextOp1 = Op1.getOperand(0);
12705+
EVT UnextOp1VT = UnextOp1.getValueType();
12706+
12707+
if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
12708+
return SDValue();
12709+
12710+
SDValue TruncOp2 = DAG.getNode(ISD::TRUNCATE, DL, UnextOp1VT, Op2);
12711+
12712+
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
12713+
12714+
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
12715+
EVT AccElemVT = Acc.getValueType().getVectorElementType();
12716+
if (Op1IsSigned != NodeIsSigned &&
12717+
(Op1.getValueType().getVectorElementType() != AccElemVT ||
12718+
Op2.getValueType().getVectorElementType() != AccElemVT))
12719+
return SDValue();
12720+
12721+
unsigned NewOpcode =
12722+
Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12723+
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
12724+
TruncOp2);
12725+
}
12726+
1267512727
SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
1267612728
auto *SLD = cast<VPStridedLoadSDNode>(N);
1267712729
EVT EltVT = SLD->getValueType(0).getVectorElementType();

0 commit comments

Comments
 (0)