@@ -238,25 +238,42 @@ impl<P: LlmProvider + Clone + 'static, C: Channel, T: ToolExecutor> Agent<P, C,
238238 } ) ;
239239 }
240240
241+ #[ cfg( test) ]
241242 pub ( super ) async fn inject_semantic_recall (
242243 & mut self ,
243244 query : & str ,
244245 token_budget : usize ,
245246 ) -> Result < ( ) , super :: error:: AgentError > {
246247 self . remove_recall_messages ( ) ;
247248
248- let Some ( memory) = & self . memory_state . memory else {
249- return Ok ( ( ) ) ;
249+ if let Some ( msg) =
250+ Self :: fetch_semantic_recall ( & self . memory_state , query, token_budget) . await ?
251+ {
252+ if self . messages . len ( ) > 1 {
253+ self . messages . insert ( 1 , msg) ;
254+ }
255+ }
256+
257+ Ok ( ( ) )
258+ }
259+
260+ async fn fetch_semantic_recall (
261+ memory_state : & super :: MemoryState < P > ,
262+ query : & str ,
263+ token_budget : usize ,
264+ ) -> Result < Option < Message > , super :: error:: AgentError > {
265+ let Some ( memory) = & memory_state. memory else {
266+ return Ok ( None ) ;
250267 } ;
251- if self . memory_state . recall_limit == 0 || token_budget == 0 {
252- return Ok ( ( ) ) ;
268+ if memory_state. recall_limit == 0 || token_budget == 0 {
269+ return Ok ( None ) ;
253270 }
254271
255272 let recalled = memory
256- . recall ( query, self . memory_state . recall_limit , None )
273+ . recall ( query, memory_state. recall_limit , None )
257274 . await ?;
258275 if recalled. is_empty ( ) {
259- return Ok ( ( ) ) ;
276+ return Ok ( None ) ;
260277 }
261278
262279 let mut recall_text = String :: from ( RECALL_PREFIX ) ;
@@ -277,17 +294,14 @@ impl<P: LlmProvider + Clone + 'static, C: Channel, T: ToolExecutor> Agent<P, C,
277294 tokens_used += entry_tokens;
278295 }
279296
280- if tokens_used > estimate_tokens ( RECALL_PREFIX ) && self . messages . len ( ) > 1 {
281- self . messages . insert (
282- 1 ,
283- Message :: from_parts (
284- Role :: System ,
285- vec ! [ MessagePart :: Recall { text: recall_text } ] ,
286- ) ,
287- ) ;
297+ if tokens_used > estimate_tokens ( RECALL_PREFIX ) {
298+ Ok ( Some ( Message :: from_parts (
299+ Role :: System ,
300+ vec ! [ MessagePart :: Recall { text: recall_text } ] ,
301+ ) ) )
302+ } else {
303+ Ok ( None )
288304 }
289-
290- Ok ( ( ) )
291305 }
292306
293307 pub ( super ) fn remove_code_context_messages ( & mut self ) {
@@ -335,31 +349,47 @@ impl<P: LlmProvider + Clone + 'static, C: Channel, T: ToolExecutor> Agent<P, C,
335349 } ) ;
336350 }
337351
352+ #[ cfg( test) ]
338353 async fn inject_cross_session_context (
339354 & mut self ,
340355 query : & str ,
341356 token_budget : usize ,
342357 ) -> Result < ( ) , super :: error:: AgentError > {
343358 self . remove_cross_session_messages ( ) ;
344359
345- let ( Some ( memory) , Some ( cid) ) =
346- ( & self . memory_state . memory , self . memory_state . conversation_id )
347- else {
348- return Ok ( ( ) ) ;
360+ if let Some ( msg) =
361+ Self :: fetch_cross_session ( & self . memory_state , query, token_budget) . await ?
362+ {
363+ if self . messages . len ( ) > 1 {
364+ self . messages . insert ( 1 , msg) ;
365+ tracing:: debug!( "injected cross-session context" ) ;
366+ }
367+ }
368+
369+ Ok ( ( ) )
370+ }
371+
372+ async fn fetch_cross_session (
373+ memory_state : & super :: MemoryState < P > ,
374+ query : & str ,
375+ token_budget : usize ,
376+ ) -> Result < Option < Message > , super :: error:: AgentError > {
377+ let ( Some ( memory) , Some ( cid) ) = ( & memory_state. memory , memory_state. conversation_id ) else {
378+ return Ok ( None ) ;
349379 } ;
350380 if token_budget == 0 {
351- return Ok ( ( ) ) ;
381+ return Ok ( None ) ;
352382 }
353383
354- let threshold = self . memory_state . cross_session_score_threshold ;
384+ let threshold = memory_state. cross_session_score_threshold ;
355385 let results: Vec < _ > = memory
356386 . search_session_summaries ( query, 5 , Some ( cid) )
357387 . await ?
358388 . into_iter ( )
359389 . filter ( |r| r. score >= threshold)
360390 . collect ( ) ;
361391 if results. is_empty ( ) {
362- return Ok ( ( ) ) ;
392+ return Ok ( None ) ;
363393 }
364394
365395 let mut text = String :: from ( CROSS_SESSION_PREFIX ) ;
@@ -375,35 +405,47 @@ impl<P: LlmProvider + Clone + 'static, C: Channel, T: ToolExecutor> Agent<P, C,
375405 tokens_used += cost;
376406 }
377407
378- if tokens_used > estimate_tokens ( CROSS_SESSION_PREFIX ) && self . messages . len ( ) > 1 {
379- self . messages . insert (
380- 1 ,
381- Message :: from_parts ( Role :: System , vec ! [ MessagePart :: CrossSession { text } ] ) ,
382- ) ;
383- tracing:: debug!( tokens_used, "injected cross-session context" ) ;
408+ if tokens_used > estimate_tokens ( CROSS_SESSION_PREFIX ) {
409+ Ok ( Some ( Message :: from_parts (
410+ Role :: System ,
411+ vec ! [ MessagePart :: CrossSession { text } ] ,
412+ ) ) )
413+ } else {
414+ Ok ( None )
384415 }
385-
386- Ok ( ( ) )
387416 }
388417
418+ #[ cfg( test) ]
389419 async fn inject_summaries (
390420 & mut self ,
391421 token_budget : usize ,
392422 ) -> Result < ( ) , super :: error:: AgentError > {
393423 self . remove_summary_messages ( ) ;
394424
395- let ( Some ( memory) , Some ( cid) ) =
396- ( & self . memory_state . memory , self . memory_state . conversation_id )
397- else {
398- return Ok ( ( ) ) ;
425+ if let Some ( msg) = Self :: fetch_summaries ( & self . memory_state , token_budget) . await ? {
426+ if self . messages . len ( ) > 1 {
427+ self . messages . insert ( 1 , msg) ;
428+ tracing:: debug!( "injected summaries into context" ) ;
429+ }
430+ }
431+
432+ Ok ( ( ) )
433+ }
434+
435+ async fn fetch_summaries (
436+ memory_state : & super :: MemoryState < P > ,
437+ token_budget : usize ,
438+ ) -> Result < Option < Message > , super :: error:: AgentError > {
439+ let ( Some ( memory) , Some ( cid) ) = ( & memory_state. memory , memory_state. conversation_id ) else {
440+ return Ok ( None ) ;
399441 } ;
400442 if token_budget == 0 {
401- return Ok ( ( ) ) ;
443+ return Ok ( None ) ;
402444 }
403445
404446 let summaries = memory. load_summaries ( cid) . await ?;
405447 if summaries. is_empty ( ) {
406- return Ok ( ( ) ) ;
448+ return Ok ( None ) ;
407449 }
408450
409451 let mut summary_text = String :: from ( SUMMARY_PREFIX ) ;
@@ -422,18 +464,14 @@ impl<P: LlmProvider + Clone + 'static, C: Channel, T: ToolExecutor> Agent<P, C,
422464 tokens_used += cost;
423465 }
424466
425- if tokens_used > estimate_tokens ( SUMMARY_PREFIX ) && self . messages . len ( ) > 1 {
426- self . messages . insert (
427- 1 ,
428- Message :: from_parts (
429- Role :: System ,
430- vec ! [ MessagePart :: Summary { text: summary_text } ] ,
431- ) ,
432- ) ;
433- tracing:: debug!( tokens_used, "injected summaries into context" ) ;
467+ if tokens_used > estimate_tokens ( SUMMARY_PREFIX ) {
468+ Ok ( Some ( Message :: from_parts (
469+ Role :: System ,
470+ vec ! [ MessagePart :: Summary { text: summary_text } ] ,
471+ ) ) )
472+ } else {
473+ Ok ( None )
434474 }
435-
436- Ok ( ( ) )
437475 }
438476
439477 fn trim_messages_to_budget ( & mut self , token_budget : usize ) {
@@ -485,16 +523,45 @@ impl<P: LlmProvider + Clone + 'static, C: Channel, T: ToolExecutor> Agent<P, C,
485523 let system_prompt = self . messages . first ( ) . map_or ( "" , |m| m. content . as_str ( ) ) ;
486524 let alloc = budget. allocate ( system_prompt, & self . skill_state . last_skills_prompt ) ;
487525
488- self . inject_summaries ( alloc. summaries ) . await ?;
526+ // Remove stale injected messages before concurrent fetch
527+ self . remove_summary_messages ( ) ;
528+ self . remove_cross_session_messages ( ) ;
529+ self . remove_recall_messages ( ) ;
530+ #[ cfg( feature = "index" ) ]
531+ self . remove_code_context_messages ( ) ;
489532
490- self . inject_cross_session_context ( query, alloc. cross_session )
491- . await ?;
533+ // Fetch all context sources concurrently
534+ #[ cfg( not( feature = "index" ) ) ]
535+ let ( summaries_msg, cross_session_msg, recall_msg) = tokio:: try_join!(
536+ Self :: fetch_summaries( & self . memory_state, alloc. summaries) ,
537+ Self :: fetch_cross_session( & self . memory_state, query, alloc. cross_session) ,
538+ Self :: fetch_semantic_recall( & self . memory_state, query, alloc. semantic_recall) ,
539+ ) ?;
492540
493- self . inject_semantic_recall ( query, alloc. semantic_recall )
494- . await ?;
541+ #[ cfg( feature = "index" ) ]
542+ let ( summaries_msg, cross_session_msg, recall_msg, code_rag_text) = tokio:: try_join!(
543+ Self :: fetch_summaries( & self . memory_state, alloc. summaries) ,
544+ Self :: fetch_cross_session( & self . memory_state, query, alloc. cross_session) ,
545+ Self :: fetch_semantic_recall( & self . memory_state, query, alloc. semantic_recall) ,
546+ Self :: fetch_code_rag( & self . index, query, alloc. code_context) ,
547+ ) ?;
548+
549+ // Insert fetched messages (order: recall, cross-session, summaries at position 1)
550+ if let Some ( msg) = recall_msg. filter ( |_| self . messages . len ( ) > 1 ) {
551+ self . messages . insert ( 1 , msg) ;
552+ }
553+ if let Some ( msg) = cross_session_msg. filter ( |_| self . messages . len ( ) > 1 ) {
554+ self . messages . insert ( 1 , msg) ;
555+ }
556+ if let Some ( msg) = summaries_msg. filter ( |_| self . messages . len ( ) > 1 ) {
557+ self . messages . insert ( 1 , msg) ;
558+ tracing:: debug!( "injected summaries into context" ) ;
559+ }
495560
496561 #[ cfg( feature = "index" ) ]
497- self . inject_code_rag ( query, alloc. code_context ) . await ?;
562+ if let Some ( text) = code_rag_text {
563+ self . inject_code_context ( & text) ;
564+ }
498565
499566 self . trim_messages_to_budget ( alloc. recent_history ) ;
500567
0 commit comments