@@ -214,22 +214,7 @@ class StateFactory {
214214 return true ;
215215 }
216216
217- static StateFactory *instance () {
218- alignas (StateFactory) static char storage[sizeof (StateFactory)]{};
219- static CallOnceFlag flag = callonce_impl::NOT_CALLED;
220- static bool valid = false ;
221- callonce (&flag, []() {
222- auto *factory = new (storage) StateFactory ();
223- valid = factory->prepare ();
224- if (valid)
225- atexit ([]() {
226- auto factory = reinterpret_cast <StateFactory *>(storage);
227- factory->~StateFactory ();
228- valid = false ;
229- });
230- });
231- return valid ? reinterpret_cast <StateFactory *>(storage) : nullptr ;
232- }
217+ static StateFactory *instance ();
233218
234219 void *acquire () {
235220 cpp::lock_guard guard{mutex};
@@ -263,32 +248,62 @@ class StateFactory {
263248 static size_t size_of_opaque_state () {
264249 return instance ()->params .size_of_opaque_state ;
265250 }
251+ static void postfork_cleanup ();
266252};
267253
254+ thread_local bool fork_inflight = false ;
255+ thread_local void *tls_state = nullptr ;
256+ alignas (StateFactory) static char factory_storage[sizeof (StateFactory)]{};
257+ static CallOnceFlag factory_onceflag = callonce_impl::NOT_CALLED;
258+ static bool factory_valid = false ;
259+
260+ StateFactory *StateFactory::instance () {
261+ callonce (&factory_onceflag, []() {
262+ auto *factory = new (factory_storage) StateFactory ();
263+ factory_valid = factory->prepare ();
264+ if (factory_valid)
265+ atexit ([]() {
266+ auto factory = reinterpret_cast <StateFactory *>(factory_storage);
267+ factory->~StateFactory ();
268+ factory_valid = false ;
269+ });
270+ });
271+ return factory_valid ? reinterpret_cast <StateFactory *>(factory_storage)
272+ : nullptr ;
273+ }
274+
275+ void StateFactory::postfork_cleanup () {
276+ if (factory_valid)
277+ reinterpret_cast <StateFactory *>(factory_storage)->~StateFactory ();
278+ factory_onceflag = callonce_impl::NOT_CALLED;
279+ factory_valid = false ;
280+ }
281+
268282void *acquire_tls () {
269- static thread_local void *state = nullptr ;
283+ if (fork_inflight)
284+ return nullptr ;
270285 // previous acquire failed, do not try again
271- if (state == MAP_FAILED)
286+ if (tls_state == MAP_FAILED)
272287 return nullptr ;
273288 // first acquirement
274- if (state == nullptr ) {
275- state = StateFactory::acquire_global ();
289+ if (tls_state == nullptr ) {
290+ tls_state = StateFactory::acquire_global ();
276291 // if still fails, remember the failure
277- if (state == nullptr ) {
278- state = MAP_FAILED;
292+ if (tls_state == nullptr ) {
293+ tls_state = MAP_FAILED;
279294 return nullptr ;
280295 } else {
281296 // register the release callback.
282297 if (__cxa_thread_atexit_impl (
283- [](void *s) { StateFactory::release_global (s); }, state ,
298+ [](void *s) { StateFactory::release_global (s); }, tls_state ,
284299 __dso_handle)) {
285- StateFactory::release_global (state );
286- state = MAP_FAILED;
300+ StateFactory::release_global (tls_state );
301+ tls_state = MAP_FAILED;
287302 return nullptr ;
288303 }
289304 }
290305 }
291- return state ;
306+ return tls_state ;
292307}
293308
294309template <class F > void random_fill_impl (F gen, void *buf, size_t size) {
@@ -331,4 +346,12 @@ void random_fill(void *buf, size_t size) {
331346 }
332347}
333348
349+ void random_prefork () { fork_inflight = true ; }
350+ void random_postfork_parent () { fork_inflight = false ; }
351+ void random_postfork_child () {
352+ tls_state = nullptr ;
353+ StateFactory::postfork_cleanup ();
354+ fork_inflight = false ;
355+ }
356+
334357} // namespace LIBC_NAMESPACE_DECL
0 commit comments