@@ -201,3 +201,197 @@ where
201
201
}
202
202
}
203
203
}
204
+
205
+ #[ cfg( test) ]
206
+ mod tests {
207
+ use super :: * ;
208
+ use crate :: listener_select_proto;
209
+ use async_std:: future:: timeout;
210
+ use async_std:: net:: { TcpListener , TcpStream } ;
211
+ use log:: info;
212
+ use quickcheck:: { Arbitrary , Gen , GenRange } ;
213
+ use std:: time:: Duration ;
214
+
215
+ #[ test]
216
+ fn select_proto_basic ( ) {
217
+ async fn run ( version : Version ) {
218
+ let ( client_connection, server_connection) = futures_ringbuf:: Endpoint :: pair ( 100 , 100 ) ;
219
+
220
+ let server = async_std:: task:: spawn ( async move {
221
+ let protos = vec ! [ "/proto1" , "/proto2" ] ;
222
+ let ( proto, mut io) = listener_select_proto ( server_connection, protos)
223
+ . await
224
+ . unwrap ( ) ;
225
+ assert_eq ! ( proto, "/proto2" ) ;
226
+
227
+ let mut out = vec ! [ 0 ; 32 ] ;
228
+ let n = io. read ( & mut out) . await . unwrap ( ) ;
229
+ out. truncate ( n) ;
230
+ assert_eq ! ( out, b"ping" ) ;
231
+
232
+ io. write_all ( b"pong" ) . await . unwrap ( ) ;
233
+ io. flush ( ) . await . unwrap ( ) ;
234
+ } ) ;
235
+
236
+ let client = async_std:: task:: spawn ( async move {
237
+ let protos = vec ! [ "/proto3" , "/proto2" ] ;
238
+ let ( proto, mut io) = dialer_select_proto ( client_connection, protos, version)
239
+ . await
240
+ . unwrap ( ) ;
241
+ assert_eq ! ( proto, "/proto2" ) ;
242
+
243
+ io. write_all ( b"ping" ) . await . unwrap ( ) ;
244
+ io. flush ( ) . await . unwrap ( ) ;
245
+
246
+ let mut out = vec ! [ 0 ; 32 ] ;
247
+ let n = io. read ( & mut out) . await . unwrap ( ) ;
248
+ out. truncate ( n) ;
249
+ assert_eq ! ( out, b"pong" ) ;
250
+ } ) ;
251
+
252
+ server. await ;
253
+ client. await ;
254
+ }
255
+
256
+ async_std:: task:: block_on ( run ( Version :: V1 ) ) ;
257
+ async_std:: task:: block_on ( run ( Version :: V1Lazy ) ) ;
258
+ }
259
+
260
+ /// Tests the expected behaviour of failed negotiations.
261
+ #[ test]
262
+ fn negotiation_failed ( ) {
263
+ fn prop (
264
+ version : Version ,
265
+ DialerProtos ( dial_protos) : DialerProtos ,
266
+ ListenerProtos ( listen_protos) : ListenerProtos ,
267
+ DialPayload ( dial_payload) : DialPayload ,
268
+ ) {
269
+ let _ = env_logger:: try_init ( ) ;
270
+
271
+ async_std:: task:: block_on ( async move {
272
+ let listener = TcpListener :: bind ( "0.0.0.0:0" ) . await . unwrap ( ) ;
273
+ let addr = listener. local_addr ( ) . unwrap ( ) ;
274
+
275
+ let server = async_std:: task:: spawn ( async move {
276
+ let server_connection = listener. accept ( ) . await . unwrap ( ) . 0 ;
277
+
278
+ let io = match timeout (
279
+ Duration :: from_secs ( 2 ) ,
280
+ listener_select_proto ( server_connection, listen_protos) ,
281
+ )
282
+ . await
283
+ . unwrap ( )
284
+ {
285
+ Ok ( ( _, io) ) => io,
286
+ Err ( NegotiationError :: Failed ) => return ,
287
+ Err ( NegotiationError :: ProtocolError ( e) ) => {
288
+ panic ! ( "Unexpected protocol error {e}" )
289
+ }
290
+ } ;
291
+ match io. complete ( ) . await {
292
+ Err ( NegotiationError :: Failed ) => { }
293
+ _ => panic ! ( ) ,
294
+ }
295
+ } ) ;
296
+
297
+ let client = async_std:: task:: spawn ( async move {
298
+ let client_connection = TcpStream :: connect ( addr) . await . unwrap ( ) ;
299
+
300
+ let mut io = match timeout (
301
+ Duration :: from_secs ( 2 ) ,
302
+ dialer_select_proto ( client_connection, dial_protos, version) ,
303
+ )
304
+ . await
305
+ . unwrap ( )
306
+ {
307
+ Err ( NegotiationError :: Failed ) => return ,
308
+ Ok ( ( _, io) ) => io,
309
+ Err ( _) => panic ! ( ) ,
310
+ } ;
311
+ // The dialer may write a payload that is even sent before it
312
+ // got confirmation of the last proposed protocol, when `V1Lazy`
313
+ // is used.
314
+
315
+ info ! ( "Writing early data" ) ;
316
+
317
+ io. write_all ( & dial_payload) . await . unwrap ( ) ;
318
+ match io. complete ( ) . await {
319
+ Err ( NegotiationError :: Failed ) => { }
320
+ _ => panic ! ( ) ,
321
+ }
322
+ } ) ;
323
+
324
+ server. await ;
325
+ client. await ;
326
+
327
+ info ! ( "---------------------------------------" )
328
+ } ) ;
329
+ }
330
+
331
+ quickcheck:: QuickCheck :: new ( )
332
+ . tests ( 1000 )
333
+ . quickcheck ( prop as fn ( _, _, _, _) ) ;
334
+ }
335
+
336
+ #[ async_std:: test]
337
+ async fn v1_lazy_do_not_wait_for_negotiation_on_poll_close ( ) {
338
+ let ( client_connection, _server_connection) =
339
+ futures_ringbuf:: Endpoint :: pair ( 1024 * 1024 , 1 ) ;
340
+
341
+ let client = async_std:: task:: spawn ( async move {
342
+ // Single protocol to allow for lazy (or optimistic) protocol negotiation.
343
+ let protos = vec ! [ "/proto1" ] ;
344
+ let ( proto, mut io) = dialer_select_proto ( client_connection, protos, Version :: V1Lazy )
345
+ . await
346
+ . unwrap ( ) ;
347
+ assert_eq ! ( proto, "/proto1" ) ;
348
+
349
+ // client can close the connection even though protocol negotiation is not yet done, i.e.
350
+ // `_server_connection` had been untouched.
351
+ io. close ( ) . await . unwrap ( ) ;
352
+ } ) ;
353
+
354
+ async_std:: future:: timeout ( Duration :: from_secs ( 10 ) , client)
355
+ . await
356
+ . unwrap ( ) ;
357
+ }
358
+
359
+ #[ derive( Clone , Debug ) ]
360
+ struct DialerProtos ( Vec < & ' static str > ) ;
361
+
362
+ impl Arbitrary for DialerProtos {
363
+ fn arbitrary ( g : & mut Gen ) -> Self {
364
+ if bool:: arbitrary ( g) {
365
+ DialerProtos ( vec ! [ "/proto1" ] )
366
+ } else {
367
+ DialerProtos ( vec ! [ "/proto1" , "/proto2" ] )
368
+ }
369
+ }
370
+ }
371
+
372
+ #[ derive( Clone , Debug ) ]
373
+ struct ListenerProtos ( Vec < & ' static str > ) ;
374
+
375
+ impl Arbitrary for ListenerProtos {
376
+ fn arbitrary ( g : & mut Gen ) -> Self {
377
+ if bool:: arbitrary ( g) {
378
+ ListenerProtos ( vec ! [ "/proto3" ] )
379
+ } else {
380
+ ListenerProtos ( vec ! [ "/proto3" , "/proto4" ] )
381
+ }
382
+ }
383
+ }
384
+
385
+ #[ derive( Clone , Debug ) ]
386
+ struct DialPayload ( Vec < u8 > ) ;
387
+
388
+ impl Arbitrary for DialPayload {
389
+ fn arbitrary ( g : & mut Gen ) -> Self {
390
+ DialPayload (
391
+ ( 0 ..g. gen_range ( 0 ..2u8 ) )
392
+ . map ( |_| g. gen_range ( 1 ..255 ) ) // We can generate 0 as that will produce a different error.
393
+ . collect ( ) ,
394
+ )
395
+ }
396
+ }
397
+ }
0 commit comments