Skip to content

Commit 7796031

Browse files
authored
SafeHeap: Avoid instrumenting functions directly called from the "start" (#4439)
1 parent c12de34 commit 7796031

File tree

3 files changed

+53
-23
lines changed

3 files changed

+53
-23
lines changed

src/passes/SafeHeap.cpp

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -63,30 +63,21 @@ static Name getStoreName(Store* curr) {
6363
}
6464

6565
struct AccessInstrumenter : public WalkerPass<PostWalker<AccessInstrumenter>> {
66-
// If the getSbrkPtr function is implemented in the wasm, we must not
67-
// instrument that, as it would lead to infinite recursion of it calling
68-
// SAFE_HEAP_LOAD that calls it and so forth.
69-
// As well as the getSbrkPtr function we also avoid instrumenting the
70-
// module start function. This is because this function is used in
71-
// shared memory builds to load the passive memory segments, which in
72-
// turn means that value of sbrk() is not available.
73-
Name getSbrkPtr;
66+
// A set of function that we should ignore (not instrument).
67+
std::set<Name> ignoreFunctions;
7468

7569
bool isFunctionParallel() override { return true; }
7670

7771
AccessInstrumenter* create() override {
78-
return new AccessInstrumenter(getSbrkPtr);
72+
return new AccessInstrumenter(ignoreFunctions);
7973
}
8074

81-
AccessInstrumenter(Name getSbrkPtr) : getSbrkPtr(getSbrkPtr) {}
75+
AccessInstrumenter(std::set<Name> ignoreFunctions)
76+
: ignoreFunctions(ignoreFunctions) {}
8277

8378
void visitLoad(Load* curr) {
84-
// As well as the getSbrkPtr function we also avoid insturmenting the
85-
// module start function. This is because this function is used in
86-
// shared memory builds to load the passive memory segments, which in
87-
// turn means that value of sbrk() is not available.
88-
if (getFunction()->name == getModule()->start ||
89-
getFunction()->name == getSbrkPtr || curr->type == Type::unreachable) {
79+
if (ignoreFunctions.count(getFunction()->name) != 0 ||
80+
curr->type == Type::unreachable) {
9081
return;
9182
}
9283
Builder builder(*getModule());
@@ -97,8 +88,8 @@ struct AccessInstrumenter : public WalkerPass<PostWalker<AccessInstrumenter>> {
9788
}
9889

9990
void visitStore(Store* curr) {
100-
if (getFunction()->name == getModule()->start ||
101-
getFunction()->name == getSbrkPtr || curr->type == Type::unreachable) {
91+
if (ignoreFunctions.count(getFunction()->name) != 0 ||
92+
curr->type == Type::unreachable) {
10293
return;
10394
}
10495
Builder builder(*getModule());
@@ -109,6 +100,12 @@ struct AccessInstrumenter : public WalkerPass<PostWalker<AccessInstrumenter>> {
109100
}
110101
};
111102

103+
struct FindDirectCallees : public WalkerPass<PostWalker<FindDirectCallees>> {
104+
public:
105+
void visitCall(Call* curr) { callees.insert(curr->target); }
106+
std::set<Name> callees;
107+
};
108+
112109
struct SafeHeap : public Pass {
113110
PassOptions options;
114111

@@ -117,12 +114,31 @@ struct SafeHeap : public Pass {
117114
// add imports
118115
addImports(module);
119116
// instrument loads and stores
120-
AccessInstrumenter(getSbrkPtr).run(runner, module);
117+
// We avoid instrumenting the module start function of any function
118+
// that it directly calls. This is because in some cases the linker
119+
// generates `__wasm_init_memory` (either as the start function or
120+
// a function directly called from it) and this function is used in shared
121+
// memory builds to load the passive memory segments, which in turn means
122+
// that value of sbrk() is not available until after it has run.
123+
std::set<Name> ignoreFunctions;
124+
if (module->start.is()) {
125+
// Note that this only finds directly called functions, not transitively
126+
// called ones. That is enough given the current LLVM output as start
127+
// will only contain very specific, linker-generated code
128+
// (__wasm_init_memory etc. as mentioned above).
129+
FindDirectCallees findDirectCallees;
130+
findDirectCallees.walkFunctionInModule(module->getFunction(module->start),
131+
module);
132+
ignoreFunctions = findDirectCallees.callees;
133+
ignoreFunctions.insert(module->start);
134+
}
135+
ignoreFunctions.insert(getSbrkPtr);
136+
AccessInstrumenter(ignoreFunctions).run(runner, module);
121137
// add helper checking funcs and imports
122138
addGlobals(module, module->features);
123139
}
124140

125-
Name dynamicTopPtr, getSbrkPtr, sbrk, segfault, alignfault;
141+
Name getSbrkPtr, dynamicTopPtr, sbrk, segfault, alignfault;
126142

127143
void addImports(Module* module) {
128144
ImportInfo info(*module);

test/passes/safe-heap_start-function.txt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,16 @@
1313
(import "env" "segfault" (func $segfault))
1414
(import "env" "alignfault" (func $alignfault))
1515
(memory $0 1 1)
16-
(start $foo)
16+
(start $mystart)
17+
(func $mystart
18+
(i32.store
19+
(i32.load
20+
(i32.const 42)
21+
)
22+
(i32.const 43)
23+
)
24+
(call $foo)
25+
)
1726
(func $foo
1827
(i32.store
1928
(i32.load
Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
(module
22
(memory 1 1)
3-
(func $foo
3+
(func $mystart
44
;; should not be modified because its the start function
5+
(i32.store (i32.load (i32.const 42)) (i32.const 43))
6+
(call $foo)
7+
)
8+
(func $foo
9+
;; should not be modified because its called from the start function
510
(i32.store (i32.load (i32.const 1234)) (i32.const 5678))
611
)
712
(func $bar
813
(i32.store (i32.load (i32.const 1234)) (i32.const 5678))
914
)
10-
(start $foo)
15+
(start $mystart)
1116
)

0 commit comments

Comments
 (0)