@@ -8,12 +8,15 @@ use std::convert::TryInto;
8
8
use std:: os:: unix:: io:: RawFd ;
9
9
use std:: sync:: { Arc , Mutex } ;
10
10
11
+ use async_trait:: async_trait;
11
12
use nix:: unistd:: close;
12
- use tokio:: { self , io :: split , sync:: mpsc, sync :: Notify } ;
13
+ use tokio:: { self , sync:: mpsc, task } ;
13
14
14
15
use crate :: common:: client_connect;
15
16
use crate :: error:: { Error , Result } ;
16
17
use crate :: proto:: { Code , Codec , GenMessage , Message , Request , Response , MESSAGE_TYPE_RESPONSE } ;
18
+ use crate :: r#async:: connection:: * ;
19
+ use crate :: r#async:: shutdown;
17
20
use crate :: r#async:: stream:: { ResultReceiver , ResultSender } ;
18
21
use crate :: r#async:: utils;
19
22
@@ -36,122 +39,12 @@ impl Client {
36
39
pub fn new ( fd : RawFd ) -> Client {
37
40
let stream = utils:: new_unix_stream_from_raw_fd ( fd) ;
38
41
39
- let ( mut reader, mut writer) = split ( stream) ;
40
- let ( req_tx, mut rx) : ( RequestSender , RequestReceiver ) = mpsc:: channel ( 100 ) ;
42
+ let ( req_tx, rx) : ( RequestSender , RequestReceiver ) = mpsc:: channel ( 100 ) ;
41
43
42
- let req_map = Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ;
43
- let req_map2 = req_map. clone ( ) ;
44
-
45
- let notify = Arc :: new ( Notify :: new ( ) ) ;
46
- let notify2 = notify. clone ( ) ;
47
-
48
- // Request sender
49
- let request_sender = tokio:: spawn ( async move {
50
- let mut stream_id: u32 = 1 ;
51
-
52
- while let Some ( ( mut msg, resp_tx) ) = rx. recv ( ) . await {
53
- let current_stream_id = stream_id;
54
- msg. header . set_stream_id ( current_stream_id) ;
55
- stream_id += 2 ;
56
-
57
- {
58
- let mut map = req_map2. lock ( ) . unwrap ( ) ;
59
- map. insert ( current_stream_id, resp_tx. clone ( ) ) ;
60
- }
61
-
62
- if let Err ( e) = msg. write_to ( & mut writer) . await {
63
- error ! ( "write_message got error: {:?}" , e) ;
64
-
65
- {
66
- let mut map = req_map2. lock ( ) . unwrap ( ) ;
67
- map. remove ( & current_stream_id) ;
68
- }
44
+ let delegate = ClientBuilder { rx : Some ( rx) } ;
69
45
70
- let e = Error :: Socket ( format ! ( "{:?}" , e) ) ;
71
- resp_tx
72
- . send ( Err ( e) )
73
- . await
74
- . unwrap_or_else ( |_e| error ! ( "The request has returned" ) ) ;
75
-
76
- break ; // The stream is dead, exit the loop.
77
- }
78
- }
79
-
80
- // rx.recv will abort when client.req_tx and client is dropped.
81
- // notify the response-receiver to quit at this time.
82
- notify. notify_one ( ) ;
83
- } ) ;
84
-
85
- // Response receiver
86
- tokio:: spawn ( async move {
87
- loop {
88
- tokio:: select! {
89
- _ = notify2. notified( ) => {
90
- break ;
91
- }
92
- res = GenMessage :: read_from( & mut reader) => {
93
- match res {
94
- Ok ( msg) => {
95
- trace!( "Got Message body {:?}" , msg. payload) ;
96
- let req_map = req_map. clone( ) ;
97
- tokio:: spawn( async move {
98
- let resp_tx2;
99
- {
100
- let mut map = req_map. lock( ) . unwrap( ) ;
101
- let resp_tx = match map. get( & msg. header. stream_id) {
102
- Some ( tx) => tx,
103
- None => {
104
- debug!(
105
- "Receiver got unknown packet {:?}" ,
106
- msg
107
- ) ;
108
- return ;
109
- }
110
- } ;
111
-
112
- resp_tx2 = resp_tx. clone( ) ;
113
- map. remove( & msg. header. stream_id) ; // Forget the result, just remove.
114
- }
115
-
116
- if msg. header. type_ != MESSAGE_TYPE_RESPONSE {
117
- resp_tx2
118
- . send( Err ( Error :: Others ( format!(
119
- "Recver got malformed packet {:?}" ,
120
- msg
121
- ) ) ) )
122
- . await
123
- . unwrap_or_else( |_e| error!( "The request has returned" ) ) ;
124
- return ;
125
- }
126
-
127
- resp_tx2. send( Ok ( msg) ) . await . unwrap_or_else( |_e| error!( "The request has returned" ) ) ;
128
- } ) ;
129
- }
130
- Err ( e) => {
131
- debug!( "Connection closed by the ttRPC server: {}" , e) ;
132
-
133
- // Abort the request sender task to prevent incoming RPC requests
134
- // from being processed.
135
- request_sender. abort( ) ;
136
- let _ = request_sender. await ;
137
-
138
- // Take all items out of `req_map`.
139
- let mut map = std:: mem:: take( & mut * req_map. lock( ) . unwrap( ) ) ;
140
- // Terminate outstanding RPC requests with the error.
141
- for ( _stream_id, resp_tx) in map. drain( ) {
142
- if let Err ( _e) = resp_tx. send( Err ( e. clone( ) ) ) . await {
143
- warn!( "Failed to terminate pending RPC: \
144
- the request has returned") ;
145
- }
146
- }
147
-
148
- break ;
149
- }
150
- }
151
- }
152
- } ;
153
- }
154
- } ) ;
46
+ let conn = Connection :: new ( stream, delegate) ;
47
+ tokio:: spawn ( async move { conn. run ( ) . await } ) ;
155
48
156
49
Client { req_tx }
157
50
}
@@ -208,3 +101,140 @@ impl Drop for ClientClose {
208
101
trace ! ( "All client is droped" ) ;
209
102
}
210
103
}
104
+
105
+ #[ derive( Debug ) ]
106
+ struct ClientBuilder {
107
+ rx : Option < RequestReceiver > ,
108
+ }
109
+
110
+ impl Builder for ClientBuilder {
111
+ type Reader = ClientReader ;
112
+ type Writer = ClientWriter ;
113
+
114
+ fn build ( & mut self ) -> ( Self :: Reader , Self :: Writer ) {
115
+ let ( notifier, waiter) = shutdown:: new ( ) ;
116
+ let req_map = Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ;
117
+ (
118
+ ClientReader {
119
+ shutdown_waiter : waiter,
120
+ req_map : req_map. clone ( ) ,
121
+ } ,
122
+ ClientWriter {
123
+ stream_id : 1 ,
124
+ rx : self . rx . take ( ) . unwrap ( ) ,
125
+ shutdown_notifier : notifier,
126
+ req_map,
127
+ } ,
128
+ )
129
+ }
130
+ }
131
+
132
+ struct ClientWriter {
133
+ stream_id : u32 ,
134
+ rx : RequestReceiver ,
135
+ shutdown_notifier : shutdown:: Notifier ,
136
+ req_map : Arc < Mutex < HashMap < u32 , ResultSender > > > ,
137
+ }
138
+
139
+ #[ async_trait]
140
+ impl WriterDelegate for ClientWriter {
141
+ async fn recv ( & mut self ) -> Option < GenMessage > {
142
+ if let Some ( ( mut msg, resp_tx) ) = self . rx . recv ( ) . await {
143
+ let current_stream_id = self . stream_id ;
144
+ msg. header . set_stream_id ( current_stream_id) ;
145
+ self . stream_id += 2 ;
146
+ {
147
+ let mut map = self . req_map . lock ( ) . unwrap ( ) ;
148
+ map. insert ( current_stream_id, resp_tx) ;
149
+ }
150
+ return Some ( msg) ;
151
+ } else {
152
+ return None ;
153
+ }
154
+ }
155
+
156
+ async fn disconnect ( & self , msg : & GenMessage , e : Error ) {
157
+ let resp_tx = {
158
+ let mut map = self . req_map . lock ( ) . unwrap ( ) ;
159
+ map. remove ( & msg. header . stream_id )
160
+ } ;
161
+
162
+ if let Some ( resp_tx) = resp_tx {
163
+ let e = Error :: Socket ( format ! ( "{:?}" , e) ) ;
164
+ resp_tx
165
+ . send ( Err ( e) )
166
+ . await
167
+ . unwrap_or_else ( |_e| error ! ( "The request has returned" ) ) ;
168
+ }
169
+ }
170
+
171
+ async fn exit ( & self ) {
172
+ self . shutdown_notifier . shutdown ( ) ;
173
+ }
174
+ }
175
+
176
+ struct ClientReader {
177
+ shutdown_waiter : shutdown:: Waiter ,
178
+ req_map : Arc < Mutex < HashMap < u32 , ResultSender > > > ,
179
+ }
180
+
181
+ #[ async_trait]
182
+ impl ReaderDelegate for ClientReader {
183
+ async fn wait_shutdown ( & self ) {
184
+ self . shutdown_waiter . wait_shutdown ( ) . await
185
+ }
186
+
187
+ async fn disconnect ( & self , e : Error , sender : & mut task:: JoinHandle < ( ) > ) {
188
+ // Abort the request sender task to prevent incoming RPC requests
189
+ // from being processed.
190
+ sender. abort ( ) ;
191
+ let _ = sender. await ;
192
+
193
+ // Take all items out of `req_map`.
194
+ let mut map = std:: mem:: take ( & mut * self . req_map . lock ( ) . unwrap ( ) ) ;
195
+ // Terminate outstanding RPC requests with the error.
196
+ for ( _stream_id, resp_tx) in map. drain ( ) {
197
+ if let Err ( _e) = resp_tx. send ( Err ( e. clone ( ) ) ) . await {
198
+ warn ! ( "Failed to terminate pending RPC: the request has returned" ) ;
199
+ }
200
+ }
201
+ }
202
+
203
+ async fn exit ( & self ) { }
204
+
205
+ async fn handle_msg ( & self , msg : GenMessage ) {
206
+ let req_map = self . req_map . clone ( ) ;
207
+ tokio:: spawn ( async move {
208
+ let resp_tx2;
209
+ {
210
+ let mut map = req_map. lock ( ) . unwrap ( ) ;
211
+ let resp_tx = match map. get ( & msg. header . stream_id ) {
212
+ Some ( tx) => tx,
213
+ None => {
214
+ debug ! ( "Receiver got unknown packet {:?}" , msg) ;
215
+ return ;
216
+ }
217
+ } ;
218
+
219
+ resp_tx2 = resp_tx. clone ( ) ;
220
+ map. remove ( & msg. header . stream_id ) ; // Forget the result, just remove.
221
+ }
222
+
223
+ if msg. header . type_ != MESSAGE_TYPE_RESPONSE {
224
+ resp_tx2
225
+ . send ( Err ( Error :: Others ( format ! (
226
+ "Recver got malformed packet {:?}" ,
227
+ msg
228
+ ) ) ) )
229
+ . await
230
+ . unwrap_or_else ( |_e| error ! ( "The request has returned" ) ) ;
231
+ return ;
232
+ }
233
+
234
+ resp_tx2
235
+ . send ( Ok ( msg) )
236
+ . await
237
+ . unwrap_or_else ( |_e| error ! ( "The request has returned" ) ) ;
238
+ } ) ;
239
+ }
240
+ }
0 commit comments