Skip to content

Commit bdce217

Browse files
authored
Move where MPK management happens for async (#10550)
Async suspension/resumption has to deal with saving state. Previously this done in both `AsyncWasmCallState` and in `block_on`, meaning that there were two locations doing pretty similar things. The goal of this commit is to put all "restore the state of the world" logic in one location, so the `FiberFuture::resume` function now exclusively has all of the logic for saving/restoring state around execution of a fiber.
1 parent 2e08d56 commit bdce217

File tree

1 file changed

+98
-42
lines changed

1 file changed

+98
-42
lines changed

crates/wasmtime/src/runtime/store/async_.rs

Lines changed: 98 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ impl StoreOpaque {
231231
let current_poll_cx = self.async_state.current_poll_cx.get();
232232
let current_suspend = self.async_state.current_suspend.get();
233233
let stack = self.allocate_fiber_stack()?;
234+
let track_pkey_context_switch = self.pkey.is_some();
234235

235236
let engine = self.engine().clone();
236237
let slot = &mut slot;
@@ -264,7 +265,14 @@ impl StoreOpaque {
264265
fiber: Some(fiber),
265266
current_poll_cx,
266267
engine,
267-
state: Some(crate::runtime::vm::AsyncWasmCallState::new()),
268+
fiber_resume_state: Some(FiberResumeState {
269+
tls: crate::runtime::vm::AsyncWasmCallState::new(),
270+
mpk: if track_pkey_context_switch {
271+
Some(ProtectionMask::all())
272+
} else {
273+
None
274+
},
275+
}),
268276
}
269277
};
270278
(&mut future).await?;
@@ -280,8 +288,8 @@ impl StoreOpaque {
280288
fiber: Option<wasmtime_fiber::Fiber<'a, Result<()>, (), Result<()>>>,
281289
current_poll_cx: *mut PollContext,
282290
engine: Engine,
283-
// See comments in `FiberFuture::resume` for this
284-
state: Option<crate::runtime::vm::AsyncWasmCallState>,
291+
// See comments in `FiberResumeState` for this
292+
fiber_resume_state: Option<FiberResumeState>,
285293
}
286294

287295
// This is surely the most dangerous `unsafe impl Send` in the entire
@@ -353,42 +361,36 @@ impl StoreOpaque {
353361
}
354362

355363
/// This is a helper function to call `resume` on the underlying
356-
/// fiber while correctly managing Wasmtime's thread-local data.
364+
/// fiber while correctly managing Wasmtime's state that the fiber
365+
/// may clobber.
357366
///
358-
/// Wasmtime's implementation of traps leverages thread-local data
359-
/// to get access to metadata during a signal. This thread-local
360-
/// data is a linked list of "activations" where the nodes of the
361-
/// linked list are stored on the stack. It would be invalid as a
362-
/// result to suspend a computation with the head of the linked list
363-
/// on this stack then move the stack to another thread and resume
364-
/// it. That means that a different thread would point to our stack
365-
/// and our thread doesn't point to our stack at all!
367+
/// ## Return Value
366368
///
367-
/// Basically management of TLS is required here one way or another.
368-
/// The strategy currently settled on is to manage the list of
369-
/// activations created by this fiber as a unit. When a fiber
370-
/// resumes the linked list is prepended to the current thread's
371-
/// list. When the fiber is suspended then the fiber's list of
372-
/// activations are all removed en-masse and saved within the fiber.
369+
/// * `Ok(Ok(()))` - the fiber successfully completed and yielded a
370+
/// successful result.
371+
/// * `Ok(Err(e))` - the fiber successfully completed and yielded
372+
/// an error as a result of computation.
373+
/// * `Err(())` - the fiber has not finished and it is suspended.
373374
fn resume(&mut self, val: Result<()>) -> Result<Result<()>, ()> {
374375
unsafe {
375-
let prev = self.state.take().unwrap().push();
376+
let prev = self.fiber_resume_state.take().unwrap().replace();
376377
let restore = Restore {
377378
fiber: self,
378-
state: Some(prev),
379+
prior_fiber_state: Some(prev),
379380
};
380381
return restore.fiber.fiber().resume(val);
381382
}
382383

383384
struct Restore<'a, 'b> {
384385
fiber: &'a mut FiberFuture<'b>,
385-
state: Option<crate::runtime::vm::PreviousAsyncWasmCallState>,
386+
prior_fiber_state: Option<PriorFiberResumeState>,
386387
}
387388

388389
impl Drop for Restore<'_, '_> {
389390
fn drop(&mut self) {
390391
unsafe {
391-
self.fiber.state = Some(self.state.take().unwrap().restore());
392+
self.fiber.fiber_resume_state =
393+
Some(self.prior_fiber_state.take().unwrap().replace());
392394
}
393395
}
394396
}
@@ -489,13 +491,19 @@ impl StoreOpaque {
489491
if !self.fiber().done() {
490492
let result = self.resume(Err(anyhow!("future dropped")));
491493
// This resumption with an error should always complete the
492-
// fiber. While it's technically possible for host code to catch
493-
// the trap and re-resume, we'd ideally like to signal that to
494-
// callers that they shouldn't be doing that.
494+
// fiber. While it's technically possible for host code to
495+
// catch the trap and re-resume, we'd ideally like to
496+
// signal that to callers that they shouldn't be doing
497+
// that.
495498
debug_assert!(result.is_ok());
499+
500+
// Note that `result` is `Ok(r)` where `r` is either
501+
// `Ok(())` or `Err(e)`. If it's an error that's disposed of
502+
// here. It's expected to be a propagation of the `future
503+
// dropped` error created above.
496504
}
497505

498-
self.state.take().unwrap().assert_null();
506+
self.fiber_resume_state.take().unwrap().dispose();
499507

500508
unsafe {
501509
self.engine
@@ -504,6 +512,70 @@ impl StoreOpaque {
504512
}
505513
}
506514
}
515+
516+
/// State of the world when a fiber last suspended.
517+
///
518+
/// This structure represents global state that a fiber clobbers during
519+
/// its execution. For example TLS variables are updated, system
520+
/// resources like MPK masks are updated, etc. The purpose of this
521+
/// structure is to track all of this state and appropriately
522+
/// save/restore it around fiber suspension points.
523+
struct FiberResumeState {
524+
/// Saved list of `CallThreadState` activations that are stored on a
525+
/// fiber stack.
526+
///
527+
/// This is a linked list that references stack-stored nodes on the
528+
/// fiber stack that is currently suspended. The
529+
/// `AsyncWasmCallState` type documents this more thoroughly but the
530+
/// general gist is that when we this fiber is resumed this linked
531+
/// list needs to be pushed on to the current thread's linked list
532+
/// of activations.
533+
tls: crate::runtime::vm::AsyncWasmCallState,
534+
535+
/// Saved MPK protection mask, if enabled.
536+
///
537+
/// When MPK is enabled then executing WebAssembly will modify the
538+
/// processor's current mask of addressable protection keys. This
539+
/// means that our current state may get clobbered when a fiber
540+
/// suspends. To ensure that this function preserves context it
541+
/// will, when MPK is enabled, save the current mask when this
542+
/// function is called and then restore the mask when the function
543+
/// returns (aka the fiber suspends).
544+
mpk: Option<ProtectionMask>,
545+
}
546+
547+
impl FiberResumeState {
548+
unsafe fn replace(self) -> PriorFiberResumeState {
549+
let tls = self.tls.push();
550+
let mpk = swap_mpk_states(self.mpk);
551+
PriorFiberResumeState { tls, mpk }
552+
}
553+
554+
fn dispose(self) {
555+
self.tls.assert_null();
556+
}
557+
}
558+
559+
struct PriorFiberResumeState {
560+
tls: crate::runtime::vm::PreviousAsyncWasmCallState,
561+
mpk: Option<ProtectionMask>,
562+
}
563+
564+
impl PriorFiberResumeState {
565+
unsafe fn replace(self) -> FiberResumeState {
566+
let tls = self.tls.restore();
567+
let mpk = swap_mpk_states(self.mpk);
568+
FiberResumeState { tls, mpk }
569+
}
570+
}
571+
572+
fn swap_mpk_states(mask: Option<ProtectionMask>) -> Option<ProtectionMask> {
573+
mask.map(|mask| {
574+
let current = mpk::current_mask();
575+
mpk::allow(mask);
576+
current
577+
})
578+
}
507579
}
508580

509581
#[cfg(feature = "gc")]
@@ -585,7 +657,6 @@ impl StoreOpaque {
585657
Some(AsyncCx {
586658
current_suspend: self.async_state.current_suspend.get(),
587659
current_poll_cx: unsafe { &raw mut (*poll_cx_box_ptr).future_context },
588-
track_pkey_context_switch: self.pkey.is_some(),
589660
})
590661
}
591662

@@ -663,7 +734,6 @@ impl<T> StoreContextMut<'_, T> {
663734
pub struct AsyncCx {
664735
current_suspend: *mut *mut wasmtime_fiber::Suspend<Result<()>, (), Result<()>>,
665736
current_poll_cx: *mut *mut Context<'static>,
666-
track_pkey_context_switch: bool,
667737
}
668738

669739
impl AsyncCx {
@@ -725,21 +795,7 @@ impl AsyncCx {
725795
Poll::Pending => {}
726796
}
727797

728-
// In order to prevent this fiber's MPK state from being munged by
729-
// other fibers while it is suspended, we save and restore it once
730-
// once execution resumes. Note that when MPK is not supported,
731-
// these are noops.
732-
let previous_mask = if self.track_pkey_context_switch {
733-
let previous_mask = mpk::current_mask();
734-
mpk::allow(ProtectionMask::all());
735-
previous_mask
736-
} else {
737-
ProtectionMask::all()
738-
};
739798
(*suspend).suspend(())?;
740-
if self.track_pkey_context_switch {
741-
mpk::allow(previous_mask);
742-
}
743799
}
744800
}
745801
}

0 commit comments

Comments
 (0)