Skip to content

Commit 538b465

Browse files
authored
feat(unstable): Add initial support for forking sessions (#33)
1 parent c009c23 commit 538b465

File tree

4 files changed

+154
-14
lines changed

4 files changed

+154
-14
lines changed

examples/agent.rs

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,6 @@ impl acp::Agent for ExampleAgent {
142142
Ok(acp::SetSessionModelResponse::default())
143143
}
144144

145-
#[cfg(feature = "unstable_session_list")]
146-
async fn list_sessions(
147-
&self,
148-
args: acp::ListSessionsRequest,
149-
) -> Result<acp::ListSessionsResponse, acp::Error> {
150-
log::info!("Received list sessions request {args:?}");
151-
Ok(acp::ListSessionsResponse::new(vec![]))
152-
}
153-
154145
async fn ext_method(&self, args: acp::ExtRequest) -> Result<acp::ExtResponse, acp::Error> {
155146
log::info!(
156147
"Received extension method call: method={}, params={:?}",

src/agent-client-protocol/src/agent.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ use agent_client_protocol_schema::{
66
LoadSessionResponse, NewSessionRequest, NewSessionResponse, PromptRequest, PromptResponse,
77
Result, SetSessionModeRequest, SetSessionModeResponse,
88
};
9+
#[cfg(feature = "unstable_session_fork")]
10+
use agent_client_protocol_schema::{ForkSessionRequest, ForkSessionResponse};
911
#[cfg(feature = "unstable_session_list")]
1012
use agent_client_protocol_schema::{ListSessionsRequest, ListSessionsResponse};
1113
#[cfg(feature = "unstable_session_model")]
@@ -140,6 +142,18 @@ pub trait Agent {
140142
Err(Error::method_not_found())
141143
}
142144

145+
/// **UNSTABLE**
146+
///
147+
/// This capability is not part of the spec yet, and may be removed or changed at any point.
148+
///
149+
/// Forks an existing session, creating a new session with the same conversation history.
150+
///
151+
/// Only available if the Agent supports the `sessionCapabilities.fork` capability.
152+
#[cfg(feature = "unstable_session_fork")]
153+
async fn fork_session(&self, _args: ForkSessionRequest) -> Result<ForkSessionResponse> {
154+
Err(Error::method_not_found())
155+
}
156+
143157
/// Handles extension method requests from the client.
144158
///
145159
/// Extension methods provide a way to add custom functionality while maintaining
@@ -198,6 +212,10 @@ impl<T: Agent> Agent for Rc<T> {
198212
async fn list_sessions(&self, args: ListSessionsRequest) -> Result<ListSessionsResponse> {
199213
self.as_ref().list_sessions(args).await
200214
}
215+
#[cfg(feature = "unstable_session_fork")]
216+
async fn fork_session(&self, args: ForkSessionRequest) -> Result<ForkSessionResponse> {
217+
self.as_ref().fork_session(args).await
218+
}
201219
async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
202220
self.as_ref().ext_method(args).await
203221
}
@@ -243,6 +261,10 @@ impl<T: Agent> Agent for Arc<T> {
243261
async fn list_sessions(&self, args: ListSessionsRequest) -> Result<ListSessionsResponse> {
244262
self.as_ref().list_sessions(args).await
245263
}
264+
#[cfg(feature = "unstable_session_fork")]
265+
async fn fork_session(&self, args: ForkSessionRequest) -> Result<ForkSessionResponse> {
266+
self.as_ref().fork_session(args).await
267+
}
246268
async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
247269
self.as_ref().ext_method(args).await
248270
}

src/agent-client-protocol/src/lib.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,16 @@ impl Agent for ClientSideConnection {
165165
.await
166166
}
167167

168+
#[cfg(feature = "unstable_session_fork")]
169+
async fn fork_session(&self, args: ForkSessionRequest) -> Result<ForkSessionResponse> {
170+
self.conn
171+
.request(
172+
AGENT_METHOD_NAMES.session_fork,
173+
Some(ClientRequest::ForkSessionRequest(args)),
174+
)
175+
.await
176+
}
177+
168178
async fn ext_method(&self, args: ExtRequest) -> Result<ExtResponse> {
169179
self.conn
170180
.request(
@@ -532,6 +542,10 @@ impl Side for AgentSide {
532542
m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
533543
.map(ClientRequest::ListSessionsRequest)
534544
.map_err(Into::into),
545+
#[cfg(feature = "unstable_session_fork")]
546+
m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
547+
.map(ClientRequest::ForkSessionRequest)
548+
.map_err(Into::into),
535549
m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
536550
.map(ClientRequest::PromptRequest)
537551
.map_err(Into::into),
@@ -606,6 +620,11 @@ impl<T: Agent> MessageHandler<AgentSide> for T {
606620
let response = self.list_sessions(args).await?;
607621
Ok(AgentResponse::ListSessionsResponse(response))
608622
}
623+
#[cfg(feature = "unstable_session_fork")]
624+
ClientRequest::ForkSessionRequest(args) => {
625+
let response = self.fork_session(args).await?;
626+
Ok(AgentResponse::ForkSessionResponse(response))
627+
}
609628
ClientRequest::ExtMethodRequest(args) => {
610629
let response = self.ext_method(args).await?;
611630
Ok(AgentResponse::ExtMethodResponse(response))

src/agent-client-protocol/src/rpc_tests.rs

Lines changed: 113 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ impl Client for TestClient {
133133

134134
#[derive(Clone)]
135135
struct TestAgent {
136-
sessions: Arc<Mutex<std::collections::HashSet<SessionId>>>,
136+
sessions: Arc<Mutex<std::collections::HashMap<SessionId, std::path::PathBuf>>>,
137137
prompts_received: Arc<Mutex<Vec<PromptReceived>>>,
138138
cancellations_received: Arc<Mutex<Vec<SessionId>>>,
139139
extension_notifications: Arc<Mutex<Vec<(String, ExtNotification)>>>,
@@ -144,7 +144,7 @@ type PromptReceived = (SessionId, Vec<ContentBlock>);
144144
impl TestAgent {
145145
fn new() -> Self {
146146
Self {
147-
sessions: Arc::new(Mutex::new(std::collections::HashSet::new())),
147+
sessions: Arc::new(Mutex::new(std::collections::HashMap::new())),
148148
prompts_received: Arc::new(Mutex::new(Vec::new())),
149149
cancellations_received: Arc::new(Mutex::new(Vec::new())),
150150
extension_notifications: Arc::new(Mutex::new(Vec::new())),
@@ -163,9 +163,12 @@ impl Agent for TestAgent {
163163
Ok(AuthenticateResponse::default())
164164
}
165165

166-
async fn new_session(&self, _arguments: NewSessionRequest) -> Result<NewSessionResponse> {
166+
async fn new_session(&self, arguments: NewSessionRequest) -> Result<NewSessionResponse> {
167167
let session_id = SessionId::new("test-session-123");
168-
self.sessions.lock().unwrap().insert(session_id.clone());
168+
self.sessions
169+
.lock()
170+
.unwrap()
171+
.insert(session_id.clone(), arguments.cwd);
169172
Ok(NewSessionResponse::new(session_id))
170173
}
171174

@@ -210,8 +213,30 @@ impl Agent for TestAgent {
210213
&self,
211214
_args: agent_client_protocol_schema::ListSessionsRequest,
212215
) -> Result<agent_client_protocol_schema::ListSessionsResponse> {
216+
let sessions = self.sessions.lock().unwrap();
217+
let session_infos: Vec<_> = sessions
218+
.iter()
219+
.map(|(id, cwd)| {
220+
agent_client_protocol_schema::SessionInfo::new(id.clone(), cwd.clone())
221+
})
222+
.collect();
213223
Ok(agent_client_protocol_schema::ListSessionsResponse::new(
214-
vec![],
224+
session_infos,
225+
))
226+
}
227+
228+
#[cfg(feature = "unstable_session_fork")]
229+
async fn fork_session(
230+
&self,
231+
args: agent_client_protocol_schema::ForkSessionRequest,
232+
) -> Result<agent_client_protocol_schema::ForkSessionResponse> {
233+
let new_session_id = SessionId::new(format!("fork-of-{}", args.session_id.0));
234+
self.sessions
235+
.lock()
236+
.unwrap()
237+
.insert(new_session_id.clone(), args.cwd);
238+
Ok(agent_client_protocol_schema::ForkSessionResponse::new(
239+
new_session_id,
215240
))
216241
}
217242

@@ -665,3 +690,86 @@ async fn test_extension_methods_and_notifications() {
665690
})
666691
.await;
667692
}
693+
694+
#[cfg(feature = "unstable_session_fork")]
695+
#[tokio::test]
696+
async fn test_fork_session() {
697+
let local_set = tokio::task::LocalSet::new();
698+
local_set
699+
.run_until(async {
700+
let client = TestClient::new();
701+
let agent = TestAgent::new();
702+
703+
let (agent_conn, _client_conn) = create_connection_pair(&client, &agent);
704+
705+
// First create a session
706+
let new_session_response = agent_conn
707+
.new_session(NewSessionRequest::new("/test"))
708+
.await
709+
.expect("new_session failed");
710+
711+
let original_session_id = new_session_response.session_id;
712+
713+
// Fork the session
714+
let fork_response = agent_conn
715+
.fork_session(agent_client_protocol_schema::ForkSessionRequest::new(
716+
original_session_id.clone(),
717+
"/test",
718+
))
719+
.await
720+
.expect("fork_session failed");
721+
722+
// Verify the forked session has a different ID
723+
assert_ne!(fork_response.session_id, original_session_id);
724+
assert_eq!(
725+
fork_response.session_id.0.as_ref(),
726+
format!("fork-of-{}", original_session_id.0)
727+
);
728+
729+
// Verify the forked session was added to the agent's sessions
730+
let sessions = agent.sessions.lock().unwrap();
731+
assert!(sessions.contains_key(&fork_response.session_id));
732+
})
733+
.await;
734+
}
735+
736+
#[cfg(feature = "unstable_session_list")]
737+
#[tokio::test]
738+
async fn test_list_sessions() {
739+
let local_set = tokio::task::LocalSet::new();
740+
local_set
741+
.run_until(async {
742+
let client = TestClient::new();
743+
let agent = TestAgent::new();
744+
745+
let (agent_conn, _client_conn) = create_connection_pair(&client, &agent);
746+
747+
// First create a session
748+
let new_session_response = agent_conn
749+
.new_session(NewSessionRequest::new("/test"))
750+
.await
751+
.expect("new_session failed");
752+
753+
// Verify the session was created
754+
assert!(!new_session_response.session_id.0.is_empty());
755+
756+
// List sessions
757+
let list_response = agent_conn
758+
.list_sessions(agent_client_protocol_schema::ListSessionsRequest::new())
759+
.await
760+
.expect("list_sessions failed");
761+
762+
// Verify the response contains our session
763+
assert_eq!(list_response.sessions.len(), 1);
764+
assert_eq!(
765+
list_response.sessions[0].session_id,
766+
new_session_response.session_id
767+
);
768+
assert_eq!(
769+
list_response.sessions[0].cwd,
770+
std::path::PathBuf::from("/test")
771+
);
772+
assert!(list_response.next_cursor.is_none());
773+
})
774+
.await;
775+
}

0 commit comments

Comments
 (0)