Skip to content

Commit 4bcd58e

Browse files
committed
combine api res and openai model
1 parent b0d6177 commit 4bcd58e

File tree

7 files changed

+86
-102
lines changed

7 files changed

+86
-102
lines changed

crates/chat-cli/src/api_client/mod.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ pub mod model;
66
mod opt_out;
77
pub mod profile;
88
pub mod send_message_output;
9-
109
use std::sync::Arc;
1110
use std::time::Duration;
1211

@@ -18,6 +17,7 @@ use amzn_codewhisperer_client::types::{
1817
OptOutPreference,
1918
SubscriptionStatus,
2019
TelemetryEvent,
20+
TokenLimits,
2121
UserContext,
2222
};
2323
use amzn_codewhisperer_streaming_client::Client as CodewhispererStreamingClient;
@@ -275,9 +275,7 @@ impl ApiClient {
275275
models.extend(models_output.models().iter().cloned());
276276

277277
if default_model.is_none() {
278-
if let Some(model) = models_output.default_model().cloned() {
279-
default_model = Some(model);
280-
}
278+
default_model = Some(models_output.default_model().clone());
281279
}
282280
}
283281

@@ -308,6 +306,23 @@ impl ApiClient {
308306
tracing::info!("Model cache invalidated");
309307
}
310308

