diff --git a/Cargo.lock b/Cargo.lock index f9d328e9a35..7c58d1167df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3240,7 +3240,7 @@ dependencies = [ [[package]] name = "libp2p-webrtc-utils" -version = "0.4.0" +version = "0.4.1" dependencies = [ "asynchronous-codec", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 22bc98cd32c..6c4870f623a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -110,7 +110,7 @@ libp2p-tls = { version = "0.6.2", path = "transports/tls" } libp2p-uds = { version = "0.43.0", path = "transports/uds" } libp2p-upnp = { version = "0.5.0", path = "protocols/upnp" } libp2p-webrtc = { version = "0.9.0-alpha.1", path = "transports/webrtc" } -libp2p-webrtc-utils = { version = "0.4.0", path = "misc/webrtc-utils" } +libp2p-webrtc-utils = { version = "0.4.1", path = "misc/webrtc-utils" } libp2p-webrtc-websys = { version = "0.4.0", path = "transports/webrtc-websys" } libp2p-websocket = { version = "0.45.1", path = "transports/websocket" } libp2p-websocket-websys = { version = "0.5.0", path = "transports/websocket-websys" } diff --git a/misc/webrtc-utils/CHANGELOG.md b/misc/webrtc-utils/CHANGELOG.md index 992f8354bba..764c9b6a8ad 100644 --- a/misc/webrtc-utils/CHANGELOG.md +++ b/misc/webrtc-utils/CHANGELOG.md @@ -1,3 +1,8 @@ +## 0.4.1 + +- FIN_ACK support for WebRTC streams. + See [PR 6084](https://github.com/libp2p/rust-libp2p/pull/6084). + ## 0.4.0 diff --git a/misc/webrtc-utils/Cargo.toml b/misc/webrtc-utils/Cargo.toml index 8c6eaedd1e3..6b35965afc5 100644 --- a/misc/webrtc-utils/Cargo.toml +++ b/misc/webrtc-utils/Cargo.toml @@ -7,7 +7,7 @@ license = "MIT" name = "libp2p-webrtc-utils" repository = "https://github.com/libp2p/rust-libp2p" rust-version = { workspace = true } -version = "0.4.0" +version = "0.4.1" publish = true [dependencies] diff --git a/misc/webrtc-utils/src/generated/message.proto b/misc/webrtc-utils/src/generated/message.proto index eab3ceb720b..4c08b6f4e79 100644 --- a/misc/webrtc-utils/src/generated/message.proto +++ b/misc/webrtc-utils/src/generated/message.proto @@ -12,9 +12,13 @@ message Message { // The sender abruptly terminates the sending part of the stream. The // receiver can discard any data that it already received on that stream. RESET = 2; + // Sending the FIN_ACK flag acknowledges the previous receipt of a message + // with the FIN flag set. Receiving a FIN_ACK flag gives the recipient + // confidence that the remote has received all sent messages. + FIN_ACK = 3; } - optional Flag flag=1; + optional Flag flag = 1; optional bytes message = 2; } diff --git a/misc/webrtc-utils/src/generated/webrtc/pb.rs b/misc/webrtc-utils/src/generated/webrtc/pb.rs index 9e33e97188c..9e1e282ebd3 100644 --- a/misc/webrtc-utils/src/generated/webrtc/pb.rs +++ b/misc/webrtc-utils/src/generated/webrtc/pb.rs @@ -57,6 +57,7 @@ pub enum Flag { FIN = 0, STOP_SENDING = 1, RESET = 2, + FIN_ACK = 3, } impl Default for Flag { @@ -71,6 +72,7 @@ impl From for Flag { 0 => Flag::FIN, 1 => Flag::STOP_SENDING, 2 => Flag::RESET, + 3 => Flag::FIN_ACK, _ => Self::default(), } } @@ -82,6 +84,7 @@ impl<'a> From<&'a str> for Flag { "FIN" => Flag::FIN, "STOP_SENDING" => Flag::STOP_SENDING, "RESET" => Flag::RESET, + "FIN_ACK" => Flag::FIN_ACK, _ => Self::default(), } } diff --git a/misc/webrtc-utils/src/stream.rs b/misc/webrtc-utils/src/stream.rs index 0ec420a103a..88845e159a3 100644 --- a/misc/webrtc-utils/src/stream.rs +++ b/misc/webrtc-utils/src/stream.rs @@ -135,6 +135,17 @@ where return Poll::Ready(Ok(n)); } + // Check if we need to send a FIN_ACK + if self.state.needs_fin_ack() { + ready!(self.io.poll_ready_unpin(cx))?; + self.io.start_send_unpin(Message { + flag: Some(Flag::FIN_ACK), + message: None, + })?; + ready!(self.io.poll_flush_unpin(cx))?; + self.state.fin_ack_sent(); + } + let Self { read_buffer, io, @@ -177,6 +188,17 @@ where cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { + // Check if we need to send a FIN_ACK first + if self.state.needs_fin_ack() { + ready!(self.io.poll_ready_unpin(cx))?; + self.io.start_send_unpin(Message { + flag: Some(Flag::FIN_ACK), + message: None, + })?; + ready!(self.io.poll_flush_unpin(cx))?; + self.state.fin_ack_sent(); + } + while self.state.read_flags_in_async_write() { // TODO: In case AsyncRead::poll_read encountered an error or returned None earlier, we // will poll the underlying I/O resource once more. Is that allowed? How @@ -193,7 +215,7 @@ where Poll::Ready(Some((Some(flag), message))) => { // Read side is closed. Discard any incoming messages. drop(message); - // But still handle flags, e.g. a `Flag::StopSending`. + // But still handle flags, e.g. a `Flag::StopSending` or `Flag::FIN_ACK`. state.handle_inbound_flag(flag, read_buffer) } Poll::Ready(Some((None, message))) => drop(message), @@ -216,10 +238,32 @@ where } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // Check if we need to send a FIN_ACK first + if self.state.needs_fin_ack() { + ready!(self.io.poll_ready_unpin(cx))?; + self.io.start_send_unpin(Message { + flag: Some(Flag::FIN_ACK), + message: None, + })?; + ready!(self.io.poll_flush_unpin(cx))?; + self.state.fin_ack_sent(); + } + self.io.poll_flush_unpin(cx).map_err(Into::into) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // Check if we need to send a FIN_ACK first + if self.state.needs_fin_ack() { + ready!(self.io.poll_ready_unpin(cx))?; + self.io.start_send_unpin(Message { + flag: Some(Flag::FIN_ACK), + message: None, + })?; + ready!(self.io.poll_flush_unpin(cx))?; + self.state.fin_ack_sent(); + } + loop { match self.state.close_write_barrier()? { Some(Closing::Requested) => { diff --git a/misc/webrtc-utils/src/stream/state.rs b/misc/webrtc-utils/src/stream/state.rs index be44c061fd9..454ac1f4c1b 100644 --- a/misc/webrtc-utils/src/stream/state.rs +++ b/misc/webrtc-utils/src/stream/state.rs @@ -35,10 +35,20 @@ pub(crate) enum State { inner: Closing, }, ClosingWrite { - /// Whether the write side of our channel was already closed. + /// Whether the read side of our channel was already closed. read_closed: bool, inner: Closing, }, + /// We've received a FIN and need to send a FIN_ACK + ReadClosedNeedFinAck { + /// Whether the write side of our channel was already closed. + write_closed: bool, + }, + /// We've sent a FIN and are waiting for a FIN_ACK + WriteSentFinWaitingForAck { + /// Whether the read side of our channel was already closed. + read_closed: bool, + }, BothClosed { reset: bool, }, @@ -61,9 +71,19 @@ impl State { match (current, flag) { (Self::Open, Flag::FIN) => { - *self = Self::ReadClosed; + *self = Self::ReadClosedNeedFinAck { + write_closed: false, + }; } (Self::WriteClosed, Flag::FIN) => { + *self = Self::ReadClosedNeedFinAck { write_closed: true }; + } + (Self::WriteSentFinWaitingForAck { read_closed: false }, Flag::FIN) => { + *self = Self::ReadClosedNeedFinAck { + write_closed: false, + }; + } + (Self::WriteSentFinWaitingForAck { read_closed: true }, Flag::FIN) => { *self = Self::BothClosed { reset: false }; } (Self::Open, Flag::STOP_SENDING) => { @@ -72,6 +92,39 @@ impl State { (Self::ReadClosed, Flag::STOP_SENDING) => { *self = Self::BothClosed { reset: false }; } + ( + Self::ReadClosedNeedFinAck { + write_closed: false, + }, + Flag::STOP_SENDING, + ) => { + *self = Self::ReadClosedNeedFinAck { write_closed: true }; + } + (Self::ReadClosedNeedFinAck { write_closed: true }, Flag::STOP_SENDING) => { + // Already closed, ignore + } + (Self::WriteSentFinWaitingForAck { read_closed: _ }, Flag::STOP_SENDING) => { + *self = Self::WriteSentFinWaitingForAck { read_closed: true }; + } + (Self::WriteSentFinWaitingForAck { read_closed: false }, Flag::FIN_ACK) => { + *self = Self::WriteClosed; + } + (Self::WriteSentFinWaitingForAck { read_closed: true }, Flag::FIN_ACK) => { + *self = Self::BothClosed { reset: false }; + } + ( + Self::ClosingWrite { + read_closed, + inner: Closing::MessageSent, + }, + Flag::FIN_ACK, + ) => { + *self = if read_closed { + Self::BothClosed { reset: false } + } else { + Self::WriteClosed + }; + } (_, Flag::RESET) => { buffer.clear(); *self = Self::BothClosed { reset: true }; @@ -88,7 +141,7 @@ impl State { } => { debug_assert!(matches!(inner, Closing::MessageSent)); - *self = State::BothClosed { reset: false }; + *self = State::WriteSentFinWaitingForAck { read_closed: true }; } State::ClosingWrite { read_closed: false, @@ -96,7 +149,13 @@ impl State { } => { debug_assert!(matches!(inner, Closing::MessageSent)); - *self = State::WriteClosed; + *self = State::WriteSentFinWaitingForAck { read_closed: false }; + } + State::ReadClosedNeedFinAck { .. } => { + unreachable!("write_closed called on ReadClosedNeedFinAck state") + } + State::WriteSentFinWaitingForAck { .. } => { + unreachable!("write_closed called on WriteSentFinWaitingForAck state") } State::Open | State::ReadClosed @@ -118,6 +177,12 @@ impl State { inner: Closing::MessageSent, }; } + State::ReadClosedNeedFinAck { .. } => { + unreachable!("close_write_message_sent called on ReadClosedNeedFinAck state") + } + State::WriteSentFinWaitingForAck { .. } => { + unreachable!("close_write_message_sent called on WriteSentFinWaitingForAck state") + } State::Open | State::ReadClosed | State::WriteClosed @@ -146,6 +211,12 @@ impl State { *self = State::ReadClosed; } + State::ReadClosedNeedFinAck { .. } => { + unreachable!("read_closed called on ReadClosedNeedFinAck state") + } + State::WriteSentFinWaitingForAck { .. } => { + unreachable!("read_closed called on WriteSentFinWaitingForAck state") + } State::Open | State::ReadClosed | State::WriteClosed @@ -169,6 +240,12 @@ impl State { inner: Closing::MessageSent, }; } + State::ReadClosedNeedFinAck { .. } => { + unreachable!("close_read_message_sent called on ReadClosedNeedFinAck state") + } + State::WriteSentFinWaitingForAck { .. } => { + unreachable!("close_read_message_sent called on WriteSentFinWaitingForAck state") + } State::Open | State::ReadClosed | State::WriteClosed @@ -184,7 +261,12 @@ impl State { /// This is necessary for read-closed streams because we would otherwise /// not read any more flags from the socket. pub(crate) fn read_flags_in_async_write(&self) -> bool { - matches!(self, Self::ReadClosed) + matches!( + self, + Self::ReadClosed + | Self::ReadClosedNeedFinAck { .. } + | Self::WriteSentFinWaitingForAck { .. } + ) } /// Acts as a "barrier" for [`futures::AsyncRead::poll_read`]. @@ -196,11 +278,14 @@ impl State { | WriteClosed | ClosingWrite { read_closed: false, .. - } => return Ok(()), + } + | WriteSentFinWaitingForAck { read_closed: false } => return Ok(()), ClosingWrite { read_closed: true, .. } | ReadClosed + | ReadClosedNeedFinAck { .. } + | WriteSentFinWaitingForAck { read_closed: true } | ClosingRead { .. } | BothClosed { reset: false } => io::ErrorKind::BrokenPipe, BothClosed { reset: true } => io::ErrorKind::ConnectionReset, @@ -216,6 +301,9 @@ impl State { let kind = match self { Open | ReadClosed + | ReadClosedNeedFinAck { + write_closed: false, + } | ClosingRead { write_closed: false, .. @@ -224,6 +312,8 @@ impl State { write_closed: true, .. } | WriteClosed + | ReadClosedNeedFinAck { write_closed: true } + | WriteSentFinWaitingForAck { .. } | ClosingWrite { .. } | BothClosed { reset: false } => io::ErrorKind::BrokenPipe, BothClosed { reset: true } => io::ErrorKind::ConnectionReset, @@ -240,6 +330,10 @@ impl State { State::ClosingWrite { inner, .. } => return Ok(Some(*inner)), + State::WriteSentFinWaitingForAck { .. } => { + return Err(io::Error::other("waiting for FIN_ACK before closing")) + } + State::Open => { *self = Self::ClosingWrite { read_closed: false, @@ -252,6 +346,17 @@ impl State { inner: Closing::Requested, }; } + State::ReadClosedNeedFinAck { + write_closed: false, + } => { + *self = Self::ClosingWrite { + read_closed: true, + inner: Closing::Requested, + }; + } + State::ReadClosedNeedFinAck { write_closed: true } => { + return Err(io::ErrorKind::BrokenPipe.into()) + } State::ClosingRead { write_closed: true, .. @@ -284,6 +389,8 @@ impl State { State::ClosingRead { inner, .. } => return Ok(Some(*inner)), + State::ReadClosedNeedFinAck { .. } => return Ok(None), + State::Open => { *self = Self::ClosingRead { write_closed: false, @@ -296,6 +403,15 @@ impl State { inner: Closing::Requested, }; } + State::WriteSentFinWaitingForAck { read_closed: false } => { + *self = Self::ClosingRead { + write_closed: false, + inner: Closing::Requested, + }; + } + State::WriteSentFinWaitingForAck { read_closed: true } => { + return Err(io::ErrorKind::BrokenPipe.into()) + } State::ClosingWrite { read_closed: true, .. @@ -318,6 +434,30 @@ impl State { } } } + + /// Returns whether the state requires sending a FIN_ACK. + /// This should be called by the stream implementation to check if it needs to send a FIN_ACK. + pub(crate) fn needs_fin_ack(&self) -> bool { + matches!(self, Self::ReadClosedNeedFinAck { .. }) + } + + /// Marks that a FIN_ACK has been sent in response to a received FIN. + /// This transitions from ReadClosedNeedFinAck to ReadClosed. + pub(crate) fn fin_ack_sent(&mut self) { + match self { + State::ReadClosedNeedFinAck { + write_closed: false, + } => { + *self = State::ReadClosed; + } + State::ReadClosedNeedFinAck { write_closed: true } => { + *self = State::BothClosed { reset: false }; + } + _ => { + unreachable!("fin_ack_sent called on wrong state") + } + } + } } #[cfg(test)] @@ -331,6 +471,8 @@ mod tests { let mut open = State::Open; open.handle_inbound_flag(Flag::FIN, &mut Bytes::default()); + // After receiving FIN, we're in ReadClosedNeedFinAck state but read barrier should still + // prevent reading let error = open.read_barrier().unwrap_err(); assert_eq!(error.kind(), ErrorKind::BrokenPipe) @@ -477,8 +619,18 @@ mod tests { open.close_write_message_sent(); open.write_closed(); + // After write_closed(), we're waiting for FIN_ACK, so close_write_barrier should return + // error + let result = open.close_write_barrier(); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string(), + "waiting for FIN_ACK before closing" + ); + + // After receiving FIN_ACK, we should be in WriteClosed state + open.handle_inbound_flag(Flag::FIN_ACK, &mut Bytes::default()); let maybe = open.close_write_barrier().unwrap(); - assert!(maybe.is_none()) } @@ -504,4 +656,180 @@ mod tests { assert!(buffer.is_empty()); } + + #[test] + fn fin_requires_fin_ack() { + let mut open = State::Open; + + open.handle_inbound_flag(Flag::FIN, &mut Bytes::default()); + + assert!(open.needs_fin_ack()); + assert!(matches!( + open, + State::ReadClosedNeedFinAck { + write_closed: false + } + )); + } + + #[test] + fn fin_ack_sent_transitions_to_read_closed() { + let mut state = State::ReadClosedNeedFinAck { + write_closed: false, + }; + + state.fin_ack_sent(); + + assert!(!state.needs_fin_ack()); + assert!(matches!(state, State::ReadClosed)); + } + + #[test] + fn fin_ack_sent_with_write_closed_transitions_to_both_closed() { + let mut state = State::ReadClosedNeedFinAck { write_closed: true }; + + state.fin_ack_sent(); + + assert!(!state.needs_fin_ack()); + assert!(matches!(state, State::BothClosed { reset: false })); + } + + #[test] + fn fin_ack_completes_write_close() { + let mut state = State::WriteSentFinWaitingForAck { read_closed: false }; + + state.handle_inbound_flag(Flag::FIN_ACK, &mut Bytes::default()); + + assert!(matches!(state, State::WriteClosed)); + } + + #[test] + fn fin_ack_with_read_closed_transitions_to_both_closed() { + let mut state = State::WriteSentFinWaitingForAck { read_closed: true }; + + state.handle_inbound_flag(Flag::FIN_ACK, &mut Bytes::default()); + + assert!(matches!(state, State::BothClosed { reset: false })); + } + + #[test] + fn simultaneous_fin_exchange() { + let mut state = State::WriteSentFinWaitingForAck { read_closed: false }; + + // Receive FIN while waiting for FIN_ACK + state.handle_inbound_flag(Flag::FIN, &mut Bytes::default()); + + assert!(state.needs_fin_ack()); + assert!(matches!( + state, + State::ReadClosedNeedFinAck { + write_closed: false + } + )); + + // Send FIN_ACK + state.fin_ack_sent(); + + assert!(matches!(state, State::ReadClosed)); + } + + #[test] + fn write_close_waits_for_fin_ack() { + let mut state = State::WriteSentFinWaitingForAck { read_closed: false }; + + let result = state.close_write_barrier(); + + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string(), + "waiting for FIN_ACK before closing" + ); + } + + #[test] + fn read_flags_in_async_write_with_fin_ack_states() { + let state_need_ack = State::ReadClosedNeedFinAck { + write_closed: false, + }; + let state_waiting_ack = State::WriteSentFinWaitingForAck { read_closed: false }; + + assert!(state_need_ack.read_flags_in_async_write()); + assert!(state_waiting_ack.read_flags_in_async_write()); + } + + #[test] + fn complete_fin_ack_handshake_example() { + // This test demonstrates the complete FIN_ACK handshake as described in the spec: + // NodeA closes for writing, NodeB delays allowing the channel to close until it + // also finishes writing. + + let mut node_a = State::Open; + let mut node_b = State::Open; + + // NodeA wants to close for writing + node_a.close_write_barrier().unwrap(); + node_a.close_write_message_sent(); + node_a.write_closed(); + + // NodeA is now waiting for FIN_ACK + assert!(matches!( + node_a, + State::WriteSentFinWaitingForAck { read_closed: false } + )); + + // NodeB receives the FIN from NodeA + node_b.handle_inbound_flag(Flag::FIN, &mut Bytes::default()); + + // NodeB should now need to send a FIN_ACK + assert!(node_b.needs_fin_ack()); + assert!(matches!( + node_b, + State::ReadClosedNeedFinAck { + write_closed: false + } + )); + + // NodeB sends FIN_ACK (simulated by calling fin_ack_sent) + node_b.fin_ack_sent(); + assert!(matches!(node_b, State::ReadClosed)); + + // NodeA receives the FIN_ACK + node_a.handle_inbound_flag(Flag::FIN_ACK, &mut Bytes::default()); + + // NodeA's write side is now closed + assert!(matches!(node_a, State::WriteClosed)); + + // NodeB also wants to close for writing + node_b.close_write_barrier().unwrap(); + node_b.close_write_message_sent(); + node_b.write_closed(); + + // NodeB is now waiting for FIN_ACK + assert!(matches!( + node_b, + State::WriteSentFinWaitingForAck { read_closed: true } + )); + + // NodeA receives the FIN from NodeB + node_a.handle_inbound_flag(Flag::FIN, &mut Bytes::default()); + + // NodeA should now need to send a FIN_ACK + assert!(node_a.needs_fin_ack()); + assert!(matches!( + node_a, + State::ReadClosedNeedFinAck { write_closed: true } + )); + + // NodeA sends FIN_ACK + node_a.fin_ack_sent(); + assert!(matches!(node_a, State::BothClosed { reset: false })); + + // NodeB receives the FIN_ACK + node_b.handle_inbound_flag(Flag::FIN_ACK, &mut Bytes::default()); + + // NodeB is now fully closed + assert!(matches!(node_b, State::BothClosed { reset: false })); + + // Both nodes have successfully closed the channel without data loss + } }