Skip to content

Commit f4ac403

Browse files
committed
Some cleanup
1 parent cfe15d4 commit f4ac403

File tree

21 files changed

+920
-4003
lines changed

21 files changed

+920
-4003
lines changed

crates/agent/src/agent/agent_config/parse.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,6 @@ pub enum ToolParseErrorKind {
151151
SchemaFailure(String),
152152
#[error("The tool arguments failed validation: {}", .0)]
153153
InvalidArgs(String),
154-
#[error("The tool name could not be resolved: {}", .0)]
155-
AmbiguousToolName(String),
156154
#[error("An unexpected error occurred parsing the tools: {}", .0)]
157155
Other(#[from] AgentError),
158156
}

crates/agent/src/agent/agent_loop/mod.rs

Lines changed: 2 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ use protocol::{
2323
StreamMetadata,
2424
UserTurnMetadata,
2525
};
26-
use rand::seq::IndexedRandom;
2726
use serde::{
2827
Deserialize,
2928
Serialize,
@@ -77,10 +76,6 @@ impl AgentLoopId {
7776
rand: rand::random::<u32>(),
7877
}
7978
}
80-
81-
pub fn agent_id(&self) -> &AgentId {
82-
&self.agent_id
83-
}
8479
}
8580

8681
impl std::fmt::Display for AgentLoopId {
@@ -89,23 +84,6 @@ impl std::fmt::Display for AgentLoopId {
8984
}
9085
}
9186

92-
// impl FromStr for AgentLoopId {
93-
// type Err = String;
94-
//
95-
// fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
96-
// match s.find("/") {
97-
// Some(i) => Ok(Self {
98-
// agent_id: s[..i].to_string(),
99-
// rand: match s[i + 1..].to_string().parse() {
100-
// Ok(v) => v,
101-
// Err(_) => return Err(s.to_string()),
102-
// },
103-
// }),
104-
// None => Err(s.to_string()),
105-
// }
106-
// }
107-
// }
108-
10987
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize, strum::Display, strum::EnumString)]
11088
#[serde(rename_all = "camelCase")]
11189
#[strum(serialize_all = "camelCase")]
@@ -126,14 +104,6 @@ pub enum LoopState {
126104
Errored,
127105
}
128106

