Skip to content

Commit 0c23526

Browse files
[fix] Fixes issues with Tool Input parsing. (#2986)
* [fix] Fixes issues with Tool Input parsing. * Ocassionally the model will generate a tool use which parameters are not a valid json. When this happens it corrupts the conversation history. * Here we first avoid storing the tool use and add the propert validation logic to the conversation history. * adds validation logic to safety * [fix] Update to use a new RecvErrorKind instead of custom error handling. * [fix] Gives visual hint to the user, that request is being retried. --------- Co-authored-by: Kenneth S. <[email protected]>
1 parent 814f149 commit 0c23526

File tree

2 files changed

+165
-1
lines changed

2 files changed

+165
-1
lines changed

crates/chat-cli/src/cli/chat/mod.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2694,6 +2694,53 @@ impl ChatSession {
26942694
.await?,
26952695
));
26962696
},
2697+
RecvErrorKind::ToolValidationError {
2698+
tool_use_id,
2699+
name,
2700+
message,
2701+
error_message,
2702+
} => {
2703+
self.send_chat_telemetry(
2704+
os,
2705+
TelemetryResult::Failed,
2706+
Some(reason),
2707+
Some(reason_desc),
2708+
status_code,
2709+
false, // We retry the request, so don't end the current turn yet.
2710+
)
2711+
.await;
2712+
2713+
error!(
2714+
recv_error.request_metadata.request_id,
2715+
tool_use_id, name, error_message, "Tool validation failed"
2716+
);
2717+
self.conversation
2718+
.push_assistant_message(os, *message, Some(recv_error.request_metadata));
2719+
let tool_results = vec![ToolUseResult {
2720+
tool_use_id,
2721+
content: vec![ToolUseResultBlock::Text(format!(
2722+
"Tool validation failed: {}. Please ensure tool arguments are provided as a valid JSON object.",
2723+
error_message
2724+
))],
2725+
status: ToolResultStatus::Error,
2726+
}];
2727+
// User hint of what happened
2728+
let _ = queue!(
2729+
self.stdout,
2730+
style::Print("\n\n"),
2731+
style::SetForegroundColor(Color::Yellow),
2732+
style::Print(format!("Tool validation failed: {}\n Retrying the request...", error_message)),
2733+
style::ResetColor,
2734+
style::Print("\n"),
2735+
);
2736+
self.conversation.add_tool_results(tool_results);
2737+
self.send_tool_use_telemetry(os).await;
2738+
return Ok(ChatState::HandleResponseStream(
2739+
self.conversation
2740+
.as_sendable_conversation_state(os, &mut self.stderr, false)
2741+
.await?,
2742+
));
2743+
},
26972744
_ => {
26982745
self.send_chat_telemetry(
26992746
os,

crates/chat-cli/src/cli/chat/parser.rs

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ impl RecvError {
9191
RecvErrorKind::StreamTimeout { .. } => None,
9292
RecvErrorKind::UnexpectedToolUseEos { .. } => None,
9393
RecvErrorKind::Cancelled => None,
94+
RecvErrorKind::ToolValidationError { .. } => None,
9495
}
9596
}
9697
}
@@ -103,6 +104,7 @@ impl ReasonCode for RecvError {
103104
RecvErrorKind::StreamTimeout { .. } => "RecvErrorStreamTimeout".to_string(),
104105
RecvErrorKind::UnexpectedToolUseEos { .. } => "RecvErrorUnexpectedToolUseEos".to_string(),
105106
RecvErrorKind::Cancelled => "Interrupted".to_string(),
107+
RecvErrorKind::ToolValidationError { .. } => "RecvErrorToolValidation".to_string(),
106108
}
107109
}
108110
}
@@ -151,6 +153,14 @@ pub enum RecvErrorKind {
151153
/// The stream processing task was cancelled
152154
#[error("Stream handling was cancelled")]
153155
Cancelled,
156+
/// Tool validation failed due to invalid arguments
157+
#[error("Tool validation failed for tool: {} with id: {}", .name, .tool_use_id)]
158+
ToolValidationError {
159+
tool_use_id: String,
160+
name: String,
161+
message: Box<AssistantMessage>,
162+
error_message: String,
163+
},
154164
}
155165

156166
/// Represents a response stream from a call to the SendMessage API.
@@ -472,7 +482,43 @@ impl ResponseParser {
472482
}
473483

