Skip to content

Commit 638c0ca

Browse files
Change splitting functions.
Adjust ISDOpcode description. Rename variables in expand function. Remove unnecessary assert statement.
1 parent d5719f9 commit 638c0ca

File tree

5 files changed

+112
-28
lines changed

5 files changed

+112
-28
lines changed

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,6 +1456,9 @@ enum NodeType {
14561456
// element type of Accumulator before multiplying their results.
14571457
// This result is concatenated to the Accumulator, and this is then reduced,
14581458
// using addition, to the result type.
1459+
// The output is only expected to either be given to another partial reduction
1460+
// operation or an equivalent vector reduce operation, so the order in which
1461+
// the elements are reduced is deliberately not specified.
14591462
// Input1 and Input2 must be the same type. Accumulator and the output must be
14601463
// the same type.
14611464
// The number of elements in Input1 and Input2 must be a positive integer

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -970,7 +970,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
970970
void SplitVecRes_VAARG(SDNode *N, SDValue &Lo, SDValue &Hi);
971971
void SplitVecRes_FP_TO_XINT_SAT(SDNode *N, SDValue &Lo, SDValue &Hi);
972972
void SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo, SDValue &Hi);
973-
void SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N);
973+
void SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N, SDValue &Lo, SDValue &Hi);
974974

975975
// Vector Operand Splitting: <128 x ty> -> 2 x <64 x ty>.
976976
bool SplitVectorOperand(SDNode *N, unsigned OpNo);

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,7 +1375,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
13751375
break;
13761376
case ISD::PARTIAL_REDUCE_UMLA:
13771377
case ISD::PARTIAL_REDUCE_SMLA:
1378-
SplitVecRes_PARTIAL_REDUCE_MLA(N);
1378+
SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
13791379
break;
13801380
}
13811381

@@ -3186,9 +3186,11 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
31863186
std::tie(Lo, Hi) = DAG.SplitVector(Load, DL);
31873187
}
31883188

3189-
void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
3190-
SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
3191-
ReplaceValueWith(SDValue(N, 0), Res);
3189+
void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N, SDValue &Lo,
3190+
SDValue &Hi) {
3191+
SDLoc DL(N);
3192+
SDValue Expanded = TLI.expandPartialReduceMLA(N, DAG);
3193+
std::tie(Lo, Hi) = DAG.SplitVector(Expanded, DL);
31923194
}
31933195

31943196
void DAGTypeLegalizer::SplitVecRes_VECTOR_DEINTERLEAVE(SDNode *N) {
@@ -4449,9 +4451,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
44494451
}
44504452

44514453
SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
4452-
SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
4453-
ReplaceValueWith(SDValue(N, 0), Res);
4454-
return SDValue();
4454+
return TLI.expandPartialReduceMLA(N, DAG);
44554455
}
44564456

44574457
//===----------------------------------------------------------------------===//

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11897,46 +11897,41 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
1189711897
SDValue Acc = N->getOperand(0);
1189811898
SDValue MulLHS = N->getOperand(1);
1189911899
SDValue MulRHS = N->getOperand(2);
11900-
EVT ReducedTy = Acc.getValueType();
11901-
EVT FullTy = MulLHS.getValueType();
11900+
EVT AccVT = Acc.getValueType();
11901+
EVT MulOpVT = MulLHS.getValueType();
1190211902

11903-
EVT NewVT =
11904-
EVT::getVectorVT(*DAG.getContext(), ReducedTy.getVectorElementType(),
11905-
FullTy.getVectorElementCount());
11903+
EVT ExtMulOpVT =
11904+
EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(),
11905+
MulOpVT.getVectorElementCount());
1190611906
unsigned ExtOpc = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
1190711907
? ISD::SIGN_EXTEND
1190811908
: ISD::ZERO_EXTEND;
11909-
EVT MulLHSVT = MulLHS.getValueType();
11910-
assert(MulLHSVT == MulRHS.getValueType() &&
11911-
"The second and third operands of a PARTIAL_REDUCE_MLA node must have "
11912-
"the same value type!");
11913-
EVT ExtVT = MulLHSVT.changeVectorElementType(
11914-
Acc.getValueType().getVectorElementType());
11915-
if (ExtVT != FullTy) {
11916-
MulLHS = DAG.getNode(ExtOpc, DL, ExtVT, MulLHS);
11917-
MulRHS = DAG.getNode(ExtOpc, DL, ExtVT, MulRHS);
11909+
11910+
if (ExtMulOpVT != MulOpVT) {
11911+
MulLHS = DAG.getNode(ExtOpc, DL, ExtMulOpVT, MulLHS);
11912+
MulRHS = DAG.getNode(ExtOpc, DL, ExtMulOpVT, MulRHS);
1191811913
}
1191911914
SDValue Input = MulLHS;
1192011915
APInt ConstantOne;
1192111916
if (!ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) ||
1192211917
!ConstantOne.isOne())
11923-
Input = DAG.getNode(ISD::MUL, DL, NewVT, MulLHS, MulRHS);
11918+
Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS);
1192411919

