Skip to content

Commit cb27bf9

Browse files
committed
change modelid in conversationstate into model info
1 parent 33c2537 commit cb27bf9

File tree

6 files changed

+196
-68
lines changed

6 files changed

+196
-68
lines changed

crates/chat-cli/src/cli/chat/cli/context.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,7 @@ impl ContextSubcommand {
222222
execute!(session.stderr, style::Print(format!("{}\n\n", "▔".repeat(3))),)?;
223223
}
224224

225-
let context_files_max_size =
226-
calc_max_context_files_size(session.conversation.model.as_deref(), os).await;
225+
let context_files_max_size = calc_max_context_files_size(session.conversation.model.as_ref());
227226
let mut files_as_vec = profile_context_files
228227
.iter()
229228
.map(|(path, content, _)| (path.clone(), content.clone()))

crates/chat-cli/src/cli/chat/cli/model.rs

Lines changed: 138 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ use crossterm::{
99
queue,
1010
};
1111
use dialoguer::Select;
12+
use serde::{
13+
Deserialize,
14+
Deserializer,
15+
Serialize,
16+
};
1217

1318
use crate::api_client::Endpoint;
1419
use crate::cli::chat::{
@@ -17,6 +22,110 @@ use crate::cli::chat::{
1722
ChatState,
1823
};
1924
use crate::os::Os;
25+
26+
#[derive(Debug, Clone, Serialize)]
27+
pub struct ModelInfo {
28+
/// Display name
29+
#[serde(skip_serializing_if = "Option::is_none")]
30+
pub model_name: Option<String>,
31+
/// Actual model id to send in the API
32+
pub model_id: String,
33+
/// Size of the model's context window, in tokens
34+
#[serde(default = "default_context_window")]
35+
pub context_window_tokens: usize,
36+
}
37+
38+
impl ModelInfo {
39+
pub fn from_api_model(model: &Model) -> Self {
40+
let context_window_tokens = model
41+
.token_limits()
42+
.and_then(|limits| limits.max_input_tokens())
43+
.map_or(default_context_window(), |tokens| tokens as usize);
44+
Self {
45+
model_id: model.model_id().to_string(),
46+
model_name: model.model_name().map(|s| s.to_string()),
47+
context_window_tokens,
48+
}
49+
}
50+
51+
/// create a defualt model with only model_id(be compatoble with old stored model data)
52+
pub fn from_id(model_id: String) -> Self {
53+
Self {
54+
model_id,
55+
model_name: None,
56+
context_window_tokens: 200_000,
57+
}
58+
}
59+
60+
pub fn display_name(&self) -> &str {
61+
self.model_name.as_deref().unwrap_or(&self.model_id)
62+
}
63+
}
64+
impl<'de> Deserialize<'de> for ModelInfo {
65+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
66+
where
67+
D: Deserializer<'de>,
68+
{
69+
use std::fmt;
70+
71+
use serde::de::{
72+
self,
73+
MapAccess,
74+
Visitor,
75+
};
76+
77+
struct ModelInfoVisitor;
78+
79+
impl<'de> Visitor<'de> for ModelInfoVisitor {
80+
type Value = ModelInfo;
81+
82+
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
83+
formatter.write_str("a string or a ModelInfo object")
84+
}
85+
86+
// old version: modelid string
87+
fn visit_str<E>(self, value: &str) -> Result<ModelInfo, E>
88+
where
89+
E: de::Error,
90+
{
91+
Ok(ModelInfo {
92+
model_id: value.to_string(),
93+
model_name: None,
94+
context_window_tokens: default_context_window(),
95+
})
96+
}
97+
98+
// new version: modelInfo object
99+
fn visit_map<M>(self, mut map: M) -> Result<ModelInfo, M::Error>
100+
where
101+
M: MapAccess<'de>,
102+
{
103+
let mut model_id = None;
104+
let mut model_name = None;
105+
let mut context_window_tokens = None;
106+
107+
while let Some(key) = map.next_key::<String>()? {
108+
match key.as_str() {
109+
"model_id" => model_id = Some(map.next_value()?),
110+
"model_name" => model_name = map.next_value()?,
111+
"context_window_tokens" => context_window_tokens = Some(map.next_value()?),
112+
_ => {
113+
let _: serde::de::IgnoredAny = map.next_value()?;
114+
},
115+
}
116+
}
117+
118+
Ok(ModelInfo {
119+
model_id: model_id.ok_or_else(|| de::Error::missing_field("model_id"))?,
120+
model_name,
121+
context_window_tokens: context_window_tokens.unwrap_or_else(default_context_window),
122+
})
123+
}
124+
}
125+
126+
deserializer.deserialize_any(ModelInfoVisitor)
127+
}
128+
}
20129
#[deny(missing_docs)]
21130
#[derive(Debug, PartialEq, Args)]
22131
pub struct ModelArgs;
@@ -45,14 +154,13 @@ pub async fn select_model(os: &Os, session: &mut ChatSession) -> Result<Option<C
45154
return Ok(None);
46155
}
47156

