Skip to content

Commit 5d319c7

Browse files
authored
perf: parallelize context preparation with try_join! (#331)
* perf: parallelize context preparation with try_join! Extract fetch_summaries, fetch_cross_session, fetch_semantic_recall, fetch_code_rag as standalone async fns returning data instead of mutating self. Run concurrently via tokio::try_join! in prepare_context. Closes #317 * fix: resolve clippy collapsible_if and large_futures warnings
1 parent 5f543ee commit 5d319c7

File tree

3 files changed

+133
-66
lines changed

3 files changed

+133
-66
lines changed

crates/zeph-core/src/agent/context.rs

Lines changed: 121 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -238,25 +238,42 @@ impl<P: LlmProvider + Clone + 'static, C: Channel, T: ToolExecutor> Agent<P, C,
238238
});
239239
}
240240

241+
#[cfg(test)]
241242
pub(super) async fn inject_semantic_recall(
242243
&mut self,
243244
query: &str,
244245
token_budget: usize,
245246
) -> Result<(), super::error::AgentError> {
246247
self.remove_recall_messages();
247248

248-
let Some(memory) = &self.memory_state.memory else {
249-
return Ok(());
249+
if let Some(msg) =
250+
Self::fetch_semantic_recall(&self.memory_state, query, token_budget).await?
251+
{
252+
if self.messages.len() > 1 {
253+
self.messages.insert(1, msg);
254+
}
255+
}
256+
257+
Ok(())
258+
}
259+
260+
async fn fetch_semantic_recall(
261+
memory_state: &super::MemoryState<P>,
262+
query: &str,
263+
token_budget: usize,
264+
) -> Result<Option<Message>, super::error::AgentError> {
265+
let Some(memory) = &memory_state.memory else {
266+
return Ok(None);
250267
};
251-
if self.memory_state.recall_limit == 0 || token_budget == 0 {
252-
return Ok(());
268+
if memory_state.recall_limit == 0 || token_budget == 0 {
269+
return Ok(None);
253270
}
254271

255272
let recalled = memory
256-
.recall(query, self.memory_state.recall_limit, None)
273+
.recall(query, memory_state.recall_limit, None)
257274
.await?;
258275
if recalled.is_empty() {
259-
return Ok(());
276+
return Ok(None);
260277
}
261278

262279
let mut recall_text = String::from(RECALL_PREFIX);
@@ -277,17 +294,14 @@ impl<P: LlmProvider + Clone + 'static, C: Channel, T: ToolExecutor> Agent<P, C,
277294
tokens_used += entry_tokens;
278295
}
279296

280-
if tokens_used > estimate_tokens(RECALL_PREFIX) && self.messages.len() > 1 {
281-
self.messages.insert(
282-
1,
283-
Message::from_parts(
284-
Role::System,
285-
vec![MessagePart::Recall { text: recall_text }],
286-
),
287-
);
297+
if tokens_used > estimate_tokens(RECALL_PREFIX) {
298+
Ok(Some(Message::from_parts(
299+
Role::System,
300+
vec![MessagePart::Recall { text: recall_text }],
301+
)))
302+
} else {
303+
Ok(None)
288304
}
289-
290-
Ok(())
291305
}
292306

293307
pub(super) fn remove_code_context_messages(&mut self) {
@@ -335,31 +349,47 @@ impl<P: LlmProvider + Clone + 'static, C: Channel, T: ToolExecutor> Agent<P, C,
335349
});
336350
}
337351

