Skip to content

Commit 011bd7a

Browse files
kalyazinroypat
authored andcommitted
test(uffd_utils): add handling for FaultRequest in secret freedom
There are two ways a UFFD handler receives a fault notification if Secret Fredom is enabled (which is inferred from 3 fds sent by Firecracker instead of 1): - a VMM- or KVM-triggered fault is delivered via a minor UFFD fault event. The handler is supposed to respond to it via memcpying the content of the page (if the page hasn't already been populated) followed by a UFFDIO_CONTINUE call. - a vCPU-triggered fault is delievered via a FaultRequest message on the UDS socket. The handler is supposed to reply with a pwrite64 call on the guest_memfd to populate the page followed by a FaultReply message on the UDS socket. In both cases, the handler also needs to clear the bit in the userfault bitmap at the corresponding offset in order to stop further fault notifications for the same page. UFFD handlers use the userfault bitmap for two purposes: - communicate to the kernel whether a fault at the corresponding guest_memfd offset will cause a VM exit - keep track of pages that have already been populated in order to avoid overwriting the content of the page that is already initialised. Signed-off-by: Nikita Kalyazin <[email protected]>
1 parent f0c0208 commit 011bd7a

File tree

3 files changed

+253
-18
lines changed

3 files changed

+253
-18
lines changed

src/firecracker/examples/uffd/fault_all_handler.rs

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,19 @@
55
//! which loads the whole region from the backing memory file
66
//! when a page fault occurs.
77
8+
#![allow(clippy::cast_possible_truncation)]
9+
810
mod uffd_utils;
911

1012
use std::fs::File;
13+
use std::os::fd::AsRawFd;
1114
use std::os::unix::net::UnixListener;
1215

1316
use uffd_utils::{Runtime, UffdHandler};
1417
use utils::time::{ClockType, get_time_us};
1518

19+
use crate::uffd_utils::uffd_continue;
20+
1621
fn main() {
1722
let mut args = std::env::args();
1823
let uffd_sock_path = args.nth(1).expect("No socket path given");
@@ -37,19 +42,69 @@ fn main() {
3742
.expect("Failed to read uffd_msg")
3843
.expect("uffd_msg not ready");
3944

40-
match event {
41-
userfaultfd::Event::Pagefault { .. } => {
42-
let start = get_time_us(ClockType::Monotonic);
43-
for region in uffd_handler.mem_regions.clone() {
44-
uffd_handler.serve_pf(region.base_host_virt_addr as _, region.size);
45-
}
46-
let end = get_time_us(ClockType::Monotonic);
45+
if let userfaultfd::Event::Pagefault { addr, .. } = event {
46+
let bit =
47+
uffd_handler.addr_to_offset(addr.cast()) as usize / uffd_handler.page_size;
48+
49+
// If Secret Free, we know if this is the first fault based on the userfault
50+
// bitmap state. Otherwise, we assume that we will ever only receive a single fault
51+
// event via UFFD.
52+
let are_we_faulted_yet = uffd_handler
53+
.userfault_bitmap
54+
.as_mut()
55+
.is_some_and(|bitmap| !bitmap.is_bit_set(bit));
4756

48-
println!("Finished Faulting All: {}us", end - start);
57+
if are_we_faulted_yet {
58+
// TODO: we currently ignore the result as we may attempt to
59+
// populate the page that is already present as we may receive
60+
// multiple minor fault events per page.
61+
let _ = uffd_continue(
62+
uffd_handler.uffd.as_raw_fd(),
63+
addr as _,
64+
uffd_handler.page_size as u64,
65+
)
66+
.inspect_err(|err| println!("Error during uffdio_continue: {:?}", err));
67+
} else {
68+
fault_all(uffd_handler, addr);
4969
}
50-
_ => panic!("Unexpected event on userfaultfd"),
5170
}
5271
},
5372
|_uffd_handler: &mut UffdHandler, _offset: usize| {},
5473
);
5574
}
75+
76+
fn fault_all(uffd_handler: &mut UffdHandler, fault_addr: *mut libc::c_void) {
77+
let start = get_time_us(ClockType::Monotonic);
78+
for region in uffd_handler.mem_regions.clone() {
79+
match uffd_handler.guest_memfd {
80+
None => {
81+
uffd_handler.serve_pf(region.base_host_virt_addr as _, region.size);
82+
}
83+
Some(_) => {
84+
let written = uffd_handler.populate_via_write(region.offset as usize, region.size);
85+
86+
// This code is written under the assumption that the first fault triggered by
87+
// Firecracker is either due to an MSR write (on x86) or due to device restoration
88+
// reading from guest memory to check the virtio queues are sane (on
89+
// ARM). This will be reported via a UFFD minor fault which needs to
90+
// be handled via memcpy. Importantly, we get to the UFFD handler
91+
// with the actual guest_memfd page already faulted in, meaning pwrite will stop
92+
// once it gets to the offset of that page (e.g. written < region.size above).
93+
// Thus, to fault in everything, we now need to skip this one page, write the
94+
// remaining region, and then deal with the "gap" via uffd_handler.serve_pf().
95+
96+
if written < region.size - uffd_handler.page_size {
97+
let r = uffd_handler.populate_via_write(
98+
region.offset as usize + written + uffd_handler.page_size,
99+
region.size - written - uffd_handler.page_size,
100+
);
101+
assert_eq!(written + r, region.size - uffd_handler.page_size);
102+
}
103+
}
104+
}
105+
}
106+
uffd_handler.serve_pf(fault_addr.cast(), uffd_handler.page_size);
107+
let end = get_time_us(ClockType::Monotonic);
108+
109+
println!("Finished Faulting All: {}us", end - start);
110+
}

src/firecracker/examples/uffd/on_demand_handler.rs

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,18 @@
55
//! which loads the whole region from the backing memory file
66
//! when a page fault occurs.
77
8+
#![allow(clippy::cast_possible_truncation)]
9+
810
mod uffd_utils;
911

1012
use std::fs::File;
13+
use std::os::fd::AsRawFd;
1114
use std::os::unix::net::UnixListener;
1215

1316
use uffd_utils::{Runtime, UffdHandler};
1417

18+
use crate::uffd_utils::uffd_continue;
19+
1520
fn main() {
1621
let mut args = std::env::args();
1722
let uffd_sock_path = args.nth(1).expect("No socket path given");
@@ -90,7 +95,33 @@ fn main() {
9095
// event (if the balloon device is enabled).
9196
match event {
9297
userfaultfd::Event::Pagefault { addr, .. } => {
93-
if !uffd_handler.serve_pf(addr.cast(), uffd_handler.page_size) {
98+
let bit = uffd_handler.addr_to_offset(addr.cast()) as usize
99+
/ uffd_handler.page_size;
100+
101+
if uffd_handler.userfault_bitmap.is_some() {
102+
if uffd_handler
103+
.userfault_bitmap
104+
.as_mut()
105+
.unwrap()
106+
.is_bit_set(bit)
107+
{
108+
if !uffd_handler.serve_pf(addr.cast(), uffd_handler.page_size) {
109+
deferred_events.push(event);
110+
}
111+
} else {
112+
// TODO: we currently ignore the result as we may attempt to
113+
// populate the page that is already present as we may receive
114+
// multiple minor fault events per page.
115+
let _ = uffd_continue(
116+
uffd_handler.uffd.as_raw_fd(),
117+
addr as _,
118+
uffd_handler.page_size as u64,
119+
)
120+
.inspect_err(|err| {
121+
println!("uffdio_continue error: {:?}", err)
122+
});
123+
}
124+
} else if !uffd_handler.serve_pf(addr.cast(), uffd_handler.page_size) {
94125
deferred_events.push(event);
95126
}
96127
}
@@ -111,6 +142,17 @@ fn main() {
111142
}
112143
}
113144
},
114-
|_uffd_handler: &mut UffdHandler, _offset: usize| {},
145+
|uffd_handler: &mut UffdHandler, offset: usize| {
146+
let bytes_written = uffd_handler.populate_via_write(offset, uffd_handler.page_size);
147+
148+
if bytes_written == 0 {
149+
println!(
150+
"got a vcpu fault for an already populated page at offset {}",
151+
offset
152+
);
153+
} else {
154+
assert_eq!(bytes_written, uffd_handler.page_size);
155+
}
156+
},
115157
);
116158
}

src/firecracker/examples/uffd/uffd_utils.rs

Lines changed: 145 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
clippy::cast_sign_loss,
77
clippy::undocumented_unsafe_blocks,
88
clippy::ptr_as_ptr,
9+
clippy::cast_possible_wrap,
910
// Not everything is used by both binaries
1011
dead_code
1112
)]
@@ -17,6 +18,7 @@ use std::ffi::c_void;
1718
use std::fs::File;
1819
use std::io::{Read, Write};
1920
use std::num::NonZero;
21+
use std::os::fd::RawFd;
2022
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd};
2123
use std::os::unix::net::UnixStream;
2224
use std::ptr;
@@ -26,10 +28,47 @@ use std::time::Duration;
2628
use serde::{Deserialize, Serialize};
2729
use serde_json::{Deserializer, StreamDeserializer};
2830
use userfaultfd::{Error, Event, Uffd};
31+
use vmm_sys_util::ioctl::ioctl_with_mut_ref;
32+
use vmm_sys_util::ioctl_iowr_nr;
2933
use vmm_sys_util::sock_ctrl_msg::ScmSocket;
3034

