Skip to content

Commit 0be4495

Browse files
committed
AMDGPU: Improve getShuffleCost accuracy for 8- and 16-bit shuffles
These shuffles can always be implemented using v_perm_b32, and so this rewrites the analysis from the perspective of "how many v_perm_b32s does it take to assemble each register of the result?" The test changes in Transforms/SLPVectorizer/reduction.ll are reasonable: VI (gfx8) has native f16 math, but not packed math. commit-id:8b76e888
1 parent 6683adb commit 0be4495

File tree

6 files changed

+747
-666
lines changed

6 files changed

+747
-666
lines changed

llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp

Lines changed: 91 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,46 +1241,108 @@ InstructionCost GCNTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
12411241
(ScalarSize == 16 || ScalarSize == 8)) {
12421242
// Larger vector widths may require additional instructions, but are
12431243
// typically cheaper than scalarized versions.
1244-
unsigned NumVectorElts = cast<FixedVectorType>(SrcTy)->getNumElements();
1245-
unsigned RequestedElts =
1246-
count_if(Mask, [](int MaskElt) { return MaskElt != -1; });
1247-
unsigned EltsPerReg = 32 / ScalarSize;
1248-
if (RequestedElts == 0)
1244+
//
1245+
// We assume that shuffling at a register granularity can be done for free.
1246+
// This is not true for vectors fed into memory instructions, but it is
1247+
// effectively true for all other shuffling. The emphasis of the logic here
1248+
// is to assist generic transform in cleaning up / canonicalizing those
1249+
// shuffles.
1250+
unsigned NumDstElts = cast<FixedVectorType>(DstTy)->getNumElements();
1251+
unsigned NumSrcElts = cast<FixedVectorType>(SrcTy)->getNumElements();
1252+
1253+
// With op_sel VOP3P instructions freely can access the low half or high
1254+
// half of a register, so any swizzle of two elements is free.
1255+
if (ST->hasVOP3PInsts() && ScalarSize == 16 && NumSrcElts == 2 &&
1256+
(Kind == TTI::SK_Broadcast || Kind == TTI::SK_Reverse ||
1257+
Kind == TTI::SK_PermuteSingleSrc))
12491258
return 0;
1259+
1260+
unsigned EltsPerReg = 32 / ScalarSize;
12501261
switch (Kind) {
12511262
case TTI::SK_Broadcast:
1263+
// A single v_perm_b32 can be re-used for all destination registers.
1264+
return 1;
12521265
case TTI::SK_Reverse:
1253-
case TTI::SK_PermuteSingleSrc: {
1254-
// With op_sel VOP3P instructions freely can access the low half or high
1255-
// half of a register, so any swizzle of two elements is free.
1256-
if (ST->hasVOP3PInsts() && ScalarSize == 16 && NumVectorElts == 2)
1257-
return 0;
1258-
unsigned NumPerms = alignTo(RequestedElts, EltsPerReg) / EltsPerReg;
1259-
// SK_Broadcast just reuses the same mask
1260-
unsigned NumPermMasks = Kind == TTI::SK_Broadcast ? 1 : NumPerms;
1261-
return NumPerms + NumPermMasks;
1262-
}
1266+
// One instruction per register.
1267+
return divideCeil(NumDstElts, EltsPerReg);
12631268
case TTI::SK_ExtractSubvector:
1269+
if (Index % EltsPerReg == 0)
1270+
return 0; // Shuffling at register granularity
1271+
return divideCeil(NumDstElts, EltsPerReg);
12641272
case TTI::SK_InsertSubvector: {
1265-
// Even aligned accesses are free
1266-
if (!(Index % 2))
1267-
return 0;
1268-
// Insert/extract subvectors only require shifts / extract code to get the
1269-
// relevant bits
1270-
return alignTo(RequestedElts, EltsPerReg) / EltsPerReg;
1273+
unsigned NumInsertElts = cast<FixedVectorType>(SubTp)->getNumElements();
1274+
unsigned EndIndex = Index + NumInsertElts;
1275+
unsigned BeginSubIdx = Index % EltsPerReg;
1276+
unsigned EndSubIdx = EndIndex % EltsPerReg;
1277+
unsigned Cost = 0;
1278+
1279+
if (BeginSubIdx != 0) {
1280+
// Need to shift the inserted vector into place. The cost is the number
1281+
// of destination registers overlapped by the inserted vector.
1282+
Cost = divideCeil(EndIndex, EltsPerReg) - (Index / EltsPerReg);
1283+
}
1284+
1285+
// If the last register overlap is partial, there may be three source
1286+
// registers feeding into it; that takes an extra instruction.
1287+
if (EndIndex < NumDstElts && BeginSubIdx < EndSubIdx)
1288+
Cost += 1;
1289+
1290+
return Cost;
12711291
}
1272-
case TTI::SK_PermuteTwoSrc:
1273-
case TTI::SK_Splice:
1274-
case TTI::SK_Select: {
1275-
unsigned NumPerms = alignTo(RequestedElts, EltsPerReg) / EltsPerReg;
1276-
// SK_Select just reuses the same mask
1277-
unsigned NumPermMasks = Kind == TTI::SK_Select ? 1 : NumPerms;
1278-
return NumPerms + NumPermMasks;
1292+
case TTI::SK_Splice: {
1293+
assert(NumDstElts == NumSrcElts);
1294+
// Determine the sub-region of the result vector that requires
1295+
// sub-register shuffles / mixing.
1296+
unsigned EltsFromLHS = NumSrcElts - Index;
1297+
bool LHSIsAligned = (Index % EltsPerReg) == 0;
1298+
bool RHSIsAligned = (EltsFromLHS % EltsPerReg) == 0;
1299+
if (LHSIsAligned && RHSIsAligned)
1300+
return 0;
1301+
if (LHSIsAligned && !RHSIsAligned)
1302+
return divideCeil(NumDstElts - EltsFromLHS, EltsPerReg);
1303+
if (!LHSIsAligned && RHSIsAligned)
1304+
return divideCeil(EltsFromLHS, EltsPerReg);
1305+
return divideCeil(NumDstElts, EltsPerReg);
12791306
}
1280-
12811307
default:
12821308
break;
12831309
}
1310+
1311+
if (!Mask.empty()) {
1312+
// Generically estimate the cost by assuming that each destination
1313+
// register is derived from sources via v_perm_b32 instructions if it
1314+
// can't be copied as-is.
1315+
//
1316+
// For each destination register, derive the cost of obtaining it based
1317+
// on the number of source registers that feed into it.
1318+
unsigned Cost = 0;
1319+
for (unsigned DstIdx = 0; DstIdx < Mask.size(); DstIdx += EltsPerReg) {
1320+
SmallVector<int, 4> Regs;
1321+
bool Aligned = true;
1322+
for (unsigned I = 0; I < EltsPerReg && DstIdx + I < Mask.size(); ++I) {
1323+
int SrcIdx = Mask[DstIdx + I];
1324+
if (SrcIdx == -1)
1325+
continue;
1326+
int Reg;
1327+
if (SrcIdx < (int)NumSrcElts) {
1328+
Reg = SrcIdx / EltsPerReg;
1329+
if (SrcIdx % EltsPerReg != I)
1330+
Aligned = false;
1331+
} else {
1332+
Reg = NumSrcElts + (SrcIdx - NumSrcElts) / EltsPerReg;
1333+
if ((SrcIdx - NumSrcElts) % EltsPerReg != I)
1334+
Aligned = false;
1335+
}
1336+
if (!llvm::is_contained(Regs, Reg))
1337+
Regs.push_back(Reg);
1338+
}
1339+
if (Regs.size() >= 2)
1340+
Cost += Regs.size() - 1;
1341+
else if (!Aligned)
1342+
Cost += 1;
1343+
}
1344+
return Cost;
1345+
}
12841346
}
12851347

12861348
return BaseT::getShuffleCost(Kind, DstTy, SrcTy, Mask, CostKind, Index,

0 commit comments

Comments
 (0)