@@ -37,6 +37,7 @@ const WAIT_FOR_EVENT: i32 = 1;
37
37
38
38
pub struct PipeListener {
39
39
first_instance : AtomicBool ,
40
+ shutting_down : AtomicBool ,
40
41
address : String ,
41
42
connection_event : isize ,
42
43
}
@@ -65,6 +66,7 @@ impl PipeListener {
65
66
let connection_event = create_event ( ) ?;
66
67
Ok ( PipeListener {
67
68
first_instance : AtomicBool :: new ( true ) ,
69
+ shutting_down : AtomicBool :: new ( false ) ,
68
70
address : sockaddr. to_string ( ) ,
69
71
connection_event
70
72
} )
@@ -98,6 +100,9 @@ impl PipeListener {
98
100
trace ! ( "listening for connection" ) ;
99
101
let result = unsafe { ConnectNamedPipe ( np. named_pipe , ol. as_mut_ptr ( ) ) } ;
100
102
if result != 0 {
103
+ if let Some ( error) = self . handle_shutdown ( & np) {
104
+ return Err ( error) ;
105
+ }
101
106
return Err ( io:: Error :: last_os_error ( ) ) ;
102
107
}
103
108
@@ -110,11 +115,17 @@ impl PipeListener {
110
115
return Err ( io:: Error :: last_os_error ( ) ) ;
111
116
}
112
117
_ => {
118
+ if let Some ( shutdown_signal) = self . handle_shutdown ( & np) {
119
+ return Err ( shutdown_signal) ;
120
+ }
113
121
Ok ( Some ( np) )
114
122
}
115
123
}
116
124
}
117
125
e if e. raw_os_error ( ) == Some ( ERROR_PIPE_CONNECTED as i32 ) => {
126
+ if let Some ( error) = self . handle_shutdown ( & np) {
127
+ return Err ( error) ;
128
+ }
118
129
Ok ( Some ( np) )
119
130
}
120
131
e => {
@@ -126,6 +137,17 @@ impl PipeListener {
126
137
}
127
138
}
128
139
140
+ fn handle_shutdown ( & self , np : & PipeConnection ) -> Option < io:: Error > {
141
+ if self . shutting_down . load ( Ordering :: SeqCst ) {
142
+ np. close ( ) . unwrap_or_else ( |err| trace ! ( "Failed to close the pipe {:?}" , err) ) ;
143
+ return Some ( io:: Error :: new (
144
+ io:: ErrorKind :: Other ,
145
+ "closing pipe" ,
146
+ ) ) ;
147
+ }
148
+ None
149
+ }
150
+
129
151
fn new_instance ( & self ) -> io:: Result < isize > {
130
152
let name = OsStr :: new ( & self . address . as_str ( ) )
131
153
. encode_wide ( )
@@ -153,6 +175,7 @@ impl PipeListener {
153
175
154
176
pub fn close ( & self ) -> Result < ( ) > {
155
177
// release the ConnectNamedPipe thread by signaling the event and clean up event handle
178
+ self . shutting_down . store ( true , Ordering :: SeqCst ) ;
156
179
set_event ( self . connection_event ) ?;
157
180
close_handle ( self . connection_event )
158
181
}
@@ -359,4 +382,77 @@ mod test {
359
382
}
360
383
}
361
384
}
385
+
386
+ #[ test]
387
+ fn should_accept_new_client ( ) {
388
+ let address = r"\\.\pipe\ttrpc-test-accept" ;
389
+ let listener = Arc :: new ( PipeListener :: new ( address) . unwrap ( ) ) ;
390
+
391
+ let listener_server = listener. clone ( ) ;
392
+ let thread = std:: thread:: spawn ( move || {
393
+ let quit_flag = Arc :: new ( AtomicBool :: new ( false ) ) ;
394
+ match listener_server. accept ( & quit_flag) {
395
+ Ok ( Some ( _) ) => {
396
+ // pipe is working
397
+ }
398
+ Ok ( None ) => {
399
+ assert ! ( false , "should get a working pipe" )
400
+ }
401
+ Err ( e) => {
402
+ assert ! ( false , "should not get error {}" , e. to_string( ) )
403
+ }
404
+ }
405
+ } ) ;
406
+
407
+ wait_socket_working ( address, 10 , 5 ) . unwrap ( ) ;
408
+ thread. join ( ) . unwrap ( ) ;
409
+ }
410
+
411
+ #[ test]
412
+ fn close_should_cancel_accept ( ) {
413
+ let listener = Arc :: new ( PipeListener :: new ( r"\\.\pipe\ttrpc-test-close" ) . unwrap ( ) ) ;
414
+
415
+ let listener_server = listener. clone ( ) ;
416
+ let thread = std:: thread:: spawn ( move || {
417
+ let quit_flag = Arc :: new ( AtomicBool :: new ( false ) ) ;
418
+ match listener_server. accept ( & quit_flag) {
419
+ Ok ( _) => {
420
+ assert ! ( false , "should not get pipe on close" )
421
+ }
422
+ Err ( e) => {
423
+ assert_eq ! ( e. to_string( ) , "closing pipe" )
424
+ }
425
+ }
426
+ } ) ;
427
+
428
+ // sleep for a moment to allow the pipe to start initialize and be ready to accept new connection.
429
+ // this simulates scenario where the thread is asleep and awaiting a connection
430
+ std:: thread:: sleep ( std:: time:: Duration :: from_millis ( 500 ) ) ;
431
+ listener. close ( ) . unwrap ( ) ;
432
+ thread. join ( ) . unwrap ( ) ;
433
+ }
434
+
435
+ fn wait_socket_working ( address : & str , interval_in_ms : u64 , count : u32 ) -> Result < ( ) > {
436
+ for _i in 0 ..count {
437
+ let client = match ClientConnection :: client_connect ( address) {
438
+ Ok ( c) => {
439
+ c
440
+ }
441
+ Err ( _) => {
442
+ std:: thread:: sleep ( std:: time:: Duration :: from_millis ( interval_in_ms) ) ;
443
+ continue ;
444
+ }
445
+ } ;
446
+
447
+ match client. get_pipe_connection ( ) {
448
+ Ok ( _) => {
449
+ return Ok ( ( ) ) ;
450
+ }
451
+ Err ( _) => {
452
+ std:: thread:: sleep ( std:: time:: Duration :: from_millis ( interval_in_ms) ) ;
453
+ }
454
+ }
455
+ }
456
+ Err ( Error :: Others ( "timed out" . to_string ( ) ) )
457
+ }
362
458
}
0 commit comments