@@ -42,7 +42,7 @@ use headers::{
42
42
} ;
43
43
use mime:: Mime ;
44
44
use sha2:: Digest ;
45
- use tokio:: sync:: RwLock ;
45
+ use tokio:: sync:: { Mutex , RwLock } ;
46
46
use tower_http:: {
47
47
cors:: { Any , CorsLayer } ,
48
48
limit:: RequestBodyLimitLayer ,
@@ -116,6 +116,7 @@ struct Sessions {
116
116
// TODO: is that global lock alright?
117
117
inner : Arc < RwLock < HashMap < Ulid , Session > > > ,
118
118
ttl : Duration ,
119
+ generator : Arc < Mutex < ulid:: Generator > > ,
119
120
}
120
121
121
122
impl Sessions {
@@ -127,6 +128,16 @@ impl Sessions {
127
128
self . inner . write ( ) . await . remove ( & id) ;
128
129
} ) ;
129
130
}
131
+
132
+ async fn generate_id ( & self ) -> Ulid {
133
+ self . generator
134
+ . lock ( )
135
+ . await
136
+ . generate ( )
137
+ // This would panic the thread if too many IDs (more than 2^40) are generated on the same
138
+ // millisecond, which is very unlikely
139
+ . expect ( "Failed to generate random ID" )
140
+ }
130
141
}
131
142
132
143
impl Deref for Sessions {
@@ -159,7 +170,9 @@ async fn new_session(
159
170
payload : Bytes ,
160
171
) -> impl IntoResponse {
161
172
let ttl = sessions. ttl ;
162
- let id = Ulid :: new ( ) ;
173
+
174
+ let id = sessions. generate_id ( ) . await ;
175
+
163
176
let content_type =
164
177
content_type. map_or ( mime:: APPLICATION_OCTET_STREAM , |TypedHeader ( c) | c. into ( ) ) ;
165
178
let session = Session :: new ( payload, content_type, ttl) ;
@@ -240,6 +253,7 @@ where
240
253
let sessions = Sessions {
241
254
inner : Arc :: default ( ) ,
242
255
ttl,
256
+ generator : Arc :: default ( ) ,
243
257
} ;
244
258
245
259
let state = AppState :: new ( sessions) ;
@@ -382,6 +396,40 @@ mod tests {
382
396
assert_eq ! ( response. status( ) , StatusCode :: NOT_FOUND ) ;
383
397
}
384
398
399
+ #[ tokio:: test]
400
+ async fn test_monotonically_increasing ( ) {
401
+ let ttl = Duration :: from_secs ( 60 ) ;
402
+ let app = router ( "/" , ttl, 4096 ) ;
403
+
404
+ // Prepare a thousand requests
405
+ let mut requests = Vec :: with_capacity ( 1000 ) ;
406
+ for _ in 0 ..requests. capacity ( ) {
407
+ requests. push (
408
+ app. clone ( )
409
+ . oneshot ( Request :: post ( "/" ) . body ( String :: new ( ) ) . unwrap ( ) ) ,
410
+ ) ;
411
+ }
412
+
413
+ // Run them all in order
414
+ let mut responses = Vec :: with_capacity ( requests. len ( ) ) ;
415
+ for fut in requests {
416
+ responses. push ( fut. await ) ;
417
+ }
418
+
419
+ // Get the location out of them
420
+ let ids: Vec < _ > = responses
421
+ . iter ( )
422
+ . map ( |res| {
423
+ let res = res. as_ref ( ) . unwrap ( ) ;
424
+ assert_eq ! ( res. status( ) , StatusCode :: CREATED ) ;
425
+ res. headers ( ) . get ( LOCATION ) . unwrap ( ) . to_str ( ) . unwrap ( )
426
+ } )
427
+ . collect ( ) ;
428
+
429
+ // Check that all the IDs are monotonically increasing
430
+ assert ! ( ids. windows( 2 ) . all( |loc| loc[ 0 ] < loc[ 1 ] ) ) ;
431
+ }
432
+
385
433
#[ tokio:: test]
386
434
async fn test_post_max_bytes ( ) {
387
435
let ttl = Duration :: from_secs ( 60 ) ;
0 commit comments