Skip to content

Commit e2a9857

Browse files
committed
patina_internal_collections: Use MaybeUninit to avoid Default requirement
Replace Default trait requirement with MaybeUninit wrapper for Node data field. Only initialize data when nodes move from available to in-use list. Fixes UB where Cell::set() was reading uninitialized memory.
1 parent f921e1e commit e2a9857

File tree

3 files changed

+158
-81
lines changed

3 files changed

+158
-81
lines changed

core/patina_internal_collections/src/bst.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ where
159159
///
160160
pub fn get(&self, key: &D::Key) -> Option<&D> {
161161
match self.get_node(key) {
162-
Some(node) => Some(&node.data),
162+
Some(node) => Some(unsafe { node.data() }),
163163
None => None,
164164
}
165165
}
@@ -186,7 +186,7 @@ where
186186
// SAFETY: The pointer comes from as_mut_ptr() on a valid node reference obtained from get_node().
187187
// The caller is responsible for ensuring that the mutable reference doesn't modify key-affecting
188188
// values.
189-
Some(unsafe { &mut (*ptr).data })
189+
Some(unsafe { (*ptr).data_mut() })
190190
}
191191
None => None,
192192
}
@@ -209,7 +209,7 @@ where
209209
///
210210
pub fn get_with_idx(&self, idx: usize) -> Option<&D> {
211211
match self.storage.get(idx) {
212-
Some(node) => Some(&node.data),
212+
Some(node) => Some(unsafe { node.data() }),
213213
None => None,
214214
}
215215
}
@@ -236,7 +236,7 @@ where
236236
///
237237
pub unsafe fn get_with_idx_mut(&mut self, idx: usize) -> Option<&mut D> {
238238
match self.storage.get_mut(idx) {
239-
Some(node) => Some(&mut node.data),
239+
Some(node) => Some(unsafe { node.data_mut() }),
240240
None => None,
241241
}
242242
}
@@ -281,7 +281,7 @@ where
281281
let mut current = self.root();
282282
let mut closest = None;
283283
while let Some(node) = current {
284-
match key.cmp(node.data.key()) {
284+
match key.cmp(unsafe { node.data() }.key()) {
285285
Ordering::Equal => return Some(self.storage.idx(node.as_mut_ptr())),
286286
Ordering::Less => current = node.left(),
287287
Ordering::Greater => {
@@ -494,7 +494,7 @@ where
494494
fn get_node(&self, key: &D::Key) -> Option<&Node<D>> {
495495
let mut current_idx = self.root();
496496
while let Some(node) = current_idx {
497-
match key.cmp(node.data.key()) {
497+
match key.cmp(unsafe { node.data() }.key()) {
498498
Ordering::Equal => return Some(node),
499499
Ordering::Less => current_idx = node.left(),
500500
Ordering::Greater => current_idx = node.right(),
@@ -639,7 +639,7 @@ where
639639
fn _dfs(node: Option<&Node<D>>, values: &mut alloc::vec::Vec<D>) {
640640
if let Some(node) = node {
641641
Self::_dfs(node.left(), values);
642-
values.push(node.data);
642+
values.push(unsafe { *node.data() });
643643
Self::_dfs(node.right(), values);
644644
}
645645
}

core/patina_internal_collections/src/node.rs

Lines changed: 114 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
//!
77
//! SPDX-License-Identifier: Apache-2.0
88
//!
9-
use core::{cell::Cell, mem, ptr::NonNull, slice};
9+
use core::{cell::Cell, mem, mem::MaybeUninit, ptr::NonNull, slice};
1010

1111
use crate::{Error, Result, SliceKey};
1212

@@ -51,23 +51,32 @@ where
5151

5252
/// Create a new storage container with a slice of memory.
5353
pub fn with_capacity(slice: &'a mut [u8]) -> Storage<'a, D> {
54-
let storage = Storage {
55-
// SAFETY: This is reinterpreting a byte slice as a Node<D> slice.
56-
// 1. The alignment is checked implicitly by the slice bounds.
57-
// 2. The correct number of Node<D> elements that fit in the byte slice is calculated.
58-
// 3. The lifetime ensures the byte slice remains valid for the storage's lifetime
59-
data: unsafe {
60-
slice::from_raw_parts_mut::<'a, Node<D>>(
61-
slice as *mut [u8] as *mut Node<D>,
62-
slice.len() / mem::size_of::<Node<D>>(),
63-
)
64-
},
65-
length: 0,
66-
available: Cell::default(),
54+
// SAFETY: This is reinterpreting a byte slice as a MaybeUninit<Node<D>> slice.
55+
// Using MaybeUninit explicitly represents uninitialized memory.
56+
let uninit_buffer = unsafe {
57+
slice::from_raw_parts_mut::<'a, MaybeUninit<Node<D>>>(
58+
slice as *mut [u8] as *mut MaybeUninit<Node<D>>,
59+
slice.len() / mem::size_of::<Node<D>>(),
60+
)
6761
};
6862

69-
Self::build_linked_list(storage.data);
70-
storage.available.set(storage.data[0].as_mut_ptr());
63+
// Initialize nodes with uninitialized data fields
64+
for elem in uninit_buffer.iter_mut() {
65+
elem.write(Node::new_uninit());
66+
}
67+
68+
// SAFETY: All nodes have been initialized (though their data fields are uninitialized).
69+
// We can now safely convert from MaybeUninit<Node<D>> to Node<D>.
70+
let buffer =
71+
unsafe { slice::from_raw_parts_mut(uninit_buffer.as_mut_ptr() as *mut Node<D>, uninit_buffer.len()) };
72+
73+
let storage = Storage { data: buffer, length: 0, available: Cell::default() };
74+
75+
if !storage.data.is_empty() {
76+
Self::build_linked_list(storage.data);
77+
storage.available.set(storage.data[0].as_mut_ptr());
78+
}
79+
7180
storage
7281
}
7382

@@ -105,7 +114,11 @@ where
105114
node.set_left(None);
106115
node.set_right(None);
107116
node.set_parent(None);
108-
node.data = data;
117+
// SAFETY: The node is from the available list, so its data field is uninitialized.
118+
// We initialize it here when moving the node to the "in use" state.
119+
unsafe {
120+
node.init_data(data);
121+
}
109122
self.length += 1;
110123
Ok((self.idx(node.as_mut_ptr()), node))
111124
} else {
@@ -188,17 +201,26 @@ where
188201
///
189202
/// O(n)
190203
pub fn resize(&mut self, slice: &'a mut [u8]) {
191-
// SAFETY: This is reinterpreting a byte slice as a Node<D> slice.
192-
// 1. The alignment is handled by slice casting rules
193-
// 2. The correct number of Node<D> elements that fit in the byte slice is calculated
194-
// 3. The lifetime 'a ensures the byte slice remains valid for the storage's lifetime
195-
let buffer = unsafe {
196-
slice::from_raw_parts_mut::<'a, Node<D>>(
197-
slice as *mut [u8] as *mut Node<D>,
204+
// SAFETY: This is reinterpreting a byte slice as a MaybeUninit<Node<D>> slice.
205+
// Using MaybeUninit explicitly represents uninitialized memory.
206+
let uninit_buffer = unsafe {
207+
slice::from_raw_parts_mut::<'a, MaybeUninit<Node<D>>>(
208+
slice as *mut [u8] as *mut MaybeUninit<Node<D>>,
198209
slice.len() / mem::size_of::<Node<D>>(),
199210
)
200211
};
201212

213+
assert!(uninit_buffer.len() >= self.len());
214+
215+
// Initialize all nodes with uninitialized data fields
216+
for elem in uninit_buffer.iter_mut() {
217+
elem.write(Node::new_uninit());
218+
}
219+
220+
// SAFETY: All nodes have been initialized (though their data fields may be uninitialized).
221+
let buffer =
222+
unsafe { slice::from_raw_parts_mut(uninit_buffer.as_mut_ptr() as *mut Node<D>, uninit_buffer.len()) };
223+
202224
assert!(buffer.len() >= self.len());
203225

204226
// When current capacity is 0, we just need to copy the data and build the available list
@@ -213,7 +235,13 @@ where
213235
for i in 0..self.len() {
214236
let old = &self.data[i];
215237

216-
buffer[i].data = old.data;
238+
// SAFETY: Nodes at indices 0..self.len() are "in use" and have initialized data.
239+
// We copy the initialized data from old to new.
240+
unsafe {
241+
let old_data = old.data();
242+
// Use ptr::copy to copy the data from old to new
243+
buffer[i].data = MaybeUninit::new(*old_data);
244+
}
217245
buffer[i].set_color(old.color());
218246

219247
if let Some(left) = old.left() {
@@ -464,7 +492,7 @@ pub struct Node<D>
464492
where
465493
D: SliceKey,
466494
{
467-
pub data: D,
495+
pub data: MaybeUninit<D>,
468496
color: Cell<bool>,
469497
parent: Cell<*mut Node<D>>,
470498
left: Cell<*mut Node<D>>,
@@ -475,8 +503,48 @@ impl<D> Node<D>
475503
where
476504
D: SliceKey,
477505
{
506+
/// Create a new node with uninitialized data.
507+
/// The data field must be initialized separately using `init_data()`.
508+
pub fn new_uninit() -> Self {
509+
Node {
510+
data: MaybeUninit::uninit(),
511+
color: Cell::new(RED),
512+
parent: Cell::default(),
513+
left: Cell::default(),
514+
right: Cell::default(),
515+
}
516+
}
517+
518+
/// Initialize the data field of an uninitialized node.
519+
/// # Safety
520+
/// The caller must ensure the data field has not been previously initialized.
521+
pub unsafe fn init_data(&mut self, data: D) {
522+
self.data.write(data);
523+
}
524+
525+
/// Creates a new Node with initialized data.
526+
/// Used for testing purposes.
527+
#[cfg(test)]
478528
pub fn new(data: D) -> Self {
479-
Node { data, color: Cell::new(RED), parent: Cell::default(), left: Cell::default(), right: Cell::default() }
529+
let mut node = Self::new_uninit();
530+
node.data.write(data);
531+
node
532+
}
533+
534+
/// Get a reference to the data, assuming it is initialized.
535+
/// # Safety
536+
/// The caller must ensure the data field has been initialized.
537+
pub unsafe fn data(&self) -> &D {
538+
// SAFETY: Caller guarantees data is initialized
539+
unsafe { self.data.assume_init_ref() }
540+
}
541+
542+
/// Get a mutable reference to the data, assuming it is initialized.
543+
/// # Safety
544+
/// The caller must ensure the data field has been initialized.
545+
pub unsafe fn data_mut(&mut self) -> &mut D {
546+
// SAFETY: Caller guarantees data is initialized
547+
unsafe { self.data.assume_init_mut() }
480548
}
481549

482550
pub fn height_and_balance(node: Option<&Node<D>>) -> (i32, bool) {
@@ -584,7 +652,9 @@ where
584652
impl<D: SliceKey> SliceKey for Node<D> {
585653
type Key = D::Key;
586654
fn key(&self) -> &Self::Key {
587-
self.data.key()
655+
// SAFETY: This method is only called on nodes that are in use (initialized).
656+
// Nodes in the available list are never accessed for their key.
657+
unsafe { self.data().key() }
588658
}
589659
}
590660

@@ -602,7 +672,8 @@ mod tests {
602672
for i in 0..10 {
603673
let (index, node) = storage.add(i).unwrap();
604674
assert_eq!(index, i);
605-
assert_eq!(node.data, i);
675+
// SAFETY: Node was just added with data, so it's initialized
676+
assert_eq!(unsafe { *node.data() }, i);
606677
assert_eq!(storage.len(), i + 1);
607678
}
608679

@@ -613,16 +684,22 @@ mod tests {
613684
storage.delete(storage.get(5).unwrap().as_mut_ptr());
614685
let (index, node) = storage.add(11).unwrap();
615686
assert_eq!(index, 5);
616-
assert_eq!(node.data, 11);
687+
// SAFETY: Node was just added with data, so it's initialized
688+
assert_eq!(unsafe { *node.data() }, 11);
617689

618690
// Try and get a mutable reference to a node
619691
{
620692
let node = storage.get_mut(5).unwrap();
621-
assert_eq!(node.data, 11);
622-
node.data = 12;
693+
// SAFETY: Node is in use, so data is initialized
694+
assert_eq!(unsafe { *node.data() }, 11);
695+
// SAFETY: Node is in use, we can modify the initialized data
696+
unsafe {
697+
*node.data_mut() = 12;
698+
}
623699
}
624700
let node = storage.get(5).unwrap();
625-
assert_eq!(node.data, 12);
701+
// SAFETY: Node is in use, so data is initialized
702+
assert_eq!(unsafe { *node.data() }, 12);
626703
}
627704

628705
#[test]
@@ -640,8 +717,8 @@ mod tests {
640717

641718
p4.set_parent(Some(p1));
642719

643-
assert_eq!(Node::sibling(p2).unwrap().data, 3);
644-
assert_eq!(Node::sibling(p3).unwrap().data, 2);
720+
assert_eq!(unsafe { *Node::sibling(p2).unwrap().data() }, 3);
721+
assert_eq!(unsafe { *Node::sibling(p3).unwrap().data() }, 2);
645722
assert!(Node::sibling(p1).is_none());
646723
}
647724

@@ -680,7 +757,7 @@ mod tests {
680757
p2.set_right(Some(p4));
681758
p4.set_parent(Some(p2));
682759

683-
assert_eq!(Node::predecessor(p1).unwrap().data, 4);
760+
assert_eq!(unsafe { *Node::predecessor(p1).unwrap().data() }, 4);
684761
assert!(Node::predecessor(p4).is_none());
685762
}
686763

@@ -700,7 +777,7 @@ mod tests {
700777
p2.set_right(Some(p4));
701778
p4.set_parent(Some(p2));
702779

703-
assert_eq!(Node::successor(p1).unwrap().data, 3);
780+
assert_eq!(unsafe { *Node::successor(p1).unwrap().data() }, 3);
704781
assert!(Node::successor(p4).is_none());
705782
}
706783

0 commit comments

Comments
 (0)