Skip to content

Commit d8b722a

Browse files
add fork hooks
1 parent a09f3f3 commit d8b722a

File tree

2 files changed

+52
-26
lines changed

2 files changed

+52
-26
lines changed

libc/src/__support/OSUtil/linux/random.cpp

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -214,22 +214,7 @@ class StateFactory {
214214
return true;
215215
}
216216

217-
static StateFactory *instance() {
218-
alignas(StateFactory) static char storage[sizeof(StateFactory)]{};
219-
static CallOnceFlag flag = callonce_impl::NOT_CALLED;
220-
static bool valid = false;
221-
callonce(&flag, []() {
222-
auto *factory = new (storage) StateFactory();
223-
valid = factory->prepare();
224-
if (valid)
225-
atexit([]() {
226-
auto factory = reinterpret_cast<StateFactory *>(storage);
227-
factory->~StateFactory();
228-
valid = false;
229-
});
230-
});
231-
return valid ? reinterpret_cast<StateFactory *>(storage) : nullptr;
232-
}
217+
static StateFactory *instance();
233218

234219
void *acquire() {
235220
cpp::lock_guard guard{mutex};
@@ -263,32 +248,62 @@ class StateFactory {
263248
static size_t size_of_opaque_state() {
264249
return instance()->params.size_of_opaque_state;
265250
}
251+
static void postfork_cleanup();
266252
};
267253

254+
thread_local bool fork_inflight = false;
255+
thread_local void *tls_state = nullptr;
256+
alignas(StateFactory) static char factory_storage[sizeof(StateFactory)]{};
257+
static CallOnceFlag factory_onceflag = callonce_impl::NOT_CALLED;
258+
static bool factory_valid = false;
259+
260+
StateFactory *StateFactory::instance() {
261+
callonce(&factory_onceflag, []() {
262+
auto *factory = new (factory_storage) StateFactory();
263+
factory_valid = factory->prepare();
264+
if (factory_valid)
265+
atexit([]() {
266+
auto factory = reinterpret_cast<StateFactory *>(factory_storage);
267+
factory->~StateFactory();
268+
factory_valid = false;
269+
});
270+
});
271+
return factory_valid ? reinterpret_cast<StateFactory *>(factory_storage)
272+
: nullptr;
273+
}
274+
275+
void StateFactory::postfork_cleanup() {
276+
if (factory_valid)
277+
reinterpret_cast<StateFactory *>(factory_storage)->~StateFactory();
278+
factory_onceflag = callonce_impl::NOT_CALLED;
279+
factory_valid = false;
280+
}
281+
268282
void *acquire_tls() {
269-
static thread_local void *state = nullptr;
283+
if (fork_inflight)
284+
return nullptr;
270285
// previous acquire failed, do not try again
271-
if (state == MAP_FAILED)
286+
if (tls_state == MAP_FAILED)
272287
return nullptr;
273288
// first acquirement
274-
if (state == nullptr) {
275-
state = StateFactory::acquire_global();
289+
if (tls_state == nullptr) {
290+
tls_state = StateFactory::acquire_global();
276291
// if still fails, remember the failure
277-
if (state == nullptr) {
278-
state = MAP_FAILED;
292+
if (tls_state == nullptr) {
293+
tls_state = MAP_FAILED;
279294
return nullptr;
280295
} else {
281296
// register the release callback.
282297
if (__cxa_thread_atexit_impl(
283-
[](void *s) { StateFactory::release_global(s); }, state,
298+
[](void *s) { StateFactory::release_global(s); }, tls_state,
284299
__dso_handle)) {
285-
StateFactory::release_global(state);
286-
state = MAP_FAILED;
300+
StateFactory::release_global(tls_state);
301+
tls_state = MAP_FAILED;
287302
return nullptr;
288303
}
289304
}
290305
}
291-
return state;
306+
return tls_state;
292307
}
293308

294309
template <class F> void random_fill_impl(F gen, void *buf, size_t size) {
@@ -331,4 +346,12 @@ void random_fill(void *buf, size_t size) {
331346
}
332347
}
333348

349+
void random_prefork() { fork_inflight = true; }
350+
void random_postfork_parent() { fork_inflight = false; }
351+
void random_postfork_child() {
352+
tls_state = nullptr;
353+
StateFactory::postfork_cleanup();
354+
fork_inflight = false;
355+
}
356+
334357
} // namespace LIBC_NAMESPACE_DECL

libc/src/__support/OSUtil/linux/random.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,8 @@
1616

1717
namespace LIBC_NAMESPACE_DECL {
1818
void random_fill(void *buf, unsigned long size);
19+
void random_prefork();
20+
void random_postfork_parent();
21+
void random_postfork_child();
1922
} // namespace LIBC_NAMESPACE_DECL
2023
#endif // LLVM_LIBC_SRC___SUPPORT_RANDOMNESS_H

0 commit comments

Comments
 (0)