Skip to content

Commit 8666cb3

Browse files
RobinKastbergtoppercXinlong-WuScottEgerton
committed
[RISCV][LLD] Zcmt RISC-V extension in lld
This patch implements optimizations for the zcmt extension in lld. A new TableJumpSection has been added. Scans each R_RISCV_CALL/R_RISCV_CALL_PLT relocType in each section before the linker relaxation, recording the symbol In finalizeContents the recorded symbol names are sorted in descending order by the number of jumps. The top symbols are compressed to table jumps during the relax process. This is a continuation of PR llvm#77884 Co-authored-by: Craig Topper <[email protected]> Co-authored-by: VincentWu <[email protected]> Co-authored-by: Scott Egerton <[email protected]>
1 parent 98ceb45 commit 8666cb3

File tree

12 files changed

+810
-0
lines changed

12 files changed

+810
-0
lines changed

lld/ELF/Arch/RISCV.cpp

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class RISCV final : public TargetInfo {
4343
void scanSectionImpl(InputSectionBase &, Relocs<RelTy>);
4444
template <class ELFT> void scanSection1(InputSectionBase &);
4545
void scanSection(InputSectionBase &) override;
46+
void writeTableJumpHeader(uint8_t *buf) const override;
47+
void writeTableJumpEntry(uint8_t *buf, const uint64_t symbol) const override;
4648
RelType getDynRel(RelType type) const override;
4749
RelExpr getRelExpr(RelType type, const Symbol &s,
4850
const uint8_t *loc) const override;
@@ -75,6 +77,7 @@ class RISCV final : public TargetInfo {
7577
#define INTERNAL_R_RISCV_GPREL_S 257
7678
#define INTERNAL_R_RISCV_X0REL_I 258
7779
#define INTERNAL_R_RISCV_X0REL_S 259
80+
#define INTERNAL_R_RISCV_TBJAL 260
7881

7982
const uint64_t dtpOffset = 0x800;
8083

@@ -274,6 +277,20 @@ void RISCV::writePlt(uint8_t *buf, const Symbol &sym,
274277
write32le(buf + 12, itype(ADDI, 0, 0, 0));
275278
}
276279

280+
void RISCV::writeTableJumpHeader(uint8_t *buf) const {
281+
if (ctx.arg.is64)
282+
write64le(buf, ctx.mainPart->dynamic->getVA());
283+
else
284+
write32le(buf, ctx.mainPart->dynamic->getVA());
285+
}
286+
287+
void RISCV::writeTableJumpEntry(uint8_t *buf, const uint64_t address) const {
288+
if (ctx.arg.is64)
289+
write64le(buf, address);
290+
else
291+
write32le(buf, address);
292+
}
293+
277294
RelType RISCV::getDynRel(RelType type) const {
278295
return type == ctx.target->symbolicRel ? type
279296
: static_cast<RelType>(R_RISCV_NONE);
@@ -496,6 +513,9 @@ void RISCV::relocate(uint8_t *loc, const Relocation &rel, uint64_t val) const {
496513
return;
497514
}
498515

516+
case INTERNAL_R_RISCV_TBJAL:
517+
return;
518+
499519
case R_RISCV_ADD8:
500520
*loc += val;
501521
return;
@@ -745,6 +765,32 @@ void elf::initSymbolAnchors(Ctx &ctx) {
745765
}
746766
}
747767

768+
static bool relaxTableJump(Ctx &ctx, const InputSection &sec, size_t i,
769+
uint64_t loc, Relocation &r, uint32_t &remove) {
770+
if (!ctx.in.riscvTableJumpSection ||
771+
!ctx.in.riscvTableJumpSection->isFinalized)
772+
return false;
773+
774+
const uint32_t jalr = read32le(sec.contentMaybeDecompress().data() +
775+
r.offset + (r.type == R_RISCV_JAL ? 0 : 4));
776+
const uint8_t rd = extractBits(jalr, 11, 7);
777+
int tblEntryIndex = -1;
778+
if (rd == X_X0) {
779+
tblEntryIndex = ctx.in.riscvTableJumpSection->getCMJTEntryIndex(r.sym);
780+
} else if (rd == X_RA) {
781+
tblEntryIndex = ctx.in.riscvTableJumpSection->getCMJALTEntryIndex(r.sym);
782+
}
783+
784+
if (tblEntryIndex >= 0) {
785+
sec.relaxAux->relocTypes[i] = INTERNAL_R_RISCV_TBJAL;
786+
sec.relaxAux->writes.push_back(0xA002 |
787+
(tblEntryIndex << 2)); // cm.jt or cm.jalt
788+
remove = (r.type == R_RISCV_JAL ? 2 : 6);
789+
return true;
790+
}
791+
return false;
792+
}
793+
748794
// Relax R_RISCV_CALL/R_RISCV_CALL_PLT auipc+jalr to c.j, c.jal, or jal.
749795
static void relaxCall(Ctx &ctx, const InputSection &sec, size_t i, uint64_t loc,
750796
Relocation &r, uint32_t &remove) {
@@ -767,6 +813,8 @@ static void relaxCall(Ctx &ctx, const InputSection &sec, size_t i, uint64_t loc,
767813
sec.relaxAux->relocTypes[i] = R_RISCV_RVC_JUMP;
768814
sec.relaxAux->writes.push_back(0x2001); // c.jal
769815
remove = 6;
816+
} else if (remove >= 6 && relaxTableJump(ctx, sec, i, loc, r, remove)) {
817+
// relaxTableJump sets remove
770818
} else if (remove >= 4 && isInt<21>(displace)) {
771819
sec.relaxAux->relocTypes[i] = R_RISCV_JAL;
772820
sec.relaxAux->writes.push_back(0x6f | rd << 7); // jal
@@ -890,6 +938,11 @@ static bool relax(Ctx &ctx, int pass, InputSection &sec) {
890938
relaxCall(ctx, sec, i, loc, r, remove);
891939
}
892940
break;
941+
case R_RISCV_JAL:
942+
if (relaxable(relocs, i)) {
943+
relaxTableJump(ctx, sec, i, loc, r, remove);
944+
}
945+
break;
893946
case R_RISCV_TPREL_HI20:
894947
case R_RISCV_TPREL_ADD:
895948
case R_RISCV_TPREL_LO12_I:
@@ -1144,6 +1197,12 @@ void RISCV::finalizeRelax(int passes) const {
11441197
case INTERNAL_R_RISCV_X0REL_I:
11451198
case INTERNAL_R_RISCV_X0REL_S:
11461199
break;
1200+
case INTERNAL_R_RISCV_TBJAL:
1201+
assert(ctx.arg.relaxTbljal);
1202+
assert((aux.writes[writesIdx] & 0xfc03) == 0xA002);
1203+
skip = 2;
1204+
write16le(p, aux.writes[writesIdx++]);
1205+
break;
11471206
case R_RISCV_RELAX:
11481207
// Used by relaxTlsLe to indicate the relocation is ignored.
11491208
break;
@@ -1155,6 +1214,8 @@ void RISCV::finalizeRelax(int passes) const {
11551214
skip = 4;
11561215
write32le(p, aux.writes[writesIdx++]);
11571216
break;
1217+
case R_RISCV_64:
1218+
break;
11581219
case R_RISCV_32:
11591220
// Used by relaxTlsLe to write a uint32_t then suppress the handling
11601221
// in relocateAlloc.
@@ -1533,3 +1594,219 @@ template <class ELFT> void RISCV::scanSection1(InputSectionBase &sec) {
15331594
void RISCV::scanSection(InputSectionBase &sec) {
15341595
invokeELFT(scanSection1, sec);
15351596
}
1597+
1598+
TableJumpSection::TableJumpSection(Ctx &ctx)
1599+
: SyntheticSection(ctx, ".riscv.jvt", SHT_PROGBITS,
1600+
SHF_ALLOC | SHF_EXECINSTR, tableAlign) {}
1601+
1602+
void TableJumpSection::addCMJTEntryCandidate(const Symbol *symbol,
1603+
int csReduction) {
1604+
addEntry(symbol, CMJTEntryCandidates, csReduction);
1605+
}
1606+
1607+
int TableJumpSection::getCMJTEntryIndex(const Symbol *symbol) {
1608+
uint32_t index = getIndex(symbol, maxCMJTEntrySize, finalizedCMJTEntries);
1609+
return index < finalizedCMJTEntries.size() ? (int)(startCMJTEntryIdx + index)
1610+
: -1;
1611+
}
1612+
1613+
void TableJumpSection::addCMJALTEntryCandidate(const Symbol *symbol,
1614+
int csReduction) {
1615+
addEntry(symbol, CMJALTEntryCandidates, csReduction);
1616+
}
1617+
1618+
int TableJumpSection::getCMJALTEntryIndex(const Symbol *symbol) {
1619+
uint32_t index = getIndex(symbol, maxCMJALTEntrySize, finalizedCMJALTEntries);
1620+
return index < finalizedCMJALTEntries.size()
1621+
? (int)(startCMJALTEntryIdx + index)
1622+
: -1;
1623+
}
1624+
1625+
void TableJumpSection::addEntry(
1626+
const Symbol *symbol, llvm::DenseMap<const Symbol *, int> &entriesList,
1627+
int csReduction) {
1628+
entriesList[symbol] += csReduction;
1629+
}
1630+
1631+
uint32_t TableJumpSection::getIndex(
1632+
const Symbol *symbol, uint32_t maxSize,
1633+
SmallVector<llvm::detail::DenseMapPair<const Symbol *, int>, 0>
1634+
&entriesList) {
1635+
// Find this symbol in the ordered list of entries if it exists.
1636+
assert(maxSize >= entriesList.size() &&
1637+
"Finalized vector of entries exceeds maximum");
1638+
auto idx = std::find_if(
1639+
entriesList.begin(), entriesList.end(),
1640+
[symbol](llvm::detail::DenseMapPair<const Symbol *, int> &e) {
1641+
return e.first == symbol;
1642+
});
1643+
1644+
if (idx == entriesList.end())
1645+
return entriesList.size();
1646+
return idx - entriesList.begin();
1647+
}
1648+
1649+
void TableJumpSection::scanTableJumpEntries(const InputSection &sec) const {
1650+
for (auto [i, r] : llvm::enumerate(sec.relocations)) {
1651+
Defined *definedSymbol = dyn_cast<Defined>(r.sym);
1652+
if (!definedSymbol)
1653+
continue;
1654+
if (i + 1 == sec.relocs().size() ||
1655+
sec.relocs()[i + 1].type != R_RISCV_RELAX)
1656+
continue;
1657+
switch (r.type) {
1658+
case R_RISCV_JAL:
1659+
case R_RISCV_CALL:
1660+
case R_RISCV_CALL_PLT: {
1661+
const uint32_t jalr =
1662+
read32le(sec.contentMaybeDecompress().data() + r.offset +
1663+
(r.type == R_RISCV_JAL ? 0 : 4));
1664+
const uint8_t rd = extractBits(jalr, 11, 7);
1665+
1666+
int csReduction = 6;
1667+
if (sec.relaxAux->relocTypes[i] == R_RISCV_RVC_JUMP)
1668+
continue;
1669+
else if (sec.relaxAux->relocTypes[i] == R_RISCV_JAL)
1670+
csReduction = 2;
1671+
1672+
if (rd == 0)
1673+
ctx.in.riscvTableJumpSection->addCMJTEntryCandidate(r.sym, csReduction);
1674+
else if (rd == X_RA)
1675+
ctx.in.riscvTableJumpSection->addCMJALTEntryCandidate(r.sym,
1676+
csReduction);
1677+
}
1678+
}
1679+
}
1680+
}
1681+
1682+
void TableJumpSection::finalizeContents() {
1683+
if (isFinalized)
1684+
return;
1685+
isFinalized = true;
1686+
1687+
finalizedCMJTEntries = finalizeEntry(CMJTEntryCandidates, maxCMJTEntrySize);
1688+
CMJTEntryCandidates.clear();
1689+
int32_t CMJTSizeReduction = getSizeReduction();
1690+
finalizedCMJALTEntries =
1691+
finalizeEntry(CMJALTEntryCandidates, maxCMJALTEntrySize);
1692+
CMJALTEntryCandidates.clear();
1693+
1694+
if (!finalizedCMJALTEntries.empty() &&
1695+
getSizeReduction() < CMJTSizeReduction) {
1696+
// In memory, the cm.jt table occupies the first 0x20 entries.
1697+
// To be able to use the cm.jalt table which comes afterwards
1698+
// it is necessary to pad out the cm.jt table.
1699+
// Remove cm.jalt entries if the code reduction of cm.jalt is
1700+
// smaller than the size of the padding.
1701+
finalizedCMJALTEntries.clear();
1702+
}
1703+
// if table jump still got negative effect, give up.
1704+
if (getSizeReduction() <= 0) {
1705+
warn("Table Jump Relaxation didn't got any reduction for code size.");
1706+
finalizedCMJTEntries.clear();
1707+
}
1708+
}
1709+
1710+
// Sort the map in decreasing order of the amount of code reduction provided
1711+
// by the entries. Drop any entries that can't fit in the map from the tail
1712+
// end since they provide less code reduction. Drop any entries that cause
1713+
// an increase in code size (i.e. the reduction from instruction conversion
1714+
// does not cover the code size gain from adding a table entry).
1715+
SmallVector<llvm::detail::DenseMapPair<const Symbol *, int>, 0>
1716+
TableJumpSection::finalizeEntry(llvm::DenseMap<const Symbol *, int> EntryMap,
1717+
uint32_t maxSize) {
1718+
auto cmp = [](const llvm::detail::DenseMapPair<const Symbol *, int> &p1,
1719+
const llvm::detail::DenseMapPair<const Symbol *, int> &p2) {
1720+
return p1.second > p2.second;
1721+
};
1722+
1723+
SmallVector<llvm::detail::DenseMapPair<const Symbol *, int>, 0>
1724+
tempEntryVector;
1725+
std::copy(EntryMap.begin(), EntryMap.end(),
1726+
std::back_inserter(tempEntryVector));
1727+
std::sort(tempEntryVector.begin(), tempEntryVector.end(), cmp);
1728+
1729+
auto finalizedVector = tempEntryVector;
1730+
1731+
finalizedVector.resize(maxSize);
1732+
1733+
// Drop any items that have a negative effect (i.e. increase code size).
1734+
while (!finalizedVector.empty()) {
1735+
if (finalizedVector.rbegin()->second < ctx.arg.wordsize)
1736+
finalizedVector.pop_back();
1737+
else
1738+
break;
1739+
}
1740+
return finalizedVector;
1741+
}
1742+
1743+
size_t TableJumpSection::getSize() const {
1744+
if (isFinalized) {
1745+
if (!finalizedCMJALTEntries.empty())
1746+
return (startCMJALTEntryIdx + finalizedCMJALTEntries.size()) *
1747+
ctx.arg.wordsize;
1748+
return (startCMJTEntryIdx + finalizedCMJTEntries.size()) * ctx.arg.wordsize;
1749+
}
1750+
1751+
if (!CMJALTEntryCandidates.empty())
1752+
return (startCMJALTEntryIdx + CMJALTEntryCandidates.size()) *
1753+
ctx.arg.wordsize;
1754+
return (startCMJTEntryIdx + CMJTEntryCandidates.size()) * ctx.arg.wordsize;
1755+
}
1756+
1757+
int32_t TableJumpSection::getSizeReduction() {
1758+
// The total reduction in code size is J + JA - JTS - JAE.
1759+
// Where:
1760+
// J = number of bytes saved for all the cm.jt instructions emitted
1761+
// JA = number of bytes saved for all the cm.jalt instructions emitted
1762+
// JTS = size of the part of the table for cm.jt jumps (i.e. 32 x wordsize)
1763+
// JAE = number of entries emitted for the cm.jalt jumps x wordsize
1764+
1765+
int32_t sizeReduction = -getSize();
1766+
for (auto entry : finalizedCMJTEntries) {
1767+
sizeReduction += entry.second;
1768+
}
1769+
for (auto entry : finalizedCMJALTEntries) {
1770+
sizeReduction += entry.second;
1771+
}
1772+
return sizeReduction;
1773+
}
1774+
1775+
void TableJumpSection::writeTo(uint8_t *buf) {
1776+
if (getSizeReduction() <= 0)
1777+
return;
1778+
ctx.target->writeTableJumpHeader(buf);
1779+
writeEntries(buf + startCMJTEntryIdx * ctx.arg.wordsize,
1780+
finalizedCMJTEntries);
1781+
if (finalizedCMJALTEntries.size() > 0) {
1782+
padWords(buf + ((startCMJTEntryIdx + finalizedCMJTEntries.size()) *
1783+
ctx.arg.wordsize),
1784+
startCMJALTEntryIdx);
1785+
writeEntries(buf + (startCMJALTEntryIdx * ctx.arg.wordsize),
1786+
finalizedCMJALTEntries);
1787+
}
1788+
}
1789+
1790+
void TableJumpSection::padWords(uint8_t *buf, const uint8_t maxWordCount) {
1791+
for (size_t i = 0; i < maxWordCount; ++i) {
1792+
if (ctx.arg.is64)
1793+
write64le(buf + i, 0);
1794+
else
1795+
write32le(buf + i, 0);
1796+
}
1797+
}
1798+
1799+
void TableJumpSection::writeEntries(
1800+
uint8_t *buf,
1801+
SmallVector<llvm::detail::DenseMapPair<const Symbol *, int>, 0>
1802+
&entriesList) {
1803+
for (const auto &entry : entriesList) {
1804+
assert(entry.second > 0);
1805+
// Use the symbol from in.symTab to ensure we have the final adjusted
1806+
// symbol.
1807+
if (!entry.first->isDefined())
1808+
continue;
1809+
ctx.target->writeTableJumpEntry(buf, entry.first->getVA(ctx, 0));
1810+
buf += ctx.arg.wordsize;
1811+
}
1812+
}

lld/ELF/Config.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class MipsGotSection;
6767
class MipsRldMapSection;
6868
class PPC32Got2Section;
6969
class PPC64LongBranchTargetSection;
70+
class TableJumpSection;
7071
class PltSection;
7172
class RelocationBaseSection;
7273
class RelroPaddingSection;
@@ -370,6 +371,7 @@ struct Config {
370371
bool resolveGroups;
371372
bool relrGlibc = false;
372373
bool relrPackDynRelocs = false;
374+
bool relaxTbljal;
373375
llvm::DenseSet<llvm::StringRef> saveTempsArgs;
374376
llvm::SmallVector<std::pair<llvm::GlobPattern, uint32_t>, 0> shuffleSections;
375377
bool singleRoRx;
@@ -582,6 +584,7 @@ struct InStruct {
582584
std::unique_ptr<RelroPaddingSection> relroPadding;
583585
std::unique_ptr<SyntheticSection> armCmseSGSection;
584586
std::unique_ptr<PPC64LongBranchTargetSection> ppc64LongBranchTarget;
587+
std::unique_ptr<TableJumpSection> riscvTableJumpSection;
585588
std::unique_ptr<SyntheticSection> mipsAbiFlags;
586589
std::unique_ptr<MipsGotSection> mipsGot;
587590
std::unique_ptr<SyntheticSection> mipsOptions;

lld/ELF/Driver.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,6 +1623,7 @@ static void readConfigs(Ctx &ctx, opt::InputArgList &args) {
16231623
}
16241624
ctx.arg.zCombreloc = getZFlag(args, "combreloc", "nocombreloc", true);
16251625
ctx.arg.zCopyreloc = getZFlag(args, "copyreloc", "nocopyreloc", true);
1626+
ctx.arg.relaxTbljal = args.hasArg(OPT_relax_tbljal);
16261627
ctx.arg.zForceBti = hasZOption(args, "force-bti");
16271628
ctx.arg.zForceIbt = hasZOption(args, "force-ibt");
16281629
ctx.arg.zZicfilp = getZZicfilp(ctx, args);

lld/ELF/Options.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,11 @@ defm use_android_relr_tags: BB<"use-android-relr-tags",
378378
"Use SHT_ANDROID_RELR / DT_ANDROID_RELR* tags instead of SHT_RELR / DT_RELR*",
379379
"Use SHT_RELR / DT_RELR* tags (default)">;
380380

381+
def relax_tbljal : FF<"relax-tbljal">,
382+
HelpText<"Enable conversion of call instructions to table "
383+
"jump instruction from the Zcmt extension for "
384+
"frequently called functions (RISC-V only)">;
385+
381386
def pic_veneer: F<"pic-veneer">,
382387
HelpText<"Always generate position independent thunks (veneers)">;
383388

0 commit comments

Comments
 (0)