From f0fe9a61db1a357c7f91608d4d042bca706f9b65 Mon Sep 17 00:00:00 2001 From: xmh0511 <970252187@qq.com> Date: Wed, 6 Dec 2023 14:29:12 +0800 Subject: [PATCH 1/4] simplify the io copy for dispatch stream --- clash_lib/src/app/dispatcher/dispatcher.rs | 5 +- clash_lib/src/common/io.rs | 128 ++++++++++++++++++--- 2 files changed, 113 insertions(+), 20 deletions(-) diff --git a/clash_lib/src/app/dispatcher/dispatcher.rs b/clash_lib/src/app/dispatcher/dispatcher.rs index 3657a1812..b055eb264 100644 --- a/clash_lib/src/app/dispatcher/dispatcher.rs +++ b/clash_lib/src/app/dispatcher/dispatcher.rs @@ -11,6 +11,7 @@ use crate::proxy::AnyInboundDatagram; use crate::session::Session; use futures::SinkExt; use futures::StreamExt; +use tokio::io::AsyncReadExt; use std::collections::HashMap; use std::fmt::{Debug, Formatter}; use std::net::SocketAddr; @@ -136,8 +137,8 @@ impl Dispatcher { let mut rhs = TrackedStream::new(rhs, self.manager.clone(), sess.clone(), rule).await; match copy_buf_bidirectional_with_timeout( - &mut lhs, - &mut rhs, + lhs, + rhs, 4096, Duration::from_secs(10), Duration::from_secs(10), diff --git a/clash_lib/src/common/io.rs b/clash_lib/src/common/io.rs index b9e5fa59b..51df4647f 100644 --- a/clash_lib/src/common/io.rs +++ b/clash_lib/src/common/io.rs @@ -6,7 +6,7 @@ use std::task::{Context, Poll}; use std::time::Duration; use futures::ready; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; #[derive(Debug)] pub struct CopyBuffer { @@ -264,28 +264,120 @@ where } } +// pub async fn copy_buf_bidirectional_with_timeout( +// a: &mut A, +// b: &mut B, +// size: usize, +// a_to_b_timeout_duration: Duration, +// b_to_a_timeout_duration: Duration, +// ) -> Result<(u64, u64), std::io::Error> +// where +// A: AsyncRead + AsyncWrite + Unpin + ?Sized, +// B: AsyncRead + AsyncWrite + Unpin + ?Sized, +// { +// CopyBidirectional { +// a, +// b, +// a_to_b: TransferState::Running(CopyBuffer::new_with_capacity(size)?), +// b_to_a: TransferState::Running(CopyBuffer::new_with_capacity(size)?), +// a_to_b_count: 0, +// b_to_a_count: 0, +// a_to_b_delay: None, +// b_to_a_delay: None, +// a_to_b_timeout_duration, +// b_to_a_timeout_duration, +// } +// .await +// } + pub async fn copy_buf_bidirectional_with_timeout( - a: &mut A, - b: &mut B, + a: A, + b: B, size: usize, a_to_b_timeout_duration: Duration, b_to_a_timeout_duration: Duration, ) -> Result<(u64, u64), std::io::Error> where - A: AsyncRead + AsyncWrite + Unpin + ?Sized, - B: AsyncRead + AsyncWrite + Unpin + ?Sized, + A: AsyncRead + AsyncWrite + Unpin + Sized, + B: AsyncRead + AsyncWrite + Unpin + Sized, { - CopyBidirectional { - a, - b, - a_to_b: TransferState::Running(CopyBuffer::new_with_capacity(size)?), - b_to_a: TransferState::Running(CopyBuffer::new_with_capacity(size)?), - a_to_b_count: 0, - b_to_a_count: 0, - a_to_b_delay: None, - b_to_a_delay: None, - a_to_b_timeout_duration, - b_to_a_timeout_duration, - } - .await + let (mut a_reader, mut a_writer) = tokio::io::split(a); + let (mut b_reader, mut b_writer) = tokio::io::split(b); + let a_to_b = tokio::spawn(async move { + let mut upload_total_size = 0; + let mut buf = Vec::new(); + buf.resize(size, 0); + loop { + match tokio::time::timeout(a_to_b_timeout_duration, a_reader.read(&mut buf[..])).await { + Ok(Ok(size)) => { + if size == 0 { + return Ok(upload_total_size); + } + match b_writer.write_all(&buf[..size]).await { + Ok(_) => { + upload_total_size += size; + continue; + } + Err(e) => { + return Err(e); + } + } + } + Ok(Err(e)) => { + return Err(e); + } + Err(_) => { + return Ok(upload_total_size); + } + } + } + }); + let b_to_a = tokio::spawn(async move { + let download_total_size = 0; + let mut buf = Vec::new(); + buf.resize(size, 0); + loop { + match tokio::time::timeout(b_to_a_timeout_duration, b_reader.read(&mut buf[..])).await { + Ok(Ok(size)) => { + if size == 0 { + return Ok(download_total_size); + } + match a_writer.write_all(&buf[..size]).await { + Ok(_) => { + download_total_size += size; + continue; + } + Err(e) => { + return Err(e); + } + } + } + Ok(Err(e)) => { + return Err(e); + } + Err(_) => { + return Ok(download_total_size); + } + } + } + }); + let up = match a_to_b.await { + Ok(Ok(up)) => up, + Ok(Err(e)) => { + return Err(e); + } + Err(e) => { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "")); + } + }; + let down = match b_to_a.await { + Ok(Ok(up)) => up, + Ok(Err(e)) => { + return Err(e); + } + Err(e) => { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "")); + } + }; + Ok((up as u64, down as u64)) } From 452af1ffa46a3767e5bcd766aa523ae4a72f8f15 Mon Sep 17 00:00:00 2001 From: xmh0511 <970252187@qq.com> Date: Wed, 6 Dec 2023 15:13:33 +0800 Subject: [PATCH 2/4] fix some errors --- clash_lib/src/app/dispatcher/dispatcher.rs | 6 +- clash_lib/src/common/io.rs | 494 ++++++++++---------- clash_lib/src/proxy/mixed/mod.rs | 4 +- clash_lib/src/proxy/socks/inbound/mod.rs | 4 +- clash_lib/src/proxy/socks/inbound/stream.rs | 4 +- 5 files changed, 256 insertions(+), 256 deletions(-) diff --git a/clash_lib/src/app/dispatcher/dispatcher.rs b/clash_lib/src/app/dispatcher/dispatcher.rs index b055eb264..f09b5129d 100644 --- a/clash_lib/src/app/dispatcher/dispatcher.rs +++ b/clash_lib/src/app/dispatcher/dispatcher.rs @@ -11,7 +11,7 @@ use crate::proxy::AnyInboundDatagram; use crate::session::Session; use futures::SinkExt; use futures::StreamExt; -use tokio::io::AsyncReadExt; + use std::collections::HashMap; use std::fmt::{Debug, Formatter}; use std::net::SocketAddr; @@ -78,7 +78,7 @@ impl Dispatcher { #[instrument(skip(lhs))] pub async fn dispatch_stream(&self, sess: Session, mut lhs: S) where - S: AsyncRead + AsyncWrite + Unpin + Send, + S: AsyncRead + AsyncWrite + Unpin + Send +'static, { let sess = if self.resolver.fake_ip_enabled() { match sess.destination { @@ -134,7 +134,7 @@ impl Dispatcher { { Ok(rhs) => { debug!("remote connection established {}", sess); - let mut rhs = + let rhs = TrackedStream::new(rhs, self.manager.clone(), sess.clone(), rule).await; match copy_buf_bidirectional_with_timeout( lhs, diff --git a/clash_lib/src/common/io.rs b/clash_lib/src/common/io.rs index 51df4647f..c8d01317e 100644 --- a/clash_lib/src/common/io.rs +++ b/clash_lib/src/common/io.rs @@ -1,268 +1,268 @@ /// copy of https://github.com/eycorsican/leaf/blob/a77a1e497ae034f3a2a89c8628d5e7ebb2af47f0/leaf/src/common/io.rs -use std::future::Future; -use std::io; -use std::pin::Pin; -use std::task::{Context, Poll}; +// use std::future::Future; +// use std::io; +// use std::pin::Pin; +// use std::task::{Context, Poll}; use std::time::Duration; -use futures::ready; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; +// use futures::ready; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -#[derive(Debug)] -pub struct CopyBuffer { - read_done: bool, - need_flush: bool, - pos: usize, - cap: usize, - amt: u64, - buf: Box<[u8]>, -} +// #[derive(Debug)] +// pub struct CopyBuffer { +// read_done: bool, +// need_flush: bool, +// pos: usize, +// cap: usize, +// amt: u64, +// buf: Box<[u8]>, +// } -impl CopyBuffer { - #[allow(unused)] - pub fn new() -> Self { - Self { - read_done: false, - need_flush: false, - pos: 0, - cap: 0, - amt: 0, - buf: vec![0; 2 * 1024].into_boxed_slice(), - } - } +// impl CopyBuffer { +// #[allow(unused)] +// pub fn new() -> Self { +// Self { +// read_done: false, +// need_flush: false, +// pos: 0, +// cap: 0, +// amt: 0, +// buf: vec![0; 2 * 1024].into_boxed_slice(), +// } +// } - pub fn new_with_capacity(size: usize) -> Result { - let mut buf = Vec::new(); - buf.try_reserve(size).map_err(|e| { - std::io::Error::new( - std::io::ErrorKind::Other, - format!("new buffer failed: {}", e), - ) - })?; - buf.resize(size, 0); - Ok(Self { - read_done: false, - need_flush: false, - pos: 0, - cap: 0, - amt: 0, - buf: buf.into_boxed_slice(), - }) - } +// pub fn new_with_capacity(size: usize) -> Result { +// let mut buf = Vec::new(); +// buf.try_reserve(size).map_err(|e| { +// std::io::Error::new( +// std::io::ErrorKind::Other, +// format!("new buffer failed: {}", e), +// ) +// })?; +// buf.resize(size, 0); +// Ok(Self { +// read_done: false, +// need_flush: false, +// pos: 0, +// cap: 0, +// amt: 0, +// buf: buf.into_boxed_slice(), +// }) +// } - pub fn amount_transfered(&self) -> u64 { - self.amt - } +// pub fn amount_transfered(&self) -> u64 { +// self.amt +// } - pub fn poll_copy( - &mut self, - cx: &mut Context<'_>, - mut reader: Pin<&mut R>, - mut writer: Pin<&mut W>, - ) -> Poll> - where - R: AsyncRead + ?Sized, - W: AsyncWrite + ?Sized, - { - loop { - // If our buffer is empty, then we need to read some data to - // continue. - if self.pos == self.cap && !self.read_done { - let me = &mut *self; - let mut buf = ReadBuf::new(&mut me.buf); +// pub fn poll_copy( +// &mut self, +// cx: &mut Context<'_>, +// mut reader: Pin<&mut R>, +// mut writer: Pin<&mut W>, +// ) -> Poll> +// where +// R: AsyncRead + ?Sized, +// W: AsyncWrite + ?Sized, +// { +// loop { +// // If our buffer is empty, then we need to read some data to +// // continue. +// if self.pos == self.cap && !self.read_done { +// let me = &mut *self; +// let mut buf = ReadBuf::new(&mut me.buf); - match reader.as_mut().poll_read(cx, &mut buf) { - Poll::Ready(Ok(_)) => (), - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => { - // Try flushing when the reader has no progress to avoid deadlock - // when the reader depends on buffered writer. - if self.need_flush { - ready!(writer.as_mut().poll_flush(cx))?; - self.need_flush = false; - } +// match reader.as_mut().poll_read(cx, &mut buf) { +// Poll::Ready(Ok(_)) => (), +// Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), +// Poll::Pending => { +// // Try flushing when the reader has no progress to avoid deadlock +// // when the reader depends on buffered writer. +// if self.need_flush { +// ready!(writer.as_mut().poll_flush(cx))?; +// self.need_flush = false; +// } - return Poll::Pending; - } - } +// return Poll::Pending; +// } +// } - let n = buf.filled().len(); - if n == 0 { - self.read_done = true; - } else { - self.pos = 0; - self.cap = n; - } - } +// let n = buf.filled().len(); +// if n == 0 { +// self.read_done = true; +// } else { +// self.pos = 0; +// self.cap = n; +// } +// } - // If our buffer has some data, let's write it out! - while self.pos < self.cap { - let me = &mut *self; - let i = ready!(writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]))?; - if i == 0 { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::WriteZero, - "write zero byte into writer", - ))); - } else { - self.pos += i; - self.amt += i as u64; - self.need_flush = true; - } - } +// // If our buffer has some data, let's write it out! +// while self.pos < self.cap { +// let me = &mut *self; +// let i = ready!(writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]))?; +// if i == 0 { +// return Poll::Ready(Err(io::Error::new( +// io::ErrorKind::WriteZero, +// "write zero byte into writer", +// ))); +// } else { +// self.pos += i; +// self.amt += i as u64; +// self.need_flush = true; +// } +// } - // If pos larger than cap, this loop will never stop. - // In particular, user's wrong poll_write implementation returning - // incorrect written length may lead to thread blocking. - debug_assert!( - self.pos <= self.cap, - "writer returned length larger than input slice" - ); +// // If pos larger than cap, this loop will never stop. +// // In particular, user's wrong poll_write implementation returning +// // incorrect written length may lead to thread blocking. +// debug_assert!( +// self.pos <= self.cap, +// "writer returned length larger than input slice" +// ); - // If we've written all the data and we've seen EOF, flush out the - // data and finish the transfer. - if self.pos == self.cap && self.read_done { - ready!(writer.as_mut().poll_flush(cx))?; - return Poll::Ready(Ok(self.amt)); - } - } - } -} +// // If we've written all the data and we've seen EOF, flush out the +// // data and finish the transfer. +// if self.pos == self.cap && self.read_done { +// ready!(writer.as_mut().poll_flush(cx))?; +// return Poll::Ready(Ok(self.amt)); +// } +// } +// } +// } -enum TransferState { - Running(CopyBuffer), - ShuttingDown(u64), - Done, -} +// enum TransferState { +// Running(CopyBuffer), +// ShuttingDown(u64), +// Done, +// } -struct CopyBidirectional<'a, A: ?Sized, B: ?Sized> { - a: &'a mut A, - b: &'a mut B, - a_to_b: TransferState, - b_to_a: TransferState, - a_to_b_count: u64, - b_to_a_count: u64, - a_to_b_delay: Option>>, - b_to_a_delay: Option>>, - a_to_b_timeout_duration: Duration, - b_to_a_timeout_duration: Duration, -} +// struct CopyBidirectional<'a, A: ?Sized, B: ?Sized> { +// a: &'a mut A, +// b: &'a mut B, +// a_to_b: TransferState, +// b_to_a: TransferState, +// a_to_b_count: u64, +// b_to_a_count: u64, +// a_to_b_delay: Option>>, +// b_to_a_delay: Option>>, +// a_to_b_timeout_duration: Duration, +// b_to_a_timeout_duration: Duration, +// } -impl<'a, A, B> Future for CopyBidirectional<'a, A, B> -where - A: AsyncRead + AsyncWrite + Unpin + ?Sized, - B: AsyncRead + AsyncWrite + Unpin + ?Sized, -{ - type Output = io::Result<(u64, u64)>; +// impl<'a, A, B> Future for CopyBidirectional<'a, A, B> +// where +// A: AsyncRead + AsyncWrite + Unpin + ?Sized, +// B: AsyncRead + AsyncWrite + Unpin + ?Sized, +// { +// type Output = io::Result<(u64, u64)>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - // Unpack self into mut refs to each field to avoid borrow check issues. - let CopyBidirectional { - a, - b, - a_to_b, - b_to_a, - a_to_b_count, - b_to_a_count, - a_to_b_delay, - b_to_a_delay, - a_to_b_timeout_duration, - b_to_a_timeout_duration, - } = &mut *self; +// fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { +// // Unpack self into mut refs to each field to avoid borrow check issues. +// let CopyBidirectional { +// a, +// b, +// a_to_b, +// b_to_a, +// a_to_b_count, +// b_to_a_count, +// a_to_b_delay, +// b_to_a_delay, +// a_to_b_timeout_duration, +// b_to_a_timeout_duration, +// } = &mut *self; - let mut a = Pin::new(a); - let mut b = Pin::new(b); +// let mut a = Pin::new(a); +// let mut b = Pin::new(b); - loop { - match a_to_b { - TransferState::Running(buf) => { - let res = buf.poll_copy(cx, a.as_mut(), b.as_mut()); - match res { - Poll::Ready(Ok(count)) => { - *a_to_b = TransferState::ShuttingDown(count); - continue; - } - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => { - if let Some(delay) = a_to_b_delay { - match delay.as_mut().poll(cx) { - Poll::Ready(()) => { - *a_to_b = - TransferState::ShuttingDown(buf.amount_transfered()); - continue; - } - Poll::Pending => (), - } - } - } - } - } - TransferState::ShuttingDown(count) => { - let res = b.as_mut().poll_shutdown(cx); - match res { - Poll::Ready(Ok(())) => { - *a_to_b_count += *count; - *a_to_b = TransferState::Done; - b_to_a_delay - .replace(Box::pin(tokio::time::sleep(*b_to_a_timeout_duration))); - continue; - } - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => (), - } - } - TransferState::Done => (), - } +// loop { +// match a_to_b { +// TransferState::Running(buf) => { +// let res = buf.poll_copy(cx, a.as_mut(), b.as_mut()); +// match res { +// Poll::Ready(Ok(count)) => { +// *a_to_b = TransferState::ShuttingDown(count); +// continue; +// } +// Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), +// Poll::Pending => { +// if let Some(delay) = a_to_b_delay { +// match delay.as_mut().poll(cx) { +// Poll::Ready(()) => { +// *a_to_b = +// TransferState::ShuttingDown(buf.amount_transfered()); +// continue; +// } +// Poll::Pending => (), +// } +// } +// } +// } +// } +// TransferState::ShuttingDown(count) => { +// let res = b.as_mut().poll_shutdown(cx); +// match res { +// Poll::Ready(Ok(())) => { +// *a_to_b_count += *count; +// *a_to_b = TransferState::Done; +// b_to_a_delay +// .replace(Box::pin(tokio::time::sleep(*b_to_a_timeout_duration))); +// continue; +// } +// Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), +// Poll::Pending => (), +// } +// } +// TransferState::Done => (), +// } - match b_to_a { - TransferState::Running(buf) => { - let res = buf.poll_copy(cx, b.as_mut(), a.as_mut()); - match res { - Poll::Ready(Ok(count)) => { - *b_to_a = TransferState::ShuttingDown(count); - continue; - } - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => { - if let Some(delay) = b_to_a_delay { - match delay.as_mut().poll(cx) { - Poll::Ready(()) => { - *b_to_a = - TransferState::ShuttingDown(buf.amount_transfered()); - continue; - } - Poll::Pending => (), - } - } - } - } - } - TransferState::ShuttingDown(count) => { - let res = a.as_mut().poll_shutdown(cx); - match res { - Poll::Ready(Ok(())) => { - *b_to_a_count += *count; - *b_to_a = TransferState::Done; - a_to_b_delay - .replace(Box::pin(tokio::time::sleep(*a_to_b_timeout_duration))); - continue; - } - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => (), - } - } - TransferState::Done => (), - } +// match b_to_a { +// TransferState::Running(buf) => { +// let res = buf.poll_copy(cx, b.as_mut(), a.as_mut()); +// match res { +// Poll::Ready(Ok(count)) => { +// *b_to_a = TransferState::ShuttingDown(count); +// continue; +// } +// Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), +// Poll::Pending => { +// if let Some(delay) = b_to_a_delay { +// match delay.as_mut().poll(cx) { +// Poll::Ready(()) => { +// *b_to_a = +// TransferState::ShuttingDown(buf.amount_transfered()); +// continue; +// } +// Poll::Pending => (), +// } +// } +// } +// } +// } +// TransferState::ShuttingDown(count) => { +// let res = a.as_mut().poll_shutdown(cx); +// match res { +// Poll::Ready(Ok(())) => { +// *b_to_a_count += *count; +// *b_to_a = TransferState::Done; +// a_to_b_delay +// .replace(Box::pin(tokio::time::sleep(*a_to_b_timeout_duration))); +// continue; +// } +// Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), +// Poll::Pending => (), +// } +// } +// TransferState::Done => (), +// } - match (&a_to_b, &b_to_a) { - (TransferState::Done, TransferState::Done) => break, - _ => return Poll::Pending, - } - } +// match (&a_to_b, &b_to_a) { +// (TransferState::Done, TransferState::Done) => break, +// _ => return Poll::Pending, +// } +// } - Poll::Ready(Ok((*a_to_b_count, *b_to_a_count))) - } -} +// Poll::Ready(Ok((*a_to_b_count, *b_to_a_count))) +// } +// } // pub async fn copy_buf_bidirectional_with_timeout( // a: &mut A, @@ -298,8 +298,8 @@ pub async fn copy_buf_bidirectional_with_timeout( b_to_a_timeout_duration: Duration, ) -> Result<(u64, u64), std::io::Error> where - A: AsyncRead + AsyncWrite + Unpin + Sized, - B: AsyncRead + AsyncWrite + Unpin + Sized, + A: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static, + B: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static, { let (mut a_reader, mut a_writer) = tokio::io::split(a); let (mut b_reader, mut b_writer) = tokio::io::split(b); @@ -333,7 +333,7 @@ where } }); let b_to_a = tokio::spawn(async move { - let download_total_size = 0; + let mut download_total_size = 0; let mut buf = Vec::new(); buf.resize(size, 0); loop { @@ -366,7 +366,7 @@ where Ok(Err(e)) => { return Err(e); } - Err(e) => { + Err(_) => { return Err(std::io::Error::new(std::io::ErrorKind::Other, "")); } }; @@ -375,7 +375,7 @@ where Ok(Err(e)) => { return Err(e); } - Err(e) => { + Err(_) => { return Err(std::io::Error::new(std::io::ErrorKind::Other, "")); } }; diff --git a/clash_lib/src/proxy/mixed/mod.rs b/clash_lib/src/proxy/mixed/mod.rs index 9cd2326cc..949f7cd7a 100644 --- a/clash_lib/src/proxy/mixed/mod.rs +++ b/clash_lib/src/proxy/mixed/mod.rs @@ -53,7 +53,7 @@ impl InboundListener for Listener { loop { let (socket, _) = listener.accept().await?; - let mut socket = apply_tcp_options(socket)?; + let socket = apply_tcp_options(socket)?; let mut p = [0; 1]; let n = socket.peek(&mut p).await?; @@ -75,7 +75,7 @@ impl InboundListener for Listener { }; tokio::spawn(async move { - socks::handle_tcp(&mut sess, &mut socket, dispatcher, authenticator).await + socks::handle_tcp(&mut sess, socket, dispatcher, authenticator).await }); } diff --git a/clash_lib/src/proxy/socks/inbound/mod.rs b/clash_lib/src/proxy/socks/inbound/mod.rs index 4b37c807f..78faeb372 100644 --- a/clash_lib/src/proxy/socks/inbound/mod.rs +++ b/clash_lib/src/proxy/socks/inbound/mod.rs @@ -83,7 +83,7 @@ impl InboundListener for Listener { loop { let (socket, _) = listener.accept().await?; - let mut socket = apply_tcp_options(socket)?; + let socket = apply_tcp_options(socket)?; let mut sess = Session { network: Network::Tcp, @@ -97,7 +97,7 @@ impl InboundListener for Listener { let authenticator = self.authenticator.clone(); tokio::spawn(async move { - handle_tcp(&mut sess, &mut socket, dispatcher, authenticator).await + handle_tcp(&mut sess, socket, dispatcher, authenticator).await }); } } diff --git a/clash_lib/src/proxy/socks/inbound/stream.rs b/clash_lib/src/proxy/socks/inbound/stream.rs index d518c8ca9..fbd5cf9a9 100644 --- a/clash_lib/src/proxy/socks/inbound/stream.rs +++ b/clash_lib/src/proxy/socks/inbound/stream.rs @@ -19,7 +19,7 @@ use tracing::{instrument, trace, warn}; #[instrument(skip(s, dispatcher, authenticator))] pub async fn handle_tcp<'a>( sess: &'a mut Session, - s: &'a mut TcpStream, + mut s: TcpStream, dispatcher: Arc, authenticator: ThreadSafeAuthenticator, ) -> io::Result<()> { @@ -117,7 +117,7 @@ pub async fn handle_tcp<'a>( )); } - let dst = SocksAddr::read_from(s).await?; + let dst = SocksAddr::read_from(& mut s).await?; match buf[1] { socks_command::CONNECT => { From 45f89859d0c726e946d9170848d98d89e5dd223f Mon Sep 17 00:00:00 2001 From: xmh0511 <970252187@qq.com> Date: Thu, 7 Dec 2023 14:24:29 +0800 Subject: [PATCH 3/4] 1:1 implements the original logic of stream forward --- clash_lib/src/common/io.rs | 217 ++++++++++++++++++++++++++++++++----- config.yaml | 2 + 2 files changed, 191 insertions(+), 28 deletions(-) create mode 100644 config.yaml diff --git a/clash_lib/src/common/io.rs b/clash_lib/src/common/io.rs index c8d01317e..37bd7ae3d 100644 --- a/clash_lib/src/common/io.rs +++ b/clash_lib/src/common/io.rs @@ -290,6 +290,108 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; // .await // } +// #1 close the connection if no traffic can be read within the specified time +// pub async fn copy_buf_bidirectional_with_timeout( +// a: A, +// b: B, +// size: usize, +// a_to_b_timeout_duration: Duration, +// b_to_a_timeout_duration: Duration, +// ) -> Result<(u64, u64), std::io::Error> +// where +// A: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static, +// B: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static, +// { +// let (mut a_reader, mut a_writer) = tokio::io::split(a); +// let (mut b_reader, mut b_writer) = tokio::io::split(b); +// let a_to_b = tokio::spawn(async move { +// let mut upload_total_size = 0; +// let mut buf = Vec::new(); +// buf.resize(size, 0); +// loop { +// match tokio::time::timeout(a_to_b_timeout_duration, a_reader.read(&mut buf[..])).await { +// Ok(Ok(size)) => { +// if size == 0 { +// return Ok(upload_total_size); +// } +// match b_writer.write_all(&buf[..size]).await { +// Ok(_) => { +// upload_total_size += size; +// continue; +// } +// Err(e) => { +// return Err(e); +// } +// } +// } +// Ok(Err(e)) => { +// return Err(e); +// } +// Err(_) => { +// return Ok(upload_total_size); +// } +// } +// } +// }); +// let b_to_a = tokio::spawn(async move { +// let mut download_total_size = 0; +// let mut buf = Vec::new(); +// buf.resize(size, 0); +// loop { +// match tokio::time::timeout(b_to_a_timeout_duration, b_reader.read(&mut buf[..])).await { +// Ok(Ok(size)) => { +// if size == 0 { +// return Ok(download_total_size); +// } +// match a_writer.write_all(&buf[..size]).await { +// Ok(_) => { +// download_total_size += size; +// continue; +// } +// Err(e) => { +// return Err(e); +// } +// } +// } +// Ok(Err(e)) => { +// return Err(e); +// } +// Err(_) => { +// return Ok(download_total_size); +// } +// } +// } +// }); +// let up = match a_to_b.await { +// Ok(Ok(up)) => up, +// Ok(Err(e)) => { +// return Err(e); +// } +// Err(_) => { +// return Err(std::io::Error::new(std::io::ErrorKind::Other, "")); +// } +// }; +// let down = match b_to_a.await { +// Ok(Ok(up)) => up, +// Ok(Err(e)) => { +// return Err(e); +// } +// Err(_) => { +// return Err(std::io::Error::new(std::io::ErrorKind::Other, "")); +// } +// }; +// Ok((up as u64, down as u64)) +// } + +// #2 1:1 implement the original logic, that is, the other direction has a timeout requirement after one direction is ready +// if both directions are pending, no timeout limition will be imposed on them + +enum DirectionMsg { + SetTimeOut, + Terminate(std::io::Error), + Done, +} + pub async fn copy_buf_bidirectional_with_timeout( a: A, b: B, @@ -303,62 +405,120 @@ where { let (mut a_reader, mut a_writer) = tokio::io::split(a); let (mut b_reader, mut b_writer) = tokio::io::split(b); + let (a_to_b_msg_tx, mut a_to_b_msg_rx) = tokio::sync::mpsc::unbounded_channel::(); + let (b_to_a_msg_tx, mut b_to_a_msg_rx) = tokio::sync::mpsc::unbounded_channel::(); let a_to_b = tokio::spawn(async move { let mut upload_total_size = 0; let mut buf = Vec::new(); buf.resize(size, 0); + let mut need_timeout = false; loop { - match tokio::time::timeout(a_to_b_timeout_duration, a_reader.read(&mut buf[..])).await { - Ok(Ok(size)) => { - if size == 0 { - return Ok(upload_total_size); + tokio::select! { + msg = async { + if need_timeout{ + tokio::time::sleep(a_to_b_timeout_duration).await; + DirectionMsg::Done + }else{ + match a_to_b_msg_rx.recv().await{ + Some(v)=>v, + None=>DirectionMsg::SetTimeOut // the other direction is done + } + } + } =>{ + match msg{ + DirectionMsg::SetTimeOut => { + need_timeout = true; + }, + DirectionMsg::Terminate(e) =>{ + return Err(e); + }, + DirectionMsg::Done => return Ok(upload_total_size), } - match b_writer.write_all(&buf[..size]).await { - Ok(_) => { + } + r = a_reader.read(&mut buf[..]) =>{ + match r{ + Ok(size)=>{ + if size == 0{ + // peer has shutdown + match b_writer.shutdown().await{ + Ok(_)=>{ + let _ = b_to_a_msg_tx.send(DirectionMsg::SetTimeOut); + } + Err(e)=>{ + let _ = b_to_a_msg_tx.send(DirectionMsg::Terminate(e)); + } + } + return Ok(upload_total_size); + } + if let Err(e) = b_writer.write_all(&buf[..size]).await{ + return Err(e); + } upload_total_size += size; continue; } - Err(e) => { + Err(e)=>{ return Err(e); } } } - Ok(Err(e)) => { - return Err(e); - } - Err(_) => { - return Ok(upload_total_size); - } - } + }; } }); let b_to_a = tokio::spawn(async move { let mut download_total_size = 0; let mut buf = Vec::new(); buf.resize(size, 0); + let mut need_timeout = false; loop { - match tokio::time::timeout(b_to_a_timeout_duration, b_reader.read(&mut buf[..])).await { - Ok(Ok(size)) => { - if size == 0 { - return Ok(download_total_size); + tokio::select! { + msg = async { + if need_timeout{ + tokio::time::sleep(b_to_a_timeout_duration).await; + DirectionMsg::Done + }else{ + match b_to_a_msg_rx.recv().await{ + Some(v)=>v, + None=>DirectionMsg::SetTimeOut // the other direction is done + } + } + } =>{ + match msg{ + DirectionMsg::SetTimeOut => { + need_timeout = true; + }, + DirectionMsg::Terminate(e) =>{ + return Err(e); + }, + DirectionMsg::Done => return Ok(download_total_size), } - match a_writer.write_all(&buf[..size]).await { - Ok(_) => { + } + r = b_reader.read(&mut buf[..]) =>{ + match r{ + Ok(size)=>{ + if size == 0{ + // peer has shutdown + match a_writer.shutdown().await{ + Ok(_)=>{ + let _ = a_to_b_msg_tx.send(DirectionMsg::SetTimeOut); + } + Err(e)=>{ + let _ = a_to_b_msg_tx.send(DirectionMsg::Terminate(e)); + } + } + return Ok(download_total_size); + } + if let Err(e) = a_writer.write_all(&buf[..size]).await{ + return Err(e); + } download_total_size += size; continue; } - Err(e) => { + Err(e)=>{ return Err(e); } } } - Ok(Err(e)) => { - return Err(e); - } - Err(_) => { - return Ok(download_total_size); - } - } + }; } }); let up = match a_to_b.await { @@ -379,5 +539,6 @@ where return Err(std::io::Error::new(std::io::ErrorKind::Other, "")); } }; + //dbg!(up,down); Ok((up as u64, down as u64)) } diff --git a/config.yaml b/config.yaml new file mode 100644 index 000000000..a45ddc832 --- /dev/null +++ b/config.yaml @@ -0,0 +1,2 @@ +port: 8080 +external-controller: 127.0.0.1:9090 From b8ef453ee9a2a37acbbd30d21920ef978f0be807 Mon Sep 17 00:00:00 2001 From: xmh0511 <970252187@qq.com> Date: Thu, 7 Dec 2023 21:38:54 +0800 Subject: [PATCH 4/4] revamp the implementation of forwarding stream --- clash_lib/src/common/io.rs | 166 +++++++++++++++---------------------- config.yaml | 2 - 2 files changed, 66 insertions(+), 102 deletions(-) delete mode 100644 config.yaml diff --git a/clash_lib/src/common/io.rs b/clash_lib/src/common/io.rs index 37bd7ae3d..185196ee4 100644 --- a/clash_lib/src/common/io.rs +++ b/clash_lib/src/common/io.rs @@ -6,7 +6,11 @@ use std::time::Duration; // use futures::ready; +use futures::TryFutureExt; +use std::io::Error; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; +use tokio::task::JoinHandle; // #[derive(Debug)] // pub struct CopyBuffer { @@ -386,7 +390,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; // #2 1:1 implement the original logic, that is, the other direction has a timeout requirement after one direction is ready // if both directions are pending, no timeout limition will be imposed on them -enum DirectionMsg { +enum SyncMsg { SetTimeOut, Terminate(std::io::Error), Done, @@ -403,114 +407,96 @@ where A: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static, B: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static, { - let (mut a_reader, mut a_writer) = tokio::io::split(a); - let (mut b_reader, mut b_writer) = tokio::io::split(b); - let (a_to_b_msg_tx, mut a_to_b_msg_rx) = tokio::sync::mpsc::unbounded_channel::(); - let (b_to_a_msg_tx, mut b_to_a_msg_rx) = tokio::sync::mpsc::unbounded_channel::(); - let a_to_b = tokio::spawn(async move { - let mut upload_total_size = 0; - let mut buf = Vec::new(); - buf.resize(size, 0); - let mut need_timeout = false; - loop { - tokio::select! { - msg = async { - if need_timeout{ - tokio::time::sleep(a_to_b_timeout_duration).await; - DirectionMsg::Done - }else{ - match a_to_b_msg_rx.recv().await{ - Some(v)=>v, - None=>DirectionMsg::SetTimeOut // the other direction is done - } - } - } =>{ - match msg{ - DirectionMsg::SetTimeOut => { - need_timeout = true; - }, - DirectionMsg::Terminate(e) =>{ - return Err(e); - }, - DirectionMsg::Done => return Ok(upload_total_size), - } - } - r = a_reader.read(&mut buf[..]) =>{ - match r{ - Ok(size)=>{ - if size == 0{ - // peer has shutdown - match b_writer.shutdown().await{ - Ok(_)=>{ - let _ = b_to_a_msg_tx.send(DirectionMsg::SetTimeOut); - } - Err(e)=>{ - let _ = b_to_a_msg_tx.send(DirectionMsg::Terminate(e)); - } - } - return Ok(upload_total_size); - } - if let Err(e) = b_writer.write_all(&buf[..size]).await{ - return Err(e); - } - upload_total_size += size; - continue; - } - Err(e)=>{ - return Err(e); - } - } - } - }; - } - }); - let b_to_a = tokio::spawn(async move { - let mut download_total_size = 0; + let (a_reader, a_writer) = tokio::io::split(a); + let (b_reader, b_writer) = tokio::io::split(b); + let (tx_for_a_to_b, rx_for_a_to_b) = tokio::sync::mpsc::unbounded_channel(); + let (tx_for_b_to_a, rx_for_b_to_a) = tokio::sync::mpsc::unbounded_channel(); + let a_to_b = copy_from_lhs_to_rhs( + a_reader, + b_writer, + size, + a_to_b_timeout_duration, + rx_for_a_to_b, + tx_for_b_to_a, + ); + let b_to_a = copy_from_lhs_to_rhs( + b_reader, + a_writer, + size, + b_to_a_timeout_duration, + rx_for_b_to_a, + tx_for_a_to_b, + ); + + let up = a_to_b + .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "")) + .await??; + let down = b_to_a + .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "")) + .await??; + //dbg!(up,down); + Ok((up as u64, down as u64)) +} + +fn copy_from_lhs_to_rhs( + mut lhs_reader: A, + mut rhs_writer: B, + buf_size: usize, + timeout: Duration, + mut msg_rx: UnboundedReceiver, + other_side_sender: UnboundedSender, +) -> JoinHandle> +where + A: AsyncRead + Unpin + Sized + Send + 'static, + B: AsyncWrite + Unpin + Sized + Send + 'static, +{ + tokio::spawn(async move { + let mut transferred_size = 0; let mut buf = Vec::new(); - buf.resize(size, 0); + buf.resize(buf_size, 0); let mut need_timeout = false; loop { tokio::select! { msg = async { if need_timeout{ - tokio::time::sleep(b_to_a_timeout_duration).await; - DirectionMsg::Done + tokio::time::sleep(timeout).await; + SyncMsg::Done }else{ - match b_to_a_msg_rx.recv().await{ + match msg_rx.recv().await{ Some(v)=>v, - None=>DirectionMsg::SetTimeOut // the other direction is done + None=>SyncMsg::SetTimeOut // the other direction is done } } } =>{ match msg{ - DirectionMsg::SetTimeOut => { + SyncMsg::SetTimeOut => { need_timeout = true; }, - DirectionMsg::Terminate(e) =>{ + SyncMsg::Terminate(e) =>{ return Err(e); }, - DirectionMsg::Done => return Ok(download_total_size), + SyncMsg::Done => return Ok(transferred_size), } } - r = b_reader.read(&mut buf[..]) =>{ + r = lhs_reader.read(&mut buf[..]) =>{ match r{ Ok(size)=>{ if size == 0{ // peer has shutdown - match a_writer.shutdown().await{ + match rhs_writer.shutdown().await{ Ok(_)=>{ - let _ = a_to_b_msg_tx.send(DirectionMsg::SetTimeOut); + let _ = other_side_sender.send(SyncMsg::SetTimeOut); } Err(e)=>{ - let _ = a_to_b_msg_tx.send(DirectionMsg::Terminate(e)); + let _ = other_side_sender.send(SyncMsg::Terminate(e)); } } - return Ok(download_total_size); + return Ok(transferred_size); } - if let Err(e) = a_writer.write_all(&buf[..size]).await{ + if let Err(e) = rhs_writer.write_all(&buf[..size]).await{ return Err(e); } - download_total_size += size; + transferred_size += size; continue; } Err(e)=>{ @@ -520,25 +506,5 @@ where } }; } - }); - let up = match a_to_b.await { - Ok(Ok(up)) => up, - Ok(Err(e)) => { - return Err(e); - } - Err(_) => { - return Err(std::io::Error::new(std::io::ErrorKind::Other, "")); - } - }; - let down = match b_to_a.await { - Ok(Ok(up)) => up, - Ok(Err(e)) => { - return Err(e); - } - Err(_) => { - return Err(std::io::Error::new(std::io::ErrorKind::Other, "")); - } - }; - //dbg!(up,down); - Ok((up as u64, down as u64)) + }) } diff --git a/config.yaml b/config.yaml deleted file mode 100644 index a45ddc832..000000000 --- a/config.yaml +++ /dev/null @@ -1,2 +0,0 @@ -port: 8080 -external-controller: 127.0.0.1:9090