Skip to content

Commit 6dbca6c

Browse files
committed
makes main chat loop update state if applicable
1 parent f9b0891 commit 6dbca6c

File tree

4 files changed

+93
-35
lines changed

4 files changed

+93
-35
lines changed

crates/chat-cli/src/cli/chat/conversation_state.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::collections::{
33
VecDeque,
44
};
55
use std::sync::Arc;
6+
use std::sync::atomic::Ordering;
67

78
use crossterm::style::Color;
89
use crossterm::{
@@ -354,6 +355,7 @@ impl ConversationState {
354355
/// - `run_hooks` - whether hooks should be executed and included as context
355356
pub async fn as_sendable_conversation_state(&mut self, run_hooks: bool) -> FigConversationState {
356357
debug_assert!(self.next_message.is_some());
358+
self.update_state().await;
357359
self.enforce_conversation_invariants();
358360
self.history.drain(self.valid_history_range.1..);
359361
self.history.drain(..self.valid_history_range.0);
@@ -379,6 +381,30 @@ impl ConversationState {
379381
.expect("unable to construct conversation state")
380382
}
381383

384+
pub async fn update_state(&mut self) {
385+
let needs_update = self.tool_manager.has_new_stuff.load(Ordering::Acquire);
386+
if !needs_update {
387+
return;
388+
}
389+
self.tool_manager.update().await;
390+
self.tools = self
391+
.tool_manager
392+
.schema
393+
.values()
394+
.fold(HashMap::<ToolOrigin, Vec<Tool>>::new(), |mut acc, v| {
395+
let tool = Tool::ToolSpecification(ToolSpecification {
396+
name: v.name.clone(),
397+
description: v.description.clone(),
398+
input_schema: v.input_schema.clone().into(),
399+
});
400+
acc.entry(v.tool_origin.clone())
401+
.and_modify(|tools| tools.push(tool.clone()))
402+
.or_insert(vec![tool]);
403+
acc
404+
});
405+
self.tool_manager.has_new_stuff.store(false, Ordering::Release);
406+
}
407+
382408
/// Returns a conversation state representation which reflects the exact conversation to send
383409
/// back to the model.
384410
pub async fn backend_conversation_state(&mut self, run_hooks: bool, quiet: bool) -> BackendConversationState<'_> {

crates/chat-cli/src/cli/chat/mod.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ pub async fn chat(
401401
let tool_config = tool_manager.load_tools().await?;
402402
let mut tool_permissions = ToolPermissions::new(tool_config.len());
403403
if accept_all || trust_all_tools {
404+
tool_permissions.trust_all = true;
404405
for tool in tool_config.values() {
405406
tool_permissions.trust_tool(&tool.name);
406407
}
@@ -2259,7 +2260,7 @@ impl ChatContext {
22592260
)?;
22602261
},
22612262
Some(ToolsSubcommand::ResetSingle { tool_name }) => {
2262-
if self.tool_permissions.has(&tool_name) {
2263+
if self.tool_permissions.has(&tool_name) || self.tool_permissions.trust_all {
22632264
self.tool_permissions.reset_tool(&tool_name);
22642265
queue!(
22652266
self.output,
@@ -2735,11 +2736,9 @@ impl ChatContext {
27352736
}
27362737

27372738
// If there is an override, we will use it. Otherwise fall back to Tool's default.
2738-
let allowed = if self.tool_permissions.has(&tool.name) {
2739-
self.tool_permissions.is_trusted(&tool.name)
2740-
} else {
2741-
!tool.tool.requires_acceptance(&self.ctx)
2742-
};
2739+
let allowed = self.tool_permissions.trust_all
2740+
|| (self.tool_permissions.has(&tool.name) && self.tool_permissions.is_trusted(&tool.name))
2741+
|| !tool.tool.requires_acceptance(&self.ctx);
27432742

27442743
if self.settings.get_bool_or("chat.enableNotifications", false) {
27452744
play_notification_bell(!allowed);

crates/chat-cli/src/cli/chat/tool_manager.rs

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use std::collections::HashMap;
1+
use std::collections::{
2+
HashMap,
3+
HashSet,
4+
};
25
use std::future::Future;
36
use std::hash::{
47
DefaultHasher,
@@ -103,7 +106,7 @@ pub enum GetPromptError {
103106
/// Messages used for communication between the tool initialization thread and the loading
104107
/// display thread. These messages control the visual loading indicators shown to
105108
/// the user during tool initialization.
106-
pub enum LoadingMsg {
109+
enum LoadingMsg {
107110
/// Indicates a new tool is being initialized and should be added to the loading
108111
/// display. The String parameter is the name of the tool being initialized.
109112
Add(String),
@@ -421,7 +424,7 @@ impl ToolManagerBuilder {
421424
.insert(server_name, (sanitized_mapping, specs));
422425
// We only want to set this flag when the display task has ended
423426
if load_msg_sender.is_none() {
424-
has_new_stuff_clone.store(true, Ordering::Relaxed);
427+
has_new_stuff_clone.store(true, Ordering::Release);
425428
}
426429
},
427430
UpdateEventMessage::PromptsListResult {
@@ -710,28 +713,8 @@ impl ToolManager {
710713
}
711714
}
712715
}
713-
let new_tools = {
714-
let mut new_tool_specs = self.new_tool_specs.lock().await;
715-
new_tool_specs.drain().fold(HashMap::new(), |mut acc, (k, v)| {
716-
acc.insert(k, v);
717-
acc
718-
})
719-
};
720-
for (_server_name, (tool_name_map, specs)) in new_tools {
721-
for (k, v) in tool_name_map {
722-
self.tn_map.insert(k, v);
723-
}
724-
for spec in specs {
725-
tool_specs.insert(spec.name.clone(), spec);
726-
}
727-
}
728-
// caching the tool names for skim operations
729-
for tool_name in tool_specs.keys() {
730-
if !self.tn_map.contains_key(tool_name) {
731-
self.tn_map.insert(tool_name.clone(), tool_name.clone());
732-
}
733-
}
734-
self.schema = tool_specs.clone();
716+
self.update().await;
717+
tool_specs.extend(self.schema.clone());
735718
Ok(tool_specs)
736719
}
737720

@@ -831,6 +814,49 @@ impl ToolManager {
831814
})
832815
}
833816

817+
/// Updates tool managers various states with new information
818+
pub async fn update(&mut self) {
819+
// A hashmap of <tool name, tool spec>
820+
let mut tool_specs = HashMap::<String, ToolSpec>::new();
821+
let new_tools = {
822+
let mut new_tool_specs = self.new_tool_specs.lock().await;
823+
new_tool_specs.drain().fold(HashMap::new(), |mut acc, (k, v)| {
824+
acc.insert(k, v);
825+
acc
826+
})
827+
};
828+
let mut updated_servers = HashSet::<ToolOrigin>::new();
829+
for (_server_name, (tool_name_map, specs)) in new_tools {
830+
// In a populated tn map (i.e. a partially initialized or outdated fleet of servers) there
831+
// will be incoming tools with names that are already in the tn map, we will be writing
832+
// over them (perhaps with the same information that they already had), and that's okay.
833+
// In an event where a server has removed tools, the tools that are no longer available
834+
// will linger in this map. This is also okay to not clean up as it does not affect the
835+
// look up of tool names that are still active.
836+
for (k, v) in tool_name_map {
837+
self.tn_map.insert(k, v);
838+
}
839+
if let Some(spec) = specs.first() {
840+
updated_servers.insert(spec.tool_origin.clone());
841+
}
842+
for spec in specs {
843+
tool_specs.insert(spec.name.clone(), spec);
844+
}
845+
}
846+
// Caching the tool names for skim operations
847+
for tool_name in tool_specs.keys() {
848+
if !self.tn_map.contains_key(tool_name) {
849+
self.tn_map.insert(tool_name.clone(), tool_name.clone());
850+
}
851+
}
852+
// Update schema
853+
// As we are writing over the ensemble of tools in a given server, we will need to first
854+
// remove everything that it has.
855+
self.schema
856+
.retain(|_tool_name, spec| !updated_servers.contains(&spec.tool_origin));
857+
self.schema.extend(tool_specs);
858+
}
859+
834860
#[allow(clippy::await_holding_lock)]
835861
pub async fn get_prompt(&self, get_command: PromptsGetCommand) -> Result<JsonRpcResponse, GetPromptError> {
836862
let (server_name, prompt_name) = match get_command.params.name.split_once('/') {

crates/chat-cli/src/cli/chat/tools/mod.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,30 +126,33 @@ pub struct ToolPermission {
126126
/// Tools that do not have an associated ToolPermission should use
127127
/// their default logic to determine to permission.
128128
pub struct ToolPermissions {
129+
// We need this field for any stragglers
130+
pub trust_all: bool,
129131
pub permissions: HashMap<String, ToolPermission>,
130132
}
131133

132134
impl ToolPermissions {
133135
pub fn new(capacity: usize) -> Self {
134136
Self {
137+
trust_all: false,
135138
permissions: HashMap::with_capacity(capacity),
136139
}
137140
}
138141

139142
pub fn is_trusted(&self, tool_name: &str) -> bool {
140-
self.permissions.get(tool_name).is_some_and(|perm| perm.trusted)
143+
self.trust_all || self.permissions.get(tool_name).is_some_and(|perm| perm.trusted)
141144
}
142145

143146
/// Returns a label to describe the permission status for a given tool.
144147
pub fn display_label(&self, tool_name: &str) -> String {
145-
if self.has(tool_name) {
148+
if self.has(tool_name) || self.trust_all {
146149
if self.is_trusted(tool_name) {
147150
format!(" {}", "trusted".dark_green().bold())
148151
} else {
149152
format!(" {}", "not trusted".dark_grey())
150153
}
151154
} else {
152-
Self::default_permission_label(tool_name)
155+
self.default_permission_label(tool_name)
153156
}
154157
}
155158

@@ -159,15 +162,18 @@ impl ToolPermissions {
159162
}
160163

161164
pub fn untrust_tool(&mut self, tool_name: &str) {
165+
self.trust_all = false;
162166
self.permissions
163167
.insert(tool_name.to_string(), ToolPermission { trusted: false });
164168
}
165169

166170
pub fn reset(&mut self) {
171+
self.trust_all = false;
167172
self.permissions.clear();
168173
}
169174

170175
pub fn reset_tool(&mut self, tool_name: &str) {
176+
self.trust_all = false;
171177
self.permissions.remove(tool_name);
172178
}
173179

@@ -178,14 +184,15 @@ impl ToolPermissions {
178184
/// Provide default permission labels for the built-in set of tools.
179185
/// Unknown tools are assumed to be "Per-request"
180186
// This "static" way avoids needing to construct a tool instance.
181-
fn default_permission_label(tool_name: &str) -> String {
187+
fn default_permission_label(&self, tool_name: &str) -> String {
182188
let label = match tool_name {
183189
"fs_read" => "trusted".dark_green().bold(),
184190
"fs_write" => "not trusted".dark_grey(),
185191
"execute_bash" => "trust read-only commands".dark_grey(),
186192
"use_aws" => "trust read-only commands".dark_grey(),
187193
"report_issue" => "trusted".dark_green().bold(),
188194
"thinking" => "trusted (prerelease)".dark_green().bold(),
195+
_ if self.trust_all => "trusted".dark_grey().bold(),
189196
_ => "not trusted".dark_grey(),
190197
};
191198

0 commit comments

Comments
 (0)