diff --git a/.github/workflows/docs-generator.yaml b/.github/workflows/docs-generator.yaml new file mode 100644 index 0000000000..6d10c74b32 --- /dev/null +++ b/.github/workflows/docs-generator.yaml @@ -0,0 +1,62 @@ +name: docs-generator + +on: + # workflow_dispatch: + # inputs: + # pr_number: + # description: 'Number of PR to document' + # required: true + # type: string + push: + +jobs: + generate_docs: + runs-on: ubuntu-latest + env: + AMAZON_Q_SIGV4: 1 + CHAT_DOWNLOAD_ROLE_ARN: ${{ secrets.CHAT_DOWNLOAD_ROLE_ARN }} + CHAT_BUILD_BUCKET_NAME: ${{ secrets.CHAT_BUILD_BUCKET_NAME }} + PR_FILE: "pr-contents.txt" + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GIT_HASH: "latest" + TEST_PR_NUMBER: 2533 + permissions: + id-token: write + contents: write + pull-requests: write + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.AWS_TB_ROLE }} + aws-region: us-east-1 + + - name: Make scripts executable + run: | + chmod +x docs-generation/setup_amazon_q.sh + chmod +x docs-generation/create-docs-pr.sh + chmod +x docs-generation/read-pr.sh + chmod +x docs-generation/update-docs.sh + + - name: Run setup script + run: bash docs-generation/setup_amazon_q.sh + + - name: Generate PR contents file + run: bash docs-generation/read-pr.sh ${{ env.TEST_PR_NUMBER }} + + - name: Update docs + run: bash docs-generation/update-docs.sh + + - name: Create PR + if: success() + run: bash docs-generation/create-docs-pr.sh ${{ env.TEST_PR_NUMBER }} + + + + + + \ No newline at end of file diff --git a/docs-generation/create-docs-pr.sh b/docs-generation/create-docs-pr.sh new file mode 100755 index 0000000000..fc9b01a44c --- /dev/null +++ b/docs-generation/create-docs-pr.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -e + +PR_NUMBER=$1 +BRANCH_NAME="docs-update-for-pr-$PR_NUMBER" + +# Ensure we have changes to merge +if [ -z "$(git status --porcelain)" ]; then + echo "No changes to commit" + exit 0 +fi + +git config user.name "docs-generator[bot]" +git config user.email "docs-generator[bot]@amazon.com" + +# Create branch and push +git checkout -b "$BRANCH_NAME" +git add . +git commit -m "Update docs based on PR #$PR_NUMBER + +Auto-generated by Q" + +git push origin "$BRANCH_NAME" + +# Create PR +gh pr create \ + --title "Update docs based on PR #$PR_NUMBER" \ + --body "Auto-generated documentation updates based on changes in PR #$PR_NUMBER" \ + --base main \ + --head "$BRANCH_NAME" diff --git a/docs-generation/read-pr.sh b/docs-generation/read-pr.sh new file mode 100755 index 0000000000..ab084275b9 --- /dev/null +++ b/docs-generation/read-pr.sh @@ -0,0 +1,29 @@ +#!/bin/bash +set -e + +PR_NUMBER=$1 + +# Add PR information +echo "====== PR Information ======\n" > $PR_FILE +gh pr view $PR_NUMBER --json title,body --jq '"Title: " + .title + "\nDescription: " + .body' >> $PR_FILE + +# Include updated files +echo -e "\n====== Updated files ======\n" >> $PR_FILE +gh pr view $PR_NUMBER --json files --jq ".files[].path" | while read file; do + case "$file" in + *.lock|*-lock.*|*.min.*|dist/*|build/*|target/*) + continue + ;; + esac + if [ -f "$file" ]; then + echo "---- $file ----" >> $PR_FILE + cat "$file" >> $PR_FILE + echo -e "\n" >> $PR_FILE + fi +done + + + + + + \ No newline at end of file diff --git a/docs-generation/setup_amazon_q.sh b/docs-generation/setup_amazon_q.sh new file mode 100755 index 0000000000..e5c2c923ee --- /dev/null +++ b/docs-generation/setup_amazon_q.sh @@ -0,0 +1,54 @@ +#!/bin/bash +set -e +# if git hash empty then set to latest auto +sudo apt-get update +sudo apt-get install -y curl wget unzip jq + +# Create AWS credentials from environment variables +mkdir -p ~/.aws +cat > ~/.aws/credentials << EOF +[default] +aws_access_key_id = ${AWS_ACCESS_KEY_ID} +aws_secret_access_key = ${AWS_SECRET_ACCESS_KEY} +aws_session_token = ${AWS_SESSION_TOKEN} +EOF +chmod 600 ~/.aws/credentials + +cat > ~/.aws/config << EOF +[default] +region = us-east-1 +EOF +chmod 600 ~/.aws/config + +# Assume role and capture temporary credentials --> needed for s3 bucket access for build +echo "Assuming AWS s3 role" +TEMP_CREDENTIALS=$(aws sts assume-role --role-arn ${CHAT_DOWNLOAD_ROLE_ARN} --role-session-name S3AccessSession 2>/dev/null || echo '{}') +QCHAT_ACCESSKEY=$(echo $TEMP_CREDENTIALS | jq -r '.Credentials.AccessKeyId') +Q_SECRET_ACCESS_KEY=$(echo $TEMP_CREDENTIALS | jq -r '.Credentials.SecretAccessKey') +Q_SESSION_TOKEN=$(echo $TEMP_CREDENTIALS | jq -r '.Credentials.SessionToken') + +# Download specific build from S3 based on commit hash +echo "Downloading Amazon Q CLI build from S3..." +S3_PREFIX="main/${GIT_HASH}/x86_64-unknown-linux-musl" +echo "Downloading qchat.zip from s3://.../${S3_PREFIX}/qchat.zip" + +# Try download, if hash is invalid we fail. +AWS_ACCESS_KEY_ID="$QCHAT_ACCESSKEY" AWS_SECRET_ACCESS_KEY="$Q_SECRET_ACCESS_KEY" AWS_SESSION_TOKEN="$Q_SESSION_TOKEN" \ + aws s3 cp s3://${CHAT_BUILD_BUCKET_NAME}/${S3_PREFIX}/qchat.zip ./qchat.zip --region us-east-1 + +# Handle the zip file, copy the qchat executable to /usr/local/bin + symlink from old code +echo "Extracting qchat.zip..." +unzip -q qchat.zip + +# move it to /usr/local/bin/qchat for path as qchat may not work otherwise +if cp qchat /usr/local/bin/ && chmod +x /usr/local/bin/qchat; then + ln -sf /usr/local/bin/qchat /usr/local/bin/q + echo "qchat installed successfully" +else + echo "ERROR: Failed to install qchat" + exit 1 +fi + +echo "Cleaning q zip" +rm -f qchat.zip +rm -rf qchat diff --git a/docs-generation/update-docs.sh b/docs-generation/update-docs.sh new file mode 100755 index 0000000000..3ffe551574 --- /dev/null +++ b/docs-generation/update-docs.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -e + +if [ ! -f "$PR_FILE" ]; then + echo "PR file not found, aborting" + exit 1 +fi + +PROMPT="Before making any changes, read the 'docs' directory for the project's current +documentation. Then read 'pr-contents.txt' to see the contents of the current PR.\n\n +After reading both the directory and the PR file, update the files in the 'docs' directory +with new documentation reflecting the proposed changes in the PR. Make new files as appropriate." + +cat pr-contents.txt +echo "Would prompt q chat here" +# timeout 10m echo -e $PROMPT | qchat chat --non-interactive --trust-all-tools +exit $? \ No newline at end of file diff --git a/pr-contents.txt b/pr-contents.txt new file mode 100644 index 0000000000..8aef575cf5 --- /dev/null +++ b/pr-contents.txt @@ -0,0 +1,10442 @@ +====== PR Information ======\n +Title: Add to-do list functionality to QCLI +Description: Adds a to-do list tool (called todo_list) with several commands that allow Q to create a to-do list and update it as it completes tasks, along with a slash command (/todos) that allows users to view and manage their in-progress to-do lists. + + +By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. + + +====== Updated files ====== + +---- build-config/buildspec-linux.yml ---- +version: 0.2 + +env: + shell: bash + +phases: + install: + run-as: root + commands: + - dnf update -y + - dnf install -y python cmake bash zsh unzip git jq + - dnf swap -y gnupg2-minimal gnupg2-full + pre_build: + commands: + - export HOME=/home/codebuild-user + - export PATH="$HOME/.local/bin:$PATH" + - mkdir -p "$HOME/.local/bin" + # Create fish config dir to prevent rustup from failing + - mkdir -p "$HOME/.config/fish/conf.d" + # Install cargo + - curl --retry 5 --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + - . "$HOME/.cargo/env" + - rustup toolchain install `cat rust-toolchain.toml | grep channel | cut -d '=' -f2 | tr -d ' "'` + # Install cross only if the musl env var is set and not null + - if [ ! -z "${AMAZON_Q_BUILD_MUSL:+x}" ]; then cargo install cross --git https://github.com/cross-rs/cross; fi + # Install python/node via mise (https://mise.jdx.dev/continuous-integration.html) + - curl --retry 5 --proto '=https' --tlsv1.2 -sSf https://mise.run | sh + - mise install + - eval "$(mise activate bash --shims)" + # Install python deps + - pip3 install -r scripts/requirements.txt + build: + commands: + - python3.11 scripts/main.py build + +artifacts: + discard-paths: "yes" + base-directory: "build" + files: + - ./*.tar.gz + - ./*.zip + # Hashes + - ./*.sha256 + # Signatures + - ./*.asc + - ./*.sig + + + +---- build-config/buildspec-macos.yml ---- +version: 0.2 + +phases: + pre_build: + commands: + - whoami + - echo "$HOME" + - echo "$SHELL" + - pwd + - ls + - mkdir -p "$HOME/.local/bin" + - export PATH="$HOME/.local/bin:$PATH" + # Create fish config dir to prevent rustup from failing + - mkdir -p "$HOME/.config/fish/conf.d" + # Install cargo + - export CARGO_HOME="$HOME/.cargo" + - curl --retry 5 --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + - . "$HOME/.cargo/env" + - rustup toolchain install `cat rust-toolchain.toml | grep channel | cut -d '=' -f2 | tr -d ' "'` + # Install cross only if the musl env var is set and not null + - if [ ! -z "${AMAZON_Q_BUILD_MUSL:+x}" ]; then cargo install cross --git https://github.com/cross-rs/cross; fi + # Install python/node via mise (https://mise.jdx.dev/continuous-integration.html) + - curl --retry 5 --proto '=https' --tlsv1.2 -sSf https://mise.run | sh + - mise install + - eval "$(mise activate zsh --shims)" + # Install python deps + - python3 -m venv scripts/.env + - source scripts/.env/bin/activate + - pip3 install -r scripts/requirements.txt + build: + commands: + - python3 scripts/main.py build + +artifacts: + discard-paths: "yes" + base-directory: "build" + files: + - ./*.zip + # Hashes + - ./*.sha256 + + + +---- crates/chat-cli/src/cli/agent/mod.rs ---- +pub mod hook; +mod legacy; +mod mcp_config; +mod root_command_args; +mod wrapper_types; + +use std::borrow::Borrow; +use std::collections::{ + HashMap, + HashSet, +}; +use std::ffi::OsStr; +use std::io::{ + self, + Write, +}; +use std::path::{ + Path, + PathBuf, +}; + +use crossterm::style::{ + Color, + Stylize as _, +}; +use crossterm::{ + execute, + queue, + style, +}; +use eyre::bail; +pub use mcp_config::McpServerConfig; +pub use root_command_args::*; +use schemars::{ + JsonSchema, + schema_for, +}; +use serde::{ + Deserialize, + Serialize, +}; +use thiserror::Error; +use tokio::fs::ReadDir; +use tracing::{ + error, + info, + warn, +}; +use wrapper_types::ResourcePath; +pub use wrapper_types::{ + OriginalToolName, + ToolSettingTarget, + alias_schema, + tool_settings_schema, +}; + +use super::chat::tools::{ + DEFAULT_APPROVE, + NATIVE_TOOLS, + ToolOrigin, +}; +use crate::cli::agent::hook::{ + Hook, + HookTrigger, +}; +use crate::database::settings::Setting; +use crate::os::Os; +use crate::util::{ + self, + MCP_SERVER_TOOL_DELIMITER, + directories, +}; + +pub const DEFAULT_AGENT_NAME: &str = "q_cli_default"; + +#[derive(Debug, Error)] +pub enum AgentConfigError { + #[error("Json supplied at {} is invalid: {}", path.display(), error)] + InvalidJson { error: serde_json::Error, path: PathBuf }, + #[error( + "Agent config is malformed at {}: {}", error.instance_path, error + )] + SchemaMismatch { + #[from] + error: Box>, + }, + #[error("Encountered directory error: {0}")] + Directories(#[from] util::directories::DirectoryError), + #[error("Encountered io error: {0}")] + Io(#[from] std::io::Error), + #[error("Failed to parse legacy mcp config: {0}")] + BadLegacyMcpConfig(#[from] eyre::Report), +} + +/// An [Agent] is a declarative way of configuring a given instance of q chat. Currently, it is +/// impacting q chat in via influenicng [ContextManager] and [ToolManager]. +/// Changes made to [ContextManager] and [ToolManager] do not persist across sessions. +/// +/// To increase the usability of the agent config, (both from the perspective of CLI and the users +/// who would need to write these config), the agent config has two states of existence: "cold" and +/// "warm". +/// +/// A "cold" state describes the config as it is written. And a "warm" state is an alternate form +/// of the same config, modified for the convenience of the business logic that relies on it in the +/// application. +/// +/// For example, the "cold" state does not require the field of "path" to be populated. This is +/// because it would be redundant and tedious for user to have to write the path of the file they +/// had created in said file. This field is thus populated during its parsing. +/// +/// Another example is the mcp config. To support backwards compatibility of users existing global +/// mcp.json, we allow users to supply a flag to denote whether they would want to include servers +/// from the legacy global mcp.json. If this flag exists, we would need to read the legacy mcp +/// config and merge it with what is in the agent mcp servers field. Conversely, when we write this +/// config to file, we would want to filter out the servers that belong only in the mcp.json. +/// +/// Where agents are instantiated from their config, we would need to convert them from "cold" to +/// "warm". +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, JsonSchema)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +#[schemars(description = "An Agent is a declarative way of configuring a given instance of q chat.")] +pub struct Agent { + #[serde(rename = "$schema", default = "default_schema")] + pub schema: String, + /// Name of the agent + pub name: String, + /// This field is not model facing and is mostly here for users to discern between agents + #[serde(default)] + pub description: Option, + /// The intention for this field is to provide high level context to the + /// agent. This should be seen as the same category of context as a system prompt. + #[serde(default)] + pub prompt: Option, + /// Configuration for Model Context Protocol (MCP) servers + #[serde(default)] + pub mcp_servers: McpServerConfig, + /// List of tools the agent can see. Use \"@{MCP_SERVER_NAME}/tool_name\" to specify tools from + /// mcp servers. To include all tools from a server, use \"@{MCP_SERVER_NAME}\" + #[serde(default)] + pub tools: Vec, + /// Tool aliases for remapping tool names + #[serde(default)] + #[schemars(schema_with = "alias_schema")] + pub tool_aliases: HashMap, + /// List of tools the agent is explicitly allowed to use + #[serde(default)] + pub allowed_tools: HashSet, + /// Files to include in the agent's context + #[serde(default)] + pub resources: Vec, + /// Commands to run when a chat session is created + #[serde(default)] + pub hooks: HashMap>, + /// Settings for specific tools. These are mostly for native tools. The actual schema differs by + /// tools and is documented in detail in our documentation + #[serde(default)] + #[schemars(schema_with = "tool_settings_schema")] + pub tools_settings: HashMap, + /// Whether or not to include the legacy ~/.aws/amazonq/mcp.json in the agent + /// You can reference tools brought in by these servers as just as you would with the servers + /// you configure in the mcpServers field in this config + #[serde(default)] + pub use_legacy_mcp_json: bool, + #[serde(skip)] + pub path: Option, +} + +impl Default for Agent { + fn default() -> Self { + Self { + schema: default_schema(), + name: DEFAULT_AGENT_NAME.to_string(), + description: Some("Default agent".to_string()), + prompt: Default::default(), + mcp_servers: Default::default(), + tools: vec!["*".to_string()], + tool_aliases: Default::default(), + allowed_tools: { + let mut set = HashSet::::new(); + let default_approve = DEFAULT_APPROVE.iter().copied().map(str::to_string); + set.extend(default_approve); + set + }, + resources: vec!["file://AmazonQ.md", "file://README.md", "file://.amazonq/rules/**/*.md"] + .into_iter() + .map(Into::into) + .collect::>(), + hooks: Default::default(), + tools_settings: Default::default(), + use_legacy_mcp_json: true, + path: None, + } + } +} + +impl Agent { + /// This function mutates the agent to a state that is writable. + /// Practically this means reverting some fields back to their original values as they were + /// written in the config. + fn freeze(&mut self) { + let Self { mcp_servers, .. } = self; + + mcp_servers + .mcp_servers + .retain(|_name, config| !config.is_from_legacy_mcp_json); + } + + /// This function mutates the agent to a state that is usable for runtime. + /// Practically this means to convert some of the fields value to their usable counterpart. + /// For example, converting the mcp array to actual mcp config and populate the agent file path. + fn thaw(&mut self, path: &Path, legacy_mcp_config: Option<&McpServerConfig>) -> Result<(), AgentConfigError> { + let Self { mcp_servers, .. } = self; + + self.path = Some(path.to_path_buf()); + + let mut stderr = std::io::stderr(); + if let (true, Some(legacy_mcp_config)) = (self.use_legacy_mcp_json, legacy_mcp_config) { + for (name, legacy_server) in &legacy_mcp_config.mcp_servers { + if mcp_servers.mcp_servers.contains_key(name) { + let _ = queue!( + stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("WARNING: "), + style::ResetColor, + style::Print("MCP server '"), + style::SetForegroundColor(Color::Green), + style::Print(name), + style::ResetColor, + style::Print( + "' is already configured in agent config. Skipping duplicate from legacy mcp.json.\n" + ) + ); + continue; + } + let mut server_clone = legacy_server.clone(); + server_clone.is_from_legacy_mcp_json = true; + mcp_servers.mcp_servers.insert(name.clone(), server_clone); + } + } + + stderr.flush()?; + + Ok(()) + } + + pub fn print_overridden_permissions(&self, output: &mut impl Write) -> Result<(), AgentConfigError> { + let execute_name = if cfg!(windows) { "execute_cmd" } else { "execute_bash" }; + for allowed_tool in &self.allowed_tools { + if let Some(settings) = self.tools_settings.get(allowed_tool.as_str()) { + // currently we only have four native tools that offers tool settings + let overridden_settings_key = match allowed_tool.as_str() { + "fs_read" | "fs_write" => Some("allowedPaths"), + "use_aws" => Some("allowedServices"), + name if name == execute_name => Some("allowedCommands"), + _ => None, + }; + + if let Some(key) = overridden_settings_key { + if let Some(ref override_settings) = settings.get(key).map(|value| format!("{key}: {value}")) { + queue_permission_override_warning(allowed_tool.as_str(), override_settings, output)?; + } + } + } + } + + Ok(()) + } + + pub fn to_str_pretty(&self) -> eyre::Result { + let mut agent_clone = self.clone(); + agent_clone.freeze(); + Ok(serde_json::to_string_pretty(&agent_clone)?) + } + + /// Retrieves an agent by name. It does so via first seeking the given agent under local dir, + /// and falling back to global dir if it does not exist in local. + pub async fn get_agent_by_name(os: &Os, agent_name: &str) -> eyre::Result<(Agent, PathBuf)> { + let config_path: Result = 'config: { + // local first, and then fall back to looking at global + let local_config_dir = directories::chat_local_agent_dir(os)?.join(format!("{agent_name}.json")); + if os.fs.exists(&local_config_dir) { + break 'config Ok(local_config_dir); + } + + let global_config_dir = directories::chat_global_agent_path(os)?.join(format!("{agent_name}.json")); + if os.fs.exists(&global_config_dir) { + break 'config Ok(global_config_dir); + } + + Err(global_config_dir) + }; + + match config_path { + Ok(config_path) => { + let content = os.fs.read(&config_path).await?; + let mut agent = serde_json::from_slice::(&content)?; + let legacy_mcp_config = if agent.use_legacy_mcp_json { + load_legacy_mcp_config(os).await.unwrap_or(None) + } else { + None + }; + + agent.thaw(&config_path, legacy_mcp_config.as_ref())?; + Ok((agent, config_path)) + }, + _ => bail!("Agent {agent_name} does not exist"), + } + } + + pub async fn load( + os: &Os, + agent_path: impl AsRef, + legacy_mcp_config: &mut Option, + mcp_enabled: bool, + ) -> Result { + let content = os.fs.read(&agent_path).await?; + let mut agent = serde_json::from_slice::(&content).map_err(|e| AgentConfigError::InvalidJson { + error: e, + path: agent_path.as_ref().to_path_buf(), + })?; + + if mcp_enabled { + if agent.use_legacy_mcp_json && legacy_mcp_config.is_none() { + let config = load_legacy_mcp_config(os).await.unwrap_or_default(); + if let Some(config) = config { + legacy_mcp_config.replace(config); + } + } + agent.thaw(agent_path.as_ref(), legacy_mcp_config.as_ref())?; + } else { + agent.clear_mcp_configs(); + // Thaw the agent with empty MCP config to finalize normalization. + agent.thaw(agent_path.as_ref(), None)?; + } + Ok(agent) + } + + /// Clear all MCP configurations while preserving built-in tools + pub fn clear_mcp_configs(&mut self) { + self.mcp_servers = McpServerConfig::default(); + self.use_legacy_mcp_json = false; + + // Transform tools: "*" → "@builtin", remove MCP refs + self.tools = self + .tools + .iter() + .filter_map(|tool| match tool.as_str() { + "*" => Some("@builtin".to_string()), + t if !is_mcp_tool_ref(t) => Some(t.to_string()), + _ => None, + }) + .collect(); + + // Remove MCP references from other fields + self.allowed_tools.retain(|tool| !is_mcp_tool_ref(tool)); + self.tool_aliases.retain(|orig, _| !is_mcp_tool_ref(&orig.to_string())); + self.tools_settings + .retain(|target, _| !is_mcp_tool_ref(&target.to_string())); + } +} + +/// Result of evaluating tool permissions, indicating whether a tool should be allowed, +/// require user confirmation, or be denied with specific reasons. +#[derive(Debug, PartialEq)] +pub enum PermissionEvalResult { + /// Tool is allowed to execute without user confirmation + Allow, + /// Tool requires user confirmation before execution + Ask, + /// Denial with specific reasons explaining why the tool was denied + /// Tools are free to overload what these reasons are + Deny(Vec), +} + +#[derive(Clone, Default, Debug)] +pub struct Agents { + /// Mapping from agent name to an [Agent]. + pub agents: HashMap, + /// Agent name. + pub active_idx: String, + pub trust_all_tools: bool, +} + +impl Agents { + /// This function assumes the relevant transformation to the tool names have been done: + /// - model tool name -> host tool name + /// - custom tool namespacing + pub fn trust_tools(&mut self, tool_names: Vec) { + if let Some(agent) = self.get_active_mut() { + agent.allowed_tools.extend(tool_names); + } + } + + /// This function assumes the relevant transformation to the tool names have been done: + /// - model tool name -> host tool name + /// - custom tool namespacing + pub fn untrust_tools(&mut self, tool_names: &[String]) { + if let Some(agent) = self.get_active_mut() { + agent.allowed_tools.retain(|t| !tool_names.contains(t)); + } + } + + pub fn get_active(&self) -> Option<&Agent> { + self.agents.get(&self.active_idx) + } + + pub fn get_active_mut(&mut self) -> Option<&mut Agent> { + self.agents.get_mut(&self.active_idx) + } + + pub fn switch(&mut self, name: &str) -> eyre::Result<&Agent> { + if !self.agents.contains_key(name) { + eyre::bail!("No agent with name {name} found"); + } + self.active_idx = name.to_string(); + self.agents + .get(name) + .ok_or(eyre::eyre!("No agent with name {name} found")) + } + + /// This function does a number of things in the following order: + /// 1. Migrates old profiles if applicable + /// 2. Loads local agents + /// 3. Loads global agents + /// 4. Resolve agent conflicts and merge the two sets of agents + /// 5. Validates the active agent config and surfaces error to output accordingly + /// + /// # Arguments + /// * `os` - Operating system interface for file system operations and database access + /// * `agent_name` - Optional specific agent name to activate; if None, falls back to default + /// agent selection + /// * `skip_migration` - If true, skips migration of old profiles to new format + /// * `output` - Writer for outputting warnings, errors, and status messages during loading + pub async fn load( + os: &mut Os, + agent_name: Option<&str>, + skip_migration: bool, + output: &mut impl Write, + mcp_enabled: bool, + ) -> (Self, AgentsLoadMetadata) { + if !mcp_enabled { + let _ = execute!( + output, + style::SetForegroundColor(Color::Yellow), + style::Print("\n"), + style::Print("⚠️ WARNING: "), + style::SetForegroundColor(Color::Reset), + style::Print("MCP functionality has been disabled by your administrator.\n\n"), + ); + } + + // Tracking metadata about the performed load operation. + let mut load_metadata = AgentsLoadMetadata::default(); + + let new_agents = if !skip_migration { + match legacy::migrate(os, false).await { + Ok(Some(new_agents)) => { + let migrated_count = new_agents.len(); + info!(migrated_count, "Profile migration successful"); + load_metadata.migration_performed = true; + load_metadata.migrated_count = migrated_count as u32; + new_agents + }, + Ok(None) => { + info!("Migration was not performed"); + vec![] + }, + Err(e) => { + error!("Migration did not happen for the following reason: {e}"); + vec![] + }, + } + } else { + vec![] + }; + + let mut global_mcp_config = None::; + + let mut local_agents = 'local: { + // We could be launching from the home dir, in which case the global and local agents + // are the same set of agents. If that is the case, we simply skip this. + match (std::env::current_dir(), directories::home_dir(os)) { + (Ok(cwd), Ok(home_dir)) if cwd == home_dir => break 'local Vec::::new(), + _ => { + // noop, we keep going with the extraction of local agents (even if we have an + // error retrieving cwd or home_dir) + }, + } + + let Ok(path) = directories::chat_local_agent_dir(os) else { + break 'local Vec::::new(); + }; + let Ok(files) = os.fs.read_dir(path).await else { + break 'local Vec::::new(); + }; + + let mut agents = Vec::::new(); + let results = load_agents_from_entries(files, os, &mut global_mcp_config, mcp_enabled).await; + for result in results { + match result { + Ok(agent) => agents.push(agent), + Err(e) => { + load_metadata.load_failed_count += 1; + let _ = queue!( + output, + style::SetForegroundColor(Color::Red), + style::Print("Error: "), + style::ResetColor, + style::Print(e), + style::Print("\n"), + ); + }, + } + } + + agents + }; + + let mut global_agents = 'global: { + let Ok(path) = directories::chat_global_agent_path(os) else { + break 'global Vec::::new(); + }; + let files = match os.fs.read_dir(&path).await { + Ok(files) => files, + Err(e) => { + if matches!(e.kind(), io::ErrorKind::NotFound) { + if let Err(e) = os.fs.create_dir_all(&path).await { + error!("Error creating global agent dir: {:?}", e); + } + } + break 'global Vec::::new(); + }, + }; + + let mut agents = Vec::::new(); + let results = load_agents_from_entries(files, os, &mut global_mcp_config, mcp_enabled).await; + for result in results { + match result { + Ok(agent) => agents.push(agent), + Err(e) => { + load_metadata.load_failed_count += 1; + let _ = queue!( + output, + style::SetForegroundColor(Color::Red), + style::Print("Error: "), + style::ResetColor, + style::Print(e), + style::Print("\n"), + ); + }, + } + } + + agents + } + .into_iter() + .chain(new_agents) + .collect::>(); + + // Here we also want to make sure the example config is written to disk if it's not already + // there. + // Note that this config is not what q chat uses. It merely serves as an example. + 'example_config: { + let Ok(path) = directories::example_agent_config(os) else { + error!("Error obtaining example agent path."); + break 'example_config; + }; + if os.fs.exists(&path) { + break 'example_config; + } + + // At this point the agents dir would have been created. All we have to worry about is + // the creation of the example config + if let Err(e) = os.fs.create_new(&path).await { + error!("Error creating example agent config: {e}."); + break 'example_config; + } + + let example_agent = Agent { + // This is less important than other fields since names are derived from the name + // of the config file and thus will not be persisted + name: "example".to_string(), + description: Some("This is an example agent config (and will not be loaded unless you change it to have .json extension)".to_string()), + tools: { + NATIVE_TOOLS + .iter() + .copied() + .map(str::to_string) + .chain(vec![ + format!("@mcp_server_name{MCP_SERVER_TOOL_DELIMITER}mcp_tool_name"), + "@mcp_server_name_without_tool_specification_to_include_all_tools".to_string(), + ]) + .collect::>() + }, + ..Default::default() + }; + let Ok(content) = example_agent.to_str_pretty() else { + error!("Error serializing example agent config"); + break 'example_config; + }; + if let Err(e) = os.fs.write(&path, &content).await { + error!("Error writing example agent config to file: {e}"); + break 'example_config; + }; + } + + let local_names = local_agents.iter().map(|a| a.name.as_str()).collect::>(); + global_agents.retain(|a| { + // If there is a naming conflict for agents, we would retain the local instance + let name = a.name.as_str(); + if local_names.contains(name) { + let _ = queue!( + output, + style::SetForegroundColor(style::Color::Yellow), + style::Print("WARNING: "), + style::ResetColor, + style::Print("Agent conflict for "), + style::SetForegroundColor(style::Color::Green), + style::Print(name), + style::ResetColor, + style::Print(". Using workspace version.\n") + ); + false + } else { + true + } + }); + + local_agents.append(&mut global_agents); + let mut all_agents = local_agents; + + // Assume agent in the following order of priority: + // 1. The agent name specified by the start command via --agent (this is the agent_name that's + // passed in) + // 2. If the above is missing or invalid, assume one that is specified by chat.defaultAgent + // 3. If the above is missing or invalid, assume the in-memory default + let active_idx = 'active_idx: { + if let Some(name) = agent_name { + if all_agents.iter().any(|a| a.name.as_str() == name) { + break 'active_idx name.to_string(); + } + let _ = queue!( + output, + style::SetForegroundColor(Color::Red), + style::Print("Error"), + style::SetForegroundColor(Color::Yellow), + style::Print(format!( + ": no agent with name {} found. Falling back to user specified default", + name + )), + style::Print("\n"), + style::SetForegroundColor(Color::Reset) + ); + } + + if let Some(user_set_default) = os.database.settings.get_string(Setting::ChatDefaultAgent) { + if all_agents.iter().any(|a| a.name == user_set_default) { + break 'active_idx user_set_default; + } + let _ = queue!( + output, + style::SetForegroundColor(Color::Red), + style::Print("Error"), + style::SetForegroundColor(Color::Yellow), + style::Print(format!( + ": user defined default {} not found. Falling back to in-memory default", + user_set_default + )), + style::Print("\n"), + style::SetForegroundColor(Color::Reset) + ); + } + + all_agents.push({ + let mut agent = Agent::default(); + if mcp_enabled { + 'load_legacy_mcp_json: { + if global_mcp_config.is_none() { + let Ok(global_mcp_path) = directories::chat_legacy_global_mcp_config(os) else { + tracing::error!("Error obtaining legacy mcp json path. Skipping"); + break 'load_legacy_mcp_json; + }; + let legacy_mcp_config = match McpServerConfig::load_from_file(os, global_mcp_path).await { + Ok(config) => config, + Err(e) => { + tracing::error!("Error loading global mcp json path: {e}. Skipping"); + break 'load_legacy_mcp_json; + }, + }; + global_mcp_config.replace(legacy_mcp_config); + } + } + + if let Some(config) = &global_mcp_config { + agent.mcp_servers = config.clone(); + } + } else { + agent.mcp_servers = McpServerConfig::default(); + } + agent + }); + + DEFAULT_AGENT_NAME.to_string() + }; + + let _ = output.flush(); + + // Post parsing validation here + let schema = schema_for!(Agent); + let agents = all_agents + .into_iter() + .map(|a| (a.name.clone(), a)) + .collect::>(); + let active_agent = agents.get(&active_idx); + + 'validate: { + match (serde_json::to_value(schema), active_agent) { + (Ok(schema), Some(agent)) => { + let Ok(instance) = serde_json::to_value(agent) else { + let name = &agent.name; + error!("Error converting active agent {name} to value for validation. Skipping"); + break 'validate; + }; + if let Err(e) = jsonschema::validate(&schema, &instance).map_err(|e| e.to_owned()) { + let name = &agent.name; + let _ = execute!( + output, + style::SetForegroundColor(Color::Yellow), + style::Print("WARNING "), + style::ResetColor, + style::Print("Agent config "), + style::SetForegroundColor(Color::Green), + style::Print(name), + style::ResetColor, + style::Print(" is malformed at "), + style::SetForegroundColor(Color::Yellow), + style::Print(&e.instance_path), + style::ResetColor, + style::Print(format!(": {e}\n")), + ); + } + }, + (Err(e), _) => { + error!("Failed to convert agent definition to schema: {e}. Skipping validation"); + }, + (_, None) => { + warn!("Skipping config validation because there is no active agent"); + }, + } + } + + load_metadata.launched_agent = active_idx.clone(); + ( + Self { + agents, + active_idx, + ..Default::default() + }, + load_metadata, + ) + } + + /// Returns a label to describe the permission status for a given tool. + pub fn display_label(&self, tool_name: &str, origin: &ToolOrigin) -> String { + use crate::util::pattern_matching::matches_any_pattern; + + let tool_trusted = self.get_active().is_some_and(|a| { + if matches!(origin, &ToolOrigin::Native) { + return matches_any_pattern(&a.allowed_tools, tool_name); + } + + a.allowed_tools.iter().any(|name| { + name.strip_prefix("@").is_some_and(|remainder| { + remainder + .split_once(MCP_SERVER_TOOL_DELIMITER) + .is_some_and(|(_left, right)| right == tool_name) + || remainder == >::borrow(origin) + }) || { + if let Some(server_name) = name.strip_prefix("@").and_then(|s| s.split('/').next()) { + if server_name == >::borrow(origin) { + let tool_pattern = format!("@{}/{}", server_name, tool_name); + matches_any_pattern(&a.allowed_tools, &tool_pattern) + } else { + false + } + } else { + false + } + } + }) + }); + + if tool_trusted || self.trust_all_tools { + format!("* {}", "trusted".dark_green().bold()) + } else { + self.default_permission_label(tool_name) + } + } + + /// Provide default permission labels for the built-in set of tools. + // This "static" way avoids needing to construct a tool instance. + fn default_permission_label(&self, tool_name: &str) -> String { + let label = match tool_name { + "fs_read" => "trusted".dark_green().bold(), + "fs_write" => "not trusted".dark_grey(), + #[cfg(not(windows))] + "execute_bash" => "trust read-only commands".dark_grey(), + #[cfg(windows)] + "execute_cmd" => "trust read-only commands".dark_grey(), + "use_aws" => "trust read-only commands".dark_grey(), + "report_issue" => "trusted".dark_green().bold(), + "thinking" => "trusted (prerelease)".dark_green().bold(), + _ if self.trust_all_tools => "trusted".dark_grey().bold(), + _ => "not trusted".dark_grey(), + }; + + format!("{} {label}", "*".reset()) + } +} + +/// Metadata from the executed [Agents::load] operation. +#[derive(Debug, Clone, Default)] +pub struct AgentsLoadMetadata { + pub migration_performed: bool, + pub migrated_count: u32, + pub load_count: u32, + pub load_failed_count: u32, + pub launched_agent: String, +} + +async fn load_agents_from_entries( + mut files: ReadDir, + os: &Os, + global_mcp_config: &mut Option, + mcp_enabled: bool, +) -> Vec> { + let mut res = Vec::>::new(); + + while let Ok(Some(file)) = files.next_entry().await { + let file_path = &file.path(); + if file_path + .extension() + .and_then(OsStr::to_str) + .is_some_and(|s| s == "json") + { + res.push(Agent::load(os, file_path, global_mcp_config, mcp_enabled).await); + } + } + + res +} + +/// Loads legacy mcp config by combining workspace and global config. +/// In case of a server naming conflict, the workspace config is prioritized. +async fn load_legacy_mcp_config(os: &Os) -> eyre::Result> { + let global_mcp_path = directories::chat_legacy_global_mcp_config(os)?; + let global_mcp_config = match McpServerConfig::load_from_file(os, global_mcp_path).await { + Ok(config) => Some(config), + Err(e) => { + tracing::error!("Error loading global mcp json path: {e}."); + None + }, + }; + + let workspace_mcp_path = directories::chat_legacy_workspace_mcp_config(os)?; + let workspace_mcp_config = match McpServerConfig::load_from_file(os, workspace_mcp_path).await { + Ok(config) => Some(config), + Err(e) => { + tracing::error!("Error loading global mcp json path: {e}."); + None + }, + }; + + Ok(match (workspace_mcp_config, global_mcp_config) { + (Some(mut wc), Some(gc)) => { + for (server_name, config) in gc.mcp_servers { + // We prioritize what is in the workspace + wc.mcp_servers.entry(server_name).or_insert(config); + } + + Some(wc) + }, + (None, Some(gc)) => Some(gc), + (Some(wc), None) => Some(wc), + _ => None, + }) +} + +pub fn queue_permission_override_warning( + tool_name: &str, + overridden_settings: &str, + output: &mut impl Write, +) -> Result<(), std::io::Error> { + Ok(queue!( + output, + style::SetForegroundColor(Color::Yellow), + style::Print("WARNING: "), + style::ResetColor, + style::Print("You have trusted "), + style::SetForegroundColor(Color::Green), + style::Print(tool_name), + style::ResetColor, + style::Print(" tool, which overrides the toolsSettings: "), + style::SetForegroundColor(Color::Cyan), + style::Print(overridden_settings), + style::ResetColor, + style::Print("\n"), + )?) +} + +fn default_schema() -> String { + "https://raw.githubusercontent.com/aws/amazon-q-developer-cli/refs/heads/main/schemas/agent-v1.json".into() +} + +// Check if a tool reference is MCP-specific (not @builtin and starts with @) +pub fn is_mcp_tool_ref(s: &str) -> bool { + // @builtin is not MCP, it's a reference to all built-in tools + // Any other @ prefix is MCP (e.g., "@git", "@git/git_status") + !s.starts_with("@builtin") && s.starts_with('@') +} + +#[cfg(test)] +fn validate_agent_name(name: &str) -> eyre::Result<()> { + // Check if name is empty + if name.is_empty() { + eyre::bail!("Agent name cannot be empty"); + } + + // Check if name contains only allowed characters and starts with an alphanumeric character + let re = regex::Regex::new(r"^[a-zA-Z0-9][a-zA-Z0-9_-]*$")?; + if !re.is_match(name) { + eyre::bail!( + "Agent name must start with an alphanumeric character and can only contain alphanumeric characters, hyphens, and underscores" + ); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::*; + const INPUT: &str = r#" + { + "name": "some_agent", + "description": "My developer agent is used for small development tasks like solving open issues.", + "prompt": "You are a principal developer who uses multiple agents to accomplish difficult engineering tasks", + "mcpServers": { + "fetch": { "command": "fetch3.1", "args": [] }, + "git": { "command": "git-mcp", "args": [] } + }, + "tools": [ + "@git" + ], + "toolAliases": { + "@gits/some_tool": "some_tool2" + }, + "allowedTools": [ + "fs_read", + "@fetch", + "@gits/git_status" + ], + "resources": [ + "file://~/my-genai-prompts/unittest.md" + ], + "toolsSettings": { + "fs_write": { "allowedPaths": ["~/**"] }, + "@git/git_status": { "git_user": "$GIT_USER" } + } + } + "#; + + #[test] + fn test_deser() { + let agent = serde_json::from_str::(INPUT).expect("Deserializtion failed"); + assert!(agent.mcp_servers.mcp_servers.contains_key("fetch")); + assert!(agent.mcp_servers.mcp_servers.contains_key("git")); + assert!(agent.tool_aliases.contains_key("@gits/some_tool")); + } + + #[test] + fn test_get_active() { + let mut collection = Agents::default(); + assert!(collection.get_active().is_none()); + + let agent = Agent::default(); + let agent_name = agent.name.clone(); + collection.agents.insert(agent_name.clone(), agent); + collection.active_idx = agent_name.clone(); + + assert!(collection.get_active().is_some()); + assert_eq!(collection.get_active().unwrap().name, agent_name); + } + + #[test] + fn test_get_active_mut() { + let mut collection = Agents::default(); + assert!(collection.get_active_mut().is_none()); + + let agent = Agent::default(); + collection.agents.insert("default".to_string(), agent); + collection.active_idx = "default".to_string(); + + assert!(collection.get_active_mut().is_some()); + let active = collection.get_active_mut().unwrap(); + active.description = Some("Modified description".to_string()); + + assert_eq!( + collection.agents.get("default").unwrap().description, + Some("Modified description".to_string()) + ); + } + + #[test] + fn test_switch() { + let mut collection = Agents::default(); + + let default_agent = Agent::default(); + let dev_agent = Agent { + name: "dev".to_string(), + description: Some("Developer agent".to_string()), + ..Default::default() + }; + + collection.agents.insert("default".to_string(), default_agent); + collection.agents.insert("dev".to_string(), dev_agent); + collection.active_idx = "default".to_string(); + + // Test successful switch + let result = collection.switch("dev"); + assert!(result.is_ok()); + assert_eq!(result.unwrap().name, "dev"); + + // Test switch to non-existent agent + let result = collection.switch("nonexistent"); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "No agent with name nonexistent found"); + } + + #[test] + fn test_validate_agent_name() { + // Valid names + assert!(validate_agent_name("valid").is_ok()); + assert!(validate_agent_name("valid123").is_ok()); + assert!(validate_agent_name("valid-name").is_ok()); + assert!(validate_agent_name("valid_name").is_ok()); + assert!(validate_agent_name("123valid").is_ok()); + + // Invalid names + assert!(validate_agent_name("").is_err()); + assert!(validate_agent_name("-invalid").is_err()); + assert!(validate_agent_name("_invalid").is_err()); + assert!(validate_agent_name("invalid!").is_err()); + assert!(validate_agent_name("invalid space").is_err()); + } + + #[test] + fn test_clear_mcp_configs_with_builtin_variants() { + let mut agent: Agent = serde_json::from_value(json!({ + "name": "test", + "tools": [ + "@builtin", + "@builtin/fs_read", + "@builtin/execute_bash", + "@git", + "@git/status", + "fs_write" + ], + "allowedTools": [ + "@builtin/fs_read", + "@git/status", + "fs_write" + ], + "toolAliases": { + "@builtin/fs_read": "read", + "@git/status": "git_st" + }, + "toolsSettings": { + "@builtin/fs_write": { "allowedPaths": ["~/**"] }, + "@git/commit": { "sign": true } + } + })) + .unwrap(); + + agent.clear_mcp_configs(); + + // All @builtin variants should be preserved while MCP tools should be removed + assert!(agent.tools.contains(&"@builtin".to_string())); + assert!(agent.tools.contains(&"@builtin/fs_read".to_string())); + assert!(agent.tools.contains(&"@builtin/execute_bash".to_string())); + assert!(agent.tools.contains(&"fs_write".to_string())); + assert!(!agent.tools.contains(&"@git".to_string())); + assert!(!agent.tools.contains(&"@git/status".to_string())); + + assert!(agent.allowed_tools.contains("@builtin/fs_read")); + assert!(agent.allowed_tools.contains("fs_write")); + assert!(!agent.allowed_tools.contains("@git/status")); + + // Check tool aliases - need to iterate since we can't construct OriginalToolName directly + let has_builtin_alias = agent + .tool_aliases + .iter() + .any(|(k, v)| k.to_string() == "@builtin/fs_read" && v == "read"); + assert!(has_builtin_alias, "@builtin/fs_read alias should be preserved"); + + let has_git_alias = agent.tool_aliases.iter().any(|(k, _)| k.to_string() == "@git/status"); + assert!(!has_git_alias, "@git/status alias should be removed"); + + // Check tool settings - need to iterate since we can't construct ToolSettingTarget directly + let has_builtin_setting = agent + .tools_settings + .iter() + .any(|(k, _)| k.to_string() == "@builtin/fs_write"); + assert!(has_builtin_setting, "@builtin/fs_write settings should be preserved"); + + let has_git_setting = agent.tools_settings.iter().any(|(k, _)| k.to_string() == "@git/commit"); + assert!(!has_git_setting, "@git/commit settings should be removed"); + } + + #[test] + fn test_display_label_no_active_agent() { + let agents = Agents::default(); + + let label = agents.display_label("fs_read", &ToolOrigin::Native); + // With no active agent, it should fall back to default permissions + // fs_read has a default of "trusted" + assert!( + label.contains("trusted"), + "fs_read should show default trusted permission, instead found: {}", + label + ); + } + + #[test] + fn test_display_label_trust_all_tools() { + let agents = Agents { + trust_all_tools: true, + ..Default::default() + }; + + // Should be trusted even if not in allowed_tools + let label = agents.display_label("random_tool", &ToolOrigin::Native); + assert!( + label.contains("trusted"), + "trust_all_tools should make everything trusted, instead found: {}", + label + ); + } + + #[test] + fn test_display_label_default_permissions() { + let agents = Agents::default(); + + // Test default permissions for known tools + let fs_read_label = agents.display_label("fs_read", &ToolOrigin::Native); + assert!( + fs_read_label.contains("trusted"), + "fs_read should be trusted by default, instead found: {}", + fs_read_label + ); + + let fs_write_label = agents.display_label("fs_write", &ToolOrigin::Native); + assert!( + fs_write_label.contains("not trusted"), + "fs_write should not be trusted by default, instead found: {}", + fs_write_label + ); + + let execute_name = if cfg!(windows) { "execute_cmd" } else { "execute_bash" }; + let execute_bash_label = agents.display_label(execute_name, &ToolOrigin::Native); + assert!( + execute_bash_label.contains("read-only"), + "execute_bash should show read-only by default, instead found: {}", + execute_bash_label + ); + } + + #[test] + fn test_display_label_comprehensive_patterns() { + let mut agents = Agents::default(); + + // Create agent with all types of patterns + let mut allowed_tools = HashSet::new(); + // Native exact match + allowed_tools.insert("fs_read".to_string()); + // Native wildcard + allowed_tools.insert("execute_*".to_string()); + // MCP server exact (allows all tools from that server) + allowed_tools.insert("@server1".to_string()); + // MCP tool exact + allowed_tools.insert("@server2/specific_tool".to_string()); + // MCP tool wildcard + allowed_tools.insert("@server3/tool_*".to_string()); + + let agent = Agent { + schema: "test".to_string(), + name: "test-agent".to_string(), + description: None, + prompt: None, + mcp_servers: Default::default(), + tools: Vec::new(), + tool_aliases: Default::default(), + allowed_tools, + tools_settings: Default::default(), + resources: Vec::new(), + hooks: Default::default(), + use_legacy_mcp_json: false, + path: None, + }; + + agents.agents.insert("test-agent".to_string(), agent); + agents.active_idx = "test-agent".to_string(); + + // Test 1: Native exact match + let label = agents.display_label("fs_read", &ToolOrigin::Native); + assert!( + label.contains("trusted"), + "fs_read should be trusted (exact match), instead found: {}", + label + ); + + // Test 2: Native wildcard match + let label = agents.display_label("execute_bash", &ToolOrigin::Native); + assert!( + label.contains("trusted"), + "execute_bash should match execute_* pattern, instead found: {}", + label + ); + + // Test 3: Native no match + let label = agents.display_label("fs_write", &ToolOrigin::Native); + assert!( + !label.contains("trusted") || label.contains("not trusted"), + "fs_write should not be trusted, instead found: {}", + label + ); + + // Test 4: MCP server exact match (allows any tool from server1) + let label = agents.display_label("any_tool", &ToolOrigin::McpServer("server1".to_string())); + assert!( + label.contains("trusted"), + "Server-level permission should allow any tool, instead found: {}", + label + ); + + // Test 5: MCP tool exact match + let label = agents.display_label("specific_tool", &ToolOrigin::McpServer("server2".to_string())); + assert!( + label.contains("trusted"), + "Exact MCP tool should be trusted, instead found: {}", + label + ); + + // Test 6: MCP tool wildcard match + let label = agents.display_label("tool_read", &ToolOrigin::McpServer("server3".to_string())); + assert!( + label.contains("trusted"), + "tool_read should match @server3/tool_* pattern, instead found: {}", + label + ); + + // Test 7: MCP tool no match + let label = agents.display_label("other_tool", &ToolOrigin::McpServer("server2".to_string())); + assert!( + !label.contains("trusted") || label.contains("not trusted"), + "Non-matching MCP tool should not be trusted, instead found: {}", + label + ); + + // Test 8: MCP server no match + let label = agents.display_label("some_tool", &ToolOrigin::McpServer("unknown_server".to_string())); + assert!( + !label.contains("trusted") || label.contains("not trusted"), + "Unknown server should not be trusted, instead found: {}", + label + ); + } +} + + +---- crates/chat-cli/src/cli/chat/cli/mod.rs ---- +pub mod clear; +pub mod compact; +pub mod context; +pub mod editor; +pub mod hooks; +pub mod knowledge; +pub mod mcp; +pub mod model; +pub mod persist; +pub mod profile; +pub mod prompts; +pub mod subscribe; +pub mod tangent; +pub mod tools; +pub mod usage; + +use clap::Parser; +use clear::ClearArgs; +use compact::CompactArgs; +use context::ContextSubcommand; +use editor::EditorArgs; +use hooks::HooksArgs; +use knowledge::KnowledgeSubcommand; +use mcp::McpArgs; +use model::ModelArgs; +use persist::PersistSubcommand; +use profile::AgentSubcommand; +use prompts::PromptsArgs; +use tangent::TangentArgs; +use tools::ToolsArgs; + +use crate::cli::chat::cli::subscribe::SubscribeArgs; +use crate::cli::chat::cli::usage::UsageArgs; +use crate::cli::chat::consts::AGENT_MIGRATION_DOC_URL; +use crate::cli::chat::{ + ChatError, + ChatSession, + ChatState, + EXTRA_HELP, +}; +use crate::cli::issue; +use crate::os::Os; + +/// q (Amazon Q Chat) +#[derive(Debug, PartialEq, Parser)] +#[command(color = clap::ColorChoice::Always, term_width = 0, after_long_help = EXTRA_HELP)] +pub enum SlashCommand { + /// Quit the application + #[command(aliases = ["q", "exit"])] + Quit, + /// Clear the conversation history + Clear(ClearArgs), + /// Manage agents + #[command(subcommand)] + Agent(AgentSubcommand), + #[command(hide = true)] + Profile, + /// Manage context files for the chat session + #[command(subcommand)] + Context(ContextSubcommand), + /// (Beta) Manage knowledge base for persistent context storage. Requires "q settings + /// chat.enableKnowledge true" + #[command(subcommand, hide = true)] + Knowledge(KnowledgeSubcommand), + /// Open $EDITOR (defaults to vi) to compose a prompt + #[command(name = "editor")] + PromptEditor(EditorArgs), + /// Summarize the conversation to free up context space + Compact(CompactArgs), + /// View tools and permissions + Tools(ToolsArgs), + /// Create a new Github issue or make a feature request + Issue(issue::IssueArgs), + /// View and retrieve prompts + Prompts(PromptsArgs), + /// View context hooks + Hooks(HooksArgs), + /// Show current session's context window usage + Usage(UsageArgs), + /// See mcp server loaded + Mcp(McpArgs), + /// Select a model for the current conversation session + Model(ModelArgs), + /// Upgrade to a Q Developer Pro subscription for increased query limits + Subscribe(SubscribeArgs), + /// Toggle tangent mode for isolated conversations + Tangent(TangentArgs), + #[command(flatten)] + Persist(PersistSubcommand), + // #[command(flatten)] + // Root(RootSubcommand), +} + +impl SlashCommand { + pub async fn execute(self, os: &mut Os, session: &mut ChatSession) -> Result { + match self { + Self::Quit => Ok(ChatState::Exit), + Self::Clear(args) => args.execute(session).await, + Self::Agent(subcommand) => subcommand.execute(os, session).await, + Self::Profile => { + use crossterm::{ + execute, + style, + }; + execute!( + session.stderr, + style::SetForegroundColor(style::Color::Yellow), + style::Print("This command has been deprecated. Use"), + style::SetForegroundColor(style::Color::Cyan), + style::Print(" /agent "), + style::SetForegroundColor(style::Color::Yellow), + style::Print("instead.\nSee "), + style::Print(AGENT_MIGRATION_DOC_URL), + style::Print(" for more detail"), + style::Print("\n"), + style::ResetColor, + )?; + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + }, + Self::Context(args) => args.execute(os, session).await, + Self::Knowledge(subcommand) => subcommand.execute(os, session).await, + Self::PromptEditor(args) => args.execute(session).await, + Self::Compact(args) => args.execute(os, session).await, + Self::Tools(args) => args.execute(session).await, + Self::Issue(args) => { + if let Err(err) = args.execute(os).await { + return Err(ChatError::Custom(err.to_string().into())); + } + + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + }, + Self::Prompts(args) => args.execute(session).await, + Self::Hooks(args) => args.execute(session).await, + Self::Usage(args) => args.execute(os, session).await, + Self::Mcp(args) => args.execute(session).await, + Self::Model(args) => args.execute(os, session).await, + Self::Subscribe(args) => args.execute(os, session).await, + Self::Tangent(args) => args.execute(os, session).await, + Self::Persist(subcommand) => subcommand.execute(os, session).await, + // Self::Root(subcommand) => { + // if let Err(err) = subcommand.execute(os, database, telemetry).await { + // return Err(ChatError::Custom(err.to_string().into())); + // } + // + // Ok(ChatState::PromptUser { + // skip_printing_tools: true, + // }) + // }, + } + } + + pub fn command_name(&self) -> &'static str { + match self { + Self::Quit => "quit", + Self::Clear(_) => "clear", + Self::Agent(_) => "agent", + Self::Profile => "profile", + Self::Context(_) => "context", + Self::Knowledge(_) => "knowledge", + Self::PromptEditor(_) => "editor", + Self::Compact(_) => "compact", + Self::Tools(_) => "tools", + Self::Issue(_) => "issue", + Self::Prompts(_) => "prompts", + Self::Hooks(_) => "hooks", + Self::Usage(_) => "usage", + Self::Mcp(_) => "mcp", + Self::Model(_) => "model", + Self::Subscribe(_) => "subscribe", + Self::Tangent(_) => "tangent", + Self::Persist(sub) => match sub { + PersistSubcommand::Save { .. } => "save", + PersistSubcommand::Load { .. } => "load", + }, + } + } + + pub fn subcommand_name(&self) -> Option<&'static str> { + match self { + SlashCommand::Agent(sub) => Some(sub.name()), + SlashCommand::Context(sub) => Some(sub.name()), + SlashCommand::Knowledge(sub) => Some(sub.name()), + SlashCommand::Tools(arg) => arg.subcommand_name(), + SlashCommand::Prompts(arg) => arg.subcommand_name(), + _ => None, + } + } +} + + +---- crates/chat-cli/src/cli/chat/conversation.rs ---- +use std::collections::{ + HashMap, + HashSet, + VecDeque, +}; +use std::io::Write; +use std::sync::atomic::Ordering; + +use chrono::Utc; +use crossterm::style::Color; +use crossterm::{ + execute, + style, +}; +use serde::{ + Deserialize, + Serialize, +}; +use tracing::{ + debug, + warn, +}; + +use super::cli::compact::CompactStrategy; +use super::cli::model::context_window_tokens; +use super::consts::{ + DUMMY_TOOL_NAME, + MAX_CONVERSATION_STATE_HISTORY_LEN, +}; +use super::context::{ + ContextManager, + calc_max_context_files_size, +}; +use super::line_tracker::FileLineTracker; +use super::message::{ + AssistantMessage, + ToolUseResult, + UserMessage, +}; +use super::parser::RequestMetadata; +use super::token_counter::{ + CharCount, + CharCounter, + TokenCounter, +}; +use super::tool_manager::ToolManager; +use super::tools::{ + InputSchema, + QueuedTool, + ToolOrigin, + ToolSpec, +}; +use super::util::serde_value_to_document; +use crate::api_client::model::{ + ChatMessage, + ConversationState as FigConversationState, + ImageBlock, + Tool, + ToolInputSchema, + ToolSpecification, + UserInputMessage, +}; +use crate::cli::agent::Agents; +use crate::cli::agent::hook::{ + Hook, + HookTrigger, +}; +use crate::cli::chat::ChatError; +use crate::cli::chat::cli::model::{ + ModelInfo, + get_model_info, +}; +use crate::mcp_client::Prompt; +use crate::os::Os; + +pub const CONTEXT_ENTRY_START_HEADER: &str = "--- CONTEXT ENTRY BEGIN ---\n"; +pub const CONTEXT_ENTRY_END_HEADER: &str = "--- CONTEXT ENTRY END ---\n\n"; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HistoryEntry { + user: UserMessage, + assistant: AssistantMessage, + #[serde(default)] + request_metadata: Option, +} + +/// Tracks state related to an ongoing conversation. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConversationState { + /// Randomly generated on creation. + conversation_id: String, + /// The next user message to be sent as part of the conversation. Required to be [Some] before + /// calling [Self::as_sendable_conversation_state]. + next_message: Option, + history: VecDeque, + /// The range in the history sendable to the backend (start inclusive, end exclusive). + valid_history_range: (usize, usize), + /// Similar to history in that stores user and assistant responses, except that it is not used + /// in message requests. Instead, the responses are expected to be in human-readable format, + /// e.g user messages prefixed with '> '. Should also be used to store errors posted in the + /// chat. + pub transcript: VecDeque, + pub tools: HashMap>, + /// Context manager for handling sticky context files + pub context_manager: Option, + /// Tool manager for handling tool and mcp related activities + #[serde(skip)] + pub tool_manager: ToolManager, + /// Cached value representing the length of the user context message. + context_message_length: Option, + /// Stores the latest conversation summary created by /compact + latest_summary: Option<(String, RequestMetadata)>, + #[serde(skip)] + pub agents: Agents, + /// Unused, kept only to maintain deserialization backwards compatibility with <=v1.13.3 + /// Model explicitly selected by the user in this conversation state via `/model`. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub model: Option, + /// Model explicitly selected by the user in this conversation state via `/model`. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub model_info: Option, + /// Used to track agent vs user updates to file modifications. + /// + /// Maps from a file path to [FileLineTracker] + #[serde(default)] + pub file_line_tracker: HashMap, + #[serde(default = "default_true")] + pub mcp_enabled: bool, + /// Tangent mode checkpoint - stores main conversation when in tangent mode + #[serde(default, skip_serializing_if = "Option::is_none")] + tangent_state: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ConversationCheckpoint { + /// Main conversation history stored while in tangent mode + main_history: VecDeque, + /// Main conversation next message + main_next_message: Option, + /// Main conversation transcript + main_transcript: VecDeque, +} + +impl ConversationState { + pub async fn new( + conversation_id: &str, + agents: Agents, + tool_config: HashMap, + tool_manager: ToolManager, + current_model_id: Option, + os: &Os, + mcp_enabled: bool, + ) -> Self { + let model = if let Some(model_id) = current_model_id { + match get_model_info(&model_id, os).await { + Ok(info) => Some(info), + Err(e) => { + tracing::warn!("Failed to get model info for {}: {}, using default", model_id, e); + Some(ModelInfo::from_id(model_id)) + }, + } + } else { + None + }; + + let context_manager = if let Some(agent) = agents.get_active() { + ContextManager::from_agent(agent, calc_max_context_files_size(model.as_ref())).ok() + } else { + None + }; + + Self { + conversation_id: conversation_id.to_string(), + next_message: None, + history: VecDeque::new(), + valid_history_range: Default::default(), + transcript: VecDeque::with_capacity(MAX_CONVERSATION_STATE_HISTORY_LEN), + tools: tool_config + .into_values() + .fold(HashMap::>::new(), |mut acc, v| { + let tool = Tool::ToolSpecification(ToolSpecification { + name: v.name, + description: v.description, + input_schema: v.input_schema.into(), + }); + acc.entry(v.tool_origin) + .and_modify(|tools| tools.push(tool.clone())) + .or_insert(vec![tool]); + acc + }), + context_manager, + tool_manager, + context_message_length: None, + latest_summary: None, + agents, + model: None, + model_info: model, + file_line_tracker: HashMap::new(), + mcp_enabled, + tangent_state: None, + } + } + + pub fn latest_summary(&self) -> Option<&str> { + self.latest_summary.as_ref().map(|(s, _)| s.as_str()) + } + + pub fn history(&self) -> &VecDeque { + &self.history + } + + /// Clears the conversation history and optionally the summary. + pub fn clear(&mut self, preserve_summary: bool) { + self.next_message = None; + self.history.clear(); + if !preserve_summary { + self.latest_summary = None; + } + } + + /// Check if currently in tangent mode + pub fn is_in_tangent_mode(&self) -> bool { + self.tangent_state.is_some() + } + + /// Create a checkpoint of current conversation state + fn create_checkpoint(&self) -> ConversationCheckpoint { + ConversationCheckpoint { + main_history: self.history.clone(), + main_next_message: self.next_message.clone(), + main_transcript: self.transcript.clone(), + } + } + + /// Restore conversation state from checkpoint + fn restore_from_checkpoint(&mut self, checkpoint: ConversationCheckpoint) { + self.history = checkpoint.main_history; + self.next_message = checkpoint.main_next_message; + self.transcript = checkpoint.main_transcript; + self.valid_history_range = (0, self.history.len()); + } + + /// Enter tangent mode - creates checkpoint of current state + pub fn enter_tangent_mode(&mut self) { + if self.tangent_state.is_none() { + self.tangent_state = Some(self.create_checkpoint()); + } + } + + /// Exit tangent mode - restore from checkpoint + pub fn exit_tangent_mode(&mut self) { + if let Some(checkpoint) = self.tangent_state.take() { + self.restore_from_checkpoint(checkpoint); + } + } + + /// Appends a collection prompts into history and returns the last message in the collection. + /// It asserts that the collection ends with a prompt that assumes the role of user. + pub fn append_prompts(&mut self, mut prompts: VecDeque) -> Option { + debug_assert!(self.next_message.is_none(), "next_message should not exist"); + debug_assert!(prompts.back().is_some_and(|p| p.role == crate::mcp_client::Role::User)); + let last_msg = prompts.pop_back()?; + let (mut candidate_user, mut candidate_asst) = (None::, None::); + while let Some(prompt) = prompts.pop_front() { + let Prompt { role, content } = prompt; + match role { + crate::mcp_client::Role::User => { + let user_msg = UserMessage::new_prompt(content.to_string(), None); + candidate_user.replace(user_msg); + }, + crate::mcp_client::Role::Assistant => { + let assistant_msg = AssistantMessage::new_response(None, content.into()); + candidate_asst.replace(assistant_msg); + }, + } + if candidate_asst.is_some() && candidate_user.is_some() { + let assistant = candidate_asst.take().unwrap(); + let user = candidate_user.take().unwrap(); + self.append_assistant_transcript(&assistant); + self.history.push_back(HistoryEntry { + user, + assistant, + request_metadata: None, + }); + } + } + Some(last_msg.content.to_string()) + } + + pub fn next_user_message(&self) -> Option<&UserMessage> { + self.next_message.as_ref() + } + + pub fn reset_next_user_message(&mut self) { + self.next_message = None; + } + + pub async fn set_next_user_message(&mut self, input: String) { + debug_assert!(self.next_message.is_none(), "next_message should not exist"); + if let Some(next_message) = self.next_message.as_ref() { + warn!(?next_message, "next_message should not exist"); + } + + let input = if input.is_empty() { + warn!("input must not be empty when adding new messages"); + "Empty prompt".to_string() + } else { + input + }; + + let msg = UserMessage::new_prompt(input, Some(Utc::now())); + self.next_message = Some(msg); + } + + /// Sets the response message according to the currently set [Self::next_message]. + pub fn push_assistant_message( + &mut self, + os: &mut Os, + message: AssistantMessage, + request_metadata: Option, + ) { + debug_assert!(self.next_message.is_some(), "next_message should exist"); + let next_user_message = self.next_message.take().expect("next user message should exist"); + + self.append_assistant_transcript(&message); + self.history.push_back(HistoryEntry { + user: next_user_message, + assistant: message, + request_metadata, + }); + + if let Ok(cwd) = std::env::current_dir() { + os.database.set_conversation_by_path(cwd, self).ok(); + } + } + + /// Returns the conversation id. + pub fn conversation_id(&self) -> &str { + self.conversation_id.as_ref() + } + + /// Returns the message id associated with the last assistant message, if present. + /// + /// This is equivalent to `utterance_id` in the Q API. + pub fn message_id(&self) -> Option<&str> { + self.history + .back() + .and_then(|HistoryEntry { assistant, .. }| assistant.message_id()) + } + + pub fn latest_tool_use_ids(&self) -> Option { + self.history + .back() + .and_then(|HistoryEntry { assistant, .. }| assistant.tool_uses()) + .map(|tools| (tools.iter().map(|t| t.id.as_str()).collect::>().join(","))) + } + + pub fn latest_tool_use_names(&self) -> Option { + self.history + .back() + .and_then(|HistoryEntry { assistant, .. }| assistant.tool_uses()) + .map(|tools| (tools.iter().map(|t| t.name.as_str()).collect::>().join(","))) + } + + /// Updates the history so that, when non-empty, the following invariants are in place: + /// 1. The history length is `<= MAX_CONVERSATION_STATE_HISTORY_LEN`. Oldest messages are + /// dropped. + /// 2. The first message is from the user, and does not contain tool results. Oldest messages + /// are dropped. + /// 3. If the last message from the assistant contains tool results, and a next user message is + /// set without tool results, then the user message will have "cancelled" tool results. + pub fn enforce_conversation_invariants(&mut self) { + self.valid_history_range = + enforce_conversation_invariants(&mut self.history, &mut self.next_message, &self.tools); + } + + /// Here we also need to make sure that the tool result corresponds to one of the tools + /// in the list. Otherwise we will see validation error from the backend. There are three + /// such circumstances where intervention would be needed: + /// 1. The model had decided to call a tool with its partial name AND there is only one such + /// tool, in which case we would automatically resolve this tool call to its correct name. + /// This will NOT result in an error in its tool result. The intervention here is to + /// substitute the partial name with its full name. + /// 2. The model had decided to call a tool with its partial name AND there are multiple tools + /// it could be referring to, in which case we WILL return an error in the tool result. The + /// intervention here is to substitute the ambiguous, partial name with a dummy. + /// 3. The model had decided to call a tool that does not exist. The intervention here is to + /// substitute the non-existent tool name with a dummy. + pub fn enforce_tool_use_history_invariants(&mut self) { + enforce_tool_use_history_invariants(&mut self.history, &self.tools); + } + + pub fn add_tool_results(&mut self, tool_results: Vec) { + debug_assert!(self.next_message.is_none()); + self.next_message = Some(UserMessage::new_tool_use_results(tool_results)); + } + + pub fn add_tool_results_with_images(&mut self, tool_results: Vec, images: Vec) { + debug_assert!(self.next_message.is_none()); + self.next_message = Some(UserMessage::new_tool_use_results_with_images( + tool_results, + images, + Some(Utc::now()), + )); + } + + /// Sets the next user message with "cancelled" tool results. + pub fn abandon_tool_use(&mut self, tools_to_be_abandoned: &[QueuedTool], deny_input: String) { + self.next_message = Some(UserMessage::new_cancelled_tool_uses( + Some(deny_input), + tools_to_be_abandoned.iter().map(|t| t.id.as_str()), + Some(Utc::now()), + )); + } + + /// Returns a [FigConversationState] capable of being sent by [api_client::StreamingClient]. + /// + /// Params: + /// - `run_hooks` - whether hooks should be executed and included as context + pub async fn as_sendable_conversation_state( + &mut self, + os: &Os, + stderr: &mut impl Write, + run_perprompt_hooks: bool, + ) -> Result { + debug_assert!(self.next_message.is_some()); + self.enforce_conversation_invariants(); + self.history.drain(self.valid_history_range.1..); + self.history.drain(..self.valid_history_range.0); + + let context = self.backend_conversation_state(os, run_perprompt_hooks, stderr).await?; + if !context.dropped_context_files.is_empty() { + execute!( + stderr, + style::SetForegroundColor(Color::DarkYellow), + style::Print("\nSome context files are dropped due to size limit, please run "), + style::SetForegroundColor(Color::DarkGreen), + style::Print("/context show "), + style::SetForegroundColor(Color::DarkYellow), + style::Print("to learn more.\n"), + style::SetForegroundColor(style::Color::Reset) + ) + .ok(); + } + + Ok(context + .into_fig_conversation_state() + .expect("unable to construct conversation state")) + } + + pub async fn update_state(&mut self, force_update: bool) { + let needs_update = self.tool_manager.has_new_stuff.load(Ordering::Acquire) || force_update; + if !needs_update { + return; + } + self.tool_manager.update().await; + // TODO: make this more targeted so we don't have to clone the entire list of tools + self.tools = self + .tool_manager + .schema + .values() + .fold(HashMap::>::new(), |mut acc, v| { + let tool = Tool::ToolSpecification(ToolSpecification { + name: v.name.clone(), + description: v.description.clone(), + input_schema: v.input_schema.clone().into(), + }); + acc.entry(v.tool_origin.clone()) + .and_modify(|tools| tools.push(tool.clone())) + .or_insert(vec![tool]); + acc + }); + self.tool_manager.has_new_stuff.store(false, Ordering::Release); + // We call this in [Self::enforce_conversation_invariants] as well. But we need to call it + // here as well because when it's being called in [Self::enforce_conversation_invariants] + // it is only checking the last entry. + self.enforce_tool_use_history_invariants(); + } + + /// Returns a conversation state representation which reflects the exact conversation to send + /// back to the model. + pub async fn backend_conversation_state( + &mut self, + os: &Os, + run_perprompt_hooks: bool, + output: &mut impl Write, + ) -> Result, ChatError> { + self.update_state(false).await; + self.enforce_conversation_invariants(); + + // Run hooks and add to conversation start and next user message. + let mut agent_spawn_context = None; + if let Some(cm) = self.context_manager.as_mut() { + let user_prompt = self.next_message.as_ref().and_then(|m| m.prompt()); + let agent_spawn = cm.run_hooks(HookTrigger::AgentSpawn, output, user_prompt).await?; + agent_spawn_context = format_hook_context(&agent_spawn, HookTrigger::AgentSpawn); + + if let (true, Some(next_message)) = (run_perprompt_hooks, self.next_message.as_mut()) { + let per_prompt = cm + .run_hooks(HookTrigger::UserPromptSubmit, output, next_message.prompt()) + .await?; + if let Some(ctx) = format_hook_context(&per_prompt, HookTrigger::UserPromptSubmit) { + next_message.additional_context = ctx; + } + } + } + + let (context_messages, dropped_context_files) = self.context_messages(os, agent_spawn_context).await; + + Ok(BackendConversationState { + conversation_id: self.conversation_id.as_str(), + next_user_message: self.next_message.as_ref(), + history: self + .history + .range(self.valid_history_range.0..self.valid_history_range.1), + context_messages, + dropped_context_files, + tools: &self.tools, + model_id: self.model_info.as_ref().map(|m| m.model_id.as_str()), + }) + } + + /// Returns a [FigConversationState] capable of replacing the history of the current + /// conversation with a summary generated by the model. + /// + /// The resulting summary should update the state by immediately following with + /// [ConversationState::replace_history_with_summary]. + pub async fn create_summary_request( + &mut self, + os: &Os, + custom_prompt: Option>, + strategy: CompactStrategy, + ) -> Result { + let mut summary_content = match custom_prompt { + Some(custom_prompt) => { + // Make the custom instructions much more prominent and directive + format!( + "[SYSTEM NOTE: This is an automated summarization request, not from the user]\n\n\ + FORMAT REQUIREMENTS: Create a structured, concise summary in bullet-point format. DO NOT respond conversationally. DO NOT address the user directly.\n\n\ + IMPORTANT CUSTOM INSTRUCTION: {}\n\n\ + Your task is to create a structured summary document containing:\n\ + 1) A bullet-point list of key topics/questions covered\n\ + 2) Bullet points for all significant tools executed and their results\n\ + 3) Bullet points for any code or technical information shared\n\ + 4) A section of key insights gained\n\n\ + FORMAT THE SUMMARY IN THIRD PERSON, NOT AS A DIRECT RESPONSE. Example format:\n\n\ + ## CONVERSATION SUMMARY\n\ + * Topic 1: Key information\n\ + * Topic 2: Key information\n\n\ + ## TOOLS EXECUTED\n\ + * Tool X: Result Y\n\n\ + Remember this is a DOCUMENT not a chat response. The custom instruction above modifies what to prioritize.\n\ + FILTER OUT CHAT CONVENTIONS (greetings, offers to help, etc).", + custom_prompt.as_ref() + ) + }, + None => { + // Default prompt + "[SYSTEM NOTE: This is an automated summarization request, not from the user]\n\n\ + FORMAT REQUIREMENTS: Create a structured, concise summary in bullet-point format. DO NOT respond conversationally. DO NOT address the user directly.\n\n\ + Your task is to create a structured summary document containing:\n\ + 1) A bullet-point list of key topics/questions covered\n\ + 2) Bullet points for all significant tools executed and their results\n\ + 3) Bullet points for any code or technical information shared\n\ + 4) A section of key insights gained\n\n\ + FORMAT THE SUMMARY IN THIRD PERSON, NOT AS A DIRECT RESPONSE. Example format:\n\n\ + ## CONVERSATION SUMMARY\n\ + * Topic 1: Key information\n\ + * Topic 2: Key information\n\n\ + ## TOOLS EXECUTED\n\ + * Tool X: Result Y\n\n\ + Remember this is a DOCUMENT not a chat response.\n\ + FILTER OUT CHAT CONVENTIONS (greetings, offers to help, etc).".to_string() + }, + }; + if let Some((summary, _)) = &self.latest_summary { + summary_content.push_str("\n\n"); + summary_content.push_str(CONTEXT_ENTRY_START_HEADER); + summary_content.push_str("This summary contains ALL relevant information from our previous conversation including tool uses, results, code analysis, and file operations. YOU MUST be sure to include this information when creating your summarization document.\n\n"); + summary_content.push_str("SUMMARY CONTENT:\n"); + summary_content.push_str(summary); + summary_content.push('\n'); + summary_content.push_str(CONTEXT_ENTRY_END_HEADER); + } + + let conv_state = self.backend_conversation_state(os, false, &mut vec![]).await?; + let mut summary_message = Some(UserMessage::new_prompt(summary_content.clone(), None)); + + // Create the history according to the passed compact strategy. + let mut history = conv_state.history.cloned().collect::>(); + history.drain((history.len().saturating_sub(strategy.messages_to_exclude))..); + if strategy.truncate_large_messages { + for HistoryEntry { user, .. } in &mut history { + user.truncate_safe(strategy.max_message_length); + } + } + + // Only send the dummy tool spec in order to prevent the model from ever attempting a tool + // use. + let mut tools = self.tools.clone(); + tools.retain(|k, v| match k { + ToolOrigin::Native => { + v.retain(|tool| match tool { + Tool::ToolSpecification(tool_spec) => tool_spec.name == DUMMY_TOOL_NAME, + }); + true + }, + ToolOrigin::McpServer(_) => false, + }); + + enforce_conversation_invariants(&mut history, &mut summary_message, &tools); + + Ok(FigConversationState { + conversation_id: Some(self.conversation_id.clone()), + user_input_message: summary_message + .unwrap_or(UserMessage::new_prompt(summary_content, None)) // should not happen + .into_user_input_message(self.model_info.as_ref().map(|m| m.model_id.clone()), &tools), + history: Some(flatten_history(history.iter())), + }) + } + + /// `strategy` - The [CompactStrategy] used for the corresponding + /// [ConversationState::create_summary_request]. + pub fn replace_history_with_summary( + &mut self, + summary: String, + strategy: CompactStrategy, + request_metadata: RequestMetadata, + ) { + self.history + .drain(..(self.history.len().saturating_sub(strategy.messages_to_exclude))); + self.latest_summary = Some((summary, request_metadata)); + } + + pub fn current_profile(&self) -> Option<&str> { + if let Some(cm) = self.context_manager.as_ref() { + Some(cm.current_profile.as_str()) + } else { + None + } + } + + /// Returns pairs of user and assistant messages to include as context in the message history + /// including both summaries and context files if available, and the dropped context files. + /// + /// TODO: + /// - Either add support for multiple context messages if the context is too large to fit inside + /// a single user message, or handle this case more gracefully. For now, always return 2 + /// messages. + /// - Cache this return for some period of time. + async fn context_messages( + &mut self, + os: &Os, + additional_context: Option, + ) -> (Option>, Vec<(String, String)>) { + let mut context_content = String::new(); + let mut dropped_context_files = Vec::new(); + if let Some((summary, _)) = &self.latest_summary { + context_content.push_str(CONTEXT_ENTRY_START_HEADER); + context_content.push_str("This summary contains ALL relevant information from our previous conversation including tool uses, results, code analysis, and file operations. YOU MUST reference this information when answering questions and explicitly acknowledge specific details from the summary when they're relevant to the current question.\n\n"); + context_content.push_str("SUMMARY CONTENT:\n"); + context_content.push_str(summary); + context_content.push('\n'); + context_content.push_str(CONTEXT_ENTRY_END_HEADER); + } + + // Add context files if available + if let Some(context_manager) = self.context_manager.as_mut() { + match context_manager.collect_context_files_with_limit(os).await { + Ok((files_to_use, files_dropped)) => { + if !files_dropped.is_empty() { + dropped_context_files.extend(files_dropped); + } + + if !files_to_use.is_empty() { + context_content.push_str(CONTEXT_ENTRY_START_HEADER); + for (filename, content) in files_to_use { + context_content.push_str(&format!("[{}]\n{}\n", filename, content)); + } + context_content.push_str(CONTEXT_ENTRY_END_HEADER); + } + }, + Err(e) => { + warn!("Failed to get context files: {}", e); + }, + } + } + + if let Some(context) = additional_context { + context_content.push_str(&context); + } + + if let Some(agent_prompt) = self.agents.get_active().and_then(|a| a.prompt.as_ref()) { + context_content.push_str(&format!("Follow this instruction: {}", agent_prompt)); + } + + if !context_content.is_empty() { + self.context_message_length = Some(context_content.len()); + let user = UserMessage::new_prompt(context_content, None); + let assistant = AssistantMessage::new_response(None, "I will fully incorporate this information when generating my responses, and explicitly acknowledge relevant parts of the summary when answering questions.".into()); + ( + Some(vec![HistoryEntry { + user, + assistant, + request_metadata: None, + }]), + dropped_context_files, + ) + } else { + (None, dropped_context_files) + } + } + + /// The length of the user message used as context, if any. + pub fn context_message_length(&self) -> Option { + self.context_message_length + } + + /// Calculate the total character count in the conversation + pub async fn calculate_char_count(&mut self, os: &Os) -> Result { + Ok(self + .backend_conversation_state(os, false, &mut vec![]) + .await? + .char_count()) + } + + /// Get the current token warning level + pub async fn get_token_warning_level(&mut self, os: &Os) -> Result { + let total_chars = self.calculate_char_count(os).await?; + let max_chars = TokenCounter::token_to_chars(context_window_tokens(self.model_info.as_ref())); + + Ok(if *total_chars >= max_chars { + TokenWarningLevel::Critical + } else { + TokenWarningLevel::None + }) + } + + pub fn append_user_transcript(&mut self, message: &str) { + self.append_transcript(format!("> {}", message.replace("\n", "> \n"))); + } + + pub fn append_assistant_transcript(&mut self, message: &AssistantMessage) { + let tool_uses = message.tool_uses().map_or("none".to_string(), |tools| { + tools.iter().map(|tool| tool.name.clone()).collect::>().join(",") + }); + self.append_transcript(format!("{}\n[Tool uses: {tool_uses}]", message.content())); + } + + pub fn append_transcript(&mut self, message: String) { + if self.transcript.len() >= MAX_CONVERSATION_STATE_HISTORY_LEN { + self.transcript.pop_front(); + } + self.transcript.push_back(message); + } + + /// Swapping agent involves the following: + /// - Reinstantiate the context manager + /// - Swap agent on tool manager + pub async fn swap_agent( + &mut self, + os: &mut Os, + output: &mut impl Write, + agent_name: &str, + ) -> Result<(), ChatError> { + let agent = self.agents.switch(agent_name).map_err(ChatError::AgentSwapError)?; + self.context_manager.replace({ + ContextManager::from_agent(agent, calc_max_context_files_size(self.model_info.as_ref())) + .map_err(|e| ChatError::Custom(format!("Context manager has failed to instantiate: {e}").into()))? + }); + + self.tool_manager + .swap_agent(os, output, agent) + .await + .map_err(ChatError::AgentSwapError)?; + + self.update_state(true).await; + + Ok(()) + } +} + +/// Represents a conversation state that can be converted into a [FigConversationState] (the type +/// used by the API client). Represents borrowed data, and reflects an exact [FigConversationState] +/// that can be generated from [ConversationState] at any point in time. +/// +/// This is intended to provide us ways to accurately assess the exact state that is sent to the +/// model without having to needlessly clone and mutate [ConversationState] in strange ways. +pub type BackendConversationState<'a> = + BackendConversationStateImpl<'a, std::collections::vec_deque::Iter<'a, HistoryEntry>, Option>>; + +/// See [BackendConversationState] +#[derive(Debug, Clone)] +pub struct BackendConversationStateImpl<'a, T, U> { + pub conversation_id: &'a str, + pub next_user_message: Option<&'a UserMessage>, + pub history: T, + pub context_messages: U, + pub dropped_context_files: Vec<(String, String)>, + pub tools: &'a HashMap>, + pub model_id: Option<&'a str>, +} + +impl BackendConversationStateImpl<'_, std::collections::vec_deque::Iter<'_, HistoryEntry>, Option>> { + fn into_fig_conversation_state(self) -> eyre::Result { + let history = flatten_history(self.context_messages.unwrap_or_default().iter().chain(self.history)); + let user_input_message: UserInputMessage = self + .next_user_message + .cloned() + .map(|msg| msg.into_user_input_message(self.model_id.map(str::to_string), self.tools)) + .ok_or(eyre::eyre!("next user message is not set"))?; + + Ok(FigConversationState { + conversation_id: Some(self.conversation_id.to_string()), + user_input_message, + history: Some(history), + }) + } + + pub fn calculate_conversation_size(&self) -> ConversationSize { + let mut user_chars = 0; + let mut assistant_chars = 0; + let mut context_chars = 0; + + // Count the chars used by the messages in the history. + // this clone is cheap + let history = self.history.clone(); + for HistoryEntry { user, assistant, .. } in history { + user_chars += *user.char_count(); + assistant_chars += *assistant.char_count(); + } + + // Add any chars from context messages, if available. + context_chars += self + .context_messages + .as_ref() + .map(|v| { + v.iter().fold(0, |acc, HistoryEntry { user, assistant, .. }| { + acc + *user.char_count() + *assistant.char_count() + }) + }) + .unwrap_or_default(); + + ConversationSize { + context_messages: context_chars.into(), + user_messages: user_chars.into(), + assistant_messages: assistant_chars.into(), + } + } +} + +/// Reflects a detailed accounting of the context window utilization for a given conversation. +#[derive(Debug, Clone, Copy)] +pub struct ConversationSize { + pub context_messages: CharCount, + pub user_messages: CharCount, + pub assistant_messages: CharCount, +} + +/// Converts a list of user/assistant message pairs into a flattened list of ChatMessage. +fn flatten_history<'a, T>(history: T) -> Vec +where + T: Iterator, +{ + history.fold(Vec::new(), |mut acc, HistoryEntry { user, assistant, .. }| { + acc.push(ChatMessage::UserInputMessage(user.clone().into_history_entry())); + acc.push(ChatMessage::AssistantResponseMessage(assistant.clone().into())); + acc + }) +} + +/// Character count warning levels for conversation size +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TokenWarningLevel { + /// No warning, conversation is within normal limits + None, + /// Critical level - at single warning threshold (600K characters) + Critical, +} + +impl From for ToolInputSchema { + fn from(value: InputSchema) -> Self { + Self { + json: Some(serde_value_to_document(value.0).into()), + } + } +} + +/// Formats hook output to be used within context blocks (e.g., in context messages or in new user +/// prompts). +/// +/// # Returns +/// [Option::Some] if `hook_results` is not empty and at least one hook has content. Otherwise, +/// [Option::None] +fn format_hook_context(hook_results: &[((HookTrigger, Hook), String)], trigger: HookTrigger) -> Option { + if hook_results.iter().all(|(_, content)| content.is_empty()) { + return None; + } + + let mut context_content = String::new(); + + context_content.push_str(CONTEXT_ENTRY_START_HEADER); + context_content.push_str("This section (like others) contains important information that I want you to use in your responses. I have gathered this context from valuable programmatic script hooks. You must follow any requests and consider all of the information in this section"); + if trigger == HookTrigger::AgentSpawn { + context_content.push_str(" for the entire conversation"); + } + context_content.push_str("\n\n"); + + for (_, output) in hook_results.iter().filter(|((h_trigger, _), _)| *h_trigger == trigger) { + context_content.push_str(&format!("{output}\n\n")); + } + context_content.push_str(CONTEXT_ENTRY_END_HEADER); + Some(context_content) +} + +fn enforce_conversation_invariants( + history: &mut VecDeque, + next_message: &mut Option, + tools: &HashMap>, +) -> (usize, usize) { + // First set the valid range as the entire history - this will be truncated as necessary + // later below. + let mut valid_history_range = (0, history.len()); + + // Trim the conversation history by finding the second oldest message from the user without + // tool results - this will be the new oldest message in the history. + // + // Note that we reserve extra slots for [ConversationState::context_messages]. + if (history.len() * 2) > MAX_CONVERSATION_STATE_HISTORY_LEN - 6 { + match history + .iter() + .enumerate() + .skip(1) + .find(|(_, HistoryEntry { user, .. })| -> bool { !user.has_tool_use_results() }) + .map(|v| v.0) + { + Some(i) => { + debug!("removing the first {i} user/assistant response pairs in the history"); + valid_history_range.0 = i; + }, + None => { + debug!("no valid starting user message found in the history, clearing"); + valid_history_range = (0, 0); + // Edge case: if the next message contains tool results, then we have to just + // abandon them. + if next_message.as_ref().is_some_and(|m| m.has_tool_use_results()) { + debug!("abandoning tool results"); + *next_message = Some(UserMessage::new_prompt( + "The conversation history has overflowed, clearing state".to_string(), + None, + )); + } + }, + } + } + + // If the first message contains tool results, then we add the results to the content field + // instead. This is required to avoid validation errors. + if let Some(HistoryEntry { user, .. }) = history.front_mut() { + if user.has_tool_use_results() { + user.replace_content_with_tool_use_results(); + } + } + + // If the next message is set with tool results, but the previous assistant message is not a + // tool use, then we add the results to the content field instead. + match ( + next_message.as_mut(), + history.range(valid_history_range.0..valid_history_range.1).last(), + ) { + (Some(next_message), prev_msg) if next_message.has_tool_use_results() => match prev_msg { + // None | Some((_, AssistantMessage::Response { .. }, _)) => { + None + | Some(HistoryEntry { + assistant: AssistantMessage::Response { .. }, + .. + }) => { + next_message.replace_content_with_tool_use_results(); + }, + _ => (), + }, + (_, _) => (), + } + + // If the last message from the assistant contains tool uses AND next_message is set, we need to + // ensure that next_message contains tool results. + if let ( + Some(HistoryEntry { + assistant: AssistantMessage::ToolUse { tool_uses, .. }, + .. + }), + Some(user_msg), + ) = ( + history.range(valid_history_range.0..valid_history_range.1).last(), + next_message, + ) { + if !user_msg.has_tool_use_results() { + debug!( + "last assistant message contains tool uses, but next message is set and does not contain tool results. setting tool results as cancelled" + ); + *user_msg = UserMessage::new_cancelled_tool_uses( + user_msg.prompt().map(|p| p.to_string()), + tool_uses.iter().map(|t| t.id.as_str()), + None, + ); + } + } + + enforce_tool_use_history_invariants(history, tools); + + valid_history_range +} + +fn enforce_tool_use_history_invariants(history: &mut VecDeque, tools: &HashMap>) { + let tool_names: HashSet<_> = tools + .values() + .flat_map(|tools| { + tools.iter().map(|tool| match tool { + Tool::ToolSpecification(tool_specification) => tool_specification.name.as_str(), + }) + }) + .filter(|name| *name != DUMMY_TOOL_NAME) + .collect(); + + for HistoryEntry { assistant, .. } in history { + if let AssistantMessage::ToolUse { tool_uses, .. } = assistant { + for tool_use in tool_uses { + if tool_names.contains(tool_use.name.as_str()) { + continue; + } + + if tool_names.contains(tool_use.orig_name.as_str()) { + tool_use.name = tool_use.orig_name.clone(); + tool_use.args = tool_use.orig_args.clone(); + continue; + } + + let names: Vec<&str> = tool_names + .iter() + .filter_map(|name| { + if name.ends_with(&tool_use.name) { + Some(*name) + } else { + None + } + }) + .collect(); + + // There's only one tool use matching, so we can just replace it with the + // found name. + if names.len() == 1 { + tool_use.name = (*names.first().unwrap()).to_string(); + continue; + } + + // Otherwise, we have to replace it with a dummy. + tool_use.name = DUMMY_TOOL_NAME.to_string(); + } + } + } +} + +fn default_true() -> bool { + true +} +#[cfg(test)] +mod tests { + use super::super::message::AssistantToolUse; + use super::*; + use crate::api_client::model::{ + AssistantResponseMessage, + ToolResultStatus, + }; + use crate::cli::agent::{ + Agent, + Agents, + }; + use crate::cli::chat::tool_manager::ToolManager; + + const AMAZONQ_FILENAME: &str = "AmazonQ.md"; + + fn assert_conversation_state_invariants(state: FigConversationState, assertion_iteration: usize) { + if let Some(Some(msg)) = state.history.as_ref().map(|h| h.first()) { + assert!( + matches!(msg, ChatMessage::UserInputMessage(_)), + "{assertion_iteration}: First message in the history must be from the user, instead found: {:?}", + msg + ); + } + if let Some(Some(msg)) = state.history.as_ref().map(|h| h.last()) { + assert!( + matches!(msg, ChatMessage::AssistantResponseMessage(_)), + "{assertion_iteration}: Last message in the history must be from the assistant, instead found: {:?}", + msg + ); + // If the last message from the assistant contains tool uses, then the next user + // message must contain tool results. + match (state.user_input_message.user_input_message_context.as_ref(), msg) { + ( + Some(os), + ChatMessage::AssistantResponseMessage(AssistantResponseMessage { + tool_uses: Some(tool_uses), + .. + }), + ) if !tool_uses.is_empty() => { + assert!( + os.tool_results.as_ref().is_some_and(|r| !r.is_empty()), + "The user input message must contain tool results when the last assistant message contains tool uses" + ); + }, + _ => {}, + } + } + + if let Some(history) = state.history.as_ref() { + for (i, msg) in history.iter().enumerate() { + // User message checks. + if let ChatMessage::UserInputMessage(user) = msg { + assert!( + user.user_input_message_context + .as_ref() + .is_none_or(|os| os.tools.is_none()), + "the tool specification should be empty for all user messages in the history" + ); + + // Check that messages with tool results are immediately preceded by an + // assistant message with tool uses. + if user + .user_input_message_context + .as_ref() + .is_some_and(|os| os.tool_results.as_ref().is_some_and(|r| !r.is_empty())) + { + match history.get(i.checked_sub(1).unwrap_or_else(|| { + panic!( + "{assertion_iteration}: first message in the history should not contain tool results" + ) + })) { + Some(ChatMessage::AssistantResponseMessage(assistant)) => { + assert!(assistant.tool_uses.is_some()); + }, + _ => panic!( + "expected an assistant response message with tool uses at index: {}", + i - 1 + ), + } + } + } + } + } + + let actual_history_len = state.history.unwrap_or_default().len(); + assert!( + actual_history_len <= MAX_CONVERSATION_STATE_HISTORY_LEN, + "history should not extend past the max limit of {}, instead found length {}", + MAX_CONVERSATION_STATE_HISTORY_LEN, + actual_history_len + ); + + let os = state + .user_input_message + .user_input_message_context + .as_ref() + .expect("user input message context must exist"); + assert!( + os.tools.is_some(), + "Currently, the tool spec must be included in the next user message" + ); + } + + #[tokio::test] + async fn test_conversation_state_history_handling_truncation() { + let mut os = Os::new().await.unwrap(); + let agents = Agents::default(); + let mut output = vec![]; + + let mut tool_manager = ToolManager::default(); + let mut conversation = ConversationState::new( + "fake_conv_id", + agents, + tool_manager.load_tools(&mut os, &mut output).await.unwrap(), + tool_manager, + None, + &os, + false, + ) + .await; + + // First, build a large conversation history. We need to ensure that the order is always + // User -> Assistant -> User -> Assistant ...and so on. + conversation.set_next_user_message("start".to_string()).await; + for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { + let s = conversation + .as_sendable_conversation_state(&os, &mut vec![], true) + .await + .unwrap(); + assert_conversation_state_invariants(s, i); + conversation.push_assistant_message(&mut os, AssistantMessage::new_response(None, i.to_string()), None); + conversation.set_next_user_message(i.to_string()).await; + } + } + + #[tokio::test] + async fn test_conversation_state_history_handling_with_tool_results() { + let mut os = Os::new().await.unwrap(); + let agents = Agents::default(); + + // Build a long conversation history of tool use results. + let mut tool_manager = ToolManager::default(); + let tool_config = tool_manager.load_tools(&mut os, &mut vec![]).await.unwrap(); + let mut conversation = ConversationState::new( + "fake_conv_id", + agents.clone(), + tool_config.clone(), + tool_manager.clone(), + None, + &os, + false, + ) + .await; + conversation.set_next_user_message("start".to_string()).await; + for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { + let s = conversation + .as_sendable_conversation_state(&os, &mut vec![], true) + .await + .unwrap(); + assert_conversation_state_invariants(s, i); + + conversation.push_assistant_message( + &mut os, + AssistantMessage::new_tool_use(None, i.to_string(), vec![AssistantToolUse { + id: "tool_id".to_string(), + name: "tool name".to_string(), + args: serde_json::Value::Null, + ..Default::default() + }]), + None, + ); + conversation.add_tool_results(vec![ToolUseResult { + tool_use_id: "tool_id".to_string(), + content: vec![], + status: ToolResultStatus::Success, + }]); + } + + // Build a long conversation history of user messages mixed in with tool results. + let mut conversation = ConversationState::new( + "fake_conv_id", + agents, + tool_config.clone(), + tool_manager.clone(), + None, + &os, + false, + ) + .await; + conversation.set_next_user_message("start".to_string()).await; + for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { + let s = conversation + .as_sendable_conversation_state(&os, &mut vec![], true) + .await + .unwrap(); + assert_conversation_state_invariants(s, i); + if i % 3 == 0 { + conversation.push_assistant_message( + &mut os, + AssistantMessage::new_tool_use(None, i.to_string(), vec![AssistantToolUse { + id: "tool_id".to_string(), + name: "tool name".to_string(), + args: serde_json::Value::Null, + ..Default::default() + }]), + None, + ); + conversation.add_tool_results(vec![ToolUseResult { + tool_use_id: "tool_id".to_string(), + content: vec![], + status: ToolResultStatus::Success, + }]); + } else { + conversation.push_assistant_message(&mut os, AssistantMessage::new_response(None, i.to_string()), None); + conversation.set_next_user_message(i.to_string()).await; + } + } + } + + #[tokio::test] + async fn test_conversation_state_with_context_files() { + let mut os = Os::new().await.unwrap(); + let agents = { + let mut agents = Agents::default(); + let mut agent = Agent::default(); + agent.resources.push(AMAZONQ_FILENAME.into()); + agents.agents.insert("TestAgent".to_string(), agent); + agents.switch("TestAgent").expect("Agent switch failed"); + agents + }; + os.fs.write(AMAZONQ_FILENAME, "test context").await.unwrap(); + let mut output = vec![]; + + let mut tool_manager = ToolManager::default(); + let mut conversation = ConversationState::new( + "fake_conv_id", + agents, + tool_manager.load_tools(&mut os, &mut output).await.unwrap(), + tool_manager, + None, + &os, + false, + ) + .await; + + // First, build a large conversation history. We need to ensure that the order is always + // User -> Assistant -> User -> Assistant ...and so on. + conversation.set_next_user_message("start".to_string()).await; + for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { + let s = conversation + .as_sendable_conversation_state(&os, &mut vec![], true) + .await + .unwrap(); + + // Ensure that the first two messages are the fake context messages. + let hist = s.history.as_ref().unwrap(); + let user = &hist[0]; + let assistant = &hist[1]; + match (user, assistant) { + (ChatMessage::UserInputMessage(user), ChatMessage::AssistantResponseMessage(_)) => { + assert!( + user.content.contains("test context"), + "expected context message to contain context file, instead found: {}", + user.content + ); + }, + _ => panic!("Expected the first two messages to be from the user and the assistant"), + } + + assert_conversation_state_invariants(s, i); + + conversation.push_assistant_message(&mut os, AssistantMessage::new_response(None, i.to_string()), None); + conversation.set_next_user_message(i.to_string()).await; + } + } + + #[tokio::test] + async fn test_tangent_mode() { + let mut os = Os::new().await.unwrap(); + let agents = Agents::default(); + let mut tool_manager = ToolManager::default(); + let mut conversation = ConversationState::new( + "fake_conv_id", + agents, + tool_manager.load_tools(&mut os, &mut vec![]).await.unwrap(), + tool_manager, + None, + &os, + false, // mcp_enabled + ) + .await; + + // Initially not in tangent mode + assert!(!conversation.is_in_tangent_mode()); + + // Add some main conversation history + conversation + .set_next_user_message("main conversation".to_string()) + .await; + conversation.push_assistant_message( + &mut os, + AssistantMessage::new_response(None, "main response".to_string()), + None, + ); + conversation.transcript.push_back("main transcript".to_string()); + + let main_history_len = conversation.history.len(); + let main_transcript_len = conversation.transcript.len(); + + // Enter tangent mode (toggle from normal to tangent) + conversation.enter_tangent_mode(); + assert!(conversation.is_in_tangent_mode()); + + // History should be preserved for tangent (not cleared) + assert_eq!(conversation.history.len(), main_history_len); + assert_eq!(conversation.transcript.len(), main_transcript_len); + assert!(conversation.next_message.is_none()); + + // Add tangent conversation + conversation + .set_next_user_message("tangent conversation".to_string()) + .await; + conversation.push_assistant_message( + &mut os, + AssistantMessage::new_response(None, "tangent response".to_string()), + None, + ); + + // During tangent mode, history should have grown + assert_eq!(conversation.history.len(), main_history_len + 1); + assert_eq!(conversation.transcript.len(), main_transcript_len + 1); + + // Exit tangent mode (toggle from tangent to normal) + conversation.exit_tangent_mode(); + assert!(!conversation.is_in_tangent_mode()); + + // Main conversation should be restored (tangent additions discarded) + assert_eq!(conversation.history.len(), main_history_len); // Back to original length + assert_eq!(conversation.transcript.len(), main_transcript_len); // Back to original length + assert!(conversation.transcript.contains(&"main transcript".to_string())); + assert!(!conversation.transcript.iter().any(|t| t.contains("tangent"))); + + // Test multiple toggles + conversation.enter_tangent_mode(); + assert!(conversation.is_in_tangent_mode()); + conversation.exit_tangent_mode(); + assert!(!conversation.is_in_tangent_mode()); + } +} + + +---- crates/chat-cli/src/cli/chat/mod.rs ---- +pub mod cli; +mod consts; +pub mod context; +mod conversation; +mod error_formatter; +mod input_source; +mod message; +mod parse; +use std::path::MAIN_SEPARATOR; +mod line_tracker; +mod parser; +mod prompt; +mod prompt_parser; +mod server_messenger; +#[cfg(unix)] +mod skim_integration; +mod token_counter; +pub mod tool_manager; +pub mod tools; +pub mod util; +use std::borrow::Cow; +use std::collections::{ + HashMap, + VecDeque, +}; +use std::io::{ + IsTerminal, + Read, + Write, +}; +use std::process::ExitCode; +use std::sync::Arc; +use std::time::{ + Duration, + Instant, +}; + +use amzn_codewhisperer_client::types::SubscriptionStatus; +use clap::{ + Args, + CommandFactory, + Parser, +}; +use cli::compact::CompactStrategy; +use cli::model::{ + get_available_models, + select_model, +}; +pub use conversation::ConversationState; +use conversation::TokenWarningLevel; +use crossterm::style::{ + Attribute, + Color, + Stylize, +}; +use crossterm::{ + cursor, + execute, + queue, + style, + terminal, +}; +use eyre::{ + Report, + Result, + bail, + eyre, +}; +use input_source::InputSource; +use message::{ + AssistantMessage, + AssistantToolUse, + ToolUseResult, + ToolUseResultBlock, +}; +use parse::{ + ParseState, + interpret_markdown, +}; +use parser::{ + RecvErrorKind, + RequestMetadata, + SendMessageStream, +}; +use regex::Regex; +use spinners::{ + Spinner, + Spinners, +}; +use thiserror::Error; +use time::OffsetDateTime; +use token_counter::TokenCounter; +use tokio::signal::ctrl_c; +use tokio::sync::{ + Mutex, + broadcast, +}; +use tool_manager::{ + PromptQuery, + PromptQueryResult, + ToolManager, + ToolManagerBuilder, +}; +use tools::gh_issue::GhIssueContext; +use tools::{ + NATIVE_TOOLS, + OutputKind, + QueuedTool, + Tool, + ToolSpec, +}; +use tracing::{ + debug, + error, + info, + trace, + warn, +}; +use util::images::RichImageBlock; +use util::ui::draw_box; +use util::{ + animate_output, + play_notification_bell, +}; +use winnow::Partial; +use winnow::stream::Offset; + +use super::agent::{ + DEFAULT_AGENT_NAME, + PermissionEvalResult, +}; +use crate::api_client::model::ToolResultStatus; +use crate::api_client::{ + self, + ApiClientError, +}; +use crate::auth::AuthError; +use crate::auth::builder_id::is_idc_user; +use crate::cli::agent::Agents; +use crate::cli::chat::cli::SlashCommand; +use crate::cli::chat::cli::model::find_model; +use crate::cli::chat::cli::prompts::{ + GetPromptError, + PromptsSubcommand, +}; +use crate::cli::chat::util::sanitize_unicode_tags; +use crate::database::settings::Setting; +use crate::mcp_client::Prompt; +use crate::os::Os; +use crate::telemetry::core::{ + AgentConfigInitArgs, + ChatAddedMessageParams, + ChatConversationType, + MessageMetaTag, + RecordUserTurnCompletionArgs, + ToolUseEventBuilder, +}; +use crate::telemetry::{ + ReasonCode, + TelemetryResult, + get_error_reason, +}; +use crate::util::MCP_SERVER_TOOL_DELIMITER; + +const LIMIT_REACHED_TEXT: &str = color_print::cstr! { "You've used all your free requests for this month. You have two options: +1. Upgrade to a paid subscription for increased limits. See our Pricing page for what's included> https://aws.amazon.com/q/developer/pricing/ +2. Wait until next month when your limit automatically resets." }; + +pub const EXTRA_HELP: &str = color_print::cstr! {" +MCP: +You can now configure the Amazon Q CLI to use MCP servers. \nLearn how: https://docs.aws.amazon.com/amazonq/latest/qdeveloper-ug/qdev-mcp.html + +Tips: +!{command} Quickly execute a command in your current session +Ctrl(^) + j Insert new-line to provide multi-line prompt + Alternatively, [Alt(⌥) + Enter(⏎)] +Ctrl(^) + s Fuzzy search commands and context files + Use Tab to select multiple items + Change the keybind using: q settings chat.skimCommandKey x +Ctrl(^) + t Toggle tangent mode for isolated conversations + Change the keybind using: q settings chat.tangentModeKey x +chat.editMode The prompt editing mode (vim or emacs) + Change using: q settings chat.skimCommandKey x +"}; + +#[derive(Debug, Clone, PartialEq, Eq, Default, Args)] +pub struct ChatArgs { + /// Resumes the previous conversation from this directory. + #[arg(short, long)] + pub resume: bool, + /// Context profile to use + #[arg(long = "agent", alias = "profile")] + pub agent: Option, + /// Current model to use + #[arg(long = "model")] + pub model: Option, + /// Allows the model to use any tool to run commands without asking for confirmation. + #[arg(short = 'a', long)] + pub trust_all_tools: bool, + /// Trust only this set of tools. Example: trust some tools: + /// '--trust-tools=fs_read,fs_write', trust no tools: '--trust-tools=' + #[arg(long, value_delimiter = ',', value_name = "TOOL_NAMES")] + pub trust_tools: Option>, + /// Whether the command should run without expecting user input + #[arg(long, alias = "non-interactive")] + pub no_interactive: bool, + /// The first question to ask + pub input: Option, +} + +impl ChatArgs { + pub async fn execute(mut self, os: &mut Os) -> Result { + let mut input = self.input; + + if self.no_interactive && input.is_none() { + if !std::io::stdin().is_terminal() { + let mut buffer = String::new(); + match std::io::stdin().read_to_string(&mut buffer) { + Ok(_) => { + if !buffer.trim().is_empty() { + input = Some(buffer.trim().to_string()); + } + }, + Err(e) => { + eprintln!("Error reading from stdin: {}", e); + }, + } + } + + if input.is_none() { + bail!("Input must be supplied when running in non-interactive mode"); + } + } + + let stdout = std::io::stdout(); + let mut stderr = std::io::stderr(); + + let args: Vec = std::env::args().collect(); + if args + .iter() + .any(|arg| arg == "--profile" || arg.starts_with("--profile=")) + { + execute!( + stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("WARNING: "), + style::SetForegroundColor(Color::Reset), + style::Print("--profile is deprecated, use "), + style::SetForegroundColor(Color::Green), + style::Print("--agent"), + style::SetForegroundColor(Color::Reset), + style::Print(" instead\n") + )?; + } + + let conversation_id = uuid::Uuid::new_v4().to_string(); + info!(?conversation_id, "Generated new conversation id"); + + // Check MCP status once at the beginning of the session + let mcp_enabled = match os.client.is_mcp_enabled().await { + Ok(enabled) => enabled, + Err(err) => { + tracing::warn!(?err, "Failed to check MCP configuration, defaulting to enabled"); + true + }, + }; + + let agents = { + let skip_migration = self.no_interactive; + let (mut agents, md) = + Agents::load(os, self.agent.as_deref(), skip_migration, &mut stderr, mcp_enabled).await; + agents.trust_all_tools = self.trust_all_tools; + + os.telemetry + .send_agent_config_init(&os.database, conversation_id.clone(), AgentConfigInitArgs { + agents_loaded_count: md.load_count as i64, + agents_loaded_failed_count: md.load_failed_count as i64, + legacy_profile_migration_executed: md.migration_performed, + legacy_profile_migrated_count: md.migrated_count as i64, + launched_agent: md.launched_agent, + }) + .await + .map_err(|err| error!(?err, "failed to send agent config init telemetry")) + .ok(); + + // Only show MCP safety message if MCP is enabled and has servers + if mcp_enabled + && agents + .get_active() + .is_some_and(|a| !a.mcp_servers.mcp_servers.is_empty()) + { + if !self.no_interactive && !os.database.settings.get_bool(Setting::McpLoadedBefore).unwrap_or(false) { + execute!( + stderr, + style::Print( + "To learn more about MCP safety, see https://docs.aws.amazon.com/amazonq/latest/qdeveloper-ug/command-line-mcp-security.html\n\n" + ) + )?; + } + os.database.settings.set(Setting::McpLoadedBefore, true).await?; + } + + if let Some(trust_tools) = self.trust_tools.take() { + for tool in &trust_tools { + if !tool.starts_with("@") && !NATIVE_TOOLS.contains(&tool.as_str()) { + let _ = queue!( + stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("WARNING: "), + style::SetForegroundColor(Color::Reset), + style::Print("--trust-tools arg for custom tool "), + style::SetForegroundColor(Color::Cyan), + style::Print(tool), + style::SetForegroundColor(Color::Reset), + style::Print(" needs to be prepended with "), + style::SetForegroundColor(Color::Green), + style::Print("@{MCPSERVERNAME}/"), + style::SetForegroundColor(Color::Reset), + style::Print("\n"), + ); + } + } + + let _ = stderr.flush(); + + if let Some(a) = agents.get_active_mut() { + a.allowed_tools.extend(trust_tools); + } + } + + agents + }; + + // If modelId is specified, verify it exists before starting the chat + // Otherwise, CLI will use a default model when starting chat + let (models, default_model_opt) = get_available_models(os).await?; + let model_id: Option = if let Some(requested) = self.model.as_ref() { + if let Some(m) = find_model(&models, requested) { + Some(m.model_id.clone()) + } else { + let available = models + .iter() + .map(|m| m.model_name.as_deref().unwrap_or(&m.model_id)) + .collect::>() + .join(", "); + bail!("Model '{}' does not exist. Available models: {}", requested, available); + } + } else if let Some(saved) = os.database.settings.get_string(Setting::ChatDefaultModel) { + find_model(&models, &saved) + .map(|m| m.model_id.clone()) + .or(Some(default_model_opt.model_id.clone())) + } else { + Some(default_model_opt.model_id.clone()) + }; + + let (prompt_request_sender, prompt_request_receiver) = tokio::sync::broadcast::channel::(5); + let (prompt_response_sender, prompt_response_receiver) = + tokio::sync::broadcast::channel::(5); + let mut tool_manager = ToolManagerBuilder::default() + .prompt_query_result_sender(prompt_response_sender) + .prompt_query_receiver(prompt_request_receiver) + .prompt_query_sender(prompt_request_sender.clone()) + .prompt_query_result_receiver(prompt_response_receiver.resubscribe()) + .conversation_id(&conversation_id) + .agent(agents.get_active().cloned().unwrap_or_default()) + .build(os, Box::new(std::io::stderr()), !self.no_interactive) + .await?; + let tool_config = tool_manager.load_tools(os, &mut stderr).await?; + + ChatSession::new( + os, + stdout, + stderr, + &conversation_id, + agents, + input, + InputSource::new(os, prompt_request_sender, prompt_response_receiver)?, + self.resume, + || terminal::window_size().map(|s| s.columns.into()).ok(), + tool_manager, + model_id, + tool_config, + !self.no_interactive, + mcp_enabled, + ) + .await? + .spawn(os) + .await + .map(|_| ExitCode::SUCCESS) + } +} + +const WELCOME_TEXT: &str = color_print::cstr! {" + ⢠⣶⣶⣦⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣤⣶⣿⣿⣿⣶⣦⡀⠀ + ⠀⠀⠀⣾⡿⢻⣿⡆⠀⠀⠀⢀⣄⡄⢀⣠⣤⣤⡀⢀⣠⣤⣤⡀⠀⠀⢀⣠⣤⣤⣤⣄⠀⠀⢀⣤⣤⣤⣤⣤⣤⡀⠀⠀⣀⣤⣤⣤⣀⠀⠀⠀⢠⣤⡀⣀⣤⣤⣄⡀⠀⠀⠀⠀⠀⠀⢠⣿⣿⠋⠀⠀⠀⠙⣿⣿⡆ + ⠀⠀⣼⣿⠇⠀⣿⣿⡄⠀⠀⢸⣿⣿⠛⠉⠻⣿⣿⠛⠉⠛⣿⣿⠀⠀⠘⠛⠉⠉⠻⣿⣧⠀⠈⠛⠛⠛⣻⣿⡿⠀⢀⣾⣿⠛⠉⠻⣿⣷⡀⠀⢸⣿⡟⠛⠉⢻⣿⣷⠀⠀⠀⠀⠀⠀⣼⣿⡏⠀⠀⠀⠀⠀⢸⣿⣿ + ⠀⢰⣿⣿⣤⣤⣼⣿⣷⠀⠀⢸⣿⣿⠀⠀⠀⣿⣿⠀⠀⠀⣿⣿⠀⠀⢀⣴⣶⣶⣶⣿⣿⠀⠀⠀⣠⣾⡿⠋⠀⠀⢸⣿⣿⠀⠀⠀⣿⣿⡇⠀⢸⣿⡇⠀⠀⢸⣿⣿⠀⠀⠀⠀⠀⠀⢹⣿⣇⠀⠀⠀⠀⠀⢸⣿⡿ + ⢀⣿⣿⠋⠉⠉⠉⢻⣿⣇⠀⢸⣿⣿⠀⠀⠀⣿⣿⠀⠀⠀⣿⣿⠀⠀⣿⣿⡀⠀⣠⣿⣿⠀⢀⣴⣿⣋⣀⣀⣀⡀⠘⣿⣿⣄⣀⣠⣿⣿⠃⠀⢸⣿⡇⠀⠀⢸⣿⣿⠀⠀⠀⠀⠀⠀⠈⢿⣿⣦⣀⣀⣀⣴⣿⡿⠃ + ⠚⠛⠋⠀⠀⠀⠀⠘⠛⠛⠀⠘⠛⠛⠀⠀⠀⠛⠛⠀⠀⠀⠛⠛⠀⠀⠙⠻⠿⠟⠋⠛⠛⠀⠘⠛⠛⠛⠛⠛⠛⠃⠀⠈⠛⠿⠿⠿⠛⠁⠀⠀⠘⠛⠃⠀⠀⠘⠛⠛⠀⠀⠀⠀⠀⠀⠀⠀⠙⠛⠿⢿⣿⣿⣋⠀⠀ + ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠛⠿⢿⡧"}; + +const SMALL_SCREEN_WELCOME_TEXT: &str = color_print::cstr! {"Welcome to Amazon Q!"}; +const RESUME_TEXT: &str = color_print::cstr! {"Picking up where we left off..."}; + +// Only show the model-related tip for now to make users aware of this feature. +const ROTATING_TIPS: [&str; 17] = [ + color_print::cstr! {"You can resume the last conversation from your current directory by launching with + q chat --resume"}, + color_print::cstr! {"Get notified whenever Q CLI finishes responding. + Just run q settings chat.enableNotifications true"}, + color_print::cstr! {"You can use + /editor to edit your prompt with a vim-like experience"}, + color_print::cstr! {"/usage shows you a visual breakdown of your current context window usage"}, + color_print::cstr! {"Get notified whenever Q CLI finishes responding. Just run q settings + chat.enableNotifications true"}, + color_print::cstr! {"You can execute bash commands by typing + ! followed by the command"}, + color_print::cstr! {"Q can use tools without asking for + confirmation every time. Give /tools trust a try"}, + color_print::cstr! {"You can + programmatically inject context to your prompts by using hooks. Check out /context hooks + help"}, + color_print::cstr! {"You can use /compact to replace the conversation + history with its summary to free up the context space"}, + color_print::cstr! {"If you want to file an issue + to the Q CLI team, just tell me, or run q issue"}, + color_print::cstr! {"You can enable + custom tools with MCP servers. Learn more with /help"}, + color_print::cstr! {"You can + specify wait time (in ms) for mcp server loading with q settings mcp.initTimeout {timeout in + int}. Servers that takes longer than the specified time will continue to load in the background. Use + /tools to see pending servers."}, + color_print::cstr! {"You can see the server load status as well as any + warnings or errors associated with /mcp"}, + color_print::cstr! {"Use /model to select the model to use for this conversation"}, + color_print::cstr! {"Set a default model by running q settings chat.defaultModel MODEL. Run /model to learn more."}, + color_print::cstr! {"Run /prompts to learn how to build & run repeatable workflows"}, + color_print::cstr! {"Use /tangent or ctrl + t (customizable) to start isolated conversations ( ↯ ) that don't affect your main chat history"}, +]; + +const GREETING_BREAK_POINT: usize = 80; + +const POPULAR_SHORTCUTS: &str = color_print::cstr! {"/help all commands ctrl + j new lines ctrl + s fuzzy search"}; +const SMALL_SCREEN_POPULAR_SHORTCUTS: &str = color_print::cstr! {"/help all commands +ctrl + j new lines +ctrl + s fuzzy search +"}; + +const RESPONSE_TIMEOUT_CONTENT: &str = "Response timed out - message took too long to generate"; +const TRUST_ALL_TEXT: &str = color_print::cstr! {"All tools are now trusted (!). Amazon Q will execute tools without asking for confirmation.\ +\nAgents can sometimes do unexpected things so understand the risks. +\nLearn more at https://docs.aws.amazon.com/amazonq/latest/qdeveloper-ug/command-line-chat-security.html#command-line-chat-trustall-safety"}; + +const TOOL_BULLET: &str = " ● "; +const CONTINUATION_LINE: &str = " ⋮ "; +const PURPOSE_ARROW: &str = " ↳ "; +const SUCCESS_TICK: &str = " ✓ "; +const ERROR_EXCLAMATION: &str = " ❗ "; + +/// Enum used to denote the origin of a tool use event +enum ToolUseStatus { + /// Variant denotes that the tool use event associated with chat context is a direct result of + /// a user request + Idle, + /// Variant denotes that the tool use event associated with the chat context is a result of a + /// retry for one or more previously attempted tool use. The tuple is the utterance id + /// associated with the original user request that necessitated the tool use + RetryInProgress(String), +} + +#[derive(Debug, Error)] +pub enum ChatError { + #[error("{0}")] + Client(Box), + #[error("{0}")] + Auth(#[from] AuthError), + #[error("{0}")] + SendMessage(Box), + #[error("{0}")] + ResponseStream(Box), + #[error("{0}")] + Std(#[from] std::io::Error), + #[error("{0}")] + Readline(#[from] rustyline::error::ReadlineError), + #[error("{0}")] + Custom(Cow<'static, str>), + #[error("interrupted")] + Interrupted { tool_uses: Option> }, + #[error(transparent)] + GetPromptError(#[from] GetPromptError), + #[error( + "Tool approval required but --no-interactive was specified. Use --trust-all-tools to automatically approve tools." + )] + NonInteractiveToolApproval, + #[error("The conversation history is too large to compact")] + CompactHistoryFailure, + #[error("Failed to swap to agent: {0}")] + AgentSwapError(eyre::Report), +} + +impl ChatError { + fn status_code(&self) -> Option { + match self { + ChatError::Client(e) => e.status_code(), + ChatError::Auth(_) => None, + ChatError::SendMessage(e) => e.status_code(), + ChatError::ResponseStream(_) => None, + ChatError::Std(_) => None, + ChatError::Readline(_) => None, + ChatError::Custom(_) => None, + ChatError::Interrupted { .. } => None, + ChatError::GetPromptError(_) => None, + ChatError::NonInteractiveToolApproval => None, + ChatError::CompactHistoryFailure => None, + ChatError::AgentSwapError(_) => None, + } + } +} + +impl ReasonCode for ChatError { + fn reason_code(&self) -> String { + match self { + ChatError::Client(e) => e.reason_code(), + ChatError::SendMessage(e) => e.reason_code(), + ChatError::ResponseStream(e) => e.reason_code(), + ChatError::Std(_) => "StdIoError".to_string(), + ChatError::Readline(_) => "ReadlineError".to_string(), + ChatError::Custom(_) => "GenericError".to_string(), + ChatError::Interrupted { .. } => "Interrupted".to_string(), + ChatError::GetPromptError(_) => "GetPromptError".to_string(), + ChatError::Auth(_) => "AuthError".to_string(), + ChatError::NonInteractiveToolApproval => "NonInteractiveToolApproval".to_string(), + ChatError::CompactHistoryFailure => "CompactHistoryFailure".to_string(), + ChatError::AgentSwapError(_) => "AgentSwapError".to_string(), + } + } +} + +impl From for ChatError { + fn from(value: ApiClientError) -> Self { + Self::Client(Box::new(value)) + } +} + +impl From for ChatError { + fn from(value: parser::SendMessageError) -> Self { + Self::SendMessage(Box::new(value)) + } +} + +impl From for ChatError { + fn from(value: parser::RecvError) -> Self { + Self::ResponseStream(Box::new(value)) + } +} + +pub struct ChatSession { + /// For output read by humans and machine + pub stdout: std::io::Stdout, + /// For display output, only read by humans + pub stderr: std::io::Stderr, + initial_input: Option, + /// Whether we're starting a new conversation or continuing an old one. + existing_conversation: bool, + input_source: InputSource, + /// Width of the terminal, required for [ParseState]. + terminal_width_provider: fn() -> Option, + spinner: Option, + /// [ConversationState]. + conversation: ConversationState, + /// Tool uses requested by the model that are actively being handled. + tool_uses: Vec, + /// An index into [Self::tool_uses] to represent the current tool use being handled. + pending_tool_index: Option, + /// The time immediately after having received valid tool uses from the model. + /// + /// Used to track the time taken from initially prompting the user to tool execute + /// completion. + tool_turn_start_time: Option, + /// [RequestMetadata] about the ongoing operation. + user_turn_request_metadata: Vec, + /// Telemetry events to be sent as part of the conversation. The HashMap key is tool_use_id. + tool_use_telemetry_events: HashMap, + /// State used to keep track of tool use relation + tool_use_status: ToolUseStatus, + /// Any failed requests that could be useful for error report/debugging + failed_request_ids: Vec, + /// Pending prompts to be sent + pending_prompts: VecDeque, + interactive: bool, + inner: Option, + ctrlc_rx: broadcast::Receiver<()>, +} + +impl ChatSession { + #[allow(clippy::too_many_arguments)] + pub async fn new( + os: &mut Os, + stdout: std::io::Stdout, + mut stderr: std::io::Stderr, + conversation_id: &str, + mut agents: Agents, + mut input: Option, + input_source: InputSource, + resume_conversation: bool, + terminal_width_provider: fn() -> Option, + tool_manager: ToolManager, + model_id: Option, + tool_config: HashMap, + interactive: bool, + mcp_enabled: bool, + ) -> Result { + // Reload prior conversation + let mut existing_conversation = false; + let previous_conversation = std::env::current_dir() + .ok() + .and_then(|cwd| os.database.get_conversation_by_path(cwd).ok()) + .flatten(); + + // Only restore conversations where there were actual messages. + // Prevents edge case where user clears conversation then exits without chatting. + let conversation = match resume_conversation + && previous_conversation + .as_ref() + .is_some_and(|cs| !cs.history().is_empty()) + { + true => { + let mut cs = previous_conversation.unwrap(); + existing_conversation = true; + input = Some(input.unwrap_or("In a few words, summarize our conversation so far.".to_owned())); + cs.tool_manager = tool_manager; + if let Some(profile) = cs.current_profile() { + if agents.switch(profile).is_err() { + execute!( + stderr, + style::SetForegroundColor(Color::Red), + style::Print("Error"), + style::ResetColor, + style::Print(format!( + ": cannot resume conversation with {profile} because it no longer exists. Using default.\n" + )) + )?; + let _ = agents.switch(DEFAULT_AGENT_NAME); + } + } + cs.agents = agents; + cs.mcp_enabled = mcp_enabled; + cs.update_state(true).await; + cs.enforce_tool_use_history_invariants(); + cs + }, + false => { + ConversationState::new( + conversation_id, + agents, + tool_config, + tool_manager, + model_id, + os, + mcp_enabled, + ) + .await + }, + }; + + // Spawn a task for listening and broadcasting sigints. + let (ctrlc_tx, ctrlc_rx) = tokio::sync::broadcast::channel(4); + tokio::spawn(async move { + loop { + match ctrl_c().await { + Ok(_) => { + let _ = ctrlc_tx + .send(()) + .map_err(|err| error!(?err, "failed to send ctrlc to broadcast channel")); + }, + Err(err) => { + error!(?err, "Encountered an error while receiving a ctrl+c"); + }, + } + } + }); + + Ok(Self { + stdout, + stderr, + initial_input: input, + existing_conversation, + input_source, + terminal_width_provider, + spinner: None, + conversation, + tool_uses: vec![], + user_turn_request_metadata: vec![], + pending_tool_index: None, + tool_turn_start_time: None, + tool_use_telemetry_events: HashMap::new(), + tool_use_status: ToolUseStatus::Idle, + failed_request_ids: Vec::new(), + pending_prompts: VecDeque::new(), + interactive, + inner: Some(ChatState::default()), + ctrlc_rx, + }) + } + + pub async fn next(&mut self, os: &mut Os) -> Result<(), ChatError> { + // Update conversation state with new tool information + self.conversation.update_state(false).await; + + let mut ctrl_c_stream = self.ctrlc_rx.resubscribe(); + let result = match self.inner.take().expect("state must always be Some") { + ChatState::PromptUser { skip_printing_tools } => { + match (self.interactive, self.tool_uses.is_empty()) { + (false, true) => { + self.inner = Some(ChatState::Exit); + return Ok(()); + }, + (false, false) => { + return Err(ChatError::NonInteractiveToolApproval); + }, + _ => (), + }; + + self.prompt_user(os, skip_printing_tools).await + }, + ChatState::HandleInput { input } => { + tokio::select! { + res = self.handle_input(os, input) => res, + Ok(_) = ctrl_c_stream.recv() => Err(ChatError::Interrupted { tool_uses: Some(self.tool_uses.clone()) }) + } + }, + ChatState::CompactHistory { + prompt, + show_summary, + strategy, + } => { + // compact_history manages ctrl+c handling + self.compact_history(os, prompt, show_summary, strategy).await + }, + ChatState::ExecuteTools => { + let tool_uses_clone = self.tool_uses.clone(); + tokio::select! { + res = self.tool_use_execute(os) => res, + Ok(_) = ctrl_c_stream.recv() => Err(ChatError::Interrupted { tool_uses: Some(tool_uses_clone) }) + } + }, + ChatState::ValidateTools { tool_uses } => { + tokio::select! { + res = self.validate_tools(os, tool_uses) => res, + Ok(_) = ctrl_c_stream.recv() => Err(ChatError::Interrupted { tool_uses: None }) + } + }, + ChatState::HandleResponseStream(conversation_state) => { + let request_metadata: Arc>> = Arc::new(Mutex::new(None)); + let request_metadata_clone = Arc::clone(&request_metadata); + + tokio::select! { + res = self.handle_response(os, conversation_state, request_metadata_clone) => res, + Ok(_) = ctrl_c_stream.recv() => { + debug!(?request_metadata, "ctrlc received"); + // Wait for handle_response to finish handling the ctrlc. + tokio::time::sleep(Duration::from_millis(5)).await; + if let Some(request_metadata) = request_metadata.lock().await.take() { + self.user_turn_request_metadata.push(request_metadata); + } + self.send_chat_telemetry(os, TelemetryResult::Cancelled, None, None, None, true).await; + Err(ChatError::Interrupted { tool_uses: None }) + } + } + }, + ChatState::RetryModelOverload => tokio::select! { + res = self.retry_model_overload(os) => res, + Ok(_) = ctrl_c_stream.recv() => { + Err(ChatError::Interrupted { tool_uses: None }) + } + }, + ChatState::Exit => return Ok(()), + }; + + let err = match result { + Ok(state) => { + self.inner = Some(state); + return Ok(()); + }, + Err(err) => err, + }; + + // We encountered an error. Handle it. + error!(?err, "An error occurred processing the current state"); + let (reason, reason_desc) = get_error_reason(&err); + self.send_error_telemetry(os, reason, Some(reason_desc), err.status_code()) + .await; + + if self.spinner.is_some() { + drop(self.spinner.take()); + queue!( + self.stderr, + terminal::Clear(terminal::ClearType::CurrentLine), + cursor::MoveToColumn(0), + )?; + } + + let (context, report, display_err_message) = match err { + ChatError::Interrupted { tool_uses: ref inter } => { + execute!(self.stderr, style::Print("\n\n"))?; + + // If there was an interrupt during tool execution, then we add fake + // messages to "reset" the chat state. + match inter { + Some(tool_uses) if !tool_uses.is_empty() => { + self.conversation + .abandon_tool_use(tool_uses, "The user interrupted the tool execution.".to_string()); + let _ = self + .conversation + .as_sendable_conversation_state(os, &mut self.stderr, false) + .await?; + self.conversation.push_assistant_message( + os, + AssistantMessage::new_response( + None, + "Tool uses were interrupted, waiting for the next user prompt".to_string(), + ), + None, + ); + }, + _ => (), + } + + ("Tool use was interrupted", Report::from(err), false) + }, + ChatError::CompactHistoryFailure => { + // This error is not retryable - the user must take manual intervention to manage + // their context. + execute!( + self.stderr, + style::SetForegroundColor(Color::Red), + style::Print("Your conversation is too large to continue.\n"), + style::SetForegroundColor(Color::Reset), + style::Print(format!( + "• Run {} to compact your conversation. See {} for compaction options\n", + "/compact".green(), + "/compact --help".green() + )), + style::Print(format!("• Run {} to analyze your context usage\n", "/usage".green())), + style::Print(format!("• Run {} to reset your conversation state\n", "/clear".green())), + style::SetAttribute(Attribute::Reset), + style::Print("\n\n"), + )?; + ("Unable to compact the conversation history", eyre!(err), true) + }, + ChatError::SendMessage(err) => match err.source { + // Errors from attempting to send too large of a conversation history. In + // this case, attempt to automatically compact the history for the user. + ApiClientError::ContextWindowOverflow { .. } => { + if os + .database + .settings + .get_bool(Setting::ChatDisableAutoCompaction) + .unwrap_or(false) + { + execute!( + self.stderr, + style::SetForegroundColor(Color::Red), + style::Print("The conversation history has overflowed.\n"), + style::SetForegroundColor(Color::Reset), + style::Print(format!("• Run {} to compact your conversation\n", "/compact".green())), + style::SetAttribute(Attribute::Reset), + style::Print("\n\n"), + )?; + ("The conversation history has overflowed", eyre!(err), false) + } else { + self.inner = Some(ChatState::CompactHistory { + prompt: None, + show_summary: false, + strategy: CompactStrategy { + truncate_large_messages: self.conversation.history().len() <= 2, + max_message_length: if self.conversation.history().len() <= 2 { + 25_000 + } else { + Default::default() + }, + ..Default::default() + }, + }); + + execute!( + self.stdout, + style::SetForegroundColor(Color::Yellow), + style::Print("The context window has overflowed, summarizing the history..."), + style::SetAttribute(Attribute::Reset), + style::Print("\n\n"), + )?; + + return Ok(()); + } + }, + ApiClientError::QuotaBreach { + message: _, + status_code: _, + } => { + let err = "Request quota exceeded. Please wait a moment and try again.".to_string(); + self.conversation.append_transcript(err.clone()); + execute!( + self.stderr, + style::SetAttribute(Attribute::Bold), + style::SetForegroundColor(Color::Red), + style::Print(" ⚠️ Amazon Q rate limit reached:\n"), + style::Print(format!(" {}\n\n", err.clone())), + style::SetAttribute(Attribute::Reset), + style::SetForegroundColor(Color::Reset), + )?; + ("Amazon Q is having trouble responding right now", eyre!(err), false) + }, + ApiClientError::ModelOverloadedError { request_id, .. } => { + if self.interactive { + execute!( + self.stderr, + style::SetAttribute(Attribute::Bold), + style::SetForegroundColor(Color::Red), + style::Print( + "\nThe model you've selected is temporarily unavailable. Please select a different model.\n" + ), + style::SetAttribute(Attribute::Reset), + style::SetForegroundColor(Color::Reset), + )?; + + if let Some(id) = request_id { + self.conversation + .append_transcript(format!("Model unavailable (Request ID: {})", id)); + } + + self.inner = Some(ChatState::RetryModelOverload); + + return Ok(()); + } + + // non-interactive throws this error + let model_instruction = "Please relaunch with '--model ' to use a different model."; + let err = format!( + "The model you've selected is temporarily unavailable. {}{}\n\n", + model_instruction, + match request_id { + Some(id) => format!("\n Request ID: {}", id), + None => "".to_owned(), + } + ); + self.conversation.append_transcript(err.clone()); + execute!( + self.stderr, + style::SetAttribute(Attribute::Bold), + style::SetForegroundColor(Color::Red), + style::Print("Amazon Q is having trouble responding right now:\n"), + style::Print(format!(" {}\n", err.clone())), + style::SetAttribute(Attribute::Reset), + style::SetForegroundColor(Color::Reset), + )?; + ("Amazon Q is having trouble responding right now", eyre!(err), false) + }, + ApiClientError::MonthlyLimitReached { .. } => { + let subscription_status = get_subscription_status(os).await; + if subscription_status.is_err() { + execute!( + self.stderr, + style::SetForegroundColor(Color::Red), + style::Print(format!( + "Unable to verify subscription status: {}\n\n", + subscription_status.as_ref().err().unwrap() + )), + style::SetForegroundColor(Color::Reset), + )?; + } + + execute!( + self.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("Monthly request limit reached"), + style::SetForegroundColor(Color::Reset), + )?; + + let limits_text = format!( + "The limits reset on {:02}/01.", + OffsetDateTime::now_utc().month().next() as u8 + ); + + if subscription_status.is_err() + || subscription_status.is_ok_and(|s| s == ActualSubscriptionStatus::None) + { + execute!( + self.stderr, + style::Print(format!("\n\n{LIMIT_REACHED_TEXT} {limits_text}")), + style::SetForegroundColor(Color::DarkGrey), + style::Print("\n\nUse "), + style::SetForegroundColor(Color::Green), + style::Print("/subscribe"), + style::SetForegroundColor(Color::DarkGrey), + style::Print(" to upgrade your subscription.\n\n"), + style::SetForegroundColor(Color::Reset), + )?; + } else { + execute!( + self.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print(format!(" - {limits_text}\n\n")), + style::SetForegroundColor(Color::Reset), + )?; + } + + self.inner = Some(ChatState::PromptUser { + skip_printing_tools: false, + }); + + return Ok(()); + }, + _ => ( + "Amazon Q is having trouble responding right now", + Report::from(err), + true, + ), + }, + _ => ( + "Amazon Q is having trouble responding right now", + Report::from(err), + true, + ), + }; + + if display_err_message { + // Remove non-ASCII and ANSI characters. + let re = Regex::new(r"((\x9B|\x1B\[)[0-?]*[ -\/]*[@-~])|([^\x00-\x7F]+)").unwrap(); + + queue!( + self.stderr, + style::SetAttribute(Attribute::Bold), + style::SetForegroundColor(Color::Red), + )?; + + let text = re.replace_all(&format!("{}: {:?}\n", context, report), "").into_owned(); + + queue!(self.stderr, style::Print(&text),)?; + self.conversation.append_transcript(text); + + execute!( + self.stderr, + style::SetAttribute(Attribute::Reset), + style::SetForegroundColor(Color::Reset), + )?; + } + + self.conversation.enforce_conversation_invariants(); + self.conversation.reset_next_user_message(); + self.pending_tool_index = None; + self.tool_turn_start_time = None; + self.reset_user_turn(); + + self.inner = Some(ChatState::PromptUser { + skip_printing_tools: false, + }); + + Ok(()) + } +} + +impl Drop for ChatSession { + fn drop(&mut self) { + if let Some(spinner) = &mut self.spinner { + spinner.stop(); + } + + execute!( + self.stderr, + cursor::MoveToColumn(0), + style::SetAttribute(Attribute::Reset), + style::ResetColor, + cursor::Show + ) + .ok(); + } +} + +/// The chat execution state. +/// +/// Intended to provide more robust handling around state transitions while dealing with, e.g., +/// tool validation, execution, response stream handling, etc. +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub enum ChatState { + /// Prompt the user with `tool_uses`, if available. + PromptUser { + /// Used to avoid displaying the tool info at inappropriate times, e.g. after clear or help + /// commands. + skip_printing_tools: bool, + }, + /// Handle the user input, depending on if any tools require execution. + HandleInput { input: String }, + /// Validate the list of tool uses provided by the model. + ValidateTools { tool_uses: Vec }, + /// Execute the list of tools. + ExecuteTools, + /// Consume the response stream and display to the user. + HandleResponseStream(crate::api_client::model::ConversationState), + /// Compact the chat history. + CompactHistory { + /// Custom prompt to include as part of history compaction. + prompt: Option, + /// Whether or not the summary should be shown on compact success. + show_summary: bool, + /// Parameters for how to perform the compaction request. + strategy: CompactStrategy, + }, + /// Retry the current request if we encounter a model overloaded error. + RetryModelOverload, + /// Exit the chat. + Exit, +} + +impl Default for ChatState { + fn default() -> Self { + Self::PromptUser { + skip_printing_tools: false, + } + } +} + +impl ChatSession { + /// Sends a request to the SendMessage API. Emits error telemetry on failure. + async fn send_message( + &mut self, + os: &mut Os, + conversation_state: api_client::model::ConversationState, + request_metadata_lock: Arc>>, + message_meta_tags: Option>, + ) -> Result { + match SendMessageStream::send_message(&os.client, conversation_state, request_metadata_lock, message_meta_tags) + .await + { + Ok(res) => Ok(res), + Err(err) => { + let (reason, reason_desc) = get_error_reason(&err); + self.send_chat_telemetry( + os, + TelemetryResult::Failed, + Some(reason), + Some(reason_desc), + err.status_code(), + true, // We never retry failed requests, so this always ends the current turn. + ) + .await; + Err(err.into()) + }, + } + } + + async fn spawn(&mut self, os: &mut Os) -> Result<()> { + let is_small_screen = self.terminal_width() < GREETING_BREAK_POINT; + if os + .database + .settings + .get_bool(Setting::ChatGreetingEnabled) + .unwrap_or(true) + { + let welcome_text = match self.existing_conversation { + true => RESUME_TEXT, + false => match is_small_screen { + true => SMALL_SCREEN_WELCOME_TEXT, + false => WELCOME_TEXT, + }, + }; + + execute!(self.stderr, style::Print(welcome_text), style::Print("\n\n"),)?; + + let tip = ROTATING_TIPS[usize::try_from(rand::random::()).unwrap_or(0) % ROTATING_TIPS.len()]; + if is_small_screen { + // If the screen is small, print the tip in a single line + execute!( + self.stderr, + style::Print("💡 ".to_string()), + style::Print(tip), + style::Print("\n") + )?; + } else { + draw_box( + &mut self.stderr, + "Did you know?", + tip, + GREETING_BREAK_POINT, + Color::DarkGrey, + )?; + } + + execute!( + self.stderr, + style::Print("\n"), + style::Print(match is_small_screen { + true => SMALL_SCREEN_POPULAR_SHORTCUTS, + false => POPULAR_SHORTCUTS, + }), + style::Print("\n"), + style::Print( + "━" + .repeat(if is_small_screen { 0 } else { GREETING_BREAK_POINT }) + .dark_grey() + ) + )?; + execute!(self.stderr, style::Print("\n"), style::SetForegroundColor(Color::Reset))?; + } + + if self.all_tools_trusted() { + queue!( + self.stderr, + style::Print(format!( + "{}{TRUST_ALL_TEXT}\n\n", + if !is_small_screen { "\n" } else { "" } + )) + )?; + } + + if let Some(agent) = self.conversation.agents.get_active() { + agent.print_overridden_permissions(&mut self.stderr)?; + } + + self.stderr.flush()?; + + if let Some(ref model_info) = self.conversation.model_info { + let (models, _default_model) = get_available_models(os).await?; + if let Some(model_option) = models.iter().find(|option| option.model_id == model_info.model_id) { + let display_name = model_option.model_name.as_deref().unwrap_or(&model_option.model_id); + execute!( + self.stderr, + style::SetForegroundColor(Color::Cyan), + style::Print(format!("🤖 You are chatting with {}\n", display_name)), + style::SetForegroundColor(Color::Reset), + style::Print("\n") + )?; + } + } + + if let Some(user_input) = self.initial_input.take() { + self.inner = Some(ChatState::HandleInput { input: user_input }); + } + + while !matches!(self.inner, Some(ChatState::Exit)) { + self.next(os).await?; + } + + Ok(()) + } + + /// Compacts the conversation history using the strategy specified by [CompactStrategy], + /// replacing the history with a summary generated by the model. + /// + /// If the compact request itself fails, it will be retried depending on [CompactStrategy] + /// + /// If [CompactStrategy::messages_to_exclude] is greater than 0, and + /// [CompactStrategy::truncate_large_messages] is true, then compaction will not be retried and + /// will fail with [ChatError::CompactHistoryFailure]. + async fn compact_history( + &mut self, + os: &mut Os, + custom_prompt: Option, + show_summary: bool, + strategy: CompactStrategy, + ) -> Result { + // Same pattern as is done for handle_response for getting request metadata on sigint. + let request_metadata: Arc>> = Arc::new(Mutex::new(None)); + let request_metadata_clone = Arc::clone(&request_metadata); + let mut ctrl_c_stream = self.ctrlc_rx.resubscribe(); + + tokio::select! { + res = self.compact_history_impl(os, custom_prompt, show_summary, strategy, request_metadata_clone) => res, + Ok(_) = ctrl_c_stream.recv() => { + debug!(?request_metadata, "ctrlc received in compact history"); + // Wait for handle_response to finish handling the ctrlc. + tokio::time::sleep(Duration::from_millis(5)).await; + if let Some(request_metadata) = request_metadata.lock().await.take() { + self.user_turn_request_metadata.push(request_metadata); + } + self.send_chat_telemetry( + os, + TelemetryResult::Cancelled, + None, + None, + None, + true, + ) + .await; + Err(ChatError::Interrupted { tool_uses: Some(self.tool_uses.clone()) }) + } + } + } + + async fn compact_history_impl( + &mut self, + os: &mut Os, + custom_prompt: Option, + show_summary: bool, + strategy: CompactStrategy, + request_metadata_lock: Arc>>, + ) -> Result { + let hist = self.conversation.history(); + debug!(?strategy, ?hist, "compacting history"); + + if self.conversation.history().is_empty() { + execute!( + self.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("\nConversation too short to compact.\n\n"), + style::SetForegroundColor(Color::Reset) + )?; + + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + + if strategy.truncate_large_messages { + info!("truncating large messages"); + execute!( + self.stderr, + terminal::Clear(terminal::ClearType::CurrentLine), + cursor::MoveToColumn(0), + style::SetForegroundColor(Color::Yellow), + style::Print("Truncating large messages..."), + style::SetAttribute(Attribute::Reset), + style::Print("\n\n"), + )?; + } + + let summary_state = self + .conversation + .create_summary_request(os, custom_prompt.as_ref(), strategy) + .await?; + + if self.interactive { + execute!(self.stderr, cursor::Hide, style::Print("\n"))?; + self.spinner = Some(Spinner::new(Spinners::Dots, "Creating summary...".to_string())); + } + + let mut response = match self + .send_message( + os, + summary_state, + request_metadata_lock, + Some(vec![MessageMetaTag::Compact]), + ) + .await + { + Ok(res) => res, + Err(err) => { + if self.interactive { + self.spinner.take(); + execute!( + self.stderr, + terminal::Clear(terminal::ClearType::CurrentLine), + cursor::MoveToColumn(0), + style::SetAttribute(Attribute::Reset) + )?; + } + + // If the request fails due to context window overflow, then we'll see if it's + // retryable according to the passed strategy. + let history_len = self.conversation.history().len(); + match err { + ChatError::SendMessage(err) + if matches!(err.source, ApiClientError::ContextWindowOverflow { .. }) => + { + error!(?strategy, "failed to send compaction request"); + // If there's only two messages in the history, we have no choice but to + // truncate it. We use two messages since it's almost guaranteed to contain: + // 1. A small user prompt + // 2. A large user tool use result + if history_len <= 2 && !strategy.truncate_large_messages { + return Ok(ChatState::CompactHistory { + prompt: custom_prompt, + show_summary, + strategy: CompactStrategy { + truncate_large_messages: true, + max_message_length: 25_000, + messages_to_exclude: 0, + }, + }); + } + + // Otherwise, we will first exclude the most recent message, and only then + // truncate. If both of these have already been set, then return an error. + if history_len > 2 && strategy.messages_to_exclude < 1 { + return Ok(ChatState::CompactHistory { + prompt: custom_prompt, + show_summary, + strategy: CompactStrategy { + messages_to_exclude: 1, + ..strategy + }, + }); + } else if !strategy.truncate_large_messages { + return Ok(ChatState::CompactHistory { + prompt: custom_prompt, + show_summary, + strategy: CompactStrategy { + truncate_large_messages: true, + max_message_length: 25_000, + ..strategy + }, + }); + } else { + return Err(ChatError::CompactHistoryFailure); + } + }, + err => return Err(err), + } + }, + }; + + let (summary, request_metadata) = { + loop { + match response.recv().await { + Some(Ok(parser::ResponseEvent::EndStream { + message, + request_metadata, + })) => { + self.user_turn_request_metadata.push(request_metadata.clone()); + break (message.content().to_string(), request_metadata); + }, + Some(Ok(_)) => (), + Some(Err(err)) => { + if let Some(request_id) = &err.request_metadata.request_id { + self.failed_request_ids.push(request_id.clone()); + }; + + self.user_turn_request_metadata.push(err.request_metadata.clone()); + + let (reason, reason_desc) = get_error_reason(&err); + self.send_chat_telemetry( + os, + TelemetryResult::Failed, + Some(reason), + Some(reason_desc), + err.status_code(), + true, + ) + .await; + + return Err(err.into()); + }, + None => { + error!("response stream receiver closed before receiving a stop event"); + return Err(ChatError::Custom("Stream failed during compaction".into())); + }, + } + } + }; + + if self.spinner.is_some() { + drop(self.spinner.take()); + queue!( + self.stderr, + terminal::Clear(terminal::ClearType::CurrentLine), + cursor::MoveToColumn(0), + cursor::Show + )?; + } + + self.conversation + .replace_history_with_summary(summary.clone(), strategy, request_metadata); + + // If a next message is set, then retry the request. + let should_retry = self.conversation.next_user_message().is_some(); + + // If we retry, then don't end the current turn. + self.send_chat_telemetry(os, TelemetryResult::Succeeded, None, None, None, !should_retry) + .await; + + // Print output to the user. + { + execute!( + self.stderr, + style::SetForegroundColor(Color::Green), + style::Print("✔ Conversation history has been compacted successfully!\n\n"), + style::SetForegroundColor(Color::DarkGrey) + )?; + + let mut output = Vec::new(); + if let Some(custom_prompt) = &custom_prompt { + execute!( + output, + style::Print(format!("• Custom prompt applied: {}\n", custom_prompt)) + )?; + } + animate_output(&mut self.stderr, &output)?; + + // Display the summary if the show_summary flag is set + if show_summary { + // Add a border around the summary for better visual separation + let terminal_width = self.terminal_width(); + let border = "═".repeat(terminal_width.min(80)); + execute!( + self.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Cyan), + style::Print(&border), + style::Print("\n"), + style::SetAttribute(Attribute::Bold), + style::Print(" CONVERSATION SUMMARY"), + style::Print("\n"), + style::Print(&border), + style::SetAttribute(Attribute::Reset), + style::Print("\n\n"), + )?; + + execute!( + output, + style::Print(&summary), + style::Print("\n\n"), + style::SetForegroundColor(Color::Cyan), + style::Print("The conversation history has been replaced with this summary.\n"), + style::Print("It contains all important details from previous interactions.\n"), + )?; + animate_output(&mut self.stderr, &output)?; + + execute!( + self.stderr, + style::Print(&border), + style::Print("\n\n"), + style::SetForegroundColor(Color::Reset) + )?; + } + } + + if should_retry { + Ok(ChatState::HandleResponseStream( + self.conversation + .as_sendable_conversation_state(os, &mut self.stderr, false) + .await?, + )) + } else { + // Otherwise, return back to the prompt for any pending tool uses. + Ok(ChatState::PromptUser { + skip_printing_tools: true, + }) + } + } + + /// Read input from the user. + async fn prompt_user(&mut self, os: &Os, skip_printing_tools: bool) -> Result { + execute!(self.stderr, cursor::Show)?; + + // Check token usage and display warnings if needed + if self.pending_tool_index.is_none() { + // Only display warnings when not waiting for tool approval + if let Err(err) = self.display_char_warnings(os).await { + warn!("Failed to display character limit warnings: {}", err); + } + } + + let show_tool_use_confirmation_dialog = !skip_printing_tools && self.pending_tool_index.is_some(); + if show_tool_use_confirmation_dialog { + execute!( + self.stderr, + style::SetForegroundColor(Color::DarkGrey), + style::Print("\nAllow this action? Use '"), + style::SetForegroundColor(Color::Green), + style::Print("t"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("' to trust (always allow) this tool for the session. ["), + style::SetForegroundColor(Color::Green), + style::Print("y"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("/"), + style::SetForegroundColor(Color::Green), + style::Print("n"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("/"), + style::SetForegroundColor(Color::Green), + style::Print("t"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("]:\n\n"), + style::SetForegroundColor(Color::Reset), + )?; + } + + // Do this here so that the skim integration sees an updated view of the context *during the current + // q session*. (e.g., if I add files to context, that won't show up for skim for the current + // q session unless we do this in prompt_user... unless you can find a better way) + #[cfg(unix)] + if let Some(ref context_manager) = self.conversation.context_manager { + use std::sync::Arc; + + use crate::cli::chat::consts::DUMMY_TOOL_NAME; + + let tool_names = self + .conversation + .tool_manager + .tn_map + .keys() + .filter(|name| *name != DUMMY_TOOL_NAME) + .cloned() + .collect::>(); + self.input_source + .put_skim_command_selector(os, Arc::new(context_manager.clone()), tool_names); + } + + execute!( + self.stderr, + style::SetForegroundColor(Color::Reset), + style::SetAttribute(Attribute::Reset) + )?; + let prompt = self.generate_tool_trust_prompt(); + let user_input = match self.read_user_input(&prompt, false) { + Some(input) => input, + None => return Ok(ChatState::Exit), + }; + + self.conversation.append_user_transcript(&user_input); + Ok(ChatState::HandleInput { input: user_input }) + } + + async fn handle_input(&mut self, os: &mut Os, mut user_input: String) -> Result { + queue!(self.stderr, style::Print('\n'))?; + user_input = sanitize_unicode_tags(&user_input); + let input = user_input.trim(); + + // handle image path + if let Some(chat_state) = does_input_reference_file(input) { + return Ok(chat_state); + } + if let Some(mut args) = input.strip_prefix("/").and_then(shlex::split) { + // Required for printing errors correctly. + let orig_args = args.clone(); + + // We set the binary name as a dummy name "slash_command" which we + // replace anytime we error out and print a usage statement. + args.insert(0, "slash_command".to_owned()); + + match SlashCommand::try_parse_from(args) { + Ok(command) => { + let command_name = command.command_name().to_string(); + let subcommand_name = command.subcommand_name().map(|s| s.to_string()); + + match command.execute(os, self).await { + Ok(chat_state) => { + let _ = self + .send_slash_command_telemetry( + os, + command_name, + subcommand_name, + TelemetryResult::Succeeded, + None, + ) + .await; + + if matches!(chat_state, ChatState::Exit) + || matches!(chat_state, ChatState::HandleResponseStream(_)) + || matches!(chat_state, ChatState::HandleInput { input: _ }) + // TODO(bskiser): this is just a hotfix for handling state changes + // from manually running /compact, without impacting behavior of + // other slash commands. + || matches!(chat_state, ChatState::CompactHistory { .. }) + { + return Ok(chat_state); + } + }, + Err(err) => { + queue!( + self.stderr, + style::SetForegroundColor(Color::Red), + style::Print(format!("\nFailed to execute command: {}\n", err)), + style::SetForegroundColor(Color::Reset) + )?; + let _ = self + .send_slash_command_telemetry( + os, + command_name, + subcommand_name, + TelemetryResult::Failed, + Some(err.to_string()), + ) + .await; + }, + } + + writeln!(self.stderr)?; + }, + Err(err) => { + // Replace the dummy name with a slash. Also have to check for an ansi sequence + // for invalid slash commands (e.g. on a "/doesntexist" input). + let ansi_output = err + .render() + .ansi() + .to_string() + .replace("slash_command ", "/") + .replace("slash_command\u{1b}[0m ", "/"); + + writeln!(self.stderr, "{}", ansi_output)?; + + // Print the subcommand help, if available. Required since by default we won't + // show what the actual arguments are, requiring an unnecessary --help call. + if let clap::error::ErrorKind::InvalidValue + | clap::error::ErrorKind::UnknownArgument + | clap::error::ErrorKind::InvalidSubcommand + | clap::error::ErrorKind::MissingRequiredArgument = err.kind() + { + let mut cmd = SlashCommand::command(); + for arg in &orig_args { + match cmd.find_subcommand(arg) { + Some(subcmd) => cmd = subcmd.clone(), + None => break, + } + } + let help = cmd.help_template("{all-args}").render_help(); + writeln!(self.stderr, "{}", help.ansi())?; + } + }, + } + + Ok(ChatState::PromptUser { + skip_printing_tools: false, + }) + } else if let Some(command) = input.strip_prefix("@") { + let input_parts = + shlex::split(command).ok_or(ChatError::Custom("Error splitting prompt command".into()))?; + + let mut iter = input_parts.into_iter(); + let prompt_name = iter + .next() + .ok_or(ChatError::Custom("Prompt name needs to be specified".into()))?; + + let args: Vec = iter.collect(); + let arguments = if args.is_empty() { None } else { Some(args) }; + + let subcommand = PromptsSubcommand::Get { + orig_input: Some(command.to_string()), + name: prompt_name, + arguments, + }; + return subcommand.execute(self).await; + } else if let Some(command) = input.strip_prefix("!") { + // Use platform-appropriate shell + let result = if cfg!(target_os = "windows") { + std::process::Command::new("cmd").args(["/C", command]).status() + } else { + std::process::Command::new("bash").args(["-c", command]).status() + }; + + // Handle the result and provide appropriate feedback + match result { + Ok(status) => { + if !status.success() { + queue!( + self.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print(format!("Self exited with status: {}\n", status)), + style::SetForegroundColor(Color::Reset) + )?; + } + }, + Err(e) => { + queue!( + self.stderr, + style::SetForegroundColor(Color::Red), + style::Print(format!("\nFailed to execute command: {}\n", e)), + style::SetForegroundColor(Color::Reset) + )?; + }, + } + + Ok(ChatState::PromptUser { + skip_printing_tools: false, + }) + } else { + // Check for a pending tool approval + if let Some(index) = self.pending_tool_index { + let is_trust = ["t", "T"].contains(&input); + let tool_use = &mut self.tool_uses[index]; + if ["y", "Y"].contains(&input) || is_trust { + if is_trust { + let formatted_tool_name = self + .conversation + .tool_manager + .tn_map + .get(&tool_use.name) + .map(|info| { + format!( + "@{}{MCP_SERVER_TOOL_DELIMITER}{}", + info.server_name, info.host_tool_name + ) + }) + .clone() + .unwrap_or(tool_use.name.clone()); + self.conversation.agents.trust_tools(vec![formatted_tool_name]); + + if let Some(agent) = self.conversation.agents.get_active() { + agent + .print_overridden_permissions(&mut self.stderr) + .map_err(|_e| ChatError::Custom("Failed to validate agent tool settings".into()))?; + } + } + tool_use.accepted = true; + + return Ok(ChatState::ExecuteTools); + } + } else if !self.pending_prompts.is_empty() { + let prompts = self.pending_prompts.drain(0..).collect(); + user_input = self + .conversation + .append_prompts(prompts) + .ok_or(ChatError::Custom("Prompt append failed".into()))?; + } + + // Otherwise continue with normal chat on 'n' or other responses + self.tool_use_status = ToolUseStatus::Idle; + + if self.pending_tool_index.is_some() { + // If the user just enters "n", replace the message we send to the model with + // something more substantial. + // TODO: Update this flow to something that does *not* require two requests just to + // get a meaningful response from the user - this is a short term solution before + // we decide on a better flow. + let user_input = if ["n", "N"].contains(&user_input.trim()) { + "I deny this tool request. Ask a follow up question clarifying the expected action".to_string() + } else { + user_input + }; + self.conversation.abandon_tool_use(&self.tool_uses, user_input); + } else { + self.conversation.set_next_user_message(user_input).await; + } + + self.reset_user_turn(); + + let conv_state = self + .conversation + .as_sendable_conversation_state(os, &mut self.stderr, true) + .await?; + self.send_tool_use_telemetry(os).await; + + queue!(self.stderr, style::SetForegroundColor(Color::Magenta))?; + queue!(self.stderr, style::SetForegroundColor(Color::Reset))?; + queue!(self.stderr, cursor::Hide)?; + + if self.interactive { + self.spinner = Some(Spinner::new(Spinners::Dots, "Thinking...".to_owned())); + } + + Ok(ChatState::HandleResponseStream(conv_state)) + } + } + + async fn tool_use_execute(&mut self, os: &mut Os) -> Result { + // Verify tools have permissions. + for i in 0..self.tool_uses.len() { + let tool = &mut self.tool_uses[i]; + + // Manually accepted by the user or otherwise verified already. + if tool.accepted { + continue; + } + + let mut denied_match_set = None::>; + let allowed = + self.conversation + .agents + .get_active() + .is_some_and(|a| match tool.tool.requires_acceptance(os, a) { + PermissionEvalResult::Allow => true, + PermissionEvalResult::Ask => false, + PermissionEvalResult::Deny(matches) => { + denied_match_set.replace(matches); + false + }, + }) + || self.conversation.agents.trust_all_tools; + + if let Some(match_set) = denied_match_set { + let formatted_set = match_set.into_iter().fold(String::new(), |mut acc, rule| { + acc.push_str(&format!("\n - {rule}")); + acc + }); + + execute!( + self.stderr, + style::SetForegroundColor(Color::Red), + style::Print("Command "), + style::SetForegroundColor(Color::Yellow), + style::Print(&tool.name), + style::SetForegroundColor(Color::Red), + style::Print(" is rejected because it matches one or more rules on the denied list:"), + style::Print(formatted_set), + style::Print("\n"), + style::SetForegroundColor(Color::Reset), + )?; + + return Ok(ChatState::HandleInput { + input: format!( + "Tool use with {} was rejected because the arguments supplied were forbidden", + tool.name + ), + }); + } + + if os + .database + .settings + .get_bool(Setting::ChatEnableNotifications) + .unwrap_or(false) + { + play_notification_bell(!allowed); + } + + // TODO: Control flow is hacky here because of borrow rules + let _ = tool; + self.print_tool_description(os, i, allowed).await?; + let tool = &mut self.tool_uses[i]; + + if allowed { + tool.accepted = true; + self.tool_use_telemetry_events + .entry(tool.id.clone()) + .and_modify(|ev| ev.is_trusted = true); + continue; + } + + self.pending_tool_index = Some(i); + + return Ok(ChatState::PromptUser { + skip_printing_tools: false, + }); + } + + // Execute the requested tools. + let mut tool_results = vec![]; + let mut image_blocks: Vec = Vec::new(); + + for tool in &self.tool_uses { + let tool_start = std::time::Instant::now(); + let mut tool_telemetry = self.tool_use_telemetry_events.entry(tool.id.clone()); + tool_telemetry = tool_telemetry.and_modify(|ev| { + ev.is_accepted = true; + }); + + // Extract AWS service name and operation name if available + if let Some(additional_info) = tool.tool.get_additional_info() { + if let Some(aws_service_name) = additional_info.get("aws_service_name").and_then(|v| v.as_str()) { + tool_telemetry = + tool_telemetry.and_modify(|ev| ev.aws_service_name = Some(aws_service_name.to_string())); + } + if let Some(aws_operation_name) = additional_info.get("aws_operation_name").and_then(|v| v.as_str()) { + tool_telemetry = + tool_telemetry.and_modify(|ev| ev.aws_operation_name = Some(aws_operation_name.to_string())); + } + } + + let invoke_result = tool + .tool + .invoke(os, &mut self.stdout, &mut self.conversation.file_line_tracker) + .await; + + if self.spinner.is_some() { + queue!( + self.stderr, + terminal::Clear(terminal::ClearType::CurrentLine), + cursor::MoveToColumn(0), + cursor::Show + )?; + } + execute!(self.stdout, style::Print("\n"))?; + + let tool_end_time = Instant::now(); + let tool_time = tool_end_time.duration_since(tool_start); + tool_telemetry = tool_telemetry.and_modify(|ev| { + ev.execution_duration = Some(tool_time); + ev.turn_duration = self.tool_turn_start_time.map(|t| tool_end_time.duration_since(t)); + }); + if let Tool::Custom(ct) = &tool.tool { + tool_telemetry = tool_telemetry.and_modify(|ev| { + ev.is_custom_tool = true; + // legacy fields previously implemented for only MCP tools + ev.custom_tool_call_latency = Some(tool_time.as_secs() as usize); + ev.input_token_size = Some(ct.get_input_token_size()); + }); + } + let tool_time = format!("{}.{}", tool_time.as_secs(), tool_time.subsec_millis()); + match invoke_result { + Ok(result) => { + match result.output { + OutputKind::Text(ref text) => { + debug!("Output is Text: {}", text); + }, + OutputKind::Json(ref json) => { + debug!("Output is JSON: {}", json); + }, + OutputKind::Images(ref image) => { + image_blocks.extend(image.clone()); + }, + OutputKind::Mixed { ref text, ref images } => { + debug!("Output is Mixed: text = {:?}, images = {}", text, images.len()); + image_blocks.extend(images.clone()); + }, + } + + debug!("tool result output: {:#?}", result); + execute!( + self.stdout, + style::Print(CONTINUATION_LINE), + style::Print("\n"), + style::SetForegroundColor(Color::Green), + style::SetAttribute(Attribute::Bold), + style::Print(format!(" ● Completed in {}s", tool_time)), + style::SetForegroundColor(Color::Reset), + style::Print("\n\n"), + )?; + + tool_telemetry = tool_telemetry.and_modify(|ev| ev.is_success = Some(true)); + if let Tool::Custom(_) = &tool.tool { + tool_telemetry + .and_modify(|ev| ev.output_token_size = Some(TokenCounter::count_tokens(&result.as_str()))); + } + + // Send telemetry for agent contribution + if let Tool::FsWrite(w) = &tool.tool { + let sanitized_path_str = w.path(os).to_string_lossy().to_string(); + let conversation_id = self.conversation.conversation_id().to_string(); + let message_id = self.conversation.message_id().map(|s| s.to_string()); + if let Some(tracker) = self.conversation.file_line_tracker.get_mut(&sanitized_path_str) { + let lines_by_agent = tracker.lines_by_agent(); + let lines_by_user = tracker.lines_by_user(); + + os.telemetry + .send_agent_contribution_metric( + &os.database, + conversation_id, + message_id, + Some(tool.id.clone()), // Already a String + Some(tool.name.clone()), // Already a String + Some(lines_by_agent), + Some(lines_by_user), + ) + .await + .ok(); + + tracker.prev_fswrite_lines = tracker.after_fswrite_lines; + } + } + + tool_results.push(ToolUseResult { + tool_use_id: tool.id.clone(), + content: vec![result.into()], + status: ToolResultStatus::Success, + }); + }, + Err(err) => { + error!(?err, "An error occurred processing the tool"); + execute!( + self.stderr, + style::Print(CONTINUATION_LINE), + style::Print("\n"), + style::SetAttribute(Attribute::Bold), + style::SetForegroundColor(Color::Red), + style::Print(format!(" ● Execution failed after {}s:\n", tool_time)), + style::SetAttribute(Attribute::Reset), + style::SetForegroundColor(Color::Red), + style::Print(&err), + style::SetAttribute(Attribute::Reset), + style::Print("\n\n"), + )?; + + tool_telemetry.and_modify(|ev| { + ev.is_success = Some(false); + ev.reason_desc = Some(err.to_string()); + }); + tool_results.push(ToolUseResult { + tool_use_id: tool.id.clone(), + content: vec![ToolUseResultBlock::Text(format!( + "An error occurred processing the tool: \n{}", + &err + ))], + status: ToolResultStatus::Error, + }); + if let ToolUseStatus::Idle = self.tool_use_status { + self.tool_use_status = ToolUseStatus::RetryInProgress( + self.conversation + .message_id() + .map_or("No utterance id found".to_string(), |v| v.to_string()), + ); + } + }, + } + } + + if !image_blocks.is_empty() { + let images = image_blocks.into_iter().map(|(block, _)| block).collect(); + self.conversation.add_tool_results_with_images(tool_results, images); + execute!( + self.stderr, + style::SetAttribute(Attribute::Reset), + style::SetForegroundColor(Color::Reset), + style::Print("\n") + )?; + } else { + self.conversation.add_tool_results(tool_results); + } + + execute!(self.stderr, cursor::Hide)?; + execute!(self.stderr, style::Print("\n"), style::SetAttribute(Attribute::Reset))?; + if self.interactive { + self.spinner = Some(Spinner::new(Spinners::Dots, "Thinking...".to_string())); + } + + self.send_chat_telemetry(os, TelemetryResult::Succeeded, None, None, None, false) + .await; + self.send_tool_use_telemetry(os).await; + return Ok(ChatState::HandleResponseStream( + self.conversation + .as_sendable_conversation_state(os, &mut self.stderr, false) + .await?, + )); + } + + /// Sends a [crate::api_client::ApiClient::send_message] request to the backend and consumes + /// the response stream. + /// + /// In order to handle sigints while also keeping track of metadata about how the + /// response stream was handled, we need an extra parameter: + /// * `request_metadata_lock` - Updated with the [RequestMetadata] once it has been received + /// (either though a successful request, or on an error). + async fn handle_response( + &mut self, + os: &mut Os, + state: crate::api_client::model::ConversationState, + request_metadata_lock: Arc>>, + ) -> Result { + let mut rx = self.send_message(os, state, request_metadata_lock, None).await?; + + let request_id = rx.request_id().map(String::from); + + let mut buf = String::new(); + let mut offset = 0; + let mut ended = false; + let mut state = ParseState::new( + Some(self.terminal_width()), + os.database.settings.get_bool(Setting::ChatDisableMarkdownRendering), + ); + let mut response_prefix_printed = false; + + let mut tool_uses = Vec::new(); + let mut tool_name_being_recvd: Option = None; + + if self.spinner.is_some() { + drop(self.spinner.take()); + queue!( + self.stderr, + style::SetForegroundColor(Color::Reset), + cursor::MoveToColumn(0), + cursor::Show, + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + } + + loop { + match rx.recv().await { + Some(Ok(msg_event)) => { + trace!("Consumed: {:?}", msg_event); + match msg_event { + parser::ResponseEvent::ToolUseStart { name } => { + // We need to flush the buffer here, otherwise text will not be + // printed while we are receiving tool use events. + buf.push('\n'); + tool_name_being_recvd = Some(name); + }, + parser::ResponseEvent::AssistantText(text) => { + // Add Q response prefix before the first assistant text. + if !response_prefix_printed && !text.trim().is_empty() { + queue!( + self.stdout, + style::SetForegroundColor(Color::Green), + style::Print("> "), + style::SetForegroundColor(Color::Reset) + )?; + response_prefix_printed = true; + } + buf.push_str(&text); + }, + parser::ResponseEvent::ToolUse(tool_use) => { + if self.spinner.is_some() { + drop(self.spinner.take()); + queue!( + self.stderr, + terminal::Clear(terminal::ClearType::CurrentLine), + cursor::MoveToColumn(0), + cursor::Show + )?; + } + tool_uses.push(tool_use); + tool_name_being_recvd = None; + }, + parser::ResponseEvent::EndStream { + message, + request_metadata: rm, + } => { + // This log is attempting to help debug instances where users encounter + // the response timeout message. + if message.content() == RESPONSE_TIMEOUT_CONTENT { + error!(?request_id, ?message, "Encountered an unexpected model response"); + } + self.conversation.push_assistant_message(os, message, Some(rm.clone())); + self.user_turn_request_metadata.push(rm); + ended = true; + }, + } + }, + Some(Err(recv_error)) => { + if let Some(request_id) = &recv_error.request_metadata.request_id { + self.failed_request_ids.push(request_id.clone()); + }; + + self.user_turn_request_metadata + .push(recv_error.request_metadata.clone()); + let (reason, reason_desc) = get_error_reason(&recv_error); + let status_code = recv_error.status_code(); + + match recv_error.source { + RecvErrorKind::StreamTimeout { source, duration } => { + self.send_chat_telemetry( + os, + TelemetryResult::Failed, + Some(reason), + Some(reason_desc), + status_code, + false, // We retry the request, so don't end the current turn yet. + ) + .await; + + error!( + recv_error.request_metadata.request_id, + ?source, + "Encountered a stream timeout after waiting for {}s", + duration.as_secs() + ); + + execute!(self.stderr, cursor::Hide)?; + self.spinner = Some(Spinner::new(Spinners::Dots, "Dividing up the work...".to_string())); + + // For stream timeouts, we'll tell the model to try and split its response into + // smaller chunks. + self.conversation.push_assistant_message( + os, + AssistantMessage::new_response(None, RESPONSE_TIMEOUT_CONTENT.to_string()), + None, + ); + self.conversation + .set_next_user_message( + "You took too long to respond - try to split up the work into smaller steps." + .to_string(), + ) + .await; + self.send_tool_use_telemetry(os).await; + return Ok(ChatState::HandleResponseStream( + self.conversation + .as_sendable_conversation_state(os, &mut self.stderr, false) + .await?, + )); + }, + RecvErrorKind::UnexpectedToolUseEos { + tool_use_id, + name, + message, + .. + } => { + self.send_chat_telemetry( + os, + TelemetryResult::Failed, + Some(reason), + Some(reason_desc), + status_code, + false, // We retry the request, so don't end the current turn yet. + ) + .await; + + error!( + recv_error.request_metadata.request_id, + tool_use_id, name, "The response stream ended before the entire tool use was received" + ); + self.conversation + .push_assistant_message(os, *message, Some(recv_error.request_metadata)); + let tool_results = vec![ToolUseResult { + tool_use_id, + content: vec![ToolUseResultBlock::Text( + "The generated tool was too large, try again but this time split up the work between multiple tool uses".to_string(), + )], + status: ToolResultStatus::Error, + }]; + self.conversation.add_tool_results(tool_results); + self.send_tool_use_telemetry(os).await; + return Ok(ChatState::HandleResponseStream( + self.conversation + .as_sendable_conversation_state(os, &mut self.stderr, false) + .await?, + )); + }, + _ => { + self.send_chat_telemetry( + os, + TelemetryResult::Failed, + Some(reason), + Some(reason_desc), + status_code, + true, // Hard fail -> end the current user turn. + ) + .await; + + return Err(recv_error.into()); + }, + } + }, + None => { + warn!("response stream receiver closed before receiving a stop event"); + ended = true; + }, + } + + // Fix for the markdown parser copied over from q chat: + // this is a hack since otherwise the parser might report Incomplete with useful data + // still left in the buffer. I'm not sure how this is intended to be handled. + if ended { + buf.push('\n'); + } + + if tool_name_being_recvd.is_none() && !buf.is_empty() && self.spinner.is_some() { + drop(self.spinner.take()); + queue!( + self.stderr, + terminal::Clear(terminal::ClearType::CurrentLine), + cursor::MoveToColumn(0), + cursor::Show + )?; + } + + // Print the response for normal cases + loop { + let input = Partial::new(&buf[offset..]); + match interpret_markdown(input, &mut self.stdout, &mut state) { + Ok(parsed) => { + offset += parsed.offset_from(&input); + self.stdout.flush()?; + state.newline = state.set_newline; + state.set_newline = false; + }, + Err(err) => match err.into_inner() { + Some(err) => return Err(ChatError::Custom(err.to_string().into())), + None => break, // Data was incomplete + }, + } + + // TODO: We should buffer output based on how much we have to parse, not as a constant + // Do not remove unless you are nabochay :) + tokio::time::sleep(Duration::from_millis(8)).await; + } + + // Set spinner after showing all of the assistant text content so far. + if tool_name_being_recvd.is_some() { + queue!(self.stderr, cursor::Hide)?; + if self.interactive { + self.spinner = Some(Spinner::new(Spinners::Dots, "Thinking...".to_string())); + } + } + + if ended { + if os + .database + .settings + .get_bool(Setting::ChatEnableNotifications) + .unwrap_or(false) + { + // For final responses (no tools suggested), always play the bell + play_notification_bell(tool_uses.is_empty()); + } + + queue!(self.stderr, style::ResetColor, style::SetAttribute(Attribute::Reset))?; + execute!(self.stdout, style::Print("\n"))?; + + for (i, citation) in &state.citations { + queue!( + self.stdout, + style::Print("\n"), + style::SetForegroundColor(Color::Blue), + style::Print(format!("[^{i}]: ")), + style::SetForegroundColor(Color::DarkGrey), + style::Print(format!("{citation}\n")), + style::SetForegroundColor(Color::Reset) + )?; + } + + break; + } + } + + if !tool_uses.is_empty() { + Ok(ChatState::ValidateTools { tool_uses }) + } else { + self.tool_uses.clear(); + self.pending_tool_index = None; + self.tool_turn_start_time = None; + + self.send_chat_telemetry(os, TelemetryResult::Succeeded, None, None, None, true) + .await; + + Ok(ChatState::PromptUser { + skip_printing_tools: false, + }) + } + } + + async fn validate_tools(&mut self, os: &Os, tool_uses: Vec) -> Result { + let conv_id = self.conversation.conversation_id().to_owned(); + debug!(?tool_uses, "Validating tool uses"); + let mut queued_tools: Vec = Vec::new(); + let mut tool_results: Vec = Vec::new(); + + for tool_use in tool_uses { + let tool_use_id = tool_use.id.clone(); + let tool_use_name = tool_use.name.clone(); + let mut tool_telemetry = ToolUseEventBuilder::new( + conv_id.clone(), + tool_use.id.clone(), + self.conversation.model_info.as_ref().map(|m| m.model_id.clone()), + ) + .set_tool_use_id(tool_use_id.clone()) + .set_tool_name(tool_use.name.clone()) + .utterance_id(self.conversation.message_id().map(|s| s.to_string())); + match self.conversation.tool_manager.get_tool_from_tool_use(tool_use) { + Ok(mut tool) => { + // Apply non-Q-generated context to tools + self.contextualize_tool(&mut tool); + + match tool.validate(os).await { + Ok(()) => { + tool_telemetry.is_valid = Some(true); + queued_tools.push(QueuedTool { + id: tool_use_id.clone(), + name: tool_use_name, + tool, + accepted: false, + }); + }, + Err(err) => { + tool_telemetry.is_valid = Some(false); + tool_results.push(ToolUseResult { + tool_use_id: tool_use_id.clone(), + content: vec![ToolUseResultBlock::Text(format!( + "Failed to validate tool parameters: {err}" + ))], + status: ToolResultStatus::Error, + }); + }, + }; + }, + Err(err) => { + tool_telemetry.is_valid = Some(false); + tool_results.push(err.into()); + }, + } + self.tool_use_telemetry_events.insert(tool_use_id, tool_telemetry); + } + + // If we have any validation errors, then return them immediately to the model. + if !tool_results.is_empty() { + debug!(?tool_results, "Error found in the model tools"); + queue!( + self.stderr, + style::SetAttribute(Attribute::Bold), + style::Print("Tool validation failed: "), + style::SetAttribute(Attribute::Reset), + )?; + for tool_result in &tool_results { + for block in &tool_result.content { + let content: Option> = match block { + ToolUseResultBlock::Text(t) => Some(t.as_str().into()), + ToolUseResultBlock::Json(d) => serde_json::to_string(d) + .map_err(|err| error!(?err, "failed to serialize tool result content")) + .map(Into::into) + .ok(), + }; + if let Some(content) = content { + queue!( + self.stderr, + style::Print("\n"), + style::SetForegroundColor(Color::Red), + style::Print(format!("{}\n", content)), + style::SetForegroundColor(Color::Reset), + )?; + } + } + } + + self.conversation.add_tool_results(tool_results); + self.send_chat_telemetry(os, TelemetryResult::Succeeded, None, None, None, false) + .await; + self.send_tool_use_telemetry(os).await; + if let ToolUseStatus::Idle = self.tool_use_status { + self.tool_use_status = ToolUseStatus::RetryInProgress( + self.conversation + .message_id() + .map_or("No utterance id found".to_string(), |v| v.to_string()), + ); + } + + return Ok(ChatState::HandleResponseStream( + self.conversation + .as_sendable_conversation_state(os, &mut self.stderr, false) + .await?, + )); + } + + self.tool_uses = queued_tools; + self.pending_tool_index = Some(0); + self.tool_turn_start_time = Some(Instant::now()); + Ok(ChatState::ExecuteTools) + } + + async fn retry_model_overload(&mut self, os: &mut Os) -> Result { + os.client.invalidate_model_cache().await; + match select_model(os, self).await { + Ok(Some(_)) => (), + Ok(None) => { + // User did not select a model, so reset the current request state. + self.conversation.enforce_conversation_invariants(); + self.conversation.reset_next_user_message(); + self.pending_tool_index = None; + self.tool_turn_start_time = None; + return Ok(ChatState::PromptUser { + skip_printing_tools: false, + }); + }, + Err(err) => return Err(err), + } + + if self.interactive { + self.spinner = Some(Spinner::new(Spinners::Dots, "Thinking...".to_owned())); + } + + Ok(ChatState::HandleResponseStream( + self.conversation + .as_sendable_conversation_state(os, &mut self.stderr, true) + .await?, + )) + } + + /// Apply program context to tools that Q may not have. + // We cannot attach this any other way because Tools are constructed by deserializing + // output from Amazon Q. + // TODO: Is there a better way? + fn contextualize_tool(&self, tool: &mut Tool) { + if let Tool::GhIssue(gh_issue) = tool { + let allowed_tools = self + .conversation + .agents + .get_active() + .map(|a| a.allowed_tools.iter().cloned().collect::>()) + .unwrap_or_default(); + gh_issue.set_context(GhIssueContext { + // Ideally we avoid cloning, but this function is not called very often. + // Using references with lifetimes requires a large refactor, and Arc> + // seems like overkill and may incur some performance cost anyway. + context_manager: self.conversation.context_manager.clone(), + transcript: self.conversation.transcript.clone(), + failed_request_ids: self.failed_request_ids.clone(), + tool_permissions: allowed_tools, + }); + } + } + + async fn print_tool_description(&mut self, os: &Os, tool_index: usize, trusted: bool) -> Result<(), ChatError> { + let tool_use = &self.tool_uses[tool_index]; + + queue!( + self.stdout, + style::SetForegroundColor(Color::Magenta), + style::Print(format!( + "🛠️ Using tool: {}{}", + tool_use.tool.display_name(), + if trusted { " (trusted)".dark_green() } else { "".reset() } + )), + style::SetForegroundColor(Color::Reset) + )?; + if let Tool::Custom(ref tool) = tool_use.tool { + queue!( + self.stdout, + style::SetForegroundColor(Color::Reset), + style::Print(" from mcp server "), + style::SetForegroundColor(Color::Magenta), + style::Print(tool.client.get_server_name()), + style::SetForegroundColor(Color::Reset), + )?; + } + + execute!( + self.stdout, + style::Print("\n"), + style::Print(CONTINUATION_LINE), + style::Print("\n"), + style::Print(TOOL_BULLET) + )?; + + tool_use + .tool + .queue_description(os, &mut self.stdout) + .await + .map_err(|e| ChatError::Custom(format!("failed to print tool, `{}`: {}", tool_use.name, e).into()))?; + + Ok(()) + } + + /// Helper function to read user input with a prompt and Ctrl+C handling + fn read_user_input(&mut self, prompt: &str, exit_on_single_ctrl_c: bool) -> Option { + let mut ctrl_c = false; + loop { + match (self.input_source.read_line(Some(prompt)), ctrl_c) { + (Ok(Some(line)), _) => { + if line.trim().is_empty() { + continue; // Reprompt if the input is empty + } + return Some(line); + }, + (Ok(None), false) => { + if exit_on_single_ctrl_c { + return None; + } + execute!( + self.stderr, + style::Print(format!( + "\n(To exit the CLI, press Ctrl+C or Ctrl+D again or type {})\n\n", + "/quit".green() + )) + ) + .unwrap_or_default(); + ctrl_c = true; + }, + (Ok(None), true) => return None, // Exit if Ctrl+C was pressed twice + (Err(_), _) => return None, + } + } + } + + /// Helper function to generate a prompt based on the current context + fn generate_tool_trust_prompt(&mut self) -> String { + let profile = self.conversation.current_profile().map(|s| s.to_string()); + let all_trusted = self.all_tools_trusted(); + let tangent_mode = self.conversation.is_in_tangent_mode(); + prompt::generate_prompt(profile.as_deref(), all_trusted, tangent_mode) + } + + async fn send_tool_use_telemetry(&mut self, os: &Os) { + for (_, mut event) in self.tool_use_telemetry_events.drain() { + event.user_input_id = match self.tool_use_status { + ToolUseStatus::Idle => self.conversation.message_id(), + ToolUseStatus::RetryInProgress(ref id) => Some(id.as_str()), + } + .map(|v| v.to_string()); + + os.telemetry.send_tool_use_suggested(&os.database, event).await.ok(); + } + } + + fn terminal_width(&self) -> usize { + (self.terminal_width_provider)().unwrap_or(80) + } + + fn all_tools_trusted(&self) -> bool { + self.conversation.agents.trust_all_tools + } + + /// Display character limit warnings based on current conversation size + async fn display_char_warnings(&mut self, os: &Os) -> Result<(), ChatError> { + let warning_level = self.conversation.get_token_warning_level(os).await?; + + match warning_level { + TokenWarningLevel::Critical => { + // Memory constraint warning with gentler wording + execute!( + self.stderr, + style::SetForegroundColor(Color::Yellow), + style::SetAttribute(Attribute::Bold), + style::Print("\n⚠️ This conversation is getting lengthy.\n"), + style::SetAttribute(Attribute::Reset), + style::Print( + "To ensure continued smooth operation, please use /compact to summarize the conversation.\n\n" + ), + style::SetForegroundColor(Color::Reset) + )?; + }, + TokenWarningLevel::None => { + // No warning needed + }, + } + + Ok(()) + } + + /// Resets state associated with the active user turn. + /// + /// This should *always* be called whenever a new user prompt is sent to the backend. Note + /// that includes tool use rejections. + fn reset_user_turn(&mut self) { + info!(?self.user_turn_request_metadata, "Resetting the current user turn"); + self.user_turn_request_metadata.clear(); + } + + /// Sends an "codewhispererterminal_addChatMessage" telemetry event. + /// + /// This *MUST* be called in the following cases: + /// 1. After the end of a user turn + /// 2. After tool use execution has completed + /// 3. After an error was encountered during the handling of the response stream, tool use + /// validation, or tool use execution. + /// + /// [Self::user_turn_request_metadata] must contain the [RequestMetadata] associated with the + /// current user turn. + #[allow(clippy::too_many_arguments)] + async fn send_chat_telemetry( + &self, + os: &Os, + result: TelemetryResult, + reason: Option, + reason_desc: Option, + status_code: Option, + is_end_turn: bool, + ) { + // Get metadata for the most recent request. + let md = self.user_turn_request_metadata.last(); + + let conversation_id = self.conversation.conversation_id().to_owned(); + let data = ChatAddedMessageParams { + request_id: md.and_then(|md| md.request_id.clone()), + message_id: md.map(|md| md.message_id.clone()), + context_file_length: self.conversation.context_message_length(), + model: md.and_then(|m| m.model_id.clone()), + reason: reason.clone(), + reason_desc: reason_desc.clone(), + status_code, + time_to_first_chunk_ms: md.and_then(|md| md.time_to_first_chunk.map(|d| d.as_secs_f64() * 1000.0)), + time_between_chunks_ms: md.map(|md| { + md.time_between_chunks + .iter() + .map(|d| d.as_secs_f64() * 1000.0) + .collect::>() + }), + chat_conversation_type: md.and_then(|md| md.chat_conversation_type), + tool_use_id: self.conversation.latest_tool_use_ids(), + tool_name: self.conversation.latest_tool_use_names(), + assistant_response_length: md.map(|md| md.response_size as i32), + message_meta_tags: { + let mut tags = md.map(|md| md.message_meta_tags.clone()).unwrap_or_default(); + if self.conversation.is_in_tangent_mode() { + tags.push(crate::telemetry::core::MessageMetaTag::TangentMode); + } + tags + }, + }; + os.telemetry + .send_chat_added_message(&os.database, conversation_id.clone(), result, data) + .await + .ok(); + + if is_end_turn { + let mds = &self.user_turn_request_metadata; + + // Get the user turn duration. + let start_time = mds.first().map(|md| md.request_start_timestamp_ms); + let end_time = mds.last().map(|md| md.stream_end_timestamp_ms); + let user_turn_duration_seconds = match (start_time, end_time) { + // Convert ms back to seconds + (Some(start), Some(end)) => end.saturating_sub(start) as i64 / 1000, + _ => 0, + }; + + os.telemetry + .send_record_user_turn_completion(&os.database, conversation_id, result, RecordUserTurnCompletionArgs { + message_ids: mds.iter().map(|md| md.message_id.clone()).collect::<_>(), + request_ids: mds.iter().map(|md| md.request_id.clone()).collect::<_>(), + reason, + reason_desc, + status_code, + time_to_first_chunks_ms: mds + .iter() + .map(|md| md.time_to_first_chunk.map(|d| d.as_secs_f64() * 1000.0)) + .collect::<_>(), + chat_conversation_type: md.and_then(|md| md.chat_conversation_type), + assistant_response_length: mds.iter().map(|md| md.response_size as i64).sum(), + message_meta_tags: mds.last().map(|md| md.message_meta_tags.clone()).unwrap_or_default(), + user_prompt_length: mds.first().map(|md| md.user_prompt_length).unwrap_or_default() as i64, + user_turn_duration_seconds, + follow_up_count: mds + .iter() + .filter(|md| matches!(md.chat_conversation_type, Some(ChatConversationType::ToolUse))) + .count() as i64, + }) + .await + .ok(); + } + } + + async fn send_error_telemetry( + &self, + os: &Os, + reason: String, + reason_desc: Option, + status_code: Option, + ) { + let md = self.user_turn_request_metadata.last(); + os.telemetry + .send_response_error( + &os.database, + self.conversation.conversation_id().to_owned(), + self.conversation.context_message_length(), + TelemetryResult::Failed, + Some(reason), + reason_desc, + status_code, + md.and_then(|md| md.request_id.clone()), + md.map(|md| md.message_id.clone()), + ) + .await + .ok(); + } + + pub async fn send_slash_command_telemetry( + &self, + os: &Os, + command: String, + subcommand: Option, + result: TelemetryResult, + reason: Option, + ) { + let conversation_id = self.conversation.conversation_id().to_owned(); + if let Err(e) = os + .telemetry + .send_chat_slash_command_executed(&os.database, conversation_id, command, subcommand, result, reason) + .await + { + tracing::warn!("Failed to send slash command telemetry: {}", e); + } + } +} + +/// Replaces amzn_codewhisperer_client::types::SubscriptionStatus with a more descriptive type. +/// See response expectations in [`get_subscription_status`] for reasoning. +#[derive(Debug, Clone, PartialEq, Eq)] +enum ActualSubscriptionStatus { + Active, // User has paid for this month + Expiring, // User has paid for this month but cancelled + None, // User has not paid for this month +} + +// NOTE: The subscription API behaves in a non-intuitive way. We expect the following responses: +// +// 1. SubscriptionStatus::Active: +// - The user *has* a subscription, but it is set to *not auto-renew* (i.e., cancelled). +// - We return ActualSubscriptionStatus::Expiring to indicate they are eligible to re-subscribe +// +// 2. SubscriptionStatus::Inactive: +// - The user has no subscription at all (no Pro access). +// - We return ActualSubscriptionStatus::None to indicate they are eligible to subscribe. +// +// 3. ConflictException (as an error): +// - The user already has an active subscription *with auto-renewal enabled*. +// - We return ActualSubscriptionStatus::Active since they don’t need to subscribe again. +// +// Also, it is currently not possible to subscribe or re-subscribe via console, only IDE/CLI. +async fn get_subscription_status(os: &mut Os) -> Result { + if is_idc_user(&os.database).await? { + return Ok(ActualSubscriptionStatus::Active); + } + + match os.client.create_subscription_token().await { + Ok(response) => match response.status() { + SubscriptionStatus::Active => Ok(ActualSubscriptionStatus::Expiring), + SubscriptionStatus::Inactive => Ok(ActualSubscriptionStatus::None), + _ => Ok(ActualSubscriptionStatus::None), + }, + Err(ApiClientError::CreateSubscriptionToken(e)) => { + let sdk_error_code = e.as_service_error().and_then(|err| err.meta().code()); + + if sdk_error_code.is_some_and(|c| c.contains("ConflictException")) { + Ok(ActualSubscriptionStatus::Active) + } else { + Err(e.into()) + } + }, + Err(e) => Err(e.into()), + } +} + +async fn get_subscription_status_with_spinner( + os: &mut Os, + output: &mut impl Write, +) -> Result { + return with_spinner(output, "Checking subscription status...", || async { + get_subscription_status(os).await + }) + .await; +} + +async fn with_spinner(output: &mut impl std::io::Write, spinner_text: &str, f: F) -> Result +where + F: FnOnce() -> Fut, + Fut: std::future::Future>, +{ + queue!(output, cursor::Hide,).ok(); + let spinner = Some(Spinner::new(Spinners::Dots, spinner_text.to_owned())); + + let result = f().await; + + if let Some(mut s) = spinner { + s.stop(); + let _ = queue!( + output, + terminal::Clear(terminal::ClearType::CurrentLine), + cursor::MoveToColumn(0), + ); + } + + result +} + +/// Checks if an input may be referencing a file and should not be handled as a typical slash +/// command. If true, then return [Option::Some], otherwise [Option::None]. +fn does_input_reference_file(input: &str) -> Option { + let after_slash = input.strip_prefix("/")?; + + if let Some(first) = shlex::split(after_slash).unwrap_or_default().first() { + let looks_like_path = + first.contains(MAIN_SEPARATOR) || first.contains('/') || first.contains('\\') || first.contains('.'); + + if looks_like_path { + return Some(ChatState::HandleInput { + input: after_slash.to_string(), + }); + } + } + + None +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use super::*; + use crate::cli::agent::Agent; + + async fn get_test_agents(os: &Os) -> Agents { + const AGENT_PATH: &str = "/persona/TestAgent.json"; + let mut agents = Agents::default(); + let agent = Agent { + path: Some(PathBuf::from(AGENT_PATH)), + ..Default::default() + }; + if let Ok(false) = os.fs.try_exists(AGENT_PATH).await { + let content = agent.to_str_pretty().expect("Failed to serialize test agent to file"); + let agent_path = PathBuf::from(AGENT_PATH); + os.fs + .create_dir_all( + agent_path + .parent() + .expect("Failed to obtain parent path for agent config"), + ) + .await + .expect("Failed to create test agent dir"); + os.fs + .write(agent_path, &content) + .await + .expect("Failed to write test agent to file"); + } + agents.agents.insert("TestAgent".to_string(), agent); + agents.switch("TestAgent").expect("Failed to switch agent"); + agents + } + + #[tokio::test] + async fn test_flow() { + let mut os = Os::new().await.unwrap(); + os.client.set_mock_output(serde_json::json!([ + [ + "Sure, I'll create a file for you", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file.txt", + } + } + ], + [ + "Hope that looks good to you!", + ], + ])); + + let agents = get_test_agents(&os).await; + let tool_manager = ToolManager::default(); + let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) + .expect("Tools failed to load"); + ChatSession::new( + &mut os, + std::io::stdout(), + std::io::stderr(), + "fake_conv_id", + agents, + None, + InputSource::new_mock(vec![ + "create a new file".to_string(), + "y".to_string(), + "exit".to_string(), + ]), + false, + || Some(80), + tool_manager, + None, + tool_config, + true, + false, + ) + .await + .unwrap() + .spawn(&mut os) + .await + .unwrap(); + + assert_eq!(os.fs.read_to_string("/file.txt").await.unwrap(), "Hello, world!\n"); + } + + #[tokio::test] + async fn test_flow_tool_permissions() { + let mut os = Os::new().await.unwrap(); + os.client.set_mock_output(serde_json::json!([ + [ + "Ok", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file1.txt", + } + } + ], + [ + "Done", + ], + [ + "Ok", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file2.txt", + } + } + ], + [ + "Done", + ], + [ + "Ok", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file3.txt", + } + } + ], + [ + "Done", + ], + [ + "Ok", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file4.txt", + } + } + ], + [ + "Ok, I won't make it.", + ], + [ + "Ok", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file5.txt", + } + } + ], + [ + "Done", + ], + [ + "Ok", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file6.txt", + } + } + ], + [ + "Ok, I won't make it.", + ], + ])); + + let agents = get_test_agents(&os).await; + let tool_manager = ToolManager::default(); + let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) + .expect("Tools failed to load"); + ChatSession::new( + &mut os, + std::io::stdout(), + std::io::stderr(), + "fake_conv_id", + agents, + None, + InputSource::new_mock(vec![ + "/tools".to_string(), + "/tools help".to_string(), + "create a new file".to_string(), + "y".to_string(), + "create a new file".to_string(), + "t".to_string(), + "create a new file".to_string(), // should make without prompting due to 't' + "/tools untrust fs_write".to_string(), + "create a file".to_string(), // prompt again due to untrust + "n".to_string(), // cancel + "/tools trust fs_write".to_string(), + "create a file".to_string(), // again without prompting due to '/tools trust' + "/tools reset".to_string(), + "create a file".to_string(), // prompt again due to reset + "n".to_string(), // cancel + "exit".to_string(), + ]), + false, + || Some(80), + tool_manager, + None, + tool_config, + true, + false, + ) + .await + .unwrap() + .spawn(&mut os) + .await + .unwrap(); + + assert_eq!(os.fs.read_to_string("/file2.txt").await.unwrap(), "Hello, world!\n"); + assert_eq!(os.fs.read_to_string("/file3.txt").await.unwrap(), "Hello, world!\n"); + assert!(!os.fs.exists("/file4.txt")); + assert_eq!(os.fs.read_to_string("/file5.txt").await.unwrap(), "Hello, world!\n"); + // TODO: fix this with agent change (dingfeli) + // assert!(!ctx.fs.exists("/file6.txt")); + } + + #[tokio::test] + async fn test_flow_multiple_tools() { + // let _ = tracing_subscriber::fmt::try_init(); + let mut os = Os::new().await.unwrap(); + os.client.set_mock_output(serde_json::json!([ + [ + "Sure, I'll create a file for you", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file1.txt", + } + }, + { + "tool_use_id": "2", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file2.txt", + } + } + ], + [ + "Done", + ], + [ + "Sure, I'll create a file for you", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file3.txt", + } + }, + { + "tool_use_id": "2", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file4.txt", + } + } + ], + [ + "Done", + ], + ])); + + let agents = get_test_agents(&os).await; + let tool_manager = ToolManager::default(); + let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) + .expect("Tools failed to load"); + ChatSession::new( + &mut os, + std::io::stdout(), + std::io::stderr(), + "fake_conv_id", + agents, + None, + InputSource::new_mock(vec![ + "create 2 new files parallel".to_string(), + "t".to_string(), + "/tools reset".to_string(), + "create 2 new files parallel".to_string(), + "y".to_string(), + "y".to_string(), + "exit".to_string(), + ]), + false, + || Some(80), + tool_manager, + None, + tool_config, + true, + false, + ) + .await + .unwrap() + .spawn(&mut os) + .await + .unwrap(); + + assert_eq!(os.fs.read_to_string("/file1.txt").await.unwrap(), "Hello, world!\n"); + assert_eq!(os.fs.read_to_string("/file2.txt").await.unwrap(), "Hello, world!\n"); + assert_eq!(os.fs.read_to_string("/file3.txt").await.unwrap(), "Hello, world!\n"); + assert_eq!(os.fs.read_to_string("/file4.txt").await.unwrap(), "Hello, world!\n"); + } + + #[tokio::test] + async fn test_flow_tools_trust_all() { + // let _ = tracing_subscriber::fmt::try_init(); + let mut os = Os::new().await.unwrap(); + os.client.set_mock_output(serde_json::json!([ + [ + "Sure, I'll create a file for you", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file1.txt", + } + } + ], + [ + "Done", + ], + [ + "Sure, I'll create a file for you", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file3.txt", + } + } + ], + [ + "Ok I won't.", + ], + ])); + + let agents = get_test_agents(&os).await; + let tool_manager = ToolManager::default(); + let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) + .expect("Tools failed to load"); + ChatSession::new( + &mut os, + std::io::stdout(), + std::io::stderr(), + "fake_conv_id", + agents, + None, + InputSource::new_mock(vec![ + "/tools trust-all".to_string(), + "create a new file".to_string(), + "/tools reset".to_string(), + "create a new file".to_string(), + "exit".to_string(), + ]), + false, + || Some(80), + tool_manager, + None, + tool_config, + true, + false, + ) + .await + .unwrap() + .spawn(&mut os) + .await + .unwrap(); + + assert_eq!(os.fs.read_to_string("/file1.txt").await.unwrap(), "Hello, world!\n"); + assert!(!os.fs.exists("/file2.txt")); + } + + #[test] + fn test_editor_content_processing() { + // Since we no longer have template replacement, this test is simplified + let cases = vec![ + ("My content", "My content"), + ("My content with newline\n", "My content with newline"), + ("", ""), + ]; + + for (input, expected) in cases { + let processed = input.trim().to_string(); + assert_eq!(processed, expected.trim().to_string(), "Failed for input: {}", input); + } + } + + #[tokio::test] + #[cfg(unix)] + async fn test_subscribe_flow() { + let mut os = Os::new().await.unwrap(); + os.client.set_mock_output(serde_json::Value::Array(vec![])); + let agents = get_test_agents(&os).await; + + let tool_manager = ToolManager::default(); + let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) + .expect("Tools failed to load"); + ChatSession::new( + &mut os, + std::io::stdout(), + std::io::stderr(), + "fake_conv_id", + agents, + None, + InputSource::new_mock(vec!["/subscribe".to_string(), "y".to_string(), "/quit".to_string()]), + false, + || Some(80), + tool_manager, + None, + tool_config, + true, + false, + ) + .await + .unwrap() + .spawn(&mut os) + .await + .unwrap(); + } + + #[test] + fn test_does_input_reference_file() { + let tests = &[ + ( + r"/Users/user/Desktop/Screenshot\ 2025-06-30\ at\ 2.13.34 PM.png read this image for me", + true, + ), + ("/path/to/file.json", true), + ("/save output.json", false), + ("~/does/not/start/with/slash", false), + ]; + for (input, expected) in tests { + let actual = does_input_reference_file(input).is_some(); + assert_eq!(actual, *expected, "expected {} for input {}", expected, input); + } + } +} + + +---- crates/chat-cli/src/cli/chat/prompt.rs ---- +use std::borrow::Cow; +use std::cell::RefCell; + +use eyre::Result; +use rustyline::completion::{ + Completer, + FilenameCompleter, + extract_word, +}; +use rustyline::error::ReadlineError; +use rustyline::highlight::{ + CmdKind, + Highlighter, +}; +use rustyline::hint::Hinter as RustylineHinter; +use rustyline::history::DefaultHistory; +use rustyline::validate::{ + ValidationContext, + ValidationResult, + Validator, +}; +use rustyline::{ + Cmd, + Completer, + CompletionType, + Config, + Context, + EditMode, + Editor, + EventHandler, + Helper, + Hinter, + KeyCode, + KeyEvent, + Modifiers, +}; +use winnow::stream::AsChar; + +pub use super::prompt_parser::generate_prompt; +use super::prompt_parser::parse_prompt_components; +use super::tool_manager::{ + PromptQuery, + PromptQueryResult, +}; +use crate::database::settings::Setting; +use crate::os::Os; + +pub const COMMANDS: &[&str] = &[ + "/clear", + "/help", + "/editor", + "/issue", + "/quit", + "/tools", + "/tools trust", + "/tools untrust", + "/tools trust-all", + "/tools reset", + "/mcp", + "/model", + "/agent", + "/agent help", + "/agent list", + "/agent create", + "/agent delete", + "/agent rename", + "/agent set", + "/agent schema", + "/prompts", + "/context", + "/context help", + "/context show", + "/context show --expand", + "/context add", + "/context rm", + "/context clear", + "/hooks", + "/hooks help", + "/hooks add", + "/hooks rm", + "/hooks enable", + "/hooks disable", + "/hooks enable-all", + "/hooks disable-all", + "/compact", + "/compact help", + "/usage", + "/save", + "/load", + "/subscribe", +]; + +pub type PromptQuerySender = tokio::sync::broadcast::Sender; +pub type PromptQueryResponseReceiver = tokio::sync::broadcast::Receiver; + +/// Complete commands that start with a slash +fn complete_command(word: &str, start: usize) -> (usize, Vec) { + ( + start, + COMMANDS + .iter() + .filter(|p| p.starts_with(word)) + .map(|s| (*s).to_owned()) + .collect(), + ) +} + +/// A wrapper around FilenameCompleter that provides enhanced path detection +/// and completion capabilities for the chat interface. +pub struct PathCompleter { + /// The underlying filename completer from rustyline + filename_completer: FilenameCompleter, +} + +impl PathCompleter { + /// Creates a new PathCompleter instance + pub fn new() -> Self { + Self { + filename_completer: FilenameCompleter::new(), + } + } + + /// Attempts to complete a file path at the given position in the line + pub fn complete_path( + &self, + line: &str, + pos: usize, + os: &Context<'_>, + ) -> Result<(usize, Vec), ReadlineError> { + // Use the filename completer to get path completions + match self.filename_completer.complete(line, pos, os) { + Ok((pos, completions)) => { + // Convert the filename completer's pairs to strings + let file_completions: Vec = completions.iter().map(|pair| pair.replacement.clone()).collect(); + + // Return the completions if we have any + Ok((pos, file_completions)) + }, + Err(err) => Err(err), + } + } +} + +pub struct PromptCompleter { + sender: PromptQuerySender, + receiver: RefCell, +} + +impl PromptCompleter { + fn new(sender: PromptQuerySender, receiver: PromptQueryResponseReceiver) -> Self { + PromptCompleter { + sender, + receiver: RefCell::new(receiver), + } + } + + fn complete_prompt(&self, word: &str) -> Result, ReadlineError> { + let sender = &self.sender; + let receiver = self.receiver.borrow_mut(); + let query = PromptQuery::Search(if !word.is_empty() { Some(word.to_string()) } else { None }); + + sender + .send(query) + .map_err(|e| ReadlineError::Io(std::io::Error::other(e.to_string())))?; + // We only want stuff from the current tail end onward + let mut new_receiver = receiver.resubscribe(); + + // Here we poll on the receiver for [max_attempts] number of times. + // The reason for this is because we are trying to receive something managed by an async + // channel from a sync context. + // If we ever switch back to a single threaded runtime for whatever reason, this function + // will not panic but nothing will be fetched because the thread that is doing + // try_recv is also the thread that is supposed to be doing the sending. + let mut attempts = 0; + let max_attempts = 5; + let query_res = loop { + match new_receiver.try_recv() { + Ok(result) => break result, + Err(_e) if attempts < max_attempts - 1 => { + attempts += 1; + std::thread::sleep(std::time::Duration::from_millis(100)); + }, + Err(e) => { + return Err(ReadlineError::Io(std::io::Error::other(eyre::eyre!( + "Failed to receive prompt info from complete prompt after {} attempts: {:?}", + max_attempts, + e + )))); + }, + } + }; + let matches = match query_res { + PromptQueryResult::Search(list) => list.into_iter().map(|n| format!("@{n}")).collect::>(), + PromptQueryResult::List(_) => { + return Err(ReadlineError::Io(std::io::Error::other(eyre::eyre!( + "Wrong query response type received", + )))); + }, + }; + + Ok(matches) + } +} + +pub struct ChatCompleter { + path_completer: PathCompleter, + prompt_completer: PromptCompleter, +} + +impl ChatCompleter { + fn new(sender: PromptQuerySender, receiver: PromptQueryResponseReceiver) -> Self { + Self { + path_completer: PathCompleter::new(), + prompt_completer: PromptCompleter::new(sender, receiver), + } + } +} + +impl Completer for ChatCompleter { + type Candidate = String; + + fn complete( + &self, + line: &str, + pos: usize, + _os: &Context<'_>, + ) -> Result<(usize, Vec), ReadlineError> { + let (start, word) = extract_word(line, pos, None, |c| c.is_space()); + + // Handle command completion + if word.starts_with('/') { + return Ok(complete_command(word, start)); + } + + if line.starts_with('@') { + let search_word = line.strip_prefix('@').unwrap_or(""); + if let Ok(completions) = self.prompt_completer.complete_prompt(search_word) { + if !completions.is_empty() { + return Ok((0, completions)); + } + } + } + + // Handle file path completion as fallback + if let Ok((pos, completions)) = self.path_completer.complete_path(line, pos, _os) { + if !completions.is_empty() { + return Ok((pos, completions)); + } + } + + // Default: no completions + Ok((start, Vec::new())) + } +} + +/// Custom hinter that provides shadowtext suggestions +pub struct ChatHinter { + /// Command history for providing suggestions based on past commands + history: Vec, + /// Whether history-based hints are enabled + history_hints_enabled: bool, +} + +impl ChatHinter { + /// Creates a new ChatHinter instance + pub fn new(history_hints_enabled: bool) -> Self { + Self { + history: Vec::new(), + history_hints_enabled, + } + } + + /// Updates the history with a new command + pub fn update_history(&mut self, command: &str) { + let command = command.trim(); + if !command.is_empty() && !command.contains('\n') && !command.contains('\r') { + self.history.push(command.to_string()); + } + } + + /// Finds the best hint for the current input + fn find_hint(&self, line: &str) -> Option { + // If line is empty, no hint + if line.is_empty() { + return None; + } + + // If line starts with a slash, try to find a command hint + if line.starts_with('/') { + return COMMANDS + .iter() + .find(|cmd| cmd.starts_with(line)) + .map(|cmd| cmd[line.len()..].to_string()); + } + + // Try to find a hint from history if history hints are enabled + if self.history_hints_enabled { + return self.history + .iter() + .rev() // Start from most recent + .find(|cmd| cmd.starts_with(line) && cmd.len() > line.len()) + .map(|cmd| cmd[line.len()..].to_string()); + } + + None + } +} + +impl RustylineHinter for ChatHinter { + type Hint = String; + + fn hint(&self, line: &str, pos: usize, _ctx: &Context<'_>) -> Option { + // Only provide hints when cursor is at the end of the line + if pos < line.len() { + return None; + } + + self.find_hint(line) + } +} + +/// Custom validator for multi-line input +pub struct MultiLineValidator; + +impl Validator for MultiLineValidator { + fn validate(&self, os: &mut ValidationContext<'_>) -> rustyline::Result { + let input = os.input(); + + // Check for code block markers + if input.contains("```") { + // Count the number of ``` occurrences + let triple_backtick_count = input.matches("```").count(); + + // If we have an odd number of ```, we're in an incomplete code block + if triple_backtick_count % 2 == 1 { + return Ok(ValidationResult::Incomplete); + } + } + + // Check for backslash continuation + if input.ends_with('\\') { + return Ok(ValidationResult::Incomplete); + } + + Ok(ValidationResult::Valid(None)) + } +} + +#[derive(Helper, Completer, Hinter)] +pub struct ChatHelper { + #[rustyline(Completer)] + completer: ChatCompleter, + #[rustyline(Hinter)] + hinter: ChatHinter, + validator: MultiLineValidator, +} + +impl ChatHelper { + /// Updates the history of the ChatHinter with a new command + pub fn update_hinter_history(&mut self, command: &str) { + self.hinter.update_history(command); + } +} + +impl Validator for ChatHelper { + fn validate(&self, os: &mut ValidationContext<'_>) -> rustyline::Result { + self.validator.validate(os) + } +} + +impl Highlighter for ChatHelper { + fn highlight_hint<'h>(&self, hint: &'h str) -> Cow<'h, str> { + Cow::Owned(format!("\x1b[38;5;240m{hint}\x1b[m")) + } + + fn highlight<'l>(&self, line: &'l str, _pos: usize) -> Cow<'l, str> { + Cow::Borrowed(line) + } + + fn highlight_char(&self, _line: &str, _pos: usize, _kind: CmdKind) -> bool { + false + } + + fn highlight_prompt<'b, 's: 'b, 'p: 'b>(&'s self, prompt: &'p str, _default: bool) -> Cow<'b, str> { + use crossterm::style::Stylize; + + // Parse the plain text prompt to extract profile and warning information + // and apply colors using crossterm's ANSI escape codes + if let Some(components) = parse_prompt_components(prompt) { + let mut result = String::new(); + + // Add profile part if present (cyan) + if let Some(profile) = components.profile { + result.push_str(&format!("[{}] ", profile).cyan().to_string()); + } + + // Add tangent indicator if present (yellow) + if components.tangent_mode { + result.push_str(&"↯ ".yellow().to_string()); + } + + // Add warning symbol if present (red) + if components.warning { + result.push_str(&"!".red().to_string()); + } + + // Add the prompt symbol (magenta) + result.push_str(&"> ".magenta().to_string()); + + Cow::Owned(result) + } else { + // If we can't parse the prompt, return it as-is + Cow::Borrowed(prompt) + } + } +} + +pub fn rl( + os: &Os, + sender: PromptQuerySender, + receiver: PromptQueryResponseReceiver, +) -> Result> { + let edit_mode = match os.database.settings.get_string(Setting::ChatEditMode).as_deref() { + Some("vi" | "vim") => EditMode::Vi, + _ => EditMode::Emacs, + }; + let config = Config::builder() + .history_ignore_space(true) + .completion_type(CompletionType::List) + .edit_mode(edit_mode) + .build(); + + // Default to disabled if setting doesn't exist + let history_hints_enabled = os + .database + .settings + .get_bool(Setting::ChatEnableHistoryHints) + .unwrap_or(false); + let h = ChatHelper { + completer: ChatCompleter::new(sender, receiver), + hinter: ChatHinter::new(history_hints_enabled), + validator: MultiLineValidator, + }; + + let mut rl = Editor::with_config(config)?; + rl.set_helper(Some(h)); + + // Add custom keybinding for Alt+Enter to insert a newline + rl.bind_sequence( + KeyEvent(KeyCode::Enter, Modifiers::ALT), + EventHandler::Simple(Cmd::Insert(1, "\n".to_string())), + ); + + // Add custom keybinding for Ctrl+J to insert a newline + rl.bind_sequence( + KeyEvent(KeyCode::Char('j'), Modifiers::CTRL), + EventHandler::Simple(Cmd::Insert(1, "\n".to_string())), + ); + + // Add custom keybinding for Ctrl+F to accept hint (like fish shell) + rl.bind_sequence( + KeyEvent(KeyCode::Char('f'), Modifiers::CTRL), + EventHandler::Simple(Cmd::CompleteHint), + ); + + // Add custom keybinding for Ctrl+T to toggle tangent mode (configurable) + let tangent_key_char = match os.database.settings.get_string(Setting::TangentModeKey) { + Some(key) if key.len() == 1 => key.chars().next().unwrap_or('t'), + _ => 't', // Default to 't' if setting is missing or invalid + }; + rl.bind_sequence( + KeyEvent(KeyCode::Char(tangent_key_char), Modifiers::CTRL), + EventHandler::Simple(Cmd::Insert(1, "/tangent".to_string())), + ); + + Ok(rl) +} + +#[cfg(test)] +mod tests { + use crossterm::style::Stylize; + use rustyline::highlight::Highlighter; + + use super::*; + + #[test] + fn test_chat_completer_command_completion() { + let (prompt_request_sender, _) = tokio::sync::broadcast::channel::(5); + let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); + let completer = ChatCompleter::new(prompt_request_sender, prompt_response_receiver); + let line = "/h"; + let pos = 2; // Position at the end of "/h" + + // Create a mock context with empty history + let empty_history = DefaultHistory::new(); + let os = Context::new(&empty_history); + + // Get completions + let (start, completions) = completer.complete(line, pos, &os).unwrap(); + + // Verify start position + assert_eq!(start, 0); + + // Verify completions contain expected commands + assert!(completions.contains(&"/help".to_string())); + } + + #[test] + fn test_chat_completer_no_completion() { + let (prompt_request_sender, _) = tokio::sync::broadcast::channel::(5); + let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); + let completer = ChatCompleter::new(prompt_request_sender, prompt_response_receiver); + let line = "Hello, how are you?"; + let pos = line.len(); + + // Create a mock context with empty history + let empty_history = DefaultHistory::new(); + let os = Context::new(&empty_history); + + // Get completions + let (_, completions) = completer.complete(line, pos, &os).unwrap(); + + // Verify no completions are returned for regular text + assert!(completions.is_empty()); + } + + #[test] + fn test_highlight_prompt_basic() { + let (prompt_request_sender, _) = tokio::sync::broadcast::channel::(5); + let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); + let helper = ChatHelper { + completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), + hinter: ChatHinter::new(true), + validator: MultiLineValidator, + }; + + // Test basic prompt highlighting + let highlighted = helper.highlight_prompt("> ", true); + + assert_eq!(highlighted, "> ".magenta().to_string()); + } + + #[test] + fn test_highlight_prompt_with_warning() { + let (prompt_request_sender, _) = tokio::sync::broadcast::channel::(5); + let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); + let helper = ChatHelper { + completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), + hinter: ChatHinter::new(true), + validator: MultiLineValidator, + }; + + // Test warning prompt highlighting + let highlighted = helper.highlight_prompt("!> ", true); + + assert_eq!(highlighted, format!("{}{}", "!".red(), "> ".magenta())); + } + + #[test] + fn test_highlight_prompt_with_profile() { + let (prompt_request_sender, _) = tokio::sync::broadcast::channel::(5); + let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); + let helper = ChatHelper { + completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), + hinter: ChatHinter::new(true), + validator: MultiLineValidator, + }; + + // Test profile prompt highlighting + let highlighted = helper.highlight_prompt("[test-profile] > ", true); + + assert_eq!(highlighted, format!("{}{}", "[test-profile] ".cyan(), "> ".magenta())); + } + + #[test] + fn test_highlight_prompt_with_profile_and_warning() { + let (prompt_request_sender, _) = tokio::sync::broadcast::channel::(5); + let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); + let helper = ChatHelper { + completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), + hinter: ChatHinter::new(true), + validator: MultiLineValidator, + }; + + // Test profile + warning prompt highlighting + let highlighted = helper.highlight_prompt("[dev] !> ", true); + // Should have cyan profile + red warning + cyan bold prompt + assert_eq!( + highlighted, + format!("{}{}{}", "[dev] ".cyan(), "!".red(), "> ".magenta()) + ); + } + + #[test] + fn test_highlight_prompt_invalid_format() { + let (prompt_request_sender, _) = tokio::sync::broadcast::channel::(5); + let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); + let helper = ChatHelper { + completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), + hinter: ChatHinter::new(true), + validator: MultiLineValidator, + }; + + // Test invalid prompt format (should return as-is) + let invalid_prompt = "invalid prompt format"; + let highlighted = helper.highlight_prompt(invalid_prompt, true); + assert_eq!(highlighted, invalid_prompt); + } + + #[test] + fn test_highlight_prompt_tangent_mode() { + let (prompt_request_sender, _) = tokio::sync::broadcast::channel::(1); + let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(1); + let helper = ChatHelper { + completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), + hinter: ChatHinter::new(true), + validator: MultiLineValidator, + }; + + // Test tangent mode prompt highlighting - ↯ yellow, > magenta + let highlighted = helper.highlight_prompt("↯ > ", true); + assert_eq!(highlighted, format!("{}{}", "↯ ".yellow(), "> ".magenta())); + } + + #[test] + fn test_highlight_prompt_tangent_mode_with_warning() { + let (prompt_request_sender, _) = tokio::sync::broadcast::channel::(1); + let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(1); + let helper = ChatHelper { + completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), + hinter: ChatHinter::new(true), + validator: MultiLineValidator, + }; + + // Test tangent mode with warning - ↯ yellow, ! red, > magenta + let highlighted = helper.highlight_prompt("↯ !> ", true); + assert_eq!(highlighted, format!("{}{}{}", "↯ ".yellow(), "!".red(), "> ".magenta())); + } + + #[test] + fn test_highlight_prompt_profile_with_tangent_mode() { + let (prompt_request_sender, _) = tokio::sync::broadcast::channel::(1); + let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(1); + let helper = ChatHelper { + completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), + hinter: ChatHinter::new(true), + validator: MultiLineValidator, + }; + + // Test profile with tangent mode - [dev] cyan, ↯ yellow, > magenta + let highlighted = helper.highlight_prompt("[dev] ↯ > ", true); + assert_eq!( + highlighted, + format!("{}{}{}", "[dev] ".cyan(), "↯ ".yellow(), "> ".magenta()) + ); + } + + #[test] + fn test_chat_hinter_command_hint() { + let hinter = ChatHinter::new(true); + + // Test hint for a command + let line = "/he"; + let pos = line.len(); + let empty_history = DefaultHistory::new(); + let ctx = Context::new(&empty_history); + + let hint = hinter.hint(line, pos, &ctx); + assert_eq!(hint, Some("lp".to_string())); + + // Test hint when cursor is not at the end + let hint = hinter.hint(line, 1, &ctx); + assert_eq!(hint, None); + + // Test hint for a non-existent command + let line = "/xyz"; + let pos = line.len(); + let hint = hinter.hint(line, pos, &ctx); + assert_eq!(hint, None); + + // Test hint for a multi-line command + let line = "/abcd\nefg"; + let pos = line.len(); + let hint = hinter.hint(line, pos, &ctx); + assert_eq!(hint, None); + } + + #[test] + fn test_chat_hinter_history_hint_disabled() { + let mut hinter = ChatHinter::new(false); + + // Add some history + hinter.update_history("Hello, world!"); + hinter.update_history("How are you?"); + + // Test hint from history - should be None since history hints are disabled + let line = "How"; + let pos = line.len(); + let empty_history = DefaultHistory::new(); + let ctx = Context::new(&empty_history); + + let hint = hinter.hint(line, pos, &ctx); + assert_eq!(hint, None); + } +} + + +---- crates/chat-cli/src/cli/chat/tool_manager.rs ---- +use std::borrow::Borrow; +use std::collections::{ + HashMap, + HashSet, +}; +use std::future::Future; +use std::hash::{ + DefaultHasher, + Hasher, +}; +use std::io::{ + BufWriter, + Write, +}; +use std::path::PathBuf; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{ + AtomicBool, + Ordering, +}; +use std::time::{ + Duration, + Instant, +}; + +use crossterm::{ + cursor, + execute, + queue, + style, + terminal, +}; +use eyre::Report; +use futures::{ + StreamExt, + future, + stream, +}; +use regex::Regex; +use tokio::signal::ctrl_c; +use tokio::sync::{ + Mutex, + Notify, + RwLock, +}; +use tokio::task::JoinHandle; +use tracing::{ + error, + info, + warn, +}; + +use super::tools::custom_tool::CustomToolConfig; +use crate::api_client::model::{ + ToolResult, + ToolResultContentBlock, + ToolResultStatus, +}; +use crate::cli::agent::{ + Agent, + McpServerConfig, +}; +use crate::cli::chat::cli::prompts::GetPromptError; +use crate::cli::chat::consts::DUMMY_TOOL_NAME; +use crate::cli::chat::message::AssistantToolUse; +use crate::cli::chat::server_messenger::{ + ServerMessengerBuilder, + UpdateEventMessage, +}; +use crate::cli::chat::tools::custom_tool::{ + CustomTool, + CustomToolClient, +}; +use crate::cli::chat::tools::execute::ExecuteCommand; +use crate::cli::chat::tools::fs_read::FsRead; +use crate::cli::chat::tools::fs_write::FsWrite; +use crate::cli::chat::tools::gh_issue::GhIssue; +use crate::cli::chat::tools::knowledge::Knowledge; +use crate::cli::chat::tools::thinking::Thinking; +use crate::cli::chat::tools::use_aws::UseAws; +use crate::cli::chat::tools::{ + Tool, + ToolOrigin, + ToolSpec, +}; +use crate::database::Database; +use crate::database::settings::Setting; +use crate::mcp_client::{ + JsonRpcResponse, + Messenger, + PromptGet, +}; +use crate::os::Os; +use crate::telemetry::TelemetryThread; +use crate::util::MCP_SERVER_TOOL_DELIMITER; +use crate::util::directories::home_dir; + +const NAMESPACE_DELIMITER: &str = "___"; +// This applies for both mcp server and tool name since in the end the tool name as seen by the +// model is just {server_name}{NAMESPACE_DELIMITER}{tool_name} +const VALID_TOOL_NAME: &str = "^[a-zA-Z][a-zA-Z0-9_]*$"; +const SPINNER_CHARS: [char; 10] = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']; + +pub fn workspace_mcp_config_path(os: &Os) -> eyre::Result { + Ok(os.env.current_dir()?.join(".amazonq").join("mcp.json")) +} + +pub fn global_mcp_config_path(os: &Os) -> eyre::Result { + Ok(home_dir(os)?.join(".aws").join("amazonq").join("mcp.json")) +} + +/// Messages used for communication between the tool initialization thread and the loading +/// display thread. These messages control the visual loading indicators shown to +/// the user during tool initialization. +enum LoadingMsg { + /// Indicates a tool has finished initializing successfully and should be removed from + /// the loading display. The String parameter is the name of the tool that + /// completed initialization. + Done { name: String, time: String }, + /// Represents an error that occurred during tool initialization. + /// Contains the name of the server that failed to initialize and the error message. + Error { + name: String, + msg: eyre::Report, + time: String, + }, + /// Represents a warning that occurred during tool initialization. + /// Contains the name of the server that generated the warning and the warning message. + Warn { + name: String, + msg: eyre::Report, + time: String, + }, + /// Signals that the loading display thread should terminate. + /// This is sent when all tool initialization is complete or when the application is shutting + /// down. + Terminate { still_loading: Vec }, +} + +/// Used to denote the loading outcome associated with a server. +/// This is mainly used in the non-interactive mode to determine if there is any fatal errors to +/// surface (since we would only want to surface fatal errors in non-interactive mode). +#[derive(Clone, Debug)] +pub enum LoadingRecord { + Success(String), + Warn(String), + Err(String), +} + +pub struct ToolManagerBuilder { + prompt_query_result_sender: Option>, + prompt_query_receiver: Option>, + prompt_query_sender: Option>, + prompt_query_result_receiver: Option>, + messenger_builder: Option, + conversation_id: Option, + has_new_stuff: Arc, + mcp_load_record: Arc>>>, + new_tool_specs: NewToolSpecs, + is_first_launch: bool, + agent: Option>>, +} + +impl Default for ToolManagerBuilder { + fn default() -> Self { + Self { + prompt_query_result_sender: Default::default(), + prompt_query_receiver: Default::default(), + prompt_query_sender: Default::default(), + prompt_query_result_receiver: Default::default(), + messenger_builder: Default::default(), + conversation_id: Default::default(), + has_new_stuff: Default::default(), + mcp_load_record: Default::default(), + new_tool_specs: Default::default(), + is_first_launch: true, + agent: Default::default(), + } + } +} + +impl From<&mut ToolManager> for ToolManagerBuilder { + fn from(value: &mut ToolManager) -> Self { + Self { + conversation_id: Some(value.conversation_id.clone()), + agent: Some(value.agent.clone()), + prompt_query_sender: value + .prompts_sender_receiver_pair + .as_ref() + .map(|(sender, _)| sender.clone()), + prompt_query_result_receiver: value.prompts_sender_receiver_pair.take().map(|(_, receiver)| receiver), + messenger_builder: value.messenger_builder.take(), + has_new_stuff: value.has_new_stuff.clone(), + mcp_load_record: value.mcp_load_record.clone(), + new_tool_specs: value.new_tool_specs.clone(), + // if we are getting a builder from an instantiated tool manager this field would be + // false + is_first_launch: false, + ..Default::default() + } + } +} + +impl ToolManagerBuilder { + pub fn prompt_query_result_sender(mut self, sender: tokio::sync::broadcast::Sender) -> Self { + self.prompt_query_result_sender.replace(sender); + self + } + + pub fn prompt_query_receiver(mut self, receiver: tokio::sync::broadcast::Receiver) -> Self { + self.prompt_query_receiver.replace(receiver); + self + } + + pub fn prompt_query_sender(mut self, sender: tokio::sync::broadcast::Sender) -> Self { + self.prompt_query_sender.replace(sender); + self + } + + pub fn prompt_query_result_receiver( + mut self, + receiver: tokio::sync::broadcast::Receiver, + ) -> Self { + self.prompt_query_result_receiver.replace(receiver); + self + } + + pub fn conversation_id(mut self, conversation_id: &str) -> Self { + self.conversation_id.replace(conversation_id.to_string()); + self + } + + pub fn agent(mut self, agent: Agent) -> Self { + let agent = Arc::new(Mutex::new(agent)); + self.agent.replace(agent); + self + } + + /// Creates a [ToolManager] based on the current fields populated, which consists of the + /// following: + /// - Instantiates child processes associated with the list of mcp servers in scope + /// - Spawns a loading display task that is used to show server loading status (if applicable) + /// - Spawns the orchestrator task (see [spawn_orchestrator_task] for more detail) (if + /// applicable) + /// - Finally, creates an instance of [ToolManager] + pub async fn build( + mut self, + os: &mut Os, + mut output: Box, + interactive: bool, + ) -> eyre::Result { + let McpServerConfig { mcp_servers } = match &self.agent { + Some(agent) => agent.lock().await.mcp_servers.clone(), + None => Default::default(), + }; + debug_assert!(self.conversation_id.is_some()); + let conversation_id = self.conversation_id.ok_or(eyre::eyre!("Missing conversation id"))?; + + // Separate enabled and disabled servers + let (enabled_servers, disabled_servers): (Vec<_>, Vec<_>) = mcp_servers + .into_iter() + .partition(|(_, server_config)| !server_config.disabled); + + // Prepare disabled servers for display + let disabled_servers_display: Vec = disabled_servers + .iter() + .map(|(server_name, _)| server_name.clone()) + .collect(); + + let pre_initialized = enabled_servers + .into_iter() + .filter_map(|(server_name, server_config)| { + if server_name == "builtin" { + let _ = queue!( + output, + style::SetForegroundColor(style::Color::Red), + style::Print("✗ Invalid server name "), + style::SetForegroundColor(style::Color::Blue), + style::Print(&server_name), + style::ResetColor, + style::Print(". Server name cannot contain reserved word "), + style::SetForegroundColor(style::Color::Yellow), + style::Print("builtin"), + style::ResetColor, + style::Print(" (it is used to denote native tools)\n") + ); + None + } else { + let custom_tool_client = CustomToolClient::from_config(server_name.clone(), server_config, os); + Some((server_name, custom_tool_client)) + } + }) + .collect::>(); + + let mut loading_servers = HashMap::::new(); + for (server_name, _) in &pre_initialized { + let init_time = std::time::Instant::now(); + loading_servers.insert(server_name.clone(), init_time); + } + let total = loading_servers.len(); + + // Spawn a task for displaying the mcp loading statuses. + // This is only necessary when we are in interactive mode AND there are servers to load. + // Otherwise we do not need to be spawning this. + let (loading_display_task, loading_status_sender) = + spawn_display_task(interactive, total, disabled_servers, output); + + let mut clients = HashMap::>::new(); + let new_tool_specs = self.new_tool_specs; + let has_new_stuff = self.has_new_stuff; + let pending = Arc::new(RwLock::new(HashSet::::new())); + let notify = Arc::new(Notify::new()); + let load_record = self.mcp_load_record; + let agent = self.agent.unwrap_or_default(); + let database = os.database.clone(); + let mut messenger_builder = self.messenger_builder.take(); + + // This is the orchestrator task that serves as a bridge between tool manager and mcp + // clients for server initiated async events + if let (Some(prompt_list_sender), Some(prompt_list_receiver)) = ( + self.prompt_query_result_sender.clone(), + self.prompt_query_receiver.as_ref().map(|r| r.resubscribe()), + ) { + let (msg_rx, builder) = ServerMessengerBuilder::new(20); + messenger_builder.replace(builder); + + let has_new_stuff = has_new_stuff.clone(); + let notify_weak = Arc::downgrade(¬ify); + let telemetry = os.telemetry.clone(); + let loading_status_sender = loading_status_sender.clone(); + let new_tool_specs = new_tool_specs.clone(); + let conv_id = conversation_id.clone(); + let pending = pending.clone(); + let regex = Regex::new(VALID_TOOL_NAME)?; + + spawn_orchestrator_task( + has_new_stuff, + loading_servers, + msg_rx, + prompt_list_receiver, + prompt_list_sender, + pending, + agent.clone(), + database, + regex, + notify_weak, + load_record.clone(), + telemetry, + loading_status_sender, + new_tool_specs, + total, + conv_id, + ); + } + + debug_assert!(messenger_builder.is_some()); + let messenger_builder = messenger_builder.unwrap(); + for (mut name, init_res) in pre_initialized { + let mut messenger = messenger_builder.build_with_name(name.clone()); + match init_res { + Ok(mut client) => { + let pid = client.get_pid(); + messenger.pid = pid; + client.assign_messenger(Box::new(messenger)); + let mut client = Arc::new(client); + while let Some(collided_client) = clients.insert(name.clone(), client) { + // to avoid server name collision we are going to circumvent this by + // appending the name with 1 + name.push('1'); + client = collided_client; + } + }, + Err(e) => { + error!("Error initializing mcp client for server {}: {:?}", name, &e); + os.telemetry + .send_mcp_server_init( + &os.database, + conversation_id.clone(), + name, + Some(e.to_string()), + 0, + Some("".to_string()), + Some("".to_string()), + 0, + ) + .await + .ok(); + let _ = messenger.send_tools_list_result(Err(e)).await; + }, + } + } + + Ok(ToolManager { + conversation_id, + clients, + pending_clients: pending, + notify: Some(notify), + loading_status_sender, + loading_display_task, + new_tool_specs, + has_new_stuff, + is_interactive: interactive, + mcp_load_record: load_record, + agent, + disabled_servers: disabled_servers_display, + prompts_sender_receiver_pair: { + if let (Some(sender), Some(receiver)) = (self.prompt_query_sender, self.prompt_query_result_receiver) { + Some((sender, receiver)) + } else { + None + } + }, + messenger_builder: Some(messenger_builder), + is_first_launch: self.is_first_launch, + ..Default::default() + }) + } +} + +#[derive(Clone, Debug)] +/// A collection of information that is used for the following purposes: +/// - Checking if prompt info cached is out of date +/// - Retrieve new prompt info +pub struct PromptBundle { + /// The server name from which the prompt is offered / exposed + pub server_name: String, + /// The prompt get (info with which a prompt is retrieved) cached + pub prompt_get: PromptGet, +} + +#[derive(Clone, Debug)] +pub enum PromptQuery { + List, + Search(Option), +} + +#[derive(Clone, Debug)] +pub enum PromptQueryResult { + List(HashMap>), + Search(Vec), +} + +/// Categorizes different types of tool name validation failures: +/// - `TooLong`: The tool name exceeds the maximum allowed length +/// - `IllegalChar`: The tool name contains characters that are not allowed +/// - `EmptyDescription`: The tool description is empty or missing +#[allow(dead_code)] +enum OutOfSpecName { + TooLong(String), + IllegalChar(String), + EmptyDescription(String), +} + +#[derive(Clone, Default, Debug, Eq, PartialEq)] +pub struct ToolInfo { + pub server_name: String, + pub host_tool_name: HostToolName, +} + +impl Borrow for ToolInfo { + fn borrow(&self) -> &HostToolName { + &self.host_tool_name + } +} + +impl std::hash::Hash for ToolInfo { + fn hash(&self, state: &mut H) { + self.host_tool_name.hash(state); + } +} + +/// Tool name as recognized by the model. This is [HostToolName] post sanitization. +type ModelToolName = String; + +/// Tool name as recognized by the host (i.e. Q CLI). This is identical to how each MCP server +/// exposed them. +type HostToolName = String; + +/// MCP server name as they are defined in the config +type ServerName = String; + +/// A list of new tools to be included in the main chat loop. +/// The vector of [ToolSpec] is a comprehensive list of all tools exposed by the server. +/// The hashmap of [ModelToolName]: [HostToolName] are mapping of tool names that have been changed +/// (which is a subset of the tools that are in the aforementioned vector) +/// Note that [ToolSpec] is model facing and thus will have names that are model facing (i.e. model +/// tool name). +type NewToolSpecs = Arc, Vec)>>>; + +/// A pair of channels used for prompt list communication between the tool manager and chat helper. +/// The sender broadcasts a list of available prompt names, while the receiver listens for +/// search queries to filter the prompt list. +type PromptsChannelPair = ( + tokio::sync::broadcast::Sender, + tokio::sync::broadcast::Receiver, +); + +#[derive(Default, Debug)] +/// Manages the lifecycle and interactions with tools from various sources, including MCP servers. +/// This struct is responsible for initializing tools, handling tool requests, and maintaining +/// a cache of available prompts from connected servers. +pub struct ToolManager { + /// Unique identifier for the current conversation. + /// This ID is used to track and associate tools with a specific chat session. + pub conversation_id: String, + + /// Map of server names to their corresponding client instances. + /// These clients are used to communicate with MCP servers. + pub clients: HashMap>, + + /// A list of client names that are still in the process of being initialized + pub pending_clients: Arc>>, + + /// Flag indicating whether new tool specifications have been added since the last update. + /// When set to true, it signals that the tool manager needs to refresh its internal state + /// to incorporate newly available tools from MCP servers. + pub has_new_stuff: Arc, + + /// Used by methods on the [ToolManager] to retrieve information from the orchestrator thread + prompts_sender_receiver_pair: Option, + + /// Storage for newly discovered tool specifications from MCP servers that haven't yet been + /// integrated into the main tool registry. This field holds a thread-safe reference to a map + /// of server names to their tool specifications and name mappings, allowing concurrent updates + /// from server initialization processes. + new_tool_specs: NewToolSpecs, + + /// A notifier to understand if the initial loading has completed. + /// This is only used for initial loading and is discarded after. + notify: Option>, + + /// Channel sender for communicating with the loading display thread. + /// Used to send status updates about tool initialization progress. + loading_status_sender: Option>, + + /// This is here so we can await it to avoid output buffer from the display task interleaving + /// with other buffer displayed by chat. + loading_display_task: Option>>, + + /// Mapping from sanitized tool names to original tool names. + /// This is used to handle tool name transformations that may occur during initialization + /// to ensure tool names comply with naming requirements. + pub tn_map: HashMap, + + /// A cache of tool's input schema for all of the available tools. + /// This is mainly used to show the user what the tools look like from the perspective of the + /// model. + pub schema: HashMap, + + is_interactive: bool, + + /// This serves as a record of the loading of mcp servers. + /// The key of which is the server name as they are recognized by the current instance of chat + /// (which may be different than how it is written in the config, depending of the presence of + /// invalid characters). + /// The value is the load message (i.e. load time, warnings, and errors) + pub mcp_load_record: Arc>>>, + + /// List of disabled MCP server names for display purposes + disabled_servers: Vec, + + /// A builder for mcp clients to communicate with the orchestrator task + /// We need to store this for when we switch agent - we need to be spawning messengers that are + /// already listened to by the orchestrator task + messenger_builder: Option, + + /// A collection of preferences that pertains to the conversation + /// As far as tool manager goes, this is relevant for tool and server filters + /// We need to put this behind a lock because the orchestrator task depends on agent + pub agent: Arc>, + + is_first_launch: bool, +} + +impl Clone for ToolManager { + fn clone(&self) -> Self { + Self { + conversation_id: self.conversation_id.clone(), + clients: self.clients.clone(), + has_new_stuff: self.has_new_stuff.clone(), + new_tool_specs: self.new_tool_specs.clone(), + tn_map: self.tn_map.clone(), + schema: self.schema.clone(), + is_interactive: self.is_interactive, + mcp_load_record: self.mcp_load_record.clone(), + disabled_servers: self.disabled_servers.clone(), + ..Default::default() + } + } +} + +impl ToolManager { + /// Swapping agent involves the following: + /// - Dropping all of the clients first to avoid resource contention + /// - Clearing fields that are already referenced by background tasks. We can't simply spawn new + /// instances of these fields because one or more background tasks are already depending on it + /// - Building a new tool manager builder from the current tool manager + /// - Building a tool manager from said tool manager builder + /// - Swapping the old with the new (the old would be dropped after we exit the scope of this + /// function) + /// - Calling load tools + pub async fn swap_agent(&mut self, os: &mut Os, output: &mut impl Write, agent: &Agent) -> eyre::Result<()> { + self.clients.clear(); + + let mut agent_lock = self.agent.lock().await; + *agent_lock = agent.clone(); + drop(agent_lock); + + self.mcp_load_record.lock().await.clear(); + + let builder = ToolManagerBuilder::from(&mut *self); + let mut new_tool_manager = builder.build(os, Box::new(std::io::sink()), true).await?; + std::mem::swap(self, &mut new_tool_manager); + + // we can discard the output here and let background server load take care of getting the + // new tools + let _ = self.load_tools(os, output).await?; + + Ok(()) + } + + pub async fn load_tools( + &mut self, + os: &mut Os, + stderr: &mut impl Write, + ) -> eyre::Result> { + let tx = self.loading_status_sender.take(); + let notify = self.notify.take(); + self.schema = { + let tool_list = &self.agent.lock().await.tools; + let is_allow_all = tool_list.len() == 1 && tool_list.first().is_some_and(|n| n == "*"); + let is_allow_native = tool_list.iter().any(|t| t.as_str() == "@builtin"); + let mut tool_specs = + serde_json::from_str::>(include_str!("tools/tool_index.json"))? + .into_iter() + .filter(|(name, _)| { + name == DUMMY_TOOL_NAME + || is_allow_all + || is_allow_native + || tool_list.contains(name) + || tool_list.contains(&format!("@builtin/{name}")) + }) + .collect::>(); + if !crate::cli::chat::tools::thinking::Thinking::is_enabled(os) { + tool_specs.remove("thinking"); + } + if !crate::cli::chat::tools::knowledge::Knowledge::is_enabled(os) { + tool_specs.remove("knowledge"); + } + + #[cfg(windows)] + { + use serde_json::json; + + use crate::cli::chat::tools::InputSchema; + + tool_specs.remove("execute_bash"); + + tool_specs.insert("execute_cmd".to_string(), ToolSpec { + name: "execute_cmd".to_string(), + description: "Execute the specified Windows command.".to_string(), + input_schema: InputSchema(json!({ + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "Windows command to execute" + }, + "summary": { + "type": "string", + "description": "A brief explanation of what the command does" + } + }, + "required": ["command"]})), + tool_origin: ToolOrigin::Native, + }); + } + + tool_specs + }; + let load_tools = self + .clients + .values() + .map(|c| { + let clone = Arc::clone(c); + async move { clone.init().await } + }) + .collect::>(); + let initial_poll = stream::iter(load_tools) + .map(|async_closure| tokio::spawn(async_closure)) + .buffer_unordered(20); + tokio::spawn(async move { + initial_poll.collect::>().await; + }); + // We need to cast it to erase the type otherwise the compiler will default to static + // dispatch, which would result in an error of inconsistent match arm return type. + let timeout_fut: Pin>> = if self.clients.is_empty() || !self.is_first_launch { + // If there is no server loaded, we want to resolve immediately + Box::pin(future::ready(())) + } else if self.is_interactive { + let init_timeout = os + .database + .settings + .get_int(Setting::McpInitTimeout) + .map_or(5000_u64, |s| s as u64); + Box::pin(tokio::time::sleep(std::time::Duration::from_millis(init_timeout))) + } else { + // if it is non-interactive we will want to use the "mcp.noInteractiveTimeout" + let init_timeout = os + .database + .settings + .get_int(Setting::McpNoInteractiveTimeout) + .map_or(30_000_u64, |s| s as u64); + Box::pin(tokio::time::sleep(std::time::Duration::from_millis(init_timeout))) + }; + let server_loading_fut: Pin>> = if let Some(notify) = notify { + Box::pin(async move { notify.notified().await }) + } else { + Box::pin(future::ready(())) + }; + let loading_display_task = self.loading_display_task.take(); + tokio::select! { + _ = timeout_fut => { + if let Some(tx) = tx { + let still_loading = self.pending_clients.read().await.iter().cloned().collect::>(); + let _ = tx.send(LoadingMsg::Terminate { still_loading }).await; + if let Some(task) = loading_display_task { + let _ = tokio::time::timeout( + std::time::Duration::from_millis(80), + task + ).await; + } + } + if !self.clients.is_empty() && !self.is_interactive { + let _ = queue!( + stderr, + style::Print( + "Not all mcp servers loaded. Configure non-interactive timeout with q settings mcp.noInteractiveTimeout" + ), + style::Print("\n------\n") + ); + } + }, + _ = server_loading_fut => { + if let Some(tx) = tx { + let still_loading = self.pending_clients.read().await.iter().cloned().collect::>(); + let _ = tx.send(LoadingMsg::Terminate { still_loading }).await; + } + } + _ = ctrl_c() => { + if self.is_interactive { + if let Some(tx) = tx { + let still_loading = self.pending_clients.read().await.iter().cloned().collect::>(); + let _ = tx.send(LoadingMsg::Terminate { still_loading }).await; + } + } else { + return Err(eyre::eyre!("User interrupted mcp server loading in non-interactive mode. Ending.")); + } + } + } + if !self.is_interactive + && self + .mcp_load_record + .lock() + .await + .iter() + .any(|(_, records)| records.iter().any(|record| matches!(record, LoadingRecord::Err(_)))) + { + queue!( + stderr, + style::Print( + "One or more mcp server did not load correctly. See $TMPDIR/qlog/chat.log for more details." + ), + style::Print("\n------\n") + )?; + } + stderr.flush()?; + self.update().await; + Ok(self.schema.clone()) + } + + pub fn get_tool_from_tool_use(&self, value: AssistantToolUse) -> Result { + let map_err = |parse_error| ToolResult { + tool_use_id: value.id.clone(), + content: vec![ToolResultContentBlock::Text(format!( + "Failed to validate tool parameters: {parse_error}. The model has either suggested tool parameters which are incompatible with the existing tools, or has suggested one or more tool that does not exist in the list of known tools." + ))], + status: ToolResultStatus::Error, + }; + + Ok(match value.name.as_str() { + "fs_read" => Tool::FsRead(serde_json::from_value::(value.args).map_err(map_err)?), + "fs_write" => Tool::FsWrite(serde_json::from_value::(value.args).map_err(map_err)?), + #[cfg(windows)] + "execute_cmd" => { + Tool::ExecuteCommand(serde_json::from_value::(value.args).map_err(map_err)?) + }, + #[cfg(not(windows))] + "execute_bash" => { + Tool::ExecuteCommand(serde_json::from_value::(value.args).map_err(map_err)?) + }, + "use_aws" => Tool::UseAws(serde_json::from_value::(value.args).map_err(map_err)?), + "report_issue" => Tool::GhIssue(serde_json::from_value::(value.args).map_err(map_err)?), + "thinking" => Tool::Thinking(serde_json::from_value::(value.args).map_err(map_err)?), + "knowledge" => Tool::Knowledge(serde_json::from_value::(value.args).map_err(map_err)?), + // Note that this name is namespaced with server_name{DELIMITER}tool_name + name => { + // Note: tn_map also has tools that underwent no transformation. In otherwords, if + // it is a valid tool name, we should get a hit. + let ToolInfo { + server_name, + host_tool_name: tool_name, + } = match self.tn_map.get(name) { + Some(tool_info) => Ok::<&ToolInfo, ToolResult>(tool_info), + None => { + // No match, we throw an error + Err(ToolResult { + tool_use_id: value.id.clone(), + content: vec![ToolResultContentBlock::Text(format!( + "No tool with \"{name}\" is found" + ))], + status: ToolResultStatus::Error, + }) + }, + }?; + let Some(client) = self.clients.get(server_name) else { + return Err(ToolResult { + tool_use_id: value.id, + content: vec![ToolResultContentBlock::Text(format!( + "The tool, \"{server_name}\" is not supported by the client" + ))], + status: ToolResultStatus::Error, + }); + }; + // The tool input schema has the shape of { type, properties }. + // The field "params" expected by MCP is { name, arguments }, where name is the + // name of the tool being invoked, + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#calling-tools. + // The field "arguments" is where ToolUse::args belong. + let mut params = serde_json::Map::::new(); + params.insert("name".to_owned(), serde_json::Value::String(tool_name.to_owned())); + params.insert("arguments".to_owned(), value.args); + let params = serde_json::Value::Object(params); + let custom_tool = CustomTool { + name: tool_name.to_owned(), + client: client.clone(), + method: "tools/call".to_owned(), + params: Some(params), + }; + Tool::Custom(custom_tool) + }, + }) + } + + /// Updates tool managers various states with new information + pub async fn update(&mut self) { + // A hashmap of + let mut tool_specs = HashMap::::new(); + let new_tools = { + let mut new_tool_specs = self.new_tool_specs.lock().await; + new_tool_specs.drain().fold( + HashMap::, Vec)>::new(), + |mut acc, (server_name, v)| { + acc.insert(server_name, v); + acc + }, + ) + }; + + let mut updated_servers = HashSet::::new(); + let mut conflicts = HashMap::::new(); + for (server_name, (tool_name_map, specs)) in new_tools { + // First we evict the tools that were already in the tn_map + self.tn_map.retain(|_, tool_info| tool_info.server_name != server_name); + + // And update them with the new tools queried + // valid: tools that do not have conflicts in naming + let (valid, invalid) = tool_name_map + .into_iter() + .partition::, _>(|(model_tool_name, _)| { + !self.tn_map.contains_key(model_tool_name) + }); + // We reject tools that are conflicting with the existing tools by not including them + // in the tn_map. We would also want to report this error. + if !invalid.is_empty() { + let msg = invalid.into_iter().fold("The following tools are rejected because they conflict with existing tools in names. Avoid this via setting aliases for them: \n".to_string(), |mut acc, (model_tool_name, tool_info)| { + acc.push_str(&format!(" - {} from {}\n", model_tool_name, tool_info.server_name)); + acc + }); + conflicts.insert(server_name, msg); + } + if let Some(spec) = specs.first() { + updated_servers.insert(spec.tool_origin.clone()); + } + // We want to filter for specs that are valid + // Note that [ToolSpec::name] is a model facing name (thus you should be comparing it + // with the keys of a tn_map) + for spec in specs.into_iter().filter(|spec| valid.contains_key(&spec.name)) { + tool_specs.insert(spec.name.clone(), spec); + } + + self.tn_map.extend(valid); + } + + // Update schema + // As we are writing over the ensemble of tools in a given server, we will need to first + // remove everything that it has. + self.schema + .retain(|_tool_name, spec| !updated_servers.contains(&spec.tool_origin)); + self.schema.extend(tool_specs); + + // if block here to avoid repeatedly asking for loc + if !conflicts.is_empty() { + let mut record_lock = self.mcp_load_record.lock().await; + for (server_name, msg) in conflicts { + let record = LoadingRecord::Err(msg); + record_lock + .entry(server_name) + .and_modify(|v| v.push(record.clone())) + .or_insert(vec![record]); + } + } + } + + pub async fn list_prompts(&self) -> Result>, GetPromptError> { + if let Some((query_sender, query_result_receiver)) = &self.prompts_sender_receiver_pair { + let mut new_receiver = query_result_receiver.resubscribe(); + query_sender + .send(PromptQuery::List) + .map_err(|e| GetPromptError::General(eyre::eyre!(e)))?; + let query_result = new_receiver + .recv() + .await + .map_err(|e| GetPromptError::General(eyre::eyre!(e)))?; + + Ok(match query_result { + PromptQueryResult::List(list) => list, + PromptQueryResult::Search(_) => return Err(GetPromptError::IncorrectResponseType), + }) + } else { + Err(GetPromptError::MissingChannel) + } + } + + pub async fn get_prompt( + &self, + name: String, + arguments: Option>, + ) -> Result { + let (server_name, prompt_name) = match name.split_once('/') { + None => (None::, Some(name.clone())), + Some((server_name, prompt_name)) => (Some(server_name.to_string()), Some(prompt_name.to_string())), + }; + let prompt_name = prompt_name.ok_or(GetPromptError::MissingPromptName)?; + + if let Some((query_sender, query_result_receiver)) = &self.prompts_sender_receiver_pair { + query_sender + .send(PromptQuery::List) + .map_err(|e| GetPromptError::General(eyre::eyre!(e)))?; + let prompts = query_result_receiver + .resubscribe() + .recv() + .await + .map_err(|e| GetPromptError::General(eyre::eyre!(e)))?; + let PromptQueryResult::List(prompts) = prompts else { + return Err(GetPromptError::IncorrectResponseType); + }; + + match (prompts.get(&prompt_name), server_name.as_ref()) { + // If we have more than one eligible clients but no server name specified + (Some(bundles), None) if bundles.len() > 1 => { + Err(GetPromptError::AmbiguousPrompt(prompt_name.clone(), { + bundles.iter().fold("\n".to_string(), |mut acc, b| { + acc.push_str(&format!("- @{}/{}\n", b.server_name, prompt_name)); + acc + }) + })) + }, + // Normal case where we have enough info to proceed + // Note that if bundle exists, it should never be empty + (Some(bundles), sn) => { + let bundle = if bundles.len() > 1 { + let Some(sn) = sn else { + return Err(GetPromptError::AmbiguousPrompt(prompt_name.clone(), { + bundles.iter().fold("\n".to_string(), |mut acc, b| { + acc.push_str(&format!("- @{}/{}\n", b.server_name, prompt_name)); + acc + }) + })); + }; + let bundle = bundles.iter().find(|b| b.server_name == *sn); + match bundle { + Some(bundle) => bundle, + None => { + return Err(GetPromptError::AmbiguousPrompt(prompt_name.clone(), { + bundles.iter().fold("\n".to_string(), |mut acc, b| { + acc.push_str(&format!("- @{}/{}\n", b.server_name, prompt_name)); + acc + }) + })); + }, + } + } else { + bundles.first().ok_or(GetPromptError::MissingPromptInfo)? + }; + + let server_name = &bundle.server_name; + let client = self.clients.get(server_name).ok_or(GetPromptError::MissingClient)?; + let PromptBundle { prompt_get, .. } = bundle; + let args = if let (Some(schema), Some(value)) = (&prompt_get.arguments, &arguments) { + let params = schema.iter().zip(value.iter()).fold( + HashMap::::new(), + |mut acc, (prompt_get_arg, value)| { + acc.insert(prompt_get_arg.name.clone(), value.clone()); + acc + }, + ); + Some(serde_json::json!(params)) + } else { + None + }; + let params = { + let mut params = serde_json::Map::new(); + params.insert("name".to_string(), serde_json::Value::String(prompt_name)); + if let Some(args) = args { + params.insert("arguments".to_string(), args); + } + Some(serde_json::Value::Object(params)) + }; + let resp = client.request("prompts/get", params).await?; + Ok(resp) + }, + (None, _) => Err(GetPromptError::PromptNotFound(prompt_name)), + } + } else { + Err(GetPromptError::MissingChannel) + } + } + + pub async fn pending_clients(&self) -> Vec { + self.pending_clients.read().await.iter().cloned().collect::>() + } +} + +type DisplayTaskJoinHandle = JoinHandle>; +type LoadingStatusSender = tokio::sync::mpsc::Sender; + +/// This function spawns a background task whose sole responsibility is to listen for incoming +/// server loading status and display them to the output. +/// It returns a join handle to the task as well as a sender with which loading status is to be +/// reported. +fn spawn_display_task( + interactive: bool, + total: usize, + disabled_servers: Vec<(String, CustomToolConfig)>, + mut output: Box, +) -> (Option, Option) { + if interactive && (total > 0 || !disabled_servers.is_empty()) { + let (tx, mut rx) = tokio::sync::mpsc::channel::(50); + ( + Some(tokio::task::spawn(async move { + let mut spinner_logo_idx: usize = 0; + let mut complete: usize = 0; + let mut failed: usize = 0; + + // Show disabled servers immediately + for (server_name, _) in &disabled_servers { + queue_disabled_message(server_name, &mut output)?; + } + + if total > 0 { + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; + } + + loop { + match tokio::time::timeout(Duration::from_millis(50), rx.recv()).await { + Ok(Some(recv_result)) => match recv_result { + LoadingMsg::Done { name, time } => { + complete += 1; + execute!( + output, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + queue_success_message(&name, &time, &mut output)?; + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; + }, + LoadingMsg::Error { name, msg, time } => { + failed += 1; + execute!( + output, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + queue_failure_message(&name, &msg, time.as_str(), &mut output)?; + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; + }, + LoadingMsg::Warn { name, msg, time } => { + complete += 1; + execute!( + output, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + let msg = eyre::eyre!(msg.to_string()); + queue_warn_message(&name, &msg, time.as_str(), &mut output)?; + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; + }, + LoadingMsg::Terminate { still_loading } => { + if !still_loading.is_empty() && total > 0 { + execute!( + output, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + let msg = still_loading.iter().fold(String::new(), |mut acc, server_name| { + acc.push_str(format!("\n - {server_name}").as_str()); + acc + }); + let msg = eyre::eyre!(msg); + queue_incomplete_load_message(complete, total, &msg, &mut output)?; + } else if total > 0 { + // Clear the loading line if we have enabled servers + execute!( + output, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + } + execute!(output, style::Print("\n"),)?; + break; + }, + }, + Err(_e) => { + spinner_logo_idx = (spinner_logo_idx + 1) % SPINNER_CHARS.len(); + execute!( + output, + cursor::SavePosition, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + style::Print(SPINNER_CHARS[spinner_logo_idx]), + cursor::RestorePosition + )?; + }, + _ => break, + } + output.flush()?; + } + Ok::<_, eyre::Report>(()) + })), + Some(tx), + ) + } else { + (None, None) + } +} + +/// This function spawns the orchestrator task that has the following responsibilities: +/// - Listens for server driven events (see [UpdateEventMessage] for a list of current applicable +/// events). These are things such as tool list (because we fetch tools in the background), prompt +/// list, tool list update, and prompt list updates. In the future, if when we support sampling +/// and we have not yet moved to the official rust MCP crate, we would also be using this task to +/// facilitate it. +/// - Listens for prompt list request and serve them. Unlike tools, we do *not* cache prompts on the +/// conversation state. This is because prompts do not need to be sent to the model every turn. +/// Instead, the prompts are cached in a hashmap that is owned by the orchestrator task. +/// +/// Note that there should be exactly one instance of this task running per session. Should there +/// be any need to instantiate a new [ToolManager] (e.g. swapping agents), see +/// [ToolManager::swap_agent] for how this should be done. +#[allow(clippy::too_many_arguments)] +fn spawn_orchestrator_task( + has_new_stuff: Arc, + mut loading_servers: HashMap, + mut msg_rx: tokio::sync::mpsc::Receiver, + mut prompt_list_receiver: tokio::sync::broadcast::Receiver, + mut prompt_list_sender: tokio::sync::broadcast::Sender, + pending: Arc>>, + agent: Arc>, + database: Database, + regex: Regex, + notify_weak: std::sync::Weak, + load_record: Arc>>>, + telemetry: TelemetryThread, + loading_status_sender: Option, + new_tool_specs: NewToolSpecs, + total: usize, + conv_id: String, +) { + tokio::spawn(async move { + use tokio::sync::broadcast::Sender as BroadcastSender; + use tokio::sync::mpsc::Sender as MpscSender; + + let mut record_temp_buf = Vec::::new(); + let mut initialized = HashSet::::new(); + let mut prompts = HashMap::>::new(); + + enum ToolFilter { + All, + List(HashSet), + } + + impl ToolFilter { + pub fn should_include(&self, tool_name: &str) -> bool { + match self { + Self::All => true, + Self::List(set) => set.contains(tool_name), + } + } + } + + // We separate this into its own function for ease of maintenance since things written + // in select arms don't have type hints + #[inline] + async fn handle_prompt_queries( + query: PromptQuery, + prompts: &HashMap>, + prompt_query_response_sender: &mut BroadcastSender, + ) { + match query { + PromptQuery::List => { + let query_res = PromptQueryResult::List(prompts.clone()); + if let Err(e) = prompt_query_response_sender.send(query_res) { + error!("Error sending prompts to chat helper: {:?}", e); + } + }, + PromptQuery::Search(search_word) => { + let filtered_prompts = prompts + .iter() + .flat_map(|(prompt_name, bundles)| { + if bundles.len() > 1 { + bundles + .iter() + .map(|b| format!("{}/{}", b.server_name, prompt_name)) + .collect() + } else { + vec![prompt_name.to_owned()] + } + }) + .filter(|n| { + if let Some(p) = &search_word { + n.contains(p) + } else { + true + } + }) + .collect::>(); + + let query_res = PromptQueryResult::Search(filtered_prompts); + if let Err(e) = prompt_query_response_sender.send(query_res) { + error!("Error sending prompts to chat helper: {:?}", e); + } + }, + } + } + + // We separate this into its own function for ease of maintenance since things written + // in select arms don't have type hints + #[inline] + #[allow(clippy::too_many_arguments)] + async fn handle_messenger_msg( + msg: UpdateEventMessage, + loading_servers: &mut HashMap, + record_temp_buf: &mut Vec, + pending: &Arc>>, + agent: &Arc>, + database: &Database, + conv_id: &str, + regex: &Regex, + telemetry_clone: &TelemetryThread, + mut loading_status_sender: Option<&MpscSender>, + new_tool_specs: &NewToolSpecs, + has_new_stuff: &Arc, + load_record: &Arc>>>, + notify_weak: &std::sync::Weak, + initialized: &mut HashSet, + prompts: &mut HashMap>, + total: usize, + ) { + record_temp_buf.clear(); + // For now we will treat every list result as if they contain the + // complete set of tools. This is not necessarily true in the future when + // request method on the mcp client no longer buffers all the pages from + // list calls. + match msg { + UpdateEventMessage::ToolsListResult { + server_name, + result, + pid, + } => { + let time_taken = loading_servers + .remove(&server_name) + .map_or("0.0".to_owned(), |init_time| { + let time_taken = (std::time::Instant::now() - init_time).as_secs_f64().abs(); + format!("{:.2}", time_taken) + }); + pending.write().await.remove(&server_name); + + let result_tools = match &result { + Ok(tools_result) => { + let names: Vec = tools_result + .tools + .iter() + .filter_map(|tool| tool.get("name")?.as_str().map(String::from)) + .collect(); + names + }, + Err(_) => vec![], + }; + + let (tool_filter, alias_list) = { + let agent_lock = agent.lock().await; + + // We will assume all tools are allowed if the tool list consists of 1 + // element and it's a * + let tool_filter = if agent_lock.tools.len() == 1 + && agent_lock.tools.first().map(String::as_str).is_some_and(|c| c == "*") + { + ToolFilter::All + } else { + let set = agent_lock + .tools + .iter() + .filter(|tool_name| tool_name.starts_with(&format!("@{server_name}"))) + .map(|full_name| { + match full_name.split_once(MCP_SERVER_TOOL_DELIMITER) { + Some((_, tool_name)) if !tool_name.is_empty() => tool_name, + _ => "*", + } + .to_string() + }) + .collect::>(); + + if set.contains("*") { + ToolFilter::All + } else { + ToolFilter::List(set) + } + }; + + let server_prefix = format!("@{server_name}"); + let alias_list = agent_lock.tool_aliases.iter().fold( + HashMap::::new(), + |mut acc, (full_path, model_tool_name)| { + if full_path.starts_with(&server_prefix) { + if let Some((_, host_tool_name)) = full_path.split_once(MCP_SERVER_TOOL_DELIMITER) { + acc.insert(host_tool_name.to_string(), model_tool_name.clone()); + } + } + acc + }, + ); + + (tool_filter, alias_list) + }; + + match result { + Ok(result) => { + if pid.is_none_or(|pid| !is_process_running(pid)) { + let pid = pid.map_or("unknown".to_string(), |pid| pid.to_string()); + info!( + "Received tool list result from {server_name} but its associated process {pid} is no longer running. Ignoring." + ); + + let mut buf_writer = BufWriter::new(&mut *record_temp_buf); + let _ = queue_failure_message( + &server_name, + &eyre::eyre!("Process associated is no longer running"), + &time_taken, + &mut buf_writer, + ); + let _ = buf_writer.flush(); + drop(buf_writer); + let record_content = String::from_utf8_lossy(record_temp_buf).to_string(); + let record = LoadingRecord::Err(record_content); + + load_record + .lock() + .await + .entry(server_name.clone()) + .and_modify(|load_record| { + load_record.push(record.clone()); + }) + .or_insert(vec![record]); + + return; + } + + let mut specs = result + .tools + .into_iter() + .filter_map(|v| serde_json::from_value::(v).ok()) + .filter(|spec| tool_filter.should_include(&spec.name)) + .collect::>(); + let mut sanitized_mapping = HashMap::::new(); + let process_result = process_tool_specs( + database, + conv_id, + &server_name, + &mut specs, + &mut sanitized_mapping, + &alias_list, + regex, + telemetry_clone, + &result_tools, + ) + .await; + if let Some(sender) = &loading_status_sender { + // Anomalies here are not considered fatal, thus we shall give + // warnings. + let msg = match process_result { + Ok(_) => LoadingMsg::Done { + name: server_name.clone(), + time: time_taken.clone(), + }, + Err(ref e) => LoadingMsg::Warn { + name: server_name.clone(), + msg: eyre::eyre!(e.to_string()), + time: time_taken.clone(), + }, + }; + if let Err(e) = sender.send(msg).await { + warn!( + "Error sending update message to display task: {:?}\nAssume display task has completed", + e + ); + loading_status_sender.take(); + } + } + new_tool_specs + .lock() + .await + .insert(server_name.clone(), (sanitized_mapping, specs)); + has_new_stuff.store(true, Ordering::Release); + // Maintain a record of the server load: + let mut buf_writer = BufWriter::new(&mut *record_temp_buf); + if let Err(e) = &process_result { + let _ = + queue_warn_message(server_name.as_str(), e, time_taken.as_str(), &mut buf_writer); + } else { + let _ = + queue_success_message(server_name.as_str(), time_taken.as_str(), &mut buf_writer); + } + let _ = buf_writer.flush(); + drop(buf_writer); + let record = String::from_utf8_lossy(record_temp_buf).to_string(); + let record = if process_result.is_err() { + LoadingRecord::Warn(record) + } else { + LoadingRecord::Success(record) + }; + load_record + .lock() + .await + .entry(server_name.clone()) + .and_modify(|load_record| { + load_record.push(record.clone()); + }) + .or_insert(vec![record]); + }, + Err(e) => { + // Log error to chat Log + error!("Error loading server {server_name}: {:?}", e); + // Maintain a record of the server load: + let mut buf_writer = BufWriter::new(&mut *record_temp_buf); + let _ = queue_failure_message(server_name.as_str(), &e, &time_taken, &mut buf_writer); + let _ = buf_writer.flush(); + drop(buf_writer); + let record = String::from_utf8_lossy(record_temp_buf).to_string(); + let record = LoadingRecord::Err(record); + load_record + .lock() + .await + .entry(server_name.clone()) + .and_modify(|load_record| { + load_record.push(record.clone()); + }) + .or_insert(vec![record]); + // Errors surfaced at this point (i.e. before [process_tool_specs] + // is called) are fatals and should be considered errors + if let Some(sender) = &loading_status_sender { + let msg = LoadingMsg::Error { + name: server_name.clone(), + msg: e, + time: time_taken, + }; + if let Err(e) = sender.send(msg).await { + warn!( + "Error sending update message to display task: {:?}\nAssume display task has completed", + e + ); + loading_status_sender.take(); + } + } + }, + } + if let Some(notify) = notify_weak.upgrade() { + initialized.insert(server_name); + if initialized.len() >= total { + notify.notify_one(); + } + } + }, + UpdateEventMessage::PromptsListResult { + server_name, + result, + pid, + } => match result { + Ok(prompt_list_result) if pid.is_some() => { + let pid = pid.unwrap(); + if !is_process_running(pid) { + info!( + "Received prompt list result from {server_name} but its associated process {pid} is no longer running. Ignoring." + ); + return; + } + // We first need to clear all the PromptGets that are associated with + // this server because PromptsListResult is declaring what is available + // (and not the diff) + prompts + .values_mut() + .for_each(|bundles| bundles.retain(|bundle| bundle.server_name != server_name)); + + // And then we update them with the new comers + for result in prompt_list_result.prompts { + let Ok(prompt_get) = serde_json::from_value::(result) else { + error!("Failed to deserialize prompt get from server {server_name}"); + continue; + }; + prompts + .entry(prompt_get.name.clone()) + .and_modify(|bundles| { + bundles.push(PromptBundle { + server_name: server_name.clone(), + prompt_get: prompt_get.clone(), + }); + }) + .or_insert_with(|| { + vec![PromptBundle { + server_name: server_name.clone(), + prompt_get, + }] + }); + } + }, + Ok(_) => { + error!("Received prompt list result without pid from {server_name}. Ignoring."); + }, + Err(e) => { + error!("Error fetching prompts from server {server_name}: {:?}", e); + let mut buf_writer = BufWriter::new(&mut *record_temp_buf); + let _ = queue_prompts_load_error_message(&server_name, &e, &mut buf_writer); + let _ = buf_writer.flush(); + drop(buf_writer); + let record = String::from_utf8_lossy(record_temp_buf).to_string(); + let record = LoadingRecord::Err(record); + load_record + .lock() + .await + .entry(server_name.clone()) + .and_modify(|load_record| { + load_record.push(record.clone()); + }) + .or_insert(vec![record]); + }, + }, + UpdateEventMessage::ResourcesListResult { + server_name: _, + result: _, + pid: _, + } => {}, + UpdateEventMessage::ResourceTemplatesListResult { + server_name: _, + result: _, + pid: _, + } => {}, + UpdateEventMessage::InitStart { server_name, .. } => { + pending.write().await.insert(server_name.clone()); + loading_servers.insert(server_name, std::time::Instant::now()); + }, + UpdateEventMessage::Deinit { server_name, .. } => { + // Only prompts are stored here so we'll just be clearing that + // In the future if we are also storing tools, we need to make sure that + // the tools are also pruned. + for (_prompt_name, bundles) in prompts.iter_mut() { + bundles.retain(|bundle| bundle.server_name != server_name); + } + prompts.retain(|_, bundles| !bundles.is_empty()); + has_new_stuff.store(true, Ordering::Release); + }, + } + } + + loop { + tokio::select! { + Ok(query) = prompt_list_receiver.recv() => { + handle_prompt_queries(query, &prompts, &mut prompt_list_sender).await; + }, + Some(msg) = msg_rx.recv() => { + handle_messenger_msg( + msg, + &mut loading_servers, + &mut record_temp_buf, + &pending, + &agent, + &database, + conv_id.as_str(), + ®ex, + &telemetry, + loading_status_sender.as_ref(), + &new_tool_specs, + &has_new_stuff, + &load_record, + ¬ify_weak, + &mut initialized, + &mut prompts, + total + ).await; + }, + // Nothing else to poll + else => { + tracing::info!("Tool manager orchestrator task exited"); + break; + }, + } + } + }); +} + +#[allow(clippy::too_many_arguments)] +async fn process_tool_specs( + database: &Database, + conversation_id: &str, + server_name: &str, + specs: &mut Vec, + tn_map: &mut HashMap, + alias_list: &HashMap, + regex: &Regex, + telemetry: &TelemetryThread, + result_tools: &[String], +) -> eyre::Result<()> { + // Tools are subjected to the following validations: + // 1. ^[a-zA-Z][a-zA-Z0-9_]*$, + // 2. less than 64 characters in length + // 3. a non-empty description + // + // For non-compliance due to point 1, we shall change it on behalf of the users. + // For the rest, we simply throw a warning and reject the tool. + let mut out_of_spec_tool_names = Vec::::new(); + let mut hasher = DefaultHasher::new(); + let mut number_of_tools = 0_usize; + + let number_of_tools_in_mcp_server = result_tools.len(); + + let all_tool_names = if !result_tools.is_empty() { + Some(result_tools.join(",")) + } else { + None + }; + + for spec in specs.iter_mut() { + let model_tool_name = alias_list.get(&spec.name).cloned().unwrap_or({ + if !regex.is_match(&spec.name) { + let mut sn = sanitize_name(spec.name.clone(), regex, &mut hasher); + while tn_map.contains_key(&sn) { + sn.push('1'); + } + sn + } else { + spec.name.clone() + } + }); + if model_tool_name.len() > 64 { + out_of_spec_tool_names.push(OutOfSpecName::TooLong(spec.name.clone())); + continue; + } else if spec.description.is_empty() { + out_of_spec_tool_names.push(OutOfSpecName::EmptyDescription(spec.name.clone())); + continue; + } + tn_map.insert(model_tool_name.clone(), ToolInfo { + server_name: server_name.to_string(), + host_tool_name: spec.name.clone(), + }); + spec.name = model_tool_name; + spec.tool_origin = ToolOrigin::McpServer(server_name.to_string()); + number_of_tools += 1; + } + // Native origin is the default, and since this function never reads native tools, if we still + // have it, that would indicate a tool that should not be included. + specs.retain(|spec| !matches!(spec.tool_origin, ToolOrigin::Native)); + let loaded_tool_names = if specs.is_empty() { + None + } else { + Some(specs.iter().map(|spec| spec.name.clone()).collect::>().join(",")) + }; + // Send server load success metric datum + let conversation_id = conversation_id.to_string(); + let _ = telemetry + .send_mcp_server_init( + database, + conversation_id, + server_name.to_string(), + None, + number_of_tools, + all_tool_names, + loaded_tool_names, + number_of_tools_in_mcp_server, + ) + .await; + // Tool name translation. This is beyond of the scope of what is + // considered a "server load". Reasoning being: + // - Failures here are not related to server load + // - There is not a whole lot we can do with this data + if !out_of_spec_tool_names.is_empty() { + Err(eyre::eyre!(out_of_spec_tool_names.iter().fold( + String::from( + "The following tools are out of spec. They will be excluded from the list of available tools:\n", + ), + |mut acc, name| { + let (tool_name, msg) = match name { + OutOfSpecName::TooLong(tool_name) => ( + tool_name.as_str(), + "tool name exceeds max length of 64 when combined with server name", + ), + OutOfSpecName::IllegalChar(tool_name) => ( + tool_name.as_str(), + "tool name must be compliant with ^[a-zA-Z][a-zA-Z0-9_]*$", + ), + OutOfSpecName::EmptyDescription(tool_name) => { + (tool_name.as_str(), "tool schema contains empty description") + }, + }; + acc.push_str(format!(" - {} ({})\n", tool_name, msg).as_str()); + acc + }, + ))) + } else { + Ok(()) + } +} + +fn sanitize_name(orig: String, regex: ®ex::Regex, hasher: &mut impl Hasher) -> String { + if regex.is_match(&orig) && !orig.contains(NAMESPACE_DELIMITER) { + return orig; + } + let sanitized: String = orig + .chars() + .filter(|c| c.is_ascii_alphabetic() || c.is_ascii_digit() || *c == '_') + .collect::() + .replace(NAMESPACE_DELIMITER, ""); + if sanitized.is_empty() { + hasher.write(orig.as_bytes()); + let hash = format!("{:03}", hasher.finish() % 1000); + return format!("a{}", hash); + } + match sanitized.chars().next() { + Some(c) if c.is_ascii_alphabetic() => sanitized, + Some(_) => { + format!("a{}", sanitized) + }, + None => { + hasher.write(orig.as_bytes()); + format!("a{}", hasher.finish()) + }, + } +} + +// Add this function to check if a process is still running +fn is_process_running(pid: u32) -> bool { + #[cfg(unix)] + { + let system = sysinfo::System::new_all(); + system.process(sysinfo::Pid::from(pid as usize)).is_some() + } + #[cfg(windows)] + { + // TODO: fill in the process health check for windows when when we officially support + // windows + _ = pid; + true + } +} + +fn queue_success_message(name: &str, time_taken: &str, output: &mut impl Write) -> eyre::Result<()> { + Ok(queue!( + output, + style::SetForegroundColor(style::Color::Green), + style::Print("✓ "), + style::SetForegroundColor(style::Color::Blue), + style::Print(name), + style::ResetColor, + style::Print(" loaded in "), + style::SetForegroundColor(style::Color::Yellow), + style::Print(format!("{time_taken} s\n")), + style::ResetColor, + )?) +} + +fn queue_init_message( + spinner_logo_idx: usize, + complete: usize, + failed: usize, + total: usize, + output: &mut impl Write, +) -> eyre::Result<()> { + if total == complete { + queue!( + output, + style::SetForegroundColor(style::Color::Green), + style::Print("✓"), + style::ResetColor, + )?; + } else if total == complete + failed { + queue!( + output, + style::SetForegroundColor(style::Color::Red), + style::Print("✗"), + style::ResetColor, + )?; + } else { + queue!(output, style::Print(SPINNER_CHARS[spinner_logo_idx]))?; + } + queue!( + output, + style::SetForegroundColor(style::Color::Blue), + style::Print(format!(" {}", complete)), + style::ResetColor, + style::Print(" of "), + style::SetForegroundColor(style::Color::Blue), + style::Print(format!("{} ", total)), + style::ResetColor, + style::Print("mcp servers initialized."), + )?; + if total > complete + failed { + queue!( + output, + style::SetForegroundColor(style::Color::Blue), + style::Print(" ctrl-c "), + style::ResetColor, + style::Print("to start chatting now") + )?; + } + Ok(queue!(output, style::Print("\n"))?) +} + +fn queue_failure_message( + name: &str, + fail_load_msg: &eyre::Report, + time: &str, + output: &mut impl Write, +) -> eyre::Result<()> { + use crate::util::CHAT_BINARY_NAME; + Ok(queue!( + output, + style::SetForegroundColor(style::Color::Red), + style::Print("✗ "), + style::SetForegroundColor(style::Color::Blue), + style::Print(name), + style::ResetColor, + style::Print(" has failed to load after"), + style::SetForegroundColor(style::Color::Yellow), + style::Print(format!(" {time} s")), + style::ResetColor, + style::Print("\n - "), + style::Print(fail_load_msg), + style::Print("\n"), + style::Print(format!( + " - run with Q_LOG_LEVEL=trace and see $TMPDIR/{CHAT_BINARY_NAME} for detail\n" + )), + style::ResetColor, + )?) +} + +fn queue_warn_message(name: &str, msg: &eyre::Report, time: &str, output: &mut impl Write) -> eyre::Result<()> { + Ok(queue!( + output, + style::SetForegroundColor(style::Color::Yellow), + style::Print("⚠ "), + style::SetForegroundColor(style::Color::Blue), + style::Print(name), + style::ResetColor, + style::Print(" has loaded in"), + style::SetForegroundColor(style::Color::Yellow), + style::Print(format!(" {time} s")), + style::ResetColor, + style::Print(" with the following warning:\n"), + style::Print(msg), + style::ResetColor, + )?) +} + +fn queue_disabled_message(name: &str, output: &mut impl Write) -> eyre::Result<()> { + Ok(queue!( + output, + style::SetForegroundColor(style::Color::DarkGrey), + style::Print("○ "), + style::SetForegroundColor(style::Color::Blue), + style::Print(name), + style::ResetColor, + style::Print(" is disabled\n"), + style::ResetColor, + )?) +} + +fn queue_incomplete_load_message( + complete: usize, + total: usize, + msg: &eyre::Report, + output: &mut impl Write, +) -> eyre::Result<()> { + Ok(queue!( + output, + style::SetForegroundColor(style::Color::Yellow), + style::Print("⚠"), + style::SetForegroundColor(style::Color::Blue), + style::Print(format!(" {}", complete)), + style::ResetColor, + style::Print(" of "), + style::SetForegroundColor(style::Color::Blue), + style::Print(format!("{} ", total)), + style::ResetColor, + style::Print("mcp servers initialized."), + style::ResetColor, + // We expect the message start with a newline + style::Print(" Servers still loading:"), + style::Print(msg), + style::ResetColor, + )?) +} + +fn queue_prompts_load_error_message(name: &str, msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> { + Ok(queue!( + output, + style::Print(format!("Prompt list for {name} failed with the following message: \n")), + style::Print(msg), + )?) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sanitize_server_name() { + let regex = regex::Regex::new(VALID_TOOL_NAME).unwrap(); + let mut hasher = DefaultHasher::new(); + let orig_name = "@awslabs.cdk-mcp-server"; + let sanitized_server_name = sanitize_name(orig_name.to_string(), ®ex, &mut hasher); + assert_eq!(sanitized_server_name, "awslabscdkmcpserver"); + + let orig_name = "good_name"; + let sanitized_good_name = sanitize_name(orig_name.to_string(), ®ex, &mut hasher); + assert_eq!(sanitized_good_name, orig_name); + + let all_bad_name = "@@@@@"; + let sanitized_all_bad_name = sanitize_name(all_bad_name.to_string(), ®ex, &mut hasher); + assert!(regex.is_match(&sanitized_all_bad_name)); + + let with_delim = format!("a{}b{}c", NAMESPACE_DELIMITER, NAMESPACE_DELIMITER); + let sanitized = sanitize_name(with_delim, ®ex, &mut hasher); + assert_eq!(sanitized, "abc"); + } +} + + +---- crates/chat-cli/src/cli/chat/tools/mod.rs ---- +pub mod custom_tool; +pub mod execute; +pub mod fs_read; +pub mod fs_write; +pub mod gh_issue; +pub mod knowledge; +pub mod thinking; +pub mod use_aws; + +use std::borrow::{ + Borrow, + Cow, +}; +use std::collections::HashMap; +use std::io::Write; +use std::path::{ + Path, + PathBuf, +}; + +use crossterm::queue; +use crossterm::style::{ + self, + Color, +}; +use custom_tool::CustomTool; +use execute::ExecuteCommand; +use eyre::Result; +use fs_read::FsRead; +use fs_write::FsWrite; +use gh_issue::GhIssue; +use knowledge::Knowledge; +use serde::{ + Deserialize, + Serialize, +}; +use thinking::Thinking; +use tracing::error; +use use_aws::UseAws; + +use super::consts::{ + MAX_TOOL_RESPONSE_SIZE, + USER_AGENT_APP_NAME, + USER_AGENT_ENV_VAR, + USER_AGENT_VERSION_KEY, + USER_AGENT_VERSION_VALUE, +}; +use super::util::images::RichImageBlocks; +use crate::cli::agent::{ + Agent, + PermissionEvalResult, +}; +use crate::cli::chat::line_tracker::FileLineTracker; +use crate::os::Os; + +pub const DEFAULT_APPROVE: [&str; 1] = ["fs_read"]; +pub const NATIVE_TOOLS: [&str; 7] = [ + "fs_read", + "fs_write", + #[cfg(windows)] + "execute_cmd", + #[cfg(not(windows))] + "execute_bash", + "use_aws", + "gh_issue", + "knowledge", + "thinking", +]; + +/// Represents an executable tool use. +#[allow(clippy::large_enum_variant)] +#[derive(Debug, Clone)] +pub enum Tool { + FsRead(FsRead), + FsWrite(FsWrite), + ExecuteCommand(ExecuteCommand), + UseAws(UseAws), + Custom(CustomTool), + GhIssue(GhIssue), + Knowledge(Knowledge), + Thinking(Thinking), +} + +impl Tool { + /// The display name of a tool + pub fn display_name(&self) -> String { + match self { + Tool::FsRead(_) => "fs_read", + Tool::FsWrite(_) => "fs_write", + #[cfg(windows)] + Tool::ExecuteCommand(_) => "execute_cmd", + #[cfg(not(windows))] + Tool::ExecuteCommand(_) => "execute_bash", + Tool::UseAws(_) => "use_aws", + Tool::Custom(custom_tool) => &custom_tool.name, + Tool::GhIssue(_) => "gh_issue", + Tool::Knowledge(_) => "knowledge", + Tool::Thinking(_) => "thinking (prerelease)", + } + .to_owned() + } + + /// Whether or not the tool should prompt the user to accept before [Self::invoke] is called. + pub fn requires_acceptance(&self, os: &Os, agent: &Agent) -> PermissionEvalResult { + match self { + Tool::FsRead(fs_read) => fs_read.eval_perm(os, agent), + Tool::FsWrite(fs_write) => fs_write.eval_perm(os, agent), + Tool::ExecuteCommand(execute_command) => execute_command.eval_perm(os, agent), + Tool::UseAws(use_aws) => use_aws.eval_perm(os, agent), + Tool::Custom(custom_tool) => custom_tool.eval_perm(os, agent), + Tool::GhIssue(_) => PermissionEvalResult::Allow, + Tool::Thinking(_) => PermissionEvalResult::Allow, + Tool::Knowledge(knowledge) => knowledge.eval_perm(os, agent), + } + } + + /// Invokes the tool asynchronously + pub async fn invoke( + &self, + os: &Os, + stdout: &mut impl Write, + line_tracker: &mut HashMap, + ) -> Result { + match self { + Tool::FsRead(fs_read) => fs_read.invoke(os, stdout).await, + Tool::FsWrite(fs_write) => fs_write.invoke(os, stdout, line_tracker).await, + Tool::ExecuteCommand(execute_command) => execute_command.invoke(os, stdout).await, + Tool::UseAws(use_aws) => use_aws.invoke(os, stdout).await, + Tool::Custom(custom_tool) => custom_tool.invoke(os, stdout).await, + Tool::GhIssue(gh_issue) => gh_issue.invoke(os, stdout).await, + Tool::Knowledge(knowledge) => knowledge.invoke(os, stdout).await, + Tool::Thinking(think) => think.invoke(stdout).await, + } + } + + /// Queues up a tool's intention in a human readable format + pub async fn queue_description(&self, os: &Os, output: &mut impl Write) -> Result<()> { + match self { + Tool::FsRead(fs_read) => fs_read.queue_description(os, output).await, + Tool::FsWrite(fs_write) => fs_write.queue_description(os, output), + Tool::ExecuteCommand(execute_command) => execute_command.queue_description(output), + Tool::UseAws(use_aws) => use_aws.queue_description(output), + Tool::Custom(custom_tool) => custom_tool.queue_description(output), + Tool::GhIssue(gh_issue) => gh_issue.queue_description(output), + Tool::Knowledge(knowledge) => knowledge.queue_description(os, output).await, + Tool::Thinking(thinking) => thinking.queue_description(output), + } + } + + /// Validates the tool with the arguments supplied + pub async fn validate(&mut self, os: &Os) -> Result<()> { + match self { + Tool::FsRead(fs_read) => fs_read.validate(os).await, + Tool::FsWrite(fs_write) => fs_write.validate(os).await, + Tool::ExecuteCommand(execute_command) => execute_command.validate(os).await, + Tool::UseAws(use_aws) => use_aws.validate(os).await, + Tool::Custom(custom_tool) => custom_tool.validate(os).await, + Tool::GhIssue(gh_issue) => gh_issue.validate(os).await, + Tool::Knowledge(knowledge) => knowledge.validate(os).await, + Tool::Thinking(think) => think.validate(os).await, + } + } + + /// Returns additional information about the tool if available + pub fn get_additional_info(&self) -> Option { + match self { + Tool::UseAws(use_aws) => Some(use_aws.get_additional_info()), + // Add other tool types here as they implement get_additional_info() + _ => None, + } + } +} + +/// A tool specification to be sent to the model as part of a conversation. Maps to +/// [BedrockToolSpecification]. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolSpec { + pub name: String, + pub description: String, + #[serde(alias = "inputSchema")] + pub input_schema: InputSchema, + #[serde(skip_serializing, default = "tool_origin")] + pub tool_origin: ToolOrigin, +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum ToolOrigin { + Native, + McpServer(String), +} + +impl std::hash::Hash for ToolOrigin { + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state); + match self { + Self::Native => {}, + Self::McpServer(name) => name.hash(state), + } + } +} + +impl Borrow for ToolOrigin { + fn borrow(&self) -> &str { + match self { + Self::McpServer(name) => name.as_str(), + Self::Native => "native", + } + } +} + +impl<'de> Deserialize<'de> for ToolOrigin { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + if s == "native___" { + Ok(ToolOrigin::Native) + } else { + Ok(ToolOrigin::McpServer(s)) + } + } +} + +impl Serialize for ToolOrigin { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + ToolOrigin::Native => serializer.serialize_str("native___"), + ToolOrigin::McpServer(server) => serializer.serialize_str(server), + } + } +} + +impl std::fmt::Display for ToolOrigin { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ToolOrigin::Native => write!(f, "Built-in"), + ToolOrigin::McpServer(server) => write!(f, "{} (MCP)", server), + } + } +} + +fn tool_origin() -> ToolOrigin { + ToolOrigin::Native +} + +#[derive(Debug, Clone)] +pub struct QueuedTool { + pub id: String, + pub name: String, + pub accepted: bool, + pub tool: Tool, +} + +/// The schema specification describing a tool's fields. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InputSchema(pub serde_json::Value); + +/// The output received from invoking a [Tool]. +#[derive(Debug, Default)] +pub struct InvokeOutput { + pub output: OutputKind, +} + +impl InvokeOutput { + pub fn as_str(&self) -> Cow<'_, str> { + match &self.output { + OutputKind::Text(s) => s.as_str().into(), + OutputKind::Json(j) => serde_json::to_string(j) + .map_err(|err| error!(?err, "failed to serialize tool to json")) + .unwrap_or_default() + .into(), + OutputKind::Images(_) => "".into(), + OutputKind::Mixed { text, .. } => text.as_str().into(), // Return the text part + } + } +} + +#[non_exhaustive] +#[derive(Debug)] +pub enum OutputKind { + Text(String), + Json(serde_json::Value), + Images(RichImageBlocks), + Mixed { text: String, images: RichImageBlocks }, +} + +impl Default for OutputKind { + fn default() -> Self { + Self::Text(String::new()) + } +} + +/// Performs tilde expansion and other required sanitization modifications for handling tool use +/// path arguments. +/// +/// Required since path arguments are defined by the model. +#[allow(dead_code)] +pub fn sanitize_path_tool_arg(os: &Os, path: impl AsRef) -> PathBuf { + let mut res = PathBuf::new(); + // Expand `~` only if it is the first part. + let mut path = path.as_ref().components(); + match path.next() { + Some(p) if p.as_os_str() == "~" => { + res.push(os.env.home().unwrap_or_default()); + }, + Some(p) => res.push(p), + None => return res, + } + for p in path { + res.push(p); + } + // For testing scenarios, we need to make sure paths are appropriately handled in chroot test + // file systems since they are passed directly from the model. + os.fs.chroot_path(res) +} + +/// Converts `path` to a relative path according to the current working directory `cwd`. +fn absolute_to_relative(cwd: impl AsRef, path: impl AsRef) -> Result { + let cwd = cwd.as_ref().canonicalize()?; + let path = path.as_ref().canonicalize()?; + let mut cwd_parts = cwd.components().peekable(); + let mut path_parts = path.components().peekable(); + + // Skip common prefix + while let (Some(a), Some(b)) = (cwd_parts.peek(), path_parts.peek()) { + if a == b { + cwd_parts.next(); + path_parts.next(); + } else { + break; + } + } + + // ".." for any uncommon parts, then just append the rest of the path. + let mut relative = PathBuf::new(); + for _ in cwd_parts { + relative.push(".."); + } + for part in path_parts { + relative.push(part); + } + + Ok(relative) +} + +/// Small helper for formatting the path as a relative path, if able. +fn format_path(cwd: impl AsRef, path: impl AsRef) -> String { + absolute_to_relative(cwd, path.as_ref()) + .map(|p| p.to_string_lossy().to_string()) + // If we have three consecutive ".." then it should probably just stay as an absolute path. + .map(|p| { + let three_up = format!("..{}..{}..", std::path::MAIN_SEPARATOR, std::path::MAIN_SEPARATOR); + if p.starts_with(&three_up) { + path.as_ref().to_string_lossy().to_string() + } else { + p + } + }) + .unwrap_or(path.as_ref().to_string_lossy().to_string()) +} + +fn supports_truecolor(os: &Os) -> bool { + // Simple override to disable truecolor since shell_color doesn't use Context. + !os.env.get("Q_DISABLE_TRUECOLOR").is_ok_and(|s| !s.is_empty()) + && shell_color::get_color_support().contains(shell_color::ColorSupport::TERM24BIT) +} + +/// Helper function to display a purpose if available (for execute commands) +pub fn display_purpose(purpose: Option<&String>, updates: &mut impl Write) -> Result<()> { + if let Some(purpose) = purpose { + queue!( + updates, + style::Print(super::CONTINUATION_LINE), + style::Print("\n"), + style::Print(super::PURPOSE_ARROW), + style::SetForegroundColor(Color::Blue), + style::Print("Purpose: "), + style::ResetColor, + style::Print(purpose), + style::Print("\n"), + )?; + } + Ok(()) +} + +/// Helper function to format function results with consistent styling +/// +/// # Parameters +/// * `result` - The result text to display +/// * `updates` - The output to write to +/// * `is_error` - Whether this is an error message (changes formatting) +/// * `use_bullet` - Whether to use a bullet point instead of a tick/exclamation +pub fn queue_function_result(result: &str, updates: &mut impl Write, is_error: bool, use_bullet: bool) -> Result<()> { + let lines = result.lines().collect::>(); + + // Determine symbol and color + let (symbol, color) = match (is_error, use_bullet) { + (true, _) => (super::ERROR_EXCLAMATION, Color::Red), + (false, true) => (super::TOOL_BULLET, Color::Reset), + (false, false) => (super::SUCCESS_TICK, Color::Green), + }; + + queue!(updates, style::Print("\n"))?; + + // Print first line with symbol + if let Some(first_line) = lines.first() { + queue!( + updates, + style::SetForegroundColor(color), + style::Print(symbol), + style::ResetColor, + style::Print(first_line), + style::Print("\n"), + )?; + } + + // Print remaining lines with indentation + for line in lines.iter().skip(1) { + queue!( + updates, + style::Print(" "), // 3 spaces for alignment + style::Print(line), + style::Print("\n"), + )?; + } + + Ok(()) +} + +/// Helper function to set up environment variables with user agent metadata for CloudTrail tracking +pub fn env_vars_with_user_agent(os: &Os) -> std::collections::HashMap { + let mut env_vars: std::collections::HashMap = std::env::vars().collect(); + + // Set up additional metadata for the AWS CLI user agent + let user_agent_metadata_value = format!( + "{} {}/{}", + USER_AGENT_APP_NAME, USER_AGENT_VERSION_KEY, USER_AGENT_VERSION_VALUE + ); + + // Check if the user agent metadata env var already exists using Os + let existing_value = os.env.get(USER_AGENT_ENV_VAR).ok(); + + // If the user agent metadata env var already exists, append to it, otherwise set it + if let Some(existing_value) = existing_value { + if !existing_value.is_empty() { + env_vars.insert( + USER_AGENT_ENV_VAR.to_string(), + format!("{} {}", existing_value, user_agent_metadata_value), + ); + } else { + env_vars.insert(USER_AGENT_ENV_VAR.to_string(), user_agent_metadata_value); + } + } else { + env_vars.insert(USER_AGENT_ENV_VAR.to_string(), user_agent_metadata_value); + } + + env_vars +} + +#[cfg(test)] +mod tests { + use std::path::MAIN_SEPARATOR; + + use super::*; + use crate::os::ACTIVE_USER_HOME; + + #[tokio::test] + async fn test_tilde_path_expansion() { + let os = Os::new().await.unwrap(); + + let actual = sanitize_path_tool_arg(&os, "~"); + let expected_home = os.env.home().unwrap_or_default(); + assert_eq!(actual, os.fs.chroot_path(&expected_home), "tilde should expand"); + let actual = sanitize_path_tool_arg(&os, "~/hello"); + assert_eq!( + actual, + os.fs.chroot_path(expected_home.join("hello")), + "tilde should expand" + ); + let actual = sanitize_path_tool_arg(&os, "/~"); + assert_eq!( + actual, + os.fs.chroot_path("/~"), + "tilde should not expand when not the first component" + ); + } + + #[tokio::test] + async fn test_format_path() { + async fn assert_paths(cwd: &str, path: &str, expected: &str) { + let os = Os::new().await.unwrap(); + let cwd = sanitize_path_tool_arg(&os, cwd); + let path = sanitize_path_tool_arg(&os, path); + let fs = os.fs; + fs.create_dir_all(&cwd).await.unwrap(); + fs.create_dir_all(&path).await.unwrap(); + + let formatted = format_path(&cwd, &path); + + if Path::new(expected).is_absolute() { + // If the expected path is relative, we need to ensure it is relative to the cwd. + let expected = fs.chroot_path_str(expected); + + assert!(formatted == expected, "Expected '{}' to be '{}'", formatted, expected); + + return; + } + + assert!( + formatted.contains(expected), + "Expected '{}' to be '{}'", + formatted, + expected + ); + } + + // Test relative path from src to Downloads (sibling directories) + assert_paths( + format!("{ACTIVE_USER_HOME}{MAIN_SEPARATOR}src").as_str(), + format!("{ACTIVE_USER_HOME}{MAIN_SEPARATOR}Downloads").as_str(), + format!("..{MAIN_SEPARATOR}Downloads").as_str(), + ) + .await; + + // Test absolute path that should stay absolute (going up too many levels) + assert_paths( + format!("{ACTIVE_USER_HOME}{MAIN_SEPARATOR}projects{MAIN_SEPARATOR}some{MAIN_SEPARATOR}project").as_str(), + format!("{ACTIVE_USER_HOME}{MAIN_SEPARATOR}other").as_str(), + format!("{ACTIVE_USER_HOME}{MAIN_SEPARATOR}other").as_str(), + ) + .await; + } +} + + +---- crates/chat-cli/src/cli/chat/tools/tool_index.json ---- +{ + "dummy": { + "name": "dummy", + "description": "This is a dummy tool. If you are seeing this that means the tool associated with this tool call is not in the list of available tools. This could be because a wrong tool name was supplied or the list of tools has changed since the conversation has started. Do not show this when user asks you to list tools.", + "input_schema": { + "type": "object", + "properties": {}, + "required": [] + } + }, + "execute_bash": { + "name": "execute_bash", + "description": "Execute the specified bash command.", + "input_schema": { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "Bash command to execute" + }, + "summary": { + "type": "string", + "description": "A brief explanation of what the command does" + } + }, + "required": [ + "command" + ] + } + }, + "fs_read": { + "name": "fs_read", + "description": "Tool for reading files, directories and images. Always provide an 'operations' array.\n\nFor single operation: provide array with one element.\nFor batch operations: provide array with multiple elements.\n\nAvailable modes:\n- Line: Read lines from a file\n- Directory: List directory contents\n- Search: Search for patterns in files\n- Image: Read and process images\n\nExamples:\n1. Single: {\"operations\": [{\"mode\": \"Line\", \"path\": \"/file.txt\"}]}\n2. Batch: {\"operations\": [{\"mode\": \"Line\", \"path\": \"/file1.txt\"}, {\"mode\": \"Search\", \"path\": \"/file2.txt\", \"pattern\": \"test\"}]}", + "input_schema": { + "type": "object", + "properties": { + "operations": { + "type": "array", + "description": "Array of operations to execute. Provide one element for single operation, multiple for batch.", + "items": { + "type": "object", + "properties": { + "mode": { + "type": "string", + "enum": [ + "Line", + "Directory", + "Search", + "Image" + ], + "description": "The operation mode to run in: `Line`, `Directory`, `Search`. `Line` and `Search` are only for text files, and `Directory` is only for directories. `Image` is for image files, in this mode `image_paths` is required." + }, + "path": { + "type": "string", + "description": "Path to the file or directory. The path should be absolute, or otherwise start with ~ for the user's home (required for Line, Directory, Search modes)." + }, + "image_paths": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of paths to the images. This is currently supported by the Image mode." + }, + "start_line": { + "type": "integer", + "description": "Starting line number (optional, for Line mode). A negative index represents a line number starting from the end of the file.", + "default": 1 + }, + "end_line": { + "type": "integer", + "description": "Ending line number (optional, for Line mode). A negative index represents a line number starting from the end of the file.", + "default": -1 + }, + "pattern": { + "type": "string", + "description": "Pattern to search for (required, for Search mode). Case insensitive. The pattern matching is performed per line." + }, + "context_lines": { + "type": "integer", + "description": "Number of context lines around search results (optional, for Search mode)", + "default": 2 + }, + "depth": { + "type": "integer", + "description": "Depth of a recursive directory listing (optional, for Directory mode)", + "default": 0 + } + }, + "required": [ + "mode" + ] + }, + "minItems": 1 + }, + "summary": { + "type": "string", + "description": "Optional description of the purpose of this batch operation (mainly useful for multiple operations)" + } + }, + "required": [ + "operations" + ] + } + }, + "fs_write": { + "name": "fs_write", + "description": "A tool for creating and editing files\n * The `create` command will override the file at `path` if it already exists as a file, and otherwise create a new file\n * The `append` command will add content to the end of an existing file, automatically adding a newline if the file doesn't end with one. The file must exist.\n Notes for using the `str_replace` command:\n * The `old_str` parameter should match EXACTLY one or more consecutive lines from the original file. Be mindful of whitespaces!\n * If the `old_str` parameter is not unique in the file, the replacement will not be performed. Make sure to include enough context in `old_str` to make it unique\n * The `new_str` parameter should contain the edited lines that should replace the `old_str`.", + "input_schema": { + "type": "object", + "properties": { + "command": { + "type": "string", + "enum": [ + "create", + "str_replace", + "insert", + "append" + ], + "description": "The commands to run. Allowed options are: `create`, `str_replace`, `insert`, `append`." + }, + "file_text": { + "description": "Required parameter of `create` command, with the content of the file to be created.", + "type": "string" + }, + "insert_line": { + "description": "Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`.", + "type": "integer" + }, + "new_str": { + "description": "Required parameter of `str_replace` command containing the new string. Required parameter of `insert` command containing the string to insert. Required parameter of `append` command containing the content to append to the file.", + "type": "string" + }, + "old_str": { + "description": "Required parameter of `str_replace` command containing the string in `path` to replace.", + "type": "string" + }, + "path": { + "description": "Absolute path to file or directory, e.g. `/repo/file.py` or `/repo`.", + "type": "string" + }, + "summary": { + "description": "A brief explanation of what the file change does or why it's being made.", + "type": "string" + } + }, + "required": [ + "command", + "path" + ] + } + }, + "use_aws": { + "name": "use_aws", + "description": "Make an AWS CLI api call with the specified service, operation, and parameters. All arguments MUST conform to the AWS CLI specification. Should the output of the invocation indicate a malformed command, invoke help to obtain the the correct command.", + "input_schema": { + "type": "object", + "properties": { + "service_name": { + "type": "string", + "description": "The name of the AWS service. If you want to query s3, you should use s3api if possible." + }, + "operation_name": { + "type": "string", + "description": "The name of the operation to perform." + }, + "parameters": { + "type": "object", + "description": "The parameters for the operation. The parameter keys MUST conform to the AWS CLI specification. You should prefer to use JSON Syntax over shorthand syntax wherever possible. For parameters that are booleans, prioritize using flags with no value. Denote these flags with flag names as key and an empty string as their value. You should also prefer kebab case." + }, + "region": { + "type": "string", + "description": "Region name for calling the operation on AWS." + }, + "profile_name": { + "type": "string", + "description": "Optional: AWS profile name to use from ~/.aws/credentials. Defaults to default profile if not specified." + }, + "label": { + "type": "string", + "description": "Human readable description of the api that is being called." + } + }, + "required": [ + "region", + "service_name", + "operation_name", + "label" + ] + } + }, + "gh_issue": { + "name": "report_issue", + "description": "Opens the browser to a pre-filled gh (GitHub) issue template to report chat issues, bugs, or feature requests. Pre-filled information includes the conversation transcript, chat context, and chat request IDs from the service.", + "input_schema": { + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "The title of the GitHub issue." + }, + "expected_behavior": { + "type": "string", + "description": "Optional: The expected chat behavior or action that did not happen." + }, + "actual_behavior": { + "type": "string", + "description": "Optional: The actual chat behavior that happened and demonstrates the issue or lack of a feature." + }, + "steps_to_reproduce": { + "type": "string", + "description": "Optional: Previous user chat requests or steps that were taken that may have resulted in the issue or error response." + } + }, + "required": [ + "title" + ] + } + }, + "thinking": { + "name": "thinking", + "description": "Thinking is an internal reasoning mechanism improving the quality of complex tasks by breaking their atomic actions down; use it specifically for multi-step problems requiring step-by-step dependencies, reasoning through multiple constraints, synthesizing results from previous tool calls, planning intricate sequences of actions, troubleshooting complex errors, or making decisions involving multiple trade-offs. Avoid using it for straightforward tasks, basic information retrieval, summaries, always clearly define the reasoning challenge, structure thoughts explicitly, consider multiple perspectives, and summarize key insights before important decisions or complex tool interactions.", + "input_schema": { + "type": "object", + "properties": { + "thought": { + "type": "string", + "description": "A reflective note or intermediate reasoning step such as \"The user needs to prepare their application for production. I need to complete three major asks including 1: building their code from source, 2: bundling their release artifacts together, and 3: signing the application bundle." + } + }, + "required": [ + "thought" + ] + } + }, + "knowledge": { + "name": "knowledge", + "description": "Store and retrieve information in knowledge base across chat sessions. Provides semantic search capabilities for files, directories, and text content.", + "input_schema": { + "type": "object", + "properties": { + "command": { + "type": "string", + "enum": [ + "show", + "add", + "remove", + "clear", + "search", + "update", + "status", + "cancel" + ], + "description": "The knowledge operation to perform:\n- 'show': List all knowledge contexts (no additional parameters required)\n- 'add': Add content to knowledge base (requires 'name' and 'value')\n- 'remove': Remove content from knowledge base (requires one of: 'name', 'context_id', or 'path')\n- 'clear': Remove all knowledge contexts.\n- 'search': Search across knowledge contexts (requires 'query', optional 'context_id')\n- 'update': Update existing context with new content (requires 'path' and one of: 'name', 'context_id')\n- 'status': Show background operation status and progress\n- 'cancel': Cancel background operations (optional 'operation_id' to cancel specific operation, or cancel all if not provided)" + }, + "name": { + "type": "string", + "description": "A descriptive name for the knowledge context. Required for 'add' operations. Can be used for 'remove' and 'update' operations to identify the context." + }, + "value": { + "type": "string", + "description": "The content to store in knowledge base. Required for 'add' operations. Can be either text content or a file/directory path. If it's a valid file or directory path, the content will be indexed; otherwise it's treated as text." + }, + "context_id": { + "type": "string", + "description": "The unique context identifier for targeted operations. Can be obtained from 'show' command. Used for 'remove', 'update', and 'search' operations to specify which context to operate on." + }, + "path": { + "type": "string", + "description": "File or directory path. Used in 'remove' operations to remove contexts by their source path, and required for 'update' operations to specify the new content location." + }, + "query": { + "type": "string", + "description": "The search query string. Required for 'search' operations. Performs semantic search across knowledge contexts to find relevant content." + }, + "operation_id": { + "type": "string", + "description": "Optional operation ID to cancel a specific operation. Used with 'cancel' command. If not provided, all active operations will be cancelled. Can be either the full operation ID or the short 8-character ID." + } + }, + "required": [ + "command" + ] + } + } +} + +---- crates/chat-cli/src/cli/mod.rs ---- +mod agent; +mod chat; +mod debug; +mod diagnostics; +mod feed; +mod issue; +mod mcp; +mod settings; +mod user; + +use std::fmt::Display; +use std::io::{ + Write as _, + stdout, +}; +use std::process::ExitCode; + +use agent::AgentArgs; +use anstream::println; +pub use chat::ConversationState; +use clap::{ + ArgAction, + CommandFactory, + Parser, + Subcommand, + ValueEnum, +}; +use crossterm::style::Stylize; +use eyre::{ + Result, + bail, +}; +use feed::Feed; +use serde::Serialize; +use tracing::{ + Level, + debug, +}; + +use crate::cli::chat::ChatArgs; +use crate::cli::mcp::McpSubcommand; +use crate::cli::user::{ + LoginArgs, + WhoamiArgs, +}; +use crate::logging::{ + LogArgs, + initialize_logging, +}; +use crate::os::Os; +use crate::util::directories::logs_dir; +use crate::util::{ + CLI_BINARY_NAME, + GOV_REGIONS, +}; + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)] +pub enum OutputFormat { + /// Outputs the results as markdown + #[default] + Plain, + /// Outputs the results as JSON + Json, + /// Outputs the results as pretty print JSON + JsonPretty, +} + +impl OutputFormat { + pub fn print(&self, text_fn: TFn, json_fn: JFn) + where + T: std::fmt::Display, + TFn: FnOnce() -> T, + J: Serialize, + JFn: FnOnce() -> J, + { + match self { + OutputFormat::Plain => println!("{}", text_fn()), + OutputFormat::Json => println!("{}", serde_json::to_string(&json_fn()).unwrap()), + OutputFormat::JsonPretty => println!("{}", serde_json::to_string_pretty(&json_fn()).unwrap()), + } + } +} + +/// The Amazon Q CLI +#[deny(missing_docs)] +#[derive(Debug, PartialEq, Subcommand)] +pub enum RootSubcommand { + /// Manage agents + Agent(AgentArgs), + /// AI assistant in your terminal + Chat(ChatArgs), + /// Log in to Amazon Q + Login(LoginArgs), + /// Log out of Amazon Q + Logout, + /// Print info about the current login session + Whoami(WhoamiArgs), + /// Show the profile associated with this idc user + Profile, + /// Customize appearance & behavior + #[command(alias("setting"))] + Settings(settings::SettingsArgs), + /// Run diagnostic tests + #[command(alias("diagnostics"))] + Diagnostic(diagnostics::DiagnosticArgs), + /// Create a new Github issue + Issue(issue::IssueArgs), + /// Version + #[command(hide = true)] + Version { + /// Show the changelog (use --changelog=all for all versions, or --changelog=x.x.x for a + /// specific version) + #[arg(long, num_args = 0..=1, default_missing_value = "")] + changelog: Option, + }, + /// Model Context Protocol (MCP) + #[command(subcommand)] + Mcp(McpSubcommand), +} + +impl RootSubcommand { + /// Whether the command should have an associated telemetry event. + /// + /// Emitting telemetry takes a long time so the answer is usually no. + pub fn valid_for_telemetry(&self) -> bool { + matches!(self, Self::Chat(_) | Self::Login(_) | Self::Profile | Self::Issue(_)) + } + + pub fn requires_auth(&self) -> bool { + matches!(self, Self::Chat(_) | Self::Profile) + } + + pub async fn execute(self, os: &mut Os) -> Result { + // Check for auth on subcommands that require it. + if self.requires_auth() && !crate::auth::is_logged_in(&mut os.database).await { + bail!( + "You are not logged in, please log in with {}", + format!("{CLI_BINARY_NAME} login").bold() + ); + } + + // Send executed telemetry. + if self.valid_for_telemetry() { + os.telemetry + .send_cli_subcommand_executed(&os.database, &self) + .await + .ok(); + } + + match self { + Self::Agent(args) => args.execute(os).await, + Self::Diagnostic(args) => args.execute(os).await, + Self::Login(args) => args.execute(os).await, + Self::Logout => user::logout(os).await, + Self::Whoami(args) => args.execute(os).await, + Self::Profile => user::profile(os).await, + Self::Settings(settings_args) => settings_args.execute(os).await, + Self::Issue(args) => args.execute(os).await, + Self::Version { changelog } => Cli::print_version(changelog), + Self::Chat(args) => args.execute(os).await, + Self::Mcp(args) => args.execute(os, &mut std::io::stderr()).await, + } + } +} + +impl Default for RootSubcommand { + fn default() -> Self { + Self::Chat(ChatArgs::default()) + } +} + +impl Display for RootSubcommand { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let name = match self { + Self::Agent(_) => "agent", + Self::Chat(_) => "chat", + Self::Login(_) => "login", + Self::Logout => "logout", + Self::Whoami(_) => "whoami", + Self::Profile => "profile", + Self::Settings(_) => "settings", + Self::Diagnostic(_) => "diagnostic", + Self::Issue(_) => "issue", + Self::Version { .. } => "version", + Self::Mcp(_) => "mcp", + }; + + write!(f, "{name}") + } +} + +#[derive(Debug, Parser, PartialEq, Default)] +#[command(version, about, name = crate::util::CHAT_BINARY_NAME)] +pub struct Cli { + #[command(subcommand)] + pub subcommand: Option, + /// Increase logging verbosity + #[arg(long, short = 'v', action = ArgAction::Count, global = true)] + pub verbose: u8, + /// Print help for all subcommands + #[arg(long)] + help_all: bool, +} + +impl Cli { + pub async fn execute(self) -> Result { + let subcommand = self.subcommand.unwrap_or_default(); + + // Initialize our logger and keep around the guard so logging can perform as expected. + let _log_guard = initialize_logging(LogArgs { + log_level: match self.verbose > 0 { + true => Some( + match self.verbose { + 1 => Level::WARN, + 2 => Level::INFO, + 3 => Level::DEBUG, + _ => Level::TRACE, + } + .to_string(), + ), + false => None, + }, + log_to_stdout: std::env::var_os("Q_LOG_STDOUT").is_some() || self.verbose > 0, + log_file_path: match subcommand { + RootSubcommand::Chat { .. } => Some(logs_dir().expect("home dir must be set").join("qchat.log")), + _ => None, + }, + delete_old_log_file: false, + }); + + // Check for region support. + if let Ok(region) = std::env::var("AWS_REGION") { + if GOV_REGIONS.contains(®ion.as_str()) { + bail!("AWS GovCloud ({region}) is not supported.") + } + } + + debug!(command =? std::env::args().collect::>(), "Command being ran"); + + let mut os = Os::new().await?; + let result = subcommand.execute(&mut os).await; + + let telemetry_result = os.telemetry.finish().await; + let exit_code = result?; + telemetry_result?; + + Ok(exit_code) + } + + fn print_changelog_entry(entry: &feed::Entry) -> Result<()> { + println!("Version {} ({})", entry.version, entry.date); + + if entry.changes.is_empty() { + println!(" No changes recorded for this version."); + } else { + for change in &entry.changes { + let type_label = match change.change_type.as_str() { + "added" => "Added", + "fixed" => "Fixed", + "changed" => "Changed", + other => other, + }; + + println!(" - {}: {}", type_label, change.description); + } + } + + println!(); + Ok(()) + } + + fn print_version(changelog: Option) -> Result { + // If no changelog is requested, display normal version information + if changelog.is_none() { + let _ = writeln!(stdout(), "{}", Self::command().render_version()); + return Ok(ExitCode::SUCCESS); + } + + let changelog_value = changelog.unwrap_or_default(); + let feed = Feed::load(); + + // Display changelog for all versions + if changelog_value == "all" { + let entries = feed.get_all_changelogs(); + if entries.is_empty() { + println!("No changelog information available."); + } else { + println!("Changelog for all versions:"); + for entry in entries { + Self::print_changelog_entry(&entry)?; + } + } + return Ok(ExitCode::SUCCESS); + } + + // Display changelog for a specific version (--changelog=x.x.x) + if !changelog_value.is_empty() { + match feed.get_version_changelog(&changelog_value) { + Some(entry) => { + println!("Changelog for version {}:", changelog_value); + Self::print_changelog_entry(&entry)?; + return Ok(ExitCode::SUCCESS); + }, + None => { + println!("No changelog information available for version {}.", changelog_value); + return Ok(ExitCode::SUCCESS); + }, + } + } + + // Display changelog for the current version (--changelog only) + let current_version = env!("CARGO_PKG_VERSION"); + match feed.get_version_changelog(current_version) { + Some(entry) => { + println!("Changelog for version {}:", current_version); + Self::print_changelog_entry(&entry)?; + }, + None => { + println!("No changelog information available for version {}.", current_version); + }, + } + + Ok(ExitCode::SUCCESS) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::util::CHAT_BINARY_NAME; + use crate::util::test::assert_parse; + + #[test] + fn debug_assert() { + Cli::command().debug_assert(); + } + + /// Test flag parsing for the top level [Cli] + #[test] + fn test_flags() { + assert_eq!(Cli::parse_from([CHAT_BINARY_NAME, "-v"]), Cli { + subcommand: None, + verbose: 1, + help_all: false, + }); + + assert_eq!(Cli::parse_from([CHAT_BINARY_NAME, "-vvv"]), Cli { + subcommand: None, + verbose: 3, + help_all: false, + }); + + assert_eq!(Cli::parse_from([CHAT_BINARY_NAME, "--help-all"]), Cli { + subcommand: None, + verbose: 0, + help_all: true, + }); + + assert_eq!(Cli::parse_from([CHAT_BINARY_NAME, "chat", "-vv"]), Cli { + subcommand: Some(RootSubcommand::Chat(ChatArgs { + resume: false, + input: None, + agent: None, + model: None, + trust_all_tools: false, + trust_tools: None, + no_interactive: false, + })), + verbose: 2, + help_all: false, + }); + } + + #[test] + fn test_version_changelog() { + assert_parse!(["version", "--changelog"], RootSubcommand::Version { + changelog: Some("".to_string()), + }); + } + + #[test] + fn test_version_changelog_all() { + assert_parse!(["version", "--changelog=all"], RootSubcommand::Version { + changelog: Some("all".to_string()), + }); + } + + #[test] + fn test_version_changelog_specific() { + assert_parse!(["version", "--changelog=1.8.0"], RootSubcommand::Version { + changelog: Some("1.8.0".to_string()), + }); + } + + #[test] + fn test_chat_with_context_profile() { + assert_parse!( + ["chat", "--profile", "my-profile"], + RootSubcommand::Chat(ChatArgs { + resume: false, + input: None, + agent: Some("my-profile".to_string()), + model: None, + trust_all_tools: false, + trust_tools: None, + no_interactive: false, + }) + ); + } + + #[test] + fn test_chat_with_context_profile_and_input() { + assert_parse!( + ["chat", "--profile", "my-profile", "Hello"], + RootSubcommand::Chat(ChatArgs { + resume: false, + input: Some("Hello".to_string()), + agent: Some("my-profile".to_string()), + model: None, + trust_all_tools: false, + trust_tools: None, + no_interactive: false, + }) + ); + } + + #[test] + fn test_chat_with_context_profile_and_accept_all() { + assert_parse!( + ["chat", "--profile", "my-profile", "--trust-all-tools"], + RootSubcommand::Chat(ChatArgs { + resume: false, + input: None, + agent: Some("my-profile".to_string()), + model: None, + trust_all_tools: true, + trust_tools: None, + no_interactive: false, + }) + ); + } + + #[test] + fn test_chat_with_no_interactive_and_resume() { + assert_parse!( + ["chat", "--no-interactive", "--resume"], + RootSubcommand::Chat(ChatArgs { + resume: true, + input: None, + agent: None, + model: None, + trust_all_tools: false, + trust_tools: None, + no_interactive: true, + }) + ); + assert_parse!( + ["chat", "--non-interactive", "-r"], + RootSubcommand::Chat(ChatArgs { + resume: true, + input: None, + agent: None, + model: None, + trust_all_tools: false, + trust_tools: None, + no_interactive: true, + }) + ); + } + + #[test] + fn test_chat_with_tool_trust_all() { + assert_parse!( + ["chat", "--trust-all-tools"], + RootSubcommand::Chat(ChatArgs { + resume: false, + input: None, + agent: None, + model: None, + trust_all_tools: true, + trust_tools: None, + no_interactive: false, + }) + ); + } + + #[test] + fn test_chat_with_tool_trust_none() { + assert_parse!( + ["chat", "--trust-tools="], + RootSubcommand::Chat(ChatArgs { + resume: false, + input: None, + agent: None, + model: None, + trust_all_tools: false, + trust_tools: Some(vec!["".to_string()]), + no_interactive: false, + }) + ); + } + + #[test] + fn test_chat_with_tool_trust_some() { + assert_parse!( + ["chat", "--trust-tools=fs_read,fs_write"], + RootSubcommand::Chat(ChatArgs { + resume: false, + input: None, + agent: None, + model: None, + trust_all_tools: false, + trust_tools: Some(vec!["fs_read".to_string(), "fs_write".to_string()]), + no_interactive: false, + }) + ); + } +} + +