Skip to content

Commit 9fdb14d

Browse files
Fix: correct history truncation, tool interruption on ctrl c (#718)
* fix: infinite loop when formatting paths (#713) * fix: use rust 1.85 (#710) * fix: use rust 1.85 * fix: typos failing on typ * fix: codepipeline failures when running cargo (#711) * fix: revert back to rust 1.84 (#712) * fix: infinite loop when formatting paths * remove tracing init from test * fix: correct history truncation, tool interruption on ctrl c * fix: lints
1 parent c306fab commit 9fdb14d

File tree

8 files changed

+268
-129
lines changed

8 files changed

+268
-129
lines changed

build-config/buildspec-linux-minimal.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ phases:
2020
# Install cargo
2121
- curl --retry 5 --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
2222
- . "$HOME/.cargo/env"
23+
- rustup toolchain install `cat rust-toolchain.toml | grep channel | cut -d '=' -f2 | tr -d ' "'`
2324
# Install cross only if the musl env var is set and not null
2425
- if [ ! -z "${AMAZON_Q_BUILD_MUSL:+x}" ]; then cargo install cross --git https://github.com/cross-rs/cross; fi
2526
# Install python/node via mise (https://mise.jdx.dev/continuous-integration.html)

build-config/buildspec-linux-ubuntu.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ phases:
1919
# Install cargo
2020
- curl --retry 5 --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
2121
- . "$HOME/.cargo/env"
22+
- rustup toolchain install `cat rust-toolchain.toml | grep channel | cut -d '=' -f2 | tr -d ' "'`
2223
# Install tauri-cli, required for building and bundling the desktop app
2324
- cargo install --version 1.6.2 tauri-cli
2425
# Install cross only if the musl env var is set and not null

crates/fig_api_client/src/clients/streaming_client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ impl StreamingClient {
159159
))
160160
},
161161
inner::Inner::Mock(events) => {
162-
let mut new_events = events.lock().unwrap().next().unwrap().clone();
162+
let mut new_events = events.lock().unwrap().next().unwrap_or_default().clone();
163163
new_events.reverse();
164164
Ok(SendMessageOutput::Mock(new_events))
165165
},

crates/q_cli/src/cli/chat/conversation_state.rs

Lines changed: 164 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ use tracing::{
2727
debug,
2828
error,
2929
info,
30-
trace,
3130
warn,
3231
};
3332

@@ -137,75 +136,64 @@ impl ConversationState {
137136
}
138137

