@@ -4,6 +4,8 @@ use std::sync::{
4
4
} ;
5
5
6
6
use amzn_codewhisperer_streaming_client:: Client as CodewhispererStreamingClient ;
7
+ use amzn_qdeveloper_streaming_client:: Client as QDeveloperStreamingClient ;
8
+ use amzn_qdeveloper_streaming_client:: types:: Origin ;
7
9
use aws_types:: request_id:: RequestId ;
8
10
use tracing:: {
9
11
debug,
@@ -12,6 +14,7 @@ use tracing::{
12
14
13
15
use super :: shared:: {
14
16
bearer_sdk_config,
17
+ sigv4_sdk_config,
15
18
stalled_stream_protection_config,
16
19
} ;
17
20
use crate :: api_client:: interceptor:: opt_out:: OptOutInterceptor ;
@@ -40,12 +43,14 @@ mod inner {
40
43
} ;
41
44
42
45
use amzn_codewhisperer_streaming_client:: Client as CodewhispererStreamingClient ;
46
+ use amzn_qdeveloper_streaming_client:: Client as QDeveloperStreamingClient ;
43
47
44
48
use crate :: api_client:: model:: ChatResponseStream ;
45
49
46
50
#[ derive( Clone , Debug ) ]
47
51
pub enum Inner {
48
52
Codewhisperer ( CodewhispererStreamingClient ) ,
53
+ QDeveloper ( QDeveloperStreamingClient ) ,
49
54
Mock ( Arc < Mutex < std:: vec:: IntoIter < Vec < ChatResponseStream > > > > ) ,
50
55
}
51
56
}
@@ -58,7 +63,13 @@ pub struct StreamingClient {
58
63
59
64
impl StreamingClient {
60
65
pub async fn new ( database : & mut Database ) -> Result < Self , ApiClientError > {
61
- Self :: new_codewhisperer_client ( database, & Endpoint :: load_codewhisperer ( database) ) . await
66
+ // If SIGV4_AUTH_ENABLED is true, use Q developer client
67
+ if std:: env:: var ( "AMAZON_Q_SIGV4" ) . is_ok_and ( |v| !v. is_empty ( ) ) {
68
+ Self :: new_qdeveloper_client ( database, & Endpoint :: load_q ( database) ) . await
69
+ } else {
70
+ // Default to CodeWhisperer client
71
+ Self :: new_codewhisperer_client ( database, & Endpoint :: load_codewhisperer ( database) ) . await
72
+ }
62
73
}
63
74
64
75
pub fn mock ( events : Vec < Vec < ChatResponseStream > > ) -> Self {
@@ -96,6 +107,25 @@ impl StreamingClient {
96
107
Ok ( Self { inner, profile } )
97
108
}
98
109
110
+ // Add SigV4 client creation method
111
+ pub async fn new_qdeveloper_client ( database : & Database , endpoint : & Endpoint ) -> Result < Self , ApiClientError > {
112
+ let conf_builder: amzn_qdeveloper_streaming_client:: config:: Builder =
113
+ ( & sigv4_sdk_config ( database, endpoint) . await ?) . into ( ) ;
114
+ let conf = conf_builder
115
+ . http_client ( crate :: aws_common:: http_client:: client ( ) )
116
+ . interceptor ( OptOutInterceptor :: new ( database) )
117
+ . interceptor ( UserAgentOverrideInterceptor :: new ( ) )
118
+ . app_name ( app_name ( ) )
119
+ . endpoint_url ( endpoint. url ( ) )
120
+ . stalled_stream_protection ( stalled_stream_protection_config ( ) )
121
+ . build ( ) ;
122
+ let client = QDeveloperStreamingClient :: from_conf ( conf) ;
123
+ Ok ( Self {
124
+ inner : inner:: Inner :: QDeveloper ( client) ,
125
+ profile : None ,
126
+ } )
127
+ }
128
+
99
129
pub async fn send_message ( & self , conversation : ConversationState ) -> Result < SendMessageOutput , ApiClientError > {
100
130
debug ! ( "Sending conversation: {:#?}" , conversation) ;
101
131
let ConversationState {
@@ -180,6 +210,51 @@ impl StreamingClient {
180
210
} ,
181
211
}
182
212
} ,
213
+ inner:: Inner :: QDeveloper ( client) => {
214
+ let conversation_state = amzn_qdeveloper_streaming_client:: types:: ConversationState :: builder ( )
215
+ . set_conversation_id ( conversation_id)
216
+ . current_message ( amzn_qdeveloper_streaming_client:: types:: ChatMessage :: UserInputMessage (
217
+ user_input_message. into ( ) ,
218
+ ) )
219
+ . chat_trigger_type ( amzn_qdeveloper_streaming_client:: types:: ChatTriggerType :: Manual )
220
+ . set_history (
221
+ history
222
+ . map ( |v| v. into_iter ( ) . map ( |i| i. try_into ( ) ) . collect :: < Result < Vec < _ > , _ > > ( ) )
223
+ . transpose ( ) ?,
224
+ )
225
+ . build ( )
226
+ . expect ( "building conversation_state should not fail" ) ;
227
+
228
+ let response = client
229
+ . send_message ( )
230
+ . conversation_state ( conversation_state)
231
+ . set_source ( Some ( Origin :: from ( "CLI" ) ) )
232
+ . send ( )
233
+ . await ;
234
+
235
+ match response {
236
+ Ok ( resp) => Ok ( SendMessageOutput :: QDeveloper ( resp) ) ,
237
+ Err ( e) => {
238
+ let status_code = e. raw_response ( ) . map ( |res| res. status ( ) . as_u16 ( ) ) ;
239
+ let is_quota_breach = e. raw_response ( ) . is_some_and ( |resp| resp. status ( ) . as_u16 ( ) == 429 ) ;
240
+ let is_context_window_overflow = e. as_service_error ( ) . is_some_and ( |err| {
241
+ matches ! ( err, err if err. meta( ) . code( ) == Some ( "ValidationException" )
242
+ && err. meta( ) . message( ) == Some ( "Input is too long." ) )
243
+ } ) ;
244
+
245
+ if is_quota_breach {
246
+ Err ( ApiClientError :: QuotaBreach {
247
+ message : "quota has reached its limit" ,
248
+ status_code,
249
+ } )
250
+ } else if is_context_window_overflow {
251
+ Err ( ApiClientError :: ContextWindowOverflow { status_code } )
252
+ } else {
253
+ Err ( e. into ( ) )
254
+ }
255
+ } ,
256
+ }
257
+ } ,
183
258
inner:: Inner :: Mock ( events) => {
184
259
let mut new_events = events. lock ( ) . unwrap ( ) . next ( ) . unwrap_or_default ( ) . clone ( ) ;
185
260
new_events. reverse ( ) ;
@@ -194,13 +269,15 @@ pub enum SendMessageOutput {
194
269
Codewhisperer (
195
270
amzn_codewhisperer_streaming_client:: operation:: generate_assistant_response:: GenerateAssistantResponseOutput ,
196
271
) ,
272
+ QDeveloper ( amzn_qdeveloper_streaming_client:: operation:: send_message:: SendMessageOutput ) ,
197
273
Mock ( Vec < ChatResponseStream > ) ,
198
274
}
199
275
200
276
impl SendMessageOutput {
201
277
pub fn request_id ( & self ) -> Option < & str > {
202
278
match self {
203
279
SendMessageOutput :: Codewhisperer ( output) => output. request_id ( ) ,
280
+ SendMessageOutput :: QDeveloper ( output) => output. request_id ( ) ,
204
281
SendMessageOutput :: Mock ( _) => None ,
205
282
}
206
283
}
@@ -212,6 +289,7 @@ impl SendMessageOutput {
212
289
. recv ( )
213
290
. await ?
214
291
. map ( |s| s. into ( ) ) ) ,
292
+ SendMessageOutput :: QDeveloper ( output) => Ok ( output. send_message_response . recv ( ) . await ?. map ( |s| s. into ( ) ) ) ,
215
293
SendMessageOutput :: Mock ( vec) => Ok ( vec. pop ( ) ) ,
216
294
}
217
295
}
@@ -221,6 +299,7 @@ impl RequestId for SendMessageOutput {
221
299
fn request_id ( & self ) -> Option < & str > {
222
300
match self {
223
301
SendMessageOutput :: Codewhisperer ( output) => output. request_id ( ) ,
302
+ SendMessageOutput :: QDeveloper ( output) => output. request_id ( ) ,
224
303
SendMessageOutput :: Mock ( _) => Some ( "<mock-request-id>" ) ,
225
304
}
226
305
}
@@ -242,6 +321,7 @@ mod tests {
242
321
243
322
let _ = StreamingClient :: new ( & mut database) . await ;
244
323
let _ = StreamingClient :: new_codewhisperer_client ( & mut database, & endpoint) . await ;
324
+ let _ = StreamingClient :: new_qdeveloper_client ( & database, & endpoint) . await ;
245
325
}
246
326
247
327
#[ tokio:: test]
0 commit comments