309+
pub async fn get_available_models(&self, region: &str) -> Result<(Vec<Model>, Option<Model>), ApiClientError> {
310+
let (mut models, default_model) = self.list_available_models_cached().await?;
311+
312+
if region == "us-east-1" {
313+
let gpt_oss = Model::builder()
314+
.model_id("OPENAI_GPT_OSS_120B_1_0")
315+
.model_name("openai-gpt-oss-120b-preview")
316+
.token_limits(TokenLimits::builder().max_input_tokens(128_000).build())
317+
.build()
318+
.map_err(ApiClientError::from)?;
319+
320+
models.push(gpt_oss);
321+
}
322+
323+
Ok((models, default_model))
324+
}
325+
311326
pub async fn create_subscription_token(&self) -> Result<CreateSubscriptionTokenOutput, ApiClientError> {
312327
if cfg!(test) {
313328
return Ok(CreateSubscriptionTokenOutput::builder()

crates/chat-cli/src/cli/chat/cli/context.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ impl ContextSubcommand {
222222
execute!(session.stderr, style::Print(format!("{}\n\n", "▔".repeat(3))),)?;
223223
}
224224

225-
let context_files_max_size = calc_max_context_files_size(session.conversation.model.as_deref());
225+
let context_files_max_size =
226+
calc_max_context_files_size(session.conversation.model.as_deref(), os).await;
226227
let mut files_as_vec = profile_context_files
227228
.iter()
228229
.map(|(path, content, _)| (path.clone(), content.clone()))
Lines changed: 32 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use amzn_codewhisperer_client::types::Model;
12
use clap::Args;
23
use crossterm::style::{
34
self,
@@ -20,35 +21,6 @@ use crate::cli::chat::{
2021
ChatState,
2122
};
2223
use crate::os::Os;
23-
24-
pub struct ModelOption {
25-
/// Display name
26-
pub name: &'static str,
27-
/// Actual model id to send in the API
28-
pub model_id: &'static str,
29-
/// Size of the model's context window, in tokens
30-
pub context_window_tokens: usize,
31-
}
32-
33-
const MODEL_OPTIONS: [ModelOption; 2] = [
34-
ModelOption {
35-
name: "claude-4-sonnet",
36-
model_id: "CLAUDE_SONNET_4_20250514_V1_0",
37-
context_window_tokens: 200_000,
38-
},
39-
ModelOption {
40-
name: "claude-3.7-sonnet",
41-
model_id: "CLAUDE_3_7_SONNET_20250219_V1_0",
42-
context_window_tokens: 200_000,
43-
},
44-
];
45-
46-
const GPT_OSS_120B: ModelOption = ModelOption {
47-
name: "openai-gpt-oss-120b-preview",
48-
model_id: "OPENAI_GPT_OSS_120B_1_0",
49-
context_window_tokens: 128_000,
50-
};
51-
5224
#[deny(missing_docs)]
5325
#[derive(Debug, PartialEq, Args)]
5426
pub struct ModelArgs;
@@ -65,11 +37,7 @@ pub async fn select_model(os: &Os, session: &mut ChatSession) -> Result<Option<C
6537
queue!(session.stderr, style::Print("\n"))?;
6638

6739
// Fetch available models from service
68-
let (models, _default_model) = os
69-
.client
70-
.list_available_models_cached()
71-
.await
72-
.map_err(|e| ChatError::Custom(format!("Failed to fetch available models: {}", e).into()))?;
40+
let (models, _default_model) = get_available_models(os).await?;
7341

7442
if models.is_empty() {
7543
queue!(
@@ -82,15 +50,16 @@ pub async fn select_model(os: &Os, session: &mut ChatSession) -> Result<Option<C
8250
}
8351

8452
let active_model_id = session.conversation.model.as_deref();
85-
let model_options = get_model_options(os).await?;
8653

87-
let labels: Vec<String> = model_options
54+
let labels: Vec<String> = models
8855
.iter()
8956
.map(|model| {
57+
let display_name = model.model_name().unwrap_or(model.model_id());
58+
9059
if Some(model.model_id()) == active_model_id {
91-
format!("{} (active)", model.model_id())
60+
format!("{} (active)", display_name)
9261
} else {
93-
model.model_id().to_owned()
62+
display_name.to_owned()
9463
}
9564
})
9665
.collect();
@@ -119,11 +88,12 @@ pub async fn select_model(os: &Os, session: &mut ChatSession) -> Result<Option<C
11988
let selected = &models[index];
12089
let model_id_str = selected.model_id.clone();
12190
session.conversation.model = Some(model_id_str.clone());
91+
let display_name = selected.model_name().unwrap_or(selected.model_id());
12292

12393
queue!(
12494
session.stderr,
12595
style::Print("\n"),
126-
style::Print(format!(" Using {}\n\n", model_id_str)),
96+
style::Print(format!(" Using {}\n\n", display_name)),
12797
style::ResetColor,
12898
style::SetForegroundColor(Color::Reset),
12999
style::SetBackgroundColor(Color::Reset),
@@ -160,60 +130,41 @@ pub async fn default_model_id(os: &Os) -> String {
160130
"claude-4-sonnet".to_string()
161131
}
162132

163-
/// Returns the available models for use.
164-
pub async fn get_model_options(os: &Os) -> Result<Vec<ModelOption>, ChatError> {
165-
let mut model_options = MODEL_OPTIONS.into_iter().collect::<Vec<_>>();
166-
167-
// GPT OSS is only accessible in IAD.
133+
/// Get available models with caching support
134+
pub async fn get_available_models(os: &Os) -> Result<(Vec<Model>, Option<Model>), ChatError> {
168135
let endpoint = Endpoint::configured_value(&os.database);
169-
if endpoint.region().as_ref() != "us-east-1" {
170-
return Ok(model_options);
171-
}
136+
let region = endpoint.region().as_ref();
172137

173-
model_options.push(GPT_OSS_120B);
174-
Ok(model_options)
138+
os.client
139+
.get_available_models(region)
140+
.await
141+
.map_err(|e| ChatError::Custom(format!("Failed to fetch available models: {}", e).into()))
175142
}
176143

177144
/// Returns the context window length in tokens for the given model_id.
178-
pub fn context_window_tokens(model_id: Option<&str>) -> usize {
145+
/// Uses cached model data when available
146+
pub async fn context_window_tokens(model_id: Option<&str>, os: &Os) -> usize {
179147
const DEFAULT_CONTEXT_WINDOW_LENGTH: usize = 200_000;
180148

149+
// If no model_id provided, return default
181150
let Some(model_id) = model_id else {
182151
return DEFAULT_CONTEXT_WINDOW_LENGTH;
183152
};
184153

185-
MODEL_OPTIONS
186-
.iter()
187-
.chain(std::iter::once(&GPT_OSS_120B))
188-
.find(|m| m.model_id == model_id)
189-
.map_or(DEFAULT_CONTEXT_WINDOW_LENGTH, |m| m.context_window_tokens)
190-
}
191-
192-
/// Returns the available models for use.
193-
pub async fn get_model_options(os: &Os) -> Result<Vec<ModelOption>, ChatError> {
194-
let mut model_options = MODEL_OPTIONS.into_iter().collect::<Vec<_>>();
195-
196-
// GPT OSS is only accessible in IAD.
197-
let endpoint = Endpoint::configured_value(&os.database);
198-
if endpoint.region().as_ref() != "us-east-1" {
199-
return Ok(model_options);
200-
}
201-
202-
model_options.push(GPT_OSS_120B);
203-
Ok(model_options)
204-
}
205-
206-
/// Returns the context window length in tokens for the given model_id.
207-
pub fn context_window_tokens(model_id: Option<&str>) -> usize {
208-
const DEFAULT_CONTEXT_WINDOW_LENGTH: usize = 200_000;
209-
210-
let Some(model_id) = model_id else {
211-
return DEFAULT_CONTEXT_WINDOW_LENGTH;
154+
// Try to get from cached models first
155+
let (models, _) = match get_available_models(os).await {
156+
Ok(models) => models,
157+
Err(_) => {
158+
// If we can't get models, return default
159+
return DEFAULT_CONTEXT_WINDOW_LENGTH;
160+
},
212161
};
213162

214-
MODEL_OPTIONS
163+
models
215164
.iter()
216-
.chain(std::iter::once(&GPT_OSS_120B))
217-
.find(|m| m.model_id == model_id)
218-
.map_or(DEFAULT_CONTEXT_WINDOW_LENGTH, |m| m.context_window_tokens)
165+
.find(|m| m.model_id() == model_id)
166+
.and_then(|m| m.token_limits())
167+
.and_then(|limits| limits.max_input_tokens())
168+
.map(|tokens| tokens as usize)
169+
.unwrap_or(DEFAULT_CONTEXT_WINDOW_LENGTH)
219170
}

crates/chat-cli/src/cli/chat/cli/usage.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ impl UsageArgs {
6262
// set a max width for the progress bar for better aesthetic
6363
let progress_bar_width = std::cmp::min(window_width, 80);
6464

65-
let context_window_size = context_window_tokens(session.conversation.model.as_deref());
66-
65+
let context_window_size = context_window_tokens(session.conversation.model.as_deref(), os).await;
6766
let context_width =
6867
((context_token_count.value() as f64 / context_window_size as f64) * progress_bar_width as f64) as usize;
6968
let assistant_width =

crates/chat-cli/src/cli/chat/context.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,9 @@ impl ContextManager {
255255
}
256256

257257
/// Calculates the maximum context files size to use for the given model id.
258-
pub fn calc_max_context_files_size(model_id: Option<&str>) -> usize {
258+
pub async fn calc_max_context_files_size(model_id: Option<&str>, os: &Os) -> usize {
259259
// Sets the max as 75% of the context window
260-
context_window_tokens(model_id).saturating_mul(3) / 4
260+
context_window_tokens(model_id, os).await.saturating_mul(3) / 4
261261
}
262262

263263
/// Process a path, handling glob patterns and file types.
@@ -432,11 +432,13 @@ mod tests {
432432
}
433433

434434
#[test]
435-
fn test_calc_max_context_files_size() {
435+
async fn test_calc_max_context_files_size() {
436+
let os = Os::new().await.unwrap();
437+
436438
assert_eq!(
437-
calc_max_context_files_size(Some("CLAUDE_SONNET_4_20250514_V1_0")),
439+
calc_max_context_files_size(Some("CLAUDE_SONNET_4_20250514_V1_0"), os),
438440
150_000
439441
);
440-
assert_eq!(calc_max_context_files_size(Some("OPENAI_GPT_OSS_120B_1_0")), 96_000);
442+
assert_eq!(calc_max_context_files_size(Some("OPENAI_GPT_OSS_120B_1_0"), os), 96_000);
441443
}
442444
}

crates/chat-cli/src/cli/chat/conversation.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,14 @@ impl ConversationState {
118118
tool_config: HashMap<String, ToolSpec>,
119119
tool_manager: ToolManager,
120120
current_model_id: Option<String>,
121+
os: &Os,
121122
) -> Self {
122123
let context_manager = if let Some(agent) = agents.get_active() {
123-
ContextManager::from_agent(agent, calc_max_context_files_size(current_model_id.as_deref())).ok()
124+
ContextManager::from_agent(
125+
agent,
126+
calc_max_context_files_size(current_model_id.as_deref(), os).await,
127+
)
128+
.ok()
124129
} else {
125130
None
126131
};
@@ -638,7 +643,7 @@ impl ConversationState {
638643
/// Get the current token warning level
639644
pub async fn get_token_warning_level(&mut self, os: &Os) -> Result<TokenWarningLevel, ChatError> {
640645
let total_chars = self.calculate_char_count(os).await?;
641-
let max_chars = TokenCounter::token_to_chars(context_window_tokens(self.model.as_deref()));
646+
let max_chars = TokenCounter::token_to_chars(context_window_tokens(self.model.as_deref(), os).await);
642647

643648
Ok(if *total_chars >= max_chars {
644649
TokenWarningLevel::Critical
@@ -1061,6 +1066,7 @@ mod tests {
10611066
tool_manager.load_tools(&mut os, &mut output).await.unwrap(),
10621067
tool_manager,
10631068
None,
1069+
os,
10641070
)
10651071
.await;
10661072

@@ -1092,6 +1098,7 @@ mod tests {
10921098
tool_config.clone(),
10931099
tool_manager.clone(),
10941100
None,
1101+
os,
10951102
)
10961103
.await;
10971104
conversation.set_next_user_message("start".to_string()).await;
@@ -1120,8 +1127,15 @@ mod tests {
11201127
}
11211128

11221129
// Build a long conversation history of user messages mixed in with tool results.
1123-
let mut conversation =
1124-
ConversationState::new("fake_conv_id", agents, tool_config.clone(), tool_manager.clone(), None).await;
1130+
let mut conversation = ConversationState::new(
1131+
"fake_conv_id",
1132+
agents,
1133+
tool_config.clone(),
1134+
tool_manager.clone(),
1135+
None,
1136+
os,
1137+
)
1138+
.await;
11251139
conversation.set_next_user_message("start".to_string()).await;
11261140
for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) {
11271141
let s = conversation
@@ -1173,6 +1187,7 @@ mod tests {
11731187
tool_manager.load_tools(&mut os, &mut output).await.unwrap(),
11741188
tool_manager,
11751189
None,
1190+
os,
11761191
)
11771192
.await;
11781193

crates/chat-cli/src/cli/chat/mod.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ use clap::{
4242
};
4343
use cli::compact::CompactStrategy;
4444
use cli::model::{
45-
get_model_options,
45+
get_available_models,
4646
select_model,
4747
};
4848
pub use conversation::ConversationState;
@@ -629,7 +629,7 @@ impl ChatSession {
629629
cs.enforce_tool_use_history_invariants();
630630
cs
631631
},
632-
false => ConversationState::new(conversation_id, agents, tool_config, tool_manager, model_id).await,
632+
false => ConversationState::new(conversation_id, agents, tool_config, tool_manager, model_id, os).await,
633633
};
634634

635635
// Spawn a task for listening and broadcasting sigints.
@@ -1185,12 +1185,13 @@ impl ChatSession {
11851185
self.stderr.flush()?;
11861186

11871187
if let Some(ref id) = self.conversation.model {
1188-
let model_options = get_model_options(os).await?;
1189-
if let Some(model_option) = model_options.iter().find(|option| option.model_id == *id) {
1188+
let (models, _default_model) = get_available_models(os).await?;
1189+
if let Some(model_option) = models.iter().find(|option| option.model_id == *id) {
1190+
let display_name = model_option.model_name().unwrap_or_else(|| &model_option.model_id);
11901191
execute!(
11911192
self.stderr,
11921193
style::SetForegroundColor(Color::Cyan),
1193-
style::Print(format!("🤖 You are chatting with {}\n", model_option.name)),
1194+
style::Print(format!("🤖 You are chatting with {}\n", display_name)),
11941195
style::SetForegroundColor(Color::Reset),
11951196
style::Print("\n")
11961197
)?;

0 commit comments

Comments
 (0)