Skip to content

Commit c96476f

Browse files
DOsingaDouwe Osinga
andauthored
Make create_session work concurrently (block#4954)
Co-authored-by: Douwe Osinga <[email protected]>
1 parent a3fbeb0 commit c96476f

File tree

1 file changed

+132
-30
lines changed

1 file changed

+132
-30
lines changed

crates/goose/src/session/session_manager.rs

Lines changed: 132 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -156,36 +156,10 @@ impl SessionManager {
156156
}
157157

158158
pub async fn create_session(working_dir: PathBuf, description: String) -> Result<Session> {
159-
let today = chrono::Utc::now().format("%Y%m%d").to_string();
160-
let storage = Self::instance().await?;
161-
162-
let mut tx = storage.pool.begin().await?;
163-
164-
let max_idx = sqlx::query_scalar::<_, Option<i32>>(
165-
"SELECT MAX(CAST(SUBSTR(id, 10) AS INTEGER)) FROM sessions WHERE id LIKE ?",
166-
)
167-
.bind(format!("{}_%", today))
168-
.fetch_one(&mut *tx)
169-
.await?
170-
.unwrap_or(0);
171-
172-
let session_id = format!("{}_{}", today, max_idx + 1);
173-
174-
sqlx::query(
175-
r#"
176-
INSERT INTO sessions (id, description, working_dir, extension_data)
177-
VALUES (?, ?, ?, '{}')
178-
"#,
179-
)
180-
.bind(&session_id)
181-
.bind(&description)
182-
.bind(working_dir.to_string_lossy().as_ref())
183-
.execute(&mut *tx)
184-
.await?;
185-
186-
tx.commit().await?;
187-
188-
Self::get_session(&session_id, false).await
159+
Self::instance()
160+
.await?
161+
.create_session(working_dir, description)
162+
.await
189163
}
190164

191165
pub async fn get_session(id: &str, include_messages: bool) -> Result<Session> {
@@ -606,6 +580,32 @@ impl SessionStorage {
606580
Ok(())
607581
}
608582

583+
async fn create_session(&self, working_dir: PathBuf, description: String) -> Result<Session> {
584+
let today = chrono::Utc::now().format("%Y%m%d").to_string();
585+
Ok(sqlx::query_as(
586+
r#"
587+
INSERT INTO sessions (id, description, working_dir, extension_data)
588+
VALUES (
589+
? || '_' || CAST(COALESCE((
590+
SELECT MAX(CAST(SUBSTR(id, 10) AS INTEGER))
591+
FROM sessions
592+
WHERE id LIKE ? || '_%'
593+
), 0) + 1 AS TEXT),
594+
?,
595+
?,
596+
'{}'
597+
)
598+
RETURNING *
599+
"#,
600+
)
601+
.bind(&today)
602+
.bind(&today)
603+
.bind(&description)
604+
.bind(working_dir.to_string_lossy().as_ref())
605+
.fetch_one(&self.pool)
606+
.await?)
607+
}
608+
609609
async fn get_session(&self, id: &str, include_messages: bool) -> Result<Session> {
610610
let mut session = sqlx::query_as::<_, Session>(
611611
r#"
@@ -859,3 +859,105 @@ impl SessionStorage {
859859
})
860860
}
861861
}
862+
863+
#[cfg(test)]
864+
mod tests {
865+
use super::*;
866+
use crate::conversation::message::{Message, MessageContent};
867+
use tempfile::TempDir;
868+
869+
const NUM_CONCURRENT_SESSIONS: i32 = 10;
870+
871+
#[tokio::test]
872+
async fn test_concurrent_session_creation() {
873+
let temp_dir = TempDir::new().unwrap();
874+
let db_path = temp_dir.path().join("test_sessions.db");
875+
876+
let storage = Arc::new(SessionStorage::create(&db_path).await.unwrap());
877+
878+
let mut handles = vec![];
879+
880+
for i in 0..NUM_CONCURRENT_SESSIONS {
881+
let session_storage = Arc::clone(&storage);
882+
let handle = tokio::spawn(async move {
883+
let working_dir = PathBuf::from(format!("/tmp/test_{}", i));
884+
let description = format!("Test session {}", i);
885+
886+
let session = session_storage
887+
.create_session(working_dir.clone(), description)
888+
.await
889+
.unwrap();
890+
891+
session_storage
892+
.add_message(
893+
&session.id,
894+
&Message {
895+
id: None,
896+
role: Role::User,
897+
created: chrono::Utc::now().timestamp_millis(),
898+
content: vec![MessageContent::text("hello world")],
899+
metadata: Default::default(),
900+
},
901+
)
902+
.await
903+
.unwrap();
904+
905+
session_storage
906+
.add_message(
907+
&session.id,
908+
&Message {
909+
id: None,
910+
role: Role::Assistant,
911+
created: chrono::Utc::now().timestamp_millis(),
912+
content: vec![MessageContent::text("sup world?")],
913+
metadata: Default::default(),
914+
},
915+
)
916+
.await
917+
.unwrap();
918+
919+
session_storage
920+
.apply_update(
921+
SessionUpdateBuilder::new(session.id.clone())
922+
.description(format!("Updated session {}", i))
923+
.total_tokens(Some(100 * i)),
924+
)
925+
.await
926+
.unwrap();
927+
928+
let updated = session_storage
929+
.get_session(&session.id, true)
930+
.await
931+
.unwrap();
932+
assert_eq!(updated.message_count, 2);
933+
assert_eq!(updated.total_tokens, Some(100 * i));
934+
935+
session.id
936+
});
937+
handles.push(handle);
938+
}
939+
940+
let mut results = vec![];
941+
for handle in handles {
942+
results.push(handle.await.unwrap());
943+
}
944+
945+
assert_eq!(results.len(), NUM_CONCURRENT_SESSIONS as usize);
946+
947+
let unique_ids: std::collections::HashSet<_> = results.iter().collect();
948+
assert_eq!(unique_ids.len(), NUM_CONCURRENT_SESSIONS as usize);
949+
950+
let sessions = storage.list_sessions().await.unwrap();
951+
assert_eq!(sessions.len(), NUM_CONCURRENT_SESSIONS as usize);
952+
953+
for session in &sessions {
954+
assert_eq!(session.message_count, 2);
955+
assert!(session.description.starts_with("Updated session"));
956+
}
957+
958+
let insights = storage.get_insights().await.unwrap();
959+
assert_eq!(insights.total_sessions, NUM_CONCURRENT_SESSIONS as usize);
960+
let expected_tokens = 100 * NUM_CONCURRENT_SESSIONS * (NUM_CONCURRENT_SESSIONS - 1) / 2;
961+
assert_eq!(insights.total_tokens, expected_tokens as i64);
962+
}
963+
}

0 commit comments

Comments
 (0)