diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 054c83f..4b51ed9 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -4,7 +4,7 @@ unsafe_code = "forbid" [workspace] resolver = "2" -members = ["master", "cluster", "client", "shared", "auth", "sustenet"] +members = ["master", "cluster", "client", "shared", "auth", "sustenet", "tests"] exclude = ["backup"] [workspace.package] @@ -22,13 +22,13 @@ sustenet-master = { path = "master", version = "0.1.4" } sustenet-shared = { path = "shared", version = "0.1.4" } aes = "0.8.4" -public-ip = "0.2.2" - aes-gcm = "0.10.3" base64 = "0.22.1" -config = "0.15.4" -ctrlc = "3.4.5" +bytes = "1.10.1" +config = "0.15.11" dashmap = "6.1.0" -getrandom = "0.3.2" +getrandom = "0.3.3" lazy_static = "1.5.0" -tokio = { version = "1.41.1", default-features = false, features = [] } +num_cpus = "1.17.0" +public-ip = "0.2.2" +tokio = { version = "1.45.1", default-features = false, features = [] } diff --git a/rust/Config.toml b/rust/Config.toml deleted file mode 100644 index 8176b72..0000000 --- a/rust/Config.toml +++ /dev/null @@ -1,12 +0,0 @@ -[all] -server_name = "Default Server" - -max_connections = 0 -port = 0 - -[cluster] -key_name = "cluster_key" -master_ip = "127.0.0.1" -master_port = 0 - -domain_pub_key = "https://www.playreia.com/game/pubkey.pub" # Remove this if you want to use the server's bandwidth to send a key to a user directly. diff --git a/rust/auth/src/lib.rs b/rust/auth/src/lib.rs index e69de29..fde3dd5 100644 --- a/rust/auth/src/lib.rs +++ b/rust/auth/src/lib.rs @@ -0,0 +1,26 @@ +// This is the auth server. It handles all authentication from everywhere. +// I may also eventually make this distributed as well depending on the load. +// The idea is that the auth server will be a trusted endpoint for all clusters and clients. + +// If some random person wants to host a server, how do we handle authentication without trusting them with the password? Simple, we're the middleman. For every server. + +// 1. The client tells the cluster, "Yo, i'd like to authenticate." +// 2. The untrusted cluster tells their client, "Yo, your secret id is `5d600d55-2261-4b12-a543-3dc2f6f54a81`." +// 3. The Trusted Auth Server gets a message from the untrusted cluster and says, "Alright, I'll save your UUID for 30 seconds." +// 4. The client sends their UUID with their credentials. +// 5. The auth server says to the cluster server, "Yeah, they're good. Bye." +// 6. The untrusted cluster tells the client, "Alright, bossman said you're good to go. Come in." +// 7. The cluster server also gets a UID they can track every few seconds to if they wanted to. This is good for security like changing passwords and triggering an optional "sign out all clients". This will be done with a WebSocket server that publishes when a specific UID wants to be logged out. This is safe because usernames and emails are never public. Only UID and Display names. So even if someone knew a specific UID was changed recently, they have no way to really target that user. Especially because we'll be enforcing that you can never have your username the same as your display name UNLESS you have 2FA. If you have 2FA then we care a little less. Still not safe. + +// Why do clusters from untrusted servers need authentication? +// We want to still be able to get their purchases. +// The cluster will get all of the purchases for the UID from the auth server and then send them to the client. +// The cluster will have a copy of the purchases. + +// An idea to improve security so people don't just spam for every possible UID is to have a token the auth server encrypts +// and gives it to the cluster server. The cluster server has to send that token back with the UID to get the purchases. +// It'll also send an expiration time for the token. The token is just the UID and the expiration time. + +// Additionally, on a successful auth, the auth will send the client a token that allows them to access all of their data. +// THIS token should never be shared with the server. It's an actual token to their account. The cluster server can tell the +// client what their UID is at this point to save bandwidth on the auth server. \ No newline at end of file diff --git a/rust/client/Cargo.toml b/rust/client/Cargo.toml index 2d576bb..1a08af3 100644 --- a/rust/client/Cargo.toml +++ b/rust/client/Cargo.toml @@ -12,7 +12,7 @@ homepage.workspace = true workspace = true [dependencies] -lazy_static.workspace = true +bytes.workspace = true sustenet-shared.workspace = true tokio = { workspace = true, features = [ # "socket2", diff --git a/rust/client/src/client.rs b/rust/client/src/client.rs new file mode 100644 index 0000000..c14c2ef --- /dev/null +++ b/rust/client/src/client.rs @@ -0,0 +1,276 @@ +//! Handles connections to a server and sending messages. +use sustenet_shared::logging::{ LogType, Logger }; +use sustenet_shared::lselect; +use sustenet_shared::packets::{ Connection, Diagnostics, Messaging }; + +use std::io::Error; +use std::sync::LazyLock; + +use bytes::Bytes; +use tokio::io::AsyncReadExt; +use tokio::io::{ self, AsyncWriteExt }; +use tokio::net::TcpStream; +use tokio::sync::{ broadcast, mpsc }; + +/// Global logger for the client module. +pub static LOGGER: LazyLock = LazyLock::new(|| Logger::new(LogType::Client)); + +#[derive(Debug, Clone, PartialEq)] +pub struct ClusterInfo { + pub name: String, + pub ip: String, + pub port: u16, + pub max_connections: u32, +} + + + +/// Events emitted by the client to notify listeners. +/// +/// Should be handled with `event_receiver` or `next_event` externally. +#[derive(Debug, Clone)] +pub enum ClientEvent { + Connected, + Disconnected, + CommandSent(u8), + MessageSent(Bytes), + CommandReceived(u8), + MessageReceived(Bytes), + Error(String), +} + +/// Handles connection to a master or cluster server, and provides async channels for interaction. +pub struct Client { + /// Sends messages to the server. + sender: mpsc::Sender, + /// Sends events to listeners. + event_tx: broadcast::Sender, + /// Receives events about connection state and activity. + event_rx: broadcast::Receiver, + /// Cluster servers this client knows about. + pub cluster_servers: Vec, +} + +impl Client { + /// Attempts to connect to a server at the specified address and port and returns a `ClientHandle`. + pub async fn connect(address: &str, port: u16) -> io::Result { + let addr = format!("{}:{}", address, port); + LOGGER.info(&format!("Connecting to {addr}...")); + + // Establish a connection to the server. + let mut stream = match TcpStream::connect(&addr).await { + Ok(s) => { + LOGGER.success(&format!("Connected to {addr}")); + s + } + Err(e) => { + LOGGER.error(&format!("Failed to connect to {addr}")); + return Err(Error::new(e.kind(), format!("Failed to connect to ({addr}): {e}"))); + } + }; + + let (sender, mut receiver) = mpsc::channel::(64); + let (event_tx, event_rx) = broadcast::channel::(16); + + let sender_clone = sender.clone(); + let event_tx_clone = event_tx.clone(); + + tokio::spawn(async move { + let (reader, mut writer) = stream.split(); + let mut reader = io::BufReader::new(reader); + + lselect!( + // Handle local requests to send a message to the server. + msg = receiver.recv() => { + match msg { + Some(msg) => { + if msg.is_empty() { + LOGGER.warning("Received empty message, shutting down client"); + Self::handle_shutdown(writer, event_tx_clone).await; + break; + } + + LOGGER.debug(&format!("Sending message: {:?}", msg)); + if let Err(e) = writer.write_all(&msg).await { + let msg = format!("Failed to send message to server: {e}"); + LOGGER.error(&msg); + let _ = event_tx_clone.send(ClientEvent::Error(msg)); + } else { + let _ = event_tx_clone.send(ClientEvent::MessageSent(msg)); + } + }, + None => { + LOGGER.warning("Connection closed"); + Self::handle_shutdown(writer, event_tx_clone).await; + break; + } + } + }, + command = reader.read_u8() => { + match command { + Ok(command) => { + LOGGER.debug(&format!("Received command: {command}")); + + Self::handle_command(command, &sender_clone, &mut reader, &mut writer, &event_tx_clone).await; + + // Notify listeners about the received message. + let _ = event_tx_clone.send(ClientEvent::CommandReceived(command)); + }, + Err(e) => { + let msg = format!("Failed to read command from server: {e}"); + LOGGER.error(&msg); + let _ = event_tx_clone.send(ClientEvent::Error(msg)); + } + } + } + ) + }); + + // Notify connected immediately. + let _ = event_tx.send(ClientEvent::Connected); + + Ok(Client { + sender, + event_tx, + event_rx, + cluster_servers: Vec::new(), + }) + } + + async fn handle_shutdown( + mut writer: tokio::net::tcp::WriteHalf<'_>, + event_tx_clone: broadcast::Sender + ) { + if let Err(e) = writer.shutdown().await { + let msg = format!("Failed to shutdown writer: {e}"); + LOGGER.error(&msg); + let _ = event_tx_clone.send(ClientEvent::Error(msg)); + } + let _ = event_tx_clone.send(ClientEvent::Disconnected); + } + + /// Handles commands received from the server. + /// This function is called in a separate task to handle incoming commands. + async fn handle_command( + command: u8, + _sender: &mpsc::Sender, + _reader: &mut io::BufReader>, + _writer: &mut tokio::net::tcp::WriteHalf<'_>, + event_tx: &broadcast::Sender + ) { + // Todo: Handle commands. + // Handle the command received from the server. + match command { + x if x == (Connection::Connect as u8) => Self::handle_connect_command().await, + x if x == (Connection::Disconnect as u8) => Self::handle_disconnect_command().await, + x if x == (Connection::Authenticate as u8) => Self::handle_authenticate_command().await, + + x if x == (Messaging::SendGlobalMessage as u8) => { + Self::handle_send_global_message_command().await + } + x if x == (Messaging::SendPrivateMessage as u8) => { + Self::handle_send_private_message_command().await + } + x if x == (Messaging::SendPartyMessage as u8) => { + Self::handle_send_party_message_command().await + } + x if x == (Messaging::SendLocalMessage as u8) => { + Self::handle_send_local_message_command().await + } + + x if x == (Diagnostics::CheckServerType as u8) => { + Self::handle_check_server_type_command().await + } + x if x == (Diagnostics::CheckServerUptime as u8) => { + Self::handle_check_server_uptime_command().await + } + x if x == (Diagnostics::CheckServerPlayerCount as u8) => { + Self::handle_check_server_player_count_command().await + } + + _ => Self::handle_extra_command(command, event_tx).await, + } + } + + async fn handle_connect_command() { + todo!(); + } + async fn handle_disconnect_command() { + todo!(); + } + async fn handle_authenticate_command() { + todo!(); + } + + async fn handle_send_global_message_command() { + todo!(); + } + async fn handle_send_private_message_command() { + todo!(); + } + async fn handle_send_party_message_command() { + todo!(); + } + async fn handle_send_local_message_command() { + todo!(); + } + + async fn handle_check_server_type_command() { + todo!(); + } + async fn handle_check_server_uptime_command() { + todo!(); + } + async fn handle_check_server_player_count_command() { + todo!(); + } + + async fn handle_extra_command(command: u8, event_tx: &broadcast::Sender) { + let msg = format!("Unknown command received: {command}"); + LOGGER.error(&msg); + let _ = event_tx.send(ClientEvent::Error(msg)); + } + + /// Sends data to the server. + pub async fn send(&self, msg: Bytes) -> Result<(), mpsc::error::SendError> { + self.sender.send(msg.clone()).await?; + let _ = self.event_tx.send(ClientEvent::MessageSent(msg)); + Ok(()) + } + + /// Returns a cloneable event receiver for status updates. + pub fn event_receiver(&self) -> broadcast::Receiver { + self.event_rx.resubscribe() + } + + /// Returns the next event from the event receiver. + pub async fn next_event(&mut self) -> Option { + let event = self.event_rx.recv().await; + match event { + Ok(event) => Some(event), + Err(_) => None, + } + } + + // region: Cluster Server Utilities + pub fn get_cluster_servers(&self) -> &[ClusterInfo] { + &self.cluster_servers + } + + pub fn add_cluster_server(&mut self, server: ClusterInfo) { + self.cluster_servers.push(server); + } + + pub fn add_cluster_servers(&mut self, servers: Vec) { + self.cluster_servers.extend(servers); + } + + pub fn remove_cluster_server(&mut self, server: &ClusterInfo) { + self.cluster_servers.retain(|s| s != server); + } + + pub fn clear_cluster_servers(&mut self) { + self.cluster_servers.clear(); + } + // endregion: Cluster Server Utilities +} diff --git a/rust/client/src/lib.rs b/rust/client/src/lib.rs index 0fa54c2..d5f2848 100644 --- a/rust/client/src/lib.rs +++ b/rust/client/src/lib.rs @@ -1,305 +1,475 @@ -use sustenet_shared as shared; - -use std::net::{ IpAddr, Ipv4Addr }; -use std::str::FromStr; -use std::sync::{ Arc, LazyLock }; - -use tokio::io::{ AsyncReadExt, AsyncWriteExt, BufReader }; -use tokio::net::TcpStream; -use tokio::sync::mpsc::Sender; -use tokio::sync::{ RwLock, mpsc }; - -use sustenet_shared::ClientPlugin; -use shared::logging::{ LogType, Logger }; -use shared::packets::cluster::ToClient; -use shared::packets::master::ToUnknown; -use shared::utils::constants::{ DEFAULT_IP, MASTER_PORT }; -use shared::{ lread_string, lselect }; - -lazy_static::lazy_static! { - pub static ref CLUSTER_SERVERS: Arc>> = Arc::new( - RwLock::new(Vec::new()) - ); - pub static ref CONNECTION: Arc>> = Arc::new( - RwLock::new( - Some(Connection { - ip: get_ip(DEFAULT_IP), - port: MASTER_PORT, - connection_type: ConnectionType::MasterServer, - }) - ) - ); -} -pub static LOGGER: LazyLock = LazyLock::new(|| Logger::new(LogType::Cluster)); - -#[derive(Debug, Clone)] -pub struct ClusterInfo { - pub name: String, - pub ip: String, - pub port: u16, - pub max_connections: u32, -} - -#[derive(Clone, Copy)] -pub struct Connection { - pub ip: IpAddr, - pub port: u16, - pub connection_type: ConnectionType, -} - -impl From for Connection { - fn from(info: ClusterInfo) -> Self { - Connection { - ip: IpAddr::from_str(info.ip.as_str()).expect("Failed to parse the IP."), - port: info.port, - connection_type: ConnectionType::ClusterServer, - } - } -} - -#[derive(Clone, Copy, Eq, PartialEq)] -pub enum ConnectionType { - MasterServer, - ClusterServer, - None, -} - -impl std::fmt::Display for ConnectionType { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - ConnectionType::MasterServer => write!(f, "Master Server"), - ConnectionType::ClusterServer => write!(f, "Cluster Server"), - ConnectionType::None => write!(f, "Unknown"), - } - } -} - -pub fn get_ip(ip: &str) -> IpAddr { - IpAddr::from_str(ip).unwrap_or( - IpAddr::from_str(DEFAULT_IP).unwrap_or(IpAddr::V4(Ipv4Addr::LOCALHOST)) - ) -} - -pub async fn cleanup() {} - -pub async fn start

(plugin: P) where P: ClientPlugin + Send + Sync + 'static { - // Get the connection LOGGER.information. - let connection = *CONNECTION.read().await; - if connection.is_none() { - return; - } - let connection = connection.unwrap(); - - let ip = connection.ip; - let port = connection.port; - let connection_type = connection.connection_type; - { - *CONNECTION.write().await = None; - } - - let (tx, mut rx) = mpsc::channel::>(10); - plugin.set_sender(tx.clone()); - - let handler = tokio::spawn(async move { - LOGGER.warning(format!("Connecting to the {connection_type}...").as_str()); - let mut stream = TcpStream::connect(format!("{}:{}", ip, port)).await.expect( - format!("Failed to connect to the {connection_type} at {ip}:{port}.").as_str() - ); - LOGGER.success(format!("Connected to the {connection_type} at {ip}:{port}.").as_str()); - - let (reader, mut writer) = stream.split(); - let mut reader = BufReader::new(reader); - - lselect! { - command = reader.read_u8() => { - if command.is_err() { - continue; - } - - LOGGER.info(format!("Received data: {:?}", command).as_str()); - - match connection_type { - ConnectionType::MasterServer => match command.unwrap() { - x if x == ToUnknown::SendClusters as u8 => { - let amount = match reader.read_u8().await { - Ok(amount) => amount, - Err(_) => { - LOGGER.error("Failed to read the amount of clusters."); - continue; - } - }; - - let mut cluster_servers_tmp = Vec::new(); - for _ in 0..amount { - let name = lread_string!(reader, |msg| LOGGER.error(msg), "cluster name"); - let ip = lread_string!(reader, |msg| LOGGER.error(msg), "cluster IP"); - let port = match reader.read_u16().await { - Ok(port) => port, - Err(_) => { - LOGGER.error("Failed to read the cluster port."); - continue; - } - }; - let max_connections = match reader.read_u32().await { - Ok(max_connections) => max_connections, - Err(_) => { - LOGGER.error("Failed to read the cluster max connections."); - continue; - } - }; - - cluster_servers_tmp.push(ClusterInfo { - name, - ip, - port, - max_connections, - }); - } - - { - { - let mut cluster_servers = CLUSTER_SERVERS.write().await; - *cluster_servers = cluster_servers_tmp; - - LOGGER.success(format!("Received {amount} Cluster servers from the {connection_type}.").as_str()); - println!("{:?}", *cluster_servers); - } - } - }, - cmd => plugin.receive_master(tx.clone(), cmd, &mut reader).await, - } - ConnectionType::ClusterServer => match command.unwrap() { - x if x == ToClient::SendClusters as u8 => { - let amount = match reader.read_u8().await { - Ok(amount) => amount, - Err(_) => { - LOGGER.error("Failed to read the amount of clusters."); - continue; - } - }; - - let mut cluster_servers_tmp = Vec::new(); - for _ in 0..amount { - let name = lread_string!(reader, |msg| LOGGER.error(msg), "cluster name"); - let ip = lread_string!(reader, |msg| LOGGER.error(msg), "cluster IP"); - let port = match reader.read_u16().await { - Ok(port) => port, - Err(_) => { - LOGGER.error("Failed to read the cluster port."); - continue; - } - }; - let max_connections = match reader.read_u32().await { - Ok(max_connections) => max_connections, - Err(_) => { - LOGGER.error("Failed to read the cluster max connections."); - continue; - } - }; - - cluster_servers_tmp.push(ClusterInfo { - name, - ip, - port, - max_connections, - }); - } - - { - { - let mut cluster_servers = CLUSTER_SERVERS.write().await; - *cluster_servers = cluster_servers_tmp; - - LOGGER.success(format!("Received {amount} Cluster servers from the {connection_type}.").as_str()); - println!("{:?}", *cluster_servers); - } - } - }, - x if x == ToClient::DisconnectCluster as u8 => todo!(), - x if x == ToClient::LeaveCluster as u8 => todo!(), - - x if x == ToClient::VersionOfKey as u8 => todo!(), - x if x == ToClient::SendPubKey as u8 => todo!(), - x if x == ToClient::Authenticate as u8 => todo!(), - - x if x == ToClient::Move as u8 => todo!(), - cmd => plugin.receive_cluster(tx.clone(), cmd, &mut reader).await, - } - _ => (), - } - } - result = rx.recv() => { - if let Some(data) = result { - if data.is_empty() { - writer.shutdown().await.expect("Failed to shutdown the writer."); - LOGGER.info("Closing connection..."); - break; - } - - writer.write_all(&data).await.expect("Failed to write to the Server."); - writer.flush().await.expect("Failed to flush the writer."); - LOGGER.info(format!("Sent {data:?} as data to the {connection_type}.").as_str()); - } else { - writer.shutdown().await.expect("Failed to shutdown the writer."); - LOGGER.info("Shutting down connection..."); - break; - } - } - } - }); - - let _ = handler.await; -} - -pub async fn send_data(tx: &Sender>, data: Box<[u8]>) { - tx.send(data).await.expect("Failed to send data to the Server."); -} - -pub async fn join_cluster(tx: &Sender>, id: usize) { - if id < (0 as usize) { - LOGGER.error("Failed to join a cluster. The cluster ID is invalid (less than 0)."); - return; - } - - let cluster_servers = CLUSTER_SERVERS.read().await; - if cluster_servers.is_empty() { - LOGGER.error("Failed to join a cluster. No cluster servers are available."); - return; - } - - if id >= cluster_servers.len() { - LOGGER.error( - "Failed to join a cluster. The cluster ID is invalid (greater than the amount of clusters)." - ); - return; - } - - let cluster = ( - match cluster_servers.get(id) { - Some(cluster) => cluster, - None => { - LOGGER.error("Failed to join a cluster. The cluster ID is invalid."); - return; - } - } - ).clone(); - - LOGGER.success(format!("Client is joining cluster {}", cluster.name).as_str()); - - let connection = match std::panic::catch_unwind(|| Connection::from(cluster)) { - Ok(connection) => connection, - Err(_) => { - LOGGER.error("Failed to create a connection with the Cluster Server."); - return; - } - }; - { - // Overwrite the current connection with the cluster connection. - *CONNECTION.write().await = Some(connection); - stop(tx).await; - } -} - -async fn stop(tx: &Sender>) { - tx.send(Box::new([])).await.expect("Failed to shutdown."); -} +//! This library provides a client for connecting to a server. +//! It includes functionality for sending and receiving messages, as well as handling events. +pub mod client; + +pub use client::Client; + +// use std::{ io::{ Error, ErrorKind }, sync::LazyLock }; + +// use bytes::Bytes; +// use sustenet_shared::logging::{ LogType, Logger }; +// use tokio::io; +// use tokio::{ io::{ AsyncWriteExt, BufReader, BufWriter, split }, net::TcpStream, sync::mpsc }; + +// pub static LOGGER: LazyLock = LazyLock::new(|| Logger::new(LogType::Client)); + +// //#region Connection +// #[derive(Debug, Clone)] +// pub struct ClusterInfo { +// pub name: String, +// pub ip: String, +// pub port: u16, +// pub max_connections: u32, +// } + +// #[derive(Clone, Copy, Eq, PartialEq)] +// pub enum ConnectionType { +// MasterServer, +// ClusterServer, +// None, +// } +// //#endregion + +// //#region Client +// #[derive(Debug)] +// pub enum ClientCommand { +// Send(Bytes), +// Shutdown, +// } + +// #[derive(Debug)] +// pub enum ClientEvent { +// Received(Bytes), +// Disconnected, +// } + +// struct ClientHandle { +// pub cluster_servers: Vec, +// pub cmd_tx: mpsc::Sender, +// pub event_rx: mpsc::Receiver, +// } + +// impl ClientHandle { +// /// Connects to the server and creates a new client handle at the given address and port. +// /// Returns an error if the connection fails. +// pub async fn connect(address: &str, port: u16) -> io::Result { +// LOGGER.info(&format!("Connecting to {}:{}...", address, port)).await; + +// let (cmd_tx, mut cmd_rx) = mpsc::channel(10); +// let (event_tx, event_rx) = mpsc::channel(10); +// let addr = format!("{}:{}", address, port); + +// let stream = match TcpStream::connect(&addr).await { +// Ok(s) => { +// LOGGER.success(&format!("Connected to {}:{}", address, port)).await; +// s +// } +// Err(e) => { +// LOGGER.error(&format!("Failed to connect to {}:{}", address, port)).await; +// return Err( +// Error::new( +// ErrorKind::ConnectionRefused, +// format!("Failed to connect to ({}:{}): {}", addr, port, e) +// ) +// ); +// } +// }; + +// let (reader, writer) = split(stream); +// let mut reader = BufReader::new(reader); +// let mut writer = BufWriter::new(writer); + +// tokio::spawn(async move { +// loop { +// tokio::select! { +// Some(cmd) = cmd_rx.recv() => { +// match cmd { +// ClientCommand::Send(data) => { +// if data.is_empty() { +// LOGGER.info("Received empty data. Shutting down...").await; +// cmd_tx.send(ClientCommand::Shutdown).await.expect("Failed to send shutdown command."); +// break; +// } +// if writer.write_all(&data).await.is_err() { +// LOGGER.error("Failed to write to the server. Disconnecting...").await; +// let _ = event_tx.send(ClientEvent::Disconnected).await; +// break; +// } +// if writer.flush().await.is_err() { +// LOGGER.error("Failed to flush the writer. Disconnecting...").await; +// let _ = event_tx.send(ClientEvent::Disconnected).await; +// break; +// } + +// LOGGER.info(&format!("Sent data: {data:?}")).await; +// } +// ClientCommand::Shutdown => { +// writer.shutdown().await.expect("Failed to shutdown the writer."); +// LOGGER.info("Closing connection...").await; +// break; +// }, +// } +// } + +// // read_result = reader.read_until(b'\n', &mut read_buf) => { +// // match read_result { +// // Ok(0) => { +// // let _ = event_tx.send(ClientEvent::Disconnected).await; +// // break; +// // } +// // Ok(n) => { +// // let data = Bytes::copy_from_slice(&read_buf[..n]); +// // let _ = event_tx.send(ClientEvent::Received(data)).await; +// // read_buf.clear(); +// // } +// // Err(_) => { +// // let _ = event_tx.send(ClientEvent::Disconnected).await; +// // break; +// // } +// // } +// // } +// else => break, +// } +// } +// }); + +// Ok(ClientHandle { +// cluster_servers: Vec::new(), +// cmd_tx, +// event_rx, +// }) +// } + +// pub async fn send(&self, data: Vec) -> Result<(), mpsc::error::SendError> { +// self.cmd_tx.send(ClientCommand::Send(data)).await +// } + +// /// Waits for the next event from the client +// pub async fn next_event(&mut self) -> Option { +// self.event_rx.recv().await +// } + +// pub fn add_cluster_server(&mut self, server: ClusterInfo) { +// self.cluster_servers.push(server); +// } + +// pub fn add_cluster_servers(&mut self, servers: Vec) { +// self.cluster_servers.extend(servers); +// } + +// pub fn get_cluster_servers(&self) -> &Vec { +// &self.cluster_servers +// } + +// async fn handle_data(&self) { +// let connection_type = self.connection.connection_type; +// LOGGER.warning(format!("Connecting to the {connection_type}...").as_str()); +// } +// } +//#endregion + +// use sustenet_shared as shared; + +// use std::net::{ IpAddr, Ipv4Addr }; +// use std::str::FromStr; +// use std::sync::{ Arc, LazyLock }; + +// use tokio::io::{ AsyncReadExt, AsyncWriteExt, BufReader }; +// use tokio::net::TcpStream; +// use tokio::sync::mpsc::Sender; +// use tokio::sync::{ RwLock, mpsc }; + +// use sustenet_shared::ClientPlugin; +// use shared::logging::{ LogType, Logger }; +// use shared::packets::cluster::ToClient; +// use shared::packets::master::ToUnknown; +// use shared::utils::constants::{ DEFAULT_IP, MASTER_PORT }; +// use shared::{ lread_string, lselect }; + +// lazy_static::lazy_static! { +// pub static ref CLUSTER_SERVERS: Arc>> = Arc::new( +// RwLock::new(Vec::new()) +// ); +// pub static ref CONNECTION: Arc>> = Arc::new( +// RwLock::new( +// Some(Connection { +// ip: get_ip(DEFAULT_IP), +// port: MASTER_PORT, +// connection_type: ConnectionType::MasterServer, +// }) +// ) +// ); +// } +// pub static LOGGER: LazyLock = LazyLock::new(|| Logger::new(LogType::Cluster)); + +// #[derive(Debug, Clone)] +// pub struct ClusterInfo { +// pub name: String, +// pub ip: String, +// pub port: u16, +// pub max_connections: u32, +// } + +// #[derive(Clone, Copy)] +// pub struct Connection { +// pub ip: IpAddr, +// pub port: u16, +// pub connection_type: ConnectionType, +// } + +// impl From for Connection { +// fn from(info: ClusterInfo) -> Self { +// Connection { +// ip: IpAddr::from_str(info.ip.as_str()).expect("Failed to parse the IP."), +// port: info.port, +// connection_type: ConnectionType::ClusterServer, +// } +// } +// } + +// #[derive(Clone, Copy, Eq, PartialEq)] +// pub enum ConnectionType { +// MasterServer, +// ClusterServer, +// None, +// } + +// impl std::fmt::Display for ConnectionType { +// fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +// match self { +// ConnectionType::MasterServer => write!(f, "Master Server"), +// ConnectionType::ClusterServer => write!(f, "Cluster Server"), +// ConnectionType::None => write!(f, "Unknown"), +// } +// } +// } + +// pub fn get_ip(ip: &str) -> IpAddr { +// IpAddr::from_str(ip).unwrap_or( +// IpAddr::from_str(DEFAULT_IP).unwrap_or(IpAddr::V4(Ipv4Addr::LOCALHOST)) +// ) +// } + +// pub async fn cleanup() {} + +// pub async fn start

