@@ -15,7 +15,7 @@ use base64::{engine::general_purpose::URL_SAFE_NO_PAD, DecodeError, Engine as _}
1515use serde:: { de:: DeserializeOwned , Deserialize , Serialize } ;
1616use serde_json:: Value ;
1717use time:: { Duration , OffsetDateTime } ;
18- use tokio:: sync:: { Mutex , MutexGuard } ;
18+ use tokio:: sync:: { MappedMutexGuard , Mutex , MutexGuard } ;
1919
2020use crate :: { session_store, SessionStore } ;
2121
@@ -97,21 +97,21 @@ impl Session {
9797 }
9898
9999 #[ tracing:: instrument( skip( self ) , err) ]
100- async fn get_record ( & self ) -> Result < MutexGuard < Option < Record > > > {
100+ async fn get_record ( & self ) -> Result < MappedMutexGuard < Record > > {
101101 let mut record_guard = self . record . lock ( ) . await ;
102- let session_id = * self . session_id . lock ( ) ;
103102
104- // Lazily load the record.
103+ // Lazily load the record since `None` here indicates we have no yet loaded it .
105104 if record_guard. is_none ( ) {
106105 tracing:: trace!( "record not loaded from store; loading" ) ;
107106
108- * record_guard = if let Some ( session_id) = session_id {
109- match self . store . load ( & session_id) . await . map_err ( Error :: Store ) ? {
110- Some ( mut loaded_record) => {
107+ let session_id = * self . session_id . lock ( ) ;
108+ * record_guard = Some ( if let Some ( session_id) = session_id {
109+ match self . store . load ( & session_id) . await ? {
110+ Some ( loaded_record) => {
111111 tracing:: trace!( "record found in store" ) ;
112- loaded_record. expiry_date = self . expiry_date ( ) ;
113- Some ( loaded_record)
112+ loaded_record
114113 }
114+
115115 None => {
116116 // A well-behaved user agent should not send session cookies after
117117 // expiration. Even so it's possible for an expired session to be removed
@@ -120,16 +120,19 @@ impl Session {
120120 // malicious behavior.
121121 tracing:: warn!( "possibly suspicious activity: record not found in store" ) ;
122122 * self . session_id . lock ( ) = None ;
123- Some ( self . create_record ( ) )
123+ self . create_record ( )
124124 }
125125 }
126126 } else {
127127 tracing:: trace!( "session id not found" ) ;
128- Some ( self . create_record ( ) )
129- }
128+ self . create_record ( )
129+ } )
130130 }
131131
132- Ok ( record_guard)
132+ Ok ( MutexGuard :: map ( record_guard, |opt| {
133+ opt. as_mut ( )
134+ . expect ( "Record should always be `Option::Some` at this point" )
135+ } ) )
133136 }
134137
135138 /// Inserts a `impl Serialize` value into the session.
@@ -207,14 +210,13 @@ impl Session {
207210 /// - If the session has not been hydrated and loading from the store fails,
208211 /// we fail with [`Error::Store`].
209212 pub async fn insert_value ( & self , key : & str , value : Value ) -> Result < Option < Value > > {
210- Ok ( self . get_record ( ) . await ?. as_mut ( ) . and_then ( |record| {
211- if record. data . get ( key) != Some ( & value) {
212- self . is_modified . store ( true , atomic:: Ordering :: Release ) ;
213- record. data . insert ( key. to_string ( ) , value)
214- } else {
215- None
216- }
217- } ) )
213+ let mut record_guard = self . get_record ( ) . await ?;
214+ Ok ( if record_guard. data . get ( key) != Some ( & value) {
215+ self . is_modified . store ( true , atomic:: Ordering :: Release ) ;
216+ record_guard. data . insert ( key. to_string ( ) , value)
217+ } else {
218+ None
219+ } )
218220 }
219221
220222 /// Gets a value from the store.
@@ -275,11 +277,8 @@ impl Session {
275277 /// - If the session has not been hydrated and loading from the store fails,
276278 /// we fail with [`Error::Store`].
277279 pub async fn get_value ( & self , key : & str ) -> Result < Option < Value > > {
278- Ok ( self
279- . get_record ( )
280- . await ?
281- . as_ref ( )
282- . and_then ( |record| record. data . get ( key) . cloned ( ) ) )
280+ let record_guard = self . get_record ( ) . await ?;
281+ Ok ( record_guard. data . get ( key) . cloned ( ) )
283282 }
284283
285284 /// Removes a value from the store, retuning the value of the key if it was
@@ -346,10 +345,9 @@ impl Session {
346345 /// - If the session has not been hydrated and loading from the store fails,
347346 /// we fail with [`Error::Store`].
348347 pub async fn remove_value ( & self , key : & str ) -> Result < Option < Value > > {
349- Ok ( self . get_record ( ) . await ?. as_mut ( ) . and_then ( |record| {
350- self . is_modified . store ( true , atomic:: Ordering :: Release ) ;
351- record. data . remove ( key)
352- } ) )
348+ let mut record_guard = self . get_record ( ) . await ?;
349+ self . is_modified . store ( true , atomic:: Ordering :: Release ) ;
350+ Ok ( record_guard. data . remove ( key) )
353351 }
354352
355353 /// Clears the session of all data but does not delete it from the store.
@@ -649,24 +647,21 @@ impl Session {
649647 /// - If saving to the store fails, we fail with [`Error::Store`].
650648 #[ tracing:: instrument( skip( self ) , err) ]
651649 pub async fn save ( & self ) -> Result < ( ) > {
652- // N.B.: `get_record` will create a new record if one isn't found in the store.
653- if let Some ( record) = self . get_record ( ) . await ?. as_mut ( ) {
654- record. expiry_date = self . expiry_date ( ) ;
655-
656- {
657- let mut session_id_guard = self . session_id . lock ( ) ;
658- if session_id_guard. is_none ( ) {
659- // Generate a new ID here since e.g. flush may have been called, which will
660- // not directly update the record ID.
661- let id = Id :: default ( ) ;
662- * session_id_guard = Some ( id) ;
663- record. id = id;
664- }
650+ let mut record_guard = self . get_record ( ) . await ?;
651+ record_guard. expiry_date = self . expiry_date ( ) ;
652+ {
653+ let mut session_id_guard = self . session_id . lock ( ) ;
654+ if session_id_guard. is_none ( ) {
655+ // Generate a new ID here since e.g. flush may have been called, which will
656+ // not directly update the record ID.
657+ let id = Id :: default ( ) ;
658+ * session_id_guard = Some ( id) ;
659+ record_guard. id = id;
665660 }
666-
667- self . store . save ( record) . await . map_err ( Error :: Store ) ?;
668661 }
669662
663+ self . store . save ( & record_guard) . await . map_err ( Error :: Store ) ?;
664+
670665 Ok ( ( ) )
671666 }
672667
@@ -829,13 +824,10 @@ impl Session {
829824 /// with [`Error::Store`].
830825 pub async fn cycle_id ( & self ) -> Result < ( ) > {
831826 let mut record_guard = self . get_record ( ) . await ?;
832- let Some ( record) = record_guard. as_mut ( ) else {
833- return Ok ( ( ) ) ;
834- } ;
835827
836- let old_session_id = record . id ;
837- record . id = Id :: default ( ) ;
838- * self . session_id . lock ( ) = Some ( record . id ) ;
828+ let old_session_id = record_guard . id ;
829+ record_guard . id = Id :: default ( ) ;
830+ * self . session_id . lock ( ) = Some ( record_guard . id ) ;
839831
840832 self . store
841833 . delete ( & old_session_id)
0 commit comments