48-
let active_model_id = session.conversation.model.as_deref();
157+
let active_model_id = session.conversation.model.as_ref().map(|m| m.model_id.as_str());
49158

50159
let labels: Vec<String> = models
51160
.iter()
52161
.map(|model| {
53-
let display_name = model.model_name().unwrap_or(model.model_id());
54-
55-
if Some(model.model_id()) == active_model_id {
162+
let display_name = model.display_name();
163+
if Some(model.model_id.as_str()) == active_model_id {
56164
format!("{} (active)", display_name)
57165
} else {
58166
display_name.to_owned()
@@ -81,10 +189,9 @@ pub async fn select_model(os: &Os, session: &mut ChatSession) -> Result<Option<C
81189
queue!(session.stderr, style::ResetColor)?;
82190

83191
if let Some(index) = selection {
84-
let selected = &models[index];
85-
let model_id_str = selected.model_id.clone();
86-
session.conversation.model = Some(model_id_str.clone());
87-
let display_name = selected.model_name().unwrap_or(selected.model_id());
192+
let selected = models[index].clone();
193+
session.conversation.model = Some(selected.clone());
194+
let display_name = selected.display_name();
88195

89196
queue!(
90197
session.stderr,
@@ -103,41 +210,38 @@ pub async fn select_model(os: &Os, session: &mut ChatSession) -> Result<Option<C
103210
}))
104211
}
105212

213+
pub async fn get_model_info(model_id: &str, os: &Os) -> Result<ModelInfo, ChatError> {
214+
let (models, _) = get_available_models(os).await?;
215+
216+
models
217+
.into_iter()
218+
.find(|m| m.model_id == model_id)
219+
.ok_or_else(|| ChatError::Custom(format!("Model '{}' not found", model_id).into()))
220+
}
221+
106222
/// Get available models with caching support
107-
pub async fn get_available_models(os: &Os) -> Result<(Vec<Model>, Model), ChatError> {
223+
pub async fn get_available_models(os: &Os) -> Result<(Vec<ModelInfo>, ModelInfo), ChatError> {
108224
let endpoint = Endpoint::configured_value(&os.database);
109225
let region = endpoint.region().as_ref();
110226

111-
os.client
227+
let (api_models, api_default) = os
228+
.client
112229
.get_available_models(region)
113230
.await
114-
.map_err(|e| ChatError::Custom(format!("Failed to fetch available models: {}", e).into()))
231+
.map_err(|e| ChatError::Custom(format!("Failed to fetch available models: {}", e).into()))?;
232+
233+
let models: Vec<ModelInfo> = api_models.iter().map(ModelInfo::from_api_model).collect();
234+
let default_model = ModelInfo::from_api_model(&api_default);
235+
236+
Ok((models, default_model))
115237
}
116238

117239
/// Returns the context window length in tokens for the given model_id.
118240
/// Uses cached model data when available
119-
pub async fn context_window_tokens(model_id: Option<&str>, os: &Os) -> usize {
120-
const DEFAULT_CONTEXT_WINDOW_LENGTH: usize = 200_000;
121-
122-
// If no model_id provided, return default
123-
let Some(model_id) = model_id else {
124-
return DEFAULT_CONTEXT_WINDOW_LENGTH;
125-
};
126-
127-
// Try to get from cached models first
128-
let (models, _) = match get_available_models(os).await {
129-
Ok(models) => models,
130-
Err(_) => {
131-
// If we can't get models, return default
132-
return DEFAULT_CONTEXT_WINDOW_LENGTH;
133-
},
134-
};
241+
pub fn context_window_tokens(model_info: Option<&ModelInfo>) -> usize {
242+
model_info.map(|m| m.context_window_tokens).unwrap_or(200_000)
243+
}
135244

136-
models
137-
.iter()
138-
.find(|m| m.model_id() == model_id)
139-
.and_then(|m| m.token_limits())
140-
.and_then(|limits| limits.max_input_tokens())
141-
.map(|tokens| tokens as usize)
142-
.unwrap_or(DEFAULT_CONTEXT_WINDOW_LENGTH)
245+
fn default_context_window() -> usize {
246+
200_000
143247
}

crates/chat-cli/src/cli/chat/cli/usage.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ impl UsageArgs {
6262
// set a max width for the progress bar for better aesthetic
6363
let progress_bar_width = std::cmp::min(window_width, 80);
6464

65-
let context_window_size = context_window_tokens(session.conversation.model.as_deref(), os).await;
65+
let context_window_size = context_window_tokens(session.conversation.model.as_ref());
6666
let context_width =
6767
((context_token_count.value() as f64 / context_window_size as f64) * progress_bar_width as f64) as usize;
6868
let assistant_width =

crates/chat-cli/src/cli/chat/context.rs

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use crate::cli::agent::hook::{
2323
};
2424
use crate::cli::chat::ChatError;
2525
use crate::cli::chat::cli::hooks::HookExecutor;
26+
use crate::cli::chat::cli::model::ModelInfo;
2627
use crate::os::Os;
2728

2829
#[derive(Debug, Clone)]
@@ -255,9 +256,9 @@ impl ContextManager {
255256
}
256257

257258
/// Calculates the maximum context files size to use for the given model id.
258-
pub async fn calc_max_context_files_size(model_id: Option<&str>, os: &Os) -> usize {
259+
pub fn calc_max_context_files_size(model: Option<&ModelInfo>) -> usize {
259260
// Sets the max as 75% of the context window
260-
context_window_tokens(model_id, os).await.saturating_mul(3) / 4
261+
context_window_tokens(model).saturating_mul(3) / 4
261262
}
262263

263264
/// Process a path, handling glob patterns and file types.
@@ -432,13 +433,22 @@ mod tests {
432433
}
433434

434435
#[test]
435-
async fn test_calc_max_context_files_size() {
436-
let os = Os::new().await.unwrap();
437-
436+
fn test_calc_max_context_files_size() {
438437
assert_eq!(
439-
calc_max_context_files_size(Some("CLAUDE_SONNET_4_20250514_V1_0"), os),
438+
calc_max_context_files_size(Some(&ModelInfo {
439+
model_id: "CLAUDE_SONNET_4_20250514_V1_0".to_string(),
440+
model_name: Some("Claude".to_string()),
441+
context_window_tokens: 200_000,
442+
})),
440443
150_000
441444
);
442-
assert_eq!(calc_max_context_files_size(Some("OPENAI_GPT_OSS_120B_1_0"), os), 96_000);
445+
assert_eq!(
446+
calc_max_context_files_size(Some(&ModelInfo {
447+
model_id: "OPENAI_GPT_OSS_120B_1_0".to_string(),
448+
model_name: Some("GPT".to_string()),
449+
context_window_tokens: 128_000,
450+
})),
451+
96_000
452+
);
443453
}
444454
}

0 commit comments

Comments
 (0)