Skip to content

Commit bc0d0f9

Browse files
use mapped mutex guard (#177)
1 parent e27ecce commit bc0d0f9

File tree

1 file changed

+43
-51
lines changed

1 file changed

+43
-51
lines changed

tower-sessions-core/src/session.rs

Lines changed: 43 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use base64::{engine::general_purpose::URL_SAFE_NO_PAD, DecodeError, Engine as _}
1515
use serde::{de::DeserializeOwned, Deserialize, Serialize};
1616
use serde_json::Value;
1717
use time::{Duration, OffsetDateTime};
18-
use tokio::sync::{Mutex, MutexGuard};
18+
use tokio::sync::{MappedMutexGuard, Mutex, MutexGuard};
1919

2020
use crate::{session_store, SessionStore};
2121

@@ -97,21 +97,21 @@ impl Session {
9797
}
9898

9999
#[tracing::instrument(skip(self), err)]
100-
async fn get_record(&self) -> Result<MutexGuard<Option<Record>>> {
100+
async fn get_record(&self) -> Result<MappedMutexGuard<Record>> {
101101
let mut record_guard = self.record.lock().await;
102-
let session_id = *self.session_id.lock();
103102

104-
// Lazily load the record.
103+
// Lazily load the record since `None` here indicates we have no yet loaded it.
105104
if record_guard.is_none() {
106105
tracing::trace!("record not loaded from store; loading");
107106

108-
*record_guard = if let Some(session_id) = session_id {
109-
match self.store.load(&session_id).await.map_err(Error::Store)? {
110-
Some(mut loaded_record) => {
107+
let session_id = *self.session_id.lock();
108+
*record_guard = Some(if let Some(session_id) = session_id {
109+
match self.store.load(&session_id).await? {
110+
Some(loaded_record) => {
111111
tracing::trace!("record found in store");
112-
loaded_record.expiry_date = self.expiry_date();
113-
Some(loaded_record)
112+
loaded_record
114113
}
114+
115115
None => {
116116
// A well-behaved user agent should not send session cookies after
117117
// expiration. Even so it's possible for an expired session to be removed
@@ -120,16 +120,19 @@ impl Session {
120120
// malicious behavior.
121121
tracing::warn!("possibly suspicious activity: record not found in store");
122122
*self.session_id.lock() = None;
123-
Some(self.create_record())
123+
self.create_record()
124124
}
125125
}
126126
} else {
127127
tracing::trace!("session id not found");
128-
Some(self.create_record())
129-
}
128+
self.create_record()
129+
})
130130
}
131131

132-
Ok(record_guard)
132+
Ok(MutexGuard::map(record_guard, |opt| {
133+
opt.as_mut()
134+
.expect("Record should always be `Option::Some` at this point")
135+
}))
133136
}
134137

135138
/// Inserts a `impl Serialize` value into the session.
@@ -207,14 +210,13 @@ impl Session {
207210
/// - If the session has not been hydrated and loading from the store fails,
208211
/// we fail with [`Error::Store`].
209212
pub async fn insert_value(&self, key: &str, value: Value) -> Result<Option<Value>> {
210-
Ok(self.get_record().await?.as_mut().and_then(|record| {
211-
if record.data.get(key) != Some(&value) {
212-
self.is_modified.store(true, atomic::Ordering::Release);
213-
record.data.insert(key.to_string(), value)
214-
} else {
215-
None
216-
}
217-
}))
213+
let mut record_guard = self.get_record().await?;
214+
Ok(if record_guard.data.get(key) != Some(&value) {
215+
self.is_modified.store(true, atomic::Ordering::Release);
216+
record_guard.data.insert(key.to_string(), value)
217+
} else {
218+
None
219+
})
218220
}
219221

220222
/// Gets a value from the store.
@@ -275,11 +277,8 @@ impl Session {
275277
/// - If the session has not been hydrated and loading from the store fails,
276278
/// we fail with [`Error::Store`].
277279
pub async fn get_value(&self, key: &str) -> Result<Option<Value>> {
278-
Ok(self
279-
.get_record()
280-
.await?
281-
.as_ref()
282-
.and_then(|record| record.data.get(key).cloned()))
280+
let record_guard = self.get_record().await?;
281+
Ok(record_guard.data.get(key).cloned())
283282
}
284283

285284
/// Removes a value from the store, retuning the value of the key if it was
@@ -346,10 +345,9 @@ impl Session {
346345
/// - If the session has not been hydrated and loading from the store fails,
347346
/// we fail with [`Error::Store`].
348347
pub async fn remove_value(&self, key: &str) -> Result<Option<Value>> {
349-
Ok(self.get_record().await?.as_mut().and_then(|record| {
350-
self.is_modified.store(true, atomic::Ordering::Release);
351-
record.data.remove(key)
352-
}))
348+
let mut record_guard = self.get_record().await?;
349+
self.is_modified.store(true, atomic::Ordering::Release);
350+
Ok(record_guard.data.remove(key))
353351
}
354352

355353
/// Clears the session of all data but does not delete it from the store.
@@ -649,24 +647,21 @@ impl Session {
649647
/// - If saving to the store fails, we fail with [`Error::Store`].
650648
#[tracing::instrument(skip(self), err)]
651649
pub async fn save(&self) -> Result<()> {
652-
// N.B.: `get_record` will create a new record if one isn't found in the store.
653-
if let Some(record) = self.get_record().await?.as_mut() {
654-
record.expiry_date = self.expiry_date();
655-
656-
{
657-
let mut session_id_guard = self.session_id.lock();
658-
if session_id_guard.is_none() {
659-
// Generate a new ID here since e.g. flush may have been called, which will
660-
// not directly update the record ID.
661-
let id = Id::default();
662-
*session_id_guard = Some(id);
663-
record.id = id;
664-
}
650+
let mut record_guard = self.get_record().await?;
651+
record_guard.expiry_date = self.expiry_date();
652+
{
653+
let mut session_id_guard = self.session_id.lock();
654+
if session_id_guard.is_none() {
655+
// Generate a new ID here since e.g. flush may have been called, which will
656+
// not directly update the record ID.
657+
let id = Id::default();
658+
*session_id_guard = Some(id);
659+
record_guard.id = id;
665660
}
666-
667-
self.store.save(record).await.map_err(Error::Store)?;
668661
}
669662

663+
self.store.save(&record_guard).await.map_err(Error::Store)?;
664+
670665
Ok(())
671666
}
672667

@@ -829,13 +824,10 @@ impl Session {
829824
/// with [`Error::Store`].
830825
pub async fn cycle_id(&self) -> Result<()> {
831826
let mut record_guard = self.get_record().await?;
832-
let Some(record) = record_guard.as_mut() else {
833-
return Ok(());
834-
};
835827

836-
let old_session_id = record.id;
837-
record.id = Id::default();
838-
*self.session_id.lock() = Some(record.id);
828+
let old_session_id = record_guard.id;
829+
record_guard.id = Id::default();
830+
*self.session_id.lock() = Some(record_guard.id);
839831

840832
self.store
841833
.delete(&old_session_id)

0 commit comments

Comments
 (0)