Skip to content

Commit 072840c

Browse files
committed
add modelid in cs strucutr and setting about default model
1 parent 68ab0ed commit 072840c

File tree

4 files changed

+30
-10
lines changed

4 files changed

+30
-10
lines changed

crates/chat-cli/src/api_client/clients/streaming_client.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ impl StreamingClient {
126126
})
127127
}
128128

129+
//todo yifan generate response
129130
pub async fn send_message(
130131
&self,
131132
conversation_state: ConversationState,

crates/chat-cli/src/cli/chat/conversation_state.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ pub struct ConversationState {
105105
latest_summary: Option<String>,
106106
#[serde(skip)]
107107
pub updates: Option<SharedWriter>,
108+
/// Model explicitly selected by the user in this conversation state via `/model`. (`None` == auto)
109+
#[serde(skip)]
110+
pub current_model_id: Option<String>,
108111
}
109112

110113
impl ConversationState {
@@ -115,6 +118,7 @@ impl ConversationState {
115118
profile: Option<String>,
116119
updates: Option<SharedWriter>,
117120
tool_manager: ToolManager,
121+
curren_model_id: Option<String>,
118122
) -> Self {
119123
// Initialize context manager
120124
let context_manager = match ContextManager::new(ctx, None).await {
@@ -157,6 +161,7 @@ impl ConversationState {
157161
context_message_length: None,
158162
latest_summary: None,
159163
updates,
164+
current_model_id: curren_model_id,
160165
}
161166
}
162167

@@ -1059,6 +1064,7 @@ mod tests {
10591064
None,
10601065
None,
10611066
tool_manager,
1067+
None,
10621068
)
10631069
.await;
10641070

@@ -1089,6 +1095,7 @@ mod tests {
10891095
None,
10901096
None,
10911097
tool_manager.clone(),
1098+
None,
10921099
)
10931100
.await;
10941101
conversation_state.set_next_user_message("start".to_string()).await;
@@ -1120,6 +1127,7 @@ mod tests {
11201127
None,
11211128
None,
11221129
tool_manager.clone(),
1130+
None,
11231131
)
11241132
.await;
11251133
conversation_state.set_next_user_message("start".to_string()).await;
@@ -1165,6 +1173,7 @@ mod tests {
11651173
None,
11661174
None,
11671175
tool_manager,
1176+
None,
11681177
)
11691178
.await;
11701179

@@ -1235,6 +1244,7 @@ mod tests {
12351244
None,
12361245
Some(SharedWriter::stdout()),
12371246
tool_manager,
1247+
None,
12381248
)
12391249
.await;
12401250

crates/chat-cli/src/cli/chat/mod.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,7 @@ impl ChatContext {
562562
let output_clone = output.clone();
563563

564564
let mut existing_conversation = false;
565+
let model_id = Some(database.settings.get_string(Setting::UserDefaultModel).unwrap());
565566
let conversation_state = if resume_conversation {
566567
let prior = std::env::current_dir()
567568
.ok()
@@ -588,6 +589,7 @@ impl ChatContext {
588589
profile,
589590
Some(output_clone),
590591
tool_manager,
592+
model_id
591593
)
592594
.await
593595
}
@@ -599,6 +601,7 @@ impl ChatContext {
599601
profile,
600602
Some(output_clone),
601603
tool_manager,
604+
model_id
602605
)
603606
.await
604607
};
@@ -3014,21 +3017,14 @@ impl ChatContext {
30143017
queue!(self.output, style::Print("\n"))?;
30153018
let labels: Vec<&str> = MODEL_OPTIONS.iter().map(|(l, _)| *l).collect();
30163019
let selection: Option<_> = match Select::with_theme(&crate::util::dialoguer_theme())
3017-
.with_prompt("choose your model")
3020+
.with_prompt("Select the model you want to use for chat")
30183021
.items(&labels)
30193022
.default(0)
30203023
.interact_on_opt(&dialoguer::console::Term::stdout())
30213024
{
30223025
Ok(sel) => sel,
30233026
// Ctrl‑C -> Err(Interrupted)
3024-
Err(DError::IO(ref e)) if e.kind() == io::ErrorKind::Interrupted => {
3025-
queue!(
3026-
self.output,
3027-
style::Print("\n"),
3028-
style::Print("⚠️ User cancelled selection\n\n")
3029-
)?;
3030-
None
3031-
},
3027+
Err(DError::IO(ref e)) if e.kind() == io::ErrorKind::Interrupted => None,
30323028
Err(e) => return Err(ChatError::Custom(format!("Failed to choose model: {e}").into())),
30333029
};
30343030

@@ -3039,7 +3035,7 @@ impl ChatContext {
30393035
queue,
30403036
style,
30413037
};
3042-
queue!(self.output, style::Print(format!("\n✅ change to : {}\n\n", label)))?;
3038+
queue!(self.output, style::Print("\n"), style::Print(format!(" Swtiched model to {}\n\n", label)))?;
30433039
}
30443040

30453041
ChatState::PromptUser {
@@ -3729,6 +3725,8 @@ impl ChatContext {
37293725

37303726
Ok(())
37313727
}
3728+
3729+
37323730
}
37333731

37343732
/// Prints hook configuration grouped by trigger: conversation session start or per user message

crates/chat-cli/src/database/settings.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub enum Setting {
3131
McpInitTimeout,
3232
McpNoInteractiveTimeout,
3333
McpLoadedBefore,
34+
UserDefaultModel,
3435
}
3536

3637
impl AsRef<str> for Setting {
@@ -50,6 +51,7 @@ impl AsRef<str> for Setting {
5051
Self::McpInitTimeout => "mcp.initTimeout",
5152
Self::McpNoInteractiveTimeout => "mcp.noInteractiveTimeout",
5253
Self::McpLoadedBefore => "mcp.loadedBefore",
54+
Self::UserDefaultModel => "chat.userDefaultModel",
5355
}
5456
}
5557
}
@@ -79,6 +81,7 @@ impl TryFrom<&str> for Setting {
7981
"mcp.initTimeout" => Ok(Self::McpInitTimeout),
8082
"mcp.noInteractiveTimeout" => Ok(Self::McpNoInteractiveTimeout),
8183
"mcp.loadedBefore" => Ok(Self::McpLoadedBefore),
84+
"chat.userDefaultModel" => Ok(Self::UserDefaultModel),
8285
_ => Err(DatabaseError::InvalidSetting(value.to_string())),
8386
}
8487
}
@@ -197,11 +200,15 @@ mod test {
197200
assert_eq!(settings.get(Setting::OldClientId), None);
198201
assert_eq!(settings.get(Setting::ShareCodeWhispererContent), None);
199202
assert_eq!(settings.get(Setting::McpLoadedBefore), None);
203+
assert_eq!(settings.get(Setting::UserDefaultModel), None);
204+
200205

201206
settings.set(Setting::TelemetryEnabled, true).await.unwrap();
202207
settings.set(Setting::OldClientId, "test").await.unwrap();
203208
settings.set(Setting::ShareCodeWhispererContent, false).await.unwrap();
204209
settings.set(Setting::McpLoadedBefore, true).await.unwrap();
210+
settings.set(Setting::UserDefaultModel, "model 1").await.unwrap();
211+
205212

206213
assert_eq!(settings.get(Setting::TelemetryEnabled), Some(&Value::Bool(true)));
207214
assert_eq!(
@@ -213,6 +220,10 @@ mod test {
213220
Some(&Value::Bool(false))
214221
);
215222
assert_eq!(settings.get(Setting::McpLoadedBefore), Some(&Value::Bool(true)));
223+
assert_eq!(
224+
settings.get(Setting::UserDefaultModel),
225+
Some(&Value::String("model 1".to_string()))
226+
);
216227

217228
settings.remove(Setting::TelemetryEnabled).await.unwrap();
218229
settings.remove(Setting::OldClientId).await.unwrap();

0 commit comments

Comments
 (0)