Skip to content

Commit ef0a16d

Browse files
committed
feat: implement database setting for trusting tools
1 parent 560c4ae commit ef0a16d

File tree

3 files changed

+65
-112
lines changed

3 files changed

+65
-112
lines changed

crates/chat-cli/src/cli/chat/mod.rs

Lines changed: 21 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -19,121 +19,52 @@ pub mod tools;
1919
pub mod util;
2020

2121
use std::borrow::Cow;
22-
use std::collections::{
23-
HashMap,
24-
HashSet,
25-
VecDeque,
26-
};
22+
use std::collections::{HashMap, HashSet, VecDeque};
2723
use std::io::Write;
2824
use std::process::ExitCode;
2925
use std::time::Duration;
3026

3127
use amzn_codewhisperer_client::types::SubscriptionStatus;
32-
use clap::{
33-
Args,
34-
CommandFactory,
35-
Parser,
36-
};
28+
use clap::{Args, CommandFactory, Parser};
3729
use context::ContextManager;
3830
pub use conversation::ConversationState;
3931
use conversation::TokenWarningLevel;
40-
use crossterm::style::{
41-
Attribute,
42-
Color,
43-
Stylize,
44-
};
45-
use crossterm::{
46-
cursor,
47-
execute,
48-
queue,
49-
style,
50-
terminal,
51-
};
52-
use eyre::{
53-
Report,
54-
Result,
55-
bail,
56-
eyre,
57-
};
32+
use crossterm::style::{Attribute, Color, Stylize};
33+
use crossterm::{cursor, execute, queue, style, terminal};
34+
use eyre::{Report, Result, bail, eyre};
5835
use input_source::InputSource;
59-
use message::{
60-
AssistantMessage,
61-
AssistantToolUse,
62-
ToolUseResult,
63-
ToolUseResultBlock,
64-
};
65-
use parse::{
66-
ParseState,
67-
interpret_markdown,
68-
};
69-
use parser::{
70-
RecvErrorKind,
71-
ResponseParser,
72-
};
36+
use message::{AssistantMessage, AssistantToolUse, ToolUseResult, ToolUseResultBlock};
37+
use parse::{ParseState, interpret_markdown};
38+
use parser::{RecvErrorKind, ResponseParser};
7339
use regex::Regex;
74-
use spinners::{
75-
Spinner,
76-
Spinners,
77-
};
40+
use spinners::{Spinner, Spinners};
7841
use thiserror::Error;
7942
use time::OffsetDateTime;
8043
use token_counter::TokenCounter;
8144
use tokio::signal::ctrl_c;
82-
use tool_manager::{
83-
McpServerConfig,
84-
ToolManager,
85-
ToolManagerBuilder,
86-
};
45+
use tool_manager::{McpServerConfig, ToolManager, ToolManagerBuilder};
8746
use tools::gh_issue::GhIssueContext;
88-
use tools::{
89-
OutputKind,
90-
QueuedTool,
91-
Tool,
92-
ToolPermissions,
93-
ToolSpec,
94-
};
95-
use tracing::{
96-
debug,
97-
error,
98-
info,
99-
trace,
100-
warn,
101-
};
47+
use tools::{OutputKind, QueuedTool, Tool, ToolPermissions, ToolSpec};
48+
use tracing::{debug, error, info, trace, warn};
10249
use util::images::RichImageBlock;
10350
use util::ui::draw_box;
104-
use util::{
105-
animate_output,
106-
play_notification_bell,
107-
};
51+
use util::{animate_output, play_notification_bell};
10852
use winnow::Partial;
10953
use winnow::stream::Offset;
11054

