@@ -9,6 +9,11 @@ use crossterm::{
99 queue,
1010} ;
1111use dialoguer:: Select ;
12+ use serde:: {
13+ Deserialize ,
14+ Deserializer ,
15+ Serialize ,
16+ } ;
1217
1318use crate :: api_client:: Endpoint ;
1419use crate :: cli:: chat:: {
@@ -17,6 +22,110 @@ use crate::cli::chat::{
1722 ChatState ,
1823} ;
1924use crate :: os:: Os ;
25+
26+ #[ derive( Debug , Clone , Serialize ) ]
27+ pub struct ModelInfo {
28+ /// Display name
29+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
30+ pub model_name : Option < String > ,
31+ /// Actual model id to send in the API
32+ pub model_id : String ,
33+ /// Size of the model's context window, in tokens
34+ #[ serde( default = "default_context_window" ) ]
35+ pub context_window_tokens : usize ,
36+ }
37+
38+ impl ModelInfo {
39+ pub fn from_api_model ( model : & Model ) -> Self {
40+ let context_window_tokens = model
41+ . token_limits ( )
42+ . and_then ( |limits| limits. max_input_tokens ( ) )
43+ . map_or ( default_context_window ( ) , |tokens| tokens as usize ) ;
44+ Self {
45+ model_id : model. model_id ( ) . to_string ( ) ,
46+ model_name : model. model_name ( ) . map ( |s| s. to_string ( ) ) ,
47+ context_window_tokens,
48+ }
49+ }
50+
51+ /// create a defualt model with only model_id(be compatoble with old stored model data)
52+ pub fn from_id ( model_id : String ) -> Self {
53+ Self {
54+ model_id,
55+ model_name : None ,
56+ context_window_tokens : 200_000 ,
57+ }
58+ }
59+
60+ pub fn display_name ( & self ) -> & str {
61+ self . model_name . as_deref ( ) . unwrap_or ( & self . model_id )
62+ }
63+ }
64+ impl < ' de > Deserialize < ' de > for ModelInfo {
65+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
66+ where
67+ D : Deserializer < ' de > ,
68+ {
69+ use std:: fmt;
70+
71+ use serde:: de:: {
72+ self ,
73+ MapAccess ,
74+ Visitor ,
75+ } ;
76+
77+ struct ModelInfoVisitor ;
78+
79+ impl < ' de > Visitor < ' de > for ModelInfoVisitor {
80+ type Value = ModelInfo ;
81+
82+ fn expecting ( & self , formatter : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
83+ formatter. write_str ( "a string or a ModelInfo object" )
84+ }
85+
86+ // old version: modelid string
87+ fn visit_str < E > ( self , value : & str ) -> Result < ModelInfo , E >
88+ where
89+ E : de:: Error ,
90+ {
91+ Ok ( ModelInfo {
92+ model_id : value. to_string ( ) ,
93+ model_name : None ,
94+ context_window_tokens : default_context_window ( ) ,
95+ } )
96+ }
97+
98+ // new version: modelInfo object
99+ fn visit_map < M > ( self , mut map : M ) -> Result < ModelInfo , M :: Error >
100+ where
101+ M : MapAccess < ' de > ,
102+ {
103+ let mut model_id = None ;
104+ let mut model_name = None ;
105+ let mut context_window_tokens = None ;
106+
107+ while let Some ( key) = map. next_key :: < String > ( ) ? {
108+ match key. as_str ( ) {
109+ "model_id" => model_id = Some ( map. next_value ( ) ?) ,
110+ "model_name" => model_name = map. next_value ( ) ?,
111+ "context_window_tokens" => context_window_tokens = Some ( map. next_value ( ) ?) ,
112+ _ => {
113+ let _: serde:: de:: IgnoredAny = map. next_value ( ) ?;
114+ } ,
115+ }
116+ }
117+
118+ Ok ( ModelInfo {
119+ model_id : model_id. ok_or_else ( || de:: Error :: missing_field ( "model_id" ) ) ?,
120+ model_name,
121+ context_window_tokens : context_window_tokens. unwrap_or_else ( default_context_window) ,
122+ } )
123+ }
124+ }
125+
126+ deserializer. deserialize_any ( ModelInfoVisitor )
127+ }
128+ }
20129#[ deny( missing_docs) ]
21130#[ derive( Debug , PartialEq , Args ) ]
22131pub struct ModelArgs ;
@@ -45,14 +154,13 @@ pub async fn select_model(os: &Os, session: &mut ChatSession) -> Result<Option<C
45154 return Ok ( None ) ;
46155 }
47156
48- let active_model_id = session. conversation . model . as_deref ( ) ;
157+ let active_model_id = session. conversation . model . as_ref ( ) . map ( |m| m . model_id . as_str ( ) ) ;
49158
50159 let labels: Vec < String > = models
51160 . iter ( )
52161 . map ( |model| {
53- let display_name = model. model_name ( ) . unwrap_or ( model. model_id ( ) ) ;
54-
55- if Some ( model. model_id ( ) ) == active_model_id {
162+ let display_name = model. display_name ( ) ;
163+ if Some ( model. model_id . as_str ( ) ) == active_model_id {
56164 format ! ( "{} (active)" , display_name)
57165 } else {
58166 display_name. to_owned ( )
@@ -81,10 +189,9 @@ pub async fn select_model(os: &Os, session: &mut ChatSession) -> Result<Option<C
81189 queue ! ( session. stderr, style:: ResetColor ) ?;
82190
83191 if let Some ( index) = selection {
84- let selected = & models[ index] ;
85- let model_id_str = selected. model_id . clone ( ) ;
86- session. conversation . model = Some ( model_id_str. clone ( ) ) ;
87- let display_name = selected. model_name ( ) . unwrap_or ( selected. model_id ( ) ) ;
192+ let selected = models[ index] . clone ( ) ;
193+ session. conversation . model = Some ( selected. clone ( ) ) ;
194+ let display_name = selected. display_name ( ) ;
88195
89196 queue ! (
90197 session. stderr,
@@ -103,41 +210,38 @@ pub async fn select_model(os: &Os, session: &mut ChatSession) -> Result<Option<C
103210 } ) )
104211}
105212
213+ pub async fn get_model_info ( model_id : & str , os : & Os ) -> Result < ModelInfo , ChatError > {
214+ let ( models, _) = get_available_models ( os) . await ?;
215+
216+ models
217+ . into_iter ( )
218+ . find ( |m| m. model_id == model_id)
219+ . ok_or_else ( || ChatError :: Custom ( format ! ( "Model '{}' not found" , model_id) . into ( ) ) )
220+ }
221+
106222/// Get available models with caching support
107- pub async fn get_available_models ( os : & Os ) -> Result < ( Vec < Model > , Model ) , ChatError > {
223+ pub async fn get_available_models ( os : & Os ) -> Result < ( Vec < ModelInfo > , ModelInfo ) , ChatError > {
108224 let endpoint = Endpoint :: configured_value ( & os. database ) ;
109225 let region = endpoint. region ( ) . as_ref ( ) ;
110226
111- os. client
227+ let ( api_models, api_default) = os
228+ . client
112229 . get_available_models ( region)
113230 . await
114- . map_err ( |e| ChatError :: Custom ( format ! ( "Failed to fetch available models: {}" , e) . into ( ) ) )
231+ . map_err ( |e| ChatError :: Custom ( format ! ( "Failed to fetch available models: {}" , e) . into ( ) ) ) ?;
232+
233+ let models: Vec < ModelInfo > = api_models. iter ( ) . map ( ModelInfo :: from_api_model) . collect ( ) ;
234+ let default_model = ModelInfo :: from_api_model ( & api_default) ;
235+
236+ Ok ( ( models, default_model) )
115237}
116238
117239/// Returns the context window length in tokens for the given model_id.
118240/// Uses cached model data when available
119- pub async fn context_window_tokens ( model_id : Option < & str > , os : & Os ) -> usize {
120- const DEFAULT_CONTEXT_WINDOW_LENGTH : usize = 200_000 ;
121-
122- // If no model_id provided, return default
123- let Some ( model_id) = model_id else {
124- return DEFAULT_CONTEXT_WINDOW_LENGTH ;
125- } ;
126-
127- // Try to get from cached models first
128- let ( models, _) = match get_available_models ( os) . await {
129- Ok ( models) => models,
130- Err ( _) => {
131- // If we can't get models, return default
132- return DEFAULT_CONTEXT_WINDOW_LENGTH ;
133- } ,
134- } ;
241+ pub fn context_window_tokens ( model_info : Option < & ModelInfo > ) -> usize {
242+ model_info. map ( |m| m. context_window_tokens ) . unwrap_or ( 200_000 )
243+ }
135244
136- models
137- . iter ( )
138- . find ( |m| m. model_id ( ) == model_id)
139- . and_then ( |m| m. token_limits ( ) )
140- . and_then ( |limits| limits. max_input_tokens ( ) )
141- . map ( |tokens| tokens as usize )
142- . unwrap_or ( DEFAULT_CONTEXT_WINDOW_LENGTH )
245+ fn default_context_window ( ) -> usize {
246+ 200_000
143247}
0 commit comments