@@ -209,8 +209,14 @@ impl<W: AsyncWrite + Unpin> FrameWrite<W> {
209
209
210
210
#[ cfg( test) ]
211
211
mod tests {
212
+ use std:: pin:: Pin ;
213
+ use std:: task:: Context ;
214
+ use std:: task:: Poll ;
215
+
216
+ use bytes:: Bytes ;
212
217
use rand:: Rng ;
213
218
use rand:: thread_rng;
219
+ use tokio:: io:: AsyncWrite ;
214
220
use tokio:: io:: AsyncWriteExt ;
215
221
216
222
use super :: * ;
@@ -308,5 +314,109 @@ mod tests {
308
314
assert_eq ! ( err. kind( ) , std:: io:: ErrorKind :: UnexpectedEof ) ;
309
315
}
310
316
311
- // todo: test cancellation, frame size
317
+ /// A wrapper around an `AsyncWrite` that throttles how many bytes
318
+ /// may be written per poll.
319
+ ///
320
+ /// We are going to use this to simulate partial writes to test
321
+ /// cancellation safety: when the budget is 0, `poll_write`
322
+ /// returns `Poll::Pending` and calls the waker so the task is
323
+ /// scheduled to be polled again later.
324
+ struct Throttled < W > {
325
+ inner : W ,
326
+ // Number of bytes allowed to be written in the next poll. If
327
+ // 0, writes return `Poll::Pending`.
328
+ budget : usize ,
329
+ }
330
+
331
+ impl < W > Throttled < W > {
332
+ fn new ( inner : W ) -> Self {
333
+ Self {
334
+ inner,
335
+ budget : usize:: MAX ,
336
+ }
337
+ }
338
+
339
+ fn set_budget ( & mut self , n : usize ) {
340
+ self . budget = n;
341
+ }
342
+ }
343
+
344
+ impl < W : AsyncWrite + Unpin > AsyncWrite for Throttled < W > {
345
+ fn poll_write (
346
+ mut self : Pin < & mut Self > ,
347
+ cx : & mut Context < ' _ > ,
348
+ buf : & [ u8 ] ,
349
+ ) -> Poll < std:: io:: Result < usize > > {
350
+ // No budget left this poll. Return "not ready" and ask to
351
+ // be polled again later.
352
+ if self . budget == 0 {
353
+ cx. waker ( ) . wake_by_ref ( ) ;
354
+ return Poll :: Pending ;
355
+ }
356
+ let n = buf. len ( ) . min ( self . budget ) ;
357
+ self . budget -= n;
358
+ // Delegate a write of the first `n` bytes to the inner
359
+ // writer.
360
+ Pin :: new ( & mut self . inner ) . poll_write ( cx, & buf[ ..n] )
361
+ }
362
+
363
+ fn poll_flush ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < std:: io:: Result < ( ) > > {
364
+ // Delegate to `inner` for flushing.
365
+ Pin :: new ( & mut self . inner ) . poll_flush ( cx)
366
+ }
367
+
368
+ fn poll_shutdown (
369
+ mut self : Pin < & mut Self > ,
370
+ cx : & mut Context < ' _ > ,
371
+ ) -> Poll < std:: io:: Result < ( ) > > {
372
+ // Delegate to `inner` (ensure resources are released and
373
+ // `EOF` is signaled downstream).
374
+ Pin :: new ( & mut self . inner ) . poll_shutdown ( cx)
375
+ }
376
+ }
377
+
378
+ #[ tokio:: test]
379
+ #[ allow( clippy:: disallowed_methods) ]
380
+ async fn test_writer_cancellation_resume ( ) {
381
+ let ( a, b) = tokio:: io:: duplex ( 4096 ) ;
382
+ let ( r, _wu) = tokio:: io:: split ( a) ;
383
+ let ( _ru, w) = tokio:: io:: split ( b) ;
384
+
385
+ let w = Throttled :: new ( w) ;
386
+ // 256 bytes, all = 0x2A ('*'), "the answer"
387
+ let body = Bytes :: from_static ( & [ 42u8 ; 256 ] ) ;
388
+ let mut reader = FrameReader :: new ( r, 1024 * 1024 ) ;
389
+ let mut fw = FrameWrite :: new ( w, body. clone ( ) ) ;
390
+
391
+ // Allow only the 8-byte length to be written, then cancel.
392
+ fw. writer . set_budget ( 8 ) ;
393
+ let fut = fw. send ( ) ;
394
+ tokio:: select! {
395
+ _ = fut => panic!( "send unexpectedly completed" ) ,
396
+ _ = tokio:: time:: sleep( std:: time:: Duration :: from_millis( 5 ) ) => { }
397
+ }
398
+ // The `fut` is dropped here i.e. "cancellation".
399
+ assert ! (
400
+ tokio:: time:: timeout( std:: time:: Duration :: from_millis( 20 ) , async {
401
+ reader. next( ) . await
402
+ } )
403
+ . await
404
+ . is_err( ) ,
405
+ "a full frame isn't available yet, so reader.next().await should block"
406
+ ) ;
407
+
408
+ // Now allow the remaining body to flush and complete the
409
+ // frame.
410
+ fw. writer . set_budget ( usize:: MAX ) ;
411
+ fw. send ( ) . await . unwrap ( ) ;
412
+ let mut w = fw. complete ( ) ;
413
+ let got = reader. next ( ) . await . unwrap ( ) . unwrap ( ) ;
414
+ assert_eq ! ( got, body) ;
415
+
416
+ // Shutdown and test for EOF on boundary.
417
+ w. shutdown ( ) . await . unwrap ( ) ;
418
+ assert ! ( reader. next( ) . await . unwrap( ) . is_none( ) ) ;
419
+ }
420
+
421
+ // todo: frame size
312
422
}
0 commit comments