(plugin: P) where P: ClientPlugin + Send + Sync + 'static { +// // Get the connection LOGGER.information. +// let connection = *CONNECTION.read().await; +// if connection.is_none() { +// return; +// } +// let connection = connection.unwrap(); + +// let ip = connection.ip; +// let port = connection.port; +// let connection_type = connection.connection_type; +// { +// *CONNECTION.write().await = None; +// } + +// let (tx, mut rx) = mpsc::channel::>(10); +// plugin.set_sender(tx.clone()); + +// let handler = tokio::spawn(async move { +// LOGGER.warning(format!("Connecting to the {connection_type}...").as_str()); +// let mut stream = TcpStream::connect(format!("{}:{}", ip, port)).await.expect( +// format!("Failed to connect to the {connection_type} at {ip}:{port}.").as_str() +// ); +// LOGGER.success(format!("Connected to the {connection_type} at {ip}:{port}.").as_str()); + +// let (reader, mut writer) = stream.split(); +// let mut reader = BufReader::new(reader); + +// lselect! { +// command = reader.read_u8() => { +// if command.is_err() { +// continue; +// } + +// LOGGER.info(format!("Received data: {:?}", command).as_str()); + +// match connection_type { +// ConnectionType::MasterServer => match command.unwrap() { +// x if x == ToUnknown::SendClusters as u8 => { +// let amount = match reader.read_u8().await { +// Ok(amount) => amount, +// Err(_) => { +// LOGGER.error("Failed to read the amount of clusters."); +// continue; +// } +// }; + +// let mut cluster_servers_tmp = Vec::new(); +// for _ in 0..amount { +// let name = lread_string!(reader, |msg| LOGGER.error(msg), "cluster name"); +// let ip = lread_string!(reader, |msg| LOGGER.error(msg), "cluster IP"); +// let port = match reader.read_u16().await { +// Ok(port) => port, +// Err(_) => { +// LOGGER.error("Failed to read the cluster port."); +// continue; +// } +// }; +// let max_connections = match reader.read_u32().await { +// Ok(max_connections) => max_connections, +// Err(_) => { +// LOGGER.error("Failed to read the cluster max connections."); +// continue; +// } +// }; + +// cluster_servers_tmp.push(ClusterInfo { +// name, +// ip, +// port, +// max_connections, +// }); +// } + +// { +// { +// let mut cluster_servers = CLUSTER_SERVERS.write().await; +// *cluster_servers = cluster_servers_tmp; + +// LOGGER.success(format!("Received {amount} Cluster servers from the {connection_type}.").as_str()); +// println!("{:?}", *cluster_servers); +// } +// } +// }, +// cmd => plugin.receive_master(tx.clone(), cmd, &mut reader).await, +// } +// ConnectionType::ClusterServer => match command.unwrap() { +// x if x == ToClient::SendClusters as u8 => { +// let amount = match reader.read_u8().await { +// Ok(amount) => amount, +// Err(_) => { +// LOGGER.error("Failed to read the amount of clusters."); +// continue; +// } +// }; + +// let mut cluster_servers_tmp = Vec::new(); +// for _ in 0..amount { +// let name = lread_string!(reader, |msg| LOGGER.error(msg), "cluster name"); +// let ip = lread_string!(reader, |msg| LOGGER.error(msg), "cluster IP"); +// let port = match reader.read_u16().await { +// Ok(port) => port, +// Err(_) => { +// LOGGER.error("Failed to read the cluster port."); +// continue; +// } +// }; +// let max_connections = match reader.read_u32().await { +// Ok(max_connections) => max_connections, +// Err(_) => { +// LOGGER.error("Failed to read the cluster max connections."); +// continue; +// } +// }; + +// cluster_servers_tmp.push(ClusterInfo { +// name, +// ip, +// port, +// max_connections, +// }); +// } + +// { +// { +// let mut cluster_servers = CLUSTER_SERVERS.write().await; +// *cluster_servers = cluster_servers_tmp; + +// LOGGER.success(format!("Received {amount} Cluster servers from the {connection_type}.").as_str()); +// println!("{:?}", *cluster_servers); +// } +// } +// }, +// x if x == ToClient::DisconnectCluster as u8 => todo!(), +// x if x == ToClient::LeaveCluster as u8 => todo!(), + +// x if x == ToClient::VersionOfKey as u8 => todo!(), +// x if x == ToClient::SendPubKey as u8 => todo!(), +// x if x == ToClient::Authenticate as u8 => todo!(), + +// x if x == ToClient::Move as u8 => todo!(), +// cmd => plugin.receive_cluster(tx.clone(), cmd, &mut reader).await, +// } +// _ => (), +// } +// } +// result = rx.recv() => { +// if let Some(data) = result { +// if data.is_empty() { +// writer.shutdown().await.expect("Failed to shutdown the writer."); +// LOGGER.info("Closing connection..."); +// break; +// } + +// writer.write_all(&data).await.expect("Failed to write to the Server."); +// writer.flush().await.expect("Failed to flush the writer."); +// LOGGER.info(format!("Sent {data:?} as data to the {connection_type}.").as_str()); +// } else { +// writer.shutdown().await.expect("Failed to shutdown the writer."); +// LOGGER.info("Shutting down connection..."); +// break; +// } +// } +// } +// }); + +// let _ = handler.await; +// } + +// pub async fn send_data(tx: &Sender>, data: Box<[u8]>) { +// tx.send(data).await.expect("Failed to send data to the Server."); +// } + +// pub async fn join_cluster(tx: &Sender>, id: usize) { +// if id < (0 as usize) { +// LOGGER.error("Failed to join a cluster. The cluster ID is invalid (less than 0)."); +// return; +// } + +// let cluster_servers = CLUSTER_SERVERS.read().await; +// if cluster_servers.is_empty() { +// LOGGER.error("Failed to join a cluster. No cluster servers are available."); +// return; +// } + +// if id >= cluster_servers.len() { +// LOGGER.error( +// "Failed to join a cluster. The cluster ID is invalid (greater than the amount of clusters)." +// ); +// return; +// } + +// let cluster = ( +// match cluster_servers.get(id) { +// Some(cluster) => cluster, +// None => { +// LOGGER.error("Failed to join a cluster. The cluster ID is invalid."); +// return; +// } +// } +// ).clone(); + +// LOGGER.success(format!("Client is joining cluster {}", cluster.name).as_str()); + +// let connection = match std::panic::catch_unwind(|| Connection::from(cluster)) { +// Ok(connection) => connection, +// Err(_) => { +// LOGGER.error("Failed to create a connection with the Cluster Server."); +// return; +// } +// }; +// { +// // Overwrite the current connection with the cluster connection. +// *CONNECTION.write().await = Some(connection); +// stop(tx).await; +// } +// } + +// async fn stop(tx: &Sender>) { +// tx.send(Box::new([])).await.expect("Failed to shutdown."); +// } diff --git a/rust/client/src/main.rs b/rust/client/src/main.rs index 8716635..cb17d23 100644 --- a/rust/client/src/main.rs +++ b/rust/client/src/main.rs @@ -1,71 +1,96 @@ -use sustenet_shared as shared; +use sustenet_client::Client; +use sustenet_shared::lselect; -use tokio::sync::mpsc::Sender; - -use shared::{ lselect, utils }; -use sustenet_client::{ CONNECTION, LOGGER, cleanup, start}; - -struct DefaultPlugin { - sender: std::sync::OnceLock>>, -} -impl shared::ClientPlugin for DefaultPlugin { - fn set_sender(&self, tx: Sender>) { - // Set the sender - if self.sender.set(tx).is_err() { - LOGGER.error("Failed to set sender"); +#[tokio::main] +pub async fn main() { + let address = "127.0.0.1"; + let port = 6256; + let mut client = match Client::connect(address, port).await { + Ok(client) => client, + Err(e) => { + eprintln!("Failed to connect to the server: {e}"); + return; + } + }; + lselect! { + Some(event) = client.next_event() => { + println!("Received event: {:?}", event); + }, + else => { + println!("No event received."); + break; } } +} - fn receive_master<'plug>( - &self, - _tx: Sender>, - command: u8, - _reader: &'plug mut tokio::io::BufReader> - ) -> std::pin::Pin + Send>> { - Box::pin(async move { - match command { - 0 => println!("Command 0 received"), - 1 => println!("Command 1 received"), - _ => println!("Unknown command received"), - } - }) - } +// use sustenet_shared as shared; - fn receive_cluster<'plug>( - &self, - _tx: Sender>, - command: u8, - _reader: &'plug mut tokio::io::BufReader> - ) -> std::pin::Pin + Send>> { - Box::pin(async move { - match command { - 0 => println!("Command 0 received"), - 1 => println!("Command 1 received"), - _ => println!("Unknown command received"), - } - }) - } +// use tokio::sync::mpsc::Sender; - fn info(&self, _: &str) {} -} +// use shared::{ lselect, utils }; +// use sustenet_client::{ CONNECTION, LOGGER, cleanup, start}; -#[tokio::main] -pub async fn main() { - let mut shutdown_rx = utils::shutdown_channel().expect("Error creating shutdown channel."); +// struct DefaultPlugin { +// sender: std::sync::OnceLock>>, +// } +// impl shared::ClientPlugin for DefaultPlugin { +// fn set_sender(&self, tx: Sender>) { +// // Set the sender +// if self.sender.set(tx).is_err() { +// LOGGER.error("Failed to set sender"); +// } +// } - lselect! { - _ = shutdown_rx.recv() => { - LOGGER.warning("Shutting down..."); - break; - } - _ = start(DefaultPlugin { sender: std::sync::OnceLock::new() }) => { - if CONNECTION.read().await.is_none() { - LOGGER.warning("Closing client..."); - break; - } - } - } +// fn receive_master<'plug>( +// &self, +// _tx: Sender>, +// command: u8, +// _reader: &'plug mut tokio::io::BufReader> +// ) -> std::pin::Pin + Send>> { +// Box::pin(async move { +// match command { +// 0 => println!("Command 0 received"), +// 1 => println!("Command 1 received"), +// _ => println!("Unknown command received"), +// } +// }) +// } - cleanup().await; - LOGGER.success("The Client has been shut down."); -} +// fn receive_cluster<'plug>( +// &self, +// _tx: Sender>, +// command: u8, +// _reader: &'plug mut tokio::io::BufReader> +// ) -> std::pin::Pin + Send>> { +// Box::pin(async move { +// match command { +// 0 => println!("Command 0 received"), +// 1 => println!("Command 1 received"), +// _ => println!("Unknown command received"), +// } +// }) +// } + +// fn info(&self, _: &str) {} +// } + +// #[tokio::main] +// pub async fn main() { +// let mut shutdown_rx = utils::shutdown_channel().expect("Error creating shutdown channel."); + +// lselect! { +// _ = shutdown_rx.recv() => { +// LOGGER.warning("Shutting down..."); +// break; +// } +// _ = start(DefaultPlugin { sender: std::sync::OnceLock::new() }) => { +// if CONNECTION.read().await.is_none() { +// LOGGER.warning("Closing client..."); +// break; +// } +// } +// } + +// cleanup().await; +// LOGGER.success("The Client has been shut down."); +// } diff --git a/rust/cluster/Cargo.toml b/rust/cluster/Cargo.toml index c0b58de..dd135c0 100644 --- a/rust/cluster/Cargo.toml +++ b/rust/cluster/Cargo.toml @@ -13,8 +13,9 @@ workspace = true [dependencies] aes.workspace = true +bytes.workspace = true dashmap.workspace = true -lazy_static.workspace = true +public-ip.workspace = true sustenet-shared.workspace = true tokio = { workspace = true, features = [ # "socket2", @@ -25,4 +26,3 @@ tokio = { workspace = true, features = [ "io-util", "time", ] } -public-ip.workspace = true diff --git a/rust/cluster/src/cluster.rs b/rust/cluster/src/cluster.rs new file mode 100644 index 0000000..67705cf --- /dev/null +++ b/rust/cluster/src/cluster.rs @@ -0,0 +1,150 @@ +//! The cluster is a server that hosts many worlds and also knows about other clusters. +//! The cluster gets this information from the master server. + +use sustenet_shared::ServerPlugin; +use sustenet_shared::config::cluster::{ Settings, read }; +use sustenet_shared::logging::{ LogType, Logger }; +use sustenet_shared::network::ClusterInfo; +use sustenet_shared::packets::Diagnostics; + +use std::collections::HashMap; +use std::io::Error; +use std::sync::LazyLock; + +use bytes::Bytes; +use tokio::io; +use tokio::net::TcpListener; +use tokio::sync::mpsc; + +use crate::cluster_client::ClusterClient; +use crate::master_connection::MasterConnection; + +/// Global logger for the cluster module. +pub static LOGGER: LazyLock = LazyLock::new(|| Logger::new(LogType::Cluster)); + +/// Events emitted by the cluster server to notify listeners. +#[derive(Debug, Clone)] +pub enum ClusterEvent { + MasterConnected, + MasterDisconnected, + MasterCommandSent(u8), + MasterMessageSent(Bytes), + MasterCommandReceived(u8), + MasterMessageReceived(Bytes), + + /// When a connection is established with a client or server. + Connected(u64), + /// When a connection is closed with a client or server. + Disconnected(u64), + + DiagnosticsReceived(Diagnostics, Bytes), + Shutdown, + Error(String), +} + +/// Handles connections and interactions with Cluster Servers and Clients. +pub struct ClusterServer { + _plugin: P, + + _max_connections: u32, + bind: String, + port: u16, + + // sender: mpsc::Sender, + event_tx: mpsc::Sender, + _event_rx: mpsc::Receiver, + + connections: HashMap, + /// Only used to store cluster servers so clients can switch between them. + _cluster_servers: Vec, + _master_connection: MasterConnection, + _next_id: u64, +} + +impl ClusterServer

{ + pub async fn new(settings: Settings, plugin: P) -> io::Result { + let (event_tx, event_rx) = mpsc::channel::(16); + + let port = settings.port; + let master_connection = MasterConnection::connect( + &settings.master_ip, + settings.master_port + ).await?; + + Ok(Self { + _plugin: plugin, + + _max_connections: settings.max_connections, + bind: settings.bind, + port, + + event_tx, + _event_rx: event_rx, + + connections: HashMap::new(), + _cluster_servers: Vec::new(), + _master_connection: master_connection, + _next_id: 0, + }) + } + + pub async fn new_from_cli() -> io::Result { + // TODO (low priority): Load the configuration from CLI arguments + todo!() + } + + pub async fn new_from_config(plugin: P) -> io::Result { + let settings = read(); + + Self::new(settings, plugin).await + } + + /// + pub async fn start(&mut self) -> io::Result<()> { + // Create Listener + let addr = format!("{}:{}", self.bind, self.port); + let listener = match TcpListener::bind(&addr).await { + Ok(l) => { + LOGGER.success(&format!("Cluster server started on {addr}")); + l + } + Err(e) => { + LOGGER.error(&format!("Failed to bind to {addr}")); + return Err(Error::new(e.kind(), format!("Failed to bind to ({addr}): {e}"))); + } + }; + + // TODO: Improve starting here. + loop { + let (stream, peer) = match listener.accept().await { + Ok(pair) => pair, + Err(e) => { + LOGGER.error(&format!("Failed to accept connection: {e}")); + continue; + } + }; + LOGGER.debug(&format!("Accepted connection from {peer}")); + + // TODO: This is one the right path but AI did this. Move it to a struct. + + // Create a new ConnectionInfo instance + let id = 0; + let connection = match ClusterClient::new(id, stream, self.event_tx.clone()).await { + Ok(c) => c, + Err(e) => { + LOGGER.error(&format!("Failed to create connection: {e}")); + continue; + } + }; + + // Store the connection in the connections map + self.connections.insert(id, connection); + } + } + + // TODO: Add a tick + // async fn tick(&mut self) -> io::Result<()> { + // LOGGER.debug("Ticking cluster server..."); + // Ok(()) + // } +} diff --git a/rust/cluster/src/cluster_client.rs b/rust/cluster/src/cluster_client.rs new file mode 100644 index 0000000..62d56d4 --- /dev/null +++ b/rust/cluster/src/cluster_client.rs @@ -0,0 +1,175 @@ +use sustenet_shared::lselect; +use sustenet_shared::packets::{ ClusterSetup, Connection, Diagnostics }; + +use std::io::Error; + +use bytes::Bytes; +use tokio::io; +use tokio::io::{ AsyncReadExt, AsyncWriteExt }; +use tokio::net::TcpStream; +use tokio::sync::mpsc; +use tokio::sync::mpsc::error::SendError; + +use crate::cluster::{ ClusterEvent, LOGGER }; + +/// Handles connections that clients and cluster servers establish with the +/// master server. +pub struct ClusterClient { + sender: mpsc::Sender, +} + +impl ClusterClient { + pub async fn new( + id: u64, + stream: TcpStream, + event_tx: mpsc::Sender + ) -> io::Result { + let (sender, receiver) = mpsc::channel::(16); + let connection = Self { sender }; + + if let Err(e) = Self::receive(id, stream, connection.sender.clone(), receiver, event_tx) { + LOGGER.error(&format!("Failed to start connection #{id}")); + return Err(Error::new(e.kind(), format!("Failed to start connection #{id}: {e}"))); + } + + Ok(connection) + } + + /// Sends a message to the sender to close the connection. + /// + /// This should be called before getting rid of this ServerClient. + pub async fn close(&self) { + self.sender.send(Bytes::new()).await.unwrap(); + } + + /// Receives messages from clients and handles them. + /// + /// It also enables the MasterServer to send messages through this + /// struct's sender. + pub fn receive( + id: u64, + mut stream: TcpStream, + sender: mpsc::Sender, + mut receiver: mpsc::Receiver, + event_tx: mpsc::Sender + ) -> io::Result<()> { + tokio::spawn(async move { + let (reader, mut writer) = stream.split(); + let mut reader = io::BufReader::new(reader); + + lselect!( + // Handle local requests to send a message to the the other side of the connection. + msg = receiver.recv() => { + match msg { + Some(msg) => { + if msg.is_empty() { + LOGGER.warning("Received empty message, shutting down connection"); + Self::handle_shutdown(writer, event_tx, id).await; + break; + } + + LOGGER.debug(&format!("Sending message: {:?}", msg)); + if let Err(e) = writer.write_all(&msg).await { + let msg = format!("Failed to send message to server: {e}"); + LOGGER.error(&msg); + let _ = event_tx.send(ClusterEvent::Error(msg)); + } else { + // TODO: Still need to decide if we should notify about messages sent on a server. + // let _ = event_tx.send(ClusterEvent::MessageSent(msg)); + } + }, + None => { + LOGGER.warning("Connection closed"); + Self::handle_shutdown(writer, event_tx, id).await; + break; + } + } + }, + command = reader.read_u8() => { + match command { + Ok(command) => { + LOGGER.debug(&format!("Received command: {command}")); + + Self::handle_command(command, &sender, &mut reader, &mut writer, &event_tx).await; + + // Notify listeners about the received message. + // TODO: Should we? I'm leaning more towards not notifying about commands. + // It could ruin performance. + // let _ = event_tx_clone.send(ClusterEvent::CommandReceived(command)); + }, + Err(e) => { + let msg = format!("Failed to read command for connection #{}: {e}", id); + LOGGER.error(&msg); + let _ = event_tx.send(ClusterEvent::Error(msg)); + } + } + } + ); + }); + + Ok(()) + } + + /// An external method to allow the master server to send messages to the client. + pub async fn send(&self, bytes: Bytes) -> Result<(), SendError> { + if let Err(e) = self.sender.send(bytes).await { + LOGGER.error(&format!("Failed to send message to client: {e}")); + return Err(e); + } + Ok(()) + } + + async fn handle_shutdown( + mut writer: tokio::net::tcp::WriteHalf<'_>, + event_tx: mpsc::Sender, + id: u64 + ) { + if let Err(e) = writer.shutdown().await { + let msg = format!("Failed to shutdown writer: {e}"); + LOGGER.error(&msg); + let _ = event_tx.send(ClusterEvent::Error(msg)); + } + let _ = event_tx.send(ClusterEvent::Disconnected(id)); + } + + async fn handle_command( + command: u8, + _sender: &mpsc::Sender, + _reader: &mut io::BufReader>, + _writer: &mut tokio::net::tcp::WriteHalf<'_>, + event_tx: &mpsc::Sender + ) { + // Handle the command received from the server. + match command { + x if x == (Connection::Connect as u8) => { + LOGGER.info("Handling Connection Connect"); + } + x if x == (Connection::Disconnect as u8) => { + LOGGER.info("Handling Connection Disconnect"); + } + + x if x == (Diagnostics::CheckServerType as u8) => { + LOGGER.info("Handling Diagnostics Check Server Type"); + } + x if x == (Diagnostics::CheckServerUptime as u8) => { + LOGGER.info("Handling Diagnostics Check Server Uptime"); + } + x if x == (Diagnostics::CheckServerPlayerCount as u8) => { + LOGGER.info("Handling Diagnostics Check Server Player Count"); + } + + x if x == (ClusterSetup::Init as u8) => { + LOGGER.info("Handling Cluster Setup Init"); + } + x if x == (ClusterSetup::AnswerSecret as u8) => { + LOGGER.info("Handling Cluster Setup Answer Secret"); + } + + _ => { + let msg = format!("Unknown command received: {command}"); + LOGGER.error(&msg); + let _ = event_tx.send(ClusterEvent::Error(msg)); + } + } + } +} diff --git a/rust/cluster/src/lib.rs b/rust/cluster/src/lib.rs index f6949d5..a99cfdf 100644 --- a/rust/cluster/src/lib.rs +++ b/rust/cluster/src/lib.rs @@ -1,421 +1,427 @@ -use sustenet_shared as shared; - -use std::collections::BTreeSet; -use std::sync::{ Arc, LazyLock }; -use std::{ net::Ipv4Addr, str::FromStr }; - -use tokio::io::{ AsyncReadExt, AsyncWriteExt, BufReader }; -use tokio::net::{ TcpListener, TcpStream }; -use tokio::select; -use tokio::sync::mpsc::Sender; -use tokio::sync::{ Mutex, RwLock, mpsc }; - -use dashmap::DashMap; - -use public_ip::addr; - -use shared::config::cluster::{ Settings, read }; -use shared::logging::{ LogType, Logger }; -use shared::network::{ ClusterInfo, Event }; -use shared::packets::cluster::FromClient; -use shared::packets::master::{ FromUnknown, ToUnknown }; -use shared::security::aes::{ create_keys_dir, decrypt, generate_key, load_key, save_key }; -use shared::utils::constants::{ self, DEFAULT_IP }; -use shared::{ ServerPlugin, lselect }; - -lazy_static::lazy_static! { - static ref CLUSTER_IDS: Arc>> = Arc::new( - RwLock::new(BTreeSet::new()) - ); -} -pub static LOGGER: LazyLock = LazyLock::new(|| Logger::new(LogType::Cluster)); - -pub fn get_ip(ip: &str) -> Ipv4Addr { - Ipv4Addr::from_str(ip).unwrap_or(Ipv4Addr::from_str(DEFAULT_IP).unwrap_or(Ipv4Addr::LOCALHOST)) -} - -pub async fn cleanup() {} - -pub async fn start_with_config

(plugin: P) where P: ServerPlugin + Send + Sync + 'static { - start(plugin, read()).await; -} - -pub async fn start

(plugin: P, settings: Settings) where P: ServerPlugin + Send + Sync + 'static { - let Settings { - server_name, - max_connections, - port, - key_name, - master_ip, - master_port, - domain_pub_key: _, - } = settings; - - let plugin = Arc::new(plugin); - - LOGGER.set_plugin({ - let plugin = Arc::clone(&plugin); - move |msg| plugin.info(msg) - }); - - let key = match load_key(key_name.as_str()) { - Ok(key) => key, - Err(_) => { - if let Err(e) = create_keys_dir() { - LOGGER.error(e.to_string().as_str()); - panic!("{e:?}"); - } - - let key = generate_key(); - if save_key(key_name.as_str(), key).is_err() { - LOGGER.error("Failed to save the generated key."); - panic!("Failed to save the generated key."); - } - - LOGGER.warning( - format!( - "A new AES key at 'keys/{key_name}' has been generated and saved. Make sure the Master Server also has this key for authentication." - ).as_str() - ); - - key - } - }; - - let (tx, mut rx) = mpsc::channel::>(10); - plugin.set_sender(tx.clone()); - let tx_clone = tx.clone(); - - // Cluster Server's connection to the Master Server. - tokio::spawn(async move { - let mut stream = TcpStream::connect( - format!("{}:{}", get_ip(&master_ip), master_port) - ).await.expect("Failed to connect to the Master Server."); - - let (reader, mut writer) = stream.split(); - let mut reader = BufReader::new(reader); - - loop { - select! { - command = reader.read_u8() => { - if command.is_err() { - continue; - } - - LOGGER.debug(format!("Cluster Server received data: {:?}", command).as_str()); - - match command.unwrap() { - x if x == ToUnknown::VerifyCluster as u8 => { - let len = reader.read_u8().await.unwrap() as usize; - let mut passphrase = vec![0u8; len]; - match reader.read_exact(&mut passphrase).await { - Ok(_) => {}, - Err(e) => { - LOGGER.error(format!("Failed to read passphrase to String: {:?}", e).as_str()); - continue; - } - } - - let mut data = vec![FromUnknown::AnswerCluster as u8]; - - let decrypted_passphrase = decrypt(passphrase.as_slice(), &key); - - data.push(decrypted_passphrase.len() as u8); - data.extend_from_slice(&decrypted_passphrase); - data.push(server_name.len() as u8); - data.extend_from_slice(&server_name.as_bytes()); - - if let Some(ip) = addr().await { - let ip_string = ip.to_string(); - let ip_bytes = ip_string.as_bytes(); - data.push(ip_bytes.len() as u8); - data.extend_from_slice(ip_bytes); - } else { - LOGGER.error("Failed to get the public IP address."); - return; - } - - data.extend_from_slice(&port.to_be_bytes()); - data.extend_from_slice(&max_connections.to_be_bytes()); - - - send_data(&tx, data.into_boxed_slice()).await; - } - x if x == ToUnknown::CreateCluster as u8 => { - LOGGER.success("We did it! We verified the cluster!"); - } - cmd => plugin.receive(tx.clone(), cmd, &mut reader).await, - } - } - result = rx.recv() => { - if let Some(data) = result { - writer.write_all(&data).await.expect("Failed to write to the Master Server."); - writer.flush().await.expect("Failed to flush the writer."); - } else { - writer.shutdown().await.expect("Failed to shutdown the writer."); - LOGGER.info("Cluster Server is shutting down its client writer."); - break; - } - } - } - } - - let (event_sender, mut event_receiver) = mpsc::channel::(100); - - let clients: DashMap = DashMap::new(); - let released_ids: Arc>> = Arc::new(Mutex::new(BTreeSet::new())); // In the future, think about reserving cluster ids. Sometimes a cluster can get a high ID, causing RAM to stay high during low loads. - - { - let tcp_listener = TcpListener::bind( - format!("{}:{}", constants::DEFAULT_IP, port) - ).await.expect("Failed to bind to the specified port."); - - loop { - select! { - event = event_receiver.recv() => { - if let Some(event) = event { - match event { - Event::Connection(id) => on_connection(id), - Event::Disconnection(id) => { - LOGGER.debug(format!("Client#{id} disconnected.").as_str()); - clients.remove(&id); - - if id >= clients.len() as u32 { - LOGGER.info(format!("Client#{id} wasn't added to the released IDs list.").as_str()); - continue; - } - - let mut ids = released_ids.lock().await; - if !(*ids).insert(id) { - LOGGER.error(format!("ID {} already exists in the released IDs.", id).as_str()); - continue; - }; - }, - Event::ReceivedData(id, data) => on_received_data(id, &data), - } - } - } - // Listen and add clients. - res = tcp_listener.accept() => { - if let Ok((stream, addr)) = res { - LOGGER.debug(format!("Accepted connection from {:?}", addr).as_str()); - - // If the max_connections is reached, return an error. - if max_connections != 0 && clients.len() >= (max_connections as usize) { - LOGGER.error("Max connections reached."); - continue; - } - - // Get the next available ID and insert it. - let released_id: u32 = released_ids - .lock().await - .pop_first() - .unwrap_or(clients.len() as u32); - let mut client = ServerClient::new(released_id); - client.handle_data(event_sender.clone(), stream).await; - clients.insert(released_id, client); - - event_sender.send(Event::Connection(released_id)).await.unwrap(); - } - } - } - } - } - }); - - // Send a request to the Master Server to become a cluster. - { - let command = FromUnknown::BecomeCluster as u8; - - let mut data = [command].to_vec(); - data.push(key_name.len() as u8); - data.extend_from_slice(key_name.as_bytes()); - - let data = data.into_boxed_slice(); - send_data(&tx_clone, data).await; - } - - // Cluster Server Listener - { - let (event_sender, mut event_receiver) = mpsc::channel::(100); - - let clients: DashMap = DashMap::new(); - let released_ids: Arc>> = Arc::new(Mutex::new(BTreeSet::new())); - - { - let max_connections_str = match max_connections { - 0 => "unlimited max connections".to_string(), - 1 => "1 max connection".to_string(), - _ => format!("{} max connections", max_connections), - }; - - LOGGER.debug( - format!("Starting the Cluster Server on port {} with {max_connections_str}...", port).as_str() - ); - } - - // Listen - { - let tcp_listener = TcpListener::bind( - format!("{}:{}", constants::DEFAULT_IP, port) - ).await.expect("Failed to bind to the specified port."); - - lselect! { - event = event_receiver.recv() => { - if let Some(event) = event { - match event { - Event::Connection(id) => on_connection(id), - Event::Disconnection(id) => { - LOGGER.debug(format!("Client#{id} disconnected.").as_str()); - clients.remove(&id); - - if id >= clients.len() as u32 { - LOGGER.info(format!("Client#{id} wasn't added to the released IDs list.").as_str()); - continue; - } - - let mut ids = released_ids.lock().await; - if !(*ids).insert(id) { - LOGGER.error(format!("ID {} already exists in the released IDs.", id).as_str()); - continue; - }; - }, - Event::ReceivedData(id, data) => on_received_data(id, &data), - } - } - } - // Listen and add clients. - res = tcp_listener.accept() => { - if let Ok((stream, addr)) = res { - LOGGER.debug(format!("Accepted connection from {:?}", addr).as_str()); - - // If the max_connections is reached, return an error. - if max_connections != 0 && clients.len() >= (max_connections as usize) { - LOGGER.error("Max connections reached."); - continue; - } - - // Get the next available ID and insert it. - let released_id: u32 = released_ids - .lock().await - .pop_first() - .unwrap_or(clients.len() as u32); - let mut client = ServerClient::new(released_id); - client.handle_data(event_sender.clone(), stream).await; - clients.insert(released_id, client); - - event_sender.send(Event::Connection(released_id)).await.unwrap(); - } - } - } - } - } -} - -async fn send_data(tx: &mpsc::Sender>, data: Box<[u8]>) { - tx.send(data).await.expect("Failed to send data to the Server."); -} - -// region: Events -fn on_connection(id: u32) { - LOGGER.debug(format!("Client#{id} connected").as_str()); -} - -fn on_received_data(id: u32, data: &[u8]) { - LOGGER.debug(format!("Received data from Client#{id}: {:?}", data).as_str()); - todo!() -} - -// fn on_client_connected(id: u32) { -// LOGGER.debug(format!("Client connected: {}", id).as_str()); -// todo!() +pub mod cluster; +pub mod cluster_client; +pub mod master_connection; + +pub use cluster::ClusterServer; + +// use sustenet_shared as shared; + +// use std::collections::BTreeSet; +// use std::sync::{ Arc, LazyLock }; +// use std::{ net::Ipv4Addr, str::FromStr }; + +// use tokio::io::{ AsyncReadExt, AsyncWriteExt, BufReader }; +// use tokio::net::{ TcpListener, TcpStream }; +// use tokio::select; +// use tokio::sync::mpsc::Sender; +// use tokio::sync::{ Mutex, RwLock, mpsc }; + +// use dashmap::DashMap; + +// use public_ip::addr; + +// use shared::config::cluster::{ Settings, read }; +// use shared::logging::{ LogType, Logger }; +// use shared::network::{ ClusterInfo, Event }; +// use shared::packets::cluster::FromClient; +// use shared::packets::master::{ FromUnknown, ToUnknown }; +// use shared::security::aes::{ create_keys_dir, decrypt, generate_key, load_key, save_key }; +// use shared::utils::constants::{ self, DEFAULT_IP }; +// use shared::{ ServerPlugin, lselect }; + +// lazy_static::lazy_static! { +// static ref CLUSTER_IDS: Arc>> = Arc::new( +// RwLock::new(BTreeSet::new()) +// ); // } +// pub static LOGGER: LazyLock = LazyLock::new(|| Logger::new(LogType::Cluster)); -// fn on_client_disconnected(id: u32, protocol: Protocols) { -// LOGGER.debug(format!("Client disconnected: {} {}", id, protocol as u8).as_str()); -// todo!() +// pub fn get_ip(ip: &str) -> Ipv4Addr { +// Ipv4Addr::from_str(ip).unwrap_or(Ipv4Addr::from_str(DEFAULT_IP).unwrap_or(Ipv4Addr::LOCALHOST)) // } -// fn on_client_received_data(id: u32, protocol: Protocols, data: &[u8]) { -// LOGGER.debug(format!("Client received data: {} {} {:?}", id, protocol as u8, data).as_str()); +// pub async fn cleanup() {} + +// pub async fn start_with_config

(plugin: P) where P: ServerPlugin + Send + Sync + 'static { +// start(plugin, read()).await; +// } + +// pub async fn start

(plugin: P, settings: Settings) where P: ServerPlugin + Send + Sync + 'static { +// let Settings { +// server_name, +// max_connections, +// port, +// key_name, +// master_ip, +// master_port, +// domain_pub_key: _, +// } = settings; + +// let plugin = Arc::new(plugin); + +// LOGGER.set_plugin({ +// let plugin = Arc::clone(&plugin); +// move |msg| plugin.info(msg) +// }); + +// let key = match load_key(key_name.as_str()) { +// Ok(key) => key, +// Err(_) => { +// if let Err(e) = create_keys_dir() { +// LOGGER.error(e.to_string().as_str()); +// panic!("{e:?}"); +// } + +// let key = generate_key(); +// if save_key(key_name.as_str(), key).is_err() { +// LOGGER.error("Failed to save the generated key."); +// panic!("Failed to save the generated key."); +// } + +// LOGGER.warning( +// format!( +// "A new AES key at 'keys/{key_name}' has been generated and saved. Make sure the Master Server also has this key for authentication." +// ).as_str() +// ); + +// key +// } +// }; + +// let (tx, mut rx) = mpsc::channel::>(10); +// plugin.set_sender(tx.clone()); +// let tx_clone = tx.clone(); + +// // Cluster Server's connection to the Master Server. +// tokio::spawn(async move { +// let mut stream = TcpStream::connect( +// format!("{}:{}", get_ip(&master_ip), master_port) +// ).await.expect("Failed to connect to the Master Server."); + +// let (reader, mut writer) = stream.split(); +// let mut reader = BufReader::new(reader); + +// loop { +// select! { +// command = reader.read_u8() => { +// if command.is_err() { +// continue; +// } + +// LOGGER.debug(format!("Cluster Server received data: {:?}", command).as_str()); + +// match command.unwrap() { +// x if x == ToUnknown::VerifyCluster as u8 => { +// let len = reader.read_u8().await.unwrap() as usize; +// let mut passphrase = vec![0u8; len]; +// match reader.read_exact(&mut passphrase).await { +// Ok(_) => {}, +// Err(e) => { +// LOGGER.error(format!("Failed to read passphrase to String: {:?}", e).as_str()); +// continue; +// } +// } + +// let mut data = vec![FromUnknown::AnswerCluster as u8]; + +// let decrypted_passphrase = decrypt(passphrase.as_slice(), &key); + +// data.push(decrypted_passphrase.len() as u8); +// data.extend_from_slice(&decrypted_passphrase); +// data.push(server_name.len() as u8); +// data.extend_from_slice(&server_name.as_bytes()); + +// if let Some(ip) = addr().await { +// let ip_string = ip.to_string(); +// let ip_bytes = ip_string.as_bytes(); +// data.push(ip_bytes.len() as u8); +// data.extend_from_slice(ip_bytes); +// } else { +// LOGGER.error("Failed to get the public IP address."); +// return; +// } + +// data.extend_from_slice(&port.to_be_bytes()); +// data.extend_from_slice(&max_connections.to_be_bytes()); + + +// send_data(&tx, data.into_boxed_slice()).await; +// } +// x if x == ToUnknown::CreateCluster as u8 => { +// LOGGER.success("We did it! We verified the cluster!"); +// } +// cmd => plugin.receive(tx.clone(), cmd, &mut reader).await, +// } +// } +// result = rx.recv() => { +// if let Some(data) = result { +// writer.write_all(&data).await.expect("Failed to write to the Master Server."); +// writer.flush().await.expect("Failed to flush the writer."); +// } else { +// writer.shutdown().await.expect("Failed to shutdown the writer."); +// LOGGER.info("Cluster Server is shutting down its client writer."); +// break; +// } +// } +// } +// } + +// let (event_sender, mut event_receiver) = mpsc::channel::(100); + +// let clients: DashMap = DashMap::new(); +// let released_ids: Arc>> = Arc::new(Mutex::new(BTreeSet::new())); // In the future, think about reserving cluster ids. Sometimes a cluster can get a high ID, causing RAM to stay high during low loads. + +// { +// let tcp_listener = TcpListener::bind( +// format!("{}:{}", constants::DEFAULT_IP, port) +// ).await.expect("Failed to bind to the specified port."); + +// loop { +// select! { +// event = event_receiver.recv() => { +// if let Some(event) = event { +// match event { +// Event::Connection(id) => on_connection(id), +// Event::Disconnection(id) => { +// LOGGER.debug(format!("Client#{id} disconnected.").as_str()); +// clients.remove(&id); + +// if id >= clients.len() as u32 { +// LOGGER.info(format!("Client#{id} wasn't added to the released IDs list.").as_str()); +// continue; +// } + +// let mut ids = released_ids.lock().await; +// if !(*ids).insert(id) { +// LOGGER.error(format!("ID {} already exists in the released IDs.", id).as_str()); +// continue; +// }; +// }, +// Event::ReceivedData(id, data) => on_received_data(id, &data), +// } +// } +// } +// // Listen and add clients. +// res = tcp_listener.accept() => { +// if let Ok((stream, addr)) = res { +// LOGGER.debug(format!("Accepted connection from {:?}", addr).as_str()); + +// // If the max_connections is reached, return an error. +// if max_connections != 0 && clients.len() >= (max_connections as usize) { +// LOGGER.error("Max connections reached."); +// continue; +// } + +// // Get the next available ID and insert it. +// let released_id: u32 = released_ids +// .lock().await +// .pop_first() +// .unwrap_or(clients.len() as u32); +// let mut client = ServerClient::new(released_id); +// client.handle_data(event_sender.clone(), stream).await; +// clients.insert(released_id, client); + +// event_sender.send(Event::Connection(released_id)).await.unwrap(); +// } +// } +// } +// } +// } +// }); + +// // Send a request to the Master Server to become a cluster. +// { +// let command = FromUnknown::BecomeCluster as u8; + +// let mut data = [command].to_vec(); +// data.push(key_name.len() as u8); +// data.extend_from_slice(key_name.as_bytes()); + +// let data = data.into_boxed_slice(); +// send_data(&tx_clone, data).await; +// } + +// // Cluster Server Listener +// { +// let (event_sender, mut event_receiver) = mpsc::channel::(100); + +// let clients: DashMap = DashMap::new(); +// let released_ids: Arc>> = Arc::new(Mutex::new(BTreeSet::new())); + +// { +// let max_connections_str = match max_connections { +// 0 => "unlimited max connections".to_string(), +// 1 => "1 max connection".to_string(), +// _ => format!("{} max connections", max_connections), +// }; + +// LOGGER.debug( +// format!("Starting the Cluster Server on port {} with {max_connections_str}...", port).as_str() +// ); +// } + +// // Listen +// { +// let tcp_listener = TcpListener::bind( +// format!("{}:{}", constants::DEFAULT_IP, port) +// ).await.expect("Failed to bind to the specified port."); + +// lselect! { +// event = event_receiver.recv() => { +// if let Some(event) = event { +// match event { +// Event::Connection(id) => on_connection(id), +// Event::Disconnection(id) => { +// LOGGER.debug(format!("Client#{id} disconnected.").as_str()); +// clients.remove(&id); + +// if id >= clients.len() as u32 { +// LOGGER.info(format!("Client#{id} wasn't added to the released IDs list.").as_str()); +// continue; +// } + +// let mut ids = released_ids.lock().await; +// if !(*ids).insert(id) { +// LOGGER.error(format!("ID {} already exists in the released IDs.", id).as_str()); +// continue; +// }; +// }, +// Event::ReceivedData(id, data) => on_received_data(id, &data), +// } +// } +// } +// // Listen and add clients. +// res = tcp_listener.accept() => { +// if let Ok((stream, addr)) = res { +// LOGGER.debug(format!("Accepted connection from {:?}", addr).as_str()); + +// // If the max_connections is reached, return an error. +// if max_connections != 0 && clients.len() >= (max_connections as usize) { +// LOGGER.error("Max connections reached."); +// continue; +// } + +// // Get the next available ID and insert it. +// let released_id: u32 = released_ids +// .lock().await +// .pop_first() +// .unwrap_or(clients.len() as u32); +// let mut client = ServerClient::new(released_id); +// client.handle_data(event_sender.clone(), stream).await; +// clients.insert(released_id, client); + +// event_sender.send(Event::Connection(released_id)).await.unwrap(); +// } +// } +// } +// } +// } +// } + +// async fn send_data(tx: &mpsc::Sender>, data: Box<[u8]>) { +// tx.send(data).await.expect("Failed to send data to the Server."); +// } + +// // region: Events +// fn on_connection(id: u32) { +// LOGGER.debug(format!("Client#{id} connected").as_str()); +// } + +// fn on_received_data(id: u32, data: &[u8]) { +// LOGGER.debug(format!("Received data from Client#{id}: {:?}", data).as_str()); // todo!() // } -// endregion - -pub struct ServerClient { - pub id: u32, - pub name: Arc>>, - pub sender: Option>>, -} - -impl ServerClient { - pub fn new(id: u32) -> Self { - ServerClient { - id, - name: Arc::new(RwLock::new(None)), - sender: None, - } - } - - /// Handle the data from the client. - pub async fn handle_data(&mut self, event_sender: Sender, mut stream: TcpStream) { - let id = self.id; - let _name = self.name.clone(); // TODO: Implement name handling. - let (tx, mut rx) = tokio::sync::mpsc::channel(10); - self.sender = Some(tx.clone()); - - tokio::spawn(async move { - let (reader, mut writer) = stream.split(); - - let mut reader = BufReader::new(reader); - - loop { - select! { - // Incoming data from the client. - command = reader.read_u8() => { - if command.is_err() { - event_sender.send(Event::Disconnection(id)).await.expect("Failed to send disconnection event."); - break; - } - - LOGGER.debug(format!("Cluster Server received data: {:?}", command).as_str()); - - match command.unwrap() { - x if x == FromClient::RequestClusters as u8 => { - let mut data = vec![ToUnknown::SendClusters as u8]; - let cluster_ids = CLUSTER_IDS.read().await; - data.push(cluster_ids.len() as u8); - for cluster in (*cluster_ids).iter() { - data.push(cluster.name.len() as u8); - data.extend_from_slice(cluster.name.as_bytes()); - data.push(cluster.ip.len() as u8); - data.extend_from_slice(cluster.ip.as_bytes()); - data.extend_from_slice(&cluster.port.to_be_bytes()); - data.extend_from_slice(&cluster.max_connections.to_be_bytes()); - } - Self::send_data(&tx, data.into_boxed_slice()).await; - }, - _ => (), - } - } - // Outgoing data to the client. - result = rx.recv() => { - if let Some(data) = result { - writer.write_all(&data).await.expect("Failed to write to the Master Server."); - writer.flush().await.expect("Failed to flush the writer."); - } else { - writer.shutdown().await.expect("Failed to shutdown the writer."); - LOGGER.info("Cluster Server is shutting down its client writer."); - event_sender.send(Event::Disconnection(id)).await.expect("Failed to send disconnection event."); - break; - } - } - } - } - }); - } - - async fn send_data(tx: &mpsc::Sender>, data: Box<[u8]>) { - tx.send(data).await.expect("Failed to send data out."); - } -} + +// // fn on_client_connected(id: u32) { +// // LOGGER.debug(format!("Client connected: {}", id).as_str()); +// // todo!() +// // } + +// // fn on_client_disconnected(id: u32, protocol: Protocols) { +// // LOGGER.debug(format!("Client disconnected: {} {}", id, protocol as u8).as_str()); +// // todo!() +// // } + +// // fn on_client_received_data(id: u32, protocol: Protocols, data: &[u8]) { +// // LOGGER.debug(format!("Client received data: {} {} {:?}", id, protocol as u8, data).as_str()); +// // todo!() +// // } +// // endregion + +// pub struct ServerClient { +// pub id: u32, +// pub name: Arc>>, +// pub sender: Option>>, +// } + +// impl ServerClient { +// pub fn new(id: u32) -> Self { +// ServerClient { +// id, +// name: Arc::new(RwLock::new(None)), +// sender: None, +// } +// } + +// /// Handle the data from the client. +// pub async fn handle_data(&mut self, event_sender: Sender, mut stream: TcpStream) { +// let id = self.id; +// let _name = self.name.clone(); // TODO: Implement name handling. +// let (tx, mut rx) = tokio::sync::mpsc::channel(10); +// self.sender = Some(tx.clone()); + +// tokio::spawn(async move { +// let (reader, mut writer) = stream.split(); + +// let mut reader = BufReader::new(reader); + +// loop { +// select! { +// // Incoming data from the client. +// command = reader.read_u8() => { +// if command.is_err() { +// event_sender.send(Event::Disconnection(id)).await.expect("Failed to send disconnection event."); +// break; +// } + +// LOGGER.debug(format!("Cluster Server received data: {:?}", command).as_str()); + +// match command.unwrap() { +// x if x == FromClient::RequestClusters as u8 => { +// let mut data = vec![ToUnknown::SendClusters as u8]; +// let cluster_ids = CLUSTER_IDS.read().await; +// data.push(cluster_ids.len() as u8); +// for cluster in (*cluster_ids).iter() { +// data.push(cluster.name.len() as u8); +// data.extend_from_slice(cluster.name.as_bytes()); +// data.push(cluster.ip.len() as u8); +// data.extend_from_slice(cluster.ip.as_bytes()); +// data.extend_from_slice(&cluster.port.to_be_bytes()); +// data.extend_from_slice(&cluster.max_connections.to_be_bytes()); +// } +// Self::send_data(&tx, data.into_boxed_slice()).await; +// }, +// _ => (), +// } +// } +// // Outgoing data to the client. +// result = rx.recv() => { +// if let Some(data) = result { +// writer.write_all(&data).await.expect("Failed to write to the Master Server."); +// writer.flush().await.expect("Failed to flush the writer."); +// } else { +// writer.shutdown().await.expect("Failed to shutdown the writer."); +// LOGGER.info("Cluster Server is shutting down its client writer."); +// event_sender.send(Event::Disconnection(id)).await.expect("Failed to send disconnection event."); +// break; +// } +// } +// } +// } +// }); +// } + +// async fn send_data(tx: &mpsc::Sender>, data: Box<[u8]>) { +// tx.send(data).await.expect("Failed to send data out."); +// } +// } diff --git a/rust/cluster/src/main.rs b/rust/cluster/src/main.rs index ffce0c6..be01e2e 100644 --- a/rust/cluster/src/main.rs +++ b/rust/cluster/src/main.rs @@ -1,27 +1,42 @@ -use sustenet_shared as shared; +use sustenet_cluster::ClusterServer; +use sustenet_shared::{ PluginPin, SenderBytes, ServerPlugin }; +use tokio::{ io::BufReader, net::tcp::ReadHalf }; -use tokio::{ select, sync::mpsc::Sender }; +#[tokio::main] +async fn main() { + let plugin = DefaultPlugin { + sender: std::sync::OnceLock::new(), + }; + let mut cluster = ClusterServer::new_from_config(plugin).await.unwrap(); -use shared::utils; -use sustenet_cluster::{ cleanup, start_with_config, LOGGER }; + // Wait for the shutdown signal or start the server + tokio::select! { + _ = tokio::signal::ctrl_c() => { + println!("Shutting down..."); + }, + _ = cluster.start() => { + println!("Cluster server started."); + } + } +} struct DefaultPlugin { - sender: std::sync::OnceLock>>, + sender: std::sync::OnceLock, } -impl shared::ServerPlugin for DefaultPlugin { - fn set_sender(&self, tx: Sender>) { +impl ServerPlugin for DefaultPlugin { + fn set_sender(&self, tx: SenderBytes) { // Set the sender if self.sender.set(tx).is_err() { - LOGGER.error("Failed to set sender"); + println!("Failed to set sender"); } } fn receive<'plug>( &self, - _tx: Sender>, + _tx: SenderBytes, command: u8, - _reader: &'plug mut tokio::io::BufReader> - ) -> std::pin::Pin + Send>> { + _reader: &'plug mut BufReader> + ) -> PluginPin<'plug> { Box::pin(async move { match command { 0 => println!("Command 0 received"), @@ -34,17 +49,38 @@ impl shared::ServerPlugin for DefaultPlugin { fn info(&self, _: &str) {} } -#[tokio::main] -async fn main() { - let mut shutdown_rx = utils::shutdown_channel().expect("Error creating shutdown channel."); +// use sustenet_shared as shared; - select! { - _ = shutdown_rx.recv() => { - LOGGER.warning("Shutting down..."); - } - _ = start_with_config(DefaultPlugin { sender: std::sync::OnceLock::new() }) => {} - } +// use tokio::{ select, sync::mpsc::Sender }; - cleanup().await; - LOGGER.success("The Cluster Server has been shut down."); -} +// use shared::utils; +// use sustenet_cluster::{ cleanup, start_with_config, LOGGER }; + +// struct DefaultPlugin { +// sender: std::sync::OnceLock>>, +// } +// impl shared::ServerPlugin for DefaultPlugin { +// fn set_sender(&self, tx: Sender>) { +// // Set the sender +// if self.sender.set(tx).is_err() { +// LOGGER.error("Failed to set sender"); +// } +// } + +// fn receive<'plug>( +// &self, +// _tx: Sender>, +// command: u8, +// _reader: &'plug mut tokio::io::BufReader> +// ) -> std::pin::Pin + Send>> { +// Box::pin(async move { +// match command { +// 0 => println!("Command 0 received"), +// 1 => println!("Command 1 received"), +// _ => println!("Unknown command received"), +// } +// }) +// } + +// fn info(&self, _: &str) {} +// } diff --git a/rust/cluster/src/master_connection.rs b/rust/cluster/src/master_connection.rs new file mode 100644 index 0000000..7d84a04 --- /dev/null +++ b/rust/cluster/src/master_connection.rs @@ -0,0 +1,164 @@ +use sustenet_shared::lselect; +use sustenet_shared::packets::{ ClusterSetup, Connection, Diagnostics }; // TODO: Add Messaging +use tokio::net::tcp::{ReadHalf, WriteHalf}; +use tokio::sync::mpsc::error::SendError; + +use std::io::Error; +use std::net::IpAddr; + +use bytes::Bytes; +use tokio::io::{AsyncReadExt, BufReader}; +use tokio::io::{ self, AsyncWriteExt }; +use tokio::net::TcpStream; +use tokio::sync::{ broadcast, mpsc }; + +use crate::cluster::{ ClusterEvent, LOGGER }; + +pub struct MasterConnection { + ip: IpAddr, + port: u16, + sender: mpsc::Sender, +} + +impl MasterConnection { + pub async fn connect(address: &str, port: u16) -> io::Result { + let addr = format!("{}:{}", address, port); + LOGGER.info(&format!("Connecting to the master server at {addr}...")); + + // Establish a connection to the master server. + let mut stream = match TcpStream::connect(&addr).await { + Ok(s) => { + LOGGER.success(&format!("Connected to {addr}")); + s + } + Err(e) => { + LOGGER.error(&format!("Failed to connect to {addr}")); + return Err(Error::new(e.kind(), format!("Failed to connect to ({addr}): {e}"))); + } + }; + + let ip = stream.peer_addr()?.ip(); + + let (sender, mut receiver) = mpsc::channel::(16); + let (event_tx, _event_rx) = broadcast::channel::(16); + + let sender_clone = sender.clone(); + let event_tx_clone = event_tx.clone(); + + tokio::spawn(async move { + let (reader, mut writer) = stream.split(); + let mut reader = BufReader::new(reader); + + lselect!( + // Handle local requests to send a message to the master server. + msg = receiver.recv() => { + match msg { + Some(msg) => { + if msg.is_empty() { + LOGGER.warning("Received empty message, shutting down client"); + Self::handle_shutdown(writer, event_tx_clone).await; + break; + } + + LOGGER.debug(&format!("Sending message: {:?}", msg)); + if let Err(e) = writer.write_all(&msg).await { + let msg = format!("Failed to send message to master server: {e}"); + LOGGER.error(&msg); + let _ = event_tx_clone.send(ClusterEvent::Error(msg)); + } else { + let _ = event_tx_clone.send(ClusterEvent::MasterMessageReceived(msg)); + } + }, + None => { + LOGGER.warning("Connection closed"); + Self::handle_shutdown(writer, event_tx_clone).await; + break; + } + } + }, + command = reader.read_u8() => { + match command { + Ok(command) => { + LOGGER.debug(&format!("Received command: {command}")); + + Self::handle_command(command, &sender_clone, &mut reader, &mut writer, &event_tx_clone).await; + }, + Err(e) => { + LOGGER.error(&format!("Failed to read command: {e}")); + } + } + } + ); + }); + + + Ok(Self { ip, port, sender }) + } + + async fn handle_shutdown( + mut writer: WriteHalf<'_>, + event_tx_clone: broadcast::Sender + ) { + if let Err(e) = writer.shutdown().await { + let msg = format!("Failed to shutdown writer: {e}"); + LOGGER.error(&msg); + let _ = event_tx_clone.send(ClusterEvent::Error(msg)); + } + let _ = event_tx_clone.send(ClusterEvent::MasterDisconnected); + } + + /// Handles commands received from the server. + /// This function is called in a separate task to handle incoming commands. + async fn handle_command( + command: u8, + _sender: &mpsc::Sender, + _reader: &mut BufReader>, + _writer: &mut WriteHalf<'_>, + event_tx: &broadcast::Sender + ) { + // Handle the command received from the server. + match command { + x if x == (Connection::Connect as u8) => { + LOGGER.info("Handling Connection Connect"); + } + x if x == (Connection::Disconnect as u8) => { + LOGGER.info("Handling Connection Disconnect"); + } + + x if x == (ClusterSetup::Init as u8) => { + LOGGER.info("Handling Cluster Setup Init"); + } + x if x == (ClusterSetup::AnswerSecret as u8) => { + LOGGER.info("Handling Cluster Setup Answer Secret"); + } + + x if x == (Diagnostics::CheckServerType as u8) => { + LOGGER.info("Handling Diagnostics Check Server Type"); + } + x if x == (Diagnostics::CheckServerUptime as u8) => { + LOGGER.info("Handling Diagnostics Check Server Uptime"); + } + x if x == (Diagnostics::CheckServerPlayerCount as u8) => { + LOGGER.info("Handling Diagnostics Check Server Player Count"); + } + + _ => { + let msg = format!("Unknown command received: {command}"); + LOGGER.error(&msg); + let _ = event_tx.send(ClusterEvent::Error(msg)); + } + } + } + + pub fn get_ip(&self) -> IpAddr { + self.ip + } + + pub fn get_port(&self) -> u16 { + self.port + } + + pub async fn send(&self, data: Bytes) -> Result<(), SendError> { + self.sender.send(data).await + } +} diff --git a/rust/data/Config.toml b/rust/data/Config.toml new file mode 100644 index 0000000..ef81b88 --- /dev/null +++ b/rust/data/Config.toml @@ -0,0 +1,17 @@ +[master] +max_connections = 0 +port = 0 +bind = "0.0.0.0" # This is the IP address to bind to. + +[cluster] +name = "Cluster Server" + +max_connections = 0 +port = 0 +bind = "0.0.0.0" # This is the IP address to bind to. + +key_name = "cluster_key" # This is the name of the key that must exist BOTH on the cluster and the master. +master_ip = "127.0.0.1" # This is the IP address of the master server. +master_port = 0 # This is the port of the master server. + +domain_pub_key = "https://site-cdn.playreia.com/game/pubkey.pub" # Remove this if you want to use the server's bandwidth to send a key to a user directly. diff --git a/rust/master/Cargo.toml b/rust/master/Cargo.toml index 0caff80..fd74ab9 100644 --- a/rust/master/Cargo.toml +++ b/rust/master/Cargo.toml @@ -11,10 +11,14 @@ homepage.workspace = true [lints] workspace = true +[features] +ignored_tests = [] + [dependencies] +bytes.workspace = true dashmap.workspace = true getrandom.workspace = true -lazy_static.workspace = true +num_cpus.workspace = true sustenet-shared.workspace = true tokio = { workspace = true, features = [ # "socket2", diff --git a/rust/master/src/lib.rs b/rust/master/src/lib.rs index 48fc851..053edd3 100644 --- a/rust/master/src/lib.rs +++ b/rust/master/src/lib.rs @@ -1,402 +1,409 @@ -use sustenet_shared as shared; +pub mod master; +pub mod master_client; +pub mod security; -use std::collections::BTreeSet; -use std::sync::{ Arc, LazyLock }; +pub use master::MasterServer; +pub use master_client::MasterClient; -use dashmap::DashMap; +// use sustenet_shared as shared; -use tokio::io::{ AsyncReadExt, AsyncWriteExt, BufReader }; -use tokio::net::{ TcpListener, TcpStream }; -use tokio::select; -use tokio::sync::mpsc::{ self, Sender }; -use tokio::sync::{ Mutex, RwLock }; +// use std::collections::BTreeSet; +// use std::sync::{ Arc, LazyLock }; -use shared::config::master::{ Settings, read }; -use shared::logging::{ LogType, Logger }; -use shared::network::*; -use shared::packets::master::*; -use shared::security::aes::*; -use shared::utils::constants; +// use dashmap::DashMap; -pub mod security; +// use tokio::io::{ AsyncReadExt, AsyncWriteExt, BufReader }; +// use tokio::net::{ TcpListener, TcpStream }; +// use tokio::select; +// use tokio::sync::mpsc::{ self, Sender }; +// use tokio::sync::{ Mutex, RwLock }; -lazy_static::lazy_static! { - static ref CLUSTER_IDS: Arc>> = Arc::new( - RwLock::new(BTreeSet::new()) - ); -} -pub static LOGGER: LazyLock = LazyLock::new(|| Logger::new(LogType::Cluster)); - -#[derive(Eq)] -struct ClusterInfo { - id: u32, - name: String, - ip: String, - port: u16, - max_connections: u32, -} - -impl Ord for ClusterInfo { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - // Define how to compare two ClusterInfo instances - // For example, if ClusterInfo has a field `id` of type i32: - self.id.cmp(&other.id) - } -} - -impl PartialOrd for ClusterInfo { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl PartialEq for ClusterInfo { - fn eq(&self, other: &Self) -> bool { - // Define when two ClusterInfo instances are equal - // For example, if ClusterInfo has a field `id` of type i32: - self.id == other.id - } -} - -pub async fn start_with_config() { - start(read()).await; -} - -/// This function starts the master server. -/// It listens for an event -pub async fn start(settings: Settings) { - let Settings { server_name: _, max_connections, port } = settings; - let (event_sender, mut event_receiver) = mpsc::channel::(100); - - let clients: DashMap = DashMap::new(); - let released_ids: Arc>> = Arc::new(Mutex::new(BTreeSet::new())); // In the future, think about reserving cluster ids. Sometimes a cluster can get a high ID, causing RAM to stay high during low loads. - - { - let max_connections_str = match max_connections { - 0 => "unlimited max connections".to_string(), - 1 => "1 max connection".to_string(), - _ => format!("{} max connections", max_connections), - }; - - LOGGER.debug( - format!("Starting the Master Server on port {} with {max_connections_str}...", port).as_str() - ); - } - - // Listen - { - let tcp_listener = TcpListener::bind( - format!("{}:{}", constants::DEFAULT_IP, port) - ).await.expect("Failed to bind to the specified port."); - - loop { - select! { - event = event_receiver.recv() => { - if let Some(event) = event { - match event { - Event::Connection(id) => on_connection(id), - Event::Disconnection(id) => { - LOGGER.debug(format!("Client#{id} disconnected.").as_str()); - clients.remove(&id); - - if id >= clients.len() as u32 { - LOGGER.info(format!("Client#{id} wasn't added to the released IDs list.").as_str()); - continue; - } - - let mut ids = released_ids.lock().await; - if !(*ids).insert(id) { - LOGGER.error(format!("ID {} already exists in the released IDs.", id).as_str()); - continue; - }; - }, - Event::ReceivedData(id, data) => on_received_data(id, &data), - } - } - } - // Listen and add clients. - res = tcp_listener.accept() => { - if let Ok((stream, addr)) = res { - LOGGER.debug(format!("Accepted connection from {:?}", addr).as_str()); - - // If the max_connections is reached, return an error. - if max_connections != 0 && clients.len() >= (max_connections as usize) { - LOGGER.error("Max connections reached."); - continue; - } - - // Get the next available ID and insert it. - let released_id: u32 = released_ids - .lock().await - .pop_first() - .unwrap_or(clients.len() as u32); - let mut client = ServerClient::new(released_id); - client.handle_data(event_sender.clone(), stream).await; - clients.insert(released_id, client); - - event_sender.send(Event::Connection(released_id)).await.unwrap(); - } - } - } - } - } -} - -// region: Events -fn on_connection(id: u32) { - LOGGER.debug(format!("Client#{id} connected").as_str()); -} - -fn on_received_data(id: u32, data: &[u8]) { - LOGGER.debug(format!("Received data from Client#{id}: {:?}", data).as_str()); - todo!() -} - -// fn on_client_connected(id: u32) { -// LOGGER.debug(format!("Client connected: {}", id).as_str()); -// todo!() +// use shared::config::master::{ Settings, read }; +// use shared::logging::{ LogType, Logger }; +// use shared::network::*; +// use shared::packets::master::*; +// use shared::security::aes::*; +// use shared::utils::constants; + +// pub mod security; + +// lazy_static::lazy_static! { +// static ref CLUSTER_IDS: Arc>> = Arc::new( +// RwLock::new(BTreeSet::new()) +// ); +// } +// pub static LOGGER: LazyLock = LazyLock::new(|| Logger::new(LogType::Cluster)); + +// #[derive(Eq)] +// struct ClusterInfo { +// id: u32, +// name: String, +// ip: String, +// port: u16, +// max_connections: u32, // } -// fn on_client_disconnected(id: u32, protocol: Protocols) { -// LOGGER.debug(format!("Client disconnected: {} {}", id, protocol as u8).as_str()); -// todo!() +// impl Ord for ClusterInfo { +// fn cmp(&self, other: &Self) -> std::cmp::Ordering { +// // Define how to compare two ClusterInfo instances +// // For example, if ClusterInfo has a field `id` of type i32: +// self.id.cmp(&other.id) +// } +// } + +// impl PartialOrd for ClusterInfo { +// fn partial_cmp(&self, other: &Self) -> Option { +// Some(self.cmp(other)) +// } +// } + +// impl PartialEq for ClusterInfo { +// fn eq(&self, other: &Self) -> bool { +// // Define when two ClusterInfo instances are equal +// // For example, if ClusterInfo has a field `id` of type i32: +// self.id == other.id +// } // } -// fn on_client_received_data(id: u32, protocol: Protocols, data: &[u8]) { -// LOGGER.debug(format!("Client received data: {} {} {:?}", id, protocol as u8, data).as_str()); +// pub async fn start_with_config() { +// start(read()).await; +// } + +// /// This function starts the master server. +// /// It listens for an event +// pub async fn start(settings: Settings) { +// let Settings { server_name: _, max_connections, port } = settings; +// let (event_sender, mut event_receiver) = mpsc::channel::(100); + +// let clients: DashMap = DashMap::new(); +// let released_ids: Arc>> = Arc::new(Mutex::new(BTreeSet::new())); // In the future, think about reserving cluster ids. Sometimes a cluster can get a high ID, causing RAM to stay high during low loads. + +// { +// let max_connections_str = match max_connections { +// 0 => "unlimited max connections".to_string(), +// 1 => "1 max connection".to_string(), +// _ => format!("{} max connections", max_connections), +// }; + +// LOGGER.debug( +// format!("Starting the Master Server on port {} with {max_connections_str}...", port).as_str() +// ); +// } + +// // Listen +// { +// let tcp_listener = TcpListener::bind( +// format!("{}:{}", constants::DEFAULT_IP, port) +// ).await.expect("Failed to bind to the specified port."); + +// loop { +// select! { +// event = event_receiver.recv() => { +// if let Some(event) = event { +// match event { +// Event::Connection(id) => on_connection(id), +// Event::Disconnection(id) => { +// LOGGER.debug(format!("Client#{id} disconnected.").as_str()); +// clients.remove(&id); + +// if id >= clients.len() as u32 { +// LOGGER.info(format!("Client#{id} wasn't added to the released IDs list.").as_str()); +// continue; +// } + +// let mut ids = released_ids.lock().await; +// if !(*ids).insert(id) { +// LOGGER.error(format!("ID {} already exists in the released IDs.", id).as_str()); +// continue; +// }; +// }, +// Event::ReceivedData(id, data) => on_received_data(id, &data), +// } +// } +// } +// // Listen and add clients. +// res = tcp_listener.accept() => { +// if let Ok((stream, addr)) = res { +// LOGGER.debug(format!("Accepted connection from {:?}", addr).as_str()); + +// // If the max_connections is reached, return an error. +// if max_connections != 0 && clients.len() >= (max_connections as usize) { +// LOGGER.error("Max connections reached."); +// continue; +// } + +// // Get the next available ID and insert it. +// let released_id: u32 = released_ids +// .lock().await +// .pop_first() +// .unwrap_or(clients.len() as u32); +// let mut client = ServerClient::new(released_id); +// client.handle_data(event_sender.clone(), stream).await; +// clients.insert(released_id, client); + +// event_sender.send(Event::Connection(released_id)).await.unwrap(); +// } +// } +// } +// } +// } +// } + +// // region: Events +// fn on_connection(id: u32) { +// LOGGER.debug(format!("Client#{id} connected").as_str()); +// } + +// fn on_received_data(id: u32, data: &[u8]) { +// LOGGER.debug(format!("Received data from Client#{id}: {:?}", data).as_str()); // todo!() // } -// endregion - -pub struct ServerClient { - pub id: u32, - pub name: Arc>>, - pub sender: Option>>, -} - -impl ServerClient { - pub fn new(id: u32) -> Self { - ServerClient { - id, - name: Arc::new(RwLock::new(None)), - sender: None, - } - } - - /// Handle the data from the client. - pub async fn handle_data(&mut self, event_sender: Sender, mut stream: TcpStream) { - let id = self.id; - let name = self.name.clone(); - let (tx, mut rx) = tokio::sync::mpsc::channel(10); - self.sender = Some(tx.clone()); - - tokio::spawn(async move { - let (reader, mut writer) = stream.split(); - - let mut reader = BufReader::new(reader); - - loop { - select! { - // Incoming data from the client. - command = reader.read_u8() => { - if command.is_err() { - event_sender.send(Event::Disconnection(id)).await.expect("Failed to send disconnection event."); - break; - } - - LOGGER.debug(format!("Received data from Client#{id}: {:?}", command).as_str()); - - match command.unwrap() { - x if x == FromUnknown::RequestClusters as u8 => { - let mut data = vec![ToUnknown::SendClusters as u8]; - let cluster_ids = CLUSTER_IDS.read().await; - data.push(cluster_ids.len() as u8); - for cluster in (*cluster_ids).iter() { - data.push(cluster.name.len() as u8); - data.extend_from_slice(cluster.name.as_bytes()); - data.push(cluster.ip.len() as u8); - data.extend_from_slice(cluster.ip.as_bytes()); - data.extend_from_slice(&cluster.port.to_be_bytes()); - data.extend_from_slice(&cluster.max_connections.to_be_bytes()); - } - Self::send_data(&tx, data.into_boxed_slice()).await; - }, - x if x == FromUnknown::JoinCluster as u8 => { - event_sender.send(Event::Disconnection(id)).await.expect("Failed to send disconnection event."); - break; - }, - x if x == FromUnknown::BecomeCluster as u8 => { - let len = match reader.read_u8().await { - Ok(len) => len, - Err(e) => { - LOGGER.error(format!("Failed to read cluster name length: {:?}", e).as_str()); - continue; - } - } as usize; - let mut key_name = vec![0u8; len]; - match reader.read_exact(&mut key_name).await { - Ok(_) => {}, - Err(e) => { - LOGGER.error(format!("Failed to read cluster name to String: {:?}", e).as_str()); - continue; - } - } - let key_name = String::from_utf8(key_name).unwrap(); - let key = match security::AES_KEYS.get(&key_name) { - Some(key) => key, - None => { - LOGGER.error(format!("Key {} doesn't exist.", key_name).as_str()); - continue; - } - }; - - let mut data = vec![ToUnknown::VerifyCluster as u8]; - - let passphrase = &security::generate_passphrase(); - let encrypted_passphrase = encrypt(passphrase, key); - - data.push(encrypted_passphrase.len() as u8); - data.extend_from_slice(&encrypted_passphrase); - - { - let mut name = name.write().await; - *name = Some(String::from_utf8(passphrase.to_vec()).unwrap()); - } - Self::send_data(&tx, data.into_boxed_slice()).await; - }, - x if x == FromUnknown::AnswerCluster as u8 => { - let len = reader.read_u8().await.unwrap() as usize; - let mut passphrase = vec![0u8; len]; - match reader.read_exact(&mut passphrase).await { - Ok(_) => {}, - Err(e) => { - LOGGER.error(format!("Failed to read the passphrase to String: {:?}", e).as_str()); - continue; - } - } - - { - let passphrase = match String::from_utf8(passphrase) { - Ok(passphrase) => passphrase, - Err(e) => { - LOGGER.error(format!("Failed to convert passphrase to String: {:?}", e).as_str()); - continue; - } - }; - - let name = name.read().await; - if (*name).is_none() || passphrase != *name.as_ref().expect("Failed to get saved passphrase.") { - LOGGER.error("The passphrase doesn't match the name."); - continue; - } else { - LOGGER.success(format!("The passphrase matches the name: {:?} is {}", *name, passphrase).as_str()); - } - } - - // Read their new name they sent. - let len = reader.read_u8().await.unwrap() as usize; - let mut server_name = vec![0u8; len]; - match reader.read_exact(&mut server_name).await { - Ok(_) => {}, - Err(e) => { - LOGGER.error(format!("Failed to read the server name to String: {:?}", e).as_str()); - continue; - } - }; - - let server_name = match String::from_utf8(server_name) { - Ok(server_name) => server_name, - Err(e) => { - LOGGER.error(format!("Failed to convert server name to String: {:?}", e).as_str()); - continue; - } - }; - - { - let mut name = name.write().await; - *name = Some(server_name.clone()); - } - - { - // Read IP to len. Read port u16. Read max connections u32. - let len = match reader.read_u8().await { - Ok(len) => len, - Err(e) => { - LOGGER.error(format!("Failed to read the IP length: {:?}", e).as_str()); - continue; - } - } as usize; - let mut ip = vec![0u8; len]; - match reader.read_exact(&mut ip).await { - Ok(_) => {}, - Err(e) => { - LOGGER.error(format!("Failed to read the IP to String: {:?}", e).as_str()); - continue; - } - } - let ip = match String::from_utf8(ip) { - Ok(ip) => ip, - Err(e) => { - LOGGER.error(format!("Failed to convert IP to String: {:?}", e).as_str()); - continue; - } - }; - let port = match reader.read_u16().await { - Ok(port) => port, - Err(e) => { - LOGGER.error(format!("Failed to read the port: {:?}", e).as_str()); - continue; - } - }; - let max_connections = match reader.read_u32().await { - Ok(max_connections) => max_connections, - Err(e) => { - LOGGER.error(format!("Failed to read the max connections: {:?}", e).as_str()); - continue; - } - }; - - let mut cluster_ids = CLUSTER_IDS.write().await; - if (*cluster_ids).insert(ClusterInfo { - id, - name: server_name, - ip, - port, - max_connections, - }) { - LOGGER.success(format!("Client#{id} has become a cluster.").as_str()); - } else { - LOGGER.error(format!("Client#{id} failed to become a cluster.").as_str()); - continue; - } - } - - Self::send_data(&tx, Box::new([ToUnknown::CreateCluster as u8])).await; - }, - - // Cluster Section - - _ => (), - } - } - // Outgoing data to the client. - result = rx.recv() => { - if let Some(data) = result { - writer.write_all(&data).await.expect("Failed to write to the Master Server."); - writer.flush().await.expect("Failed to flush the writer."); - } else { - writer.shutdown().await.expect("Failed to shutdown the writer."); - LOGGER.info("Cluster Server is shutting down its client writer."); - event_sender.send(Event::Disconnection(id)).await.expect("Failed to send disconnection event."); - break; - } - } - } - } - }); - } - - async fn send_data(tx: &mpsc::Sender>, data: Box<[u8]>) { - tx.send(data).await.expect("Failed to send data out."); - } -} + +// // fn on_client_connected(id: u32) { +// // LOGGER.debug(format!("Client connected: {}", id).as_str()); +// // todo!() +// // } + +// // fn on_client_disconnected(id: u32, protocol: Protocols) { +// // LOGGER.debug(format!("Client disconnected: {} {}", id, protocol as u8).as_str()); +// // todo!() +// // } + +// // fn on_client_received_data(id: u32, protocol: Protocols, data: &[u8]) { +// // LOGGER.debug(format!("Client received data: {} {} {:?}", id, protocol as u8, data).as_str()); +// // todo!() +// // } +// // endregion + +// pub struct ServerClient { +// pub id: u32, +// pub name: Arc>>, +// pub sender: Option>>, +// } + +// impl ServerClient { +// pub fn new(id: u32) -> Self { +// ServerClient { +// id, +// name: Arc::new(RwLock::new(None)), +// sender: None, +// } +// } + +// /// Handle the data from the client. +// pub async fn handle_data(&mut self, event_sender: Sender, mut stream: TcpStream) { +// let id = self.id; +// let name = self.name.clone(); +// let (tx, mut rx) = tokio::sync::mpsc::channel(10); +// self.sender = Some(tx.clone()); + +// tokio::spawn(async move { +// let (reader, mut writer) = stream.split(); + +// let mut reader = BufReader::new(reader); + +// loop { +// select! { +// // Incoming data from the client. +// command = reader.read_u8() => { +// if command.is_err() { +// event_sender.send(Event::Disconnection(id)).await.expect("Failed to send disconnection event."); +// break; +// } + +// LOGGER.debug(format!("Received data from Client#{id}: {:?}", command).as_str()); + +// match command.unwrap() { +// x if x == FromUnknown::RequestClusters as u8 => { +// let mut data = vec![ToUnknown::SendClusters as u8]; +// let cluster_ids = CLUSTER_IDS.read().await; +// data.push(cluster_ids.len() as u8); +// for cluster in (*cluster_ids).iter() { +// data.push(cluster.name.len() as u8); +// data.extend_from_slice(cluster.name.as_bytes()); +// data.push(cluster.ip.len() as u8); +// data.extend_from_slice(cluster.ip.as_bytes()); +// data.extend_from_slice(&cluster.port.to_be_bytes()); +// data.extend_from_slice(&cluster.max_connections.to_be_bytes()); +// } +// Self::send_data(&tx, data.into_boxed_slice()).await; +// }, +// x if x == FromUnknown::JoinCluster as u8 => { +// event_sender.send(Event::Disconnection(id)).await.expect("Failed to send disconnection event."); +// break; +// }, +// x if x == FromUnknown::BecomeCluster as u8 => { +// let len = match reader.read_u8().await { +// Ok(len) => len, +// Err(e) => { +// LOGGER.error(format!("Failed to read cluster name length: {:?}", e).as_str()); +// continue; +// } +// } as usize; +// let mut key_name = vec![0u8; len]; +// match reader.read_exact(&mut key_name).await { +// Ok(_) => {}, +// Err(e) => { +// LOGGER.error(format!("Failed to read cluster name to String: {:?}", e).as_str()); +// continue; +// } +// } +// let key_name = String::from_utf8(key_name).unwrap(); +// let key = match security::AES_KEYS.get(&key_name) { +// Some(key) => key, +// None => { +// LOGGER.error(format!("Key {} doesn't exist.", key_name).as_str()); +// continue; +// } +// }; + +// let mut data = vec![ToUnknown::VerifyCluster as u8]; + +// let passphrase = &security::generate_passphrase(); +// let encrypted_passphrase = encrypt(passphrase, key); + +// data.push(encrypted_passphrase.len() as u8); +// data.extend_from_slice(&encrypted_passphrase); + +// { +// let mut name = name.write().await; +// *name = Some(String::from_utf8(passphrase.to_vec()).unwrap()); +// } +// Self::send_data(&tx, data.into_boxed_slice()).await; +// }, +// x if x == FromUnknown::AnswerCluster as u8 => { +// let len = reader.read_u8().await.unwrap() as usize; +// let mut passphrase = vec![0u8; len]; +// match reader.read_exact(&mut passphrase).await { +// Ok(_) => {}, +// Err(e) => { +// LOGGER.error(format!("Failed to read the passphrase to String: {:?}", e).as_str()); +// continue; +// } +// } + +// { +// let passphrase = match String::from_utf8(passphrase) { +// Ok(passphrase) => passphrase, +// Err(e) => { +// LOGGER.error(format!("Failed to convert passphrase to String: {:?}", e).as_str()); +// continue; +// } +// }; + +// let name = name.read().await; +// if (*name).is_none() || passphrase != *name.as_ref().expect("Failed to get saved passphrase.") { +// LOGGER.error("The passphrase doesn't match the name."); +// continue; +// } else { +// LOGGER.success(format!("The passphrase matches the name: {:?} is {}", *name, passphrase).as_str()); +// } +// } + +// // Read their new name they sent. +// let len = reader.read_u8().await.unwrap() as usize; +// let mut server_name = vec![0u8; len]; +// match reader.read_exact(&mut server_name).await { +// Ok(_) => {}, +// Err(e) => { +// LOGGER.error(format!("Failed to read the server name to String: {:?}", e).as_str()); +// continue; +// } +// }; + +// let server_name = match String::from_utf8(server_name) { +// Ok(server_name) => server_name, +// Err(e) => { +// LOGGER.error(format!("Failed to convert server name to String: {:?}", e).as_str()); +// continue; +// } +// }; + +// { +// let mut name = name.write().await; +// *name = Some(server_name.clone()); +// } + +// { +// // Read IP to len. Read port u16. Read max connections u32. +// let len = match reader.read_u8().await { +// Ok(len) => len, +// Err(e) => { +// LOGGER.error(format!("Failed to read the IP length: {:?}", e).as_str()); +// continue; +// } +// } as usize; +// let mut ip = vec![0u8; len]; +// match reader.read_exact(&mut ip).await { +// Ok(_) => {}, +// Err(e) => { +// LOGGER.error(format!("Failed to read the IP to String: {:?}", e).as_str()); +// continue; +// } +// } +// let ip = match String::from_utf8(ip) { +// Ok(ip) => ip, +// Err(e) => { +// LOGGER.error(format!("Failed to convert IP to String: {:?}", e).as_str()); +// continue; +// } +// }; +// let port = match reader.read_u16().await { +// Ok(port) => port, +// Err(e) => { +// LOGGER.error(format!("Failed to read the port: {:?}", e).as_str()); +// continue; +// } +// }; +// let max_connections = match reader.read_u32().await { +// Ok(max_connections) => max_connections, +// Err(e) => { +// LOGGER.error(format!("Failed to read the max connections: {:?}", e).as_str()); +// continue; +// } +// }; + +// let mut cluster_ids = CLUSTER_IDS.write().await; +// if (*cluster_ids).insert(ClusterInfo { +// id, +// name: server_name, +// ip, +// port, +// max_connections, +// }) { +// LOGGER.success(format!("Client#{id} has become a cluster.").as_str()); +// } else { +// LOGGER.error(format!("Client#{id} failed to become a cluster.").as_str()); +// continue; +// } +// } + +// Self::send_data(&tx, Box::new([ToUnknown::CreateCluster as u8])).await; +// }, + +// // Cluster Section + +// _ => (), +// } +// } +// // Outgoing data to the client. +// result = rx.recv() => { +// if let Some(data) = result { +// writer.write_all(&data).await.expect("Failed to write to the Master Server."); +// writer.flush().await.expect("Failed to flush the writer."); +// } else { +// writer.shutdown().await.expect("Failed to shutdown the writer."); +// LOGGER.info("Cluster Server is shutting down its client writer."); +// event_sender.send(Event::Disconnection(id)).await.expect("Failed to send disconnection event."); +// break; +// } +// } +// } +// } +// }); +// } + +// async fn send_data(tx: &mpsc::Sender>, data: Box<[u8]>) { +// tx.send(data).await.expect("Failed to send data out."); +// } +// } diff --git a/rust/master/src/main.rs b/rust/master/src/main.rs index 1bd69ef..0039ef0 100644 --- a/rust/master/src/main.rs +++ b/rust/master/src/main.rs @@ -1,23 +1,14 @@ -use sustenet_shared as shared; - -use sustenet_master::{ LOGGER, start_with_config }; - -use tokio::select; - -use shared::utils; - -pub mod security; +use sustenet_master::MasterServer; #[tokio::main] async fn main() { - let mut shutdown_rx = utils::shutdown_channel().expect("Error creating shutdown channel."); - - select! { - _ = shutdown_rx.recv() => { - LOGGER.warning("Shutting down..."); - } - _ = start_with_config() => {} + let mut master = MasterServer::new_from_config().await.unwrap(); + + // Wait for the shutdown signal or start the server + tokio::select! { + _ = tokio::signal::ctrl_c() => { + println!("Shutting down..."); + }, + _ = master.start() => println!("Master server started.") } - - LOGGER.success("The Master Server has been shut down."); } diff --git a/rust/master/src/master.rs b/rust/master/src/master.rs new file mode 100644 index 0000000..16ddabe --- /dev/null +++ b/rust/master/src/master.rs @@ -0,0 +1,266 @@ +//! The master serve acts as a load balancer. +//! +//! When a client connects to it, it will redirect them to a registered cluster. +//! +//! The focus for the Master Server is to accept connections fast. So it should +//! stray away from doing too much work and it should distribute users to other +//! servers as fast as possible. + +use crate::master_client::MasterClient; +use sustenet_shared::config::master::{ Settings, read }; +use sustenet_shared::logging::{ LogType, Logger }; +use sustenet_shared::lselect; +use sustenet_shared::network::ClusterInfo; +use sustenet_shared::packets::Diagnostics; + +use std::collections::HashMap; +use std::io::Error; +use std::net::SocketAddr; +use std::sync::{ Arc, LazyLock }; + +use bytes::Bytes; +use dashmap::DashMap; +// use num_cpus; // TODO +use tokio::net::{ TcpListener, TcpStream }; +use tokio::sync::mpsc; +use tokio::{ io /* join */ }; + +/// Global logger for the master module. +pub static LOGGER: LazyLock = LazyLock::new(|| Logger::new(LogType::Master)); + +#[derive(Debug, Clone)] +pub enum MasterEvent { + /// When a connection is established with a client or server. + Connected(u64), + /// When a connection is closed with a client or server. + Disconnected(u64), + + /// When a cluster server is initialized with a passphrase. + ClusterInit(u64, [u8; 20]), + /// When a cluster server answer the passphrase correctly. + ClusterRegistered(u64, String), + /// When a cluster server fails to register with the master server. + /// This is usually due to a wrong passphrase. But it can also be due to a timeout. + ClusterRegistrationFailed(u64), + + DiagnosticsReceived(Diagnostics, Bytes), + Shutdown, + Error(String), +} + +pub type SharedConnections = Arc>; + +/// Handles connections and interactions with Cluster Servers and Clients. +pub struct MasterServer { + max_connections: u32, + bind: String, + port: u16, + + // sender: mpsc::Sender, + event_tx: mpsc::Sender, + event_rx: mpsc::Receiver, + + connections: SharedConnections, + cluster_servers: HashMap, + cluster_passphrases: HashMap, + next_id: u64, +} + +impl MasterServer { + pub async fn new(settings: Settings) -> io::Result { + let (event_tx, event_rx) = mpsc::channel::(200_000); + + Ok(Self { + max_connections: settings.max_connections, + bind: settings.bind, + port: settings.port, + + event_tx, + event_rx, + + connections: Arc::new(DashMap::new()), + cluster_servers: HashMap::new(), + cluster_passphrases: HashMap::new(), + next_id: 0, + }) + } + + pub async fn new_from_cli() -> io::Result { + // TODO (low priority): Load the configuration from CLI arguments + todo!() + } + + pub async fn new_from_config() -> io::Result { + let settings = read(); + + Self::new(settings).await + } + + /// Starts the master server and begins listening for connections. + pub async fn start(&mut self) -> io::Result<()> { + // Create Listener + let addr = format!("{}:{}", self.bind, self.port); + let listener = match TcpListener::bind(&addr).await { + Ok(l) => { + LOGGER.success(&format!("Master server started on {addr}")); + l + } + Err(e) => { + LOGGER.error(&format!("Failed to bind to {addr}")); + return Err(Error::new(e.kind(), format!("Failed to bind to ({addr}): {e}"))); + } + }; + + lselect!( + event = self.event_rx.recv() => { + if let Some(event) = event { + if !self.handle_events(event).await? { + println!("Cleaning up master server..."); + break; + }; + } + } + res = listener.accept() => self.handle_listener(res).await? + ); + + Ok(()) + } + + pub async fn handle_events(&mut self, event: MasterEvent) -> io::Result { + match event { + MasterEvent::Connected(id) => { + LOGGER.debug(&format!("Client #{id} connected")); + } + MasterEvent::Disconnected(id) => { + // The connection is already scheduled to close, so no need + // to call close() on the MasterClient. + if self.connections.remove(&id).is_none() { + LOGGER.warning(&format!("Disconnected client #{id} not found")); + return Ok(true); + } + LOGGER.debug(&format!("Client #{id} disconnected")); + } + MasterEvent::ClusterInit(id, passphrase) => { + LOGGER.debug(&format!("Cluster #{id} initialized with passphrase: {passphrase:?}")); + self.cluster_passphrases.insert(id, passphrase); + } + MasterEvent::ClusterRegistered(id, name) => { + LOGGER.success(&format!("Cluster ({name}) registered with ID #{id}")); + } + MasterEvent::ClusterRegistrationFailed(id) => { + LOGGER.error(&format!("Cluster registration failed for ID {id}")); + } + MasterEvent::DiagnosticsReceived(diagnostics, _bytes) => { + LOGGER.debug(&format!("Diagnostics received: {diagnostics:?}")); + } + MasterEvent::Error(msg) => { + LOGGER.error(&format!("Error: {msg}")); + } + MasterEvent::Shutdown => { + LOGGER.info("Received shutdown event, cleaning up..."); + return Ok(false); + } + } + Ok(true) + } + + pub async fn handle_listener( + &mut self, + res: io::Result<(TcpStream, SocketAddr)> + ) -> io::Result<()> { + let (mut stream, peer) = match res { + Ok(pair) => pair, + Err(e) => { + LOGGER.error(&format!("Failed to accept connection: {e}")); + return Err(Error::new(e.kind(), format!("Failed to accept connection: {e}"))); + } + }; + + if self.max_connections != 0 && (self.connections.len() as u32) >= self.max_connections { + LOGGER.warning("Max connections reached, rejecting new connection"); + let _ = io::AsyncWriteExt::shutdown(&mut stream).await; + return Ok(()); + } + + // Add connection + let connection = MasterClient::new(self.next_id, stream, self.event_tx.clone()).await?; + self.connections.insert(self.next_id, connection); + + LOGGER.debug(&format!("Accepted connection from {peer}")); + let _ = self.event_tx.send(MasterEvent::Connected(self.next_id)); + self.next_id += 1; + + Ok(()) + } + + /// Sends a message to a specific client ID. + pub async fn send_to(&self, id: &u64, bytes: Bytes) -> io::Result<()> { + if let Some(client) = self.connections.get(id) { + client.send(bytes).await?; + } else { + LOGGER.warning(&format!("Client {id} not found")); + return Err(Error::new(std::io::ErrorKind::NotFound, format!("Client {id} not found"))); + } + Ok(()) + } + + /// Sends a message to all connections. + pub async fn send_to_all(&self, bytes: Bytes) -> io::Result<()> { + // let mut handles = vec![]; + // for i in 0..MAX_THREADS { + // let map_clone = Arc::clone(&map); + // let handle = thread::spawn(move || { + // for j in (i * MAX_ITERS) / MAX_THREADS..((i + 1) * MAX_ITERS) / MAX_THREADS { + // let _ = map_clone.get(&j); + // } + // }); + // handles.push(handle); + // } + // for handle in handles { + // handle.join().unwrap(); + // } + + for client in self.connections.iter() { + let client = client.value(); + client.send(bytes.clone()).await?; + } + Ok(()) + } + + /// Sends a message to all cluster servers. + pub async fn send_to_clusters(&self, bytes: Bytes) -> io::Result<()> { + for cluster in self.cluster_servers.values() { + if let Err(e) = self.send_to(&cluster.id, bytes.clone()).await { + LOGGER.error(&format!("Failed to send message to cluster {}: {e}", cluster.name)); + } + } + Ok(()) + } + + pub async fn cleanup(&mut self) { + LOGGER.info("Shutting down master server, closing all connections..."); + // Stop listening for new connections. + // TODO: This might be doing nothing... + if let Err(e) = self.event_tx.send(MasterEvent::Shutdown).await { + LOGGER.error(&format!("Failed to send shutdown event: {e}")); + } + + // Clear all cluster connections and passphrases. + self.cluster_servers.clear(); + self.cluster_passphrases.clear(); + + // Close all connections for shutdown. + let keys: Vec = self.connections + .iter() + .map(|entry| *entry.key()) + .collect(); + for id in keys { + if let Some((_, client)) = self.connections.remove(&id) { + if let Err(e) = client.close().await { + LOGGER.error(&format!("Failed to close connection #{id}: {e}")); + } + } + } + LOGGER.cleanup(); + } +} diff --git a/rust/master/src/master_client.rs b/rust/master/src/master_client.rs new file mode 100644 index 0000000..eac357b --- /dev/null +++ b/rust/master/src/master_client.rs @@ -0,0 +1,284 @@ +use sustenet_shared::lselect; +use sustenet_shared::packets::{ ClusterSetup, Connection, Diagnostics }; + +use std::io::{ Error, ErrorKind }; + +use bytes::Bytes; +use tokio::io; +use tokio::io::{ AsyncReadExt, AsyncWriteExt }; +use tokio::net::TcpStream; +use tokio::sync::mpsc; +use tokio::sync::mpsc::error::SendError; + +use crate::master::{ LOGGER, MasterEvent }; +use crate::security; + +/// Handles connections that clients and cluster servers establish with the +/// master server. +pub struct MasterClient { + sender: mpsc::Sender, +} + +impl MasterClient { + pub async fn new( + id: u64, + stream: TcpStream, + event_tx: mpsc::Sender + ) -> io::Result { + let (sender, receiver) = mpsc::channel::(16); + let connection = Self { sender }; + + if let Err(e) = Self::receive(id, stream, connection.sender.clone(), receiver, event_tx) { + LOGGER.error(&format!("Failed to start connection #{id}")); + return Err(Error::new(e.kind(), format!("Failed to start connection #{id}: {e}"))); + } + + Ok(connection) + } + + /// Sends a message to the sender to close the connection. + /// + /// This should be called before getting rid of this ServerClient. + /// + /// It doesn't need to be called if the Disconnected event is triggered + /// since that event only triggers when the connection is closed by the + /// client so it's already handled. + pub async fn close(&self) -> Result<(), SendError> { + self.sender.send(Bytes::new()).await + } + + /// Receives messages from clients and handles them. + /// + /// It also enables the MasterServer to send messages through this + /// struct's sender. + pub fn receive( + id: u64, + mut stream: TcpStream, + sender: mpsc::Sender, + mut receiver: mpsc::Receiver, + event_tx: mpsc::Sender + ) -> io::Result<()> { + tokio::spawn(async move { + let (reader, mut writer) = stream.split(); + let mut reader = io::BufReader::new(reader); + + lselect!( + // Handle local requests to send a message to the the other side of the connection. + msg = receiver.recv() => { + match msg { + Some(msg) => { + if msg.is_empty() { + LOGGER.warning("Received empty message, shutting down connection"); + Self::handle_shutdown(writer, event_tx, id).await; + break; + } + + LOGGER.debug(&format!("Sending message: {:?}", msg)); + if let Err(e) = writer.write_all(&msg).await { + let msg = format!("Failed to send message to server: {e}"); + LOGGER.error(&msg); + let _ = event_tx.send(MasterEvent::Error(msg)); + } else { + // TODO: Still need to decide if we should notify about messages sent on a server. + // let _ = event_tx.send(MasterEvent::MessageSent(msg)); + } + }, + None => { + LOGGER.warning("Connection closed"); + Self::handle_shutdown(writer, event_tx, id).await; + break; + } + } + }, + command = reader.read_u8() => { + match command { + Ok(command) => { + LOGGER.debug(&format!("Received command: {command}")); + + Self::handle_command(command, id, &sender, &mut reader, &mut writer, &event_tx).await; + + // Notify listeners about the received message. + // TODO: Should we? I'm leaning more towards not notifying about commands. + // It could ruin performance. + // let _ = event_tx_clone.send(MasterEvent::CommandReceived(command)); + }, + Err(e) => { + match e.kind() { + ErrorKind::UnexpectedEof => { + LOGGER.warning(&format!("Connection #{id} closed by peer (EOF)")); + Self::handle_shutdown(writer, event_tx, id).await; + break; + } + ErrorKind::ConnectionReset => { + LOGGER.info(&format!("Connection #{id} reset by peer")); + Self::handle_shutdown(writer, event_tx, id).await; + break; + } + ErrorKind::ConnectionAborted => { + LOGGER.info(&format!("Connection #{id} aborted")); + Self::handle_shutdown(writer, event_tx, id).await; + break; + } + ErrorKind::TimedOut => { + LOGGER.warning(&format!("Connection #{id} timed out")); + Self::handle_shutdown(writer, event_tx, id).await; + break; + } + ErrorKind::BrokenPipe => { + LOGGER.info(&format!("Connection #{id} broken pipe")); + Self::handle_shutdown(writer, event_tx, id).await; + break; + } + ErrorKind::NotConnected => { + LOGGER.info(&format!("Connection #{id} not connected")); + Self::handle_shutdown(writer, event_tx, id).await; + break; + } + _ => { + let msg = format!("Failed to read command for connection #{}: {e}", id); + LOGGER.error(&msg); + let _ = event_tx.send(MasterEvent::Error(msg)); + Self::handle_shutdown(writer, event_tx, id).await; + break; + } + } + + } + } + } + ); + }); + + Ok(()) + } + + pub async fn write(writer: &mut tokio::net::tcp::WriteHalf<'_>, bytes: Bytes) -> io::Result<()> { + if let Err(e) = writer.write_all(&bytes).await { + LOGGER.error(&format!("Failed to write to stream: {e}")); + return Err(Error::new(ErrorKind::Other, format!("Failed to write to stream: {e}"))); + } + Ok(()) + } + + /// An external method to allow the master server to send messages to the client. + pub async fn send(&self, bytes: Bytes) -> io::Result<()> { + if let Err(e) = self.sender.send(bytes).await { + LOGGER.error(&format!("Failed to send message to client: {e}")); + return Err( + Error::new(ErrorKind::Other, format!("Failed to send message to client: {e}")) + ); + } + Ok(()) + } + + async fn handle_shutdown( + mut writer: tokio::net::tcp::WriteHalf<'_>, + event_tx: mpsc::Sender, + id: u64 + ) { + if let Err(e) = writer.shutdown().await { + let msg = format!("Failed to shutdown writer: {e}"); + LOGGER.error(&msg); + let _ = event_tx.send(MasterEvent::Error(msg)); + } + let _ = event_tx.send(MasterEvent::Disconnected(id)); + } + + async fn handle_command( + command: u8, + id: u64, + _sender: &mpsc::Sender, + reader: &mut io::BufReader>, + writer: &mut tokio::net::tcp::WriteHalf<'_>, + event_tx: &mpsc::Sender + ) { + // TODO: Handle commands. + // Handle the command received from the client. + match command { + x if x == (Connection::Connect as u8) => { + LOGGER.info("Handling Connection Connect"); + } + x if x == (Connection::Disconnect as u8) => { + LOGGER.info("Handling Connection Disconnect"); + } + + x if x == (Diagnostics::CheckServerType as u8) => { + LOGGER.info("Handling Diagnostics Check Server Type"); + } + x if x == (Diagnostics::CheckServerUptime as u8) => { + LOGGER.info("Handling Diagnostics Check Server Uptime"); + } + x if x == (Diagnostics::CheckServerPlayerCount as u8) => { + LOGGER.info("Handling Diagnostics Check Server Player Count"); + } + + x if x == (ClusterSetup::Init as u8) => { + // Get len of key name. + let len = (match reader.read_u8().await { + Ok(len) => len, + Err(e) => { + LOGGER.error(&format!("Failed to read cluster name length: {:?}", e)); + return; + } + }) as usize; + + // Read key name. + let mut key_name = vec![0u8; len]; + if let Err(e) = reader.read_exact(&mut key_name).await { + LOGGER.error(&format!("Failed to read cluster name to String: {:?}", e)); + return; + }; + + // Load the key. + let key_name = match String::from_utf8(key_name) { + Ok(key_name) => key_name, + Err(e) => { + LOGGER.error(&format!("Failed to read cluster name to String: {:?}", e)); + return; + } + }; + let key = match security::AES_KEYS.get(&key_name) { + Some(key) => key, + None => { + LOGGER.error(&format!("Key {key_name} doesn't exist.")); + return; + } + }; + + // Generate the secret. + let mut data = vec![ClusterSetup::Init as u8]; + let passphrase = match security::generate_passphrase() { + Ok(passphrase) => passphrase, + Err(e) => { + LOGGER.error(&format!("Failed to generate passphrase: {:?}", e)); + return; + } + }; + let encrypted_passphrase = sustenet_shared::security::aes::encrypt(&passphrase, key); + data.push(encrypted_passphrase.len() as u8); + data.extend_from_slice(&encrypted_passphrase); + + // Tell the MasterServer to store passphrase for this ID. + if let Err(e) = event_tx.send(MasterEvent::ClusterInit(id, passphrase)).await { + LOGGER.error(&format!("Failed to send passphrase: {:?}", e)); + return; + } + + // Send the encrypted passphrase to the client. + if let Err(e) = writer.write_all(&data).await { + LOGGER.error(&format!("Failed to send passphrase: {:?}", e)); + return; + } + } + x if x == (ClusterSetup::AnswerSecret as u8) => { + LOGGER.info("Handling Cluster Setup Answer Secret"); + } + + _ => { + let msg = format!("Unknown command received: {command}"); + LOGGER.error(&msg); + let _ = event_tx.send(MasterEvent::Error(msg)); + } + } + } +} diff --git a/rust/master/src/security.rs b/rust/master/src/security.rs index 72b106b..cf77a98 100644 --- a/rust/master/src/security.rs +++ b/rust/master/src/security.rs @@ -1,29 +1,27 @@ -use sustenet_shared as shared; +use std::sync::LazyLock; -use shared::security::aes::{ load_all_keys, KeyMap }; +use sustenet_shared::security::aes::{ load_all_keys, KeyMap }; -lazy_static::lazy_static! { - pub static ref AES_KEYS: KeyMap = match load_all_keys() { - Ok(keys) => keys, - Err(e) => { - println!("Failed to load keys: {:?}", e); - KeyMap::new() - } - }; -} +pub static AES_KEYS: LazyLock = LazyLock::new(|| match load_all_keys() { + Ok(keys) => keys, + Err(e) => { + println!("Failed to load keys: {:?}", e); + KeyMap::new() + } +}); const PASSWORD_LEN: usize = 20; -pub fn generate_passphrase() -> [u8; PASSWORD_LEN] { +pub fn generate_passphrase() -> Result<[u8; PASSWORD_LEN], getrandom::Error> { const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\ abcdefghijklmnopqrstuvwxyz\ 0123456789)(*&^%$#@!~"; let mut password = [0u8; PASSWORD_LEN]; - getrandom::fill(&mut password).expect("Failed to generate password."); + getrandom::fill(&mut password)?; for byte in &mut password { *byte = CHARSET[(*byte as usize) % CHARSET.len()]; } - password + Ok(password) } diff --git a/rust/shared/Cargo.toml b/rust/shared/Cargo.toml index 3712f59..293064d 100644 --- a/rust/shared/Cargo.toml +++ b/rust/shared/Cargo.toml @@ -11,9 +11,12 @@ homepage.workspace = true [lints] workspace = true +[features] +perf = [] + [dependencies] -aes-gcm = { workspace = true } +aes-gcm.workspace = true base64.workspace = true -config = { workspace = true } -ctrlc = { workspace = true } -tokio = { workspace = true, features = ["sync", "io-util", "net"] } +bytes.workspace = true +config.workspace = true +tokio = { workspace = true, features = ["rt", "signal", "sync", "io-util", "net"] } diff --git a/rust/shared/src/config.rs b/rust/shared/src/config.rs index b07d58b..1a133f0 100644 --- a/rust/shared/src/config.rs +++ b/rust/shared/src/config.rs @@ -1,26 +1,67 @@ +use crate::utils::constants::{ DEFAULT_IP, DOMAIN_PUB_KEY }; + +/// Creates a new Config.toml file with default values if it doesn't exist. +pub fn init() -> std::io::Result<()> { + create_data_dir()?; + let config_path = "data/Config.toml"; + if !std::path::Path::new(config_path).exists() { + let default_config = format!( + r#"[master] +max_connections = 0 +port = 0 +bind = "{DEFAULT_IP}" # This is the IP address to bind to. + +[cluster] +name = "Cluster Server" + +max_connections = 0 +port = 0 +bind = "{DEFAULT_IP}" # This is the IP address to bind to. + +key_name = "cluster_key" # This is the name of the key that must exist BOTH on the cluster and the master. +master_ip = "127.0.0.1" # This is the IP address of the master server. +master_port = 0 # This is the port of the master server. + +domain_pub_key = "{DOMAIN_PUB_KEY}" # Remove this if you want to use the server's bandwidth to send a key to a user directly. +"# + ); + std::fs::write(config_path, default_config)?; + } + + Ok(()) +} + +pub fn create_data_dir() -> std::io::Result<()> { + if std::fs::DirBuilder::new().recursive(true).create("data").is_err() { + return Err( + std::io::Error::new(std::io::ErrorKind::Other, "Failed to create the 'data' directory.") + ); + } + + Ok(()) +} + pub mod master { use config::{ Config, File, FileFormat::Toml }; - use crate::utils::constants::MASTER_PORT; + use crate::utils::constants::{ DEFAULT_IP, MASTER_PORT }; + /// The settings for a master server. pub struct Settings { - pub server_name: String, - pub max_connections: u32, pub port: u16, + pub bind: String, } + /// Reads the Master configuration from the Config.toml file. pub fn read() -> Settings { + super::init().expect("Failed to initialize the configuration file."); let settings = Config::builder() - .add_source(File::new("Config.toml", Toml)) + .add_source(File::new("data/Config.toml", Toml)) .build() .expect("Failed to read the configuration file."); Settings { - server_name: settings - .get::("all.server_name") - .unwrap_or("Master Server".to_string()), - max_connections: settings.get::("all.max_connections").unwrap_or(0), port: match settings.get::("all.port") { Ok(port) => @@ -30,6 +71,7 @@ pub mod master { } Err(_) => MASTER_PORT, }, + bind: settings.get::("all.bind").unwrap_or(DEFAULT_IP.to_string()), } } } @@ -39,11 +81,13 @@ pub mod cluster { use crate::utils::constants::{ CLUSTER_PORT, DEFAULT_IP, MASTER_PORT }; + /// The settings for a cluster server. pub struct Settings { - pub server_name: String, + pub name: String, pub max_connections: u32, pub port: u16, + pub bind: String, pub key_name: String, pub master_ip: String, @@ -52,19 +96,19 @@ pub mod cluster { pub domain_pub_key: Option, } + /// Reads the Cluster configuration from the Config.toml file. pub fn read() -> Settings { + super::init().expect("Failed to initialize the configuration file."); let settings = Config::builder() - .add_source(File::new("Config.toml", Toml)) + .add_source(File::new("data/Config.toml", Toml)) .build() .expect("Failed to read the configuration file."); Settings { - server_name: settings - .get::("all.server_name") - .unwrap_or("Cluster Server".to_string()), + name: settings.get::("cluster.name").unwrap_or("Cluster Server".to_string()), - max_connections: settings.get::("max_connections").unwrap_or(0), - port: match settings.get::("all.port") { + max_connections: settings.get::("cluster.max_connections").unwrap_or(0), + port: match settings.get::("cluster.port") { Ok(port) => match port { 0 => CLUSTER_PORT, @@ -72,6 +116,7 @@ pub mod cluster { } Err(_) => CLUSTER_PORT, }, + bind: settings.get::("cluster.bind").unwrap_or(DEFAULT_IP.to_string()), key_name: settings .get::("cluster.key_name") @@ -79,7 +124,7 @@ pub mod cluster { master_ip: settings .get::("cluster.master_ip") .unwrap_or(DEFAULT_IP.to_string()), - master_port: match settings.get::("cluster.master_port") { + master_port: match settings.get::("master.port") { Ok(port) => match port { 0 => MASTER_PORT, diff --git a/rust/shared/src/lib.rs b/rust/shared/src/lib.rs index 2364a7b..4132479 100644 --- a/rust/shared/src/lib.rs +++ b/rust/shared/src/lib.rs @@ -1,4 +1,9 @@ -use tokio::sync::mpsc::Sender; +//! This crate contains the shared code for the master server, cluster server, and client. + +use std::{ future::Future, pin::Pin }; + +use bytes::Bytes; +use tokio::{ io::BufReader, net::tcp::ReadHalf, sync::mpsc::Sender }; pub mod config; pub mod logging; @@ -10,36 +15,32 @@ pub mod security; pub mod macros; +pub type SenderBytes = Sender; +pub type PluginPin<'plug> = Pin + Send + 'plug>>; + pub trait ServerPlugin: Send + Sync { - fn set_sender(&self, tx: Sender>); + fn set_sender(&self, tx: SenderBytes); fn receive<'plug>( &self, - tx: Sender>, + tx: SenderBytes, command: u8, - reader: &'plug mut tokio::io::BufReader> - ) -> std::pin::Pin + Send + 'plug>>; + reader: &'plug mut BufReader> + ) -> PluginPin<'plug>; /// Only used when debugging is enabled. fn info(&self, message: &str); } pub trait ClientPlugin: Send + Sync { - fn set_sender(&self, tx: Sender>); - - fn receive_master<'plug>( - &self, - tx: Sender>, - command: u8, - reader: &'plug mut tokio::io::BufReader> - ) -> std::pin::Pin + Send + 'plug>>; - - fn receive_cluster<'plug>( + fn set_sender(&self, tx: SenderBytes); + + fn receive<'plug>( &self, - tx: Sender>, + tx: SenderBytes, command: u8, - reader: &'plug mut tokio::io::BufReader> - ) -> std::pin::Pin + Send + 'plug>>; + reader: &'plug mut BufReader> + ) -> PluginPin<'plug>; /// Only used when debugging is enabled. fn info(&self, message: &str); diff --git a/rust/shared/src/logging.rs b/rust/shared/src/logging.rs index 642e021..8ac9119 100644 --- a/rust/shared/src/logging.rs +++ b/rust/shared/src/logging.rs @@ -46,76 +46,103 @@ macro_rules! log_message { LogType::System => "[System]", }; - println!("{}{} {}{}", level_str, type_str, format!($($arg)*), TERMINAL_DEFAULT); + format!("{}{} {}{}", level_str, type_str, format!($($arg)*), TERMINAL_DEFAULT) } }; } -use crate::{ log_message, utils::constants::DEBUGGING }; +use crate::{ log_message, utils::constants::DEBUGGING, utils::constants::PERFORMANCE }; +/// Logger struct to handle logging messages with different log levels and types. pub struct Logger { plugin_info: std::sync::OnceLock>, log_type: LogType, + task: tokio::task::JoinHandle<()>, + sender: tokio::sync::mpsc::Sender, + } impl Logger { + /// Creates a new Logger instance with the specified log type. pub fn new(log_type: LogType) -> Self { + let (sender, mut receiver) = tokio::sync::mpsc::channel::(100_000); // Bounded to 100k messages + + let task = tokio::spawn(async move { + while let Some(msg) = receiver.recv().await { + println!("{msg}"); + } + }); + Logger { plugin_info: std::sync::OnceLock::new(), log_type, + task, + sender, } } - pub fn set_plugin(&self, plugin: F) where F: Fn(&str) + Send + Sync + 'static { + pub fn cleanup(&self) { + // Wait for the logging task to finish + let _ = self.task.abort(); + println!("Logger successfully cleaned up."); + } + + /// Sets the plugin info function to be called when logging messages. + pub fn set_plugin_info(&self, plugin: F) where F: Fn(&str) + Send + Sync + 'static { let _ = self.plugin_info.set(Box::new(plugin)); } + /// Logs a debug message if debugging is enabled. pub fn debug(&self, message: &str) { - if !DEBUGGING { + if !DEBUGGING || PERFORMANCE { return; } if let Some(plugin_info) = self.plugin_info.get() { plugin_info(message); } - log_message!(LogLevel::Debug, self.log_type, "{}", message); + let _ = self.sender.try_send(log_message!( LogLevel::Debug, self.log_type, "{}", message)); } + /// Logs an info message. pub fn info(&self, message: &str) { - if !DEBUGGING { + if PERFORMANCE { return; } if let Some(plugin_info) = self.plugin_info.get() { plugin_info(message); } - log_message!(LogLevel::Info, self.log_type, "{}", message); + let _ = self.sender.try_send(log_message!( LogLevel::Info, self.log_type, "{}", message)); } + /// Logs a warning message if debugging is enabled. pub fn warning(&self, message: &str) { - if !DEBUGGING { + if !DEBUGGING || PERFORMANCE { return; } if let Some(plugin_info) = self.plugin_info.get() { plugin_info(message); } - log_message!(LogLevel::Warning, self.log_type, "{}", message); + let _ = self.sender.try_send(log_message!( LogLevel::Warning, self.log_type, "{}", message)); } + /// Logs an error message. pub fn error(&self, message: &str) { - if !DEBUGGING { + if PERFORMANCE { return; } if let Some(plugin_info) = self.plugin_info.get() { plugin_info(message); } - log_message!(LogLevel::Error, self.log_type, "{}", message); + let _ = self.sender.try_send(log_message!( LogLevel::Error, self.log_type, "{}", message)); } + /// Logs a success message. pub fn success(&self, message: &str) { - if !DEBUGGING { + if PERFORMANCE { return; } if let Some(plugin_info) = self.plugin_info.get() { plugin_info(message); } - log_message!(LogLevel::Success, self.log_type, "{}", message); + let _ = self.sender.try_send(log_message!( LogLevel::Success, self.log_type, "{}", message)); } } \ No newline at end of file diff --git a/rust/shared/src/network.rs b/rust/shared/src/network.rs index 85b1dcd..cd92d65 100644 --- a/rust/shared/src/network.rs +++ b/rust/shared/src/network.rs @@ -1,3 +1,5 @@ +use std::net::IpAddr; + pub enum Protocols { TCP, UDP, @@ -10,33 +12,12 @@ pub enum Event { ReceivedData(u32, Vec), } -#[derive(Eq)] +/// Used to store cluster information that we can reuse. pub struct ClusterInfo { - pub id: u32, + pub id: u64, pub name: String, - pub ip: String, + pub ip: IpAddr, pub port: u16, pub max_connections: u32, -} - -impl Ord for ClusterInfo { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - // Define how to compare two ClusterInfo instances - // For example, if ClusterInfo has a field `id` of type i32: - self.id.cmp(&other.id) - } -} - -impl PartialOrd for ClusterInfo { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl PartialEq for ClusterInfo { - fn eq(&self, other: &Self) -> bool { - // Define when two ClusterInfo instances are equal - // For example, if ClusterInfo has a field `id` of type i32: - self.id == other.id - } + pub start_time: u32, } \ No newline at end of file diff --git a/rust/shared/src/packets.rs b/rust/shared/src/packets.rs index 91616e3..ed3833e 100644 --- a/rust/shared/src/packets.rs +++ b/rust/shared/src/packets.rs @@ -1,73 +1,218 @@ -pub mod master { - #[repr(u8)] - pub enum FromUnknown { - /// Sends a list of names and IPs to whoever requested it. - RequestClusters, - /// Just a way to gracefully disconnect the client. - JoinCluster, - - /// They send the name of the cluster's key to the Master Server. - /// If the key doesn't exist, the server will do nothing but - /// stay silent. If it does exist, it will send a generated - /// passphrase that's encrypted with AES. - BecomeCluster, - /// When they send the decrypted key back to the Master Server. - AnswerCluster, - } - #[repr(u8)] - pub enum ToUnknown { - /// Sends a list of cluster servers containing their name, ip, and port. - SendClusters, - - /// Generates a passphrase that's encrypted with AES and sends - /// it waiting for it to be sent back. It's stored in their name. - VerifyCluster, - /// Once validated, the cluster is moved to the cluster list and - /// notifies them that they're now a cluster. - CreateCluster, - - // Cluster things go here. - } +#[repr(u8)] +/// Handles packets related to messaging and chat. +pub enum Messaging { + /// Send a message to the server. + SendGlobalMessage = 200, + /// Send a message to a specific player. + SendPrivateMessage, + /// Send a message to the party. + SendPartyMessage, + /// Send a local message. + SendLocalMessage, +} + +#[repr(u8)] +pub enum Connection { + /// The client is requesting to connect to the server. + /// + /// 1. From Client to Server: CMD + Len&VersionNumber + /// + /// TODO: Run the check version function. + Connect = 240, + /// The client is disconnecting from the server. + /// This can be sent with a DisconnectReason but doesn't have to. + Disconnect, + /// Authenticate the client with the server. + Authenticate, + /// Requests a list of clusters from the server. + /// 1. From Client to Server: CMD + /// 2. From Server to Client: CMD + ListVersionNumber + HashMapLen + ID + Len&Name + IP + Port + Max Connections + Start Time + RequestClusters, + /// Sends a list of clusters to the client but this time its ID + Player count. + /// We do this because the name, IP, port, and other information is generally + /// static. + /// We can send prepend a version number at the start of the packet + /// to ensure the client is up to date. It increments everytime a + /// major change is made to the cluster server list. Like when a + /// new cluster is added, removed, or updated. + /// 1. From Server to Client: CMD + VersionNumber + ID + PlayerCount + /// If the version number doesn't match, then RequestClusters is sent on the next request. + RefreshClusters, +} +/// These are +#[repr(u8)] +pub enum DisconnectReason { + /// The client disconnected gracefully. + Graceful, + /// The server is full. + Full, + /// Server is shutting down. + Shutdown, + /// The client disconnected due to an error. + Error, } -pub mod cluster { - pub enum FromClient { - /// Sends a list of names and IPs to whoever requested it. - RequestClusters, - /// Gracefully disconnect the client and connects to the new cluster. - JoinCluster, - /// Gracefully disconnect the client. - LeaveCluster, - - /// Only works if the server doesn't have a domain config. Sends the pub key. - /// Ran if the cache key doesn't match. - RequestKey, - /// Encrypts the password and sends it. - SendPassword, - - /// Moves the player's position. - Move +#[repr(u8)] +/// Used to set up the cluster server. +pub enum ClusterSetup { + /// They send the name of the cluster's key to the Master Server. + /// If the key doesn't exist, the server will do nothing but + /// stay silent. If it does exist, it will send a generated + /// passphrase that's encrypted with AES. + /// + /// 1. From Cluster to Master: CMD + Len&VersionNumber + Len&Key Name + /// 2. From Master to Cluster: CMD + Encrypted Passphrase + /// TODO: Then you need to temporarity store them in a DashMap outside of clusters. + /// + /// If the key doesn't exist, do nothing. + Init = 245, + /// When they send the decrypted key back to the Master Server. + /// + /// 1. From Cluster to Master: CMD + Decrypted Passphrase + IP + Port + Max connections + Len&Name + /// 2. From Master to Cluster: CMD + /// + /// If it fails, say nothing. + AnswerSecret, +} + +#[repr(u8)] +#[derive(Debug, Clone)] +/// Used to request information about the server. +pub enum Diagnostics { + /// Requests information about a server's type. + CheckServerType = 250, + /// Requests information about a server's uptime. + CheckServerUptime, + /// Requests information about how many players are connected to a server. + CheckServerPlayerCount, + /// Requests information about all players across all servers. + CheckTotalPlayerCount, +} + +#[cfg(test)] +pub mod tests { + use crate::packets::{ ClusterSetup, Connection, Diagnostics, Messaging }; + + #[test] + fn test_enum_size() { + assert_eq!(std::mem::size_of::(), 1); + assert_eq!(std::mem::size_of::(), 1); + assert_eq!(std::mem::size_of::(), 1); + assert_eq!(std::mem::size_of::(), 1); } - pub enum ToClient { - /// Sends a list of cluster servers containing their name, ip, and port. - SendClusters, - /// Disconnects the client from the cluster. - DisconnectCluster, - /// Disconnects the client from the cluster. - LeaveCluster, - - // Sends the cached version of the key. - VersionOfKey, - /// Sends the public key to the client. This only works if "domain_pub_key" - /// is not set in the Config. - SendPubKey, - /// This sends back the status to the user. It'll have a status code. - /// 20 = 200, 44 = 404, 40 = 400, 50 = 500. - /// If 20, it will send the user their ID (this was assigned on initial connection). - Authenticate, - - /// Sends the player's new position. - Move + #[test] + /// This test checks that all enum values are unique. + /// The values should range between 200 and 255. + /// That's the range for reserved commands for Sustenet. + fn test_enum_unique_values() { + use std::collections::HashSet; + + macro_rules! enum_values { + ($enum:ty, [$($variant:path),* $(,)?]) => { + vec![$($variant as u8),*] + }; + } + + let all_enums = [ + enum_values!(Messaging, [ + Messaging::SendGlobalMessage, + Messaging::SendPrivateMessage, + Messaging::SendPartyMessage, + Messaging::SendLocalMessage, + ]), + enum_values!(Connection, [ + Connection::Connect, + Connection::Disconnect, + Connection::Authenticate, + Connection::RequestClusters, + Connection::RefreshClusters, + ]), + enum_values!(ClusterSetup, [ClusterSetup::Init, ClusterSetup::AnswerSecret]), + enum_values!(Diagnostics, [ + Diagnostics::CheckServerType, + Diagnostics::CheckServerUptime, + Diagnostics::CheckServerPlayerCount, + Diagnostics::CheckTotalPlayerCount, + ]), + ].concat(); + + let mut set = HashSet::new(); + for val in all_enums { + assert!(set.insert(val), "Duplicate value found: {val}"); + } } } + +// pub mod master { +// #[repr(u8)] +// pub enum FromUnknown { +// /// Sends a list of names and IPs to whoever requested it. +// RequestClusters, +// /// Just a way to gracefully disconnect the client. +// JoinCluster, + +// /// They send the name of the cluster's key to the Master Server. +// /// If the key doesn't exist, the server will do nothing but +// /// stay silent. If it does exist, it will send a generated +// /// passphrase that's encrypted with AES. +// BecomeCluster, +// /// When they send the decrypted key back to the Master Server. +// AnswerCluster, +// } +// #[repr(u8)] +// pub enum ToUnknown { +// /// Sends a list of cluster servers containing their name, ip, and port. +// SendClusters, + +// /// Generates a passphrase that's encrypted with AES and sends +// /// it waiting for it to be sent back. It's stored in their name. +// VerifyCluster, +// /// Once validated, the cluster is moved to the cluster list and +// /// notifies them that they're now a cluster. +// CreateCluster, +// // Cluster things go here. +// } +// } + +// pub mod cluster { +// pub enum FromClient { +// /// Sends a list of names and IPs to whoever requested it. +// RequestClusters, +// /// Gracefully disconnect the client and connects to the new cluster. +// JoinCluster, +// /// Gracefully disconnect the client. +// LeaveCluster, + +// /// Only works if the server doesn't have a domain config. Sends the pub key. +// /// Ran if the cache key doesn't match. +// RequestKey, +// /// Encrypts the password and sends it. +// SendPassword, + +// /// Moves the player's position. +// Move, +// } + +// pub enum ToClient { +// /// Sends a list of cluster servers containing their name, ip, and port. +// SendClusters, +// /// Disconnects the client from the cluster. +// DisconnectCluster, +// /// Disconnects the client from the cluster. +// LeaveCluster, + +// // Sends the cached version of the key. +// VersionOfKey, +// /// Sends the public key to the client. This only works if "domain_pub_key" +// /// is not set in the Config. +// SendPubKey, +// /// This sends back the status to the user. It'll have a status code. +// /// 20 = 200, 44 = 404, 40 = 400, 50 = 500. +// /// If 20, it will send the user their ID (this was assigned on initial connection). +// Authenticate, + +// /// Sends the player's new position. +// Move, +// } +// } diff --git a/rust/shared/src/security.rs b/rust/shared/src/security.rs index ddce975..8b11508 100644 --- a/rust/shared/src/security.rs +++ b/rust/shared/src/security.rs @@ -25,7 +25,7 @@ pub mod aes { }; pub fn create_keys_dir() -> std::io::Result<()> { - if std::fs::DirBuilder::new().recursive(true).create("keys").is_err() { + if std::fs::DirBuilder::new().recursive(true).create("data/keys").is_err() { return Err( std::io::Error::new( std::io::ErrorKind::Other, @@ -42,13 +42,14 @@ pub mod aes { } pub fn save_key(name: &str, key: Key) -> std::io::Result<()> { - let mut file = File::create(format!("keys/{name}"))?; + create_keys_dir()?; + let mut file = File::create(format!("data/keys/{name}"))?; file.write_all(key.as_slice())?; Ok(()) } pub fn load_key(name: &str) -> std::io::Result> { - let mut file = match File::open(format!("keys/{name}")) { + let mut file = match File::open(format!("data/keys/{name}")) { Ok(file) => file, Err(_) => { return Err( @@ -83,7 +84,7 @@ pub mod aes { pub fn load_all_keys() -> std::io::Result>> { let mut keys = HashMap::new(); - let entries = match std::fs::read_dir("keys") { + let entries = match std::fs::read_dir("data/keys") { Ok(entries) => entries, Err(_) => { return Err( @@ -142,7 +143,7 @@ pub mod tests { pub fn test_create_keys_dir() { match create_keys_dir() { Ok(_) => - assert!(std::path::Path::new("keys").exists(), "Keys directory does not exist."), + assert!(std::path::Path::new("data/keys").exists(), "Keys directory does not exist."), Err(e) => { panic!("Failed to create keys directory: {:?}", e); } @@ -223,10 +224,6 @@ pub mod tests { #[test] pub fn test_load_all_keys() { - if let Err(e) = create_keys_dir() { - panic!("Failed to create keys directory: {:?}", e); - } - match save_key("cluster_key_testrunner4", generate_key()) { Ok(_) => {} Err(e) => { diff --git a/rust/shared/src/utils.rs b/rust/shared/src/utils.rs index ead6054..f8106b3 100644 --- a/rust/shared/src/utils.rs +++ b/rust/shared/src/utils.rs @@ -1,30 +1,28 @@ -/// Create a channel to listen for shutdown signals. -pub fn shutdown_channel() -> Result, ctrlc::Error> { - let (tx, rx) = tokio::sync::broadcast::channel::(1); - - // Handle shutdowns gracefully. - ctrlc - ::set_handler(move || { - tx.send(true).unwrap(); - }) - .expect("Error setting Ctrl-C handler"); - - Ok(rx) -} - pub mod constants { pub const VERSION: &str = "0.1.4"; pub const DEBUGGING: bool = cfg!(debug_assertions); + #[cfg(feature = "perf")] + pub(crate) const PERFORMANCE: bool = true; + + #[cfg(not(feature = "perf"))] + pub(crate) const PERFORMANCE: bool = false; + /// How many ticks are in a second. pub const TICK_RATE: i32 = 30; pub const MS_PER_TICK: u64 = 1000 / (TICK_RATE as u64); - pub const DEFAULT_IP: &str = "127.0.0.1"; + /// Default IP for binding the server. + pub const DEFAULT_IP: &str = "0.0.0.0"; pub const MASTER_PORT: u16 = 6256; pub const CLUSTER_PORT: u16 = 6257; + /// Optional. If it's used, this domain's pub key will be used + /// to encrypt data between the cluster and client for one-way + /// communication. + pub const DOMAIN_PUB_KEY: &str = "https://site-cdn.playreia.com/game/pubkey.pub"; + pub const TERMINAL_BG_GRAY: &str = "\x1b[47m"; pub const TERMINAL_DEFAULT: &str = "\x1b[39m"; pub const TERMINAL_BLACK: &str = "\x1b[30m"; diff --git a/rust/tests/Cargo.toml b/rust/tests/Cargo.toml new file mode 100644 index 0000000..46bdcb2 --- /dev/null +++ b/rust/tests/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "sustenet-tests" +version.workspace = true +edition.workspace = true +publish = false +description = "Sustenet tests." + +license.workspace = true +authors.workspace = true +homepage.workspace = true + +[lints] +workspace = true + +[features] +ignored_tests = [] + +[dependencies] +dashmap.workspace = true +sustenet-master.workspace = true +sustenet-shared.workspace = true +tokio = { workspace = true, features = [ + # "socket2", + "macros", + "rt-multi-thread", + "net", + # "sync", + "io-util", + "time", +] } diff --git a/rust/tests/src/lib.rs b/rust/tests/src/lib.rs new file mode 100644 index 0000000..f18575a --- /dev/null +++ b/rust/tests/src/lib.rs @@ -0,0 +1,4 @@ +mod test_dyn_invoke; +pub mod test_tcplistener; +mod test_maps; +mod test_stress; \ No newline at end of file diff --git a/rust/tests/src/main.rs b/rust/tests/src/main.rs new file mode 100644 index 0000000..380ec6e --- /dev/null +++ b/rust/tests/src/main.rs @@ -0,0 +1,3 @@ +fn main() { + sustenet_tests::test_tcplistener::test_create_connections(); +} \ No newline at end of file diff --git a/rust/tests/src/test_dyn_invoke.rs b/rust/tests/src/test_dyn_invoke.rs new file mode 100644 index 0000000..1cb4e30 --- /dev/null +++ b/rust/tests/src/test_dyn_invoke.rs @@ -0,0 +1,52 @@ +#[cfg(test)] +mod tests { + // The point of these tests is to check how long it takes to invoke a function dynamically + // vs statically. We'll use sustenet_shared::ServerPlugin as the trait to test against. + + const MAX_INVOKES: usize = 1_000_000_000; + + /// The function to invoke statically. + /// Static invoke took: 4.2954641s ~ 9% faster than dynamic invoke. + #[test] + #[ignore] + fn test_static_invoke() { + let mut count = 0; + let start = std::time::Instant::now(); + for _ in 0..MAX_INVOKES { + count_invoke(&mut count); + } + let duration = start.elapsed(); + println!("Static invoke took: {:?}", duration); + assert_eq!(count, MAX_INVOKES as u32); + } + + fn count_invoke(count: &mut u32) { + *count += 1; + } + + /// The function to invoke dynamically. + /// Dynamic invoke took: 4.6038441s + #[test] + #[ignore] + fn test_dynamic_invoke() { + let mut count = 0; + let plugin: Box = Box::new(TestPlugin); + let start = std::time::Instant::now(); + for _ in 0..MAX_INVOKES { + plugin.count_invoke(&mut count); + } + let duration = start.elapsed(); + println!("Dynamic invoke took: {:?}", duration); + assert_eq!(count, MAX_INVOKES as u32); + } + + pub trait DynPlugin: Send + Sync { + fn count_invoke(&self, count: &mut u32); + } + struct TestPlugin; + impl DynPlugin for TestPlugin { + fn count_invoke(&self, count: &mut u32) { + *count += 1; + } + } +} \ No newline at end of file diff --git a/rust/tests/src/test_maps.rs b/rust/tests/src/test_maps.rs new file mode 100644 index 0000000..848d5b6 --- /dev/null +++ b/rust/tests/src/test_maps.rs @@ -0,0 +1,452 @@ +#[cfg(test)] +mod tests { + const MAX_ITERS: usize = 10_000_000; + const MAX_THREADS: usize = 8; + + /// Tests the speed of adding, getting, updating, and removing items from a HashMap. + /// Time taken to add 10000000 items: 7.9314173s + /// Time taken to get 10000000 items: 4.9250668s + /// Time taken to update 10000000 items: 5.0271649s + /// Time taken to remove 10000000 items: 7.4092275s + #[test] + #[ignore] + fn test_hashmaps() { + use std::collections::HashMap; + use std::time::Instant; + + let mut map: HashMap = HashMap::new(); + + // Test adding items + let start = Instant::now(); + for i in 0..MAX_ITERS { + map.insert(i, i); + } + let duration_add = start.elapsed(); + println!("Time taken to add {} items: {:?}", MAX_ITERS, duration_add); + + // Test getting items + let start = Instant::now(); + for i in 0..MAX_ITERS { + let _ = map.get(&i); + } + let duration_get = start.elapsed(); + println!("Time taken to get {} items: {:?}", MAX_ITERS, duration_get); + + // Test updating items + let start = Instant::now(); + for i in 0..MAX_ITERS { + map.insert(i, i + 1); + } + let duration_update = start.elapsed(); + println!("Time taken to update {} items: {:?}", MAX_ITERS, duration_update); + + // Test removing items + let start = Instant::now(); + for i in 0..MAX_ITERS { + map.remove(&i); + } + let duration_remove = start.elapsed(); + println!("Time taken to remove {} items: {:?}", MAX_ITERS, duration_remove); + } + + /// Tests the speed of adding, getting, updating, and removing items from a DashMap. + /// Time taken to add 10000000 items: 12.5746573s + /// Time taken to get 10000000 items: 6.2193888s + /// Time taken to update 10000000 items: 6.5075287s + /// Time taken to remove 10000000 items: 8.9363106s + #[test] + #[ignore] + fn test_dashmaps() { + use dashmap::DashMap; + use std::time::Instant; + + let map: DashMap = DashMap::new(); + + // Test adding items + let start = Instant::now(); + for i in 0..MAX_ITERS { + map.insert(i, i); + } + let duration_add = start.elapsed(); + println!("Time taken to add {} items: {:?}", MAX_ITERS, duration_add); + + // Test getting items + let start = Instant::now(); + for i in 0..MAX_ITERS { + let _ = map.get(&i); + } + let duration_get = start.elapsed(); + println!("Time taken to get {} items: {:?}", MAX_ITERS, duration_get); + + // Test updating items + let start = Instant::now(); + for i in 0..MAX_ITERS { + map.insert(i, i + 1); + } + let duration_update = start.elapsed(); + println!("Time taken to update {} items: {:?}", MAX_ITERS, duration_update); + + // Test removing items + let start = Instant::now(); + for i in 0..MAX_ITERS { + map.remove(&i); + } + let duration_remove = start.elapsed(); + println!("Time taken to remove {} items: {:?}", MAX_ITERS, duration_remove); + } + + /// Tests the speed of adding, getting, updating, and removing items from a HashMap with threads. + /// Time taken to add 10000000 items with threads: 12.3948596s + /// Time taken to get 10000000 items with threads: 9.8344117s + /// Time taken to update 10000000 items with threads: 9.98254s + /// Time taken to remove 10000000 items with threads: 13.0771387s + #[test] + #[ignore] + fn test_hashmaps_with_threads() { + use std::collections::HashMap; + use std::sync::{ Arc, Mutex }; + use std::thread; + use std::time::Instant; + + let map: Arc>> = Arc::new(Mutex::new(HashMap::new())); + + // Test adding items with threads + let start = Instant::now(); + let mut handles = vec![]; + for i in 0..MAX_THREADS { + let map_clone = Arc::clone(&map); + let handle = thread::spawn(move || { + for j in (i * MAX_ITERS) / MAX_THREADS..((i + 1) * MAX_ITERS) / MAX_THREADS { + map_clone.lock().unwrap().insert(j, j); + } + }); + handles.push(handle); + } + for handle in handles { + handle.join().unwrap(); + } + let duration_add = start.elapsed(); + println!("Time taken to add {} items with threads: {:?}", MAX_ITERS, duration_add); + + // Test getting items with threads + let start = Instant::now(); + let mut handles = vec![]; + for i in 0..MAX_THREADS { + let map_clone = Arc::clone(&map); + let handle = thread::spawn(move || { + for j in (i * MAX_ITERS) / MAX_THREADS..((i + 1) * MAX_ITERS) / MAX_THREADS { + let _ = map_clone.lock().unwrap().get(&j); + } + }); + handles.push(handle); + } + for handle in handles { + handle.join().unwrap(); + } + let duration_get = start.elapsed(); + println!("Time taken to get {} items with threads: {:?}", MAX_ITERS, duration_get); + + // Test updating items with threads + let start = Instant::now(); + let mut handles = vec![]; + for i in 0..MAX_THREADS { + let map_clone = Arc::clone(&map); + let handle = thread::spawn(move || { + for j in (i * MAX_ITERS) / MAX_THREADS..((i + 1) * MAX_ITERS) / MAX_THREADS { + map_clone + .lock() + .unwrap() + .insert(j, j + 1); + } + }); + handles.push(handle); + } + for handle in handles { + handle.join().unwrap(); + } + let duration_update = start.elapsed(); + println!("Time taken to update {} items with threads: {:?}", MAX_ITERS, duration_update); + + let start = Instant::now(); + let mut handles = vec![]; + for i in 0..MAX_THREADS { + let map_clone = Arc::clone(&map); + let handle = thread::spawn(move || { + for j in (i * MAX_ITERS) / MAX_THREADS..((i + 1) * MAX_ITERS) / MAX_THREADS { + map_clone.lock().unwrap().remove(&j); + } + }); + handles.push(handle); + } + for handle in handles { + handle.join().unwrap(); + } + let duration_remove = start.elapsed(); + println!("Time taken to remove {} items with threads: {:?}", MAX_ITERS, duration_remove); + } + + /// Tests the speed of adding, getting, updating, and removing items from a DashMap with threads. + /// Time taken to add 10000000 items with threads: 6.2227182s + /// Time taken to get 10000000 items with threads: 1.3018232s + /// Time taken to update 10000000 items with threads: 1.4454566s + /// Time taken to remove 10000000 items with threads: 2.0433425s + #[test] + #[ignore] + fn test_dashmaps_with_threads() { + use dashmap::DashMap; + use std::sync::Arc; + use std::thread; + use std::time::Instant; + + let map: Arc> = Arc::new(DashMap::new()); + + // Test adding items with threads + let start = Instant::now(); + let mut handles = vec![]; + for i in 0..MAX_THREADS { + let map_clone = Arc::clone(&map); + let handle = thread::spawn(move || { + for j in (i * MAX_ITERS) / MAX_THREADS..((i + 1) * MAX_ITERS) / MAX_THREADS { + map_clone.insert(j, j); + } + }); + handles.push(handle); + } + for handle in handles { + handle.join().unwrap(); + } + let duration_add = start.elapsed(); + println!("Time taken to add {} items with threads: {:?}", MAX_ITERS, duration_add); + + // Test getting items with threads + let start = Instant::now(); + let mut handles = vec![]; + for i in 0..MAX_THREADS { + let map_clone = Arc::clone(&map); + let handle = thread::spawn(move || { + for j in (i * MAX_ITERS) / MAX_THREADS..((i + 1) * MAX_ITERS) / MAX_THREADS { + let _ = map_clone.get(&j); + } + }); + handles.push(handle); + } + for handle in handles { + handle.join().unwrap(); + } + let duration_get = start.elapsed(); + println!("Time taken to get {} items with threads: {:?}", MAX_ITERS, duration_get); + + // Test updating items with threads + let start = Instant::now(); + let mut handles = vec![]; + for i in 0..MAX_THREADS { + let map_clone = Arc::clone(&map); + let handle = thread::spawn(move || { + for j in (i * MAX_ITERS) / MAX_THREADS..((i + 1) * MAX_ITERS) / MAX_THREADS { + map_clone.insert(j, j + 1); + } + }); + handles.push(handle); + } + for handle in handles { + handle.join().unwrap(); + } + let duration_update = start.elapsed(); + println!("Time taken to update {} items with threads: {:?}", MAX_ITERS, duration_update); + + // Test removing items with threads + let start = Instant::now(); + let mut handles = vec![]; + for i in 0..MAX_THREADS { + let map_clone = Arc::clone(&map); + let handle = thread::spawn(move || { + for j in (i * MAX_ITERS) / MAX_THREADS..((i + 1) * MAX_ITERS) / MAX_THREADS { + map_clone.remove(&j); + } + }); + handles.push(handle); + } + for handle in handles { + handle.join().unwrap(); + } + let duration_remove = start.elapsed(); + println!("Time taken to remove {} items with threads: {:?}", MAX_ITERS, duration_remove); + } + + /// Tests the speed of adding, getting, updating, and removing items from a HashMap on tokio. + /// Time taken to add 10000000 items with tokio: 8.4185739s + /// Time taken to get 10000000 items with tokio: 5.542175s + /// Time taken to update 10000000 items with tokio: 5.5900584s + /// Time taken to remove 10000000 items with tokio: 7.4953945s + #[tokio::test] + #[ignore] + async fn test_hashmaps_tokio() { + use std::collections::HashMap; + use std::sync::{ Arc, Mutex }; + use tokio::time::Instant; + + let map: Arc>> = Arc::new(Mutex::new(HashMap::new())); + + // Test adding items + let start = Instant::now(); + let mut handles = vec![]; + for i in 0..MAX_THREADS { + let map_clone = Arc::clone(&map); + let handle = tokio::spawn(async move { + for j in (i * MAX_ITERS) / MAX_THREADS..((i + 1) * MAX_ITERS) / MAX_THREADS { + map_clone.lock().unwrap().insert(j, j); + } + }); + handles.push(handle); + } + for handle in handles { + handle.await.unwrap(); + } + let duration_add = start.elapsed(); + println!("Time taken to add {} items with tokio: {:?}", MAX_ITERS, duration_add); + + // Test getting items + let start = Instant::now(); + let mut handles = vec![]; + for i in 0..MAX_THREADS { + let map_clone = Arc::clone(&map); + let handle = tokio::spawn(async move { + for j in (i * MAX_ITERS) / MAX_THREADS..((i + 1) * MAX_ITERS) / MAX_THREADS { + let _ = map_clone.lock().unwrap().get(&j); + } + }); + handles.push(handle); + } + for handle in handles { + handle.await.unwrap(); + } + let duration_get = start.elapsed(); + println!("Time taken to get {} items with tokio: {:?}", MAX_ITERS, duration_get); + + // Test updating items + let start = Instant::now(); + let mut handles = vec![]; + for i in 0..MAX_THREADS { + let map_clone = Arc::clone(&map); + let handle = tokio::spawn(async move { + for j in (i * MAX_ITERS) / MAX_THREADS..((i + 1) * MAX_ITERS) / MAX_THREADS { + map_clone + .lock() + .unwrap() + .insert(j, j + 1); + } + }); + handles.push(handle); + } + for handle in handles { + handle.await.unwrap(); + } + let duration_update = start.elapsed(); + println!("Time taken to update {} items with tokio: {:?}", MAX_ITERS, duration_update); + + // Test removing items + let start = Instant::now(); + let mut handles = vec![]; + for i in 0..MAX_THREADS { + let map_clone = Arc::clone(&map); + let handle = tokio::spawn(async move { + for j in (i * MAX_ITERS) / MAX_THREADS..((i + 1) * MAX_ITERS) / MAX_THREADS { + map_clone.lock().unwrap().remove(&j); + } + }); + handles.push(handle); + } + for handle in handles { + handle.await.unwrap(); + } + let duration_remove = start.elapsed(); + println!("Time taken to remove {} items with tokio: {:?}", MAX_ITERS, duration_remove); + } + + /// Tests the speed of adding, getting, updating, and removing items from a DashMap with tokio. + /// Time taken to add 10000000 items with tokio: 12.3175472s + /// Time taken to get 10000000 items with tokio: 6.2607827s + /// Time taken to update 10000000 items with tokio: 7.4993663s + /// Time taken to remove 10000000 items with tokio: 9.1845476s + #[tokio::test] + #[ignore] + async fn test_dashmaps_tokio() { + use dashmap::DashMap; + use std::sync::Arc; + use tokio::time::Instant; + + let map: Arc> = Arc::new(DashMap::new()); + + // Test adding items + let start = Instant::now(); + let mut handles = vec![]; + for i in 0..MAX_THREADS { + let map_clone = Arc::clone(&map); + let handle = tokio::spawn(async move { + for j in (i * MAX_ITERS) / MAX_THREADS..((i + 1) * MAX_ITERS) / MAX_THREADS { + map_clone.insert(j, j); + } + }); + handles.push(handle); + } + for handle in handles { + handle.await.unwrap(); + } + let duration_add = start.elapsed(); + println!("Time taken to add {} items with tokio: {:?}", MAX_ITERS, duration_add); + + // Test getting items + let start = Instant::now(); + let mut handles = vec![]; + for i in 0..MAX_THREADS { + let map_clone = Arc::clone(&map); + let handle = tokio::spawn(async move { + for j in (i * MAX_ITERS) / MAX_THREADS..((i + 1) * MAX_ITERS) / MAX_THREADS { + let _ = map_clone.get(&j); + } + }); + handles.push(handle); + } + for handle in handles { + handle.await.unwrap(); + } + let duration_get = start.elapsed(); + println!("Time taken to get {} items with tokio: {:?}", MAX_ITERS, duration_get); + + // Test updating items + let start = Instant::now(); + let mut handles = vec![]; + for i in 0..MAX_THREADS { + let map_clone = Arc::clone(&map); + let handle = tokio::spawn(async move { + for j in (i * MAX_ITERS) / MAX_THREADS..((i + 1) * MAX_ITERS) / MAX_THREADS { + map_clone.insert(j, j + 1); + } + }); + handles.push(handle); + } + for handle in handles { + handle.await.unwrap(); + } + let duration_update = start.elapsed(); + println!("Time taken to update {} items with tokio: {:?}", MAX_ITERS, duration_update); + + // Test removing items + let start = Instant::now(); + let mut handles = vec![]; + for i in 0..MAX_THREADS { + let map_clone = Arc::clone(&map); + let handle = tokio::spawn(async move { + for j in (i * MAX_ITERS) / MAX_THREADS..((i + 1) * MAX_ITERS) / MAX_THREADS { + map_clone.remove(&j); + } + }); + handles.push(handle); + } + for handle in handles { + handle.await.unwrap(); + } + let duration_remove = start.elapsed(); + println!("Time taken to remove {} items with tokio: {:?}", MAX_ITERS, duration_remove); + } +} diff --git a/rust/tests/src/test_stress.rs b/rust/tests/src/test_stress.rs new file mode 100644 index 0000000..a482207 --- /dev/null +++ b/rust/tests/src/test_stress.rs @@ -0,0 +1,73 @@ +#[cfg(test)] +mod tests { + use sustenet_shared::utils::constants; + use tokio::io::AsyncWriteExt; + + use sustenet_master::MasterServer; + + const MAX_CONNS: usize = 10_000; + + // Connected 10000 clients in 7.72s (Power Save | No Turbo | 1165G7) + #[tokio::test] + #[ignore] + async fn test_without_threads() { + let mut server = MasterServer::new_from_config().await.unwrap(); + tokio::spawn(async move { + server.start().await.unwrap(); + }); + + // Simulate clients connecting + let start = tokio::time::Instant::now(); + let addr = format!("127.0.0.1:{}", constants::MASTER_PORT); + for i in 0..MAX_CONNS { + match tokio::net::TcpStream::connect(&addr).await { + Ok(mut stream) => { + // Simulate sending a message + let _ = stream.write_all(format!("Hello from client {i}").as_bytes()).await; + } + Err(e) => { + eprintln!("Failed to connect client {i}: {e}"); + } + } + } + let duration = start.elapsed(); + println!("Connected {MAX_CONNS} clients in {:.2?}", duration); + } + + // Connected 10000 clients in 3.26s (Power Save | No Turbo | 1165G7) + #[tokio::test] + #[ignore] + async fn test_with_threads() { + let mut server = MasterServer::new_from_config().await.unwrap(); + tokio::spawn(async move { + server.start().await.unwrap(); + }); + + // Simulate clients connecting + let start = tokio::time::Instant::now(); + let mut handles = vec![]; + + let addr: &str = "127.0.0.1:6256"; + for i in 0..MAX_CONNS { + let handle = tokio::spawn(async move { + match tokio::net::TcpStream::connect(addr).await { + Ok(mut stream) => { + // Simulate sending a message + let _ = stream.write_all(format!("Hello from client {i}").as_bytes()).await; + } + Err(e) => { + eprintln!("Failed to connect client {i}: {e}"); + } + } + }); + handles.push(handle); + } + for handle in handles { + if let Err(e) = handle.await { + eprintln!("Failed to join thread: {e}"); + } + } + let duration = start.elapsed(); + println!("Connected {MAX_CONNS} clients in {:.2?}", duration); + } +} \ No newline at end of file diff --git a/rust/tests/src/test_tcplistener.rs b/rust/tests/src/test_tcplistener.rs new file mode 100644 index 0000000..6091b91 --- /dev/null +++ b/rust/tests/src/test_tcplistener.rs @@ -0,0 +1,238 @@ +use std::net::SocketAddr; + +pub const MAX_CONNS: usize = 10_000; +pub const MAX_THREADS: usize = 20; +pub const ADDR: &str = "0.0.0.0:6256"; + +/// Creates connections to a TCP listener based on the number of connections specified in the CLI. +pub fn test_create_connections() { + if std::env::args().any(|arg| (arg == "--help" || arg == "-h")) { + println!("Usage: test_create_connections [options]"); + println!("-d | --dest to specify the destination address (default: {})", ADDR); + println!("-c | --conns to specify the number of connections (default: {})", MAX_CONNS); + println!("-t | --threads to specify the number of threads (default: {})", MAX_THREADS); + return; + } + println!("Add --help or -h for usage information."); + let mut addr = ADDR.to_string(); + let mut max_conns = MAX_CONNS; + let mut max_threads = MAX_THREADS; + let mut args = std::env::args().peekable(); + while let Some(arg) = args.next() { + match arg.as_str() { + "-d" | "--dest" => { + if let Some(val) = args.next() { + addr = val; + } + } + "-c" | "--conns" => { + if let Some(val) = args.next() { + if let Ok(num) = val.parse::() { + if num > 0 { + max_conns = num; + } else { + eprintln!("Number of connections must be greater than 0."); + } + } + } + } + "-t" | "--threads" => { + if let Some(val) = args.next() { + if let Ok(num) = val.parse::() { + if num > 0 { + max_threads = num; + } else { + eprintln!("Number of threads must be greater than 0."); + } + } + } + } + _ => {} + } + } + let mut handles = vec![]; + println!("Starting to create {} connections to {}", max_conns, addr); + println!("Press Enter to begin..."); + let mut input = String::new(); + let _ = std::io::stdin().read_line(&mut input); + let start = std::time::Instant::now(); + + let connections_per_thread = max_conns / max_threads; + let remainder = max_conns % max_threads; + // Break it up over the threads + for i in 0..max_threads { + let addr_clone: SocketAddr = addr.parse().unwrap(); + let num_conns = if i < remainder { + connections_per_thread + 1 + } else { + connections_per_thread + }; + let handle = std::thread::spawn(move || { + for _ in 0..num_conns { + match std::net::TcpStream::connect(&addr_clone) { + Ok(mut _stream) => { + // Optionally, send a message to the server + // let _ = stream.write_all(b"Hello from client").unwrap(); + } + Err(e) => { + eprintln!("Failed to connect: {}", e); + } + } + } + }); + handles.push(handle); + } + for handle in handles { + let _ = handle.join(); + } + let duration = start.elapsed(); + println!("Created {} connections in {:?}", max_conns, duration); +} + +#[cfg(test)] +pub mod tests { + // use std::io::Write; + + use std::sync::atomic::AtomicUsize; + + use super::{ ADDR, MAX_CONNS, MAX_THREADS }; + + /// Tests the time it takes to accept a fixed number of TCP connections. + #[test] + #[ignore] + pub fn test_tcplistener_default() { + let server = std::net::TcpListener::bind(ADDR).unwrap(); + println!("TCP listener bound to {}", ADDR); + // Accept the first connection, then start timing + let _ = server.accept().unwrap(); + println!("First connection accepted, starting timer..."); + let start = std::time::Instant::now(); + for _ in 1..MAX_CONNS { + let _ = server.accept().unwrap(); + } + let duration = start.elapsed(); + println!("Time taken to accept {} connections: {:?}", MAX_CONNS, duration); + } + + /// Tests the time it takes to accept a fixed number of TCP connections with tokio. + #[tokio::test] + #[ignore] + async fn test_tcplistener_tokio() { + let server = tokio::net::TcpListener::bind(ADDR).await.unwrap(); + println!("Tokio TCP listener bound to {}", ADDR); + // Accept the first connection, then start timing + let _ = server.accept().await.unwrap(); + println!("First connection accepted, starting timer..."); + let start = tokio::time::Instant::now(); + for _ in 1..MAX_CONNS { + let _ = server.accept().await.unwrap(); + } + let duration = start.elapsed(); + println!("Time taken to accept {} connections with Tokio: {:?}", MAX_CONNS, duration); + } + + /// Tests the time it takes to accept a fixed number of TCP connections with threads. + #[test] + #[ignore] + fn test_tcplistener_threads() { + let server = std::sync::Arc::new(std::net::TcpListener::bind(ADDR).unwrap()); + println!("Tokio TCP listener bound to {}", ADDR); + + // Accept the first connection, then start timing + let _ = server.accept().unwrap(); + println!("First connection accepted, starting timer..."); + + let start = std::time::Instant::now(); + let mut handles = vec![]; + + let connections_per_thread = MAX_CONNS / MAX_THREADS; + let remainder = MAX_CONNS % MAX_THREADS; + let curr: std::sync::Arc = std::sync::Arc::new(AtomicUsize::new(0)); + // Break it up over the threads + for i in 0..MAX_THREADS { + let curr_clone = curr.clone(); + let server_clone = server.clone(); + let num_conns = if i < remainder { + connections_per_thread + 1 + } else { + connections_per_thread + }; + let handle = std::thread::spawn(move || { + println!("Thread {} accepting {} connections...", i, num_conns); + + for _ in 0..num_conns { + let _ = server_clone.accept(); + + curr_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + + // If Curr is a multiple of 100, print progress + if curr_clone.load(std::sync::atomic::Ordering::SeqCst) % 100 == 0 { + println!( + "Accepted {} connections so far...", + curr_clone.load(std::sync::atomic::Ordering::SeqCst) + ); + } + } + }); + handles.push(handle); + } + for handle in handles { + handle.join().unwrap(); + } + let duration = start.elapsed(); + println!("Time taken to accept {} connections with threads: {:?}", MAX_CONNS, duration); + } + + /// Tests the time it takes to accept a fixed number of TCP connections with threads and tokio. + #[tokio::test] + #[ignore] + async fn test_tcplistener_threads_tokio() { + let server = std::sync::Arc::new(tokio::net::TcpListener::bind(ADDR).await.unwrap()); + println!("Tokio TCP listener bound to {}", ADDR); + + // Accept the first connection, then start timing + let _ = server.accept().await.unwrap(); + println!("First connection accepted, starting timer..."); + + let start = tokio::time::Instant::now(); + let mut handles = vec![]; + + let connections_per_thread = MAX_CONNS / MAX_THREADS; + let remainder = MAX_CONNS % MAX_THREADS; + let curr: std::sync::Arc = std::sync::Arc::new(AtomicUsize::new(0)); + // Break it up over the threads + for i in 0..MAX_THREADS { + let curr_clone = curr.clone(); + let server_clone = server.clone(); + let num_conns = if i < remainder { + connections_per_thread + 1usize + } else { + connections_per_thread + }; + let handle = tokio::spawn(async move { + for _ in 0..num_conns { + let _ = server_clone.accept().await; + curr_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + + // If Curr is a multiple of 100, print progress + if curr_clone.load(std::sync::atomic::Ordering::SeqCst) % 100 == 0 { + println!( + "Accepted {} connections so far...", + curr_clone.load(std::sync::atomic::Ordering::SeqCst) + ); + } + } + }); + handles.push(handle); + } + for handle in handles { + handle.await.unwrap(); + } + let duration = start.elapsed(); + println!( + "Time taken to accept {} connections with threads and Tokio: {:?}", + MAX_CONNS, + duration + ); + } +}