Skip to content

Commit 701a3d9

Browse files
committed
wip
1 parent 7e21fa0 commit 701a3d9

File tree

15 files changed

+85
-130
lines changed

15 files changed

+85
-130
lines changed

crates/agent/src/agent/agent_loop/model.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ pub trait Model {
2828
) -> Pin<Box<dyn Stream<Item = Result<StreamEvent, StreamError>> + Send + 'static>>;
2929
}
3030

31-
/// Required for defining [Model] with a [Box<dyn Model>] for [AgentLoopRequest].
31+
/// Required for defining [Model] with a [Box<dyn Model>] for [super::AgentLoopRequest].
3232
pub trait AgentLoopModel: Model + std::fmt::Debug + Send + Sync + 'static {}
3333

3434
// Helper blanket impl
3535
impl<T> AgentLoopModel for T where T: Model + std::fmt::Debug + Send + Sync + 'static {}
3636

37-
/// The supporte
37+
/// The supported backends
3838
#[derive(Debug, Clone)]
3939
pub enum Models {
4040
Rts(RtsModel),

crates/agent/src/agent/agent_loop/types.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ pub struct ImageBlock {
272272
pub enum ImageFormat {
273273
Gif,
274274
#[serde(alias = "jpg")]
275+
#[strum(serialize = "jpeg", serialize = "jpg")]
275276
Jpeg,
276277
Png,
277278
Webp,
@@ -422,6 +423,8 @@ pub struct MetadataEvent {
422423
#[derive(Debug, Clone, Serialize, Deserialize)]
423424
#[serde(rename_all = "camelCase")]
424425
pub struct MetadataMetrics {
426+
pub request_start_time: DateTime<Utc>,
427+
pub request_end_time: DateTime<Utc>,
425428
pub time_to_first_chunk: Option<Duration>,
426429
pub time_between_chunks: Option<Vec<Duration>>,
427430
pub response_stream_len: u32,
@@ -479,6 +482,11 @@ mod tests {
479482
test_ser_deser!(ImageFormat, ImageFormat::Png, "png");
480483
test_ser_deser!(ImageFormat, ImageFormat::Webp, "webp");
481484
test_ser_deser!(ImageFormat, ImageFormat::Jpeg, "jpeg");
482-
assert_eq!(ImageFormat::from_str("jpg").unwrap(), ImageFormat::Jpeg);
485+
assert_eq!(
486+
ImageFormat::from_str("jpg").unwrap(),
487+
ImageFormat::Jpeg,
488+
"expected 'jpg' to parse to {}",
489+
ImageFormat::Jpeg
490+
);
483491
}
484492
}

crates/agent/src/agent/mod.rs

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,38 +2104,3 @@ pub enum HookStage {
21042104
/// Hooks after executing tool uses
21052105
PostToolUse { tool_results: Vec<ToolExecutorResult> },
21062106
}
2107-
2108-
#[cfg(test)]
2109-
mod tests {
2110-
use super::*;
2111-
2112-
#[tokio::test]
2113-
async fn test_collect_resources() {
2114-
let r = collect_resources(vec!["file://AGENTS.md"]).await;
2115-
println!("{:?}", r);
2116-
}
2117-
2118-
#[tokio::test]
2119-
async fn test_agent() {
2120-
let _ = tracing_subscriber::fmt::try_init();
2121-
2122-
let path = "/Users/bskiser/.aws/amazonq/cli-agents/idk.json";
2123-
let contents = tokio::fs::read_to_string(path).await.unwrap();
2124-
let cfg: Config = serde_json::from_str(&contents).unwrap();
2125-
let mut agent = Agent::from_config(cfg).await.unwrap().spawn();
2126-
let init_res = agent.recv().await.unwrap();
2127-
println!("Init res: {:?}", init_res);
2128-
2129-
agent
2130-
.send_prompt(SendPromptArgs {
2131-
content: vec![InputItem::Text("what tools do you have?".to_string())],
2132-
})
2133-
.await
2134-
.unwrap();
2135-
2136-
loop {
2137-
let res = agent.recv().await.unwrap();
2138-
println!("res: {:?}", res);
2139-
}
2140-
}
2141-
}

crates/agent/src/agent/rts/mod.rs

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@ use std::sync::Arc;
66
use std::time::{
77
Duration,
88
Instant,
9-
SystemTime,
109
};
1110

11+
use chrono::{
12+
DateTime,
13+
Utc,
14+
};
1215
use eyre::Result;
1316
use futures::Stream;
1417
use tokio::sync::mpsc;
@@ -111,7 +114,7 @@ impl RtsModel {
111114
};
112115

113116
let request_start_time = Instant::now();
114-
let request_start_time_sys = SystemTime::now();
117+
let request_start_time_sys = Utc::now();
115118
let token_clone = cancel_token.clone();
116119
let result = tokio::select! {
117120
_ = token_clone.cancelled() => {
@@ -144,7 +147,7 @@ impl RtsModel {
144147
tx: mpsc::Sender<std::result::Result<StreamEvent, StreamError>>,
145148
token: CancellationToken,
146149
request_start_time: Instant,
147-
request_start_time_sys: SystemTime,
150+
request_start_time_sys: DateTime<Utc>,
148151
) {
149152
match res {
150153
Ok(output) => {
@@ -335,35 +338,7 @@ impl Model for RtsModel {
335338
.await;
336339
});
337340

338-
Box::pin(RtsDropWrapper {
339-
receiver_stream: ReceiverStream::new(rx),
340-
cancel_token,
341-
})
342-
}
343-
}
344-
345-
#[derive(Debug)]
346-
struct RtsDropWrapper {
347-
receiver_stream: ReceiverStream<Result<StreamEvent, StreamError>>,
348-
cancel_token: CancellationToken,
349-
}
350-
351-
impl Stream for RtsDropWrapper {
352-
type Item = Result<StreamEvent, StreamError>;
353-
354-
fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
355-
Pin::new(&mut self.receiver_stream).poll_next(cx)
356-
}
357-
}
358-
359-
impl Drop for RtsDropWrapper {
360-
fn drop(&mut self) {
361-
// TODO - I don't think RtsDropWrapper is really required here.
362-
//
363-
// Cancelling is already handled by agent_loop correctly (when AgentLoop is dropped, the
364-
// cancel token will call cancel)
365-
// debug!("rts stream dropped, cancelling");
366-
// self.cancel_token.cancel();
341+
Box::pin(ReceiverStream::new(rx))
367342
}
368343
}
369344

@@ -393,7 +368,7 @@ struct ResponseParser {
393368
/// Time immediately before sending the request.
394369
request_start_time: Instant,
395370
/// Time immediately before sending the request, as a [SystemTime].
396-
request_start_time_sys: SystemTime,
371+
request_start_time_sys: DateTime<Utc>,
397372
time_to_first_chunk: Option<Duration>,
398373
time_between_chunks: Vec<Duration>,
399374
/// Total size (in bytes) of the response received so far.
@@ -407,7 +382,7 @@ impl ResponseParser {
407382
cancel_token: CancellationToken,
408383
request_id: Option<String>,
409384
request_start_time: Instant,
410-
request_start_time_sys: SystemTime,
385+
request_start_time_sys: DateTime<Utc>,
411386
) -> Self {
412387
Self {
413388
response,
@@ -621,6 +596,8 @@ impl ResponseParser {
621596
fn make_metadata(&self) -> StreamEvent {
622597
StreamEvent::Metadata(MetadataEvent {
623598
metrics: Some(MetadataMetrics {
599+
request_start_time: self.request_start_time_sys,
600+
request_end_time: Utc::now(),
624601
time_to_first_chunk: self.time_to_first_chunk,
625602
time_between_chunks: if self.time_between_chunks.is_empty() {
626603
None

crates/agent/src/agent/rts/types.rs

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,25 +40,6 @@ impl From<ToolUseBlock> for model::ToolUse {
4040
}
4141
}
4242

43-
// impl From<ToolResultBlock> for model::ToolResult {
44-
// fn from(v: ToolResultBlock) -> Self {
45-
// Self {
46-
// tool_use_id: v.tool_use_id,
47-
// content: v.content.into_iter().map(Into::into).collect(),
48-
// status: v.status.into(),
49-
// }
50-
// }
51-
// }
52-
53-
// impl From<ToolResultContentBlock> for model::ToolResultContentBlock {
54-
// fn from(v: ToolResultContentBlock) -> Self {
55-
// match v {
56-
// ToolResultContentBlock::Text(t) => Self::Text(t),
57-
// ToolResultContentBlock::Json(v) => Self::Json(serde_value_to_document(v)),
58-
// }
59-
// }
60-
// }
61-
6243
impl From<ToolResultStatus> for model::ToolResultStatus {
6344
fn from(value: ToolResultStatus) -> Self {
6445
match value {

crates/agent/src/agent/task_executor/mod.rs

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ impl TaskExecutor {
216216
}
217217
});
218218
},
219-
HookConfig::Tool(tool) => (),
219+
HookConfig::Tool(_) => (),
220220
};
221221

222222
let start_time = Utc::now();
@@ -282,6 +282,12 @@ impl TaskExecutor {
282282
}
283283
}
284284

285+
impl Default for TaskExecutor {
286+
fn default() -> Self {
287+
Self::new()
288+
}
289+
}
290+
285291
#[derive(Debug)]
286292
pub enum ExecuteRequest {
287293
Tool(StartToolExecution),
@@ -336,6 +342,7 @@ struct ExecutingHook {
336342
}
337343

338344
#[derive(Debug, Clone, Serialize, Deserialize)]
345+
#[allow(clippy::large_enum_variant)]
339346
pub enum TaskExecutorEvent {
340347
/// A tool has started executing
341348
ToolExecutionStart(ToolExecutionStartEvent),
@@ -406,6 +413,7 @@ impl ToolExecutionId {
406413
}
407414

408415
#[derive(Debug, Clone, Serialize, Deserialize)]
416+
#[allow(clippy::large_enum_variant)]
409417
pub enum ExecutorResult {
410418
Tool(ToolExecutorResult),
411419
Hook(HookExecutorResult),
@@ -505,7 +513,7 @@ impl HookResult {
505513
pub fn is_success(&self) -> bool {
506514
match self {
507515
HookResult::Command(res) => res.as_ref().is_ok_and(|r| r.exit_code == 0),
508-
HookResult::Tool { .. } => todo!(),
516+
HookResult::Tool { .. } => panic!("unimplemented"),
509517
}
510518
}
511519

@@ -516,7 +524,6 @@ impl HookResult {
516524
pub fn output(&self) -> Option<&str> {
517525
match self {
518526
HookResult::Command(Ok(CommandResult { output, .. })) => Some(output),
519-
HookResult::Tool { output } => todo!(),
520527
_ => None,
521528
}
522529
}
@@ -668,44 +675,42 @@ fn sanitize_user_prompt(input: &str) -> String {
668675
#[cfg(test)]
669676
mod tests {
670677
use super::*;
671-
use crate::agent::types::AgentId;
672-
673-
const TEST_AGENT_NAME: &str = "test_agent";
674678

675679
const TEST_COMMAND_HOOK: &str = r#"
676680
{
677681
"command": "echo hello world"
678682
}
679683
"#;
680684

681-
async fn run_with_timeout<T: Future>(fut: T) {
682-
match tokio::time::timeout(std::time::Duration::from_millis(500), fut).await {
685+
async fn run_with_timeout<T: Future>(timeout: Duration, fut: T) {
686+
match tokio::time::timeout(timeout, fut).await {
683687
Ok(_) => (),
684688
Err(e) => panic!("Future failed to resolve within timeout: {}", e),
685689
}
686690
}
687691

688692
#[tokio::test]
689693
async fn test_hook_execution() {
690-
let mut bg = TaskExecutor::new();
691-
692-
let agent_id = AgentId::new(TEST_AGENT_NAME.to_string());
693-
bg.start_hook_execution(StartHookExecution {
694-
id: HookExecutionId {
695-
hook: Hook {
696-
trigger: HookTrigger::UserPromptSubmit,
697-
config: serde_json::from_str(TEST_COMMAND_HOOK).unwrap(),
694+
let mut executor = TaskExecutor::new();
695+
696+
executor
697+
.start_hook_execution(StartHookExecution {
698+
id: HookExecutionId {
699+
hook: Hook {
700+
trigger: HookTrigger::UserPromptSubmit,
701+
config: serde_json::from_str(TEST_COMMAND_HOOK).unwrap(),
702+
},
703+
tool_context: None,
698704
},
699-
tool_context: None,
700-
},
701-
prompt: None,
702-
})
703-
.await;
705+
prompt: None,
706+
})
707+
.await;
704708

705-
run_with_timeout(async move {
709+
run_with_timeout(Duration::from_millis(100), async move {
706710
let mut event_buf = Vec::new();
707711
loop {
708-
bg.recv_next(&mut event_buf).await;
712+
executor.recv_next(&mut event_buf).await;
713+
// Check if we get a "hello world" successful hook execution.
709714
if event_buf.iter().any(|ev| match ev {
710715
TaskExecutorEvent::HookExecutionEnd(HookExecutionEndEvent { result, .. }) => {
711716
let HookExecutorResult::Completed { result, .. } = result else {
@@ -720,9 +725,9 @@ mod tests {
720725
},
721726
_ => false,
722727
}) {
728+
// Hook succeeded with expected output, break.
723729
break;
724730
}
725-
println!("{:?}", event_buf);
726731
event_buf.drain(..);
727732
}
728733
})

crates/agent/src/agent/tool_utils.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use super::tools::BuiltInTool;
1818
/// Categorizes different types of tool name validation failures according to the requirements by
1919
/// the RTS API.
2020
#[derive(Debug, Clone)]
21+
#[allow(dead_code)] // TODO
2122
pub struct ToolValidationError {
2223
mcp_server_name: String,
2324
tool_spec: ToolSpec,
@@ -34,14 +35,18 @@ impl ToolValidationError {
3435
}
3536
}
3637

38+
// TODO - remove dead code. Keeping for debug purposes
3739
#[derive(Debug, Clone)]
3840
pub enum ToolValidationErrorKind {
39-
OutOfSpecName { transformed_name: String },
41+
OutOfSpecName {
42+
#[allow(dead_code)]
43+
transformed_name: String,
44+
},
4045
EmptyName,
4146
NameTooLong,
4247
EmptyDescription,
4348
DescriptionTooLong,
44-
NameCollision(CanonicalToolName),
49+
NameCollision(#[allow(dead_code)] CanonicalToolName),
4550
}
4651

4752
/// Represents a set of tool specs that conforms to backend validations.
@@ -113,9 +118,9 @@ impl SanitizedToolSpec {
113118
///
114119
/// - `canonical_names` - List of tool names to include in the generated tool specs
115120
/// - `mcp_tool_specs` - Map from an MCP server name to a list of tool specs as returned by the
116-
/// server
121+
/// server
117122
/// - `aliases` - Map from a canonical tool name to an aliased name. This refers to the `aliases`
118-
/// field in the agent config
123+
/// field in the agent config
119124
pub fn sanitize_tool_specs(
120125
canonical_names: Vec<CanonicalToolName>,
121126
mcp_tool_specs: HashMap<String, Vec<ToolSpec>>,

crates/agent/src/agent/tools/execute_cmd.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ use std::collections::HashMap;
55
use std::process::Stdio;
66

77
use bstr::ByteSlice as _;
8-
use futures::StreamExt;
9-
use rand::seq::IndexedRandom;
108
use schemars::{
119
JsonSchema,
1210
schema_for,

0 commit comments

Comments
 (0)