1- use anyhow:: { Context , Result } ;
2- use async_trait:: async_trait;
1+ use anyhow:: Result ;
32use serde:: { Deserialize , Serialize } ;
4- use std:: {
5- path:: { Path , PathBuf } ,
6- time:: Duration ,
7- } ;
3+ use std:: path:: Path ;
4+ use std:: time:: Duration ;
85use tokio:: io:: { AsyncBufReadExt , AsyncWriteExt , BufReader } ;
6+ use tokio:: net:: UnixStream ;
97
108#[ derive( Clone , Copy , Debug ) ]
11- pub ( crate ) struct ConnectionConfig {
9+ struct ConnectionConfig {
1210 max_retries : u16 ,
1311 initial_delay_ms : u16 ,
1412 max_delay_ms : u32 ,
@@ -26,43 +24,23 @@ impl Default for ConnectionConfig {
2624 }
2725}
2826
29- #[ async_trait]
30- pub trait ConnectionTrait : Send {
31- async fn write_all ( & mut self , buf : & [ u8 ] ) -> Result < ( ) > ;
32- async fn read_line ( & mut self , buf : & mut String ) -> Result < usize > ;
33- }
34-
35- pub struct Connection {
36- inner : UnixConnection ,
37- }
38-
39- #[ cfg( unix) ]
40- pub struct UnixConnection {
41- stream : tokio:: net:: UnixStream ,
27+ #[ derive( Debug ) ]
28+ struct Connection {
29+ stream : UnixStream ,
4230}
4331
4432impl Connection {
45- pub async fn connect ( path : & Path ) -> Result < Box < dyn ConnectionTrait > > {
33+ async fn connect ( path : & Path ) -> Result < Self > {
4634 Self :: connect_with_config ( path, ConnectionConfig :: default ( ) ) . await
4735 }
4836
49- pub ( crate ) async fn connect_with_config (
50- path : & Path ,
51- config : ConnectionConfig ,
52- ) -> Result < Box < dyn ConnectionTrait > > {
37+ async fn connect_with_config ( path : & Path , config : ConnectionConfig ) -> Result < Self > {
5338 let mut current_delay = u64:: from ( config. initial_delay_ms ) ;
5439 let mut last_error = None ;
5540
5641 for attempt in 0 ..config. max_retries {
57- let result = {
58- let stream = tokio:: net:: UnixStream :: connect ( path) . await ;
59- stream
60- . map ( |s| Box :: new ( UnixConnection { stream : s } ) as Box < dyn ConnectionTrait > )
61- . context ( "Failed to connect to Unix socket" )
62- } ;
63-
64- match result {
65- Ok ( connection) => return Ok ( connection) ,
42+ match UnixStream :: connect ( path) . await {
43+ Ok ( stream) => return Ok ( Self { stream } ) ,
6644 Err ( e) => {
6745 last_error = Some ( e) ;
6846
@@ -76,15 +54,11 @@ impl Connection {
7654 }
7755 }
7856
79- Err ( last_error. unwrap_or_else ( || {
80- anyhow :: anyhow! ( "Failed to connect after {} attempts" , config . max_retries )
81- } ) )
57+ Err ( last_error
58+ . unwrap_or_else ( || std :: io :: Error :: new ( std :: io :: ErrorKind :: Other , "Unknown error" ) )
59+ . into ( ) )
8260 }
83- }
8461
85- #[ cfg( unix) ]
86- #[ async_trait]
87- impl ConnectionTrait for UnixConnection {
8862 async fn write_all ( & mut self , buf : & [ u8 ] ) -> Result < ( ) > {
8963 self . stream . write_all ( buf) . await ?;
9064 Ok ( ( ) )
@@ -103,9 +77,11 @@ pub struct Message<T> {
10377 pub content : T ,
10478}
10579
80+ #[ derive( Debug ) ]
10681pub struct Client {
107- connection : Box < dyn ConnectionTrait > ,
82+ connection : Connection ,
10883 message_id : u64 ,
84+ #[ cfg( test) ]
10985 socket_path : PathBuf ,
11086}
11187
@@ -115,6 +91,7 @@ impl Client {
11591 Ok ( Self {
11692 connection,
11793 message_id : 0 ,
94+ #[ cfg( test) ]
11895 socket_path : path. to_owned ( ) ,
11996 } )
12097 }
@@ -339,6 +316,7 @@ mod client_tests {
339316 use super :: * ;
340317 use std:: sync:: { Arc , Mutex } ;
341318
319+ #[ derive( Debug ) ]
342320 struct MockConnection {
343321 written : Arc < Mutex < Vec < u8 > > > ,
344322 responses : Vec < Result < String > > ,
@@ -355,8 +333,7 @@ mod client_tests {
355333 }
356334 }
357335
358- #[ async_trait:: async_trait]
359- impl crate :: client:: ConnectionTrait for MockConnection {
336+ impl MockConnection {
360337 async fn write_all ( & mut self , buf : & [ u8 ] ) -> Result < ( ) > {
361338 if self . response_index >= self . responses . len ( ) {
362339 return Err ( anyhow:: anyhow!( "Connection closed" ) ) ;
@@ -378,17 +355,56 @@ mod client_tests {
378355 }
379356 }
380357
381- #[ derive( Serialize , Deserialize , Debug , PartialEq ) ]
382- struct TestMessage {
383- value : String ,
358+ #[ derive( Debug ) ]
359+ struct TestClient {
360+ connection : MockConnection ,
361+ message_id : u64 ,
362+ socket_path : PathBuf ,
384363 }
385364
386- fn create_test_client ( mock_conn : MockConnection ) -> Client {
387- Client {
388- connection : Box :: new ( mock_conn) ,
389- message_id : 0 ,
390- socket_path : PathBuf :: from ( "/test/socket" ) ,
365+ impl TestClient {
366+ fn new ( mock_conn : MockConnection ) -> Self {
367+ Self {
368+ connection : mock_conn,
369+ message_id : 0 ,
370+ socket_path : PathBuf :: from ( "/test/socket" ) ,
371+ }
391372 }
373+
374+ async fn send < T , R > ( & mut self , content : T ) -> Result < R >
375+ where
376+ T : Serialize ,
377+ R : for < ' de > Deserialize < ' de > ,
378+ {
379+ self . message_id += 1 ;
380+ let message = Message {
381+ id : self . message_id ,
382+ content,
383+ } ;
384+
385+ let msg = serde_json:: to_string ( & message) ? + "\n " ;
386+ self . connection . write_all ( msg. as_bytes ( ) ) . await ?;
387+
388+ let mut buffer = String :: new ( ) ;
389+ self . connection . read_line ( & mut buffer) . await ?;
390+
391+ let response: Message < R > = serde_json:: from_str ( & buffer) ?;
392+
393+ if response. id != self . message_id {
394+ return Err ( anyhow:: anyhow!(
395+ "Message ID mismatch. Expected {}, got {}" ,
396+ self . message_id,
397+ response. id
398+ ) ) ;
399+ }
400+
401+ Ok ( response. content )
402+ }
403+ }
404+
405+ #[ derive( Serialize , Deserialize , Debug , PartialEq ) ]
406+ struct TestMessage {
407+ value : String ,
392408 }
393409
394410 #[ tokio:: test]
@@ -397,7 +413,7 @@ mod client_tests {
397413 r#"{"id":1,"content":{"value":"response"}}"# . to_string( )
398414 ) ] ) ;
399415
400- let mut client = create_test_client ( mock_conn) ;
416+ let mut client = TestClient :: new ( mock_conn) ;
401417
402418 let request = TestMessage {
403419 value : "test" . to_string ( ) ,
@@ -413,7 +429,7 @@ mod client_tests {
413429 async fn test_connection_error ( ) {
414430 let mock_conn = MockConnection :: new ( vec ! [ Err ( anyhow:: anyhow!( "Connection error" ) ) ] ) ;
415431
416- let mut client = create_test_client ( mock_conn) ;
432+ let mut client = TestClient :: new ( mock_conn) ;
417433
418434 let request = TestMessage {
419435 value : "test" . to_string ( ) ,
@@ -429,7 +445,7 @@ mod client_tests {
429445 r#"{"id":2,"content":{"value":"response"}}"# . to_string( )
430446 ) ] ) ;
431447
432- let mut client = create_test_client ( mock_conn) ;
448+ let mut client = TestClient :: new ( mock_conn) ;
433449
434450 let request = TestMessage {
435451 value : "test" . to_string ( ) ,
@@ -447,7 +463,7 @@ mod client_tests {
447463 async fn test_invalid_json_response ( ) {
448464 let mock_conn = MockConnection :: new ( vec ! [ Ok ( "invalid json" . to_string( ) ) ] ) ;
449465
450- let mut client = create_test_client ( mock_conn) ;
466+ let mut client = TestClient :: new ( mock_conn) ;
451467
452468 let request = TestMessage {
453469 value : "test" . to_string ( ) ,
@@ -463,7 +479,7 @@ mod client_tests {
463479 Ok ( r#"{"id":2,"content":{"value":"response2"}}"# . to_string( ) ) ,
464480 ] ) ;
465481
466- let mut client = create_test_client ( mock_conn) ;
482+ let mut client = TestClient :: new ( mock_conn) ;
467483
468484 let request1 = TestMessage {
469485 value : "test1" . to_string ( ) ,
0 commit comments