3135
use crate::uffd_utils::userfault_bitmap::UserfaultBitmap;
3236

37+
// TODO: remove when UFFDIO_CONTINUE for guest_memfd is available in the crate
38+
#[repr(C)]
39+
struct uffdio_continue {
40+
range: uffdio_range,
41+
mode: u64,
42+
mapped: u64,
43+
}
44+
45+
ioctl_iowr_nr!(UFFDIO_CONTINUE, 0xAA, 0x7, uffdio_continue);
46+
47+
#[repr(C)]
48+
struct uffdio_range {
49+
start: u64,
50+
len: u64,
51+
}
52+
53+
pub fn uffd_continue(uffd: RawFd, fault_addr: u64, len: u64) -> std::io::Result<()> {
54+
let mut cont = uffdio_continue {
55+
range: uffdio_range {
56+
start: fault_addr,
57+
len,
58+
},
59+
mode: 0, // Normal continuation mode
60+
mapped: 0,
61+
};
62+
63+
let ret = unsafe { ioctl_with_mut_ref(&uffd, UFFDIO_CONTINUE(), &mut cont) };
64+
65+
if ret == -1 {
66+
return Err(std::io::Error::last_os_error());
67+
}
68+
69+
Ok(())
70+
}
71+
3372
// This is the same with the one used in src/vmm.
3473
/// This describes the mapping between Firecracker base virtual address and offset in the
3574
/// buffer or file backend for a guest memory region. It is used to tell an external
@@ -122,7 +161,7 @@ pub struct UffdHandler {
122161
pub mem_regions: Vec<GuestRegionUffdMapping>,
123162
pub page_size: usize,
124163
backing_buffer: *const u8,
125-
uffd: Uffd,
164+
pub uffd: Uffd,
126165
removed_pages: HashSet<u64>,
127166
pub guest_memfd: Option<File>,
128167
pub guest_memfd_addr: Option<*mut u8>,
@@ -266,6 +305,20 @@ impl UffdHandler {
266305
}
267306
}
268307

