@@ -10,8 +10,8 @@ use std::{
1010 task:: { Context , Poll } ,
1111 time:: Duration ,
1212} ;
13- use tokio:: sync:: { broadcast , oneshot} ;
14- use tracing:: { info, warn } ;
13+ use tokio:: sync:: oneshot;
14+ use tracing:: info;
1515
1616// 整个 文件的内容都是在模仿 AsyncRead 和 AsyncWrite 的实现,
1717// 只是加了一个 Addr 参数. 这一部分比较难懂.
@@ -330,18 +330,18 @@ pub const CP_UDP_TIMEOUT: time::Duration = Duration::from_secs(100); //todo: adj
330330pub const MAX_DATAGRAM_SIZE : usize = 65535 - 20 - 8 ;
331331pub const MTU : usize = 1400 ;
332332
333- async fn read_once < R1 : AddrReadTrait , W1 : AddrWriteTrait > (
334- r1 : & mut R1 ,
335- w1 : & mut W1 ,
333+ async fn rw_once < R : AddrReadTrait , W : AddrWriteTrait > (
334+ r : & mut R ,
335+ w : & mut W ,
336336 buf : & mut [ u8 ] ,
337337) -> io:: Result < usize > {
338- let ( u , a) = r1 . read ( buf) . await ?;
338+ let ( rn , a) = r . read ( buf) . await ?;
339339
340- let u = w1 . write ( & buf[ ..u ] , & a) . await ?;
340+ let wn = w . write ( & buf[ ..rn ] , & a) . await ?;
341341
342- let r = w1 . flush ( ) . await ;
342+ let r = w . flush ( ) . await ;
343343 match r {
344- Ok ( _) => Ok ( u ) ,
344+ Ok ( _) => Ok ( wn ) ,
345345 Err ( e) => Err ( e) ,
346346 }
347347}
@@ -355,7 +355,7 @@ pub async fn cp_addr<R: AddrReadTrait + 'static, W: AddrWriteTrait + 'static>(
355355 mut w : W ,
356356 name : String ,
357357 no_timeout : bool ,
358- mut shutdown_rx : broadcast :: Receiver < ( ) > ,
358+ mut shutdown_rx : oneshot :: Receiver < ( ) > ,
359359 is_d : bool ,
360360 opt : Option < Arc < GlobalTrafficRecorder > > ,
361361) -> Result < u64 , Error > {
@@ -365,78 +365,47 @@ pub async fn cp_addr<R: AddrReadTrait + 'static, W: AddrWriteTrait + 'static>(
365365 let mut whole_write = 0 ;
366366 let mut buf = Box :: new ( [ 0u8 ; MTU ] ) ;
367367
368- if no_timeout {
369- loop {
370- let rf = read_once ( & mut r, & mut w, buf. as_mut ( ) ) ;
371-
372- tokio:: select! {
373- r = rf =>{
374- match r {
375- Ok ( n) => whole_write+=n,
376- Err ( e) => {
377- match e. kind( ) {
378- io:: ErrorKind :: Other => {
379- debug!( "cp_addr got other e, will continue. {e}" ) ;
380- continue ;
381- } ,
382- _ => {
383- warn!( name = name, "cp_addr got e, will break {e}" ) ;
384- } ,
385- }
386-
387- break
388- } ,
389- }
390- }
391-
392- _ = shutdown_rx. recv( ) =>{
393- info!( "cp_addr got shutdown_rx, will break" ) ;
394-
395- break ;
368+ loop {
369+ tokio:: select! {
370+ r = rw_once( & mut r, & mut w, buf. as_mut( ) ) =>{
371+ match r {
372+ Ok ( n) => whole_write+=n,
373+ Err ( e) => {
374+ match e. kind( ) {
375+ io:: ErrorKind :: Other => {
376+ debug!( "cp_addr got other e, will continue: {e}" ) ;
377+ continue ;
378+ } ,
379+ _ => {
380+ // udp timeout 时 常会发生, 因此不能认为是错误
381+ info!( name = name, "cp_addr got e, will break: {e}" ) ;
382+ } ,
383+ }
384+
385+ break
386+ } ,
396387 }
397388 }
398- tokio:: task:: yield_now ( ) . await ; //necessary, or it is likely to cause stuck issue
399- } //loop
400- } else {
401- loop {
402- let rf = read_once ( & mut r, & mut w, buf. as_mut ( ) ) ;
403- let timeout_f = tokio:: time:: sleep ( CP_UDP_TIMEOUT ) ;
404-
405- tokio:: select! {
406- r = rf =>{
407- match r {
408- Ok ( n) => whole_write+=n,
409- Err ( e) => {
410- match e. kind( ) {
411- io:: ErrorKind :: Other => {
412- debug!( "cp_addr got other e, will continue: {e}" ) ;
413- continue ;
414- } ,
415- _ => {
416- // udp timeout 时 常会发生, 因此不能认为是错误
417- info!( name = name, "cp_addr got e, will break: {e}" ) ;
418- } ,
419- }
420-
421- break
422- } ,
423- }
389+ _ = async {
390+ if no_timeout{
391+ std:: future:: pending( ) . await
392+ } else{
393+ tokio:: time:: sleep( CP_UDP_TIMEOUT ) . await
424394 }
425- _ = timeout_f =>{
426- info!( timeout = ?CP_UDP_TIMEOUT , "cp_addr got timeout, will break" ) ;
395+ } =>{
396+ info!( timeout = ?CP_UDP_TIMEOUT , "cp_addr got timeout, will break" ) ;
427397
428- break ;
429- }
398+ break ;
399+ }
430400
431- _ = shutdown_rx. recv ( ) =>{
432- info!( "cp_addr got shutdown_rx, will break" ) ;
401+ _ = & mut shutdown_rx =>{
402+ info!( "cp_addr got shutdown_rx, will break" ) ;
433403
434- break ;
435- }
404+ break ;
436405 }
437- tokio :: task :: yield_now ( ) . await ; //necessary, or it is likely to cause stuck issue
438- } //loop
439- }
406+ }
407+ tokio :: task :: yield_now ( ) . await ; //necessary, or it is likely to cause stuck issue
408+ } //loop
440409
441410 let l = whole_write as u64 ;
442411 if let Some ( a) = opt {
@@ -462,18 +431,14 @@ pub async fn cp(
462431 shutdown_rx1 : Option < tokio:: sync:: oneshot:: Receiver < ( ) > > ,
463432 shutdown_rx2 : Option < tokio:: sync:: oneshot:: Receiver < ( ) > > ,
464433) -> Result < u64 , Error > {
465- let name1 = c1. cached_name . clone ( ) ;
466- let name2 = c2. cached_name . clone ( ) ;
467- let r1 = c1. r ;
468- let w1 = c1. w ;
469- let r2 = c2. r ;
470- let w2 = c2. w ;
434+ let n1 = c1. cached_name . clone ( ) + " to " + & c2. cached_name ;
435+ let n2 = c2. cached_name . clone ( ) + " to " + & c1. cached_name ;
471436
472- let ( tx1, rx1) = broadcast :: channel ( 10 ) ;
473- let ( tx2, rx2) = broadcast :: channel ( 10 ) ;
437+ let ( tx1, rx1) = oneshot :: channel ( ) ;
438+ let ( tx2, rx2) = oneshot :: channel ( ) ;
474439
475- let tx1c = tx1 . clone ( ) ;
476- let tx2c = tx1 . clone ( ) ;
440+ let cp1 = tokio :: spawn ( cp_addr ( c1 . r , c2 . w , n1 , no_timeout , rx1 , false , opt . clone ( ) ) ) ;
441+ let cp2 = tokio :: spawn ( cp_addr ( c2 . r , c1 . w , n2 , no_timeout , rx2 , true , opt . clone ( ) ) ) ;
477442
478443 let ( _tmpx0, tmp_rx0) = oneshot:: channel ( ) ;
479444 let shutdown_rx1 = if let Some ( x) = shutdown_rx1 {
@@ -489,20 +454,12 @@ pub async fn cp(
489454 tmp_rx
490455 } ;
491456
492- let n1 = name1. clone ( ) + " to " + & name2;
493- let n2 = name2 + " to " + & name1;
494-
495- let o1 = opt. clone ( ) ;
496- let o2 = opt. clone ( ) ;
497- let cp1 = tokio:: spawn ( cp_addr ( r1, w2, n1, no_timeout, rx1, false , o1) ) ;
498- let cp2 = tokio:: spawn ( cp_addr ( r2, w1, n2, no_timeout, rx2, true , o2) ) ;
499-
500457 let r = tokio:: select! {
501458 r = cp1 =>{
502459 if tracing:: enabled!( tracing:: Level :: DEBUG ) {
503- debug!( cid = %cid, "cp_addr end, u" ) ;
460+ debug!( cid = %cid, "addr_conn::cp end, u" ) ;
504461 }
505- let _ = tx1c . send( ( ) ) ;
462+ let _ = tx1 . send( ( ) ) ;
506463 let _ = tx2. send( ( ) ) ;
507464
508465 match r{
@@ -512,10 +469,10 @@ pub async fn cp(
512469 }
513470 r = cp2 =>{
514471 if tracing:: enabled!( tracing:: Level :: DEBUG ) {
515- debug!( cid = %cid, "cp_addr end, d" ) ;
472+ debug!( cid = %cid, "addr_conn::cp end, d" ) ;
516473 }
517- let _ = tx1c . send( ( ) ) ;
518- let _ = tx2c . send( ( ) ) ;
474+ let _ = tx1 . send( ( ) ) ;
475+ let _ = tx2 . send( ( ) ) ;
519476
520477 match r{
521478 Ok ( r) => r. unwrap_or( 0 ) ,
@@ -525,17 +482,16 @@ pub async fn cp(
525482 _ = shutdown_rx1 =>{
526483 debug!( "addrconn cp_between got shutdown1 signal" ) ;
527484
528- let _ = tx1c . send( ( ) ) ;
529- let _ = tx2c . send( ( ) ) ;
485+ let _ = tx1 . send( ( ) ) ;
486+ let _ = tx2 . send( ( ) ) ;
530487
531488 0
532489 }
533490
534491 _ = shutdown_rx2 =>{
535492 debug!( "addrconn cp_between got shutdown2 signal" ) ;
536-
537- let _ = tx1c. send( ( ) ) ;
538- let _ = tx2c. send( ( ) ) ;
493+ let _ = tx1. send( ( ) ) ;
494+ let _ = tx2. send( ( ) ) ;
539495
540496 0
541497 }
0 commit comments