diff --git a/Cargo.lock b/Cargo.lock index cfa919f8..f53e8160 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,7 +17,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b62040eef329ede76c8b2af149362b0fd42ec9d8c436d9eb9d47db35b71cb300" dependencies = [ - "bitflags", + "bitflags 2.10.0", "safe-mmio", "thiserror", "zerocopy", @@ -40,12 +40,24 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + [[package]] name = "bitflags" version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.11.0" @@ -79,6 +91,47 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "defmt" +version = "0.3.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0963443817029b2024136fc4dd07a5107eb8f977eaf18fcd1fdeb11306b64ad" +dependencies = [ + "defmt 1.0.1", +] + +[[package]] +name = "defmt" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "548d977b6da32fa1d1fda2876453da1e7df63ad0304c8b3dae4dbe7b96f39b78" +dependencies = [ + "bitflags 1.3.2", + "defmt-macros", +] + +[[package]] +name = "defmt-macros" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d4fc12a85bcf441cfe44344c4b72d58493178ce635338a3f3b78943aceb258e" +dependencies = [ + "defmt-parser", + "proc-macro-error2", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "defmt-parser" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10d60334b3b2e7c9d91ef8150abfb6fa4c1c39ebbcf4a81c2e346aad939fee3e" +dependencies = [ + "thiserror", +] + [[package]] name = "errno" version = "0.3.14" @@ -95,7 +148,7 @@ version = "0.9.3" source = "git+https://github.com/arihant2math/ext4-view-rs.git?branch=main#adb7c88dd8acba198f8c701e664b2a03ca9df339" dependencies = [ "async-trait", - "bitflags", + "bitflags 2.10.0", "crc", ] @@ -196,6 +249,25 @@ dependencies = [ "wasip2", ] +[[package]] +name = "hash32" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47d60b12902ba28e2730cd37e95b8c9223af2808df9e902d4df49588d1470606" +dependencies = [ + "byteorder", +] + +[[package]] +name = "heapless" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bfb9eb618601c89945a70e254898da93b13be0388091d42117462b265bb3fad" +dependencies = [ + "hash32", + "stable_deref_trait", +] + [[package]] name = "intrusive-collections" version = "0.9.7" @@ -216,7 +288,7 @@ name = "libkernel" version = "0.0.0" dependencies = [ "async-trait", - "bitflags", + "bitflags 2.10.0", "ext4-view", "intrusive-collections", "log", @@ -253,6 +325,12 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "managed" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ca88d725a0a943b096803bd34e73a4437208b6077654cc4ecb2947a5f91618d" + [[package]] name = "memchr" version = "2.7.6" @@ -286,7 +364,7 @@ dependencies = [ "aarch64-cpu", "arm-pl011-uart", "async-trait", - "bitflags", + "bitflags 2.10.0", "fdt-parser", "futures", "getargs", @@ -297,6 +375,7 @@ dependencies = [ "paste", "rand", "ringbuf", + "smoltcp", "tock-registers", ] @@ -375,6 +454,28 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "proc-macro-error-attr2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" +dependencies = [ + "proc-macro2", + "quote", +] + +[[package]] +name = "proc-macro-error2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" +dependencies = [ + "proc-macro-error-attr2", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "proc-macro2" version = "1.0.105" @@ -434,7 +535,7 @@ version = "0.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" dependencies = [ - "bitflags", + "bitflags 2.10.0", ] [[package]] @@ -484,6 +585,20 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +[[package]] +name = "smoltcp" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dad095989c1533c1c266d9b1e8d70a1329dd3723c3edac6d03bbd67e7bf6f4bb" +dependencies = [ + "bitflags 1.3.2", + "byteorder", + "cfg-if", + "defmt 0.3.100", + "heapless", + "managed", +] + [[package]] name = "socket2" version = "0.6.1" @@ -503,6 +618,12 @@ dependencies = [ "lock_api", ] +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + [[package]] name = "syn" version = "2.0.114" diff --git a/Cargo.toml b/Cargo.toml index f8feab70..02b50d61 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ ringbuf = { version = "0.4.8", default-features = false, features = ["alloc"] } bitflags = "2.9.1" futures = { version = "0.3.31", default-features = false, features = ["alloc", "async-await"] } rand = { version = "0.9.2", default-features = false, features = ["small_rng"] } +smoltcp = { version = "0.12.0", default-features = false, features = ["alloc", "medium-ethernet", "medium-ip", "proto-ipv4", "proto-ipv6", "socket-tcp", "socket-udp"] } [features] default = ["smp"] diff --git a/libkernel/src/error.rs b/libkernel/src/error.rs index 2779f21b..3a47f704 100644 --- a/libkernel/src/error.rs +++ b/libkernel/src/error.rs @@ -141,6 +141,9 @@ pub enum KernelError { #[error("Operation not supported")] NotSupported, + #[error("Address family not supported")] + AddressFamilyNotSupported, + #[error("Device probe failed: {0}")] Probe(#[from] ProbeError), @@ -186,6 +189,9 @@ pub enum KernelError { #[error("Operation timed out")] TimedOut, + #[error("Not a socket")] + NotASocket, + #[error("{0}")] Other(&'static str), } diff --git a/src/arch/arm64/exceptions/syscall.rs b/src/arch/arm64/exceptions/syscall.rs index 33c2280e..9f7e92a9 100644 --- a/src/arch/arm64/exceptions/syscall.rs +++ b/src/arch/arm64/exceptions/syscall.rs @@ -74,6 +74,10 @@ use crate::{ threading::{futex::sys_futex, sys_set_robust_list, sys_set_tid_address}, }, sched::{current::current_task, sys_sched_yield}, + socket::syscalls::{ + accept::sys_accept, bind::sys_bind, connect::sys_connect, listen::sys_listen, + shutdown::sys_shutdown, socket::sys_socket, + }, }; use alloc::boxed::Box; use libkernel::{ @@ -398,7 +402,12 @@ pub async fn handle_syscall() { 0xb1 => sys_getegid().map_err(|e| match e {}), 0xb2 => sys_gettid().map_err(|e| match e {}), 0xb3 => sys_sysinfo(TUA::from_value(arg1 as _)).await, - 0xc6 => Err(KernelError::NotSupported), + 0xc6 => sys_socket(arg1 as _, arg2 as _, arg3 as _).await, + 0xc8 => sys_bind(arg1.into(), UA::from_value(arg2 as _), arg3 as _).await, + 0xc9 => sys_listen(arg1.into(), arg2 as _).await, + 0xca => sys_accept(arg1.into()).await, + 0xcb => sys_connect(arg1.into(), UA::from_value(arg2 as _), arg3 as _).await, + 0xd2 => sys_shutdown(arg1.into(), arg2 as _).await, 0xd6 => sys_brk(VA::from_value(arg1 as _)) .await .map_err(|e| match e {}), diff --git a/src/fs/fops.rs b/src/fs/fops.rs index 6388cd0b..41aff5b1 100644 --- a/src/fs/fops.rs +++ b/src/fs/fops.rs @@ -137,4 +137,8 @@ pub trait FileOps: Send + Sync { ) -> Result { Err(KernelError::InvalidValue) } + + fn as_socket(&mut self) -> Option<&mut dyn crate::socket::SocketOps> { + None + } } diff --git a/src/main.rs b/src/main.rs index 93333420..f32213ba 100644 --- a/src/main.rs +++ b/src/main.rs @@ -44,6 +44,7 @@ mod kernel; mod memory; mod process; mod sched; +mod socket; mod sync; #[panic_handler] diff --git a/src/socket/mod.rs b/src/socket/mod.rs new file mode 100644 index 00000000..46ebbb43 --- /dev/null +++ b/src/socket/mod.rs @@ -0,0 +1,149 @@ +mod sops; +pub mod syscalls; +mod tcp; +mod unix; + +use crate::drivers::timer::now; +use crate::memory::uaccess::copy_from_user; +use crate::sync::OnceLock; +use crate::sync::SpinLock; +use alloc::vec; +use core::net::Ipv4Addr; +use libkernel::error::KernelError; +use libkernel::memory::address::UA; +use libkernel::sync::waker_set::WakerSet; +use smoltcp::iface::SocketSet; +use smoltcp::wire::{IpAddress, IpEndpoint}; +pub use sops::SocketOps; + +static SOCKETS: OnceLock> = OnceLock::new(); + +fn sockets() -> &'static SpinLock> { + SOCKETS.get_or_init(|| SpinLock::new(SocketSet::new(vec![]))) +} + +// static INTERFACE: OnceLock>> = OnceLock::new(); + +static SOCKET_WAIT_QUEUE: OnceLock> = OnceLock::new(); + +fn socket_wait_queue() -> &'static SpinLock { + SOCKET_WAIT_QUEUE.get_or_init(|| SpinLock::new(WakerSet::new())) +} + +pub const AF_UNIX: i32 = 1; +pub const AF_INET: i32 = 2; +pub const SOCK_STREAM: i32 = 1; +pub const SOCK_DGRAM: i32 = 2; +pub const SOCK_SEQPACKET: i32 = 5; +pub const IPPROTO_TCP: i32 = 6; +#[expect(dead_code)] +pub const IPPROTO_UDP: i32 = 17; + +#[repr(i32)] +pub enum ShutdownHow { + Read = 0, + Write = 1, + ReadWrite = 2, +} + +impl TryFrom for ShutdownHow { + type Error = KernelError; + fn try_from(value: i32) -> Result { + match value { + 0 => Ok(ShutdownHow::Read), + 1 => Ok(ShutdownHow::Write), + 2 => Ok(ShutdownHow::ReadWrite), + _ => Err(KernelError::InvalidValue), + } + } +} + +#[non_exhaustive] +#[derive(Debug, Clone)] +pub enum SockAddr { + In(SockAddrIn), + Un(SockAddrUn), +} + +#[derive(Copy, Clone, Debug)] +#[repr(C, packed)] +pub struct SockAddrIn { + family: u16, + port: [u8; 2], + addr: [u8; 4], + zero: [u8; 8], +} + +#[derive(Copy, Clone, Debug)] +#[repr(C, packed)] +pub struct SockAddrUn { + family: u16, + path: [u8; 108], +} + +unsafe impl crate::memory::uaccess::UserCopyable for SockAddrIn {} +unsafe impl crate::memory::uaccess::UserCopyable for SockAddrUn {} + +impl TryFrom for IpEndpoint { + type Error = KernelError; + fn try_from(sockaddr: SockAddr) -> Result { + match sockaddr { + SockAddr::In(SockAddrIn { port, addr, .. }) => Ok(IpEndpoint { + port: u16::from_be_bytes(port), + addr: IpAddress::Ipv4(Ipv4Addr::from(addr)), + }), + _ => Err(KernelError::InvalidValue), + } + } +} + +impl From for SockAddr { + fn from(endpoint: IpEndpoint) -> SockAddr { + SockAddr::In(SockAddrIn { + family: AF_INET as u16, + port: endpoint.port.to_be_bytes(), + addr: match endpoint.addr { + IpAddress::Ipv4(addr) => addr.octets(), + _ => unimplemented!(), + }, + zero: [0; 8], + }) + } +} + +pub fn process_packets() { + // For now, just wake any tasks waiting on socket progress. + let _ = sockets().lock_save_irq(); + let _ = now(); + socket_wait_queue().lock_save_irq().wake_all(); +} + +pub async fn parse_sockaddr(uaddr: UA, len: usize) -> Result { + use crate::memory::uaccess::try_copy_from_user; + use libkernel::memory::address::TUA; + + // Need at least a family field + if len < size_of::() { + return Err(KernelError::InvalidValue); + } + + let family: u16 = copy_from_user(TUA::from_value(uaddr.value())).await?; + + match family as i32 { + AF_INET => { + if len < size_of::() { + return Err(KernelError::InvalidValue); + } + let sain: SockAddrIn = try_copy_from_user(uaddr.cast())?; + Ok(SockAddr::In(sain)) + } + AF_UNIX => { + if len < size_of::() { + return Err(KernelError::InvalidValue); + } + let saun: SockAddrUn = try_copy_from_user(uaddr.cast())?; + Ok(SockAddr::Un(saun)) + } + _ => Err(KernelError::AddressFamilyNotSupported), + } +} diff --git a/src/socket/sops.rs b/src/socket/sops.rs new file mode 100644 index 00000000..e509a278 --- /dev/null +++ b/src/socket/sops.rs @@ -0,0 +1,91 @@ +use crate::fs::fops::FileOps; +use crate::fs::open_file::FileCtx; +use crate::socket::{ShutdownHow, SockAddr}; +use alloc::boxed::Box; +use async_trait::async_trait; +use libkernel::error::KernelError; +use libkernel::memory::address::UA; + +#[async_trait] +pub trait SocketOps: Send + Sync { + async fn bind(&self, _addr: SockAddr) -> libkernel::error::Result<()> { + Err(KernelError::NotSupported) + } + + async fn connect(&self, _addr: SockAddr) -> libkernel::error::Result<()> { + Err(KernelError::NotSupported) + } + + async fn listen(&self, _backlog: i32) -> libkernel::error::Result<()> { + Err(KernelError::NotSupported) + } + + async fn accept(&self) -> libkernel::error::Result> { + Err(KernelError::NotSupported) + } + + async fn read( + &mut self, + ctx: &mut FileCtx, + buf: UA, + count: usize, + ) -> libkernel::error::Result; + async fn write( + &mut self, + ctx: &mut FileCtx, + buf: UA, + count: usize, + ) -> libkernel::error::Result; + + async fn shutdown(&self, _how: ShutdownHow) -> libkernel::error::Result<()> { + Err(KernelError::NotSupported) + } + + fn as_file(self: Box) -> Box; +} + +#[async_trait] +impl FileOps for T +where + T: SocketOps, +{ + async fn read( + &mut self, + ctx: &mut FileCtx, + buf: UA, + count: usize, + ) -> libkernel::error::Result { + self.read(ctx, buf, count).await + } + + async fn readat( + &mut self, + _buf: UA, + _count: usize, + _offset: u64, + ) -> libkernel::error::Result { + Err(KernelError::NotSupported) + } + + async fn write( + &mut self, + ctx: &mut FileCtx, + buf: UA, + count: usize, + ) -> libkernel::error::Result { + self.write(ctx, buf, count).await + } + + async fn writeat( + &mut self, + _buf: UA, + _count: usize, + _offset: u64, + ) -> libkernel::error::Result { + Err(KernelError::NotSupported) + } + + fn as_socket(&mut self) -> Option<&mut dyn SocketOps> { + Some(self) + } +} diff --git a/src/socket/syscalls/accept.rs b/src/socket/syscalls/accept.rs new file mode 100644 index 00000000..6f11e48a --- /dev/null +++ b/src/socket/syscalls/accept.rs @@ -0,0 +1,29 @@ +use crate::fs::open_file::OpenFile; +use crate::process::fd_table::Fd; +use crate::sched::current::current_task_shared; +use libkernel::error::KernelError; +use libkernel::fs::OpenFlags; + +pub async fn sys_accept(fd: Fd) -> libkernel::error::Result { + let file = current_task_shared() + .fd_table + .lock_save_irq() + .get(fd) + .ok_or(KernelError::BadFd)?; + + let (ops, _ctx) = &mut *file.lock().await; + + let new_socket = ops + .as_socket() + .ok_or(KernelError::NotASocket)? + .accept() + .await? + .as_file(); + + let open_file = OpenFile::new(new_socket, OpenFlags::empty()); + let new_fd = current_task_shared() + .fd_table + .lock_save_irq() + .insert(alloc::sync::Arc::new(open_file))?; + Ok(new_fd.as_raw() as usize) +} diff --git a/src/socket/syscalls/bind.rs b/src/socket/syscalls/bind.rs new file mode 100644 index 00000000..74362732 --- /dev/null +++ b/src/socket/syscalls/bind.rs @@ -0,0 +1,20 @@ +use crate::process::fd_table::Fd; +use crate::socket::parse_sockaddr; +use libkernel::memory::address::UA; + +pub async fn sys_bind(fd: Fd, addr: UA, addrlen: usize) -> libkernel::error::Result { + let file = crate::sched::current::current_task() + .fd_table + .lock_save_irq() + .get(fd) + .ok_or(libkernel::error::KernelError::BadFd)?; + + let (ops, _ctx) = &mut *file.lock().await; + let addr = parse_sockaddr(addr, addrlen).await?; + + ops.as_socket() + .ok_or(libkernel::error::KernelError::NotASocket)? + .bind(addr) + .await?; + Ok(0) +} diff --git a/src/socket/syscalls/connect.rs b/src/socket/syscalls/connect.rs new file mode 100644 index 00000000..b16acd0c --- /dev/null +++ b/src/socket/syscalls/connect.rs @@ -0,0 +1,20 @@ +use crate::process::fd_table::Fd; +use crate::socket::parse_sockaddr; +use libkernel::memory::address::UA; + +pub async fn sys_connect(fd: Fd, addr: UA, addrlen: usize) -> libkernel::error::Result { + let file = crate::sched::current::current_task() + .fd_table + .lock_save_irq() + .get(fd) + .ok_or(libkernel::error::KernelError::BadFd)?; + + let (ops, _ctx) = &mut *file.lock().await; + let addr = parse_sockaddr(addr, addrlen).await?; + + ops.as_socket() + .ok_or(libkernel::error::KernelError::NotASocket)? + .connect(addr) + .await?; + Ok(0) +} diff --git a/src/socket/syscalls/listen.rs b/src/socket/syscalls/listen.rs new file mode 100644 index 00000000..f870c5ab --- /dev/null +++ b/src/socket/syscalls/listen.rs @@ -0,0 +1,18 @@ +use crate::process::fd_table::Fd; +use libkernel::error::KernelError; + +pub async fn sys_listen(fd: Fd, backlog: i32) -> libkernel::error::Result { + let file = crate::sched::current::current_task() + .fd_table + .lock_save_irq() + .get(fd) + .ok_or(KernelError::BadFd)?; + + let (ops, _ctx) = &mut *file.lock().await; + + ops.as_socket() + .ok_or(KernelError::NotASocket)? + .listen(backlog) + .await?; + Ok(0) +} diff --git a/src/socket/syscalls/mod.rs b/src/socket/syscalls/mod.rs new file mode 100644 index 00000000..d62cd5a4 --- /dev/null +++ b/src/socket/syscalls/mod.rs @@ -0,0 +1,6 @@ +pub mod accept; +pub mod bind; +pub mod connect; +pub mod listen; +pub mod shutdown; +pub mod socket; diff --git a/src/socket/syscalls/shutdown.rs b/src/socket/syscalls/shutdown.rs new file mode 100644 index 00000000..6e961ea2 --- /dev/null +++ b/src/socket/syscalls/shutdown.rs @@ -0,0 +1,18 @@ +use crate::process::fd_table::Fd; +use crate::socket::ShutdownHow; + +pub async fn sys_shutdown(fd: Fd, how: i32) -> libkernel::error::Result { + let file = crate::sched::current::current_task() + .fd_table + .lock_save_irq() + .get(fd) + .ok_or(libkernel::error::KernelError::BadFd)?; + + let (ops, _ctx) = &mut *file.lock().await; + + ops.as_socket() + .ok_or(libkernel::error::KernelError::NotASocket)? + .shutdown(ShutdownHow::try_from(how)?) + .await?; + Ok(0) +} diff --git a/src/socket/syscalls/socket.rs b/src/socket/syscalls/socket.rs new file mode 100644 index 00000000..b9271ca5 --- /dev/null +++ b/src/socket/syscalls/socket.rs @@ -0,0 +1,36 @@ +use crate::fs::fops::FileOps; +use crate::fs::open_file::OpenFile; +use crate::sched::current::current_task_shared; +use crate::socket::tcp::TcpSocket; +use crate::socket::unix::UnixSocket; +use crate::socket::{AF_INET, AF_UNIX, IPPROTO_TCP, SOCK_DGRAM, SOCK_SEQPACKET, SOCK_STREAM}; +use alloc::boxed::Box; +use alloc::sync::Arc; +use libkernel::error::KernelError; +use libkernel::fs::OpenFlags; + +pub const CLOSE_ON_EXEC: i32 = 0x80000; +pub const NONBLOCK: i32 = 0x800; + +pub async fn sys_socket(domain: i32, type_: i32, protocol: i32) -> libkernel::error::Result { + let _close_on_exec = (type_ & CLOSE_ON_EXEC) != 0; + let _nonblock = (type_ & NONBLOCK) != 0; + // Mask out flags + let type_ = type_ & !(CLOSE_ON_EXEC | NONBLOCK); + let new_socket: Box = match (domain, type_, protocol) { + (AF_INET, SOCK_STREAM, 0) | (AF_INET, SOCK_STREAM, IPPROTO_TCP) => { + Box::new(TcpSocket::new()) + } + (AF_UNIX, SOCK_STREAM, _) => Box::new(UnixSocket::new_stream()), + (AF_UNIX, SOCK_DGRAM, _) => Box::new(UnixSocket::new_datagram()), + (AF_UNIX, SOCK_SEQPACKET, _) => Box::new(UnixSocket::new_seqpacket()), + _ => return Err(KernelError::AddressFamilyNotSupported), + }; + // TODO: Correct flags + let open_file = OpenFile::new(new_socket, OpenFlags::empty()); + let fd = current_task_shared() + .fd_table + .lock_save_irq() + .insert(Arc::new(open_file))?; + Ok(fd.as_raw() as usize) +} diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs new file mode 100644 index 00000000..10b9b5ca --- /dev/null +++ b/src/socket/tcp.rs @@ -0,0 +1,123 @@ +use crate::arch::ArchImpl; +use crate::fs::fops::FileOps; +use crate::fs::open_file::FileCtx; +use crate::socket::sops::SocketOps; +use crate::socket::{ShutdownHow, SockAddr, process_packets, sockets}; +use crate::sync::SpinLock; +use alloc::boxed::Box; +use alloc::collections::BTreeSet; +use alloc::sync::Arc; +use alloc::vec; +use alloc::vec::Vec; +use async_trait::async_trait; +use core::sync::atomic::{AtomicUsize, Ordering}; +use libkernel::error::KernelError; +use libkernel::memory::address::UA; +use libkernel::sync::spinlock::SpinLockIrqGuard; +use smoltcp::iface::SocketHandle; +use smoltcp::socket::tcp::SocketBuffer; +use smoltcp::wire::IpEndpoint; + +const BACKLOG_MAX: usize = 8; +#[expect(dead_code)] +static INUSE_ENDPOINTS: SpinLock> = SpinLock::new(BTreeSet::new()); +#[expect(dead_code)] +static PASSIVE_OPENS_TOTAL: AtomicUsize = AtomicUsize::new(0); +#[expect(dead_code)] +static WRITTEN_BYTES_TOTAL: AtomicUsize = AtomicUsize::new(0); +#[expect(dead_code)] +static READ_BYTES_TOTAL: AtomicUsize = AtomicUsize::new(0); + +pub struct TcpSocket { + handle: SocketHandle, + local_endpoint: SpinLock>, + backlogs: SpinLock>>, + num_backlogs: AtomicUsize, +} + +impl TcpSocket { + pub fn new() -> Self { + let rx_buffer = SocketBuffer::new(vec![0; 4096]); + let tx_buffer = SocketBuffer::new(vec![0; 4096]); + let inner = smoltcp::socket::tcp::Socket::new(rx_buffer, tx_buffer); + let handle = sockets().lock_save_irq().add(inner); + TcpSocket { + handle, + local_endpoint: SpinLock::new(None), + backlogs: SpinLock::new(Vec::new()), + num_backlogs: AtomicUsize::new(0), + } + } + + fn refill_backlog_sockets( + &self, + backlogs: &mut SpinLockIrqGuard>, ArchImpl>, + ) -> Result<(), KernelError> { + let local_endpoint = match *self.local_endpoint.lock_save_irq() { + Some(local_endpoint) => local_endpoint, + None => return Err(KernelError::InvalidValue), + }; + + for _ in 0..(self.num_backlogs.load(Ordering::Relaxed) - backlogs.len()) { + let socket = TcpSocket::new(); + sockets() + .lock_save_irq() + .get_mut::(socket.handle) + .listen(local_endpoint) + .unwrap(); + backlogs.push(Arc::new(socket)); + } + + Ok(()) + } +} + +#[async_trait] +impl SocketOps for TcpSocket { + async fn bind(&self, addr: SockAddr) -> libkernel::error::Result<()> { + *self.local_endpoint.lock_save_irq() = Some(addr.try_into()?); + Ok(()) + } + + async fn listen(&self, backlog: i32) -> Result<(), KernelError> { + let mut backlogs = self.backlogs.lock_save_irq(); + + let new_num_backlogs = (backlog as usize).min(BACKLOG_MAX); + backlogs.truncate(new_num_backlogs); + self.num_backlogs.store(new_num_backlogs, Ordering::SeqCst); + + self.refill_backlog_sockets(&mut backlogs) + } + + async fn read( + &mut self, + _ctx: &mut FileCtx, + _buf: UA, + _count: usize, + ) -> libkernel::error::Result { + todo!() + } + + async fn write( + &mut self, + _ctx: &mut FileCtx, + _buf: UA, + _count: usize, + ) -> libkernel::error::Result { + todo!() + } + + async fn shutdown(&self, _how: ShutdownHow) -> libkernel::error::Result<()> { + sockets() + .lock_save_irq() + .get_mut::(self.handle) + .close(); + + process_packets(); + Ok(()) + } + + fn as_file(self: Box) -> Box { + self + } +} diff --git a/src/socket/unix.rs b/src/socket/unix.rs new file mode 100644 index 00000000..4de847d6 --- /dev/null +++ b/src/socket/unix.rs @@ -0,0 +1,284 @@ +use crate::fs::open_file::FileCtx; +use crate::kernel::kpipe::KPipe; +use crate::socket::{SockAddr, SocketOps}; +use crate::sync::OnceLock; +use crate::sync::SpinLock; +use alloc::boxed::Box; +use alloc::collections::BTreeMap; +use alloc::sync::Arc; +use alloc::vec::Vec; +use async_trait::async_trait; +use core::future::poll_fn; +use core::task::Poll; +use core::task::Waker; +use libkernel::error::{KernelError, Result}; +use libkernel::memory::address::UA; + +/// Registry mapping Unix socket path bytes to endpoint inbox and listening state +struct Endpoint { + inbox: Arc, + listening: bool, + backlog_max: usize, + pending: Vec, + /// Wakers for tasks waiting in accept + waiters: Vec, +} + +/// Registry mapping Unix socket path bytes to endpoint inbox +static UNIX_ENDPOINTS: OnceLock, Endpoint>>> = OnceLock::new(); + +fn endpoints() -> &'static SpinLock, Endpoint>> { + UNIX_ENDPOINTS.get_or_init(|| SpinLock::new(BTreeMap::new())) +} + +enum SocketType { + Stream, + Datagram, + SeqPacket, +} + +pub struct UnixSocket { + socket_type: SocketType, + /// Recv inbox + inbox: Arc, + /// The peer endpoint's inbox + peer_inbox: SpinLock>>, + local_addr: SpinLock>, + connected: SpinLock, + listening: SpinLock, + backlog: SpinLock, + // Shutdown state + rd_shutdown: SpinLock, + wr_shutdown: SpinLock, +} + +impl UnixSocket { + fn new(socket_type: SocketType) -> Self { + UnixSocket { + socket_type, + inbox: Arc::new(KPipe::new().expect("KPipe::new for UnixSocket")), + peer_inbox: SpinLock::new(None), + local_addr: SpinLock::new(None), + connected: SpinLock::new(false), + listening: SpinLock::new(false), + backlog: SpinLock::new(0), + rd_shutdown: SpinLock::new(false), + wr_shutdown: SpinLock::new(false), + } + } + + pub fn new_stream() -> Self { + Self::new(SocketType::Stream) + } + pub fn new_datagram() -> Self { + Self::new(SocketType::Datagram) + } + pub fn new_seqpacket() -> Self { + Self::new(SocketType::SeqPacket) + } + + fn path_bytes(saun: &crate::socket::SockAddrUn) -> Option> { + // Unix path is a sun_path-like fixed-size buffer which may be null-terminated + let mut end = saun.path.len(); + for (i, b) in saun.path.iter().enumerate() { + if *b == 0 { + end = i; + break; + } + } + if end == 0 { + None + } else { + Some(saun.path[..end].to_vec()) + } + } +} + +#[async_trait] +impl SocketOps for UnixSocket { + async fn bind(&self, addr: SockAddr) -> Result<()> { + match addr { + SockAddr::Un(saun) => { + let Some(path) = UnixSocket::path_bytes(&saun) else { + return Err(KernelError::InvalidValue); + }; + // Register endpoint; if already exists, return error + let mut reg = endpoints().lock_save_irq(); + if reg.contains_key(&path) { + return Err(KernelError::InvalidValue); + } + reg.insert( + path, + Endpoint { + inbox: self.inbox.clone(), + listening: false, + backlog_max: 0, + pending: Vec::new(), + waiters: Vec::new(), + }, + ); + *self.local_addr.lock_save_irq() = Some(saun); + Ok(()) + } + _ => Err(KernelError::InvalidValue), + } + } + + async fn connect(&self, addr: SockAddr) -> Result<()> { + match addr { + SockAddr::Un(saun) => { + let Some(path) = UnixSocket::path_bytes(&saun) else { + return Err(KernelError::InvalidValue); + }; + let mut reg = endpoints().lock_save_irq(); + let Some(ep) = reg.get_mut(&path) else { + return Err(KernelError::InvalidValue); + }; + if ep.listening { + if ep.pending.len() >= ep.backlog_max { + return Err(KernelError::TryAgain); + } + let server_sock = UnixSocket::new(SocketType::Stream); + *server_sock.peer_inbox.lock_save_irq() = Some(self.inbox.clone()); + *server_sock.connected.lock_save_irq() = true; + // Client links to listener inbox to write into server + *self.peer_inbox.lock_save_irq() = Some(server_sock.inbox.clone()); + *self.connected.lock_save_irq() = true; + ep.pending.push(server_sock); + // Wake one waiter if present + if let Some(w) = ep.waiters.pop() { + w.wake(); + } + Ok(()) + } else { + // Non-listening endpoint: treat as datagram or pre-bound stream endpoint + *self.peer_inbox.lock_save_irq() = Some(ep.inbox.clone()); + *self.connected.lock_save_irq() = true; + Ok(()) + } + } + _ => Err(KernelError::InvalidValue), + } + } + + async fn listen(&self, backlog: i32) -> Result<()> { + match self.socket_type { + SocketType::Stream | SocketType::SeqPacket => {} + SocketType::Datagram => return Err(KernelError::NotSupported), + } + if backlog < 0 { + return Err(KernelError::InvalidValue); + } + let Some(saun) = &*self.local_addr.lock_save_irq() else { + return Err(KernelError::InvalidValue); + }; + let Some(path) = UnixSocket::path_bytes(saun) else { + return Err(KernelError::InvalidValue); + }; + let mut reg = endpoints().lock_save_irq(); + let Some(ep) = reg.get_mut(&path) else { + return Err(KernelError::InvalidValue); + }; + ep.listening = true; + ep.backlog_max = backlog as usize; + *self.listening.lock_save_irq() = true; + *self.backlog.lock_save_irq() = backlog as usize; + Ok(()) + } + + async fn accept(&self) -> Result> { + { + if !*self.listening.lock_save_irq() { + return Err(KernelError::InvalidValue); + } + } + let path_vec: Vec = { + let guard = self.local_addr.lock_save_irq(); + let Some(saun) = &*guard else { + return Err(KernelError::InvalidValue); + }; + let Some(pv) = UnixSocket::path_bytes(saun) else { + return Err(KernelError::InvalidValue); + }; + pv + }; + + let sock = poll_fn(|cx| { + let mut reg = endpoints().lock_save_irq(); + let Some(ep) = reg.get_mut(&path_vec) else { + return Poll::Ready(Err(KernelError::InvalidValue)); + }; + if let Some(sock) = ep.pending.pop() { + Poll::Ready(Ok(sock)) + } else { + ep.waiters.push(cx.waker().clone()); + Poll::Pending + } + }) + .await?; + + Ok(Box::new(sock)) + } + + async fn read(&mut self, _ctx: &mut FileCtx, buf: UA, count: usize) -> Result { + if count == 0 { + return Ok(0); + } + if *self.rd_shutdown.lock_save_irq() { + return Ok(0); + } + self.inbox.copy_to_user(buf, count).await + } + + async fn write(&mut self, _ctx: &mut FileCtx, buf: UA, count: usize) -> Result { + if count == 0 { + return Ok(0); + } + if *self.wr_shutdown.lock_save_irq() { + return Err(KernelError::BrokenPipe); + } + match self.socket_type { + SocketType::Stream | SocketType::SeqPacket => { + if !*self.connected.lock_save_irq() { + return Err(KernelError::InvalidValue); + } + } + SocketType::Datagram => {} + } + let Some(peer) = self.peer_inbox.lock_save_irq().clone() else { + return Err(KernelError::InvalidValue); + }; + peer.copy_from_user(buf, count).await + } + + async fn shutdown(&self, how: crate::socket::ShutdownHow) -> Result<()> { + match how { + crate::socket::ShutdownHow::Read => { + *self.rd_shutdown.lock_save_irq() = true; + } + crate::socket::ShutdownHow::Write => { + *self.wr_shutdown.lock_save_irq() = true; + } + crate::socket::ShutdownHow::ReadWrite => { + *self.rd_shutdown.lock_save_irq() = true; + *self.wr_shutdown.lock_save_irq() = true; + } + } + Ok(()) + } + + fn as_file(self: Box) -> Box { + self + } +} + +impl Drop for UnixSocket { + fn drop(&mut self) { + if let Some(saun) = &*self.local_addr.lock_save_irq() + && let Some(path) = UnixSocket::path_bytes(saun) + { + let mut reg = endpoints().lock_save_irq(); + reg.remove(&path); + } + } +} diff --git a/usertest/src/main.rs b/usertest/src/main.rs index 3ffeba00..05007100 100644 --- a/usertest/src/main.rs +++ b/usertest/src/main.rs @@ -7,8 +7,13 @@ use std::{ }; use futex_bitset::test_futex_bitset; +use socket::{ + test_tcp_socket_creation, test_unix_socket_basic_functions, test_unix_socket_creation, + test_unix_socket_fork_msg_passing, +}; mod futex_bitset; +mod socket; fn test_sync() { print!("Testing sync syscall ..."); @@ -779,10 +784,7 @@ fn run_test(test_fn: fn()) { let mut status = 0; libc::waitpid(pid, &mut status, 0); if !libc::WIFEXITED(status) || libc::WEXITSTATUS(status) != 0 { - panic!( - "Test failed in child process: {}", - std::io::Error::last_os_error() - ); + panic!("Test failed in child process."); } } } @@ -820,6 +822,10 @@ fn main() { run_test(test_rust_mutex); run_test(test_parking_lot_mutex_timeout); run_test(test_thread_with_name); + run_test(test_tcp_socket_creation); + run_test(test_unix_socket_creation); + run_test(test_unix_socket_basic_functions); + run_test(test_unix_socket_fork_msg_passing); let end = std::time::Instant::now(); println!("All tests passed in {} ms", (end - start).as_millis()); } diff --git a/usertest/src/socket.rs b/usertest/src/socket.rs new file mode 100644 index 00000000..6a2c0f82 --- /dev/null +++ b/usertest/src/socket.rs @@ -0,0 +1,181 @@ +use libc::{AF_INET, AF_UNIX, SOCK_DGRAM, SOCK_STREAM}; +use libc::{accept, bind, connect, listen, shutdown, socket}; + +pub fn test_tcp_socket_creation() { + print!("Testing TCP socket creation ... "); + unsafe { + let sockfd = socket(AF_INET, SOCK_STREAM, 0); + if sockfd < 0 { + panic!("Failed to create TCP socket"); + } + } + println!("OK"); +} + +pub fn test_unix_socket_creation() { + print!("Testing UNIX stream socket creation ... "); + unsafe { + let sockfd = socket(AF_UNIX, SOCK_STREAM, 0); + if sockfd < 0 { + panic!("Failed to create UNIX stream socket"); + } + } + println!("OK"); + + print!("Testing UNIX datagram socket creation ... "); + unsafe { + let sockfd = socket(AF_UNIX, SOCK_DGRAM, 0); + if sockfd < 0 { + panic!("Failed to create UNIX datagram socket"); + } + } + println!("OK"); +} + +pub fn test_unix_socket_basic_functions() { + print!("Testing UNIX socket functions ... "); + let sockfd = unsafe { socket(AF_UNIX, SOCK_STREAM, 0) }; + if sockfd < 0 { + panic!("Failed to create UNIX stream socket for function tests"); + } + let path = "/tmp/test_socket"; + let sockaddr = libc::sockaddr_un { + sun_family: AF_UNIX as u16, + sun_path: { + let mut path_array = [0u8; 108]; + for (i, &b) in path.as_bytes().iter().enumerate() { + path_array[i] = b; + } + path_array + }, + }; + let bind_result = unsafe { + bind( + sockfd, + &sockaddr as *const libc::sockaddr_un as *const libc::sockaddr, + std::mem::size_of::() as u32, + ) + }; + if bind_result < 0 { + panic!("Failed to bind UNIX socket"); + } + let listen_result = unsafe { listen(sockfd, 5) }; + if listen_result < 0 { + panic!("Failed to listen on UNIX socket"); + } + let shutdown_result = unsafe { shutdown(sockfd, 2) }; + if shutdown_result < 0 { + panic!("Failed to shutdown UNIX socket"); + } + println!("OK"); +} + +pub fn test_unix_socket_fork_msg_passing() { + use std::ptr; + + print!("Testing UNIX socket fork message passing ... "); + + // Create server socket, bind and listen before fork + let server_fd = unsafe { socket(AF_UNIX, SOCK_STREAM, 0) }; + if server_fd < 0 { + panic!("Failed to create server UNIX socket"); + } + + let path = "/tmp/uds_fork_test"; + let sockaddr = libc::sockaddr_un { + sun_family: AF_UNIX as u16, + sun_path: { + let mut path_array = [0u8; 108]; + for (i, &b) in path.as_bytes().iter().enumerate() { + path_array[i] = b; + } + path_array + }, + }; + + let ret = unsafe { + bind( + server_fd, + &sockaddr as *const libc::sockaddr_un as *const libc::sockaddr, + std::mem::size_of::() as u32, + ) + }; + if ret < 0 { + panic!("Server bind failed"); + } + let ret = unsafe { listen(server_fd, 1) }; + if ret < 0 { + panic!("Server listen failed"); + } + + let pid = unsafe { libc::fork() }; + if pid < 0 { + panic!("fork failed"); + } + + if pid == 0 { + // Child: client + let client_fd = unsafe { socket(AF_UNIX, SOCK_STREAM, 0) }; + if client_fd < 0 { + panic!("Client socket creation failed"); + } + let ret = unsafe { + connect( + client_fd, + &sockaddr as *const libc::sockaddr_un as *const libc::sockaddr, + std::mem::size_of::() as u32, + ) + }; + if ret < 0 { + panic!("Client connect failed"); + } + + // Send request + let req = b"hello"; + let wr = unsafe { libc::write(client_fd, req.as_ptr() as *const _, req.len()) }; + if wr != req.len() as isize { + panic!("Client write failed"); + } + + // Receive response + let mut resp = [0u8; 5]; + let rd = unsafe { libc::read(client_fd, resp.as_mut_ptr() as *mut _, resp.len()) }; + if rd != resp.len() as isize || &resp != b"world" { + panic!("Client read failed"); + } + + unsafe { libc::close(client_fd) }; + unsafe { libc::_exit(0) }; + } else { + // Parent: server + let conn_fd = unsafe { accept(server_fd, ptr::null_mut(), ptr::null_mut()) }; + if conn_fd < 0 { + panic!("Server accept failed"); + } + + // Receive request + let mut buf = [0u8; 5]; + let rd = unsafe { libc::read(conn_fd, buf.as_mut_ptr() as *mut _, buf.len()) }; + if rd != buf.len() as isize || &buf != b"hello" { + panic!("Server read failed"); + } + + // Send response + let resp = b"world"; + let wr = unsafe { libc::write(conn_fd, resp.as_ptr() as *const _, resp.len()) }; + if wr != resp.len() as isize { + panic!("Server write failed"); + } + + // Wait for child + let mut status = 0; + unsafe { libc::waitpid(pid, &mut status, 0) }; + if !libc::WIFEXITED(status) || libc::WEXITSTATUS(status) != 0 { + panic!("Client process did not exit cleanly"); + } + + unsafe { libc::close(conn_fd) }; + unsafe { libc::close(server_fd) }; + println!("OK"); + } +}