@@ -91,6 +91,7 @@ impl RecvError {
9191 RecvErrorKind :: StreamTimeout { .. } => None ,
9292 RecvErrorKind :: UnexpectedToolUseEos { .. } => None ,
9393 RecvErrorKind :: Cancelled => None ,
94+ RecvErrorKind :: ToolValidationError { .. } => None ,
9495 }
9596 }
9697}
@@ -103,6 +104,7 @@ impl ReasonCode for RecvError {
103104 RecvErrorKind :: StreamTimeout { .. } => "RecvErrorStreamTimeout" . to_string ( ) ,
104105 RecvErrorKind :: UnexpectedToolUseEos { .. } => "RecvErrorUnexpectedToolUseEos" . to_string ( ) ,
105106 RecvErrorKind :: Cancelled => "Interrupted" . to_string ( ) ,
107+ RecvErrorKind :: ToolValidationError { .. } => "RecvErrorToolValidation" . to_string ( ) ,
106108 }
107109 }
108110}
@@ -151,6 +153,14 @@ pub enum RecvErrorKind {
151153 /// The stream processing task was cancelled
152154 #[ error( "Stream handling was cancelled" ) ]
153155 Cancelled ,
156+ /// Tool validation failed due to invalid arguments
157+ #[ error( "Tool validation failed for tool: {} with id: {}" , . name, . tool_use_id) ]
158+ ToolValidationError {
159+ tool_use_id : String ,
160+ name : String ,
161+ message : Box < AssistantMessage > ,
162+ error_message : String ,
163+ } ,
154164}
155165
156166/// Represents a response stream from a call to the SendMessage API.
@@ -472,7 +482,43 @@ impl ResponseParser {
472482 }
473483
474484 let args = match serde_json:: from_str ( & tool_string) {
475- Ok ( args) => args,
485+ Ok ( args) => {
486+ // Ensure we have a valid JSON object
487+ match args {
488+ serde_json:: Value :: Object ( _) => args,
489+ _ => {
490+ error ! ( "Received non-object JSON for tool arguments: {:?}" , args) ;
491+ let warning_args = serde_json:: Value :: Object (
492+ [ (
493+ "key" . to_string ( ) ,
494+ serde_json:: Value :: String (
495+ "WARNING: the actual tool use arguments were not a valid JSON object" . to_string ( ) ,
496+ ) ,
497+ ) ]
498+ . into_iter ( )
499+ . collect ( ) ,
500+ ) ;
501+ self . tool_uses . push ( AssistantToolUse {
502+ id : id. clone ( ) ,
503+ name : name. clone ( ) ,
504+ orig_name : name. clone ( ) ,
505+ args : warning_args. clone ( ) ,
506+ orig_args : warning_args. clone ( ) ,
507+ } ) ;
508+ let message = Box :: new ( AssistantMessage :: new_tool_use (
509+ Some ( self . message_id . clone ( ) ) ,
510+ std:: mem:: take ( & mut self . assistant_text ) ,
511+ self . tool_uses . clone ( ) . into_iter ( ) . collect ( ) ,
512+ ) ) ;
513+ return Err ( self . error ( RecvErrorKind :: ToolValidationError {
514+ tool_use_id : id,
515+ name,
516+ message,
517+ error_message : format ! ( "Expected JSON object, got: {:?}" , args) ,
518+ } ) ) ;
519+ } ,
520+ }
521+ } ,
476522 Err ( err) if !tool_string. is_empty ( ) => {
477523 // If we failed deserializing after waiting for a long time, then this is most
478524 // likely bedrock responding with a stop event for some reason without actually
@@ -753,4 +799,75 @@ mod tests {
753799 "assistant text preceding a code reference should be ignored as this indicates licensed code is being returned"
754800 ) ;
755801 }
802+
803+ #[ tokio:: test]
804+ async fn test_response_parser_avoid_invalid_json ( ) {
805+ let content_to_ignore = "IGNORE ME PLEASE" ;
806+ let tool_use_id = "TEST_ID" . to_string ( ) ;
807+ let tool_name = "execute_bash" . to_string ( ) ;
808+ let tool_args = serde_json:: json!( "invalid json" ) . to_string ( ) ;
809+ let mut events = vec ! [
810+ ChatResponseStream :: AssistantResponseEvent {
811+ content: "hi" . to_string( ) ,
812+ } ,
813+ ChatResponseStream :: AssistantResponseEvent {
814+ content: " there" . to_string( ) ,
815+ } ,
816+ ChatResponseStream :: AssistantResponseEvent {
817+ content: content_to_ignore. to_string( ) ,
818+ } ,
819+ ChatResponseStream :: CodeReferenceEvent ( ( ) ) ,
820+ ChatResponseStream :: ToolUseEvent {
821+ tool_use_id: tool_use_id. clone( ) ,
822+ name: tool_name. clone( ) ,
823+ input: None ,
824+ stop: None ,
825+ } ,
826+ ChatResponseStream :: ToolUseEvent {
827+ tool_use_id: tool_use_id. clone( ) ,
828+ name: tool_name. clone( ) ,
829+ input: Some ( tool_args) ,
830+ stop: None ,
831+ } ,
832+ ] ;
833+ events. reverse ( ) ;
834+ let mock = SendMessageOutput :: Mock ( events) ;
835+ let mut parser = ResponseParser :: new (
836+ mock,
837+ "" . to_string ( ) ,
838+ None ,
839+ 1 ,
840+ vec ! [ ] ,
841+ mpsc:: channel ( 32 ) . 0 ,
842+ Instant :: now ( ) ,
843+ SystemTime :: now ( ) ,
844+ CancellationToken :: new ( ) ,
845+ Arc :: new ( Mutex :: new ( None ) ) ,
846+ ) ;
847+
848+ let mut output = String :: new ( ) ;
849+ let mut found_validation_error = false ;
850+ for _ in 0 ..5 {
851+ match parser. recv ( ) . await {
852+ Ok ( event) => {
853+ output. push_str ( & format ! ( "{:?}" , event) ) ;
854+ } ,
855+ Err ( recv_error) => {
856+ if matches ! ( recv_error. source, RecvErrorKind :: ToolValidationError { .. } ) {
857+ found_validation_error = true ;
858+ }
859+ break ;
860+ } ,
861+ }
862+ }
863+
864+ assert ! (
865+ !output. contains( content_to_ignore) ,
866+ "assistant text preceding a code reference should be ignored as this indicates licensed code is being returned"
867+ ) ;
868+ assert ! (
869+ found_validation_error,
870+ "Expected to find tool validation error for non-object JSON"
871+ ) ;
872+ }
756873}
0 commit comments