Skip to content

Commit 8b94105

Browse files
committed
tests: add first unit test
1 parent d6d05ae commit 8b94105

File tree

3 files changed

+89
-5
lines changed

3 files changed

+89
-5
lines changed

rustv1/examples/bedrock-agent-runtime/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ edition = "2021"
66
[dependencies]
77
aws-config = "1.6.3"
88
aws-sdk-bedrockagentruntime = "1.98.0"
9+
aws-smithy-types = "1.3.2"
10+
mockall = "0.13.1"
911
tokio = { version = "1.45.1", features = ["full"] }
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
// List test modules here. This mod.rs is gated by #[cfg(test)] in the crate mod.rs.
5+
pub mod scenario_with_mocks;

rustv1/examples/bedrock-agent-runtime/src/bin/invoke-agent.rs

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,52 @@
11
use aws_config::{BehaviorVersion, SdkConfig};
2-
use aws_sdk_bedrockagentruntime::{self as bedrockagentruntime, types::ResponseStream};
2+
use aws_sdk_bedrockagentruntime::{
3+
self as bedrockagentruntime,
4+
types::{error::ResponseStreamError, ResponseStream},
5+
};
6+
#[allow(unused_imports)]
7+
use mockall::automock;
38

49
const BEDROCK_AGENT_ID: &str = "AJBHXXILZN";
510
const BEDROCK_AGENT_ALIAS_ID: &str = "AVKP1ITZAA";
611
const BEDROCK_AGENT_REGION: &str = "us-east-1";
712

13+
#[cfg(not(test))]
14+
pub use EventReceiverImpl as EventReceiver;
15+
#[cfg(test)]
16+
pub use MockEventReceiverImpl as EventReceiver;
17+
18+
pub struct EventReceiverImpl {
19+
inner: aws_sdk_bedrockagentruntime::primitives::event_stream::EventReceiver<
20+
ResponseStream,
21+
ResponseStreamError,
22+
>,
23+
}
24+
25+
#[cfg_attr(test, automock)]
26+
impl EventReceiverImpl {
27+
#[allow(dead_code)]
28+
pub fn new(
29+
inner: aws_sdk_bedrockagentruntime::primitives::event_stream::EventReceiver<
30+
ResponseStream,
31+
ResponseStreamError,
32+
>,
33+
) -> Self {
34+
Self { inner }
35+
}
36+
37+
pub async fn recv(
38+
&mut self,
39+
) -> Result<
40+
Option<ResponseStream>,
41+
aws_sdk_bedrockagentruntime::error::SdkError<
42+
ResponseStreamError,
43+
aws_smithy_types::event_stream::RawMessage,
44+
>,
45+
> {
46+
self.inner.recv().await
47+
}
48+
}
49+
850
#[tokio::main]
951
async fn main() -> Result<(), Box<bedrockagentruntime::Error>> {
1052
let result = invoke_bedrock_agent("I need help.".to_string(), "123".to_string()).await?;
@@ -31,11 +73,19 @@ async fn invoke_bedrock_agent(
3173

3274
let response = command_builder.send().await?;
3375

34-
let mut response_stream = response.completion;
76+
let response_stream = response.completion;
77+
78+
let event_receiver = EventReceiver::new(response_stream);
79+
80+
process_agent_response_stream(event_receiver).await
81+
}
82+
83+
async fn process_agent_response_stream(
84+
mut event_receiver: EventReceiver,
85+
) -> Result<String, bedrockagentruntime::Error> {
3586
let mut full_agent_text_response = String::new();
3687

37-
println!("Processing Bedrock agent response stream...");
38-
while let Some(event_result) = response_stream.recv().await? {
88+
while let Some(event_result) = event_receiver.recv().await? {
3989
match event_result {
4090
ResponseStream::Chunk(chunk) => {
4191
if let Some(bytes) = chunk.bytes {
@@ -54,6 +104,33 @@ async fn invoke_bedrock_agent(
54104
}
55105
}
56106
}
57-
58107
Ok(full_agent_text_response)
59108
}
109+
110+
#[cfg(test)]
111+
mod test {
112+
use super::*;
113+
114+
#[tokio::test]
115+
async fn test_process_agent_response_stream() {
116+
let mut mock = MockEventReceiverImpl::default();
117+
mock.expect_recv().times(1).returning(|| {
118+
Ok(Some(
119+
aws_sdk_bedrockagentruntime::types::ResponseStream::Chunk(
120+
aws_sdk_bedrockagentruntime::types::PayloadPart::builder()
121+
.set_bytes(Some(aws_smithy_types::Blob::new(vec![
122+
116, 101, 115, 116, 32, 99, 111, 109, 112, 108, 101, 116, 105, 111, 110,
123+
])))
124+
.build(),
125+
),
126+
))
127+
});
128+
129+
// end the stream
130+
mock.expect_recv().times(1).returning(|| Ok(None));
131+
132+
let response = process_agent_response_stream(mock).await.unwrap();
133+
134+
assert_eq!("test completion", response);
135+
}
136+
}

0 commit comments

Comments
 (0)