@@ -26,7 +26,7 @@ use std::{
26
26
27
27
use axum:: {
28
28
body:: HttpBody ,
29
- extract:: { ContentLengthLimit , FromRef , Path , State } ,
29
+ extract:: { FromRef , Path , State } ,
30
30
http:: {
31
31
header:: { CONTENT_TYPE , ETAG , IF_MATCH , IF_NONE_MATCH , LOCATION } ,
32
32
StatusCode ,
@@ -37,17 +37,19 @@ use axum::{
37
37
} ;
38
38
use base64ct:: Encoding ;
39
39
use bytes:: Bytes ;
40
- use headers:: { ContentType , ETag , Expires , HeaderName , IfMatch , IfNoneMatch , LastModified } ;
40
+ use headers:: {
41
+ ContentType , ETag , Expires , HeaderName , HeaderValue , IfMatch , IfNoneMatch , LastModified ,
42
+ } ;
41
43
use mime:: Mime ;
42
44
use sha2:: Digest ;
43
45
use tokio:: sync:: RwLock ;
44
- use tower_http:: cors:: { Any , CorsLayer } ;
46
+ use tower_http:: {
47
+ cors:: { Any , CorsLayer } ,
48
+ limit:: RequestBodyLimitLayer ,
49
+ set_header:: SetResponseHeaderLayer ,
50
+ } ;
45
51
use uuid:: Uuid ;
46
52
47
- // TODO: config?
48
- const MAX_BYTES : u64 = 4096 ;
49
- const TTL : Duration = Duration :: from_secs ( 60 ) ;
50
-
51
53
struct Session {
52
54
hash : [ u8 ; 32 ] ,
53
55
data : Bytes ,
@@ -57,14 +59,14 @@ struct Session {
57
59
}
58
60
59
61
impl Session {
60
- fn new ( data : Bytes , content_type : Mime ) -> Self {
62
+ fn new ( data : Bytes , content_type : Mime , ttl : Duration ) -> Self {
61
63
let hash = sha2:: Sha256 :: digest ( & data) . into ( ) ;
62
64
let now = SystemTime :: now ( ) ;
63
65
Self {
64
66
hash,
65
67
data,
66
68
content_type,
67
- expires : now + TTL ,
69
+ expires : now + ttl ,
68
70
last_modified : now,
69
71
}
70
72
}
@@ -113,6 +115,7 @@ impl Session {
113
115
struct Sessions {
114
116
// TODO: is that global lock alright?
115
117
inner : Arc < RwLock < HashMap < Uuid , Session > > > ,
118
+ ttl : Duration ,
116
119
}
117
120
118
121
impl Sessions {
@@ -153,25 +156,19 @@ impl AppState {
153
156
async fn new_session (
154
157
State ( sessions) : State < Sessions > ,
155
158
content_type : Option < TypedHeader < ContentType > > ,
156
- // TODO: this requires a Content-Length header, is that alright?
157
- ContentLengthLimit ( payload) : ContentLengthLimit < Bytes , MAX_BYTES > ,
159
+ payload : Bytes ,
158
160
) -> impl IntoResponse {
161
+ let ttl = sessions. ttl ;
159
162
// TODO: should we use something else? Check for colisions?
160
163
let id = Uuid :: new_v4 ( ) ;
161
164
let content_type =
162
165
content_type. map_or ( mime:: APPLICATION_OCTET_STREAM , |TypedHeader ( c) | c. into ( ) ) ;
163
- let session = Session :: new ( payload, content_type) ;
166
+ let session = Session :: new ( payload, content_type, ttl ) ;
164
167
let headers = session. typed_headers ( ) ;
165
- sessions. insert ( id, session, TTL ) . await ;
168
+ sessions. insert ( id, session, ttl ) . await ;
166
169
167
170
let location = id. to_string ( ) ;
168
- let additional_headers = [
169
- ( LOCATION , location) ,
170
- (
171
- HeaderName :: from_static ( "x-max-bytes" ) ,
172
- MAX_BYTES . to_string ( ) ,
173
- ) ,
174
- ] ;
171
+ let additional_headers = [ ( LOCATION , location) ] ;
175
172
( StatusCode :: CREATED , headers, additional_headers)
176
173
}
177
174
@@ -188,7 +185,7 @@ async fn update_session(
188
185
Path ( id) : Path < Uuid > ,
189
186
content_type : Option < TypedHeader < ContentType > > ,
190
187
if_match : Option < TypedHeader < IfMatch > > ,
191
- ContentLengthLimit ( payload) : ContentLengthLimit < Bytes , MAX_BYTES > ,
188
+ payload : Bytes ,
192
189
) -> Response {
193
190
if let Some ( session) = sessions. write ( ) . await . get_mut ( & id) {
194
191
if let Some ( TypedHeader ( if_match) ) = if_match {
@@ -235,13 +232,16 @@ async fn get_session(
235
232
}
236
233
237
234
#[ must_use]
238
- pub fn router < B > ( prefix : & str ) -> Router < ( ) , B >
235
+ pub fn router < B > ( prefix : & str , ttl : Duration , max_bytes : usize ) -> Router < ( ) , B >
239
236
where
240
237
B : HttpBody + Send + ' static ,
241
238
<B as HttpBody >:: Data : Send ,
242
239
<B as HttpBody >:: Error : std:: error:: Error + Send + Sync ,
243
240
{
244
- let sessions = Sessions :: default ( ) ;
241
+ let sessions = Sessions {
242
+ inner : Arc :: default ( ) ,
243
+ ttl,
244
+ } ;
245
245
246
246
let state = AppState :: new ( sessions) ;
247
247
let router = Router :: with_state ( state)
@@ -251,13 +251,21 @@ where
251
251
get ( get_session) . put ( update_session) . delete ( delete_session) ,
252
252
) ;
253
253
254
- Router :: new ( ) . nest ( prefix, router) . layer (
255
- CorsLayer :: new ( )
256
- . allow_origin ( Any )
257
- . allow_methods ( Any )
258
- . allow_headers ( [ CONTENT_TYPE , IF_MATCH , IF_NONE_MATCH ] )
259
- . expose_headers ( [ ETAG , LOCATION , HeaderName :: from_static ( "x-max-bytes" ) ] ) ,
260
- )
254
+ Router :: new ( )
255
+ . nest ( prefix, router)
256
+ . layer ( RequestBodyLimitLayer :: new ( max_bytes) )
257
+ . layer ( SetResponseHeaderLayer :: if_not_present (
258
+ HeaderName :: from_static ( "x-max-bytes" ) ,
259
+ HeaderValue :: from_str ( & max_bytes. to_string ( ) )
260
+ . expect ( "Could not construct x-max-bytes header value" ) ,
261
+ ) )
262
+ . layer (
263
+ CorsLayer :: new ( )
264
+ . allow_origin ( Any )
265
+ . allow_methods ( Any )
266
+ . allow_headers ( [ CONTENT_TYPE , IF_MATCH , IF_NONE_MATCH ] )
267
+ . expose_headers ( [ ETAG , LOCATION , HeaderName :: from_static ( "x-max-bytes" ) ] ) ,
268
+ )
261
269
}
262
270
263
271
#[ cfg( test) ]
@@ -280,7 +288,8 @@ mod tests {
280
288
281
289
#[ tokio:: test]
282
290
async fn test_post_and_get ( ) {
283
- let app = router ( "/" ) ;
291
+ let ttl = Duration :: from_secs ( 60 ) ;
292
+ let app = router ( "/" , ttl, 4096 ) ;
284
293
285
294
let body = r#"{"hello": "world"}"# . to_string ( ) ;
286
295
let request = Request :: post ( "/" )
@@ -306,7 +315,7 @@ mod tests {
306
315
assert_eq ! ( response. headers( ) . get( ETAG ) . unwrap( ) , etag) ;
307
316
308
317
// Let the entry expire
309
- advance_time ( TTL + Duration :: from_secs ( 1 ) ) . await ;
318
+ advance_time ( ttl + Duration :: from_secs ( 1 ) ) . await ;
310
319
311
320
let body = hyper:: body:: to_bytes ( response. into_body ( ) ) . await . unwrap ( ) ;
312
321
assert_eq ! ( & body[ ..] , br#"{"hello": "world"}"# ) ;
@@ -318,7 +327,8 @@ mod tests {
318
327
319
328
#[ tokio:: test]
320
329
async fn test_post_and_get_if_none_match ( ) {
321
- let app = router ( "/" ) ;
330
+ let ttl = Duration :: from_secs ( 60 ) ;
331
+ let app = router ( "/" , ttl, 4096 ) ;
322
332
323
333
let body = r#"{"hello": "world"}"# . to_string ( ) ;
324
334
let request = Request :: post ( "/" )
@@ -344,7 +354,8 @@ mod tests {
344
354
345
355
#[ tokio:: test]
346
356
async fn test_post_and_put ( ) {
347
- let app = router ( "/" ) ;
357
+ let ttl = Duration :: from_secs ( 60 ) ;
358
+ let app = router ( "/" , ttl, 4096 ) ;
348
359
349
360
let body = r#"{"hello": "world"}"# . to_string ( ) ;
350
361
let request = Request :: post ( "/" )
@@ -370,7 +381,8 @@ mod tests {
370
381
371
382
#[ tokio:: test]
372
383
async fn test_post_and_put_if_match ( ) {
373
- let app = router ( "/" ) ;
384
+ let ttl = Duration :: from_secs ( 60 ) ;
385
+ let app = router ( "/" , ttl, 4096 ) ;
374
386
375
387
let body = r#"{"hello": "world"}"# . to_string ( ) ;
376
388
let request = Request :: post ( "/" )
@@ -407,7 +419,8 @@ mod tests {
407
419
408
420
#[ tokio:: test]
409
421
async fn test_post_delete_and_get ( ) {
410
- let app = router ( "/" ) ;
422
+ let ttl = Duration :: from_secs ( 60 ) ;
423
+ let app = router ( "/" , ttl, 4096 ) ;
411
424
412
425
let body = r#"{"hello": "world"}"# . to_string ( ) ;
413
426
let request = Request :: post ( "/" )
0 commit comments