diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index ccc8df7f4..7b38d5b60 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -255,6 +255,7 @@ impl ChatArgs { .build(os, Box::new(std::io::stderr()), !self.non_interactive) .await?; let tool_config = tool_manager.load_tools(os, &mut stderr).await?; + let mut tool_permissions = ToolPermissions::new(tool_config.len()); if self.trust_all_tools { @@ -279,6 +280,9 @@ impl ChatArgs { tool_permissions.untrust_tool(&tool.name); } } + } else { + // CLI args has precendence over Database config + tool_permissions.trust_from_database(&os.database); } ChatSession::new( diff --git a/crates/chat-cli/src/cli/chat/tools/mod.rs b/crates/chat-cli/src/cli/chat/tools/mod.rs index f45bcaba4..d19c9f21d 100644 --- a/crates/chat-cli/src/cli/chat/tools/mod.rs +++ b/crates/chat-cli/src/cli/chat/tools/mod.rs @@ -39,6 +39,8 @@ use use_aws::UseAws; use super::consts::MAX_TOOL_RESPONSE_SIZE; use super::util::images::RichImageBlocks; +use crate::database::Database; +use crate::database::settings::Setting; use crate::os::Os; /// Represents an executable tool use. @@ -224,6 +226,15 @@ impl ToolPermissions { self.permissions.contains_key(tool_name) } + pub fn trust_from_database(&mut self, database: &Database) { + database + .settings + .get_array::(Setting::TrustedTools) + .into_iter() + .flatten() + .for_each(|tool_name| self.trust_tool(&tool_name)); + } + /// Provide default permission labels for the built-in set of tools. // This "static" way avoids needing to construct a tool instance. fn default_permission_label(&self, tool_name: &str) -> String { diff --git a/crates/chat-cli/src/database/settings.rs b/crates/chat-cli/src/database/settings.rs index 85e56c21a..e001c9729 100644 --- a/crates/chat-cli/src/database/settings.rs +++ b/crates/chat-cli/src/database/settings.rs @@ -33,6 +33,7 @@ pub enum Setting { McpNoInteractiveTimeout, McpLoadedBefore, ChatDefaultModel, + TrustedTools, } impl AsRef for Setting { @@ -54,6 +55,7 @@ impl AsRef for Setting { Self::McpNoInteractiveTimeout => "mcp.noInteractiveTimeout", Self::McpLoadedBefore => "mcp.loadedBefore", Self::ChatDefaultModel => "chat.defaultModel", + Self::TrustedTools => "tools.trusted", } } } @@ -85,6 +87,7 @@ impl TryFrom<&str> for Setting { "mcp.noInteractiveTimeout" => Ok(Self::McpNoInteractiveTimeout), "mcp.loadedBefore" => Ok(Self::McpLoadedBefore), "chat.defaultModel" => Ok(Self::ChatDefaultModel), + "tools.trusted" => Ok(Self::TrustedTools), _ => Err(DatabaseError::InvalidSetting(value.to_string())), } } @@ -150,6 +153,19 @@ impl Settings { self.get(key).and_then(|value| value.as_str().map(|s| s.into())) } + pub fn get_array(&self, key: Setting) -> Option> + where + T: serde::de::DeserializeOwned, + { + self.get(key).and_then(|value| { + value.as_array().and_then(|arr| { + arr.iter() + .map(|v| serde_json::from_value(v.clone()).ok()) + .collect::>>() + }) + }) + } + pub fn get_int(&self, key: Setting) -> Option { self.get(key).and_then(|value| value.as_i64()) } @@ -204,12 +220,17 @@ mod test { assert_eq!(settings.get(Setting::ShareCodeWhispererContent), None); assert_eq!(settings.get(Setting::McpLoadedBefore), None); assert_eq!(settings.get(Setting::ChatDefaultModel), None); + assert_eq!(settings.get(Setting::TrustedTools), None); settings.set(Setting::TelemetryEnabled, true).await.unwrap(); settings.set(Setting::OldClientId, "test").await.unwrap(); settings.set(Setting::ShareCodeWhispererContent, false).await.unwrap(); settings.set(Setting::McpLoadedBefore, true).await.unwrap(); settings.set(Setting::ChatDefaultModel, "model 1").await.unwrap(); + settings + .set(Setting::TrustedTools, r#"["tool_a","tool_b"]"#) + .await + .unwrap(); assert_eq!(settings.get(Setting::TelemetryEnabled), Some(&Value::Bool(true))); assert_eq!( @@ -225,15 +246,21 @@ mod test { settings.get(Setting::ChatDefaultModel), Some(&Value::String("model 1".to_string())) ); + assert_eq!( + settings.get(Setting::TrustedTools), + Some(&Value::String(r#"["tool_a","tool_b"]"#.to_string())) + ); settings.remove(Setting::TelemetryEnabled).await.unwrap(); settings.remove(Setting::OldClientId).await.unwrap(); settings.remove(Setting::ShareCodeWhispererContent).await.unwrap(); settings.remove(Setting::McpLoadedBefore).await.unwrap(); + settings.remove(Setting::TrustedTools).await.unwrap(); assert_eq!(settings.get(Setting::TelemetryEnabled), None); assert_eq!(settings.get(Setting::OldClientId), None); assert_eq!(settings.get(Setting::ShareCodeWhispererContent), None); assert_eq!(settings.get(Setting::McpLoadedBefore), None); + assert_eq!(settings.get(Setting::TrustedTools), None); } }