Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions crates/chat-cli/src/cli/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ impl ChatArgs {
.build(os, Box::new(std::io::stderr()), !self.non_interactive)
.await?;
let tool_config = tool_manager.load_tools(os, &mut stderr).await?;

let mut tool_permissions = ToolPermissions::new(tool_config.len());

if self.trust_all_tools {
Expand All @@ -279,6 +280,9 @@ impl ChatArgs {
tool_permissions.untrust_tool(&tool.name);
}
}
} else {
// CLI args has precendence over Database config
tool_permissions.trust_from_database(&os.database);
}

ChatSession::new(
Expand Down
11 changes: 11 additions & 0 deletions crates/chat-cli/src/cli/chat/tools/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ use use_aws::UseAws;

use super::consts::MAX_TOOL_RESPONSE_SIZE;
use super::util::images::RichImageBlocks;
use crate::database::Database;
use crate::database::settings::Setting;
use crate::os::Os;

/// Represents an executable tool use.
Expand Down Expand Up @@ -224,6 +226,15 @@ impl ToolPermissions {
self.permissions.contains_key(tool_name)
}

pub fn trust_from_database(&mut self, database: &Database) {
database
.settings
.get_array::<String>(Setting::TrustedTools)
.into_iter()
.flatten()
.for_each(|tool_name| self.trust_tool(&tool_name));
}

/// Provide default permission labels for the built-in set of tools.
// This "static" way avoids needing to construct a tool instance.
fn default_permission_label(&self, tool_name: &str) -> String {
Expand Down
27 changes: 27 additions & 0 deletions crates/chat-cli/src/database/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub enum Setting {
McpNoInteractiveTimeout,
McpLoadedBefore,
ChatDefaultModel,
TrustedTools,
}

impl AsRef<str> for Setting {
Expand All @@ -54,6 +55,7 @@ impl AsRef<str> for Setting {
Self::McpNoInteractiveTimeout => "mcp.noInteractiveTimeout",
Self::McpLoadedBefore => "mcp.loadedBefore",
Self::ChatDefaultModel => "chat.defaultModel",
Self::TrustedTools => "tools.trusted",
}
}
}
Expand Down Expand Up @@ -85,6 +87,7 @@ impl TryFrom<&str> for Setting {
"mcp.noInteractiveTimeout" => Ok(Self::McpNoInteractiveTimeout),
"mcp.loadedBefore" => Ok(Self::McpLoadedBefore),
"chat.defaultModel" => Ok(Self::ChatDefaultModel),
"tools.trusted" => Ok(Self::TrustedTools),
_ => Err(DatabaseError::InvalidSetting(value.to_string())),
}
}
Expand Down Expand Up @@ -150,6 +153,19 @@ impl Settings {
self.get(key).and_then(|value| value.as_str().map(|s| s.into()))
}

pub fn get_array<T>(&self, key: Setting) -> Option<Vec<T>>
where
T: serde::de::DeserializeOwned,
{
self.get(key).and_then(|value| {
value.as_array().and_then(|arr| {
arr.iter()
.map(|v| serde_json::from_value(v.clone()).ok())
.collect::<Option<Vec<T>>>()
})
})
}

pub fn get_int(&self, key: Setting) -> Option<i64> {
self.get(key).and_then(|value| value.as_i64())
}
Expand Down Expand Up @@ -204,12 +220,17 @@ mod test {
assert_eq!(settings.get(Setting::ShareCodeWhispererContent), None);
assert_eq!(settings.get(Setting::McpLoadedBefore), None);
assert_eq!(settings.get(Setting::ChatDefaultModel), None);
assert_eq!(settings.get(Setting::TrustedTools), None);

settings.set(Setting::TelemetryEnabled, true).await.unwrap();
settings.set(Setting::OldClientId, "test").await.unwrap();
settings.set(Setting::ShareCodeWhispererContent, false).await.unwrap();
settings.set(Setting::McpLoadedBefore, true).await.unwrap();
settings.set(Setting::ChatDefaultModel, "model 1").await.unwrap();
settings
.set(Setting::TrustedTools, r#"["tool_a","tool_b"]"#)
.await
.unwrap();

assert_eq!(settings.get(Setting::TelemetryEnabled), Some(&Value::Bool(true)));
assert_eq!(
Expand All @@ -225,15 +246,21 @@ mod test {
settings.get(Setting::ChatDefaultModel),
Some(&Value::String("model 1".to_string()))
);
assert_eq!(
settings.get(Setting::TrustedTools),
Some(&Value::String(r#"["tool_a","tool_b"]"#.to_string()))
);

settings.remove(Setting::TelemetryEnabled).await.unwrap();
settings.remove(Setting::OldClientId).await.unwrap();
settings.remove(Setting::ShareCodeWhispererContent).await.unwrap();
settings.remove(Setting::McpLoadedBefore).await.unwrap();
settings.remove(Setting::TrustedTools).await.unwrap();

assert_eq!(settings.get(Setting::TelemetryEnabled), None);
assert_eq!(settings.get(Setting::OldClientId), None);
assert_eq!(settings.get(Setting::ShareCodeWhispererContent), None);
assert_eq!(settings.get(Setting::McpLoadedBefore), None);
assert_eq!(settings.get(Setting::TrustedTools), None);
}
}