139138
/// Updates the history so that, when non-empty, the following invariants are in place:
140-
/// 1. The history length is <= MAX_CONVERSATION_STATE_HISTORY_LEN if the next user message does
141-
/// not contain tool results. Oldest messages are dropped.
139+
/// 1. The history length is `<= MAX_CONVERSATION_STATE_HISTORY_LEN`. Oldest messages are
140+
/// dropped.
142141
/// 2. The first message is from the user, and does not contain tool results. Oldest messages
143142
/// are dropped.
144143
/// 3. The last message is from the assistant. The last message is dropped if it is from the
145144
/// user.
146145
pub fn fix_history(&mut self) {
147-
if self.history.is_empty() {
148-
return;
146+
// Trim the conversation history by finding the second oldest message from the user without
147+
// tool results - this will be the new oldest message in the history.
148+
if self.history.len() > MAX_CONVERSATION_STATE_HISTORY_LEN {
149+
match self
150+
.history
151+
.iter()
152+
.enumerate()
153+
// Skip the first message which should be from the user.
154+
.skip(1)
155+
.find(|(_, m)| -> bool {
156+
match m {
157+
ChatMessage::UserInputMessage(m) => {
158+
matches!(
159+
m.user_input_message_context.as_ref(),
160+
Some(ctx) if ctx.tool_results.as_ref().is_none_or(|v| v.is_empty())
161+
)
162+
},
163+
ChatMessage::AssistantResponseMessage(_) => false,
164+
}
165+
})
166+
.map(|v| v.0)
167+
{
168+
Some(i) => {
169+
debug!("removing the first {i} elements in the history");
170+
self.history.drain(..i);
171+
},
172+
None => {
173+
debug!("no valid starting user message found in the history, clearing");
174+
self.history.clear();
175+
176+
// Edge case: if the next message contains tool results, then we have to just
177+
// abandon them.
178+
match &mut self.next_message {
179+
Some(UserInputMessage {
180+
ref mut content,
181+
user_input_message_context: Some(ctx),
182+
..
183+
}) if ctx.tool_results.as_ref().is_some_and(|r| !r.is_empty()) => {
184+
*content = "The conversation history has overflowed, clearing state".to_string();
185+
ctx.tool_results.take();
186+
},
187+
_ => {},
188+
}
189+
},
190+
}
149191
}
150192

151-
// Invariant (3).
152193
if let Some(ChatMessage::UserInputMessage(msg)) = self.history.iter().last() {
153194
debug!(?msg, "last message in history is from the user, dropping");
154195
self.history.pop_back();
155196
}
156-
157-
// Check if the next message contains tool results - if it does, then return early.
158-
// Required in the case that the entire history consists of tool results; every message is
159-
// therefore required to avoid validation errors in the backend.
160-
match self.next_message.as_ref() {
161-
Some(UserInputMessage {
162-
user_input_message_context: Some(ctx),
163-
..
164-
}) if ctx.tool_results.as_ref().is_none_or(|r| r.is_empty()) => {
165-
debug!(
166-
curr_history_len = self.history.len(),
167-
max_history_len = MAX_CONVERSATION_STATE_HISTORY_LEN,
168-
"next user message does not contain tool results, removing messages if required"
169-
);
170-
},
171-
_ => {
172-
debug!("next user message contains tool results, not modifying the history");
173-
return;
174-
},
175-
}
176-
177-
// Invariant (1).
178-
while self.history.len() > MAX_CONVERSATION_STATE_HISTORY_LEN {
179-
self.history.pop_front();
180-
}
181-
182-
// Invariant (2).
183-
match self
184-
.history
185-
.iter()
186-
.enumerate()
187-
.find(|(_, m)| -> bool {
188-
match m {
189-
ChatMessage::UserInputMessage(m) => {
190-
matches!(
191-
m.user_input_message_context.as_ref(),
192-
Some(ctx) if ctx.tool_results.as_ref().is_none_or(|v| v.is_empty())
193-
)
194-
},
195-
ChatMessage::AssistantResponseMessage(_) => false,
196-
}
197-
})
198-
.map(|v| v.0)
199-
{
200-
Some(i) => {
201-
trace!("removing the first {i} elements in the history");
202-
self.history.drain(..i);
203-
},
204-
None => {
205-
trace!("no valid starting user message found in the history, clearing");
206-
self.history.clear();
207-
},
208-
}
209197
}
210198

211199
pub fn add_tool_results(&mut self, tool_results: Vec<ToolResult>) {
@@ -229,6 +217,7 @@ impl ConversationState {
229217
self.next_message = Some(msg);
230218
}
231219

220+
/// Sets the next user message with "cancelled" tool results.
232221
pub fn abandon_tool_use(&mut self, tools_to_be_abandoned: Vec<(String, super::tools::Tool)>, deny_input: String) {
233222
debug_assert!(self.next_message.is_none());
234223
let tool_results = tools_to_be_abandoned
@@ -260,6 +249,38 @@ impl ConversationState {
260249
self.next_message = Some(msg);
261250
}
262251

252+
/// Sets the next user message with "interrupted" tool results.
253+
pub fn interrupt_tool_use(&mut self, interrupted_tools: Vec<(String, super::tools::Tool)>, deny_input: String) {
254+
debug_assert!(self.next_message.is_none());
255+
let tool_results = interrupted_tools
256+
.into_iter()
257+
.map(|(tool_use_id, _)| ToolResult {
258+
tool_use_id,
259+
content: vec![ToolResultContentBlock::Text(
260+
"Tool use was interrupted by the user".to_string(),
261+
)],
262+
status: fig_api_client::model::ToolResultStatus::Error,
263+
})
264+
.collect::<Vec<_>>();
265+
let user_input_message_context = UserInputMessageContext {
266+
shell_state: None,
267+
env_state: Some(build_env_state()),
268+
tool_results: Some(tool_results),
269+
tools: if self.tools.is_empty() {
270+
None
271+
} else {
272+
Some(self.tools.clone())
273+
},
274+
..Default::default()
275+
};
276+
let msg = UserInputMessage {
277+
content: deny_input,
278+
user_input_message_context: Some(user_input_message_context),
279+
user_intent: None,
280+
};
281+
self.next_message = Some(msg);
282+
}
283+
263284
/// Returns a [FigConversationState] capable of being sent by
264285
/// [fig_api_client::StreamingClient] while preparing the current conversation state to be sent
265286
/// in the next message.
@@ -344,6 +365,7 @@ mod tests {
344365
use fig_api_client::model::{
345366
AssistantResponseMessage,
346367
ToolResultStatus,
368+
ToolUse,
347369
};
348370

349371
use super::*;
@@ -365,72 +387,83 @@ mod tests {
365387
println!("{env_state:?}");
366388
}
367389

390+
fn assert_conversation_state_invariants(state: FigConversationState, i: usize) {
391+
if let Some(Some(msg)) = state.history.as_ref().map(|h| h.first()) {
392+
assert!(
393+
matches!(msg, ChatMessage::UserInputMessage(_)),
394+
"{i}: First message in the history must be from the user, instead found: {:?}",
395+
msg
396+
);
397+
}
398+
if let Some(Some(msg)) = state.history.as_ref().map(|h| h.last()) {
399+
assert!(
400+
matches!(msg, ChatMessage::AssistantResponseMessage(_)),
401+
"{i}: Last message in the history must be from the assistant, instead found: {:?}",
402+
msg
403+
);
404+
// If the last message from the assistant contains tool uses, then the next user
405+
// message must contain tool results.
406+
match (state.user_input_message.user_input_message_context, msg) {
407+
(
408+
Some(ctx),
409+
ChatMessage::AssistantResponseMessage(AssistantResponseMessage {
410+
tool_uses: Some(tool_uses),
411+
..
412+
}),
413+
) if !tool_uses.is_empty() => {
414+
assert!(
415+
ctx.tool_results.is_some_and(|r| !r.is_empty()),
416+
"The user input message must contain tool results when the last assistant message contains tool uses"
417+
);
418+
},
419+
_ => {},
420+
}
421+
}
422+
423+
let actual_history_len = state.history.unwrap_or_default().len();
424+
assert!(
425+
actual_history_len <= MAX_CONVERSATION_STATE_HISTORY_LEN,
426+
"history should not extend past the max limit of {}, instead found length {}",
427+
MAX_CONVERSATION_STATE_HISTORY_LEN,
428+
actual_history_len
429+
);
430+
}
431+
368432
#[tokio::test]
369-
async fn test_conversation_state_history_handling() {
433+
async fn test_conversation_state_history_handling_truncation() {
370434
let mut conversation_state = ConversationState::new(load_tools().unwrap());
371435

372436
// First, build a large conversation history. We need to ensure that the order is always
373437
// User -> Assistant -> User -> Assistant ...and so on.
374438
conversation_state.append_new_user_message("start".to_string());
375-
for i in 0..=100 {
439+
for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) {
376440
let s = conversation_state.as_sendable_conversation_state();
377-
assert!(
378-
s.history
379-
.as_ref()
380-
.is_none_or(|h| h.first().is_none_or(|m| matches!(m, ChatMessage::UserInputMessage(_)))),
381-
"First message in the history must be from the user"
382-
);
383-
assert!(
384-
s.history.as_ref().is_none_or(|h| h
385-
.last()
386-
.is_none_or(|m| matches!(m, ChatMessage::AssistantResponseMessage(_)))),
387-
"Last message in the history must be from the assistant"
388-
);
441+
assert_conversation_state_invariants(s, i);
389442
conversation_state.push_assistant_message(AssistantResponseMessage {
390443
message_id: None,
391444
content: i.to_string(),
392445
tool_uses: None,
393446
});
394447
conversation_state.append_new_user_message(i.to_string());
395448
}
396-
397-
let s = conversation_state.as_sendable_conversation_state();
398-
assert_eq!(
399-
s.history.as_ref().unwrap().len(),
400-
MAX_CONVERSATION_STATE_HISTORY_LEN,
401-
"history should be capped at {}",
402-
MAX_CONVERSATION_STATE_HISTORY_LEN
403-
);
404-
let first_msg = s.history.as_ref().unwrap().first().unwrap();
405-
match first_msg {
406-
ChatMessage::UserInputMessage(_) => {},
407-
other @ ChatMessage::AssistantResponseMessage(_) => {
408-
panic!("First message should be from the user, instead found {:?}", other)
409-
},
410-
}
411-
let last_msg = s.history.as_ref().unwrap().iter().last().unwrap();
412-
match last_msg {
413-
ChatMessage::AssistantResponseMessage(assistant_response_message) => {
414-
assert_eq!(assistant_response_message.content, "100");
415-
},
416-
other @ ChatMessage::UserInputMessage(_) => {
417-
panic!("Last message should be from the assistant, instead found {:?}", other)
418-
},
419-
}
420449
}
421450

422451
#[tokio::test]
423452
async fn test_conversation_state_history_handling_with_tool_results() {
424-
let mut conversation_state = ConversationState::new(load_tools().unwrap());
425-
426453
// Build a long conversation history of tool use results.
454+
let mut conversation_state = ConversationState::new(load_tools().unwrap());
427455
conversation_state.append_new_user_message("start".to_string());
428456
for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) {
429-
let _ = conversation_state.as_sendable_conversation_state();
457+
let s = conversation_state.as_sendable_conversation_state();
458+
assert_conversation_state_invariants(s, i);
430459
conversation_state.push_assistant_message(AssistantResponseMessage {
431460
message_id: None,
432461
content: i.to_string(),
433-
tool_uses: None,
462+
tool_uses: Some(vec![ToolUse {
463+
tool_use_id: "tool_id".to_string(),
464+
name: "tool name".to_string(),
465+
input: aws_smithy_types::Document::Null,
466+
}]),
434467
});
435468
conversation_state.add_tool_results(vec![ToolResult {
436469
tool_use_id: "tool_id".to_string(),
@@ -439,13 +472,35 @@ mod tests {
439472
}]);
440473
}
441474

442-
let s = conversation_state.as_sendable_conversation_state();
443-
let actual_history_len = s.history.as_ref().unwrap().len();
444-
assert!(
445-
actual_history_len > MAX_CONVERSATION_STATE_HISTORY_LEN,
446-
"history should extend past the max limit of {}, instead found length {}",
447-
MAX_CONVERSATION_STATE_HISTORY_LEN,
448-
actual_history_len
449-
);
475+
// Build a long conversation history of user messages mixed in with tool results.
476+
let mut conversation_state = ConversationState::new(load_tools().unwrap());
477+
conversation_state.append_new_user_message("start".to_string());
478+
for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) {
479+
let s = conversation_state.as_sendable_conversation_state();
480+
assert_conversation_state_invariants(s, i);
481+
if i % 3 == 0 {
482+
conversation_state.push_assistant_message(AssistantResponseMessage {
483+
message_id: None,
484+
content: i.to_string(),
485+
tool_uses: Some(vec![ToolUse {
486+
tool_use_id: "tool_id".to_string(),
487+
name: "tool name".to_string(),
488+
input: aws_smithy_types::Document::Null,
489+
}]),
490+
});
491+
conversation_state.add_tool_results(vec![ToolResult {
492+
tool_use_id: "tool_id".to_string(),
493+
content: vec![],
494+
status: ToolResultStatus::Success,
495+
}]);
496+
} else {
497+
conversation_state.push_assistant_message(AssistantResponseMessage {
498+
message_id: None,
499+
content: i.to_string(),
500+
tool_uses: None,
501+
});
502+
conversation_state.append_new_user_message(i.to_string());
503+
}
504+
}
450505
}
451506
}

0 commit comments

Comments
 (0)