11925-
unsigned Stride = ReducedTy.getVectorMinNumElements();
11926-
unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
11920+
unsigned Stride = AccVT.getVectorMinNumElements();
11921+
unsigned ScaleFactor = MulOpVT.getVectorMinNumElements() / Stride;
1192711922

1192811923
// Collect all of the subvectors
1192911924
std::deque<SDValue> Subvectors = {Acc};
1193011925
for (unsigned I = 0; I < ScaleFactor; I++) {
1193111926
auto SourceIndex = DAG.getVectorIdxConstant(I * Stride, DL);
11932-
Subvectors.push_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy,
11933-
{Input, SourceIndex}));
11927+
Subvectors.push_back(
11928+
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, AccVT, {Input, SourceIndex}));
1193411929
}
1193511930

1193611931
// Flatten the subvector tree
1193711932
while (Subvectors.size() > 1) {
1193811933
Subvectors.push_back(
11939-
DAG.getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]}));
11934+
DAG.getNode(ISD::ADD, DL, AccVT, {Subvectors[0], Subvectors[1]}));
1194011935
Subvectors.pop_front();
1194111936
Subvectors.pop_front();
1194211937
}

llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,3 +1272,89 @@ entry:
12721272
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
12731273
ret <vscale x 2 x i64> %partial.reduce
12741274
}
1275+
1276+
define <vscale x 2 x i16> @udot_nxv8i8_promote (<vscale x 2 x i16> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b){
1277+
; CHECK-LABEL: udot_nxv8i8_promote:
1278+
; CHECK: // %bb.0: // %entry
1279+
; CHECK-NEXT: and z1.h, z1.h, #0xff
1280+
; CHECK-NEXT: and z2.h, z2.h, #0xff
1281+
; CHECK-NEXT: mul z1.h, z1.h, z2.h
1282+
; CHECK-NEXT: uunpklo z2.s, z1.h
1283+
; CHECK-NEXT: uunpkhi z1.s, z1.h
1284+
; CHECK-NEXT: uunpklo z3.d, z2.s
1285+
; CHECK-NEXT: uunpklo z4.d, z1.s
1286+
; CHECK-NEXT: uunpkhi z2.d, z2.s
1287+
; CHECK-NEXT: uunpkhi z1.d, z1.s
1288+
; CHECK-NEXT: add z0.d, z0.d, z3.d
1289+
; CHECK-NEXT: add z2.d, z2.d, z4.d
1290+
; CHECK-NEXT: add z0.d, z1.d, z0.d
1291+
; CHECK-NEXT: add z0.d, z2.d, z0.d
1292+
; CHECK-NEXT: ret
1293+
;
1294+
; CHECK-NEWLOWERING-LABEL: udot_nxv8i8_promote:
1295+
; CHECK-NEWLOWERING: // %bb.0: // %entry
1296+
; CHECK-NEWLOWERING-NEXT: and z1.h, z1.h, #0xff
1297+
; CHECK-NEWLOWERING-NEXT: and z2.h, z2.h, #0xff
1298+
; CHECK-NEWLOWERING-NEXT: mul z1.h, z1.h, z2.h
1299+
; CHECK-NEWLOWERING-NEXT: uunpklo z2.s, z1.h
1300+
; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
1301+
; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z2.s
1302+
; CHECK-NEWLOWERING-NEXT: uunpklo z4.d, z1.s
1303+
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
1304+
; CHECK-NEWLOWERING-NEXT: uunpkhi z1.d, z1.s
1305+
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z3.d
1306+
; CHECK-NEWLOWERING-NEXT: add z2.d, z2.d, z4.d
1307+
; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
1308+
; CHECK-NEWLOWERING-NEXT: add z0.d, z2.d, z0.d
1309+
; CHECK-NEWLOWERING-NEXT: ret
1310+
entry:
1311+
%a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i16>
1312+
%b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i16>
1313+
%mult = mul nuw nsw <vscale x 8 x i16> %a.wide, %b.wide
1314+
%partial.reduce = tail call <vscale x 2 x i16> @llvm.experimental.vector.partial.reduce.add.nxv2i16.nxv8i16(<vscale x 2 x i16> %acc, <vscale x 8 x i16> %mult)
1315+
ret <vscale x 2 x i16> %partial.reduce
1316+
}
1317+
1318+
define <vscale x 2 x i16> @sdot_nxv8i8_promote (<vscale x 2 x i16> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b){
1319+
; CHECK-LABEL: sdot_nxv8i8_promote:
1320+
; CHECK: // %bb.0: // %entry
1321+
; CHECK-NEXT: ptrue p0.h
1322+
; CHECK-NEXT: sxtb z1.h, p0/m, z1.h
1323+
; CHECK-NEXT: sxtb z2.h, p0/m, z2.h
1324+
; CHECK-NEXT: mul z1.h, z1.h, z2.h
1325+
; CHECK-NEXT: uunpklo z2.s, z1.h
1326+
; CHECK-NEXT: uunpkhi z1.s, z1.h
1327+
; CHECK-NEXT: uunpklo z3.d, z2.s
1328+
; CHECK-NEXT: uunpklo z4.d, z1.s
1329+
; CHECK-NEXT: uunpkhi z2.d, z2.s
1330+
; CHECK-NEXT: uunpkhi z1.d, z1.s
1331+
; CHECK-NEXT: add z0.d, z0.d, z3.d
1332+
; CHECK-NEXT: add z2.d, z2.d, z4.d
1333+
; CHECK-NEXT: add z0.d, z1.d, z0.d
1334+
; CHECK-NEXT: add z0.d, z2.d, z0.d
1335+
; CHECK-NEXT: ret
1336+
;
1337+
; CHECK-NEWLOWERING-LABEL: sdot_nxv8i8_promote:
1338+
; CHECK-NEWLOWERING: // %bb.0: // %entry
1339+
; CHECK-NEWLOWERING-NEXT: ptrue p0.h
1340+
; CHECK-NEWLOWERING-NEXT: sxtb z1.h, p0/m, z1.h
1341+
; CHECK-NEWLOWERING-NEXT: sxtb z2.h, p0/m, z2.h
1342+
; CHECK-NEWLOWERING-NEXT: mul z1.h, z1.h, z2.h
1343+
; CHECK-NEWLOWERING-NEXT: uunpklo z2.s, z1.h
1344+
; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
1345+
; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z2.s
1346+
; CHECK-NEWLOWERING-NEXT: uunpklo z4.d, z1.s
1347+
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
1348+
; CHECK-NEWLOWERING-NEXT: uunpkhi z1.d, z1.s
1349+
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z3.d
1350+
; CHECK-NEWLOWERING-NEXT: add z2.d, z2.d, z4.d
1351+
; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
1352+
; CHECK-NEWLOWERING-NEXT: add z0.d, z2.d, z0.d
1353+
; CHECK-NEWLOWERING-NEXT: ret
1354+
entry:
1355+
%a.wide = sext <vscale x 8 x i8> %a to <vscale x 8 x i16>
1356+
%b.wide = sext <vscale x 8 x i8> %b to <vscale x 8 x i16>
1357+
%mult = mul nuw nsw <vscale x 8 x i16> %a.wide, %b.wide
1358+
%partial.reduce = tail call <vscale x 2 x i16> @llvm.experimental.vector.partial.reduce.add.nxv2i16.nxv8i16(<vscale x 2 x i16> %acc, <vscale x 8 x i16> %mult)
1359+
ret <vscale x 2 x i16> %partial.reduce
1360+
}

0 commit comments

Comments
 (0)