Skip to content

Commit ff1426a

Browse files
refactor: remove UserPtr and UserConstPtr usages
1 parent e556238 commit ff1426a

File tree

16 files changed

+328
-507
lines changed

16 files changed

+328
-507
lines changed

api/src/mm.rs

Lines changed: 2 additions & 220 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
use alloc::string::String;
22
use core::{
3-
alloc::Layout,
43
ffi::c_char,
54
hint::unlikely,
65
mem::{MaybeUninit, transmute},
7-
ptr, slice, str,
86
};
97

108
use axerrno::{AxError, AxResult};
@@ -14,226 +12,10 @@ use axhal::{
1412
};
1513
use axio::prelude::*;
1614
use axtask::current;
17-
use memory_addr::{MemoryAddr, PAGE_SIZE_4K, VirtAddr};
18-
use starry_core::{mm::access_user_memory, task::AsThread};
15+
use memory_addr::VirtAddr;
16+
use starry_core::task::AsThread;
1917
use starry_vm::{vm_load_until_nul, vm_read_slice, vm_write_slice};
2018

21-
fn check_region(start: VirtAddr, layout: Layout, access_flags: MappingFlags) -> AxResult<()> {
22-
let align = layout.align();
23-
if start.as_usize() & (align - 1) != 0 {
24-
return Err(AxError::BadAddress);
25-
}
26-
27-
let curr = current();
28-
let mut aspace = curr.as_thread().proc_data.aspace.lock();
29-
30-
if !aspace.can_access_range(start, layout.size(), access_flags) {
31-
return Err(AxError::BadAddress);
32-
}
33-
34-
let page_start = start.align_down_4k();
35-
let page_end = (start + layout.size()).align_up_4k();
36-
aspace.populate_area(page_start, page_end - page_start, access_flags)?;
37-
38-
Ok(())
39-
}
40-
41-
fn check_null_terminated<T: PartialEq + Default>(
42-
start: VirtAddr,
43-
access_flags: MappingFlags,
44-
) -> AxResult<usize> {
45-
let align = Layout::new::<T>().align();
46-
if start.as_usize() & (align - 1) != 0 {
47-
return Err(AxError::BadAddress);
48-
}
49-
50-
let zero = T::default();
51-
52-
let mut page = start.align_down_4k();
53-
54-
let start = start.as_ptr_of::<T>();
55-
let mut len = 0;
56-
57-
access_user_memory(|| {
58-
loop {
59-
// SAFETY: This won't overflow the address space since we'll check
60-
// it below.
61-
let ptr = unsafe { start.add(len) };
62-
while ptr as usize >= page.as_ptr() as usize {
63-
// We cannot prepare `aspace` outside of the loop, since holding
64-
// aspace requires a mutex which would be required on page
65-
// fault, and page faults can trigger inside the loop.
66-
67-
// TODO: this is inefficient, but we have to do this instead of
68-
// querying the page table since the page might has not been
69-
// allocated yet.
70-
let curr = current();
71-
let aspace = curr.as_thread().proc_data.aspace.lock();
72-
if !aspace.can_access_range(page, PAGE_SIZE_4K, access_flags) {
73-
return Err(AxError::BadAddress);
74-
}
75-
76-
page += PAGE_SIZE_4K;
77-
}
78-
79-
// This might trigger a page fault
80-
// SAFETY: The pointer is valid and points to a valid memory region.
81-
if unsafe { ptr.read_volatile() } == zero {
82-
break;
83-
}
84-
len += 1;
85-
}
86-
Ok(())
87-
})?;
88-
89-
Ok(len)
90-
}
91-
92-
/// A pointer to user space memory.
93-
#[repr(transparent)]
94-
#[derive(PartialEq, Clone, Copy)]
95-
pub struct UserPtr<T>(*mut T);
96-
97-
impl<T> From<usize> for UserPtr<T> {
98-
fn from(value: usize) -> Self {
99-
UserPtr(value as *mut _)
100-
}
101-
}
102-
103-
impl<T> From<*mut T> for UserPtr<T> {
104-
fn from(value: *mut T) -> Self {
105-
UserPtr(value)
106-
}
107-
}
108-
109-
impl<T> Default for UserPtr<T> {
110-
fn default() -> Self {
111-
Self(ptr::null_mut())
112-
}
113-
}
114-
115-
impl<T> UserPtr<T> {
116-
const ACCESS_FLAGS: MappingFlags = MappingFlags::READ.union(MappingFlags::WRITE);
117-
118-
pub fn address(&self) -> VirtAddr {
119-
VirtAddr::from_ptr_of(self.0)
120-
}
121-
122-
pub fn cast<U>(self) -> UserPtr<U> {
123-
UserPtr(self.0 as *mut U)
124-
}
125-
126-
pub fn is_null(&self) -> bool {
127-
self.0.is_null()
128-
}
129-
130-
pub fn get_as_mut(self) -> AxResult<&'static mut T> {
131-
check_region(self.address(), Layout::new::<T>(), Self::ACCESS_FLAGS)?;
132-
Ok(unsafe { &mut *self.0 })
133-
}
134-
135-
pub fn get_as_mut_slice(self, len: usize) -> AxResult<&'static mut [T]> {
136-
check_region(
137-
self.address(),
138-
Layout::array::<T>(len).unwrap(),
139-
Self::ACCESS_FLAGS,
140-
)?;
141-
Ok(unsafe { slice::from_raw_parts_mut(self.0, len) })
142-
}
143-
144-
pub fn get_as_mut_null_terminated(self) -> AxResult<&'static mut [T]>
145-
where
146-
T: PartialEq + Default,
147-
{
148-
let len = check_null_terminated::<T>(self.address(), Self::ACCESS_FLAGS)?;
149-
Ok(unsafe { slice::from_raw_parts_mut(self.0, len) })
150-
}
151-
}
152-
153-
/// An immutable pointer to user space memory.
154-
#[repr(transparent)]
155-
#[derive(PartialEq, Clone, Copy)]
156-
pub struct UserConstPtr<T>(*const T);
157-
158-
impl<T> From<usize> for UserConstPtr<T> {
159-
fn from(value: usize) -> Self {
160-
UserConstPtr(value as *const _)
161-
}
162-
}
163-
164-
impl<T> From<*const T> for UserConstPtr<T> {
165-
fn from(value: *const T) -> Self {
166-
UserConstPtr(value)
167-
}
168-
}
169-
170-
impl<T> Default for UserConstPtr<T> {
171-
fn default() -> Self {
172-
Self(ptr::null())
173-
}
174-
}
175-
176-
impl<T> UserConstPtr<T> {
177-
const ACCESS_FLAGS: MappingFlags = MappingFlags::READ;
178-
179-
pub fn address(&self) -> VirtAddr {
180-
VirtAddr::from_ptr_of(self.0)
181-
}
182-
183-
pub fn cast<U>(self) -> UserConstPtr<U> {
184-
UserConstPtr(self.0 as *const U)
185-
}
186-
187-
pub fn is_null(&self) -> bool {
188-
self.0.is_null()
189-
}
190-
191-
pub fn get_as_ref(self) -> AxResult<&'static T> {
192-
check_region(self.address(), Layout::new::<T>(), Self::ACCESS_FLAGS)?;
193-
Ok(unsafe { &*self.0 })
194-
}
195-
196-
pub fn get_as_slice(self, len: usize) -> AxResult<&'static [T]> {
197-
check_region(
198-
self.address(),
199-
Layout::array::<T>(len).unwrap(),
200-
Self::ACCESS_FLAGS,
201-
)?;
202-
Ok(unsafe { slice::from_raw_parts(self.0, len) })
203-
}
204-
205-
pub fn get_as_null_terminated(self) -> AxResult<&'static [T]>
206-
where
207-
T: PartialEq + Default,
208-
{
209-
let len = check_null_terminated::<T>(self.address(), Self::ACCESS_FLAGS)?;
210-
Ok(unsafe { slice::from_raw_parts(self.0, len) })
211-
}
212-
}
213-
214-
impl UserConstPtr<c_char> {
215-
/// Get the pointer as `&str`, validating the memory region.
216-
pub fn get_as_str(self) -> AxResult<&'static str> {
217-
let slice = self.get_as_null_terminated()?;
218-
// SAFETY: c_char is u8
219-
let slice = unsafe { transmute::<&[c_char], &[u8]>(slice) };
220-
221-
str::from_utf8(slice).map_err(|_| AxError::IllegalBytes)
222-
}
223-
}
224-
225-
macro_rules! nullable {
226-
($ptr:ident.$func:ident($($arg:expr),*)) => {
227-
if $ptr.is_null() {
228-
Ok(None)
229-
} else {
230-
Some($ptr.$func($($arg),*)).transpose()
231-
}
232-
};
233-
}
234-
235-
pub(crate) use nullable;
236-
23719
#[register_trap_handler(PAGE_FAULT)]
23820
fn handle_page_fault(vaddr: VirtAddr, access_flags: MappingFlags) -> bool {
23921
debug!("Page fault at {vaddr:#x}, access_flags: {access_flags:#x?}");

0 commit comments

Comments
 (0)