From f0fe9a61db1a357c7f91608d4d042bca706f9b65 Mon Sep 17 00:00:00 2001
From: xmh0511 <970252187@qq.com>
Date: Wed, 6 Dec 2023 14:29:12 +0800
Subject: [PATCH 1/4] simplify the io copy for dispatch stream
---
clash_lib/src/app/dispatcher/dispatcher.rs | 5 +-
clash_lib/src/common/io.rs | 128 ++++++++++++++++++---
2 files changed, 113 insertions(+), 20 deletions(-)
diff --git a/clash_lib/src/app/dispatcher/dispatcher.rs b/clash_lib/src/app/dispatcher/dispatcher.rs
index 3657a1812..b055eb264 100644
--- a/clash_lib/src/app/dispatcher/dispatcher.rs
+++ b/clash_lib/src/app/dispatcher/dispatcher.rs
@@ -11,6 +11,7 @@ use crate::proxy::AnyInboundDatagram;
use crate::session::Session;
use futures::SinkExt;
use futures::StreamExt;
+use tokio::io::AsyncReadExt;
use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::net::SocketAddr;
@@ -136,8 +137,8 @@ impl Dispatcher {
let mut rhs =
TrackedStream::new(rhs, self.manager.clone(), sess.clone(), rule).await;
match copy_buf_bidirectional_with_timeout(
- &mut lhs,
- &mut rhs,
+ lhs,
+ rhs,
4096,
Duration::from_secs(10),
Duration::from_secs(10),
diff --git a/clash_lib/src/common/io.rs b/clash_lib/src/common/io.rs
index b9e5fa59b..51df4647f 100644
--- a/clash_lib/src/common/io.rs
+++ b/clash_lib/src/common/io.rs
@@ -6,7 +6,7 @@ use std::task::{Context, Poll};
use std::time::Duration;
use futures::ready;
-use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
+use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
#[derive(Debug)]
pub struct CopyBuffer {
@@ -264,28 +264,120 @@ where
}
}
+// pub async fn copy_buf_bidirectional_with_timeout(
+// a: &mut A,
+// b: &mut B,
+// size: usize,
+// a_to_b_timeout_duration: Duration,
+// b_to_a_timeout_duration: Duration,
+// ) -> Result<(u64, u64), std::io::Error>
+// where
+// A: AsyncRead + AsyncWrite + Unpin + ?Sized,
+// B: AsyncRead + AsyncWrite + Unpin + ?Sized,
+// {
+// CopyBidirectional {
+// a,
+// b,
+// a_to_b: TransferState::Running(CopyBuffer::new_with_capacity(size)?),
+// b_to_a: TransferState::Running(CopyBuffer::new_with_capacity(size)?),
+// a_to_b_count: 0,
+// b_to_a_count: 0,
+// a_to_b_delay: None,
+// b_to_a_delay: None,
+// a_to_b_timeout_duration,
+// b_to_a_timeout_duration,
+// }
+// .await
+// }
+
pub async fn copy_buf_bidirectional_with_timeout(
- a: &mut A,
- b: &mut B,
+ a: A,
+ b: B,
size: usize,
a_to_b_timeout_duration: Duration,
b_to_a_timeout_duration: Duration,
) -> Result<(u64, u64), std::io::Error>
where
- A: AsyncRead + AsyncWrite + Unpin + ?Sized,
- B: AsyncRead + AsyncWrite + Unpin + ?Sized,
+ A: AsyncRead + AsyncWrite + Unpin + Sized,
+ B: AsyncRead + AsyncWrite + Unpin + Sized,
{
- CopyBidirectional {
- a,
- b,
- a_to_b: TransferState::Running(CopyBuffer::new_with_capacity(size)?),
- b_to_a: TransferState::Running(CopyBuffer::new_with_capacity(size)?),
- a_to_b_count: 0,
- b_to_a_count: 0,
- a_to_b_delay: None,
- b_to_a_delay: None,
- a_to_b_timeout_duration,
- b_to_a_timeout_duration,
- }
- .await
+ let (mut a_reader, mut a_writer) = tokio::io::split(a);
+ let (mut b_reader, mut b_writer) = tokio::io::split(b);
+ let a_to_b = tokio::spawn(async move {
+ let mut upload_total_size = 0;
+ let mut buf = Vec::new();
+ buf.resize(size, 0);
+ loop {
+ match tokio::time::timeout(a_to_b_timeout_duration, a_reader.read(&mut buf[..])).await {
+ Ok(Ok(size)) => {
+ if size == 0 {
+ return Ok(upload_total_size);
+ }
+ match b_writer.write_all(&buf[..size]).await {
+ Ok(_) => {
+ upload_total_size += size;
+ continue;
+ }
+ Err(e) => {
+ return Err(e);
+ }
+ }
+ }
+ Ok(Err(e)) => {
+ return Err(e);
+ }
+ Err(_) => {
+ return Ok(upload_total_size);
+ }
+ }
+ }
+ });
+ let b_to_a = tokio::spawn(async move {
+ let download_total_size = 0;
+ let mut buf = Vec::new();
+ buf.resize(size, 0);
+ loop {
+ match tokio::time::timeout(b_to_a_timeout_duration, b_reader.read(&mut buf[..])).await {
+ Ok(Ok(size)) => {
+ if size == 0 {
+ return Ok(download_total_size);
+ }
+ match a_writer.write_all(&buf[..size]).await {
+ Ok(_) => {
+ download_total_size += size;
+ continue;
+ }
+ Err(e) => {
+ return Err(e);
+ }
+ }
+ }
+ Ok(Err(e)) => {
+ return Err(e);
+ }
+ Err(_) => {
+ return Ok(download_total_size);
+ }
+ }
+ }
+ });
+ let up = match a_to_b.await {
+ Ok(Ok(up)) => up,
+ Ok(Err(e)) => {
+ return Err(e);
+ }
+ Err(e) => {
+ return Err(std::io::Error::new(std::io::ErrorKind::Other, ""));
+ }
+ };
+ let down = match b_to_a.await {
+ Ok(Ok(up)) => up,
+ Ok(Err(e)) => {
+ return Err(e);
+ }
+ Err(e) => {
+ return Err(std::io::Error::new(std::io::ErrorKind::Other, ""));
+ }
+ };
+ Ok((up as u64, down as u64))
}
From 452af1ffa46a3767e5bcd766aa523ae4a72f8f15 Mon Sep 17 00:00:00 2001
From: xmh0511 <970252187@qq.com>
Date: Wed, 6 Dec 2023 15:13:33 +0800
Subject: [PATCH 2/4] fix some errors
---
clash_lib/src/app/dispatcher/dispatcher.rs | 6 +-
clash_lib/src/common/io.rs | 494 ++++++++++----------
clash_lib/src/proxy/mixed/mod.rs | 4 +-
clash_lib/src/proxy/socks/inbound/mod.rs | 4 +-
clash_lib/src/proxy/socks/inbound/stream.rs | 4 +-
5 files changed, 256 insertions(+), 256 deletions(-)
diff --git a/clash_lib/src/app/dispatcher/dispatcher.rs b/clash_lib/src/app/dispatcher/dispatcher.rs
index b055eb264..f09b5129d 100644
--- a/clash_lib/src/app/dispatcher/dispatcher.rs
+++ b/clash_lib/src/app/dispatcher/dispatcher.rs
@@ -11,7 +11,7 @@ use crate::proxy::AnyInboundDatagram;
use crate::session::Session;
use futures::SinkExt;
use futures::StreamExt;
-use tokio::io::AsyncReadExt;
+
use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::net::SocketAddr;
@@ -78,7 +78,7 @@ impl Dispatcher {
#[instrument(skip(lhs))]
pub async fn dispatch_stream(&self, sess: Session, mut lhs: S)
where
- S: AsyncRead + AsyncWrite + Unpin + Send,
+ S: AsyncRead + AsyncWrite + Unpin + Send +'static,
{
let sess = if self.resolver.fake_ip_enabled() {
match sess.destination {
@@ -134,7 +134,7 @@ impl Dispatcher {
{
Ok(rhs) => {
debug!("remote connection established {}", sess);
- let mut rhs =
+ let rhs =
TrackedStream::new(rhs, self.manager.clone(), sess.clone(), rule).await;
match copy_buf_bidirectional_with_timeout(
lhs,
diff --git a/clash_lib/src/common/io.rs b/clash_lib/src/common/io.rs
index 51df4647f..c8d01317e 100644
--- a/clash_lib/src/common/io.rs
+++ b/clash_lib/src/common/io.rs
@@ -1,268 +1,268 @@
/// copy of https://github.com/eycorsican/leaf/blob/a77a1e497ae034f3a2a89c8628d5e7ebb2af47f0/leaf/src/common/io.rs
-use std::future::Future;
-use std::io;
-use std::pin::Pin;
-use std::task::{Context, Poll};
+// use std::future::Future;
+// use std::io;
+// use std::pin::Pin;
+// use std::task::{Context, Poll};
use std::time::Duration;
-use futures::ready;
-use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
+// use futures::ready;
+use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
-#[derive(Debug)]
-pub struct CopyBuffer {
- read_done: bool,
- need_flush: bool,
- pos: usize,
- cap: usize,
- amt: u64,
- buf: Box<[u8]>,
-}
+// #[derive(Debug)]
+// pub struct CopyBuffer {
+// read_done: bool,
+// need_flush: bool,
+// pos: usize,
+// cap: usize,
+// amt: u64,
+// buf: Box<[u8]>,
+// }
-impl CopyBuffer {
- #[allow(unused)]
- pub fn new() -> Self {
- Self {
- read_done: false,
- need_flush: false,
- pos: 0,
- cap: 0,
- amt: 0,
- buf: vec![0; 2 * 1024].into_boxed_slice(),
- }
- }
+// impl CopyBuffer {
+// #[allow(unused)]
+// pub fn new() -> Self {
+// Self {
+// read_done: false,
+// need_flush: false,
+// pos: 0,
+// cap: 0,
+// amt: 0,
+// buf: vec![0; 2 * 1024].into_boxed_slice(),
+// }
+// }
- pub fn new_with_capacity(size: usize) -> Result {
- let mut buf = Vec::new();
- buf.try_reserve(size).map_err(|e| {
- std::io::Error::new(
- std::io::ErrorKind::Other,
- format!("new buffer failed: {}", e),
- )
- })?;
- buf.resize(size, 0);
- Ok(Self {
- read_done: false,
- need_flush: false,
- pos: 0,
- cap: 0,
- amt: 0,
- buf: buf.into_boxed_slice(),
- })
- }
+// pub fn new_with_capacity(size: usize) -> Result {
+// let mut buf = Vec::new();
+// buf.try_reserve(size).map_err(|e| {
+// std::io::Error::new(
+// std::io::ErrorKind::Other,
+// format!("new buffer failed: {}", e),
+// )
+// })?;
+// buf.resize(size, 0);
+// Ok(Self {
+// read_done: false,
+// need_flush: false,
+// pos: 0,
+// cap: 0,
+// amt: 0,
+// buf: buf.into_boxed_slice(),
+// })
+// }
- pub fn amount_transfered(&self) -> u64 {
- self.amt
- }
+// pub fn amount_transfered(&self) -> u64 {
+// self.amt
+// }
- pub fn poll_copy(
- &mut self,
- cx: &mut Context<'_>,
- mut reader: Pin<&mut R>,
- mut writer: Pin<&mut W>,
- ) -> Poll>
- where
- R: AsyncRead + ?Sized,
- W: AsyncWrite + ?Sized,
- {
- loop {
- // If our buffer is empty, then we need to read some data to
- // continue.
- if self.pos == self.cap && !self.read_done {
- let me = &mut *self;
- let mut buf = ReadBuf::new(&mut me.buf);
+// pub fn poll_copy(
+// &mut self,
+// cx: &mut Context<'_>,
+// mut reader: Pin<&mut R>,
+// mut writer: Pin<&mut W>,
+// ) -> Poll>
+// where
+// R: AsyncRead + ?Sized,
+// W: AsyncWrite + ?Sized,
+// {
+// loop {
+// // If our buffer is empty, then we need to read some data to
+// // continue.
+// if self.pos == self.cap && !self.read_done {
+// let me = &mut *self;
+// let mut buf = ReadBuf::new(&mut me.buf);
- match reader.as_mut().poll_read(cx, &mut buf) {
- Poll::Ready(Ok(_)) => (),
- Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
- Poll::Pending => {
- // Try flushing when the reader has no progress to avoid deadlock
- // when the reader depends on buffered writer.
- if self.need_flush {
- ready!(writer.as_mut().poll_flush(cx))?;
- self.need_flush = false;
- }
+// match reader.as_mut().poll_read(cx, &mut buf) {
+// Poll::Ready(Ok(_)) => (),
+// Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
+// Poll::Pending => {
+// // Try flushing when the reader has no progress to avoid deadlock
+// // when the reader depends on buffered writer.
+// if self.need_flush {
+// ready!(writer.as_mut().poll_flush(cx))?;
+// self.need_flush = false;
+// }
- return Poll::Pending;
- }
- }
+// return Poll::Pending;
+// }
+// }
- let n = buf.filled().len();
- if n == 0 {
- self.read_done = true;
- } else {
- self.pos = 0;
- self.cap = n;
- }
- }
+// let n = buf.filled().len();
+// if n == 0 {
+// self.read_done = true;
+// } else {
+// self.pos = 0;
+// self.cap = n;
+// }
+// }
- // If our buffer has some data, let's write it out!
- while self.pos < self.cap {
- let me = &mut *self;
- let i = ready!(writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]))?;
- if i == 0 {
- return Poll::Ready(Err(io::Error::new(
- io::ErrorKind::WriteZero,
- "write zero byte into writer",
- )));
- } else {
- self.pos += i;
- self.amt += i as u64;
- self.need_flush = true;
- }
- }
+// // If our buffer has some data, let's write it out!
+// while self.pos < self.cap {
+// let me = &mut *self;
+// let i = ready!(writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]))?;
+// if i == 0 {
+// return Poll::Ready(Err(io::Error::new(
+// io::ErrorKind::WriteZero,
+// "write zero byte into writer",
+// )));
+// } else {
+// self.pos += i;
+// self.amt += i as u64;
+// self.need_flush = true;
+// }
+// }
- // If pos larger than cap, this loop will never stop.
- // In particular, user's wrong poll_write implementation returning
- // incorrect written length may lead to thread blocking.
- debug_assert!(
- self.pos <= self.cap,
- "writer returned length larger than input slice"
- );
+// // If pos larger than cap, this loop will never stop.
+// // In particular, user's wrong poll_write implementation returning
+// // incorrect written length may lead to thread blocking.
+// debug_assert!(
+// self.pos <= self.cap,
+// "writer returned length larger than input slice"
+// );
- // If we've written all the data and we've seen EOF, flush out the
- // data and finish the transfer.
- if self.pos == self.cap && self.read_done {
- ready!(writer.as_mut().poll_flush(cx))?;
- return Poll::Ready(Ok(self.amt));
- }
- }
- }
-}
+// // If we've written all the data and we've seen EOF, flush out the
+// // data and finish the transfer.
+// if self.pos == self.cap && self.read_done {
+// ready!(writer.as_mut().poll_flush(cx))?;
+// return Poll::Ready(Ok(self.amt));
+// }
+// }
+// }
+// }
-enum TransferState {
- Running(CopyBuffer),
- ShuttingDown(u64),
- Done,
-}
+// enum TransferState {
+// Running(CopyBuffer),
+// ShuttingDown(u64),
+// Done,
+// }
-struct CopyBidirectional<'a, A: ?Sized, B: ?Sized> {
- a: &'a mut A,
- b: &'a mut B,
- a_to_b: TransferState,
- b_to_a: TransferState,
- a_to_b_count: u64,
- b_to_a_count: u64,
- a_to_b_delay: Option>>,
- b_to_a_delay: Option>>,
- a_to_b_timeout_duration: Duration,
- b_to_a_timeout_duration: Duration,
-}
+// struct CopyBidirectional<'a, A: ?Sized, B: ?Sized> {
+// a: &'a mut A,
+// b: &'a mut B,
+// a_to_b: TransferState,
+// b_to_a: TransferState,
+// a_to_b_count: u64,
+// b_to_a_count: u64,
+// a_to_b_delay: Option>>,
+// b_to_a_delay: Option>>,
+// a_to_b_timeout_duration: Duration,
+// b_to_a_timeout_duration: Duration,
+// }
-impl<'a, A, B> Future for CopyBidirectional<'a, A, B>
-where
- A: AsyncRead + AsyncWrite + Unpin + ?Sized,
- B: AsyncRead + AsyncWrite + Unpin + ?Sized,
-{
- type Output = io::Result<(u64, u64)>;
+// impl<'a, A, B> Future for CopyBidirectional<'a, A, B>
+// where
+// A: AsyncRead + AsyncWrite + Unpin + ?Sized,
+// B: AsyncRead + AsyncWrite + Unpin + ?Sized,
+// {
+// type Output = io::Result<(u64, u64)>;
- fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll {
- // Unpack self into mut refs to each field to avoid borrow check issues.
- let CopyBidirectional {
- a,
- b,
- a_to_b,
- b_to_a,
- a_to_b_count,
- b_to_a_count,
- a_to_b_delay,
- b_to_a_delay,
- a_to_b_timeout_duration,
- b_to_a_timeout_duration,
- } = &mut *self;
+// fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll {
+// // Unpack self into mut refs to each field to avoid borrow check issues.
+// let CopyBidirectional {
+// a,
+// b,
+// a_to_b,
+// b_to_a,
+// a_to_b_count,
+// b_to_a_count,
+// a_to_b_delay,
+// b_to_a_delay,
+// a_to_b_timeout_duration,
+// b_to_a_timeout_duration,
+// } = &mut *self;
- let mut a = Pin::new(a);
- let mut b = Pin::new(b);
+// let mut a = Pin::new(a);
+// let mut b = Pin::new(b);
- loop {
- match a_to_b {
- TransferState::Running(buf) => {
- let res = buf.poll_copy(cx, a.as_mut(), b.as_mut());
- match res {
- Poll::Ready(Ok(count)) => {
- *a_to_b = TransferState::ShuttingDown(count);
- continue;
- }
- Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
- Poll::Pending => {
- if let Some(delay) = a_to_b_delay {
- match delay.as_mut().poll(cx) {
- Poll::Ready(()) => {
- *a_to_b =
- TransferState::ShuttingDown(buf.amount_transfered());
- continue;
- }
- Poll::Pending => (),
- }
- }
- }
- }
- }
- TransferState::ShuttingDown(count) => {
- let res = b.as_mut().poll_shutdown(cx);
- match res {
- Poll::Ready(Ok(())) => {
- *a_to_b_count += *count;
- *a_to_b = TransferState::Done;
- b_to_a_delay
- .replace(Box::pin(tokio::time::sleep(*b_to_a_timeout_duration)));
- continue;
- }
- Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
- Poll::Pending => (),
- }
- }
- TransferState::Done => (),
- }
+// loop {
+// match a_to_b {
+// TransferState::Running(buf) => {
+// let res = buf.poll_copy(cx, a.as_mut(), b.as_mut());
+// match res {
+// Poll::Ready(Ok(count)) => {
+// *a_to_b = TransferState::ShuttingDown(count);
+// continue;
+// }
+// Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
+// Poll::Pending => {
+// if let Some(delay) = a_to_b_delay {
+// match delay.as_mut().poll(cx) {
+// Poll::Ready(()) => {
+// *a_to_b =
+// TransferState::ShuttingDown(buf.amount_transfered());
+// continue;
+// }
+// Poll::Pending => (),
+// }
+// }
+// }
+// }
+// }
+// TransferState::ShuttingDown(count) => {
+// let res = b.as_mut().poll_shutdown(cx);
+// match res {
+// Poll::Ready(Ok(())) => {
+// *a_to_b_count += *count;
+// *a_to_b = TransferState::Done;
+// b_to_a_delay
+// .replace(Box::pin(tokio::time::sleep(*b_to_a_timeout_duration)));
+// continue;
+// }
+// Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
+// Poll::Pending => (),
+// }
+// }
+// TransferState::Done => (),
+// }
- match b_to_a {
- TransferState::Running(buf) => {
- let res = buf.poll_copy(cx, b.as_mut(), a.as_mut());
- match res {
- Poll::Ready(Ok(count)) => {
- *b_to_a = TransferState::ShuttingDown(count);
- continue;
- }
- Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
- Poll::Pending => {
- if let Some(delay) = b_to_a_delay {
- match delay.as_mut().poll(cx) {
- Poll::Ready(()) => {
- *b_to_a =
- TransferState::ShuttingDown(buf.amount_transfered());
- continue;
- }
- Poll::Pending => (),
- }
- }
- }
- }
- }
- TransferState::ShuttingDown(count) => {
- let res = a.as_mut().poll_shutdown(cx);
- match res {
- Poll::Ready(Ok(())) => {
- *b_to_a_count += *count;
- *b_to_a = TransferState::Done;
- a_to_b_delay
- .replace(Box::pin(tokio::time::sleep(*a_to_b_timeout_duration)));
- continue;
- }
- Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
- Poll::Pending => (),
- }
- }
- TransferState::Done => (),
- }
+// match b_to_a {
+// TransferState::Running(buf) => {
+// let res = buf.poll_copy(cx, b.as_mut(), a.as_mut());
+// match res {
+// Poll::Ready(Ok(count)) => {
+// *b_to_a = TransferState::ShuttingDown(count);
+// continue;
+// }
+// Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
+// Poll::Pending => {
+// if let Some(delay) = b_to_a_delay {
+// match delay.as_mut().poll(cx) {
+// Poll::Ready(()) => {
+// *b_to_a =
+// TransferState::ShuttingDown(buf.amount_transfered());
+// continue;
+// }
+// Poll::Pending => (),
+// }
+// }
+// }
+// }
+// }
+// TransferState::ShuttingDown(count) => {
+// let res = a.as_mut().poll_shutdown(cx);
+// match res {
+// Poll::Ready(Ok(())) => {
+// *b_to_a_count += *count;
+// *b_to_a = TransferState::Done;
+// a_to_b_delay
+// .replace(Box::pin(tokio::time::sleep(*a_to_b_timeout_duration)));
+// continue;
+// }
+// Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
+// Poll::Pending => (),
+// }
+// }
+// TransferState::Done => (),
+// }
- match (&a_to_b, &b_to_a) {
- (TransferState::Done, TransferState::Done) => break,
- _ => return Poll::Pending,
- }
- }
+// match (&a_to_b, &b_to_a) {
+// (TransferState::Done, TransferState::Done) => break,
+// _ => return Poll::Pending,
+// }
+// }
- Poll::Ready(Ok((*a_to_b_count, *b_to_a_count)))
- }
-}
+// Poll::Ready(Ok((*a_to_b_count, *b_to_a_count)))
+// }
+// }
// pub async fn copy_buf_bidirectional_with_timeout(
// a: &mut A,
@@ -298,8 +298,8 @@ pub async fn copy_buf_bidirectional_with_timeout(
b_to_a_timeout_duration: Duration,
) -> Result<(u64, u64), std::io::Error>
where
- A: AsyncRead + AsyncWrite + Unpin + Sized,
- B: AsyncRead + AsyncWrite + Unpin + Sized,
+ A: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
+ B: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
{
let (mut a_reader, mut a_writer) = tokio::io::split(a);
let (mut b_reader, mut b_writer) = tokio::io::split(b);
@@ -333,7 +333,7 @@ where
}
});
let b_to_a = tokio::spawn(async move {
- let download_total_size = 0;
+ let mut download_total_size = 0;
let mut buf = Vec::new();
buf.resize(size, 0);
loop {
@@ -366,7 +366,7 @@ where
Ok(Err(e)) => {
return Err(e);
}
- Err(e) => {
+ Err(_) => {
return Err(std::io::Error::new(std::io::ErrorKind::Other, ""));
}
};
@@ -375,7 +375,7 @@ where
Ok(Err(e)) => {
return Err(e);
}
- Err(e) => {
+ Err(_) => {
return Err(std::io::Error::new(std::io::ErrorKind::Other, ""));
}
};
diff --git a/clash_lib/src/proxy/mixed/mod.rs b/clash_lib/src/proxy/mixed/mod.rs
index 9cd2326cc..949f7cd7a 100644
--- a/clash_lib/src/proxy/mixed/mod.rs
+++ b/clash_lib/src/proxy/mixed/mod.rs
@@ -53,7 +53,7 @@ impl InboundListener for Listener {
loop {
let (socket, _) = listener.accept().await?;
- let mut socket = apply_tcp_options(socket)?;
+ let socket = apply_tcp_options(socket)?;
let mut p = [0; 1];
let n = socket.peek(&mut p).await?;
@@ -75,7 +75,7 @@ impl InboundListener for Listener {
};
tokio::spawn(async move {
- socks::handle_tcp(&mut sess, &mut socket, dispatcher, authenticator).await
+ socks::handle_tcp(&mut sess, socket, dispatcher, authenticator).await
});
}
diff --git a/clash_lib/src/proxy/socks/inbound/mod.rs b/clash_lib/src/proxy/socks/inbound/mod.rs
index 4b37c807f..78faeb372 100644
--- a/clash_lib/src/proxy/socks/inbound/mod.rs
+++ b/clash_lib/src/proxy/socks/inbound/mod.rs
@@ -83,7 +83,7 @@ impl InboundListener for Listener {
loop {
let (socket, _) = listener.accept().await?;
- let mut socket = apply_tcp_options(socket)?;
+ let socket = apply_tcp_options(socket)?;
let mut sess = Session {
network: Network::Tcp,
@@ -97,7 +97,7 @@ impl InboundListener for Listener {
let authenticator = self.authenticator.clone();
tokio::spawn(async move {
- handle_tcp(&mut sess, &mut socket, dispatcher, authenticator).await
+ handle_tcp(&mut sess, socket, dispatcher, authenticator).await
});
}
}
diff --git a/clash_lib/src/proxy/socks/inbound/stream.rs b/clash_lib/src/proxy/socks/inbound/stream.rs
index d518c8ca9..fbd5cf9a9 100644
--- a/clash_lib/src/proxy/socks/inbound/stream.rs
+++ b/clash_lib/src/proxy/socks/inbound/stream.rs
@@ -19,7 +19,7 @@ use tracing::{instrument, trace, warn};
#[instrument(skip(s, dispatcher, authenticator))]
pub async fn handle_tcp<'a>(
sess: &'a mut Session,
- s: &'a mut TcpStream,
+ mut s: TcpStream,
dispatcher: Arc,
authenticator: ThreadSafeAuthenticator,
) -> io::Result<()> {
@@ -117,7 +117,7 @@ pub async fn handle_tcp<'a>(
));
}
- let dst = SocksAddr::read_from(s).await?;
+ let dst = SocksAddr::read_from(& mut s).await?;
match buf[1] {
socks_command::CONNECT => {
From 45f89859d0c726e946d9170848d98d89e5dd223f Mon Sep 17 00:00:00 2001
From: xmh0511 <970252187@qq.com>
Date: Thu, 7 Dec 2023 14:24:29 +0800
Subject: [PATCH 3/4] 1:1 implements the original logic of stream forward
---
clash_lib/src/common/io.rs | 217 ++++++++++++++++++++++++++++++++-----
config.yaml | 2 +
2 files changed, 191 insertions(+), 28 deletions(-)
create mode 100644 config.yaml
diff --git a/clash_lib/src/common/io.rs b/clash_lib/src/common/io.rs
index c8d01317e..37bd7ae3d 100644
--- a/clash_lib/src/common/io.rs
+++ b/clash_lib/src/common/io.rs
@@ -290,6 +290,108 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
// .await
// }
+// #1 close the connection if no traffic can be read within the specified time
+// pub async fn copy_buf_bidirectional_with_timeout(
+// a: A,
+// b: B,
+// size: usize,
+// a_to_b_timeout_duration: Duration,
+// b_to_a_timeout_duration: Duration,
+// ) -> Result<(u64, u64), std::io::Error>
+// where
+// A: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
+// B: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
+// {
+// let (mut a_reader, mut a_writer) = tokio::io::split(a);
+// let (mut b_reader, mut b_writer) = tokio::io::split(b);
+// let a_to_b = tokio::spawn(async move {
+// let mut upload_total_size = 0;
+// let mut buf = Vec::new();
+// buf.resize(size, 0);
+// loop {
+// match tokio::time::timeout(a_to_b_timeout_duration, a_reader.read(&mut buf[..])).await {
+// Ok(Ok(size)) => {
+// if size == 0 {
+// return Ok(upload_total_size);
+// }
+// match b_writer.write_all(&buf[..size]).await {
+// Ok(_) => {
+// upload_total_size += size;
+// continue;
+// }
+// Err(e) => {
+// return Err(e);
+// }
+// }
+// }
+// Ok(Err(e)) => {
+// return Err(e);
+// }
+// Err(_) => {
+// return Ok(upload_total_size);
+// }
+// }
+// }
+// });
+// let b_to_a = tokio::spawn(async move {
+// let mut download_total_size = 0;
+// let mut buf = Vec::new();
+// buf.resize(size, 0);
+// loop {
+// match tokio::time::timeout(b_to_a_timeout_duration, b_reader.read(&mut buf[..])).await {
+// Ok(Ok(size)) => {
+// if size == 0 {
+// return Ok(download_total_size);
+// }
+// match a_writer.write_all(&buf[..size]).await {
+// Ok(_) => {
+// download_total_size += size;
+// continue;
+// }
+// Err(e) => {
+// return Err(e);
+// }
+// }
+// }
+// Ok(Err(e)) => {
+// return Err(e);
+// }
+// Err(_) => {
+// return Ok(download_total_size);
+// }
+// }
+// }
+// });
+// let up = match a_to_b.await {
+// Ok(Ok(up)) => up,
+// Ok(Err(e)) => {
+// return Err(e);
+// }
+// Err(_) => {
+// return Err(std::io::Error::new(std::io::ErrorKind::Other, ""));
+// }
+// };
+// let down = match b_to_a.await {
+// Ok(Ok(up)) => up,
+// Ok(Err(e)) => {
+// return Err(e);
+// }
+// Err(_) => {
+// return Err(std::io::Error::new(std::io::ErrorKind::Other, ""));
+// }
+// };
+// Ok((up as u64, down as u64))
+// }
+
+// #2 1:1 implement the original logic, that is, the other direction has a timeout requirement after one direction is ready
+// if both directions are pending, no timeout limition will be imposed on them
+
+enum DirectionMsg {
+ SetTimeOut,
+ Terminate(std::io::Error),
+ Done,
+}
+
pub async fn copy_buf_bidirectional_with_timeout(
a: A,
b: B,
@@ -303,62 +405,120 @@ where
{
let (mut a_reader, mut a_writer) = tokio::io::split(a);
let (mut b_reader, mut b_writer) = tokio::io::split(b);
+ let (a_to_b_msg_tx, mut a_to_b_msg_rx) = tokio::sync::mpsc::unbounded_channel::();
+ let (b_to_a_msg_tx, mut b_to_a_msg_rx) = tokio::sync::mpsc::unbounded_channel::();
let a_to_b = tokio::spawn(async move {
let mut upload_total_size = 0;
let mut buf = Vec::new();
buf.resize(size, 0);
+ let mut need_timeout = false;
loop {
- match tokio::time::timeout(a_to_b_timeout_duration, a_reader.read(&mut buf[..])).await {
- Ok(Ok(size)) => {
- if size == 0 {
- return Ok(upload_total_size);
+ tokio::select! {
+ msg = async {
+ if need_timeout{
+ tokio::time::sleep(a_to_b_timeout_duration).await;
+ DirectionMsg::Done
+ }else{
+ match a_to_b_msg_rx.recv().await{
+ Some(v)=>v,
+ None=>DirectionMsg::SetTimeOut // the other direction is done
+ }
+ }
+ } =>{
+ match msg{
+ DirectionMsg::SetTimeOut => {
+ need_timeout = true;
+ },
+ DirectionMsg::Terminate(e) =>{
+ return Err(e);
+ },
+ DirectionMsg::Done => return Ok(upload_total_size),
}
- match b_writer.write_all(&buf[..size]).await {
- Ok(_) => {
+ }
+ r = a_reader.read(&mut buf[..]) =>{
+ match r{
+ Ok(size)=>{
+ if size == 0{
+ // peer has shutdown
+ match b_writer.shutdown().await{
+ Ok(_)=>{
+ let _ = b_to_a_msg_tx.send(DirectionMsg::SetTimeOut);
+ }
+ Err(e)=>{
+ let _ = b_to_a_msg_tx.send(DirectionMsg::Terminate(e));
+ }
+ }
+ return Ok(upload_total_size);
+ }
+ if let Err(e) = b_writer.write_all(&buf[..size]).await{
+ return Err(e);
+ }
upload_total_size += size;
continue;
}
- Err(e) => {
+ Err(e)=>{
return Err(e);
}
}
}
- Ok(Err(e)) => {
- return Err(e);
- }
- Err(_) => {
- return Ok(upload_total_size);
- }
- }
+ };
}
});
let b_to_a = tokio::spawn(async move {
let mut download_total_size = 0;
let mut buf = Vec::new();
buf.resize(size, 0);
+ let mut need_timeout = false;
loop {
- match tokio::time::timeout(b_to_a_timeout_duration, b_reader.read(&mut buf[..])).await {
- Ok(Ok(size)) => {
- if size == 0 {
- return Ok(download_total_size);
+ tokio::select! {
+ msg = async {
+ if need_timeout{
+ tokio::time::sleep(b_to_a_timeout_duration).await;
+ DirectionMsg::Done
+ }else{
+ match b_to_a_msg_rx.recv().await{
+ Some(v)=>v,
+ None=>DirectionMsg::SetTimeOut // the other direction is done
+ }
+ }
+ } =>{
+ match msg{
+ DirectionMsg::SetTimeOut => {
+ need_timeout = true;
+ },
+ DirectionMsg::Terminate(e) =>{
+ return Err(e);
+ },
+ DirectionMsg::Done => return Ok(download_total_size),
}
- match a_writer.write_all(&buf[..size]).await {
- Ok(_) => {
+ }
+ r = b_reader.read(&mut buf[..]) =>{
+ match r{
+ Ok(size)=>{
+ if size == 0{
+ // peer has shutdown
+ match a_writer.shutdown().await{
+ Ok(_)=>{
+ let _ = a_to_b_msg_tx.send(DirectionMsg::SetTimeOut);
+ }
+ Err(e)=>{
+ let _ = a_to_b_msg_tx.send(DirectionMsg::Terminate(e));
+ }
+ }
+ return Ok(download_total_size);
+ }
+ if let Err(e) = a_writer.write_all(&buf[..size]).await{
+ return Err(e);
+ }
download_total_size += size;
continue;
}
- Err(e) => {
+ Err(e)=>{
return Err(e);
}
}
}
- Ok(Err(e)) => {
- return Err(e);
- }
- Err(_) => {
- return Ok(download_total_size);
- }
- }
+ };
}
});
let up = match a_to_b.await {
@@ -379,5 +539,6 @@ where
return Err(std::io::Error::new(std::io::ErrorKind::Other, ""));
}
};
+ //dbg!(up,down);
Ok((up as u64, down as u64))
}
diff --git a/config.yaml b/config.yaml
new file mode 100644
index 000000000..a45ddc832
--- /dev/null
+++ b/config.yaml
@@ -0,0 +1,2 @@
+port: 8080
+external-controller: 127.0.0.1:9090
From b8ef453ee9a2a37acbbd30d21920ef978f0be807 Mon Sep 17 00:00:00 2001
From: xmh0511 <970252187@qq.com>
Date: Thu, 7 Dec 2023 21:38:54 +0800
Subject: [PATCH 4/4] revamp the implementation of forwarding stream
---
clash_lib/src/common/io.rs | 166 +++++++++++++++----------------------
config.yaml | 2 -
2 files changed, 66 insertions(+), 102 deletions(-)
delete mode 100644 config.yaml
diff --git a/clash_lib/src/common/io.rs b/clash_lib/src/common/io.rs
index 37bd7ae3d..185196ee4 100644
--- a/clash_lib/src/common/io.rs
+++ b/clash_lib/src/common/io.rs
@@ -6,7 +6,11 @@
use std::time::Duration;
// use futures::ready;
+use futures::TryFutureExt;
+use std::io::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
+use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
+use tokio::task::JoinHandle;
// #[derive(Debug)]
// pub struct CopyBuffer {
@@ -386,7 +390,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
// #2 1:1 implement the original logic, that is, the other direction has a timeout requirement after one direction is ready
// if both directions are pending, no timeout limition will be imposed on them
-enum DirectionMsg {
+enum SyncMsg {
SetTimeOut,
Terminate(std::io::Error),
Done,
@@ -403,114 +407,96 @@ where
A: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
B: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
{
- let (mut a_reader, mut a_writer) = tokio::io::split(a);
- let (mut b_reader, mut b_writer) = tokio::io::split(b);
- let (a_to_b_msg_tx, mut a_to_b_msg_rx) = tokio::sync::mpsc::unbounded_channel::();
- let (b_to_a_msg_tx, mut b_to_a_msg_rx) = tokio::sync::mpsc::unbounded_channel::();
- let a_to_b = tokio::spawn(async move {
- let mut upload_total_size = 0;
- let mut buf = Vec::new();
- buf.resize(size, 0);
- let mut need_timeout = false;
- loop {
- tokio::select! {
- msg = async {
- if need_timeout{
- tokio::time::sleep(a_to_b_timeout_duration).await;
- DirectionMsg::Done
- }else{
- match a_to_b_msg_rx.recv().await{
- Some(v)=>v,
- None=>DirectionMsg::SetTimeOut // the other direction is done
- }
- }
- } =>{
- match msg{
- DirectionMsg::SetTimeOut => {
- need_timeout = true;
- },
- DirectionMsg::Terminate(e) =>{
- return Err(e);
- },
- DirectionMsg::Done => return Ok(upload_total_size),
- }
- }
- r = a_reader.read(&mut buf[..]) =>{
- match r{
- Ok(size)=>{
- if size == 0{
- // peer has shutdown
- match b_writer.shutdown().await{
- Ok(_)=>{
- let _ = b_to_a_msg_tx.send(DirectionMsg::SetTimeOut);
- }
- Err(e)=>{
- let _ = b_to_a_msg_tx.send(DirectionMsg::Terminate(e));
- }
- }
- return Ok(upload_total_size);
- }
- if let Err(e) = b_writer.write_all(&buf[..size]).await{
- return Err(e);
- }
- upload_total_size += size;
- continue;
- }
- Err(e)=>{
- return Err(e);
- }
- }
- }
- };
- }
- });
- let b_to_a = tokio::spawn(async move {
- let mut download_total_size = 0;
+ let (a_reader, a_writer) = tokio::io::split(a);
+ let (b_reader, b_writer) = tokio::io::split(b);
+ let (tx_for_a_to_b, rx_for_a_to_b) = tokio::sync::mpsc::unbounded_channel();
+ let (tx_for_b_to_a, rx_for_b_to_a) = tokio::sync::mpsc::unbounded_channel();
+ let a_to_b = copy_from_lhs_to_rhs(
+ a_reader,
+ b_writer,
+ size,
+ a_to_b_timeout_duration,
+ rx_for_a_to_b,
+ tx_for_b_to_a,
+ );
+ let b_to_a = copy_from_lhs_to_rhs(
+ b_reader,
+ a_writer,
+ size,
+ b_to_a_timeout_duration,
+ rx_for_b_to_a,
+ tx_for_a_to_b,
+ );
+
+ let up = a_to_b
+ .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, ""))
+ .await??;
+ let down = b_to_a
+ .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, ""))
+ .await??;
+ //dbg!(up,down);
+ Ok((up as u64, down as u64))
+}
+
+fn copy_from_lhs_to_rhs(
+ mut lhs_reader: A,
+ mut rhs_writer: B,
+ buf_size: usize,
+ timeout: Duration,
+ mut msg_rx: UnboundedReceiver,
+ other_side_sender: UnboundedSender,
+) -> JoinHandle>
+where
+ A: AsyncRead + Unpin + Sized + Send + 'static,
+ B: AsyncWrite + Unpin + Sized + Send + 'static,
+{
+ tokio::spawn(async move {
+ let mut transferred_size = 0;
let mut buf = Vec::new();
- buf.resize(size, 0);
+ buf.resize(buf_size, 0);
let mut need_timeout = false;
loop {
tokio::select! {
msg = async {
if need_timeout{
- tokio::time::sleep(b_to_a_timeout_duration).await;
- DirectionMsg::Done
+ tokio::time::sleep(timeout).await;
+ SyncMsg::Done
}else{
- match b_to_a_msg_rx.recv().await{
+ match msg_rx.recv().await{
Some(v)=>v,
- None=>DirectionMsg::SetTimeOut // the other direction is done
+ None=>SyncMsg::SetTimeOut // the other direction is done
}
}
} =>{
match msg{
- DirectionMsg::SetTimeOut => {
+ SyncMsg::SetTimeOut => {
need_timeout = true;
},
- DirectionMsg::Terminate(e) =>{
+ SyncMsg::Terminate(e) =>{
return Err(e);
},
- DirectionMsg::Done => return Ok(download_total_size),
+ SyncMsg::Done => return Ok(transferred_size),
}
}
- r = b_reader.read(&mut buf[..]) =>{
+ r = lhs_reader.read(&mut buf[..]) =>{
match r{
Ok(size)=>{
if size == 0{
// peer has shutdown
- match a_writer.shutdown().await{
+ match rhs_writer.shutdown().await{
Ok(_)=>{
- let _ = a_to_b_msg_tx.send(DirectionMsg::SetTimeOut);
+ let _ = other_side_sender.send(SyncMsg::SetTimeOut);
}
Err(e)=>{
- let _ = a_to_b_msg_tx.send(DirectionMsg::Terminate(e));
+ let _ = other_side_sender.send(SyncMsg::Terminate(e));
}
}
- return Ok(download_total_size);
+ return Ok(transferred_size);
}
- if let Err(e) = a_writer.write_all(&buf[..size]).await{
+ if let Err(e) = rhs_writer.write_all(&buf[..size]).await{
return Err(e);
}
- download_total_size += size;
+ transferred_size += size;
continue;
}
Err(e)=>{
@@ -520,25 +506,5 @@ where
}
};
}
- });
- let up = match a_to_b.await {
- Ok(Ok(up)) => up,
- Ok(Err(e)) => {
- return Err(e);
- }
- Err(_) => {
- return Err(std::io::Error::new(std::io::ErrorKind::Other, ""));
- }
- };
- let down = match b_to_a.await {
- Ok(Ok(up)) => up,
- Ok(Err(e)) => {
- return Err(e);
- }
- Err(_) => {
- return Err(std::io::Error::new(std::io::ErrorKind::Other, ""));
- }
- };
- //dbg!(up,down);
- Ok((up as u64, down as u64))
+ })
}
diff --git a/config.yaml b/config.yaml
deleted file mode 100644
index a45ddc832..000000000
--- a/config.yaml
+++ /dev/null
@@ -1,2 +0,0 @@
-port: 8080
-external-controller: 127.0.0.1:9090