Skip to content

Commit f634373

Browse files
authored
[Branch Hinting] Add binary support (#7572)
1 parent d84d376 commit f634373

File tree

4 files changed

+366
-8
lines changed

4 files changed

+366
-8
lines changed

src/wasm-binary.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,6 +1404,12 @@ class WasmBinaryWriter {
14041404
void trackExpressionEnd(Expression* curr, Function* func);
14051405
void trackExpressionDelimiter(Expression* curr, Function* func, size_t id);
14061406

1407+
// Writes code annotations into a buffer and returns it. We cannot write them
1408+
// directly into the output since we write function code first (to get the
1409+
// offsets for the annotations), and only then can write annotations, which we
1410+
// must then insert before the code (as the spec requires that).
1411+
std::optional<BufferWithRandomAccess> writeCodeAnnotations();
1412+
14071413
// helpers
14081414
void writeInlineString(std::string_view name);
14091415
void writeEscapedName(std::string_view name);
@@ -1667,6 +1673,15 @@ class WasmBinaryReader {
16671673
void readDylink(size_t payloadLen);
16681674
void readDylink0(size_t payloadLen);
16691675

1676+
// We read branch hints *after* the code section, even though they appear
1677+
// earlier. That is simpler for us as we note expression locations as we scan
1678+
// code, and then just need to match them up. To do this, we note the branch
1679+
// hint position and size in the first pass, and handle it later.
1680+
size_t branchHintsPos = 0;
1681+
size_t branchHintsLen = 0;
1682+
1683+
void readBranchHints(size_t payloadLen);
1684+
16701685
Index readMemoryAccess(Address& alignment, Address& offset);
16711686
std::tuple<Name, Address, Address> getMemarg();
16721687
MemoryOrder getMemoryOrder(bool isRMW = false);

src/wasm/wasm-binary.cpp

Lines changed: 211 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "support/debug.h"
2929
#include "support/stdckdint.h"
3030
#include "support/string.h"
31+
#include "wasm-annotations.h"
3132
#include "wasm-binary.h"
3233
#include "wasm-debug.h"
3334
#include "wasm-limits.h"
@@ -474,6 +475,21 @@ void WasmBinaryWriter::writeFunctions() {
474475
}
475476
});
476477
finishSection(sectionStart);
478+
479+
// Code annotations must come before the code section (see comment on
480+
// writeCodeAnnotations).
481+
if (auto annotations = writeCodeAnnotations()) {
482+
// We need to move the code section and put the annotations before it.
483+
auto& annotationsBuffer = *annotations;
484+
auto oldSize = o.size();
485+
o.resize(oldSize + annotationsBuffer.size());
486+
487+
// |sectionStart| is the start of the contents of the section. Subtract 1 to
488+
// include the section code as well, so we move all of it.
489+
std::move_backward(&o[sectionStart - 1], &o[oldSize], o.end());
490+
std::copy(
491+
annotationsBuffer.begin(), annotationsBuffer.end(), &o[sectionStart - 1]);
492+
}
477493
}
478494

