Skip to content

Commit 22636ac

Browse files
committed
Adjust how usdot cases are lowered
1 parent d40773d commit 22636ac

File tree

2 files changed

+15
-13
lines changed

2 files changed

+15
-13
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -925,8 +925,18 @@ SDValue DAGTypeLegalizer::CreateStackStoreLoad(SDValue Op,
925925
bool DAGTypeLegalizer::CustomLowerNode(SDNode *N, EVT VT, bool LegalizeResult) {
926926
// See if the target wants to custom lower this node.
927927
unsigned Opcode = N->getOpcode();
928-
if (TLI.getOperationAction(Opcode, VT) != TargetLowering::Custom)
929-
return false;
928+
bool IsPRMLAOpcode =
929+
Opcode == ISD::PARTIAL_REDUCE_UMLA || Opcode == ISD::PARTIAL_REDUCE_SMLA;
930+
931+
if (IsPRMLAOpcode) {
932+
if (TLI.getPartialReduceMLAAction(N->getValueType(0),
933+
N->getOperand(1).getValueType()) !=
934+
TargetLowering::Custom)
935+
return false;
936+
} else {
937+
if (TLI.getOperationAction(Opcode, VT) != TargetLowering::Custom)
938+
return false;
939+
}
930940

931941
SmallVector<SDValue, 8> Results;
932942
if (LegalizeResult)

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,17 +1872,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18721872
// 8to64
18731873
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
18741874

1875-
if (Subtarget->hasMatMulInt8()) {
1876-
// USDOT
1877-
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i64, Custom);
1875+
// USDOT
1876+
if (Subtarget->hasMatMulInt8())
18781877
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i32, Custom);
1879-
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i64, Custom);
1880-
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i32, Custom);
1881-
setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i16, Custom);
1882-
setPartialReduceMLAAction(MVT::nxv16i8, MVT::nxv32i8, Custom);
1883-
1884-
setOperationAction(ISD::PARTIAL_REDUCE_UMLA, MVT::nxv16i32, Custom);
1885-
}
18861878
}
18871879

18881880
// Handle operations that are only available in non-streaming SVE mode.
@@ -7755,8 +7747,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
77557747
return LowerFLDEXP(Op, DAG);
77567748
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
77577749
return LowerVECTOR_HISTOGRAM(Op, DAG);
7758-
case ISD::PARTIAL_REDUCE_UMLA:
77597750
case ISD::PARTIAL_REDUCE_SMLA:
7751+
case ISD::PARTIAL_REDUCE_UMLA:
77607752
return LowerPARTIAL_REDUCE_MLA(Op, DAG);
77617753
}
77627754
}

0 commit comments

Comments
 (0)