Skip to content

Commit 40cf653

Browse files
committed
serde rename fields
1 parent fe99abe commit 40cf653

File tree

8 files changed

+109
-84
lines changed

8 files changed

+109
-84
lines changed

crates/agent/src/agent/agent_loop/mod.rs

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use protocol::{
2222
LoopError,
2323
SendRequestArgs,
2424
StreamMetadata,
25+
StreamResult,
2526
UserTurnMetadata,
2627
};
2728
use serde::{
@@ -124,10 +125,7 @@ pub struct AgentLoop {
124125

125126
/// The current response stream future being received along with it's associated parse state
126127
#[allow(clippy::type_complexity)]
127-
curr_stream: Option<(
128-
StreamParseState,
129-
Pin<Box<dyn Stream<Item = Result<StreamEvent, StreamError>> + Send>>,
130-
)>,
128+
curr_stream: Option<(StreamParseState, Pin<Box<dyn Stream<Item = StreamResult> + Send>>)>,
131129

132130
/// List of completed stream parse states
133131
stream_states: Vec<StreamParseState>,
@@ -434,7 +432,9 @@ impl StreamParseState {
434432
}
435433
}
436434

437-
pub fn next(&mut self, ev: Option<Result<StreamEvent, StreamError>>, buf: &mut Vec<AgentLoopEventKind>) {
435+
// pub fn next(&mut self, ev: Option<Result<StreamEvent, StreamError>>, buf: &mut
436+
// Vec<AgentLoopEventKind>) {
437+
pub fn next(&mut self, ev: Option<StreamResult>, buf: &mut Vec<AgentLoopEventKind>) {
438438
if self.errored {
439439
if let Some(ev) = ev {
440440
warn!(?ev, "ignoring unexpected event after having received an error");
@@ -457,13 +457,10 @@ impl StreamParseState {
457457

458458
// Pushing low-level stream events in case end users want to consume these directly. Likely
459459
// not required.
460-
match &ev {
461-
Ok(e) => buf.push(AgentLoopEventKind::StreamEvent(e.clone())),
462-
Err(e) => buf.push(AgentLoopEventKind::StreamError(e.clone())),
463-
}
460+
buf.push(AgentLoopEventKind::Stream(ev.clone()));
464461

465462
match ev {
466-
Ok(s) => match s {
463+
StreamResult::Ok(s) => match s {
467464
StreamEvent::MessageStart(ev) => {
468465
debug_assert!(ev.role == Role::Assistant);
469466
},
@@ -543,7 +540,7 @@ impl StreamParseState {
543540

544541
// Parse invariant - we don't expect any further events after receiving a single
545542
// error.
546-
Err(err) => {
543+
StreamResult::Err(err) => {
547544
debug_assert!(
548545
self.stream_err.is_none(),
549546
"Only one stream error event is expected. Previously found: {:?}, just received: {:?}",

crates/agent/src/agent/agent_loop/model.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@ use serde::{
77
};
88
use tokio_util::sync::CancellationToken;
99

10+
use super::protocol::StreamResult;
1011
use super::types::{
1112
Message,
12-
StreamError,
13-
StreamEvent,
1413
ToolSpec,
1514
};
1615
use crate::agent::rts::RtsModel;
@@ -26,7 +25,7 @@ pub trait Model: std::fmt::Debug + Send + Sync + 'static {
2625
tool_specs: Option<Vec<ToolSpec>>,
2726
system_prompt: Option<String>,
2827
cancel_token: CancellationToken,
29-
) -> Pin<Box<dyn Stream<Item = Result<StreamEvent, StreamError>> + Send + 'static>>;
28+
) -> Pin<Box<dyn Stream<Item = StreamResult> + Send + 'static>>;
3029

3130
/// Dump serializable state required by the model implementation.
3231
///
@@ -82,7 +81,7 @@ impl Model for Models {
8281
tool_specs: Option<Vec<ToolSpec>>,
8382
system_prompt: Option<String>,
8483
cancel_token: CancellationToken,
85-
) -> Pin<Box<dyn Stream<Item = Result<StreamEvent, StreamError>> + Send + 'static>> {
84+
) -> Pin<Box<dyn Stream<Item = StreamResult> + Send + 'static>> {
8685
match self {
8786
Models::Rts(rts_model) => rts_model.stream(messages, tool_specs, system_prompt, cancel_token),
8887
Models::Test(test_model) => test_model.stream(messages, tool_specs, system_prompt, cancel_token),
@@ -106,7 +105,7 @@ impl Model for TestModel {
106105
_tool_specs: Option<Vec<ToolSpec>>,
107106
_system_prompt: Option<String>,
108107
_cancel_token: CancellationToken,
109-
) -> Pin<Box<dyn Stream<Item = Result<StreamEvent, StreamError>> + Send + 'static>> {
108+
) -> Pin<Box<dyn Stream<Item = StreamResult> + Send + 'static>> {
110109
panic!("unimplemented")
111110
}
112111
}

crates/agent/src/agent/agent_loop/protocol.rs

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ impl AgentLoopEvent {
9595
}
9696

9797
#[derive(Debug, Clone, Serialize, Deserialize)]
98+
#[serde(tag = "kind", content = "content")]
99+
#[serde(rename_all = "camelCase")]
98100
pub enum AgentLoopEventKind {
99101
/// Text returned by the assistant.
100102
AssistantText(String),
@@ -139,12 +141,26 @@ pub enum AgentLoopEventKind {
139141
///
140142
/// This reflects the exact event the agent loop parses from a [Model::stream] response as part
141143
/// of executing a user turn.
142-
StreamEvent(StreamEvent),
143-
/// Low level event. Generally only useful for [AgentLoop].
144-
///
145-
/// This reflects the exact event the agent loop parses from a [Model::stream] response as part
146-
/// of executing a user turn.
147-
StreamError(StreamError),
144+
// Stream(StreamResult<StreamEvent, StreamError>),
145+
Stream(StreamResult),
146+
}
147+
148+
#[derive(Debug, Clone, Serialize, Deserialize)]
149+
#[serde(tag = "result")]
150+
#[serde(rename_all = "lowercase")]
151+
pub enum StreamResult {
152+
Ok(StreamEvent),
153+
#[serde(rename = "error")]
154+
Err(StreamError),
155+
}
156+
157+
impl StreamResult {
158+
pub fn unwrap_err(self) -> StreamError {
159+
match self {
160+
StreamResult::Ok(t) => panic!("called `StreamResult::unwrap_err()` on an `Ok` value: {:?}", &t),
161+
StreamResult::Err(e) => e,
162+
}
163+
}
148164
}
149165

150166
#[derive(Debug, Clone, Serialize, Deserialize, thiserror::Error)]

crates/agent/src/agent/agent_loop/types.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ impl std::error::Error for StreamError {
107107
}
108108

109109
#[derive(Debug, Clone, Serialize, Deserialize)]
110+
#[serde(rename_all = "camelCase")]
110111
pub enum StreamErrorKind {
111112
/// The request failed due to the context window overflowing.
112113
///
@@ -245,7 +246,7 @@ impl Message {
245246
}
246247

247248
#[derive(Debug, Clone, Serialize, Deserialize)]
248-
#[serde(rename_all = "lowercase")]
249+
#[serde(rename_all = "camelCase")]
249250
pub enum ContentBlock {
250251
Text(String),
251252
ToolUse(ToolUseBlock),

crates/agent/src/agent/protocol.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ use super::types::AgentSnapshot;
2626

2727
#[derive(Debug, Clone, Serialize, Deserialize)]
2828
#[allow(clippy::large_enum_variant)]
29+
#[serde(tag = "kind", content = "content")]
30+
#[serde(rename_all = "camelCase")]
2931
pub enum AgentEvent {
3032
/// Agent has finished initialization, and is ready to receive requests.
3133
///

0 commit comments

Comments
 (0)