474484
let args = match serde_json::from_str(&tool_string) {
475-
Ok(args) => args,
485+
Ok(args) => {
486+
// Ensure we have a valid JSON object
487+
match args {
488+
serde_json::Value::Object(_) => args,
489+
_ => {
490+
error!("Received non-object JSON for tool arguments: {:?}", args);
491+
let warning_args = serde_json::Value::Object(
492+
[(
493+
"key".to_string(),
494+
serde_json::Value::String(
495+
"WARNING: the actual tool use arguments were not a valid JSON object".to_string(),
496+
),
497+
)]
498+
.into_iter()
499+
.collect(),
500+
);
501+
self.tool_uses.push(AssistantToolUse {
502+
id: id.clone(),
503+
name: name.clone(),
504+
orig_name: name.clone(),
505+
args: warning_args.clone(),
506+
orig_args: warning_args.clone(),
507+
});
508+
let message = Box::new(AssistantMessage::new_tool_use(
509+
Some(self.message_id.clone()),
510+
std::mem::take(&mut self.assistant_text),
511+
self.tool_uses.clone().into_iter().collect(),
512+
));
513+
return Err(self.error(RecvErrorKind::ToolValidationError {
514+
tool_use_id: id,
515+
name,
516+
message,
517+
error_message: format!("Expected JSON object, got: {:?}", args),
518+
}));
519+
},
520+
}
521+
},
476522
Err(err) if !tool_string.is_empty() => {
477523
// If we failed deserializing after waiting for a long time, then this is most
478524
// likely bedrock responding with a stop event for some reason without actually
@@ -753,4 +799,75 @@ mod tests {
753799
"assistant text preceding a code reference should be ignored as this indicates licensed code is being returned"
754800
);
755801
}
802+
803+
#[tokio::test]
804+
async fn test_response_parser_avoid_invalid_json() {
805+
let content_to_ignore = "IGNORE ME PLEASE";
806+
let tool_use_id = "TEST_ID".to_string();
807+
let tool_name = "execute_bash".to_string();
808+
let tool_args = serde_json::json!("invalid json").to_string();
809+
let mut events = vec![
810+
ChatResponseStream::AssistantResponseEvent {
811+
content: "hi".to_string(),
812+
},
813+
ChatResponseStream::AssistantResponseEvent {
814+
content: " there".to_string(),
815+
},
816+
ChatResponseStream::AssistantResponseEvent {
817+
content: content_to_ignore.to_string(),
818+
},
819+
ChatResponseStream::CodeReferenceEvent(()),
820+
ChatResponseStream::ToolUseEvent {
821+
tool_use_id: tool_use_id.clone(),
822+
name: tool_name.clone(),
823+
input: None,
824+
stop: None,
825+
},
826+
ChatResponseStream::ToolUseEvent {
827+
tool_use_id: tool_use_id.clone(),
828+
name: tool_name.clone(),
829+
input: Some(tool_args),
830+
stop: None,
831+
},
832+
];
833+
events.reverse();
834+
let mock = SendMessageOutput::Mock(events);
835+
let mut parser = ResponseParser::new(
836+
mock,
837+
"".to_string(),
838+
None,
839+
1,
840+
vec![],
841+
mpsc::channel(32).0,
842+
Instant::now(),
843+
SystemTime::now(),
844+
CancellationToken::new(),
845+
Arc::new(Mutex::new(None)),
846+
);
847+
848+
let mut output = String::new();
849+
let mut found_validation_error = false;
850+
for _ in 0..5 {
851+
match parser.recv().await {
852+
Ok(event) => {
853+
output.push_str(&format!("{:?}", event));
854+
},
855+
Err(recv_error) => {
856+
if matches!(recv_error.source, RecvErrorKind::ToolValidationError { .. }) {
857+
found_validation_error = true;
858+
}
859+
break;
860+
},
861+
}
862+
}
863+
864+
assert!(
865+
!output.contains(content_to_ignore),
866+
"assistant text preceding a code reference should be ignored as this indicates licensed code is being returned"
867+
);
868+
assert!(
869+
found_validation_error,
870+
"Expected to find tool validation error for non-object JSON"
871+
);
872+
}
756873
}

0 commit comments

Comments
 (0)