Skip to content

Commit 9be36ec

Browse files
committed
bug fixes in error handling
1 parent 720dbf3 commit 9be36ec

File tree

2 files changed

+40
-8
lines changed

2 files changed

+40
-8
lines changed

crates/agent/src/agent/agent_loop/mod.rs

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ use tracing::{
4141
use types::{
4242
ContentBlock,
4343
Message,
44+
MessageStartEvent,
4445
MessageStopEvent,
4546
MetadataEvent,
4647
Role,
@@ -403,6 +404,8 @@ struct StreamParseState {
403404
parsing_tool_use: Option<(String, String, String)>,
404405
/// Buffered metadata event returned from the response stream
405406
metadata: Option<MetadataEvent>,
407+
/// Buffered message start event returned from the response stream
408+
message_start: Option<MessageStartEvent>,
406409
/// Buffered message stop event returned from the response stream
407410
message_stop: Option<MessageStopEvent>,
408411
/// Buffered error event returned from the response stream
@@ -425,6 +428,7 @@ impl StreamParseState {
425428
user_message,
426429
message_id: None,
427430
metadata: None,
431+
message_start: None,
428432
message_stop: None,
429433
stream_err: None,
430434
ended_time: None,
@@ -433,15 +437,12 @@ impl StreamParseState {
433437
}
434438

435439
pub fn next(&mut self, ev: Option<StreamResult>, buf: &mut Vec<AgentLoopEventKind>) {
436-
if self.errored {
437-
if let Some(ev) = ev {
438-
warn!(?ev, "ignoring unexpected event after having received an error");
439-
}
440-
return;
441-
}
442-
443440
let Some(ev) = ev else {
444441
// No event received means the stream has ended.
442+
debug_assert!(
443+
self.ended_time.is_none(),
444+
"unexpected call to next after stream has already ended"
445+
);
445446
self.ended_time = Some(self.ended_time.unwrap_or(Instant::now()));
446447
self.errored = self.errored || !self.invalid_tool_uses.is_empty();
447448
let result = self.make_result();
@@ -453,14 +454,31 @@ impl StreamParseState {
453454
return;
454455
};
455456

457+
if self.errored {
458+
warn!(?ev, "ignoring unexpected event after having received an error");
459+
return;
460+
}
461+
462+
// Debug assertion that we always start with either a MessageStart, or an error.
463+
match &ev {
464+
StreamResult::Ok(StreamEvent::MessageStart(_)) | StreamResult::Err(_) => (),
465+
other @ StreamResult::Ok(_) => debug_assert!(
466+
self.message_start.is_some(),
467+
"received an unexpected event at the start of the response stream: {:?}",
468+
other
469+
),
470+
}
471+
456472
// Pushing low-level stream events in case end users want to consume these directly. Likely
457473
// not required.
458474
buf.push(AgentLoopEventKind::Stream(ev.clone()));
459475

460476
match ev {
461477
StreamResult::Ok(s) => match s {
462478
StreamEvent::MessageStart(ev) => {
479+
debug_assert!(self.message_start.is_none());
463480
debug_assert!(ev.role == Role::Assistant);
481+
self.message_start = Some(ev);
464482
},
465483
StreamEvent::MessageStop(ev) => {
466484
debug_assert!(self.message_stop.is_none());
@@ -547,7 +565,6 @@ impl StreamParseState {
547565
);
548566
self.stream_err = Some(err);
549567
self.errored = true;
550-
self.ended_time = Some(Instant::now());
551568
},
552569
}
553570
}

crates/agent/src/agent/rts/mod.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ use crate::agent::agent_loop::types::{
5555
ToolUseBlockDelta,
5656
ToolUseBlockStart,
5757
};
58+
use crate::agent_loop::types::MessageStartEvent;
5859
use crate::api_client::error::{
5960
ApiClientError,
6061
ConverseStreamError,
@@ -387,7 +388,12 @@ struct ResponseParser {
387388
/// Whether or not the stream has completed.
388389
ended: bool,
389390
/// Buffer to hold the next event in [SendMessageOutput].
391+
///
392+
/// Required since the RTS stream needs 1 look-ahead token to ensure we don't emit assistant
393+
/// response events that are immediately followed by a code reference event.
390394
peek: Option<ChatResponseStream>,
395+
/// Whether or not we have sent a [MessageStartEvent].
396+
message_start_pushed: bool,
391397
/// Whether or not we are currently receiving tool use delta events. Tuple of
392398
/// `Some((tool_use_id, name))` if true, [None] otherwise.
393399
parsing_tool_use: Option<(String, String)>,
@@ -421,6 +427,7 @@ impl ResponseParser {
421427
cancel_token,
422428
ended: false,
423429
peek: None,
430+
message_start_pushed: false,
424431
parsing_tool_use: None,
425432
tool_use_seen: false,
426433
buf: vec![],
@@ -601,6 +608,14 @@ impl ResponseParser {
601608
Ok(ev) => {
602609
trace!(?ev, "Received new event");
603610

611+
if !self.message_start_pushed {
612+
self.buf
613+
.push(StreamResult::Ok(StreamEvent::MessageStart(MessageStartEvent {
614+
role: Role::Assistant,
615+
})));
616+
self.message_start_pushed = true;
617+
}
618+
604619
// Track metadata about the chunk.
605620
self.time_to_first_chunk
606621
.get_or_insert_with(|| self.request_start_time.elapsed());

0 commit comments

Comments
 (0)