352+
#[cfg(test)]
338353
async fn inject_cross_session_context(
339354
&mut self,
340355
query: &str,
341356
token_budget: usize,
342357
) -> Result<(), super::error::AgentError> {
343358
self.remove_cross_session_messages();
344359

345-
let (Some(memory), Some(cid)) =
346-
(&self.memory_state.memory, self.memory_state.conversation_id)
347-
else {
348-
return Ok(());
360+
if let Some(msg) =
361+
Self::fetch_cross_session(&self.memory_state, query, token_budget).await?
362+
{
363+
if self.messages.len() > 1 {
364+
self.messages.insert(1, msg);
365+
tracing::debug!("injected cross-session context");
366+
}
367+
}
368+
369+
Ok(())
370+
}
371+
372+
async fn fetch_cross_session(
373+
memory_state: &super::MemoryState<P>,
374+
query: &str,
375+
token_budget: usize,
376+
) -> Result<Option<Message>, super::error::AgentError> {
377+
let (Some(memory), Some(cid)) = (&memory_state.memory, memory_state.conversation_id) else {
378+
return Ok(None);
349379
};
350380
if token_budget == 0 {
351-
return Ok(());
381+
return Ok(None);
352382
}
353383

354-
let threshold = self.memory_state.cross_session_score_threshold;
384+
let threshold = memory_state.cross_session_score_threshold;
355385
let results: Vec<_> = memory
356386
.search_session_summaries(query, 5, Some(cid))
357387
.await?
358388
.into_iter()
359389
.filter(|r| r.score >= threshold)
360390
.collect();
361391
if results.is_empty() {
362-
return Ok(());
392+
return Ok(None);
363393
}
364394

365395
let mut text = String::from(CROSS_SESSION_PREFIX);
@@ -375,35 +405,47 @@ impl<P: LlmProvider + Clone + 'static, C: Channel, T: ToolExecutor> Agent<P, C,
375405
tokens_used += cost;
376406
}
377407

378-
if tokens_used > estimate_tokens(CROSS_SESSION_PREFIX) && self.messages.len() > 1 {
379-
self.messages.insert(
380-
1,
381-
Message::from_parts(Role::System, vec![MessagePart::CrossSession { text }]),
382-
);
383-
tracing::debug!(tokens_used, "injected cross-session context");
408+
if tokens_used > estimate_tokens(CROSS_SESSION_PREFIX) {
409+
Ok(Some(Message::from_parts(
410+
Role::System,
411+
vec![MessagePart::CrossSession { text }],
412+
)))
413+
} else {
414+
Ok(None)
384415
}
385-
386-
Ok(())
387416
}
388417

418+
#[cfg(test)]
389419
async fn inject_summaries(
390420
&mut self,
391421
token_budget: usize,
392422
) -> Result<(), super::error::AgentError> {
393423
self.remove_summary_messages();
394424

395-
let (Some(memory), Some(cid)) =
396-
(&self.memory_state.memory, self.memory_state.conversation_id)
397-
else {
398-
return Ok(());
425+
if let Some(msg) = Self::fetch_summaries(&self.memory_state, token_budget).await? {
426+
if self.messages.len() > 1 {
427+
self.messages.insert(1, msg);
428+
tracing::debug!("injected summaries into context");
429+
}
430+
}
431+
432+
Ok(())
433+
}
434+
435+
async fn fetch_summaries(
436+
memory_state: &super::MemoryState<P>,
437+
token_budget: usize,
438+
) -> Result<Option<Message>, super::error::AgentError> {
439+
let (Some(memory), Some(cid)) = (&memory_state.memory, memory_state.conversation_id) else {
440+
return Ok(None);
399441
};
400442
if token_budget == 0 {
401-
return Ok(());
443+
return Ok(None);
402444
}
403445

404446
let summaries = memory.load_summaries(cid).await?;
405447
if summaries.is_empty() {
406-
return Ok(());
448+
return Ok(None);
407449
}
408450

409451
let mut summary_text = String::from(SUMMARY_PREFIX);
@@ -422,18 +464,14 @@ impl<P: LlmProvider + Clone + 'static, C: Channel, T: ToolExecutor> Agent<P, C,
422464
tokens_used += cost;
423465
}
424466

425-
if tokens_used > estimate_tokens(SUMMARY_PREFIX) && self.messages.len() > 1 {
426-
self.messages.insert(
427-
1,
428-
Message::from_parts(
429-
Role::System,
430-
vec![MessagePart::Summary { text: summary_text }],
431-
),
432-
);
433-
tracing::debug!(tokens_used, "injected summaries into context");
467+
if tokens_used > estimate_tokens(SUMMARY_PREFIX) {
468+
Ok(Some(Message::from_parts(
469+
Role::System,
470+
vec![MessagePart::Summary { text: summary_text }],
471+
)))
472+
} else {
473+
Ok(None)
434474
}
435-
436-
Ok(())
437475
}
438476

439477
fn trim_messages_to_budget(&mut self, token_budget: usize) {
@@ -485,16 +523,45 @@ impl<P: LlmProvider + Clone + 'static, C: Channel, T: ToolExecutor> Agent<P, C,
485523
let system_prompt = self.messages.first().map_or("", |m| m.content.as_str());
486524
let alloc = budget.allocate(system_prompt, &self.skill_state.last_skills_prompt);
487525

488-
self.inject_summaries(alloc.summaries).await?;
526+
// Remove stale injected messages before concurrent fetch
527+
self.remove_summary_messages();
528+
self.remove_cross_session_messages();
529+
self.remove_recall_messages();
530+
#[cfg(feature = "index")]
531+
self.remove_code_context_messages();
489532

490-
self.inject_cross_session_context(query, alloc.cross_session)
491-
.await?;
533+
// Fetch all context sources concurrently
534+
#[cfg(not(feature = "index"))]
535+
let (summaries_msg, cross_session_msg, recall_msg) = tokio::try_join!(
536+
Self::fetch_summaries(&self.memory_state, alloc.summaries),
537+
Self::fetch_cross_session(&self.memory_state, query, alloc.cross_session),
538+
Self::fetch_semantic_recall(&self.memory_state, query, alloc.semantic_recall),
539+
)?;
492540

493-
self.inject_semantic_recall(query, alloc.semantic_recall)
494-
.await?;
541+
#[cfg(feature = "index")]
542+
let (summaries_msg, cross_session_msg, recall_msg, code_rag_text) = tokio::try_join!(
543+
Self::fetch_summaries(&self.memory_state, alloc.summaries),
544+
Self::fetch_cross_session(&self.memory_state, query, alloc.cross_session),
545+
Self::fetch_semantic_recall(&self.memory_state, query, alloc.semantic_recall),
546+
Self::fetch_code_rag(&self.index, query, alloc.code_context),
547+
)?;
548+
549+
// Insert fetched messages (order: recall, cross-session, summaries at position 1)
550+
if let Some(msg) = recall_msg.filter(|_| self.messages.len() > 1) {
551+
self.messages.insert(1, msg);
552+
}
553+
if let Some(msg) = cross_session_msg.filter(|_| self.messages.len() > 1) {
554+
self.messages.insert(1, msg);
555+
}
556+
if let Some(msg) = summaries_msg.filter(|_| self.messages.len() > 1) {
557+
self.messages.insert(1, msg);
558+
tracing::debug!("injected summaries into context");
559+
}
495560

496561
#[cfg(feature = "index")]
497-
self.inject_code_rag(query, alloc.code_context).await?;
562+
if let Some(text) = code_rag_text {
563+
self.inject_code_context(&text);
564+
}
498565

499566
self.trim_messages_to_budget(alloc.recent_history);
500567

crates/zeph-core/src/agent/index.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
use super::{Agent, Channel, LlmProvider, ToolExecutor};
22

33
impl<P: LlmProvider + Clone + 'static, C: Channel, T: ToolExecutor> Agent<P, C, T> {
4-
pub(super) async fn inject_code_rag(
5-
&mut self,
4+
pub(super) async fn fetch_code_rag(
5+
index: &super::IndexState<P>,
66
query: &str,
77
token_budget: usize,
8-
) -> Result<(), super::error::AgentError> {
9-
let Some(retriever) = &self.index.retriever else {
10-
return Ok(());
8+
) -> Result<Option<String>, super::error::AgentError> {
9+
let Some(retriever) = &index.retriever else {
10+
return Ok(None);
1111
};
1212
if token_budget == 0 {
13-
return Ok(());
13+
return Ok(None);
1414
}
1515

1616
let result = retriever
@@ -19,17 +19,17 @@ impl<P: LlmProvider + Clone + 'static, C: Channel, T: ToolExecutor> Agent<P, C,
1919
.map_err(|e| super::error::AgentError::Other(format!("{e:#}")))?;
2020
let context_text = zeph_index::retriever::format_as_context(&result);
2121

22-
if !context_text.is_empty() {
23-
self.inject_code_context(&context_text);
22+
if context_text.is_empty() {
23+
Ok(None)
24+
} else {
2425
tracing::debug!(
2526
strategy = ?result.strategy,
2627
chunks = result.chunks.len(),
2728
tokens = result.total_tokens,
28-
"code context injected"
29+
"code context fetched"
2930
);
31+
Ok(Some(context_text))
3032
}
31-
32-
Ok(())
3333
}
3434
}
3535

src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ async fn main() -> anyhow::Result<()> {
551551

552552
warmup_provider(&warmup_provider_clone).await;
553553
tokio::spawn(forward_status_to_stderr(status_rx));
554-
agent.run().await
554+
Box::pin(agent.run()).await
555555
}
556556

557557
async fn forward_status_to_stderr(mut rx: tokio::sync::mpsc::UnboundedReceiver<String>) {

0 commit comments

Comments
 (0)