@@ -63,30 +63,21 @@ static Name getStoreName(Store* curr) {
6363}
6464
6565struct AccessInstrumenter : public WalkerPass <PostWalker<AccessInstrumenter>> {
66- // If the getSbrkPtr function is implemented in the wasm, we must not
67- // instrument that, as it would lead to infinite recursion of it calling
68- // SAFE_HEAP_LOAD that calls it and so forth.
69- // As well as the getSbrkPtr function we also avoid instrumenting the
70- // module start function. This is because this function is used in
71- // shared memory builds to load the passive memory segments, which in
72- // turn means that value of sbrk() is not available.
73- Name getSbrkPtr;
66+ // A set of function that we should ignore (not instrument).
67+ std::set<Name> ignoreFunctions;
7468
7569 bool isFunctionParallel () override { return true ; }
7670
7771 AccessInstrumenter* create () override {
78- return new AccessInstrumenter (getSbrkPtr );
72+ return new AccessInstrumenter (ignoreFunctions );
7973 }
8074
81- AccessInstrumenter (Name getSbrkPtr) : getSbrkPtr(getSbrkPtr) {}
75+ AccessInstrumenter (std::set<Name> ignoreFunctions)
76+ : ignoreFunctions(ignoreFunctions) {}
8277
8378 void visitLoad (Load* curr) {
84- // As well as the getSbrkPtr function we also avoid insturmenting the
85- // module start function. This is because this function is used in
86- // shared memory builds to load the passive memory segments, which in
87- // turn means that value of sbrk() is not available.
88- if (getFunction ()->name == getModule ()->start ||
89- getFunction ()->name == getSbrkPtr || curr->type == Type::unreachable) {
79+ if (ignoreFunctions.count (getFunction ()->name ) != 0 ||
80+ curr->type == Type::unreachable) {
9081 return ;
9182 }
9283 Builder builder (*getModule ());
@@ -97,8 +88,8 @@ struct AccessInstrumenter : public WalkerPass<PostWalker<AccessInstrumenter>> {
9788 }
9889
9990 void visitStore (Store* curr) {
100- if (getFunction ()->name == getModule ()-> start ||
101- getFunction ()-> name == getSbrkPtr || curr->type == Type::unreachable) {
91+ if (ignoreFunctions. count ( getFunction ()->name ) != 0 ||
92+ curr->type == Type::unreachable) {
10293 return ;
10394 }
10495 Builder builder (*getModule ());
@@ -109,6 +100,12 @@ struct AccessInstrumenter : public WalkerPass<PostWalker<AccessInstrumenter>> {
109100 }
110101};
111102
103+ struct FindDirectCallees : public WalkerPass <PostWalker<FindDirectCallees>> {
104+ public:
105+ void visitCall (Call* curr) { callees.insert (curr->target ); }
106+ std::set<Name> callees;
107+ };
108+
112109struct SafeHeap : public Pass {
113110 PassOptions options;
114111
@@ -117,12 +114,31 @@ struct SafeHeap : public Pass {
117114 // add imports
118115 addImports (module );
119116 // instrument loads and stores
120- AccessInstrumenter (getSbrkPtr).run (runner, module );
117+ // We avoid instrumenting the module start function of any function
118+ // that it directly calls. This is because in some cases the linker
119+ // generates `__wasm_init_memory` (either as the start function or
120+ // a function directly called from it) and this function is used in shared
121+ // memory builds to load the passive memory segments, which in turn means
122+ // that value of sbrk() is not available until after it has run.
123+ std::set<Name> ignoreFunctions;
124+ if (module ->start .is ()) {
125+ // Note that this only finds directly called functions, not transitively
126+ // called ones. That is enough given the current LLVM output as start
127+ // will only contain very specific, linker-generated code
128+ // (__wasm_init_memory etc. as mentioned above).
129+ FindDirectCallees findDirectCallees;
130+ findDirectCallees.walkFunctionInModule (module ->getFunction (module ->start ),
131+ module );
132+ ignoreFunctions = findDirectCallees.callees ;
133+ ignoreFunctions.insert (module ->start );
134+ }
135+ ignoreFunctions.insert (getSbrkPtr);
136+ AccessInstrumenter (ignoreFunctions).run (runner, module );
121137 // add helper checking funcs and imports
122138 addGlobals (module , module ->features );
123139 }
124140
125- Name dynamicTopPtr, getSbrkPtr , sbrk, segfault, alignfault;
141+ Name getSbrkPtr, dynamicTopPtr , sbrk, segfault, alignfault;
126142
127143 void addImports (Module* module ) {
128144 ImportInfo info (*module );
0 commit comments