308+
pub fn addr_to_offset(&self, addr: *mut u8) -> u64 {
309+
let addr = addr as u64;
310+
for region in &self.mem_regions {
311+
if region.contains(addr) {
312+
return addr - region.base_host_virt_addr + region.offset;
313+
}
314+
}
315+
316+
panic!(
317+
"Could not find addr: {:#x} within guest region mappings.",
318+
addr
319+
);
320+
}
321+
269322
pub fn serve_pf(&mut self, addr: *mut u8, len: usize) -> bool {
270323
// Find the start of the page that the current faulting address belongs to.
271324
let dst = (addr as usize & !(self.page_size - 1)) as *mut libc::c_void;
@@ -278,7 +331,7 @@ impl UffdHandler {
278331

279332
for region in self.mem_regions.iter() {
280333
if region.contains(fault_page_addr) {
281-
return self.populate_from_file(region, fault_page_addr, len);
334+
return self.populate_from_file(&region.clone(), fault_page_addr, len);
282335
}
283336
}
284337

@@ -292,12 +345,61 @@ impl UffdHandler {
292345
self.mem_regions.iter().map(|r| r.size).sum()
293346
}
294347

295-
fn populate_from_file(&self, region: &GuestRegionUffdMapping, dst: u64, len: usize) -> bool {
296-
let offset = dst - region.base_host_virt_addr;
297-
let src = self.backing_buffer as u64 + region.offset + offset;
348+
pub fn populate_via_write(&mut self, offset: usize, len: usize) -> usize {
349+
// man 2 write:
350+
//
351+
// On Linux, write() (and similar system calls) will transfer at most
352+
// 0x7ffff000 (2,147,479,552) bytes, returning the number of bytes
353+
// actually transferred. (This is true on both 32-bit and 64-bit
354+
// systems.)
355+
const MAX_WRITE_LEN: usize = 2_147_479_552;
356+
357+
assert!(
358+
offset.checked_add(len).unwrap() <= self.size(),
359+
"{} + {} >= {}",
360+
offset,
361+
len,
362+
self.size()
363+
);
364+
365+
let mut total_written = 0;
366+
367+
while total_written < len {
368+
let src = unsafe { self.backing_buffer.add(offset + total_written) };
369+
let len_to_write = (len - total_written).min(MAX_WRITE_LEN);
370+
let bytes_written = unsafe {
371+
libc::pwrite64(
372+
self.guest_memfd.as_ref().unwrap().as_raw_fd(),
373+
src.cast(),
374+
len_to_write,
375+
(offset + total_written) as libc::off64_t,
376+
)
377+
};
378+
379+
let bytes_written = match bytes_written {
380+
-1 if vmm_sys_util::errno::Error::last().errno() == libc::ENOSPC => 0,
381+
written @ 0.. => written as usize,
382+
_ => panic!("{:?}", std::io::Error::last_os_error()),
383+
};
384+
385+
self.userfault_bitmap
386+
.as_mut()
387+
.unwrap()
388+
.reset_addr_range(offset + total_written, bytes_written);
389+
390+
total_written += bytes_written;
391+
392+
if bytes_written != len_to_write {
393+
break;
394+
}
395+
}
396+
397+
total_written
398+
}
298399

400+
fn populate_via_uffdio_copy(&self, src: *const u8, dst: u64, len: usize) -> bool {
299401
unsafe {
300-
match self.uffd.copy(src as *const _, dst as *mut _, len, true) {
402+
match self.uffd.copy(src.cast(), dst as *mut _, len, true) {
301403
// Make sure the UFFD copied some bytes.
302404
Ok(value) => assert!(value > 0),
303405
// Catch EAGAIN errors, which occur when a `remove` event lands in the UFFD
@@ -322,6 +424,42 @@ impl UffdHandler {
322424
true
323425
}
324426

427+
fn populate_via_memcpy(&mut self, src: *const u8, dst: u64, offset: usize, len: usize) -> bool {
428+
let dst_memcpy = unsafe {
429+
self.guest_memfd_addr
430+
.expect("no guest_memfd addr")
431+
.add(offset)
432+
};
433+
434+
unsafe {
435+
std::ptr::copy_nonoverlapping(src, dst_memcpy, len);
436+
}
437+
438+
self.userfault_bitmap
439+
.as_mut()
440+
.unwrap()
441+
.reset_addr_range(offset, len);
442+
443+
uffd_continue(self.uffd.as_raw_fd(), dst, len as u64).expect("uffd_continue");
444+
445+
true
446+
}
447+
448+
fn populate_from_file(
449+
&mut self,
450+
region: &GuestRegionUffdMapping,
451+
dst: u64,
452+
len: usize,
453+
) -> bool {
454+
let offset = (region.offset + dst - region.base_host_virt_addr) as usize;
455+
let src = unsafe { self.backing_buffer.add(offset) };
456+
457+
match self.guest_memfd {
458+
Some(_) => self.populate_via_memcpy(src, dst, offset, len),
459+
None => self.populate_via_uffdio_copy(src, dst, len),
460+
}
461+
}
462+
325463
fn zero_out(&mut self, addr: u64) -> bool {
326464
match unsafe { self.uffd.zeropage(addr as *mut _, self.page_size, true) } {
327465
Ok(_) => true,
@@ -614,7 +752,7 @@ mod tests {
614752
let (stream, _) = listener.accept().expect("Cannot listen on UDS socket");
615753
// Update runtime with actual runtime
616754
let runtime = uninit_runtime.write(Runtime::new(stream, file));
617-
runtime.run(|_: &mut UffdHandler| {});
755+
runtime.run(|_: &mut UffdHandler| {}, |_: &mut UffdHandler, _: usize| {});
618756
});
619757

620758
// wait for runtime thread to initialize itself

0 commit comments

Comments
 (0)