|
| 1 | +//! Relay client implementation for connecting to relay servers. |
| 2 | +
|
| 3 | +use super::protocol::{NodeId, RelayError, RelayMessage}; |
| 4 | +use std::net::SocketAddr; |
| 5 | +use std::sync::Arc; |
| 6 | +use std::time::{Duration, Instant}; |
| 7 | +use tokio::net::UdpSocket; |
| 8 | +use tokio::sync::{Mutex, mpsc}; |
| 9 | +use tokio::time; |
| 10 | + |
| 11 | +/// Type alias for the message receiver |
| 12 | +type MessageReceiver = Arc<Mutex<mpsc::UnboundedReceiver<(NodeId, Vec<u8>)>>>; |
| 13 | + |
| 14 | +/// Relay client state |
| 15 | +#[derive(Debug, Clone, Copy, PartialEq, Eq)] |
| 16 | +pub enum RelayClientState { |
| 17 | + /// Disconnected from relay |
| 18 | + Disconnected, |
| 19 | + /// Connecting to relay |
| 20 | + Connecting, |
| 21 | + /// Registering with relay |
| 22 | + Registering, |
| 23 | + /// Connected and registered |
| 24 | + Connected, |
| 25 | + /// Error state |
| 26 | + Error, |
| 27 | +} |
| 28 | + |
| 29 | +/// Relay client for communicating with relay servers |
| 30 | +pub struct RelayClient { |
| 31 | + /// Local node ID |
| 32 | + node_id: NodeId, |
| 33 | + /// Relay server address |
| 34 | + relay_addr: SocketAddr, |
| 35 | + /// UDP socket for communication |
| 36 | + socket: Arc<UdpSocket>, |
| 37 | + /// Current client state |
| 38 | + state: Arc<Mutex<RelayClientState>>, |
| 39 | + /// Receiver for incoming messages |
| 40 | + rx: MessageReceiver, |
| 41 | + /// Sender for message processing |
| 42 | + tx: mpsc::UnboundedSender<(NodeId, Vec<u8>)>, |
| 43 | + /// Last keepalive time |
| 44 | + last_keepalive: Arc<Mutex<Instant>>, |
| 45 | +} |
| 46 | + |
| 47 | +impl RelayClient { |
| 48 | + /// Connect to a relay server |
| 49 | + /// |
| 50 | + /// # Arguments |
| 51 | + /// |
| 52 | + /// * `addr` - Relay server address |
| 53 | + /// * `node_id` - Local node identifier |
| 54 | + /// |
| 55 | + /// # Errors |
| 56 | + /// |
| 57 | + /// Returns error if connection fails or times out. |
| 58 | + pub async fn connect(addr: SocketAddr, node_id: NodeId) -> Result<Self, RelayError> { |
| 59 | + // Bind local UDP socket |
| 60 | + let socket = UdpSocket::bind("0.0.0.0:0").await?; |
| 61 | + socket.connect(addr).await?; |
| 62 | + |
| 63 | + let (tx, rx) = mpsc::unbounded_channel(); |
| 64 | + |
| 65 | + let client = Self { |
| 66 | + node_id, |
| 67 | + relay_addr: addr, |
| 68 | + socket: Arc::new(socket), |
| 69 | + state: Arc::new(Mutex::new(RelayClientState::Disconnected)), |
| 70 | + rx: Arc::new(Mutex::new(rx)), |
| 71 | + tx, |
| 72 | + last_keepalive: Arc::new(Mutex::new(Instant::now())), |
| 73 | + }; |
| 74 | + |
| 75 | + // Update state to connecting |
| 76 | + *client.state.lock().await = RelayClientState::Connecting; |
| 77 | + |
| 78 | + Ok(client) |
| 79 | + } |
| 80 | + |
| 81 | + /// Register with the relay server |
| 82 | + /// |
| 83 | + /// # Arguments |
| 84 | + /// |
| 85 | + /// * `public_key` - Client's public key for verification |
| 86 | + /// |
| 87 | + /// # Errors |
| 88 | + /// |
| 89 | + /// Returns error if registration fails or times out. |
| 90 | + pub async fn register(&mut self, public_key: &[u8; 32]) -> Result<(), RelayError> { |
| 91 | + *self.state.lock().await = RelayClientState::Registering; |
| 92 | + |
| 93 | + let msg = RelayMessage::Register { |
| 94 | + node_id: self.node_id, |
| 95 | + public_key: *public_key, |
| 96 | + }; |
| 97 | + |
| 98 | + let bytes = msg.to_bytes()?; |
| 99 | + self.socket.send(&bytes).await?; |
| 100 | + |
| 101 | + // Wait for RegisterAck with timeout |
| 102 | + let mut buf = vec![0u8; 65536]; |
| 103 | + let len = time::timeout(Duration::from_secs(10), self.socket.recv(&mut buf)) |
| 104 | + .await |
| 105 | + .map_err(|_| RelayError::Timeout)??; |
| 106 | + |
| 107 | + let response = RelayMessage::from_bytes(&buf[..len])?; |
| 108 | + |
| 109 | + match response { |
| 110 | + RelayMessage::RegisterAck { |
| 111 | + success, |
| 112 | + error, |
| 113 | + relay_id: _, |
| 114 | + } => { |
| 115 | + if success { |
| 116 | + *self.state.lock().await = RelayClientState::Connected; |
| 117 | + *self.last_keepalive.lock().await = Instant::now(); |
| 118 | + Ok(()) |
| 119 | + } else { |
| 120 | + *self.state.lock().await = RelayClientState::Error; |
| 121 | + Err(RelayError::Internal( |
| 122 | + error.unwrap_or_else(|| "Registration failed".to_string()), |
| 123 | + )) |
| 124 | + } |
| 125 | + } |
| 126 | + RelayMessage::Error { code, message: _ } => { |
| 127 | + *self.state.lock().await = RelayClientState::Error; |
| 128 | + Err(code.into()) |
| 129 | + } |
| 130 | + _ => { |
| 131 | + *self.state.lock().await = RelayClientState::Error; |
| 132 | + Err(RelayError::InvalidMessage) |
| 133 | + } |
| 134 | + } |
| 135 | + } |
| 136 | + |
| 137 | + /// Send a packet to a peer through the relay |
| 138 | + /// |
| 139 | + /// # Arguments |
| 140 | + /// |
| 141 | + /// * `dest` - Destination node ID |
| 142 | + /// * `data` - Packet payload (already encrypted) |
| 143 | + /// |
| 144 | + /// # Errors |
| 145 | + /// |
| 146 | + /// Returns error if send fails or client not registered. |
| 147 | + pub async fn send_to_peer(&self, dest: NodeId, data: &[u8]) -> Result<(), RelayError> { |
| 148 | + if *self.state.lock().await != RelayClientState::Connected { |
| 149 | + return Err(RelayError::NotRegistered); |
| 150 | + } |
| 151 | + |
| 152 | + let msg = RelayMessage::SendPacket { |
| 153 | + dest_id: dest, |
| 154 | + payload: data.to_vec(), |
| 155 | + }; |
| 156 | + |
| 157 | + let bytes = msg.to_bytes()?; |
| 158 | + self.socket.send(&bytes).await?; |
| 159 | + |
| 160 | + Ok(()) |
| 161 | + } |
| 162 | + |
| 163 | + /// Receive a packet from a peer through the relay |
| 164 | + /// |
| 165 | + /// # Errors |
| 166 | + /// |
| 167 | + /// Returns error if receive fails or timeout occurs. |
| 168 | + pub async fn recv_from_peer(&self) -> Result<(NodeId, Vec<u8>), RelayError> { |
| 169 | + let mut rx = self.rx.lock().await; |
| 170 | + rx.recv() |
| 171 | + .await |
| 172 | + .ok_or_else(|| RelayError::Internal("Channel closed".to_string())) |
| 173 | + } |
| 174 | + |
| 175 | + /// Send keepalive message to maintain connection |
| 176 | + /// |
| 177 | + /// # Errors |
| 178 | + /// |
| 179 | + /// Returns error if send fails. |
| 180 | + pub async fn keepalive(&self) -> Result<(), RelayError> { |
| 181 | + let msg = RelayMessage::Keepalive; |
| 182 | + let bytes = msg.to_bytes()?; |
| 183 | + self.socket.send(&bytes).await?; |
| 184 | + |
| 185 | + *self.last_keepalive.lock().await = Instant::now(); |
| 186 | + Ok(()) |
| 187 | + } |
| 188 | + |
| 189 | + /// Disconnect from relay server |
| 190 | + /// |
| 191 | + /// # Errors |
| 192 | + /// |
| 193 | + /// Returns error if disconnect message fails to send. |
| 194 | + pub async fn disconnect(&mut self) -> Result<(), RelayError> { |
| 195 | + let msg = RelayMessage::Disconnect; |
| 196 | + let bytes = msg.to_bytes()?; |
| 197 | + self.socket.send(&bytes).await?; |
| 198 | + |
| 199 | + *self.state.lock().await = RelayClientState::Disconnected; |
| 200 | + Ok(()) |
| 201 | + } |
| 202 | + |
| 203 | + /// Get current client state |
| 204 | + #[must_use] |
| 205 | + pub async fn state(&self) -> RelayClientState { |
| 206 | + *self.state.lock().await |
| 207 | + } |
| 208 | + |
| 209 | + /// Get relay server address |
| 210 | + #[must_use] |
| 211 | + pub fn relay_addr(&self) -> SocketAddr { |
| 212 | + self.relay_addr |
| 213 | + } |
| 214 | + |
| 215 | + /// Start background message processing task |
| 216 | + /// |
| 217 | + /// This task receives messages from the relay and forwards them to the channel. |
| 218 | + pub fn spawn_receiver(&self) { |
| 219 | + let socket = self.socket.clone(); |
| 220 | + let tx = self.tx.clone(); |
| 221 | + let state = self.state.clone(); |
| 222 | + |
| 223 | + tokio::spawn(async move { |
| 224 | + let mut buf = vec![0u8; 65536]; |
| 225 | + |
| 226 | + loop { |
| 227 | + match socket.recv(&mut buf).await { |
| 228 | + Ok(len) => { |
| 229 | + if let Ok(msg) = RelayMessage::from_bytes(&buf[..len]) { |
| 230 | + match msg { |
| 231 | + RelayMessage::RecvPacket { src_id, payload } => { |
| 232 | + let _ = tx.send((src_id, payload)); |
| 233 | + } |
| 234 | + RelayMessage::PeerOnline { peer_id: _ } => { |
| 235 | + // Could notify application layer |
| 236 | + } |
| 237 | + RelayMessage::PeerOffline { peer_id: _ } => { |
| 238 | + // Could notify application layer |
| 239 | + } |
| 240 | + RelayMessage::Error { code, message: _ } => { |
| 241 | + eprintln!("Relay error: {:?}", code); |
| 242 | + *state.lock().await = RelayClientState::Error; |
| 243 | + } |
| 244 | + _ => { |
| 245 | + // Ignore other messages |
| 246 | + } |
| 247 | + } |
| 248 | + } |
| 249 | + } |
| 250 | + Err(e) => { |
| 251 | + eprintln!("Receive error: {}", e); |
| 252 | + *state.lock().await = RelayClientState::Error; |
| 253 | + break; |
| 254 | + } |
| 255 | + } |
| 256 | + } |
| 257 | + }); |
| 258 | + } |
| 259 | + |
| 260 | + /// Check if keepalive is needed and send if necessary |
| 261 | + /// |
| 262 | + /// # Errors |
| 263 | + /// |
| 264 | + /// Returns error if keepalive send fails. |
| 265 | + pub async fn maybe_keepalive(&self, interval: Duration) -> Result<(), RelayError> { |
| 266 | + let last = *self.last_keepalive.lock().await; |
| 267 | + if last.elapsed() >= interval { |
| 268 | + self.keepalive().await?; |
| 269 | + } |
| 270 | + Ok(()) |
| 271 | + } |
| 272 | +} |
| 273 | + |
| 274 | +#[cfg(test)] |
| 275 | +mod tests { |
| 276 | + use super::*; |
| 277 | + |
| 278 | + #[tokio::test] |
| 279 | + async fn test_relay_client_creation() { |
| 280 | + let node_id = [1u8; 32]; |
| 281 | + let addr = "127.0.0.1:8000".parse().unwrap(); |
| 282 | + |
| 283 | + let result = RelayClient::connect(addr, node_id).await; |
| 284 | + // May fail if relay not running, but constructor should succeed |
| 285 | + assert!(result.is_ok() || matches!(result, Err(RelayError::Io(_)))); |
| 286 | + } |
| 287 | + |
| 288 | + #[tokio::test] |
| 289 | + async fn test_relay_client_state() { |
| 290 | + let node_id = [1u8; 32]; |
| 291 | + let addr = "127.0.0.1:8001".parse().unwrap(); |
| 292 | + |
| 293 | + if let Ok(client) = RelayClient::connect(addr, node_id).await { |
| 294 | + let state = client.state().await; |
| 295 | + assert_eq!(state, RelayClientState::Connecting); |
| 296 | + } |
| 297 | + } |
| 298 | + |
| 299 | + #[tokio::test] |
| 300 | + async fn test_relay_client_relay_addr() { |
| 301 | + let node_id = [1u8; 32]; |
| 302 | + let addr: SocketAddr = "127.0.0.1:8002".parse().unwrap(); |
| 303 | + |
| 304 | + if let Ok(client) = RelayClient::connect(addr, node_id).await { |
| 305 | + assert_eq!(client.relay_addr(), addr); |
| 306 | + } |
| 307 | + } |
| 308 | + |
| 309 | + #[test] |
| 310 | + fn test_relay_client_state_transitions() { |
| 311 | + assert_eq!( |
| 312 | + RelayClientState::Disconnected, |
| 313 | + RelayClientState::Disconnected |
| 314 | + ); |
| 315 | + assert_ne!(RelayClientState::Connecting, RelayClientState::Connected); |
| 316 | + } |
| 317 | +} |
0 commit comments