11155
use crate::api_client::ApiClientError;
112-
use crate::api_client::model::{
113-
Tool as FigTool,
114-
ToolResultStatus,
115-
};
56+
use crate::api_client::model::{Tool as FigTool, ToolResultStatus};
11657
use crate::api_client::send_message_output::SendMessageOutput;
11758
use crate::auth::AuthError;
11859
use crate::auth::builder_id::is_idc_user;
11960
use crate::cli::chat::cli::SlashCommand;
120-
use crate::cli::chat::cli::model::{
121-
MODEL_OPTIONS,
122-
default_model_id,
123-
};
124-
use crate::cli::chat::cli::prompts::{
125-
GetPromptError,
126-
PromptsSubcommand,
127-
};
61+
use crate::cli::chat::cli::model::{MODEL_OPTIONS, default_model_id};
62+
use crate::cli::chat::cli::prompts::{GetPromptError, PromptsSubcommand};
12863
use crate::database::settings::Setting;
12964
use crate::mcp_client::Prompt;
13065
use crate::os::Os;
13166
use crate::telemetry::core::ToolUseEventBuilder;
132-
use crate::telemetry::{
133-
ReasonCode,
134-
TelemetryResult,
135-
get_error_reason,
136-
};
67+
use crate::telemetry::{ReasonCode, TelemetryResult, get_error_reason};
13768

13869
const LIMIT_REACHED_TEXT: &str = color_print::cstr! { "You've used all your free requests for this month. You have two options:
13970
1. Upgrade to a paid subscription for increased limits. See our Pricing page for what's included> <blue!>https://aws.amazon.com/q/developer/pricing/</blue!>
@@ -255,6 +186,7 @@ impl ChatArgs {
255186
.build(os, Box::new(std::io::stderr()), !self.non_interactive)
256187
.await?;
257188
let tool_config = tool_manager.load_tools(os, &mut stderr).await?;
189+
258190
let mut tool_permissions = ToolPermissions::new(tool_config.len());
259191

260192
if self.trust_all_tools {
@@ -279,6 +211,9 @@ impl ChatArgs {
279211
tool_permissions.untrust_tool(&tool.name);
280212
}
281213
}
214+
} else {
215+
// CLI args has precendence over Database config
216+
tool_permissions.trust_from_database(&os.database);
282217
}
283218

284219
ChatSession::new(

crates/chat-cli/src/cli/chat/tools/mod.rs

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,38 +7,27 @@ pub mod knowledge;
77
pub mod thinking;
88
pub mod use_aws;
99

10-
use std::collections::{
11-
HashMap,
12-
HashSet,
13-
};
10+
use std::collections::{HashMap, HashSet};
1411
use std::io::Write;
15-
use std::path::{
16-
Path,
17-
PathBuf,
18-
};
12+
use std::path::{Path, PathBuf};
1913

2014
use crossterm::queue;
21-
use crossterm::style::{
22-
self,
23-
Color,
24-
Stylize,
25-
};
15+
use crossterm::style::{self, Color, Stylize};
2616
use custom_tool::CustomTool;
2717
use execute::ExecuteCommand;
2818
use eyre::Result;
2919
use fs_read::FsRead;
3020
use fs_write::FsWrite;
3121
use gh_issue::GhIssue;
3222
use knowledge::Knowledge;
33-
use serde::{
34-
Deserialize,
35-
Serialize,
36-
};
23+
use serde::{Deserialize, Serialize};
3724
use thinking::Thinking;
3825
use use_aws::UseAws;
3926

4027
use super::consts::MAX_TOOL_RESPONSE_SIZE;
4128
use super::util::images::RichImageBlocks;
29+
use crate::database::Database;
30+
use crate::database::settings::Setting;
4231
use crate::os::Os;
4332

4433
/// Represents an executable tool use.
@@ -224,6 +213,15 @@ impl ToolPermissions {
224213
self.permissions.contains_key(tool_name)
225214
}
226215

216+
pub fn trust_from_database(&mut self, database: &Database) {
217+
database
218+
.settings
219+
.get_array::<String>(Setting::TrustedTools)
220+
.into_iter()
221+
.flatten()
222+
.for_each(|tool_name| self.trust_tool(&tool_name));
223+
}
224+
227225
/// Provide default permission labels for the built-in set of tools.
228226
// This "static" way avoids needing to construct a tool instance.
229227
fn default_permission_label(&self, tool_name: &str) -> String {

crates/chat-cli/src/database/settings.rs

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,9 @@ use std::fmt::Display;
22
use std::io::SeekFrom;
33

44
use fd_lock::RwLock;
5-
use serde_json::{
6-
Map,
7-
Value,
8-
};
5+
use serde_json::{Map, Value};
96
use tokio::fs::File;
10-
use tokio::io::{
11-
AsyncReadExt,
12-
AsyncSeekExt,
13-
AsyncWriteExt,
14-
};
7+
use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt};
158

169
use super::DatabaseError;
1710

@@ -33,6 +26,7 @@ pub enum Setting {
3326
McpNoInteractiveTimeout,
3427
McpLoadedBefore,
3528
ChatDefaultModel,
29+
TrustedTools,
3630
}
3731

3832
impl AsRef<str> for Setting {
@@ -54,6 +48,7 @@ impl AsRef<str> for Setting {
5448
Self::McpNoInteractiveTimeout => "mcp.noInteractiveTimeout",
5549
Self::McpLoadedBefore => "mcp.loadedBefore",
5650
Self::ChatDefaultModel => "chat.defaultModel",
51+
Self::TrustedTools => "tools.trusted",
5752
}
5853
}
5954
}
@@ -85,6 +80,7 @@ impl TryFrom<&str> for Setting {
8580
"mcp.noInteractiveTimeout" => Ok(Self::McpNoInteractiveTimeout),
8681
"mcp.loadedBefore" => Ok(Self::McpLoadedBefore),
8782
"chat.defaultModel" => Ok(Self::ChatDefaultModel),
83+
"tools.trusted" => Ok(Self::TrustedTools),
8884
_ => Err(DatabaseError::InvalidSetting(value.to_string())),
8985
}
9086
}
@@ -150,6 +146,19 @@ impl Settings {
150146
self.get(key).and_then(|value| value.as_str().map(|s| s.into()))
151147
}
152148

