-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[SDAG] Add partial_reduce_sumla node #141267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
0015c50
ecc4d3b
eb755c2
5674b18
9d624e9
4a5647f
da0e907
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1874,8 +1874,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, | |
if (EnablePartialReduceNodes && Subtarget->isSVEorStreamingSVEAvailable()) { | ||
// Mark known legal pairs as 'Legal' (these will expand to UDOT or SDOT). | ||
// Other pairs will default to 'Expand'. | ||
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal); | ||
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal); | ||
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Custom); | ||
|
||
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom); | ||
|
||
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom); | ||
} | ||
|
@@ -7745,6 +7745,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, | |
return LowerVECTOR_HISTOGRAM(Op, DAG); | ||
case ISD::PARTIAL_REDUCE_SMLA: | ||
case ISD::PARTIAL_REDUCE_UMLA: | ||
case ISD::PARTIAL_REDUCE_SUMLA: | ||
return LowerPARTIAL_REDUCE_MLA(Op, DAG); | ||
} | ||
} | ||
|
@@ -29532,13 +29533,24 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op, | |
SDValue | ||
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op, | ||
SelectionDAG &DAG) const { | ||
// No support for sumla forms, let generic legalization handle them | ||
if (Op->getOpcode() == ISD::PARTIAL_REDUCE_SUMLA) | ||
return SDValue(); | ||
|
||
SDLoc DL(Op); | ||
|
||
SDValue Acc = Op.getOperand(0); | ||
SDValue LHS = Op.getOperand(1); | ||
SDValue RHS = Op.getOperand(2); | ||
EVT ResultVT = Op.getValueType(); | ||
assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8); | ||
EVT OpVT = LHS.getValueType(); | ||
|
||
// These two are legal... | ||
if ((ResultVT == MVT::nxv2i64 && OpVT == MVT::nxv8i16) || | ||
(ResultVT == MVT::nxv4i32 && OpVT == MVT::nxv16i8)) | ||
return Op; | ||
|
||
assert(ResultVT == MVT::nxv2i64 && OpVT == MVT::nxv16i8); | ||
|
||
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32, | ||
DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate on what you mean by this TODO? I'm not sure I follow why we'd want to handle a
zext
as asext
in this case.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A zext nonneg is a zext for which the high bit is known to be zero, and thus is equivalent to a sext. We canonicalize such cases to zext nonneg. As such, handling zext nonneg would allow us to recognize more parial_reduce_smla cases which we'd currently miss. Note that for partial_reduce_sumla enabled targets, this might not matter since we'd just chose an alternate instruction.