479495
void WasmBinaryWriter::writeStrings() {
@@ -1496,14 +1512,17 @@ void WasmBinaryWriter::trackExpressionStart(Expression* curr, Function* func) {
14961512
// binary locations tracked, then track it in the output as well. We also
14971513
// track locations of instructions that have code annotations, as their binary
14981514
// location goes in the custom section.
1499-
if (func && !func->expressionLocations.empty()) {
1515+
if (func && (!func->expressionLocations.empty() ||
1516+
func->codeAnnotations.count(curr))) {
15001517
binaryLocations.expressions[curr] =
15011518
BinaryLocations::Span{BinaryLocation(o.size()), 0};
15021519
binaryLocationTrackedExpressionsForFunc.push_back(curr);
15031520
}
15041521
}
15051522

15061523
void WasmBinaryWriter::trackExpressionEnd(Expression* curr, Function* func) {
1524+
// TODO: If we need to track the end of annotated code locations, we need to
1525+
// enable that here.
15071526
if (func && !func->expressionLocations.empty()) {
15081527
auto& span = binaryLocations.expressions.at(curr);
15091528
span.end = o.size();
@@ -1513,11 +1532,123 @@ void WasmBinaryWriter::trackExpressionEnd(Expression* curr, Function* func) {
15131532
void WasmBinaryWriter::trackExpressionDelimiter(Expression* curr,
15141533
Function* func,
15151534
size_t id) {
1535+
// TODO: If we need to track the delimiters of annotated code locations, we
1536+
// need to enable that here.
15161537
if (func && !func->expressionLocations.empty()) {
15171538
binaryLocations.delimiters[curr][id] = o.size();
15181539
}
15191540
}
15201541

1542+
std::optional<BufferWithRandomAccess> WasmBinaryWriter::writeCodeAnnotations() {
1543+
// Assemble the info for Branch Hinting: for each function, a vector of the
1544+
// hints.
1545+
struct ExprHint {
1546+
Expression* expr;
1547+
// The offset we will write in the custom section.
1548+
BinaryLocation offset;
1549+
Function::CodeAnnotation* hint;
1550+
};
1551+
1552+
struct FuncHints {
1553+
Name func;
1554+
std::vector<ExprHint> exprHints;
1555+
};
1556+
1557+
std::vector<FuncHints> funcHintsVec;
1558+
1559+
for (auto& func : wasm->functions) {
1560+
// Collect the Branch Hints for this function.
1561+
FuncHints funcHints;
1562+
1563+
// We compute the location of the function declaration area (where the
1564+
// locals are declared) the first time we need it.
1565+
BinaryLocation funcDeclarationsOffset = 0;
1566+
1567+
for (auto& [expr, annotation] : func->codeAnnotations) {
1568+
if (annotation.branchLikely) {
1569+
auto exprIter = binaryLocations.expressions.find(expr);
1570+
if (exprIter == binaryLocations.expressions.end()) {
1571+
// No expression exists for this annotation - perhaps optimizations
1572+
// removed it.
1573+
continue;
1574+
}
1575+
auto exprOffset = exprIter->second.start;
1576+
1577+
if (!funcDeclarationsOffset) {
1578+
auto funcIter = binaryLocations.functions.find(func.get());
1579+
assert(funcIter != binaryLocations.functions.end());
1580+
funcDeclarationsOffset = funcIter->second.declarations;
1581+
}
1582+
1583+
// Compute the offset: it should be relative to the start of the
1584+
// function locals (i.e. the function declarations).
1585+
auto offset = exprOffset - funcDeclarationsOffset;
1586+
1587+
funcHints.exprHints.push_back(ExprHint{expr, offset, &annotation});
1588+
}
1589+
}
1590+
1591+
if (funcHints.exprHints.empty()) {
1592+
continue;
1593+
}
1594+
1595+
// We found something. Finalize the data.
1596+
funcHints.func = func->name;
1597+
1598+
// Hints must be sorted by increasing binary offset.
1599+
std::sort(
1600+
funcHints.exprHints.begin(),
1601+
funcHints.exprHints.end(),
1602+
[](const ExprHint& a, const ExprHint& b) { return a.offset < b.offset; });
1603+
1604+
funcHintsVec.emplace_back(std::move(funcHints));
1605+
}
1606+
1607+
if (funcHintsVec.empty()) {
1608+
return {};
1609+
}
1610+
1611+
if (sourceMap) {
1612+
// TODO: This mode may not matter (when debugging, code annotations are an
1613+
// optimization that can be skipped), but atm source maps cause
1614+
// annotations to break.
1615+
Fatal() << "Annotations are not supported with source maps";
1616+
}
1617+
1618+
BufferWithRandomAccess buffer;
1619+
1620+
// We found data: emit the section.
1621+
buffer << uint8_t(BinaryConsts::Custom);
1622+
auto lebPos = buffer.writeU32LEBPlaceholder();
1623+
buffer.writeInlineString(Annotations::BranchHint.str);
1624+
1625+
buffer << U32LEB(funcHintsVec.size());
1626+
for (auto& funcHints : funcHintsVec) {
1627+
buffer << U32LEB(getFunctionIndex(funcHints.func));
1628+
1629+
buffer << U32LEB(funcHints.exprHints.size());
1630+
for (auto& exprHint : funcHints.exprHints) {
1631+
buffer << U32LEB(exprHint.offset);
1632+
1633+
// Hint size, always 1 for now.
1634+
buffer << U32LEB(1);
1635+
1636+
// We must only emit hints that are present.
1637+
assert(exprHint.hint->branchLikely);
1638+
1639+
// Hint contents: likely or not.
1640+
buffer << U32LEB(int(*exprHint.hint->branchLikely));
1641+
}
1642+
}
1643+
1644+
// Write the final size. We can ignore the return value, which is the number
1645+
// of bytes we shrank (if the LEB was smaller than the maximum size), as no
1646+
// value in this section cares.
1647+
buffer.emitRetroactiveSectionSizeLEB(lebPos);
1648+
1649+
return buffer;
1650+
}
1651+
15211652
void WasmBinaryWriter::writeData(const char* data, size_t size) {
15221653
for (size_t i = 0; i < size; i++) {
15231654
o << int8_t(data[i]);
@@ -1792,12 +1923,6 @@ WasmBinaryReader::WasmBinaryReader(Module& wasm,
17921923
}
17931924

17941925
void WasmBinaryReader::preScan() {
1795-
// TODO: Once we support code annotations here, we will need to always scan,
1796-
// but for now, DWARF is the only reason.
1797-
if (!DWARF) {
1798-
return;
1799-
}
1800-
18011926
assert(pos == 0);
18021927
getInt32(); // magic
18031928
getInt32(); // version
@@ -1813,12 +1938,25 @@ void WasmBinaryReader::preScan() {
18131938
auto oldPos = pos;
18141939
if (sectionCode == BinaryConsts::Section::Custom) {
18151940
auto sectionName = getInlineString();
1941+
1942+
// Code annotations require code locations.
1943+
// TODO: For Branch Hinting, we could note which functions require
1944+
// code locations, as an optimization.
1945+
if (sectionName == Annotations::BranchHint) {
1946+
needCodeLocations = true;
1947+
// Do not break, so we keep looking for DWARF.
1948+
}
1949+
18161950
// DWARF sections contain code offsets.
18171951
if (DWARF && Debug::isDWARFSection(sectionName)) {
18181952
needCodeLocations = true;
18191953
foundDWARF = true;
18201954
break;
18211955
}
1956+
1957+
// TODO: We could stop early if we see the Code section and DWARF is
1958+
// disabled, as BranchHint must appear first, but this seems to
1959+
// make practically no difference in practice.
18221960
}
18231961
pos = oldPos + payloadLen;
18241962
}
@@ -1933,6 +2071,12 @@ void WasmBinaryReader::read() {
19332071
}
19342072
}
19352073

2074+
// Go back and parse things we deferred.
2075+
if (branchHintsPos) {
2076+
pos = branchHintsPos;
2077+
readBranchHints(branchHintsLen);
2078+
}
2079+
19362080
validateBinary();
19372081
}
19382082

@@ -1953,6 +2097,10 @@ void WasmBinaryReader::readCustomSection(size_t payloadLen) {
19532097
readDylink(payloadLen);
19542098
} else if (sectionName.equals(BinaryConsts::CustomSections::Dylink0)) {
19552099
readDylink0(payloadLen);
2100+
} else if (sectionName == Annotations::BranchHint) {
2101+
// Only note the position and length, we read this later.
2102+
branchHintsPos = pos;
2103+
branchHintsLen = payloadLen;
19562104
} else {
19572105
// an unfamiliar custom section
19582106
if (sectionName.equals(BinaryConsts::CustomSections::Linking)) {
@@ -5100,6 +5248,62 @@ void WasmBinaryReader::readDylink0(size_t payloadLen) {
51005248
}
51015249
}
51025250

5251+
void WasmBinaryReader::readBranchHints(size_t payloadLen) {
5252+
auto sectionPos = pos;
5253+
5254+
auto numFuncs = getU32LEB();
5255+
for (Index i = 0; i < numFuncs; i++) {
5256+
auto funcIndex = getU32LEB();
5257+
if (funcIndex >= wasm.functions.size()) {
5258+
throwError("bad BranchHint function");
5259+
}
5260+
5261+
auto& func = wasm.functions[funcIndex];
5262+
5263+
// The encoded offsets we read below are relative to the start of the
5264+
// function's locals (the declarations).
5265+
auto funcLocalsOffset = func->funcLocation.declarations;
5266+
5267+
// We have a map of expressions to their locations. Invert that to get the
5268+
// map we will use below, from offsets to expressions.
5269+
std::unordered_map<BinaryLocation, Expression*> locationsMap;
5270+
5271+
for (auto& [expr, span] : func->expressionLocations) {
5272+
locationsMap[span.start] = expr;
5273+
}
5274+
5275+
auto numHints = getU32LEB();
5276+
for (Index hint = 0; hint < numHints; hint++) {
5277+
// To get the absolute offset, add the function's offset.
5278+
auto relativeOffset = getU32LEB();
5279+
auto absoluteOffset = funcLocalsOffset + relativeOffset;
5280+
5281+
auto iter = locationsMap.find(absoluteOffset);
5282+
if (iter == locationsMap.end()) {
5283+
throwError("bad BranchHint offset");
5284+
}
5285+
auto* expr = iter->second;
5286+
5287+
auto size = getU32LEB();
5288+
if (size != 1) {
5289+
throwError("bad BranchHint size");
5290+
}
5291+
5292+
auto likely = getU32LEB();
5293+
if (likely != 0 && likely != 1) {
5294+
throwError("bad BranchHint value");
5295+
}
5296+
5297+
// Apply the valid hint.
5298+
func->codeAnnotations[expr].branchLikely = likely;
5299+
}
5300+
}
5301+
5302+
if (pos != sectionPos + payloadLen) {
5303+
throwError("bad BranchHint section size");
5304+
}
5305+
}
5306+
51035307
Index WasmBinaryReader::readMemoryAccess(Address& alignment, Address& offset) {
51045308
auto rawAlignment = getU32LEB();
51055309
bool hasMemIdx = false;
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
;; NOTE: Assertions have been generated by update_lit_checks.py and should not be edited.
2+
3+
;; RUN: wasm-opt %s --monomorphize --roundtrip -all -S -o - | filecheck %s
4+
5+
;; The callee, which we will monomorphize, has a branch hint. The hinted code
6+
;; will end up vanishing entirely, but not the function it is in, so we end up
7+
;; with an annotation without an instruction in the binary for it. We should
8+
;; ignore it and not error.
9+
(module
10+
;; CHECK: (func $callee (type $1) (param $0 i32)
11+
;; CHECK-NEXT: (nop)
12+
;; CHECK-NEXT: )
13+
(func $callee (param $0 i32)
14+
(block $block
15+
(@metadata.code.branch_hint "\00")
16+
(br_if $block
17+
(i32.const 0)
18+
)
19+
)
20+
)
21+
22+
;; CHECK: (func $caller (type $0)
23+
;; CHECK-NEXT: (call $callee_2)
24+
;; CHECK-NEXT: )
25+
(func $caller
26+
(call $callee
27+
(i32.const 0)
28+
)
29+
)
30+
)
31+

0 commit comments

Comments
 (0)