129-
// #[derive(Debug)]
130-
// struct StreamRequest {
131-
// model: Box<dyn AgentLoopModel>,
132-
// messages: Vec<Message>,
133-
// tool_specs: Option<Vec<ToolSpec>>,
134-
// system_prompt: Option<String>,
135-
// }
136-
137107
/// Tracks the execution of a user turn, ending when either the model returns a response with no
138108
/// tool uses, or a non-retryable error is encountered.
139109
pub struct AgentLoop {
@@ -147,6 +117,7 @@ pub struct AgentLoop {
147117
cancel_token: CancellationToken,
148118

149119
/// The current response stream future being received along with it's associated parse state
120+
#[allow(clippy::type_complexity)]
150121
curr_stream: Option<(
151122
StreamParseState,
152123
Pin<Box<dyn Stream<Item = Result<StreamEvent, StreamError>> + Send>>,
@@ -201,15 +172,14 @@ impl AgentLoop {
201172
/// the spawned task.
202173
pub fn spawn(mut self) -> AgentLoopHandle {
203174
let id_clone = self.id.clone();
204-
let cancel_token_clone = self.cancel_token.clone();
205175
let loop_event_rx = self.loop_event_rx.take().expect("loop_event_rx should exist");
206176
let loop_req_tx = self.loop_req_tx.take().expect("loop_req_tx should exist");
207177
let handle = tokio::spawn(async move {
208178
info!("agent loop start");
209179
self.run().await;
210180
info!("agent loop end");
211181
});
212-
AgentLoopHandle::new(id_clone, loop_req_tx, loop_event_rx, cancel_token_clone, handle)
182+
AgentLoopHandle::new(id_clone, loop_req_tx, loop_event_rx, handle)
213183
}
214184

215185
async fn run(mut self) {
@@ -349,15 +319,6 @@ impl AgentLoop {
349319

350320
Ok(AgentLoopResponse::Metadata(metadata))
351321
},
352-
353-
AgentLoopRequest::GetPendingToolUses => {
354-
if self.execution_state != LoopState::PendingToolUseResults {
355-
return Ok(AgentLoopResponse::PendingToolUses(None));
356-
}
357-
let tool_uses = self.stream_states.last().map(|s| s.tool_uses.clone());
358-
debug_assert!(tool_uses.as_ref().is_some_and(|v| !v.is_empty()));
359-
Ok(AgentLoopResponse::PendingToolUses(tool_uses))
360-
},
361322
}
362323
}
363324

@@ -648,8 +609,6 @@ pub struct AgentLoopHandle {
648609
/// Sender for sending requests to the agent loop
649610
sender: RequestSender<AgentLoopRequest, AgentLoopResponse, AgentLoopResponseError>,
650611
loop_event_rx: mpsc::Receiver<AgentLoopEventKind>,
651-
/// A [CancellationToken] used for gracefully closing the agent loop.
652-
cancel_token: CancellationToken,
653612
/// The [JoinHandle] to the task executing the agent loop.
654613
handle: JoinHandle<()>,
655614
}
@@ -659,14 +618,12 @@ impl AgentLoopHandle {
659618
id: AgentLoopId,
660619
sender: RequestSender<AgentLoopRequest, AgentLoopResponse, AgentLoopResponseError>,
661620
loop_event_rx: mpsc::Receiver<AgentLoopEventKind>,
662-
cancel_token: CancellationToken,
663621
handle: JoinHandle<()>,
664622
) -> Self {
665623
Self {
666624
id,
667625
sender,
668626
loop_event_rx,
669-
cancel_token,
670627
handle,
671628
}
672629
}
@@ -676,19 +633,6 @@ impl AgentLoopHandle {
676633
&self.id
677634
}
678635

679-
/// Id of the agent this loop was created for.
680-
pub fn agent_id(&self) -> &AgentId {
681-
self.id.agent_id()
682-
}
683-
684-
pub fn clone_weak(&self) -> AgentLoopWeakHandle {
685-
AgentLoopWeakHandle {
686-
id: self.id.clone(),
687-
sender: self.sender.clone(),
688-
cancel_token: self.cancel_token.clone(),
689-
}
690-
}
691-
692636
pub async fn recv(&mut self) -> Option<AgentLoopEventKind> {
693637
self.loop_event_rx.recv().await
694638
}
@@ -722,21 +666,6 @@ impl AgentLoopHandle {
722666
}
723667
}
724668

725-
pub async fn get_pending_tool_uses(&self) -> Result<Option<Vec<ToolUseBlock>>, AgentLoopResponseError> {
726-
match self
727-
.sender
728-
.send_recv(AgentLoopRequest::GetPendingToolUses)
729-
.await
730-
.unwrap_or(Err(AgentLoopResponseError::AgentLoopExited))?
731-
{
732-
AgentLoopResponse::PendingToolUses(v) => Ok(v),
733-
other => Err(AgentLoopResponseError::Custom(format!(
734-
"unknown response getting stream metadata: {:?}",
735-
other,
736-
))),
737-
}
738-
}
739-
740669
/// Ends the agent loop
741670
pub async fn close(&self) -> Result<UserTurnMetadata, AgentLoopResponseError> {
742671
match self
@@ -760,105 +689,3 @@ impl Drop for AgentLoopHandle {
760689
self.handle.abort();
761690
}
762691
}
763-
764-
/// A weak handle to an executing agent loop.
765-
///
766-
/// Where [AgentLoopHandle] can receive agent loop events and abort the task on drop,
767-
/// [AgentLoopWeakHandle] is only used for sending messages to the agent loop.
768-
#[derive(Debug, Clone)]
769-
pub struct AgentLoopWeakHandle {
770-
id: AgentLoopId,
771-
sender: RequestSender<AgentLoopRequest, AgentLoopResponse, AgentLoopResponseError>,
772-
cancel_token: CancellationToken,
773-
}
774-
775-
impl AgentLoopWeakHandle {
776-
pub async fn send_request<M: AgentLoopModel>(
777-
&self,
778-
model: M,
779-
args: SendRequestArgs,
780-
) -> Result<AgentLoopResponse, AgentLoopResponseError> {
781-
self.sender
782-
.send_recv(AgentLoopRequest::SendRequest {
783-
model: Box::new(model),
784-
args,
785-
})
786-
.await
787-
.unwrap_or(Err(AgentLoopResponseError::AgentLoopExited))
788-
}
789-
790-
pub async fn get_loop_state(&self) -> Result<LoopState, AgentLoopResponseError> {
791-
match self
792-
.sender
793-
.send_recv(AgentLoopRequest::GetExecutionState)
794-
.await
795-
.unwrap_or(Err(AgentLoopResponseError::AgentLoopExited))?
796-
{
797-
AgentLoopResponse::ExecutionState(state) => Ok(state),
798-
other => Err(AgentLoopResponseError::Custom(format!(
799-
"unknown response getting execution state: {:?}",
800-
other,
801-
))),
802-
}
803-
}
804-
805-
pub async fn get_pending_tool_uses(&self) -> Result<Option<Vec<ToolUseBlock>>, AgentLoopResponseError> {
806-
match self
807-
.sender
808-
.send_recv(AgentLoopRequest::GetPendingToolUses)
809-
.await
810-
.unwrap_or(Err(AgentLoopResponseError::AgentLoopExited))?
811-
{
812-
AgentLoopResponse::PendingToolUses(v) => Ok(v),
813-
other => Err(AgentLoopResponseError::Custom(format!(
814-
"unknown response getting stream metadata: {:?}",
815-
other,
816-
))),
817-
}
818-
}
819-
820-
/// Ends the agent loop
821-
pub async fn close(&self) -> Result<UserTurnMetadata, AgentLoopResponseError> {
822-
match self
823-
.sender
824-
.send_recv(AgentLoopRequest::Close)
825-
.await
826-
.unwrap_or(Err(AgentLoopResponseError::AgentLoopExited))?
827-
{
828-
AgentLoopResponse::Metadata(md) => Ok(md),
829-
other => Err(AgentLoopResponseError::Custom(format!(
830-
"unknown response getting execution state: {:?}",
831-
other,
832-
))),
833-
}
834-
}
835-
836-
/// Cancel the executing loop for graceful shutdown.
837-
fn cancel(&self) {
838-
self.cancel_token.cancel();
839-
}
840-
}
841-
842-
#[cfg(test)]
843-
mod tests {
844-
use std::sync::Arc;
845-
846-
use super::*;
847-
use crate::api_client::error::{
848-
ConverseStreamError,
849-
ConverseStreamErrorKind,
850-
};
851-
852-
#[test]
853-
fn test_other_stream_err_downcasting() {
854-
let err = StreamError::new(StreamErrorKind::Interrupted).with_source(Arc::new(ConverseStreamError::new(
855-
ConverseStreamErrorKind::ModelOverloadedError,
856-
None::<aws_smithy_types::error::operation::BuildError>, /* annoying type inference
857-
* required */
858-
)));
859-
assert!(
860-
err.as_rts_error()
861-
.is_some_and(|r| matches!(r.kind, ConverseStreamErrorKind::ModelOverloadedError))
862-
);
863-
}
864-
}

crates/agent/src/agent/agent_loop/model.rs

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,6 @@ pub enum Models {
4242
}
4343

4444
impl Models {
45-
pub fn supported_model(&self) -> SupportedModel {
46-
match self {
47-
Models::Rts(_) => SupportedModel::Rts,
48-
Models::Test(_) => SupportedModel::Test,
49-
}
50-
}
51-
5245
pub fn state(&self) -> ModelsState {
5346
match self {
5447
Models::Rts(v) => ModelsState::Rts {
@@ -79,17 +72,6 @@ impl Default for ModelsState {
7972
}
8073
}
8174

82-
/// Identifier for the models we support.
83-
///
84-
/// TODO - probably not required, use [ModelsState] instead
85-
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, strum::Display, strum::EnumString)]
86-
#[serde(rename_all = "camelCase")]
87-
#[strum(serialize_all = "camelCase")]
88-
pub enum SupportedModel {
89-
Rts,
90-
Test,
91-
}
92-
9375
impl Model for Models {
9476
fn stream(
9577
&self,
@@ -100,7 +82,7 @@ impl Model for Models {
10082
) -> Pin<Box<dyn Stream<Item = Result<StreamEvent, StreamError>> + Send + 'static>> {
10183
match self {
10284
Models::Rts(rts_model) => rts_model.stream(messages, tool_specs, system_prompt, cancel_token),
103-
Models::Test(test_model) => todo!(),
85+
Models::Test(test_model) => test_model.stream(messages, tool_specs, system_prompt, cancel_token),
10486
}
10587
}
10688
}
@@ -113,3 +95,15 @@ impl TestModel {
11395
Self {}
11496
}
11597
}
98+
99+
impl Model for TestModel {
100+
fn stream(
101+
&self,
102+
messages: Vec<Message>,
103+
tool_specs: Option<Vec<ToolSpec>>,
104+
system_prompt: Option<String>,
105+
cancel_token: CancellationToken,
106+
) -> Pin<Box<dyn Stream<Item = Result<StreamEvent, StreamError>> + Send + 'static>> {
107+
todo!()
108+
}
109+
}

crates/agent/src/agent/agent_loop/protocol.rs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ use super::{
2424
InvalidToolUse,
2525
LoopState,
2626
};
27-
use crate::agent::types::AgentId;
2827

2928
#[derive(Debug)]
3029
pub enum AgentLoopRequest {
@@ -33,7 +32,6 @@ pub enum AgentLoopRequest {
3332
model: Box<dyn AgentLoopModel>,
3433
args: SendRequestArgs,
3534
},
36-
GetPendingToolUses,
3735
/// Ends the agent loop
3836
Close,
3937
}
@@ -93,11 +91,6 @@ impl AgentLoopEvent {
9391
pub fn new(id: AgentLoopId, kind: AgentLoopEventKind) -> Self {
9492
Self { id, kind }
9593
}
96-
97-
/// Id of the agent this loop event is associated with
98-
pub fn agent_id(&self) -> &AgentId {
99-
self.id.agent_id()
100-
}
10194
}
10295

10396
#[derive(Debug, Clone, Serialize, Deserialize)]

0 commit comments

Comments
 (0)