Skip to content

Commit 9de421a

Browse files
committed
adds copy change to server loading task
1 parent 2391800 commit 9de421a

File tree

4 files changed

+115
-39
lines changed

4 files changed

+115
-39
lines changed

crates/chat-cli/src/cli/chat/conversation_state.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ impl ConversationState {
322322
let tool_name = tool_use.name.as_str();
323323
if !tool_name_list.contains(&tool_name) {
324324
tool_use.name = DUMMY_TOOL_NAME.to_string();
325+
tool_use.args = serde_json::json!({});
325326
}
326327
})
327328
.collect::<Vec<_>>();

crates/chat-cli/src/cli/chat/mod.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ use rand::distr::{
8686
SampleString,
8787
};
8888
use tokio::signal::ctrl_c;
89-
use util::shared_writer::SharedWriter;
89+
use util::shared_writer::{
90+
NullWriter,
91+
SharedWriter,
92+
};
9093
use util::ui::draw_box;
9194

9295
use crate::api_client::StreamingClient;
@@ -396,7 +399,12 @@ pub async fn chat(
396399
.prompt_list_sender(prompt_response_sender)
397400
.prompt_list_receiver(prompt_request_receiver)
398401
.conversation_id(&conversation_id)
399-
.build()
402+
.interactive(interactive)
403+
.build(if interactive {
404+
Box::new(output.clone())
405+
} else {
406+
Box::new(NullWriter {})
407+
})
400408
.await?;
401409
let tool_config = tool_manager.load_tools().await?;
402410
let mut tool_permissions = ToolPermissions::new(tool_config.len());

crates/chat-cli/src/cli/chat/tool_manager.rs

Lines changed: 95 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ pub struct ToolManagerBuilder {
205205
prompt_list_sender: Option<std::sync::mpsc::Sender<Vec<String>>>,
206206
prompt_list_receiver: Option<std::sync::mpsc::Receiver<Option<String>>>,
207207
conversation_id: Option<String>,
208+
is_interactive: bool,
208209
}
209210

210211
impl ToolManagerBuilder {
@@ -228,12 +229,19 @@ impl ToolManagerBuilder {
228229
self
229230
}
230231

231-
pub async fn build(mut self) -> eyre::Result<ToolManager> {
232+
#[allow(dead_code)]
233+
pub fn interactive(mut self, is_interactive: bool) -> Self {
234+
self.is_interactive = is_interactive;
235+
self
236+
}
237+
238+
pub async fn build(mut self, mut output: Box<dyn Write + Send + Sync + 'static>) -> eyre::Result<ToolManager> {
232239
let McpServerConfig { mcp_servers } = self.mcp_server_config.ok_or(eyre::eyre!("Missing mcp server config"))?;
233240
debug_assert!(self.conversation_id.is_some());
234241
let conversation_id = self.conversation_id.ok_or(eyre::eyre!("Missing conversation id"))?;
235242
let regex = regex::Regex::new(VALID_TOOL_NAME)?;
236243
let mut hasher = DefaultHasher::new();
244+
let is_interactive = self.is_interactive;
237245
let pre_initialized = mcp_servers
238246
.into_iter()
239247
.map(|(server_name, server_config)| {
@@ -246,11 +254,7 @@ impl ToolManagerBuilder {
246254

247255
// Send up task to update user on server loading status
248256
let (tx, rx) = std::sync::mpsc::channel::<LoadingMsg>();
249-
// Using a hand rolled thread because it's just easier to do this than do deal with the Send
250-
// requirements that comes with holding onto the stdout lock.
251257
let loading_display_task = tokio::task::spawn_blocking(move || {
252-
let stdout = std::io::stdout();
253-
let mut stdout_lock = stdout.lock();
254258
let mut loading_servers = HashMap::<String, StatusLine>::new();
255259
let mut spinner_logo_idx: usize = 0;
256260
let mut complete: usize = 0;
@@ -262,16 +266,16 @@ impl ToolManagerBuilder {
262266
let init_time = std::time::Instant::now();
263267
let is_done = false;
264268
let status_line = StatusLine { init_time, is_done };
265-
execute!(stdout_lock, cursor::MoveToColumn(0))?;
269+
execute!(output, cursor::MoveToColumn(0))?;
266270
if !loading_servers.is_empty() {
267271
// TODO: account for terminal width
268-
execute!(stdout_lock, cursor::MoveUp(1))?;
272+
execute!(output, cursor::MoveUp(1))?;
269273
}
270274
loading_servers.insert(name.clone(), status_line);
271275
let total = loading_servers.len();
272-
execute!(stdout_lock, terminal::Clear(terminal::ClearType::CurrentLine))?;
273-
queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?;
274-
stdout_lock.flush()?;
276+
execute!(output, terminal::Clear(terminal::ClearType::CurrentLine))?;
277+
queue_init_message(spinner_logo_idx, complete, failed, total, is_interactive, &mut output)?;
278+
output.flush()?;
275279
},
276280
LoadingMsg::Done(name) => {
277281
if let Some(status_line) = loading_servers.get_mut(&name) {
@@ -281,15 +285,22 @@ impl ToolManagerBuilder {
281285
(std::time::Instant::now() - status_line.init_time).as_secs_f64().abs();
282286
let time_taken = format!("{:.2}", time_taken);
283287
execute!(
284-
stdout_lock,
288+
output,
285289
cursor::MoveToColumn(0),
286290
cursor::MoveUp(1),
287291
terminal::Clear(terminal::ClearType::CurrentLine),
288292
)?;
289-
queue_success_message(&name, &time_taken, &mut stdout_lock)?;
293+
queue_success_message(&name, &time_taken, &mut output)?;
290294
let total = loading_servers.len();
291-
queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?;
292-
stdout_lock.flush()?;
295+
queue_init_message(
296+
spinner_logo_idx,
297+
complete,
298+
failed,
299+
total,
300+
is_interactive,
301+
&mut output,
302+
)?;
303+
output.flush()?;
293304
}
294305
if loading_servers.iter().all(|(_, status)| status.is_done) {
295306
break;
@@ -300,14 +311,21 @@ impl ToolManagerBuilder {
300311
status_line.is_done = true;
301312
failed += 1;
302313
execute!(
303-
stdout_lock,
314+
output,
304315
cursor::MoveToColumn(0),
305316
cursor::MoveUp(1),
306317
terminal::Clear(terminal::ClearType::CurrentLine),
307318
)?;
308-
queue_failure_message(&name, &msg, &mut stdout_lock)?;
319+
queue_failure_message(&name, &msg, &mut output)?;
309320
let total = loading_servers.len();
310-
queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?;
321+
queue_init_message(
322+
spinner_logo_idx,
323+
complete,
324+
failed,
325+
total,
326+
is_interactive,
327+
&mut output,
328+
)?;
311329
}
312330
if loading_servers.iter().all(|(_, status)| status.is_done) {
313331
break;
@@ -318,16 +336,23 @@ impl ToolManagerBuilder {
318336
status_line.is_done = true;
319337
complete += 1;
320338
execute!(
321-
stdout_lock,
339+
output,
322340
cursor::MoveToColumn(0),
323341
cursor::MoveUp(1),
324342
terminal::Clear(terminal::ClearType::CurrentLine),
325343
)?;
326344
let msg = eyre::eyre!(msg.to_string());
327-
queue_warn_message(&name, &msg, &mut stdout_lock)?;
345+
queue_warn_message(&name, &msg, &mut output)?;
328346
let total = loading_servers.len();
329-
queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?;
330-
stdout_lock.flush()?;
347+
queue_init_message(
348+
spinner_logo_idx,
349+
complete,
350+
failed,
351+
total,
352+
is_interactive,
353+
&mut output,
354+
)?;
355+
output.flush()?;
331356
}
332357
if loading_servers.iter().all(|(_, status)| status.is_done) {
333358
break;
@@ -336,7 +361,7 @@ impl ToolManagerBuilder {
336361
LoadingMsg::Terminate => {
337362
if loading_servers.iter().any(|(_, status)| !status.is_done) {
338363
execute!(
339-
stdout_lock,
364+
output,
340365
cursor::MoveToColumn(0),
341366
cursor::MoveUp(1),
342367
terminal::Clear(terminal::ClearType::CurrentLine),
@@ -351,16 +376,17 @@ impl ToolManagerBuilder {
351376
acc
352377
});
353378
let msg = eyre::eyre!(msg);
354-
queue_incomplete_load_message(&msg, &mut stdout_lock)?;
355-
stdout_lock.flush()?;
379+
let total = loading_servers.len();
380+
queue_incomplete_load_message(complete, total, &msg, &mut output)?;
381+
output.flush()?;
356382
}
357383
break;
358384
},
359385
},
360386
Err(RecvTimeoutError::Timeout) => {
361387
spinner_logo_idx = (spinner_logo_idx + 1) % SPINNER_CHARS.len();
362388
execute!(
363-
stdout_lock,
389+
output,
364390
cursor::SavePosition,
365391
cursor::MoveToColumn(0),
366392
cursor::MoveUp(1),
@@ -570,6 +596,7 @@ impl ToolManagerBuilder {
570596
loading_status_sender,
571597
new_tool_specs,
572598
has_new_stuff,
599+
is_interactive,
573600
..Default::default()
574601
})
575602
}
@@ -640,6 +667,8 @@ pub struct ToolManager {
640667
/// This is mainly used to show the user what the tools look like from the perspective of the
641668
/// model.
642669
pub schema: HashMap<String, ToolSpec>,
670+
671+
is_interactive: bool,
643672
}
644673

645674
impl Clone for ToolManager {
@@ -652,6 +681,7 @@ impl Clone for ToolManager {
652681
prompts: self.prompts.clone(),
653682
tn_map: self.tn_map.clone(),
654683
schema: self.schema.clone(),
684+
is_interactive: self.is_interactive,
655685
..Default::default()
656686
}
657687
}
@@ -708,8 +738,12 @@ impl ToolManager {
708738
}
709739
},
710740
_ = ctrl_c() => {
711-
if let Some(tx) = tx {
712-
let _ = tx.send(LoadingMsg::Terminate);
741+
if self.is_interactive {
742+
if let Some(tx) = tx {
743+
let _ = tx.send(LoadingMsg::Terminate);
744+
}
745+
} else {
746+
return Err(eyre::eyre!("User interrupted mcp server loading in non-interactive mode. Ending."));
713747
}
714748
}
715749
}
@@ -1186,6 +1220,7 @@ fn queue_init_message(
11861220
complete: usize,
11871221
failed: usize,
11881222
total: usize,
1223+
is_interactive: bool,
11891224
output: &mut impl Write,
11901225
) -> eyre::Result<()> {
11911226
if total == complete {
@@ -1205,7 +1240,7 @@ fn queue_init_message(
12051240
} else {
12061241
queue!(output, style::Print(SPINNER_CHARS[spinner_logo_idx]))?;
12071242
}
1208-
Ok(queue!(
1243+
queue!(
12091244
output,
12101245
style::SetForegroundColor(style::Color::Blue),
12111246
style::Print(format!(" {}", complete)),
@@ -1214,11 +1249,22 @@ fn queue_init_message(
12141249
style::SetForegroundColor(style::Color::Blue),
12151250
style::Print(format!("{} ", total)),
12161251
style::ResetColor,
1217-
style::Print("mcp servers initialized. Press ctrl-c to load the remaining servers in the background\n"),
1218-
)?)
1252+
style::Print("mcp servers initialized."),
1253+
)?;
1254+
if is_interactive {
1255+
queue!(
1256+
output,
1257+
style::SetForegroundColor(style::Color::Blue),
1258+
style::Print(" ctrl-c "),
1259+
style::ResetColor,
1260+
style::Print("to start chatting now")
1261+
)?;
1262+
}
1263+
Ok(queue!(output, style::Print("\n"))?)
12191264
}
12201265

12211266
fn queue_failure_message(name: &str, fail_load_msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> {
1267+
use crate::util::CHAT_BINARY_NAME;
12221268
Ok(queue!(
12231269
output,
12241270
style::SetForegroundColor(style::Color::Red),
@@ -1229,7 +1275,9 @@ fn queue_failure_message(name: &str, fail_load_msg: &eyre::Report, output: &mut
12291275
style::Print(" has failed to load:\n- "),
12301276
style::Print(fail_load_msg),
12311277
style::Print("\n"),
1232-
style::Print("- run with Q_LOG_LEVEL=trace and see $TMPDIR/qlog for detail\n"),
1278+
style::Print(format!(
1279+
"- run with Q_LOG_LEVEL=trace and see $TMPDIR/{CHAT_BINARY_NAME} for detail\n"
1280+
)),
12331281
style::ResetColor,
12341282
)?)
12351283
}
@@ -1248,14 +1296,27 @@ fn queue_warn_message(name: &str, msg: &eyre::Report, output: &mut impl Write) -
12481296
)?)
12491297
}
12501298

1251-
fn queue_incomplete_load_message(msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> {
1299+
fn queue_incomplete_load_message(
1300+
complete: usize,
1301+
total: usize,
1302+
msg: &eyre::Report,
1303+
output: &mut impl Write,
1304+
) -> eyre::Result<()> {
12521305
Ok(queue!(
12531306
output,
12541307
style::SetForegroundColor(style::Color::Yellow),
1255-
style::Print("⚠ "),
1308+
style::Print("⚠"),
1309+
style::SetForegroundColor(style::Color::Blue),
1310+
style::Print(format!(" {}", complete)),
1311+
style::ResetColor,
1312+
style::Print(" of "),
1313+
style::SetForegroundColor(style::Color::Blue),
1314+
style::Print(format!("{} ", total)),
1315+
style::ResetColor,
1316+
style::Print("mcp servers initialized."),
12561317
style::ResetColor,
12571318
// We expect the message start with a newline
1258-
style::Print("following servers are still loading:"),
1319+
style::Print(" Servers still loading:"),
12591320
style::Print(msg),
12601321
style::ResetColor,
12611322
)?)

crates/chat-cli/src/cli/chat/tools/custom_tool.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,15 @@ impl CustomTool {
170170
pub async fn invoke(&self, _ctx: &Context, _updates: &mut impl Write) -> Result<InvokeOutput> {
171171
// Assuming a response shape as per https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#calling-tools
172172
let resp = self.client.request(self.method.as_str(), self.params.clone()).await?;
173-
let result = resp
174-
.result
175-
.ok_or(eyre::eyre!("{} invocation failed to produce a result", self.name))?;
173+
let result = match resp.result {
174+
Some(result) => result,
175+
None => {
176+
let failure = resp.error.map_or("Unknown error encountered".to_string(), |err| {
177+
serde_json::to_string(&err).unwrap_or_default()
178+
});
179+
return Err(eyre::eyre!(failure));
180+
},
181+
};
176182

177183
match serde_json::from_value::<ToolCallResult>(result.clone()) {
178184
Ok(mut de_result) => {

0 commit comments

Comments
 (0)