@@ -133,7 +133,7 @@ impl Client for TestClient {
133133
134134#[ derive( Clone ) ]
135135struct TestAgent {
136- sessions : Arc < Mutex < std:: collections:: HashSet < SessionId > > > ,
136+ sessions : Arc < Mutex < std:: collections:: HashMap < SessionId , std :: path :: PathBuf > > > ,
137137 prompts_received : Arc < Mutex < Vec < PromptReceived > > > ,
138138 cancellations_received : Arc < Mutex < Vec < SessionId > > > ,
139139 extension_notifications : Arc < Mutex < Vec < ( String , ExtNotification ) > > > ,
@@ -144,7 +144,7 @@ type PromptReceived = (SessionId, Vec<ContentBlock>);
144144impl TestAgent {
145145 fn new ( ) -> Self {
146146 Self {
147- sessions : Arc :: new ( Mutex :: new ( std:: collections:: HashSet :: new ( ) ) ) ,
147+ sessions : Arc :: new ( Mutex :: new ( std:: collections:: HashMap :: new ( ) ) ) ,
148148 prompts_received : Arc :: new ( Mutex :: new ( Vec :: new ( ) ) ) ,
149149 cancellations_received : Arc :: new ( Mutex :: new ( Vec :: new ( ) ) ) ,
150150 extension_notifications : Arc :: new ( Mutex :: new ( Vec :: new ( ) ) ) ,
@@ -163,9 +163,12 @@ impl Agent for TestAgent {
163163 Ok ( AuthenticateResponse :: default ( ) )
164164 }
165165
166- async fn new_session ( & self , _arguments : NewSessionRequest ) -> Result < NewSessionResponse > {
166+ async fn new_session ( & self , arguments : NewSessionRequest ) -> Result < NewSessionResponse > {
167167 let session_id = SessionId :: new ( "test-session-123" ) ;
168- self . sessions . lock ( ) . unwrap ( ) . insert ( session_id. clone ( ) ) ;
168+ self . sessions
169+ . lock ( )
170+ . unwrap ( )
171+ . insert ( session_id. clone ( ) , arguments. cwd ) ;
169172 Ok ( NewSessionResponse :: new ( session_id) )
170173 }
171174
@@ -210,8 +213,30 @@ impl Agent for TestAgent {
210213 & self ,
211214 _args : agent_client_protocol_schema:: ListSessionsRequest ,
212215 ) -> Result < agent_client_protocol_schema:: ListSessionsResponse > {
216+ let sessions = self . sessions . lock ( ) . unwrap ( ) ;
217+ let session_infos: Vec < _ > = sessions
218+ . iter ( )
219+ . map ( |( id, cwd) | {
220+ agent_client_protocol_schema:: SessionInfo :: new ( id. clone ( ) , cwd. clone ( ) )
221+ } )
222+ . collect ( ) ;
213223 Ok ( agent_client_protocol_schema:: ListSessionsResponse :: new (
214- vec ! [ ] ,
224+ session_infos,
225+ ) )
226+ }
227+
228+ #[ cfg( feature = "unstable_session_fork" ) ]
229+ async fn fork_session (
230+ & self ,
231+ args : agent_client_protocol_schema:: ForkSessionRequest ,
232+ ) -> Result < agent_client_protocol_schema:: ForkSessionResponse > {
233+ let new_session_id = SessionId :: new ( format ! ( "fork-of-{}" , args. session_id. 0 ) ) ;
234+ self . sessions
235+ . lock ( )
236+ . unwrap ( )
237+ . insert ( new_session_id. clone ( ) , args. cwd ) ;
238+ Ok ( agent_client_protocol_schema:: ForkSessionResponse :: new (
239+ new_session_id,
215240 ) )
216241 }
217242
@@ -665,3 +690,86 @@ async fn test_extension_methods_and_notifications() {
665690 } )
666691 . await ;
667692}
693+
694+ #[ cfg( feature = "unstable_session_fork" ) ]
695+ #[ tokio:: test]
696+ async fn test_fork_session ( ) {
697+ let local_set = tokio:: task:: LocalSet :: new ( ) ;
698+ local_set
699+ . run_until ( async {
700+ let client = TestClient :: new ( ) ;
701+ let agent = TestAgent :: new ( ) ;
702+
703+ let ( agent_conn, _client_conn) = create_connection_pair ( & client, & agent) ;
704+
705+ // First create a session
706+ let new_session_response = agent_conn
707+ . new_session ( NewSessionRequest :: new ( "/test" ) )
708+ . await
709+ . expect ( "new_session failed" ) ;
710+
711+ let original_session_id = new_session_response. session_id ;
712+
713+ // Fork the session
714+ let fork_response = agent_conn
715+ . fork_session ( agent_client_protocol_schema:: ForkSessionRequest :: new (
716+ original_session_id. clone ( ) ,
717+ "/test" ,
718+ ) )
719+ . await
720+ . expect ( "fork_session failed" ) ;
721+
722+ // Verify the forked session has a different ID
723+ assert_ne ! ( fork_response. session_id, original_session_id) ;
724+ assert_eq ! (
725+ fork_response. session_id. 0 . as_ref( ) ,
726+ format!( "fork-of-{}" , original_session_id. 0 )
727+ ) ;
728+
729+ // Verify the forked session was added to the agent's sessions
730+ let sessions = agent. sessions . lock ( ) . unwrap ( ) ;
731+ assert ! ( sessions. contains_key( & fork_response. session_id) ) ;
732+ } )
733+ . await ;
734+ }
735+
736+ #[ cfg( feature = "unstable_session_list" ) ]
737+ #[ tokio:: test]
738+ async fn test_list_sessions ( ) {
739+ let local_set = tokio:: task:: LocalSet :: new ( ) ;
740+ local_set
741+ . run_until ( async {
742+ let client = TestClient :: new ( ) ;
743+ let agent = TestAgent :: new ( ) ;
744+
745+ let ( agent_conn, _client_conn) = create_connection_pair ( & client, & agent) ;
746+
747+ // First create a session
748+ let new_session_response = agent_conn
749+ . new_session ( NewSessionRequest :: new ( "/test" ) )
750+ . await
751+ . expect ( "new_session failed" ) ;
752+
753+ // Verify the session was created
754+ assert ! ( !new_session_response. session_id. 0 . is_empty( ) ) ;
755+
756+ // List sessions
757+ let list_response = agent_conn
758+ . list_sessions ( agent_client_protocol_schema:: ListSessionsRequest :: new ( ) )
759+ . await
760+ . expect ( "list_sessions failed" ) ;
761+
762+ // Verify the response contains our session
763+ assert_eq ! ( list_response. sessions. len( ) , 1 ) ;
764+ assert_eq ! (
765+ list_response. sessions[ 0 ] . session_id,
766+ new_session_response. session_id
767+ ) ;
768+ assert_eq ! (
769+ list_response. sessions[ 0 ] . cwd,
770+ std:: path:: PathBuf :: from( "/test" )
771+ ) ;
772+ assert ! ( list_response. next_cursor. is_none( ) ) ;
773+ } )
774+ . await ;
775+ }
0 commit comments