1
+ use http_body_util:: StreamBody ;
2
+ use hyper:: body:: Bytes ;
3
+ use hyper:: body:: Frame ;
4
+ use hyper:: rt:: { Read , ReadBufCursor , Write } ;
5
+ use hyper:: server:: conn:: http1;
6
+ use hyper:: service:: service_fn;
7
+ use hyper:: { Response , StatusCode } ;
8
+ use pin_project_lite:: pin_project;
9
+ use std:: convert:: Infallible ;
10
+ use std:: io;
11
+ use std:: pin:: Pin ;
12
+ use std:: task:: { ready, Context , Poll } ;
13
+ use tokio:: sync:: mpsc;
14
+ use tracing:: { error, info} ;
15
+
16
+ pin_project ! {
17
+ #[ derive( Debug ) ]
18
+ pub struct TxReadyStream {
19
+ #[ pin]
20
+ read_rx: mpsc:: UnboundedReceiver <Vec <u8 >>,
21
+ write_tx: mpsc:: UnboundedSender <Vec <u8 >>,
22
+ read_buffer: Vec <u8 >,
23
+ poll_since_write: bool ,
24
+ flush_count: usize ,
25
+ }
26
+ }
27
+
28
+ impl TxReadyStream {
29
+ fn new (
30
+ read_rx : mpsc:: UnboundedReceiver < Vec < u8 > > ,
31
+ write_tx : mpsc:: UnboundedSender < Vec < u8 > > ,
32
+ ) -> Self {
33
+ Self {
34
+ read_rx,
35
+ write_tx,
36
+ read_buffer : Vec :: new ( ) ,
37
+ poll_since_write : true ,
38
+ flush_count : 0 ,
39
+ }
40
+ }
41
+
42
+ /// Create a new pair of connected ReadyStreams. Returns two streams that are connected to each other.
43
+ fn new_pair ( ) -> ( Self , Self ) {
44
+ let ( s1_tx, s2_rx) = mpsc:: unbounded_channel ( ) ;
45
+ let ( s2_tx, s1_rx) = mpsc:: unbounded_channel ( ) ;
46
+ let s1 = Self :: new ( s1_rx, s1_tx) ;
47
+ let s2 = Self :: new ( s2_rx, s2_tx) ;
48
+ ( s1, s2)
49
+ }
50
+
51
+ /// Send data to the other end of the stream (this will be available for reading on the other stream)
52
+ fn send ( & self , data : & [ u8 ] ) -> Result < ( ) , mpsc:: error:: SendError < Vec < u8 > > > {
53
+ self . write_tx . send ( data. to_vec ( ) )
54
+ }
55
+
56
+
57
+ /// Receive data written to this stream by the other end (async)
58
+ async fn recv ( & mut self ) -> Option < Vec < u8 > > {
59
+ self . read_rx . recv ( ) . await
60
+ }
61
+ }
62
+
63
+ impl Read for TxReadyStream {
64
+ fn poll_read (
65
+ mut self : Pin < & mut Self > ,
66
+ cx : & mut Context < ' _ > ,
67
+ mut buf : ReadBufCursor < ' _ > ,
68
+ ) -> Poll < io:: Result < ( ) > > {
69
+ let mut this = self . as_mut ( ) . project ( ) ;
70
+
71
+ // First, try to satisfy the read request from the internal buffer
72
+ if !this. read_buffer . is_empty ( ) {
73
+ let to_read = std:: cmp:: min ( this. read_buffer . len ( ) , buf. remaining ( ) ) ;
74
+ // Copy data from internal buffer to the read buffer
75
+ buf. put_slice ( & this. read_buffer [ ..to_read] ) ;
76
+ // Remove the consumed data from the internal buffer
77
+ this. read_buffer . drain ( ..to_read) ;
78
+ return Poll :: Ready ( Ok ( ( ) ) ) ;
79
+ }
80
+
81
+ // If internal buffer is empty, try to get data from the channel
82
+ match this. read_rx . try_recv ( ) {
83
+ Ok ( data) => {
84
+ // Copy as much data as we can fit in the buffer
85
+ let to_read = std:: cmp:: min ( data. len ( ) , buf. remaining ( ) ) ;
86
+ buf. put_slice ( & data[ ..to_read] ) ;
87
+
88
+ // Store any remaining data in the internal buffer for next time
89
+ if to_read < data. len ( ) {
90
+ let remaining = & data[ to_read..] ;
91
+ this. read_buffer . extend_from_slice ( remaining) ;
92
+ }
93
+ Poll :: Ready ( Ok ( ( ) ) )
94
+ }
95
+ Err ( mpsc:: error:: TryRecvError :: Empty ) => {
96
+ match ready ! ( this. read_rx. poll_recv( cx) ) {
97
+ Some ( data) => {
98
+ // Copy as much data as we can fit in the buffer
99
+ let to_read = std:: cmp:: min ( data. len ( ) , buf. remaining ( ) ) ;
100
+ buf. put_slice ( & data[ ..to_read] ) ;
101
+
102
+ // Store any remaining data in the internal buffer for next time
103
+ if to_read < data. len ( ) {
104
+ let remaining = & data[ to_read..] ;
105
+ this. read_buffer . extend_from_slice ( remaining) ;
106
+ }
107
+ Poll :: Ready ( Ok ( ( ) ) )
108
+ }
109
+ None => Poll :: Ready ( Ok ( ( ) ) ) ,
110
+ }
111
+ }
112
+ Err ( mpsc:: error:: TryRecvError :: Disconnected ) => {
113
+ // Channel closed, return EOF
114
+ Poll :: Ready ( Ok ( ( ) ) )
115
+ }
116
+ }
117
+ }
118
+ }
119
+
120
+ impl Write for TxReadyStream {
121
+ fn poll_write (
122
+ mut self : Pin < & mut Self > ,
123
+ _cx : & mut Context < ' _ > ,
124
+ buf : & [ u8 ] ,
125
+ ) -> Poll < io:: Result < usize > > {
126
+ if !self . poll_since_write {
127
+ return Poll :: Pending ;
128
+ }
129
+ self . poll_since_write = false ;
130
+ let this = self . project ( ) ;
131
+ let buf = Vec :: from ( & buf[ ..buf. len ( ) ] ) ;
132
+ let len = buf. len ( ) ;
133
+
134
+ // Send data through the channel - this should always be ready for unbounded channels
135
+ match this. write_tx . send ( buf) {
136
+ Ok ( _) => {
137
+ // Increment write count
138
+ Poll :: Ready ( Ok ( len) )
139
+ }
140
+ Err ( _) => {
141
+ error ! ( "ReadyStream::poll_write failed - channel closed" ) ;
142
+ Poll :: Ready ( Err ( io:: Error :: new (
143
+ io:: ErrorKind :: BrokenPipe ,
144
+ "Write channel closed" ,
145
+ ) ) )
146
+ }
147
+ }
148
+ }
149
+
150
+ fn poll_flush ( mut self : Pin < & mut Self > , _cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
151
+ self . flush_count += 1 ;
152
+ // We require two flushes to complete each chunk, simulating a success at the end of the old
153
+ // poll loop. After all chunks are written, we always succeed on flush to allow for finish.
154
+ if self . flush_count % 2 != 0 && self . flush_count < TOTAL_CHUNKS * 2 {
155
+ return Poll :: Pending ;
156
+ }
157
+ self . poll_since_write = true ;
158
+ Poll :: Ready ( Ok ( ( ) ) )
159
+ }
160
+
161
+ fn poll_shutdown ( self : Pin < & mut Self > , _cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
162
+ Poll :: Ready ( Ok ( ( ) ) )
163
+ }
164
+ }
165
+
166
+ fn init_tracing ( ) {
167
+ use std:: sync:: Once ;
168
+ static INIT : Once = Once :: new ( ) ;
169
+ INIT . call_once ( || {
170
+ tracing_subscriber:: fmt ( )
171
+ . with_max_level ( tracing:: Level :: INFO )
172
+ . with_target ( true )
173
+ . with_thread_ids ( true )
174
+ . with_thread_names ( true )
175
+ . init ( ) ;
176
+ } ) ;
177
+ }
178
+
179
+ const TOTAL_CHUNKS : usize = 16 ;
180
+
181
+ #[ tokio:: test( flavor = "multi_thread" , worker_threads = 2 ) ]
182
+ async fn body_test ( ) {
183
+ init_tracing ( ) ;
184
+ // Create a pair of connected streams
185
+ let ( server_stream, mut client_stream) = TxReadyStream :: new_pair ( ) ;
186
+
187
+ let mut http_builder = http1:: Builder :: new ( ) ;
188
+ http_builder. max_buf_size ( CHUNK_SIZE ) ;
189
+ const CHUNK_SIZE : usize = 64 * 1024 ;
190
+ let service = service_fn ( |_| async move {
191
+ info ! (
192
+ "Creating payload of {} chunks of {} KiB each ({} MiB total)..." ,
193
+ TOTAL_CHUNKS ,
194
+ CHUNK_SIZE / 1024 ,
195
+ TOTAL_CHUNKS * CHUNK_SIZE / ( 1024 * 1024 )
196
+ ) ;
197
+ let bytes = Bytes :: from ( vec ! [ 0 ; CHUNK_SIZE ] ) ;
198
+ let data = vec ! [ bytes. clone( ) ; TOTAL_CHUNKS ] ;
199
+ let stream = futures_util:: stream:: iter (
200
+ data. into_iter ( )
201
+ . map ( |b| Ok :: < _ , Infallible > ( Frame :: data ( b) ) ) ,
202
+ ) ;
203
+ let body = StreamBody :: new ( stream) ;
204
+ info ! ( "Server: Sending data response..." ) ;
205
+ Ok :: < _ , hyper:: Error > (
206
+ Response :: builder ( )
207
+ . status ( StatusCode :: OK )
208
+ . header ( "content-type" , "application/octet-stream" )
209
+ . header ( "content-length" , ( TOTAL_CHUNKS * CHUNK_SIZE ) . to_string ( ) )
210
+ . body ( body)
211
+ . unwrap ( ) ,
212
+ )
213
+ } ) ;
214
+
215
+ let server_task = tokio:: spawn ( async move {
216
+ let conn = http_builder. serve_connection ( server_stream, service) ;
217
+ if let Err ( e) = conn. await {
218
+ error ! ( "Server connection error: {}" , e) ;
219
+ }
220
+ } ) ;
221
+
222
+ let get_request = "GET / HTTP/1.1\r \n Host: localhost\r \n Connection: close\r \n \r \n " ;
223
+ client_stream. send ( get_request. as_bytes ( ) ) . unwrap ( ) ;
224
+
225
+ info ! ( "Client is reading response..." ) ;
226
+ let mut bytes_received = 0 ;
227
+ while let Some ( chunk) = client_stream. recv ( ) . await {
228
+ bytes_received += chunk. len ( ) ;
229
+ }
230
+ // Clean up
231
+ server_task. abort ( ) ;
232
+
233
+ info ! ( bytes_received, "Client done receiving bytes" ) ;
234
+ }
0 commit comments