Skip to content

Commit 74a703b

Browse files
authored
Add HostAlignedByteCount to enforce alignment at compile time (#9620)
* Add HostAlignedByteCount to enforce alignment at compile time As part of the work to allow mmaps to be backed by other implementations, I realized that we didn't have any way to track whether a particular usize is host-page-aligned at compile time. Add a `HostAlignedByteCount` which tracks that a particular usize is aligned to the host page size. This also does not expose safe unchecked arithmetic operations, to ensure that overflows always error out. With `HostAlignedByteCount`, a lot of runtime checks can go away thanks to the type-level assertion. In the interest of keeping the diff relatively small, I haven't converted everything over yet. More can be converted over as time permits. * Make zero-sized mprotects a no-op, add tests
1 parent 642ee73 commit 74a703b

File tree

12 files changed

+611
-154
lines changed

12 files changed

+611
-154
lines changed

crates/wasmtime/src/runtime/vm.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ pub use send_sync_unsafe_cell::SendSyncUnsafeCell;
8585
mod module_id;
8686
pub use module_id::CompiledModuleId;
8787

88+
#[cfg(feature = "signals-based-traps")]
89+
mod byte_count;
8890
#[cfg(feature = "signals-based-traps")]
8991
mod cow;
9092
#[cfg(not(feature = "signals-based-traps"))]
@@ -94,6 +96,7 @@ mod mmap;
9496

9597
cfg_if::cfg_if! {
9698
if #[cfg(feature = "signals-based-traps")] {
99+
pub use crate::runtime::vm::byte_count::*;
97100
pub use crate::runtime::vm::mmap::Mmap;
98101
pub use self::cow::{MemoryImage, MemoryImageSlot, ModuleMemoryImages};
99102
} else {
@@ -365,6 +368,8 @@ pub fn host_page_size() -> usize {
365368
}
366369

367370
/// Is `bytes` a multiple of the host page size?
371+
///
372+
/// (Deprecated: consider switching to `HostAlignedByteCount`.)
368373
#[cfg(feature = "signals-based-traps")]
369374
pub fn usize_is_multiple_of_host_page_size(bytes: usize) -> bool {
370375
bytes % host_page_size() == 0
@@ -373,6 +378,8 @@ pub fn usize_is_multiple_of_host_page_size(bytes: usize) -> bool {
373378
/// Round the given byte size up to a multiple of the host OS page size.
374379
///
375380
/// Returns an error if rounding up overflows.
381+
///
382+
/// (Deprecated: consider switching to `HostAlignedByteCount`.)
376383
#[cfg(feature = "signals-based-traps")]
377384
pub fn round_u64_up_to_host_pages(bytes: u64) -> Result<u64> {
378385
let page_size = u64::try_from(crate::runtime::vm::host_page_size()).err2anyhow()?;
@@ -386,6 +393,8 @@ pub fn round_u64_up_to_host_pages(bytes: u64) -> Result<u64> {
386393
}
387394

388395
/// Same as `round_u64_up_to_host_pages` but for `usize`s.
396+
///
397+
/// (Deprecated: consider switching to `HostAlignedByteCount`.)
389398
#[cfg(feature = "signals-based-traps")]
390399
pub fn round_usize_up_to_host_pages(bytes: usize) -> Result<usize> {
391400
let bytes = u64::try_from(bytes).err2anyhow()?;
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
use core::fmt;
2+
3+
use super::host_page_size;
4+
5+
/// A number of bytes that's guaranteed to be aligned to the host page size.
6+
///
7+
/// This is used to manage page-aligned memory allocations.
8+
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
9+
pub struct HostAlignedByteCount(
10+
// Invariant: this is always a multiple of the host page size.
11+
usize,
12+
);
13+
14+
impl HostAlignedByteCount {
15+
/// A zero byte count.
16+
pub const ZERO: Self = Self(0);
17+
18+
/// Creates a new `HostAlignedByteCount` from an aligned byte count.
19+
///
20+
/// Returns an error if `bytes` is not page-aligned.
21+
pub fn new(bytes: usize) -> Result<Self, ByteCountNotAligned> {
22+
let host_page_size = host_page_size();
23+
if bytes % host_page_size == 0 {
24+
Ok(Self(bytes))
25+
} else {
26+
Err(ByteCountNotAligned(bytes))
27+
}
28+
}
29+
30+
/// Creates a new `HostAlignedByteCount` from an aligned byte count without
31+
/// checking validity.
32+
///
33+
/// ## Safety
34+
///
35+
/// The caller must ensure that `bytes` is page-aligned.
36+
pub unsafe fn new_unchecked(bytes: usize) -> Self {
37+
debug_assert!(
38+
bytes % host_page_size() == 0,
39+
"byte count {bytes} is not page-aligned (page size = {})",
40+
host_page_size(),
41+
);
42+
Self(bytes)
43+
}
44+
45+
/// Creates a new `HostAlignedByteCount`, rounding up to the nearest page.
46+
///
47+
/// Returns an error if `bytes + page_size - 1` overflows.
48+
pub fn new_rounded_up(bytes: usize) -> Result<Self, ByteCountOutOfBounds> {
49+
let page_size = host_page_size();
50+
debug_assert!(page_size.is_power_of_two());
51+
match bytes.checked_add(page_size - 1) {
52+
Some(v) => Ok(Self(v & !(page_size - 1))),
53+
None => Err(ByteCountOutOfBounds(ByteCountOutOfBoundsKind::RoundUp)),
54+
}
55+
}
56+
57+
/// Creates a new `HostAlignedByteCount` from a `u64`, rounding up to the nearest page.
58+
///
59+
/// Returns an error if the `u64` overflows `usize`, or if `bytes +
60+
/// page_size - 1` overflows.
61+
pub fn new_rounded_up_u64(bytes: u64) -> Result<Self, ByteCountOutOfBounds> {
62+
let bytes = bytes
63+
.try_into()
64+
.map_err(|_| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::ConvertU64))?;
65+
Self::new_rounded_up(bytes)
66+
}
67+
68+
/// Returns the host page size.
69+
pub fn host_page_size() -> HostAlignedByteCount {
70+
// The host page size is always a multiple of itself.
71+
HostAlignedByteCount(host_page_size())
72+
}
73+
74+
/// Returns true if the page count is zero.
75+
#[inline]
76+
pub fn is_zero(self) -> bool {
77+
self == Self::ZERO
78+
}
79+
80+
/// Returns the number of bytes as a `usize`.
81+
#[inline]
82+
pub fn byte_count(self) -> usize {
83+
self.0
84+
}
85+
86+
/// Add two aligned byte counts together.
87+
///
88+
/// Returns an error if the result overflows.
89+
pub fn checked_add(self, bytes: HostAlignedByteCount) -> Result<Self, ByteCountOutOfBounds> {
90+
// aligned + aligned = aligned
91+
self.0
92+
.checked_add(bytes.0)
93+
.map(Self)
94+
.ok_or(ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Add))
95+
}
96+
97+
/// Compute `self - bytes`.
98+
///
99+
/// Returns an error if the result underflows.
100+
pub fn checked_sub(self, bytes: HostAlignedByteCount) -> Result<Self, ByteCountOutOfBounds> {
101+
// aligned - aligned = aligned
102+
self.0
103+
.checked_sub(bytes.0)
104+
.map(Self)
105+
.ok_or_else(|| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Sub))
106+
}
107+
108+
/// Multiply an aligned byte count by a scalar value.
109+
///
110+
/// Returns an error if the result overflows.
111+
pub fn checked_mul(self, scalar: usize) -> Result<Self, ByteCountOutOfBounds> {
112+
// aligned * scalar = aligned
113+
self.0
114+
.checked_mul(scalar)
115+
.map(Self)
116+
.ok_or_else(|| ByteCountOutOfBounds(ByteCountOutOfBoundsKind::Mul))
117+
}
118+
119+
/// Unchecked multiplication by a scalar value.
120+
///
121+
/// ## Safety
122+
///
123+
/// The result must not overflow.
124+
#[inline]
125+
pub unsafe fn unchecked_mul(self, n: usize) -> Self {
126+
Self(self.0 * n)
127+
}
128+
}
129+
130+
impl PartialEq<usize> for HostAlignedByteCount {
131+
#[inline]
132+
fn eq(&self, other: &usize) -> bool {
133+
self.0 == *other
134+
}
135+
}
136+
137+
impl PartialEq<HostAlignedByteCount> for usize {
138+
#[inline]
139+
fn eq(&self, other: &HostAlignedByteCount) -> bool {
140+
*self == other.0
141+
}
142+
}
143+
144+
struct LowerHexDisplay<T>(T);
145+
146+
impl<T: fmt::LowerHex> fmt::Display for LowerHexDisplay<T> {
147+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148+
// Use the LowerHex impl as the Display impl, ensuring that there's
149+
// always a 0x in the beginning (i.e. that the alternate formatter is
150+
// used.)
151+
if f.alternate() {
152+
fmt::LowerHex::fmt(&self.0, f)
153+
} else {
154+
// Unfortunately, fill and alignment aren't respected this way, but
155+
// it's quite hard to construct a new formatter with mostly the same
156+
// options but the alternate flag set.
157+
// https://github.com/rust-lang/rust/pull/118159 would make this
158+
// easier.
159+
write!(f, "{:#x}", self.0)
160+
}
161+
}
162+
}
163+
164+
impl fmt::Display for HostAlignedByteCount {
165+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166+
// Use the LowerHex impl as the Display impl, ensuring that there's
167+
// always a 0x in the beginning (i.e. that the alternate formatter is
168+
// used.)
169+
fmt::Display::fmt(&LowerHexDisplay(self.0), f)
170+
}
171+
}
172+
173+
impl fmt::LowerHex for HostAlignedByteCount {
174+
#[inline]
175+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176+
fmt::LowerHex::fmt(&self.0, f)
177+
}
178+
}
179+
180+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
181+
pub struct ByteCountNotAligned(usize);
182+
183+
impl fmt::Display for ByteCountNotAligned {
184+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
185+
write!(
186+
f,
187+
"byte count not page-aligned: {}",
188+
LowerHexDisplay(self.0)
189+
)
190+
}
191+
}
192+
193+
#[cfg(feature = "std")]
194+
impl std::error::Error for ByteCountNotAligned {}
195+
196+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
197+
pub struct ByteCountOutOfBounds(ByteCountOutOfBoundsKind);
198+
199+
impl fmt::Display for ByteCountOutOfBounds {
200+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201+
write!(f, "{}", self.0)
202+
}
203+
}
204+
205+
#[cfg(feature = "std")]
206+
impl std::error::Error for ByteCountOutOfBounds {}
207+
208+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
209+
enum ByteCountOutOfBoundsKind {
210+
// We don't carry the arguments that errored out to avoid the error type
211+
// becoming too big.
212+
RoundUp,
213+
ConvertU64,
214+
Add,
215+
Sub,
216+
Mul,
217+
}
218+
219+
impl fmt::Display for ByteCountOutOfBoundsKind {
220+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221+
match self {
222+
ByteCountOutOfBoundsKind::RoundUp => f.write_str("byte count overflow rounding up"),
223+
ByteCountOutOfBoundsKind::ConvertU64 => {
224+
f.write_str("byte count overflow converting u64")
225+
}
226+
ByteCountOutOfBoundsKind::Add => f.write_str("byte count overflow during addition"),
227+
ByteCountOutOfBoundsKind::Sub => f.write_str("byte count underflow during subtraction"),
228+
ByteCountOutOfBoundsKind::Mul => {
229+
f.write_str("byte count overflow during multiplication")
230+
}
231+
}
232+
}
233+
}
234+
235+
#[cfg(test)]
236+
mod tests {
237+
use super::*;
238+
239+
#[test]
240+
fn byte_count_display() {
241+
// Pages should hopefully be 64k or smaller.
242+
let byte_count = HostAlignedByteCount::new(65536).unwrap();
243+
244+
assert_eq!(format!("{byte_count}"), "0x10000");
245+
assert_eq!(format!("{byte_count:x}"), "10000");
246+
assert_eq!(format!("{byte_count:#x}"), "0x10000");
247+
}
248+
249+
#[test]
250+
fn byte_count_ops() {
251+
let host_page_size = host_page_size();
252+
HostAlignedByteCount::new(0).expect("0 is aligned");
253+
HostAlignedByteCount::new(host_page_size).expect("host_page_size is aligned");
254+
HostAlignedByteCount::new(host_page_size * 2).expect("host_page_size * 2 is aligned");
255+
HostAlignedByteCount::new(host_page_size + 1)
256+
.expect_err("host_page_size + 1 is not aligned");
257+
HostAlignedByteCount::new(host_page_size / 2)
258+
.expect_err("host_page_size / 2 is not aligned");
259+
260+
// Rounding up.
261+
HostAlignedByteCount::new_rounded_up(usize::MAX).expect_err("usize::MAX overflows");
262+
assert_eq!(
263+
HostAlignedByteCount::new_rounded_up(usize::MAX - host_page_size)
264+
.expect("(usize::MAX - 1 page) is in bounds"),
265+
HostAlignedByteCount::new((usize::MAX - host_page_size) + 1)
266+
.expect("usize::MAX is 2**N - 1"),
267+
);
268+
269+
// Addition.
270+
let half_max = HostAlignedByteCount::new((usize::MAX >> 1) + 1)
271+
.expect("(usize::MAX >> 1) + 1 is aligned");
272+
half_max
273+
.checked_add(HostAlignedByteCount::host_page_size())
274+
.expect("half max + page size is in bounds");
275+
half_max
276+
.checked_add(half_max)
277+
.expect_err("half max + half max is out of bounds");
278+
279+
// Subtraction.
280+
let half_max_minus_one = half_max
281+
.checked_sub(HostAlignedByteCount::host_page_size())
282+
.expect("(half_max - 1 page) is in bounds");
283+
assert_eq!(
284+
half_max.checked_sub(half_max),
285+
Ok(HostAlignedByteCount::ZERO)
286+
);
287+
assert_eq!(
288+
half_max.checked_sub(half_max_minus_one),
289+
Ok(HostAlignedByteCount::host_page_size())
290+
);
291+
half_max_minus_one
292+
.checked_sub(half_max)
293+
.expect_err("(half_max - 1 page) - half_max is out of bounds");
294+
295+
// Multiplication.
296+
half_max
297+
.checked_mul(2)
298+
.expect_err("half max * 2 is out of bounds");
299+
half_max_minus_one
300+
.checked_mul(2)
301+
.expect("(half max - 1 page) * 2 is in bounds");
302+
}
303+
}

0 commit comments

Comments
 (0)