149+
pub fn get_array<T>(&self, key: Setting) -> Option<Vec<T>>
150+
where
151+
T: serde::de::DeserializeOwned,
152+
{
153+
self.get(key).and_then(|value| {
154+
value.as_array().and_then(|arr| {
155+
arr.iter()
156+
.map(|v| serde_json::from_value(v.clone()).ok())
157+
.collect::<Option<Vec<T>>>()
158+
})
159+
})
160+
}
161+
153162
pub fn get_int(&self, key: Setting) -> Option<i64> {
154163
self.get(key).and_then(|value| value.as_i64())
155164
}
@@ -204,12 +213,17 @@ mod test {
204213
assert_eq!(settings.get(Setting::ShareCodeWhispererContent), None);
205214
assert_eq!(settings.get(Setting::McpLoadedBefore), None);
206215
assert_eq!(settings.get(Setting::ChatDefaultModel), None);
216+
assert_eq!(settings.get(Setting::TrustedTools), None);
207217

208218
settings.set(Setting::TelemetryEnabled, true).await.unwrap();
209219
settings.set(Setting::OldClientId, "test").await.unwrap();
210220
settings.set(Setting::ShareCodeWhispererContent, false).await.unwrap();
211221
settings.set(Setting::McpLoadedBefore, true).await.unwrap();
212222
settings.set(Setting::ChatDefaultModel, "model 1").await.unwrap();
223+
settings
224+
.set(Setting::TrustedTools, r#"["tool_a","tool_b"]"#)
225+
.await
226+
.unwrap();
213227

214228
assert_eq!(settings.get(Setting::TelemetryEnabled), Some(&Value::Bool(true)));
215229
assert_eq!(
@@ -225,15 +239,21 @@ mod test {
225239
settings.get(Setting::ChatDefaultModel),
226240
Some(&Value::String("model 1".to_string()))
227241
);
242+
assert_eq!(
243+
settings.get(Setting::TrustedTools),
244+
Some(&Value::String(r#"["tool_a","tool_b"]"#.to_string()))
245+
);
228246

229247
settings.remove(Setting::TelemetryEnabled).await.unwrap();
230248
settings.remove(Setting::OldClientId).await.unwrap();
231249
settings.remove(Setting::ShareCodeWhispererContent).await.unwrap();
232250
settings.remove(Setting::McpLoadedBefore).await.unwrap();
251+
settings.remove(Setting::TrustedTools).await.unwrap();
233252

234253
assert_eq!(settings.get(Setting::TelemetryEnabled), None);
235254
assert_eq!(settings.get(Setting::OldClientId), None);
236255
assert_eq!(settings.get(Setting::ShareCodeWhispererContent), None);
237256
assert_eq!(settings.get(Setting::McpLoadedBefore), None);
257+
assert_eq!(settings.get(Setting::TrustedTools), None);
238258
}
239259
}

0 commit comments

Comments
 (0)