@@ -42,13 +42,13 @@ use headers::{
42
42
} ;
43
43
use mime:: Mime ;
44
44
use sha2:: Digest ;
45
- use tokio:: sync:: RwLock ;
45
+ use tokio:: sync:: { Mutex , RwLock } ;
46
46
use tower_http:: {
47
47
cors:: { Any , CorsLayer } ,
48
48
limit:: RequestBodyLimitLayer ,
49
49
set_header:: SetResponseHeaderLayer ,
50
50
} ;
51
- use uuid :: Uuid ;
51
+ use ulid :: Ulid ;
52
52
53
53
struct Session {
54
54
hash : [ u8 ; 32 ] ,
@@ -114,23 +114,34 @@ impl Session {
114
114
#[ derive( Clone , Default ) ]
115
115
struct Sessions {
116
116
// TODO: is that global lock alright?
117
- inner : Arc < RwLock < HashMap < Uuid , Session > > > ,
117
+ inner : Arc < RwLock < HashMap < Ulid , Session > > > ,
118
118
ttl : Duration ,
119
+ generator : Arc < Mutex < ulid:: Generator > > ,
119
120
}
120
121
121
122
impl Sessions {
122
- async fn insert ( self , id : Uuid , session : Session , ttl : Duration ) {
123
+ async fn insert ( self , id : Ulid , session : Session , ttl : Duration ) {
123
124
self . inner . write ( ) . await . insert ( id, session) ;
124
125
// TODO: cancel this task when an item gets deleted
125
126
tokio:: task:: spawn ( async move {
126
127
tokio:: time:: sleep ( ttl) . await ;
127
128
self . inner . write ( ) . await . remove ( & id) ;
128
129
} ) ;
129
130
}
131
+
132
+ async fn generate_id ( & self ) -> Ulid {
133
+ self . generator
134
+ . lock ( )
135
+ . await
136
+ . generate ( )
137
+ // This would panic the thread if too many IDs (more than 2^40) are generated on the same
138
+ // millisecond, which is very unlikely
139
+ . expect ( "Failed to generate random ID" )
140
+ }
130
141
}
131
142
132
143
impl Deref for Sessions {
133
- type Target = RwLock < HashMap < Uuid , Session > > ;
144
+ type Target = RwLock < HashMap < Ulid , Session > > ;
134
145
135
146
fn deref ( & self ) -> & Self :: Target {
136
147
& self . inner
@@ -159,8 +170,9 @@ async fn new_session(
159
170
payload : Bytes ,
160
171
) -> impl IntoResponse {
161
172
let ttl = sessions. ttl ;
162
- // TODO: should we use something else? Check for colisions?
163
- let id = Uuid :: new_v4 ( ) ;
173
+
174
+ let id = sessions. generate_id ( ) . await ;
175
+
164
176
let content_type =
165
177
content_type. map_or ( mime:: APPLICATION_OCTET_STREAM , |TypedHeader ( c) | c. into ( ) ) ;
166
178
let session = Session :: new ( payload, content_type, ttl) ;
@@ -172,7 +184,7 @@ async fn new_session(
172
184
( StatusCode :: CREATED , headers, additional_headers)
173
185
}
174
186
175
- async fn delete_session ( State ( sessions) : State < Sessions > , Path ( id) : Path < Uuid > ) -> StatusCode {
187
+ async fn delete_session ( State ( sessions) : State < Sessions > , Path ( id) : Path < Ulid > ) -> StatusCode {
176
188
if sessions. write ( ) . await . remove ( & id) . is_some ( ) {
177
189
StatusCode :: NO_CONTENT
178
190
} else {
@@ -182,7 +194,7 @@ async fn delete_session(State(sessions): State<Sessions>, Path(id): Path<Uuid>)
182
194
183
195
async fn update_session (
184
196
State ( sessions) : State < Sessions > ,
185
- Path ( id) : Path < Uuid > ,
197
+ Path ( id) : Path < Ulid > ,
186
198
content_type : Option < TypedHeader < ContentType > > ,
187
199
if_match : Option < TypedHeader < IfMatch > > ,
188
200
payload : Bytes ,
@@ -206,7 +218,7 @@ async fn update_session(
206
218
207
219
async fn get_session (
208
220
State ( sessions) : State < Sessions > ,
209
- Path ( id) : Path < Uuid > ,
221
+ Path ( id) : Path < Ulid > ,
210
222
if_none_match : Option < TypedHeader < IfNoneMatch > > ,
211
223
) -> Response {
212
224
let sessions = sessions. read ( ) . await ;
@@ -241,6 +253,7 @@ where
241
253
let sessions = Sessions {
242
254
inner : Arc :: default ( ) ,
243
255
ttl,
256
+ generator : Arc :: default ( ) ,
244
257
} ;
245
258
246
259
let state = AppState :: new ( sessions) ;
@@ -383,6 +396,40 @@ mod tests {
383
396
assert_eq ! ( response. status( ) , StatusCode :: NOT_FOUND ) ;
384
397
}
385
398
399
+ #[ tokio:: test]
400
+ async fn test_monotonically_increasing ( ) {
401
+ let ttl = Duration :: from_secs ( 60 ) ;
402
+ let app = router ( "/" , ttl, 4096 ) ;
403
+
404
+ // Prepare a thousand requests
405
+ let mut requests = Vec :: with_capacity ( 1000 ) ;
406
+ for _ in 0 ..requests. capacity ( ) {
407
+ requests. push (
408
+ app. clone ( )
409
+ . oneshot ( Request :: post ( "/" ) . body ( String :: new ( ) ) . unwrap ( ) ) ,
410
+ ) ;
411
+ }
412
+
413
+ // Run them all in order
414
+ let mut responses = Vec :: with_capacity ( requests. len ( ) ) ;
415
+ for fut in requests {
416
+ responses. push ( fut. await ) ;
417
+ }
418
+
419
+ // Get the location out of them
420
+ let ids: Vec < _ > = responses
421
+ . iter ( )
422
+ . map ( |res| {
423
+ let res = res. as_ref ( ) . unwrap ( ) ;
424
+ assert_eq ! ( res. status( ) , StatusCode :: CREATED ) ;
425
+ res. headers ( ) . get ( LOCATION ) . unwrap ( ) . to_str ( ) . unwrap ( )
426
+ } )
427
+ . collect ( ) ;
428
+
429
+ // Check that all the IDs are monotonically increasing
430
+ assert ! ( ids. windows( 2 ) . all( |loc| loc[ 0 ] < loc[ 1 ] ) ) ;
431
+ }
432
+
386
433
#[ tokio:: test]
387
434
async fn test_post_max_bytes ( ) {
388
435
let ttl = Duration :: from_secs ( 60 ) ;
0 commit comments