diff --git a/crates/chat-cli/Cargo.toml b/crates/chat-cli/Cargo.toml index 621f6abf48..10ed45d145 100644 --- a/crates/chat-cli/Cargo.toml +++ b/crates/chat-cli/Cargo.toml @@ -6,6 +6,7 @@ homepage.workspace = true publish.workspace = true version.workspace = true license.workspace = true +default-run = "chat_cli" [lints] workspace = true diff --git a/crates/chat-cli/src/api_client/clients/client.rs b/crates/chat-cli/src/api_client/clients/client.rs index 2d29adb674..c5ab2ddc02 100644 --- a/crates/chat-cli/src/api_client/clients/client.rs +++ b/crates/chat-cli/src/api_client/clients/client.rs @@ -8,16 +8,20 @@ use tracing::error; use super::shared::bearer_sdk_config; use crate::api_client::interceptor::opt_out::OptOutInterceptor; -use crate::api_client::profile::Profile; use crate::api_client::{ ApiClientError, Endpoint, }; +use crate::auth::AuthError; use crate::auth::builder_id::BearerResolver; use crate::aws_common::{ UserAgentOverrideInterceptor, app_name, }; +use crate::database::{ + AuthProfile, + Database, +}; mod inner { use amzn_codewhisperer_client::Client as CodewhispererClient; @@ -32,67 +36,48 @@ mod inner { #[derive(Clone, Debug)] pub struct Client { inner: inner::Inner, - profile_arn: Option, + profile: Option, } impl Client { - pub async fn new() -> Result { + pub async fn new(database: &mut Database, endpoint: Option) -> Result { if cfg!(test) { return Ok(Self { inner: inner::Inner::Mock, - profile_arn: None, + profile: None, }); } - let endpoint = Endpoint::load_codewhisperer(); - Ok(Self::new_codewhisperer_client(&endpoint).await) - } - - pub async fn new_codewhisperer_client(endpoint: &Endpoint) -> Self { - let conf_builder: amzn_codewhisperer_client::config::Builder = (&bearer_sdk_config(endpoint).await).into(); + let endpoint = endpoint.unwrap_or(Endpoint::load_codewhisperer(database)); + let conf_builder: amzn_codewhisperer_client::config::Builder = + (&bearer_sdk_config(database, &endpoint).await).into(); let conf = conf_builder .http_client(crate::aws_common::http_client::client()) - .interceptor(OptOutInterceptor::new()) + .interceptor(OptOutInterceptor::new(database)) .interceptor(UserAgentOverrideInterceptor::new()) - .bearer_token_resolver(BearerResolver) + .bearer_token_resolver(BearerResolver::new(database).await?) .app_name(app_name()) .endpoint_url(endpoint.url()) .build(); let inner = inner::Inner::Codewhisperer(CodewhispererClient::from_conf(conf)); - let profile_arn = match crate::settings::state::get_value("api.codewhisperer.profile") { - Ok(Some(profile)) => match profile.get("arn") { - Some(arn) => match arn.as_str() { - Some(arn) => Some(arn.to_string()), - None => { - error!("Stored arn is not a string. Instead it was: {arn}"); - None - }, - }, - None => { - error!("Stored profile does not contain an arn. Instead it was: {profile}"); - None - }, - }, - Ok(None) => None, + let profile = match database.get_auth_profile() { + Ok(profile) => profile, Err(err) => { - error!("Failed to retrieve profile: {}", err); + error!("Failed to get auth profile: {err}"); None }, }; - Self { inner, profile_arn } + Ok(Self { inner, profile }) } - // .telemetry_event(TelemetryEvent::UserTriggerDecisionEvent(user_trigger_decision_event)) - // .user_context(user_context) - // .opt_out_preference(opt_out_preference) pub async fn send_telemetry_event( &self, telemetry_event: TelemetryEvent, user_context: UserContext, - opt_out: OptOutPreference, + telemetry_enabled: bool, ) -> Result<(), ApiClientError> { match &self.inner { inner::Inner::Codewhisperer(client) => { @@ -100,8 +85,11 @@ impl Client { .send_telemetry_event() .telemetry_event(telemetry_event) .user_context(user_context) - .opt_out_preference(opt_out) - .set_profile_arn(self.profile_arn.clone()) + .opt_out_preference(match telemetry_enabled { + true => OptOutPreference::OptIn, + false => OptOutPreference::OptOut, + }) + .set_profile_arn(self.profile.as_ref().map(|p| p.arn.clone())) .send() .await; Ok(()) @@ -110,23 +98,23 @@ impl Client { } } - pub async fn list_available_profiles(&self) -> Result, ApiClientError> { + pub async fn list_available_profiles(&self) -> Result, ApiClientError> { match &self.inner { inner::Inner::Codewhisperer(client) => { let mut profiles = vec![]; let mut client = client.list_available_profiles().into_paginator().send(); while let Some(profiles_output) = client.next().await { - profiles.extend(profiles_output?.profiles().iter().cloned().map(Profile::from)); + profiles.extend(profiles_output?.profiles().iter().cloned().map(AuthProfile::from)); } Ok(profiles) }, inner::Inner::Mock => Ok(vec![ - Profile { + AuthProfile { arn: "my:arn:1".to_owned(), profile_name: "MyProfile".to_owned(), }, - Profile { + AuthProfile { arn: "my:arn:2".to_owned(), profile_name: "MyOtherProfile".to_owned(), }, @@ -147,15 +135,14 @@ mod tests { #[tokio::test] async fn create_clients() { - let endpoint = Endpoint::load_codewhisperer(); - - let _ = Client::new().await; - let _ = Client::new_codewhisperer_client(&endpoint).await; + let mut database = crate::database::Database::new().await.unwrap(); + let _ = Client::new(&mut database, None).await; } #[tokio::test] async fn test_mock() { - let client = Client::new().await.unwrap(); + let mut database = crate::database::Database::new().await.unwrap(); + let client = Client::new(&mut database, None).await.unwrap(); client .send_telemetry_event( TelemetryEvent::ChatAddMessageEvent( @@ -171,7 +158,7 @@ mod tests { .product("") .build() .unwrap(), - OptOutPreference::OptIn, + false, ) .await .unwrap(); diff --git a/crates/chat-cli/src/api_client/clients/shared.rs b/crates/chat-cli/src/api_client/clients/shared.rs index e54a276130..9e985caa36 100644 --- a/crates/chat-cli/src/api_client/clients/shared.rs +++ b/crates/chat-cli/src/api_client/clients/shared.rs @@ -14,14 +14,16 @@ use crate::api_client::{ Endpoint, }; use crate::aws_common::behavior_version; +use crate::database::Database; +use crate::database::settings::Setting; // TODO(bskiser): confirm timeout is updated to an appropriate value? const DEFAULT_TIMEOUT_DURATION: Duration = Duration::from_secs(60 * 5); -pub(crate) fn timeout_config() -> TimeoutConfig { - let timeout = crate::settings::settings::get_int("api.timeout") - .ok() - .flatten() +pub fn timeout_config(database: &Database) -> TimeoutConfig { + let timeout = database + .settings + .get_int(Setting::ApiTimeout) .and_then(|i| i.try_into().ok()) .map_or(DEFAULT_TIMEOUT_DURATION, Duration::from_millis); @@ -39,27 +41,31 @@ pub(crate) fn stalled_stream_protection_config() -> StalledStreamProtectionConfi .build() } -async fn base_sdk_config(region: Region, credentials_provider: impl ProvideCredentials + 'static) -> SdkConfig { +async fn base_sdk_config( + database: &Database, + region: Region, + credentials_provider: impl ProvideCredentials + 'static, +) -> SdkConfig { aws_config::defaults(behavior_version()) .region(region) .credentials_provider(credentials_provider) - .timeout_config(timeout_config()) + .timeout_config(timeout_config(database)) .retry_config(RetryConfig::adaptive()) .load() .await } -pub(crate) async fn bearer_sdk_config(endpoint: &Endpoint) -> SdkConfig { +pub async fn bearer_sdk_config(database: &Database, endpoint: &Endpoint) -> SdkConfig { let credentials = Credentials::new("xxx", "xxx", None, None, "xxx"); - base_sdk_config(endpoint.region().clone(), credentials).await + base_sdk_config(database, endpoint.region().clone(), credentials).await } -pub(crate) async fn sigv4_sdk_config(endpoint: &Endpoint) -> Result { +pub async fn sigv4_sdk_config(database: &Database, endpoint: &Endpoint) -> Result { let credentials_chain = CredentialsChain::new().await; if let Err(err) = credentials_chain.provide_credentials().await { return Err(ApiClientError::Credentials(err)); }; - Ok(base_sdk_config(endpoint.region().clone(), credentials_chain).await) + Ok(base_sdk_config(database, endpoint.region().clone(), credentials_chain).await) } diff --git a/crates/chat-cli/src/api_client/clients/streaming_client.rs b/crates/chat-cli/src/api_client/clients/streaming_client.rs index df7064f732..3b5c377a49 100644 --- a/crates/chat-cli/src/api_client/clients/streaming_client.rs +++ b/crates/chat-cli/src/api_client/clients/streaming_client.rs @@ -30,6 +30,10 @@ use crate::aws_common::{ UserAgentOverrideInterceptor, app_name, }; +use crate::database::{ + AuthProfile, + Database, +}; mod inner { use std::sync::{ @@ -53,72 +57,63 @@ mod inner { #[derive(Clone, Debug)] pub struct StreamingClient { inner: inner::Inner, - profile_arn: Option, + profile: Option, } impl StreamingClient { - pub async fn new() -> Result { - let client = if crate::util::system_info::in_cloudshell() - || std::env::var("Q_USE_SENDMESSAGE").is_ok_and(|v| !v.is_empty()) - { - Self::new_qdeveloper_client(&Endpoint::load_q()).await? - } else { - Self::new_codewhisperer_client(&Endpoint::load_codewhisperer()).await - }; - Ok(client) + pub async fn new(database: &mut Database) -> Result { + Ok( + if crate::util::system_info::in_cloudshell() + || std::env::var("Q_USE_SENDMESSAGE").is_ok_and(|v| !v.is_empty()) + { + Self::new_qdeveloper_client(database, &Endpoint::load_q(database)).await? + } else { + Self::new_codewhisperer_client(database, &Endpoint::load_codewhisperer(database)).await? + }, + ) } pub fn mock(events: Vec>) -> Self { Self { inner: inner::Inner::Mock(Arc::new(Mutex::new(events.into_iter()))), - profile_arn: None, + profile: None, } } - pub async fn new_codewhisperer_client(endpoint: &Endpoint) -> Self { + pub async fn new_codewhisperer_client( + database: &mut Database, + endpoint: &Endpoint, + ) -> Result { let conf_builder: amzn_codewhisperer_streaming_client::config::Builder = - (&bearer_sdk_config(endpoint).await).into(); + (&bearer_sdk_config(database, endpoint).await).into(); let conf = conf_builder .http_client(crate::aws_common::http_client::client()) - .interceptor(OptOutInterceptor::new()) + .interceptor(OptOutInterceptor::new(database)) .interceptor(UserAgentOverrideInterceptor::new()) - .bearer_token_resolver(BearerResolver) + .bearer_token_resolver(BearerResolver::new(database).await?) .app_name(app_name()) .endpoint_url(endpoint.url()) .stalled_stream_protection(stalled_stream_protection_config()) .build(); let inner = inner::Inner::Codewhisperer(CodewhispererStreamingClient::from_conf(conf)); - let profile_arn = match crate::settings::state::get_value("api.codewhisperer.profile") { - Ok(Some(profile)) => match profile.get("arn") { - Some(arn) => match arn.as_str() { - Some(arn) => Some(arn.to_string()), - None => { - error!("Stored arn is not a string. Instead it was: {arn}"); - None - }, - }, - None => { - error!("Stored profile does not contain an arn. Instead it was: {profile}"); - None - }, - }, - Ok(None) => None, + let profile = match database.get_auth_profile() { + Ok(profile) => profile, Err(err) => { - error!("Failed to retrieve profile: {}", err); + error!("Failed to get auth profile: {err}"); None }, }; - Self { inner, profile_arn } + Ok(Self { inner, profile }) } - pub async fn new_qdeveloper_client(endpoint: &Endpoint) -> Result { + pub async fn new_qdeveloper_client(database: &Database, endpoint: &Endpoint) -> Result { let conf_builder: amzn_qdeveloper_streaming_client::config::Builder = - (&sigv4_sdk_config(endpoint).await?).into(); + (&sigv4_sdk_config(database, endpoint).await?).into(); let conf = conf_builder .http_client(crate::aws_common::http_client::client()) - .interceptor(OptOutInterceptor::new()) + .interceptor(OptOutInterceptor::new(database)) .interceptor(UserAgentOverrideInterceptor::new()) .app_name(app_name()) .endpoint_url(endpoint.url()) @@ -127,7 +122,7 @@ impl StreamingClient { let client = QDeveloperStreamingClient::from_conf(conf); Ok(Self { inner: inner::Inner::QDeveloper(client), - profile_arn: None, + profile: None, }) } @@ -162,7 +157,7 @@ impl StreamingClient { let response = client .generate_assistant_response() .conversation_state(conversation_state) - .set_profile_arn(self.profile_arn.clone()) + .set_profile_arn(self.profile.as_ref().map(|p| p.arn.clone())) .send() .await; @@ -267,11 +262,12 @@ mod tests { #[tokio::test] async fn create_clients() { - let endpoint = Endpoint::load_codewhisperer(); + let mut database = Database::new().await.unwrap(); + let endpoint = Endpoint::load_codewhisperer(&database); - let _ = StreamingClient::new().await; - let _ = StreamingClient::new_codewhisperer_client(&endpoint).await; - let _ = StreamingClient::new_qdeveloper_client(&endpoint).await; + let _ = StreamingClient::new(&mut database).await; + let _ = StreamingClient::new_codewhisperer_client(&mut database, &endpoint).await; + let _ = StreamingClient::new_qdeveloper_client(&database, &endpoint).await; } #[tokio::test] @@ -311,7 +307,8 @@ mod tests { #[ignore] #[tokio::test] async fn assistant_response() { - let client = StreamingClient::new().await.unwrap(); + let mut database = Database::new().await.unwrap(); + let client = StreamingClient::new(&mut database).await.unwrap(); let mut response = client .send_message(ConversationState { conversation_id: None, diff --git a/crates/chat-cli/src/api_client/consts.rs b/crates/chat-cli/src/api_client/consts.rs index 3ba9b17013..c1ebc45d90 100644 --- a/crates/chat-cli/src/api_client/consts.rs +++ b/crates/chat-cli/src/api_client/consts.rs @@ -12,5 +12,4 @@ pub const PROD_CODEWHISPERER_FRA_ENDPOINT_URL: &str = "https://q.eu-central-1.am pub const PROD_CODEWHISPERER_FRA_ENDPOINT_REGION: Region = Region::from_static("eu-central-1"); // Opt out constants -pub const SHARE_CODEWHISPERER_CONTENT_SETTINGS_KEY: &str = "codeWhisperer.shareCodeWhispererContentWithAWS"; pub const X_AMZN_CODEWHISPERER_OPT_OUT_HEADER: &str = "x-amzn-codewhisperer-optout"; diff --git a/crates/chat-cli/src/api_client/customization.rs b/crates/chat-cli/src/api_client/customization.rs index 00df3e781d..98d22c8f67 100644 --- a/crates/chat-cli/src/api_client/customization.rs +++ b/crates/chat-cli/src/api_client/customization.rs @@ -5,10 +5,6 @@ use serde::{ Serialize, }; -use crate::settings::State; - -const CUSTOMIZATION_STATE_KEY: &str = "api.selectedCustomization"; - #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct Customization { @@ -19,23 +15,6 @@ pub struct Customization { pub description: Option, } -impl Customization { - /// Load the currently selected customization from state - pub fn load_selected(state: &State) -> Result, crate::settings::SettingsError> { - state.get(CUSTOMIZATION_STATE_KEY) - } - - /// Save the currently selected customization to state - pub fn save_selected(&self, state: &State) -> Result<(), crate::settings::SettingsError> { - state.set_value(CUSTOMIZATION_STATE_KEY, serde_json::to_value(self)?) - } - - /// Delete the currently selected customization from state - pub fn delete_selected(state: &State) -> Result<(), crate::settings::SettingsError> { - state.remove_value(CUSTOMIZATION_STATE_KEY) - } -} - impl From for CodewhispererCustomization { fn from(Customization { arn, name, description }: Customization) -> Self { CodewhispererCustomization::builder() @@ -115,23 +94,6 @@ mod tests { assert_eq!(custom_from_consolas.description, Some("description".into())); } - #[test] - fn test_customization_save_load() { - let state = State::new(); - - let value = Customization { - arn: "arn".into(), - name: Some("name".into()), - description: Some("description".into()), - }; - - value.save_selected(&state).unwrap(); - let loaded_value = Customization::load_selected(&state).unwrap(); - assert_eq!(loaded_value, Some(value)); - - Customization::delete_selected(&state).unwrap(); - } - #[test] fn test_customization_serde() { let customization = Customization { diff --git a/crates/chat-cli/src/api_client/endpoints.rs b/crates/chat-cli/src/api_client/endpoints.rs index 21688984c3..38c6152266 100644 --- a/crates/chat-cli/src/api_client/endpoints.rs +++ b/crates/chat-cli/src/api_client/endpoints.rs @@ -12,6 +12,8 @@ use crate::api_client::consts::{ PROD_Q_ENDPOINT_REGION, PROD_Q_ENDPOINT_URL, }; +use crate::database::Database; +use crate::database::settings::Setting; #[derive(Debug, Clone, PartialEq, Eq)] pub struct Endpoint { @@ -33,35 +35,30 @@ impl Endpoint { region: PROD_Q_ENDPOINT_REGION, }; - pub fn load_codewhisperer() -> Self { - let (endpoint, region) = - if let Ok(Some(Value::Object(o))) = crate::settings::settings::get_value("api.codewhisperer.service") { - // The following branch is evaluated in case the user has set their own endpoint. - ( - o.get("endpoint").and_then(|v| v.as_str()).map(|v| v.to_owned()), - o.get("region").and_then(|v| v.as_str()).map(|v| v.to_owned()), - ) - } else if let Ok(Some(Value::Object(o))) = crate::settings::state::get_value("api.codewhisperer.profile") { - // The following branch is evaluated in the case of user profile being set. - match o.get("arn").and_then(|v| v.as_str()).map(|v| v.to_owned()) { - Some(arn) => { - let region = arn.split(':').nth(3).unwrap_or_default().to_owned(); - match Self::CODEWHISPERER_ENDPOINTS - .iter() - .find(|e| e.region().as_ref() == region) - { - Some(endpoint) => (Some(endpoint.url().to_owned()), Some(region)), - None => { - error!("Failed to find endpoint for region: {region}"); - (None, None) - }, - } - }, - None => (None, None), - } - } else { - (None, None) - }; + pub fn load_codewhisperer(database: &Database) -> Self { + let (endpoint, region) = if let Some(Value::Object(o)) = database.settings.get(Setting::ApiCodeWhispererService) + { + // The following branch is evaluated in case the user has set their own endpoint. + ( + o.get("endpoint").and_then(|v| v.as_str()).map(|v| v.to_owned()), + o.get("region").and_then(|v| v.as_str()).map(|v| v.to_owned()), + ) + } else if let Ok(Some(profile)) = database.get_auth_profile() { + // The following branch is evaluated in the case of user profile being set. + let region = profile.arn.split(':').nth(3).unwrap_or_default().to_owned(); + match Self::CODEWHISPERER_ENDPOINTS + .iter() + .find(|e| e.region().as_ref() == region) + { + Some(endpoint) => (Some(endpoint.url().to_owned()), Some(region)), + None => { + error!("Failed to find endpoint for region: {region}"); + (None, None) + }, + } + } else { + (None, None) + }; match (endpoint, region) { (Some(endpoint), Some(region)) => Self { @@ -72,9 +69,9 @@ impl Endpoint { } } - pub fn load_q() -> Self { - match crate::settings::settings::get_value("api.q.service") { - Ok(Some(Value::Object(o))) => { + pub fn load_q(database: &Database) -> Self { + match database.settings.get(Setting::ApiQService) { + Some(Value::Object(o)) => { let endpoint = o.get("endpoint").and_then(|v| v.as_str()); let region = o.get("region").and_then(|v| v.as_str()); @@ -105,10 +102,11 @@ mod tests { use super::*; - #[test] - fn test_endpoints() { - let _ = Endpoint::load_codewhisperer(); - let _ = Endpoint::load_q(); + #[tokio::test] + async fn test_endpoints() { + let database = Database::new().await.unwrap(); + let _ = Endpoint::load_codewhisperer(&database); + let _ = Endpoint::load_q(&database); let prod = &Endpoint::DEFAULT_ENDPOINT; Url::parse(prod.url()).unwrap(); diff --git a/crates/chat-cli/src/api_client/error.rs b/crates/chat-cli/src/api_client/error.rs index 523fbdd0cd..64bbc99803 100644 --- a/crates/chat-cli/src/api_client/error.rs +++ b/crates/chat-cli/src/api_client/error.rs @@ -13,6 +13,7 @@ pub use aws_smithy_runtime_api::client::result::SdkError; use aws_smithy_types::event_stream::RawMessage; use thiserror::Error; +use crate::auth::AuthError; use crate::aws_common::SdkErrorDisplay; #[derive(Debug, Error)] @@ -61,6 +62,9 @@ pub enum ApiClientError { #[error(transparent)] ListAvailableProfilesError(#[from] SdkError), + + #[error(transparent)] + AuthError(#[from] AuthError), } #[cfg(test)] diff --git a/crates/chat-cli/src/api_client/interceptor/opt_out.rs b/crates/chat-cli/src/api_client/interceptor/opt_out.rs index ef7d59af06..fd99dbc64f 100644 --- a/crates/chat-cli/src/api_client/interceptor/opt_out.rs +++ b/crates/chat-cli/src/api_client/interceptor/opt_out.rs @@ -4,24 +4,28 @@ use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterce use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents; use aws_smithy_types::config_bag::ConfigBag; -use crate::api_client::consts::{ - SHARE_CODEWHISPERER_CONTENT_SETTINGS_KEY, - X_AMZN_CODEWHISPERER_OPT_OUT_HEADER, -}; +use crate::api_client::consts::X_AMZN_CODEWHISPERER_OPT_OUT_HEADER; +use crate::database::Database; +use crate::database::settings::Setting; -fn is_codewhisperer_content_optout() -> bool { - !crate::settings::settings::get_bool_or(SHARE_CODEWHISPERER_CONTENT_SETTINGS_KEY, true) +fn is_codewhisperer_content_optout(database: &Database) -> bool { + !database + .settings + .get_bool(Setting::ShareCodeWhispererContent) + .unwrap_or(true) } #[derive(Debug, Clone)] pub struct OptOutInterceptor { + is_codewhisperer_content_optout: bool, override_value: Option, _inner: (), } impl OptOutInterceptor { - pub const fn new() -> Self { + pub fn new(database: &Database) -> Self { Self { + is_codewhisperer_content_optout: is_codewhisperer_content_optout(database), override_value: None, _inner: (), } @@ -39,7 +43,7 @@ impl Intercept for OptOutInterceptor { _runtime_components: &RuntimeComponents, _cfg: &mut ConfigBag, ) -> Result<(), BoxError> { - let opt_out = self.override_value.unwrap_or_else(is_codewhisperer_content_optout); + let opt_out = self.override_value.unwrap_or(self.is_codewhisperer_content_optout); context .request_mut() .headers_mut() @@ -56,8 +60,8 @@ mod tests { use super::*; - #[test] - fn test_opt_out_interceptor() { + #[tokio::test] + async fn test_opt_out_interceptor() { let rc = RuntimeComponentsBuilder::for_tests().build().unwrap(); let mut cfg = ConfigBag::base(); @@ -65,7 +69,8 @@ mod tests { context.set_request(aws_smithy_runtime_api::http::Request::empty()); let mut context = BeforeTransmitInterceptorContextMut::from(&mut context); - let mut interceptor = OptOutInterceptor::new(); + let database = Database::new().await.unwrap(); + let mut interceptor = OptOutInterceptor::new(&database); println!("Interceptor: {}", interceptor.name()); interceptor diff --git a/crates/chat-cli/src/api_client/mod.rs b/crates/chat-cli/src/api_client/mod.rs index f3b5f4fbe6..9061433eaa 100644 --- a/crates/chat-cli/src/api_client/mod.rs +++ b/crates/chat-cli/src/api_client/mod.rs @@ -1,7 +1,7 @@ pub mod clients; pub(crate) mod consts; pub(crate) mod credentials; -mod customization; +pub mod customization; mod endpoints; mod error; pub(crate) mod interceptor; diff --git a/crates/chat-cli/src/api_client/profile.rs b/crates/chat-cli/src/api_client/profile.rs index ea8248f5f9..4a5527df75 100644 --- a/crates/chat-cli/src/api_client/profile.rs +++ b/crates/chat-cli/src/api_client/profile.rs @@ -1,35 +1,20 @@ -use serde::{ - Deserialize, - Serialize, -}; - use crate::api_client::Client; use crate::api_client::endpoints::Endpoint; +use crate::auth::AuthError; +use crate::database::{ + AuthProfile, + Database, +}; -#[derive(Debug, Deserialize, Serialize)] -pub struct Profile { - pub arn: String, - pub profile_name: String, -} - -impl From for Profile { - fn from(profile: amzn_codewhisperer_client::types::Profile) -> Self { - Self { - arn: profile.arn, - profile_name: profile.profile_name, - } - } -} - -pub async fn list_available_profiles() -> Vec { +pub async fn list_available_profiles(database: &mut Database) -> Result, AuthError> { let mut profiles = vec![]; for endpoint in Endpoint::CODEWHISPERER_ENDPOINTS { - let client = Client::new_codewhisperer_client(&endpoint).await; + let client = Client::new(database, Some(endpoint.clone())).await?; match client.list_available_profiles().await { Ok(mut p) => profiles.append(&mut p), Err(e) => tracing::error!("Failed to list profiles from endpoint {:?}: {:?}", endpoint, e), } } - profiles + Ok(profiles) } diff --git a/crates/chat-cli/src/auth/builder_id.rs b/crates/chat-cli/src/auth/builder_id.rs index f62fc29ac6..e277c1410b 100644 --- a/crates/chat-cli/src/auth/builder_id.rs +++ b/crates/chat-cli/src/auth/builder_id.rs @@ -41,7 +41,6 @@ use aws_smithy_runtime_api::client::identity::{ }; use aws_smithy_types::error::display::DisplayErrorContext; use aws_types::region::Region; -use aws_types::request_id::RequestId; use time::OffsetDateTime; use tracing::{ debug, @@ -52,16 +51,18 @@ use tracing::{ use crate::auth::AuthError; use crate::auth::consts::*; use crate::auth::scope::is_scopes; -use crate::auth::secret_store::{ +use crate::aws_common::app_name; +use crate::database::Database; +use crate::database::secret_store::{ Secret, SecretStore, }; -use crate::aws_common::app_name; -use crate::telemetry::send_refresh_credentials; #[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub enum OAuthFlow { DeviceCode, + // This must remain backwards compatible + #[serde(rename = "PKCE")] Pkce, } @@ -69,7 +70,7 @@ impl std::fmt::Display for OAuthFlow { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match *self { OAuthFlow::DeviceCode => write!(f, "DeviceCode"), - OAuthFlow::Pkce => write!(f, "Pkce"), + OAuthFlow::Pkce => write!(f, "PKCE"), } } } @@ -85,18 +86,18 @@ pub(crate) fn oidc_url(region: &Region) -> String { format!("https://oidc.{region}.amazonaws.com") } -pub(crate) fn client(region: Region) -> Client { - let retry_config = RetryConfig::standard().with_max_attempts(3); - let sdk_config = aws_types::SdkConfig::builder() - .http_client(crate::aws_common::http_client::client()) - .behavior_version(BehaviorVersion::v2025_01_17()) - .endpoint_url(oidc_url(®ion)) - .region(region) - .retry_config(retry_config) - .sleep_impl(SharedAsyncSleep::new(TokioSleep::new())) - .app_name(app_name()) - .build(); - Client::new(&sdk_config) +pub fn client(region: Region) -> Client { + Client::new( + &aws_types::SdkConfig::builder() + .http_client(crate::aws_common::http_client::client()) + .behavior_version(BehaviorVersion::v2025_01_17()) + .endpoint_url(oidc_url(®ion)) + .region(region) + .retry_config(RetryConfig::standard().with_max_attempts(3)) + .sleep_impl(SharedAsyncSleep::new(TokioSleep::new())) + .app_name(app_name()) + .build(), + ) } /// Represents an OIDC registered client, resulting from the "register client" API call. @@ -156,11 +157,11 @@ impl DeviceRegistration { /// Loads the client saved in the secret store if available, otherwise registers a new client /// and saves it in the secret store. pub async fn init_device_code_registration( + database: &Database, client: &Client, - secret_store: &SecretStore, region: &Region, ) -> Result { - match Self::load_from_secret_store(secret_store, region).await { + match Self::load_from_secret_store(&database.secret_store, region).await { Ok(Some(registration)) if registration.oauth_flow == OAuthFlow::DeviceCode => match ®istration.scopes { Some(scopes) if is_scopes(scopes) => return Ok(registration), _ => warn!("Invalid scopes in device registration, ignoring"), @@ -189,7 +190,7 @@ impl DeviceRegistration { SCOPES.iter().map(|s| (*s).to_owned()).collect(), ); - if let Err(err) = device_registration.save(secret_store).await { + if let Err(err) = device_registration.save(&database.secret_store).await { error!(?err, "Failed to write device registration to keychain"); } @@ -225,7 +226,7 @@ pub struct StartDeviceAuthorizationResponse { /// Init a builder id request pub async fn start_device_authorization( - secret_store: &SecretStore, + database: &Database, start_url: Option, region: Option, ) -> Result { @@ -236,7 +237,7 @@ pub async fn start_device_authorization( client_id, client_secret, .. - } = DeviceRegistration::init_device_code_registration(&client, secret_store, ®ion).await?; + } = DeviceRegistration::init_device_code_registration(database, &client, ®ion).await?; let output = client .start_device_authorization() @@ -293,8 +294,8 @@ impl BuilderIdToken { } /// Load the token from the keychain, refresh the token if it is expired and return it - pub async fn load(secret_store: &SecretStore, force_refresh: bool) -> Result, AuthError> { - match secret_store.get(Self::SECRET_KEY).await { + pub async fn load(database: &mut Database) -> Result, AuthError> { + match database.secret_store.get(Self::SECRET_KEY).await { Ok(Some(secret)) => { let token: Option = serde_json::from_str(&secret.0)?; match token { @@ -303,8 +304,8 @@ impl BuilderIdToken { let client = client(region.clone()); // if token is expired try to refresh - if token.is_expired() || force_refresh { - token.refresh_token(&client, secret_store, ®ion).await + if token.is_expired() { + token.refresh_token(&client, &database.secret_store, ®ion).await } else { Ok(Some(token)) } @@ -315,7 +316,7 @@ impl BuilderIdToken { Ok(None) => Ok(None), Err(err) => { error!(%err, "Error getting builder id token from keychain"); - Err(err) + Err(err)? }, } } @@ -364,13 +365,6 @@ impl BuilderIdToken { .await { Ok(output) => { - send_refresh_credentials( - self.start_url.clone().unwrap_or_else(|| START_URL.to_owned()), - output.request_id().unwrap_or_default().into(), - registration.oauth_flow.to_string(), - ) - .await; - let token: BuilderIdToken = Self::from_output( output, region.clone(), @@ -392,13 +386,6 @@ impl BuilderIdToken { // if the error is the client's fault, clear the token if let SdkError::ServiceError(service_err) = &err { - send_refresh_credentials( - self.start_url.clone().unwrap_or_else(|| START_URL.to_owned()), - err.request_id().unwrap_or_default().into(), - registration.oauth_flow.to_string(), - ) - .await; - if !service_err.err().is_slow_down_exception() { if let Err(err) = self.delete(secret_store).await { error!(?err, "Failed to delete builder id token"); @@ -468,7 +455,7 @@ pub enum PollCreateToken { /// Poll for the create token response pub async fn poll_create_token( - secret_store: &SecretStore, + database: &Database, device_code: String, start_url: Option, region: Option, @@ -481,7 +468,7 @@ pub async fn poll_create_token( client_secret, scopes, .. - } = match DeviceRegistration::init_device_code_registration(&client, secret_store, ®ion).await { + } = match DeviceRegistration::init_device_code_registration(database, &client, ®ion).await { Ok(res) => res, Err(err) => { return PollCreateToken::Error(err); @@ -501,7 +488,7 @@ pub async fn poll_create_token( let token: BuilderIdToken = BuilderIdToken::from_output(output, region, start_url, OAuthFlow::DeviceCode, scopes); - if let Err(err) = token.save(secret_store).await { + if let Err(err) = token.save(&database.secret_store).await { error!(?err, "Failed to store builder id token"); }; @@ -517,21 +504,11 @@ pub async fn poll_create_token( } } -pub async fn builder_id_token() -> Result, AuthError> { - let secret_store = SecretStore::new().await?; - BuilderIdToken::load(&secret_store, false).await -} - -pub async fn refresh_token() -> Result, AuthError> { - let secret_store = SecretStore::new().await?; - BuilderIdToken::load(&secret_store, true).await +pub async fn is_logged_in(database: &mut Database) -> bool { + matches!(BuilderIdToken::load(database).await, Ok(Some(_))) } -pub async fn is_logged_in() -> bool { - matches!(builder_id_token().await, Ok(Some(_))) -} - -pub async fn logout() -> Result<(), AuthError> { +pub async fn logout(database: &mut Database) -> Result<(), AuthError> { let Ok(secret_store) = SecretStore::new().await else { return Ok(()); }; @@ -541,7 +518,7 @@ pub async fn logout() -> Result<(), AuthError> { secret_store.delete(DeviceRegistration::SECRET_KEY), ); - let profile_res = crate::settings::state::remove_value("api.codewhisperer.profile"); + let profile_res = database.unset_auth_profile(); builder_res?; device_res?; @@ -551,7 +528,17 @@ pub async fn logout() -> Result<(), AuthError> { } #[derive(Debug, Clone)] -pub struct BearerResolver; +pub struct BearerResolver { + token: Option, +} + +impl BearerResolver { + pub async fn new(database: &mut Database) -> Result { + Ok(Self { + token: BuilderIdToken::load(database).await?, + }) + } +} impl ResolveIdentity for BearerResolver { fn resolve_identity<'a>( @@ -560,11 +547,9 @@ impl ResolveIdentity for BearerResolver { _config_bag: &'a ConfigBag, ) -> IdentityFuture<'a> { IdentityFuture::new_boxed(Box::pin(async { - let secret_store = SecretStore::new().await?; - let token = BuilderIdToken::load(&secret_store, false).await?; - match token { + match &self.token { Some(token) => Ok(Identity::new( - Token::new(token.access_token.0, Some(token.expires_at.into())), + Token::new(token.access_token.0.clone(), Some(token.expires_at.into())), Some(token.expires_at.into()), )), None => Err(AuthError::NoToken.into()), @@ -593,11 +578,11 @@ mod tests { #[test] fn test_oauth_flow_ser_deser() { test_ser_deser!(OAuthFlow, OAuthFlow::DeviceCode, "DeviceCode"); - test_ser_deser!(OAuthFlow, OAuthFlow::Pkce, "Pkce"); + test_ser_deser!(OAuthFlow, OAuthFlow::Pkce, "PKCE"); } - #[test] - fn test_client() { + #[tokio::test] + async fn test_client() { println!("{:?}", client(US_EAST_1)); println!("{:?}", client(US_WEST_2)); } @@ -628,66 +613,4 @@ mod tests { token.start_url = Some("https://amzn.awsapps.com/start".into()); assert_eq!(token.token_type(), TokenType::IamIdentityCenter); } - - #[ignore = "not in ci"] - #[tokio::test] - async fn logout_test() { - logout().await.unwrap(); - } - - #[ignore = "login flow"] - #[tokio::test] - async fn test_login() { - let start_url = Some("https://amzn.awsapps.com/start".into()); - let region = Some("us-east-1".into()); - - // let start_url = None; - // let region = None; - - let secret_store = SecretStore::new().await.unwrap(); - let res: StartDeviceAuthorizationResponse = - start_device_authorization(&secret_store, start_url.clone(), region.clone()) - .await - .unwrap(); - - println!("{:?}", res); - - loop { - match poll_create_token( - &secret_store, - res.device_code.clone(), - start_url.clone(), - region.clone(), - ) - .await - { - PollCreateToken::Pending => { - tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - }, - PollCreateToken::Complete => { - break; - }, - PollCreateToken::Error(err) => { - println!("{}", err); - break; - }, - } - } - } - - #[ignore = "not in ci"] - #[tokio::test] - async fn test_load() { - let secret_store = SecretStore::new().await.unwrap(); - let token = BuilderIdToken::load(&secret_store, false).await; - println!("{:?}", token); - // println!("{:?}", token.unwrap().unwrap().access_token.0); - } - - #[ignore = "not in ci"] - #[tokio::test] - async fn test_refresh() { - let token = refresh_token().await.unwrap().unwrap(); - println!("{:?}", token); - } } diff --git a/crates/chat-cli/src/auth/mod.rs b/crates/chat-cli/src/auth/mod.rs index 26eaec7153..4b425f2a6f 100644 --- a/crates/chat-cli/src/auth/mod.rs +++ b/crates/chat-cli/src/auth/mod.rs @@ -2,17 +2,14 @@ pub mod builder_id; mod consts; pub mod pkce; mod scope; -pub mod secret_store; use aws_sdk_ssooidc::error::SdkError; use aws_sdk_ssooidc::operation::create_token::CreateTokenError; use aws_sdk_ssooidc::operation::register_client::RegisterClientError; use aws_sdk_ssooidc::operation::start_device_authorization::StartDeviceAuthorizationError; pub use builder_id::{ - builder_id_token, is_logged_in, logout, - refresh_token, }; pub use consts::START_URL; use thiserror::Error; @@ -20,13 +17,13 @@ use thiserror::Error; #[derive(Debug, Error)] pub enum AuthError { #[error(transparent)] - Ssooidc(#[from] Box), + Ssooidc(Box), #[error(transparent)] - SdkRegisterClient(#[from] SdkError), + SdkRegisterClient(Box>), #[error(transparent)] - SdkCreateToken(#[from] SdkError), + SdkCreateToken(Box>), #[error(transparent)] - SdkStartDeviceAuthorization(#[from] SdkError), + SdkStartDeviceAuthorization(Box>), #[error(transparent)] Io(#[from] std::io::Error), #[error(transparent)] @@ -35,17 +32,8 @@ pub enum AuthError { Directories(#[from] crate::util::directories::DirectoryError), #[error(transparent)] SerdeJson(#[from] serde_json::Error), - #[cfg(target_os = "macos")] - #[error("Security error: {}", .0)] - Security(String), #[error(transparent)] - StringFromUtf8(#[from] std::string::FromUtf8Error), - #[error(transparent)] - StrFromUtf8(#[from] std::str::Utf8Error), - #[error(transparent)] - DbOpenError(#[from] crate::settings::error::DbOpenError), - #[error(transparent)] - Setting(#[from] crate::settings::SettingsError), + DbOpenError(#[from] crate::database::DbOpenError), #[error("No token")] NoToken, #[error("OAuth state mismatch. Actual: {} | Expected: {}", .actual, .expected)] @@ -56,4 +44,30 @@ pub enum AuthError { OAuthMissingCode, #[error("OAuth error: {0}")] OAuthCustomError(String), + #[error(transparent)] + DatabaseError(#[from] crate::database::DatabaseError), +} + +impl From for AuthError { + fn from(value: aws_sdk_ssooidc::Error) -> Self { + Self::Ssooidc(Box::new(value)) + } +} + +impl From> for AuthError { + fn from(value: SdkError) -> Self { + Self::SdkRegisterClient(Box::new(value)) + } +} + +impl From> for AuthError { + fn from(value: SdkError) -> Self { + Self::SdkCreateToken(Box::new(value)) + } +} + +impl From> for AuthError { + fn from(value: SdkError) -> Self { + Self::SdkStartDeviceAuthorization(Box::new(value)) + } } diff --git a/crates/chat-cli/src/auth/pkce.rs b/crates/chat-cli/src/auth/pkce.rs index e9d6a3911d..74b2f13705 100644 --- a/crates/chat-cli/src/auth/pkce.rs +++ b/crates/chat-cli/src/auth/pkce.rs @@ -52,11 +52,11 @@ use tracing::{ use crate::auth::builder_id::*; use crate::auth::consts::*; -use crate::auth::secret_store::SecretStore; use crate::auth::{ AuthError, START_URL, }; +use crate::database::Database; const DEFAULT_AUTHORIZATION_TIMEOUT: Duration = Duration::from_secs(60 * 3); @@ -227,11 +227,11 @@ impl PkceRegistration { }) } - /// Hosts a local HTTP server to listen for browser redirects. If a [`SecretStore`] is passed, + /// Hosts a local HTTP server to listen for browser redirects. If a [`Database`] is passed, /// then the access and refresh tokens will be saved. /// /// Only the first connection will be served. - pub async fn finish(self, client: &C, secret_store: Option<&SecretStore>) -> Result<(), AuthError> { + pub async fn finish(self, client: &C, database: Option<&Database>) -> Result<(), AuthError> { let code = tokio::select! { code = Self::recv_code(self.listener, self.state) => { code? @@ -269,18 +269,16 @@ impl PkceRegistration { C::scopes(), ); - let Some(secret_store) = secret_store else { - return Ok(()); - }; + if let Some(database) = database { + if let Err(err) = device_registration.save(&database.secret_store).await { + error!(?err, "Failed to store pkce registration to secret store"); + } - if let Err(err) = device_registration.save(secret_store).await { - error!(?err, "Failed to store pkce registration to secret store"); + if let Err(err) = token.save(&database.secret_store).await { + error!(?err, "Failed to store builder id token"); + }; } - if let Err(err) = token.save(secret_store).await { - error!(?err, "Failed to store builder id token"); - }; - Ok(()) } @@ -512,6 +510,7 @@ mod tests { #[tokio::test] async fn test_pkce_flow_e2e() { tracing_subscriber::fmt::init(); + let start_url = "https://amzn.awsapps.com/start".to_string(); let region = Region::new("us-east-1"); let client = client(region.clone()); @@ -523,8 +522,8 @@ mod tests { panic!("unable to open the URL"); } println!("Waiting for authorization to complete..."); - let secret_store = SecretStore::new().await.unwrap(); - registration.finish(&client, Some(&secret_store)).await.unwrap(); + + registration.finish(&client, None).await.unwrap(); println!("Authorization successful"); } diff --git a/crates/chat-cli/src/auth/secret_store/linux.rs b/crates/chat-cli/src/auth/secret_store/linux.rs deleted file mode 100644 index a3b4529f73..0000000000 --- a/crates/chat-cli/src/auth/secret_store/linux.rs +++ /dev/null @@ -1,28 +0,0 @@ -use super::Secret; -use super::sqlite::SqliteSecretStore; -use crate::Result; -use crate::auth::AuthError; - -pub struct SecretStoreImpl { - inner: SqliteSecretStore, -} - -impl SecretStoreImpl { - pub async fn new() -> Result { - Ok(Self { - inner: SqliteSecretStore::new().await?, - }) - } - - pub async fn set(&self, key: &str, password: &str) -> Result<(), AuthError> { - self.inner.set(key, password).await - } - - pub async fn get(&self, key: &str) -> Result, AuthError> { - self.inner.get(key).await - } - - pub async fn delete(&self, key: &str) -> Result<(), AuthError> { - self.inner.delete(key).await - } -} diff --git a/crates/chat-cli/src/auth/secret_store/sqlite.rs b/crates/chat-cli/src/auth/secret_store/sqlite.rs deleted file mode 100644 index 51f6024c41..0000000000 --- a/crates/chat-cli/src/auth/secret_store/sqlite.rs +++ /dev/null @@ -1,50 +0,0 @@ -#![allow(dead_code)] -use super::Secret; -use crate::auth::AuthError; -use crate::settings::sqlite::{ - Db, - database, -}; - -pub struct SqliteSecretStore { - db: &'static Db, -} - -impl SqliteSecretStore { - pub async fn new() -> Result { - Ok(Self { db: database()? }) - } - - pub async fn set(&self, key: &str, password: &str) -> Result<(), AuthError> { - Ok(self.db.set_auth_value(key, password)?) - } - - pub async fn get(&self, key: &str) -> Result, AuthError> { - Ok(self.db.get_auth_value(key)?.map(Secret)) - } - - pub async fn delete(&self, key: &str) -> Result<(), AuthError> { - Ok(self.db.unset_auth_value(key)?) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_set_get_delete() { - let store = SqliteSecretStore::new().await.unwrap(); - let key = "test_key"; - let password = "test_password"; - - store.set(key, password).await.unwrap(); - - let secret = store.get(key).await.unwrap(); - assert_eq!(secret, Some(Secret(password.to_string()))); - - store.delete(key).await.unwrap(); - let secret = store.get(key).await.unwrap(); - assert_eq!(secret, None); - } -} diff --git a/crates/chat-cli/src/aws_common/http_client.rs b/crates/chat-cli/src/aws_common/http_client.rs index 57a64f1682..85c2b482cf 100644 --- a/crates/chat-cli/src/aws_common/http_client.rs +++ b/crates/chat-cli/src/aws_common/http_client.rs @@ -16,7 +16,7 @@ use reqwest::Client as ReqwestClient; /// Returns a wrapper around the global [fig_request::client] that implements /// [HttpClient]. pub fn client() -> Client { - let client = crate::request::client().expect("failed to create http client"); + let client = crate::request::new_client().expect("failed to create http client"); Client::new(client.clone()) } diff --git a/crates/chat-cli/src/cli/chat/conversation_state.rs b/crates/chat-cli/src/cli/chat/conversation_state.rs index 94cdcf0817..bc6bfae5a3 100644 --- a/crates/chat-cli/src/cli/chat/conversation_state.rs +++ b/crates/chat-cli/src/cli/chat/conversation_state.rs @@ -837,6 +837,9 @@ mod tests { ToolResultStatus, }; use crate::cli::chat::tool_manager::ToolManager; + use crate::database::Database; + use crate::platform::Env; + use crate::telemetry::TelemetryThread; fn assert_conversation_state_invariants(state: FigConversationState, assertion_iteration: usize) { if let Some(Some(msg)) = state.history.as_ref().map(|h| h.first()) { @@ -928,11 +931,15 @@ mod tests { #[tokio::test] async fn test_conversation_state_history_handling_truncation() { + let env = Env::new(); + let mut database = Database::new().await.unwrap(); + let telemetry = TelemetryThread::new(&env, &mut database).await.unwrap(); + let mut tool_manager = ToolManager::default(); let mut conversation_state = ConversationState::new( Context::new(), "fake_conv_id", - tool_manager.load_tools().await.unwrap(), + tool_manager.load_tools(&database, &telemetry).await.unwrap(), None, None, ) @@ -951,12 +958,16 @@ mod tests { #[tokio::test] async fn test_conversation_state_history_handling_with_tool_results() { + let env = Env::new(); + let mut database = Database::new().await.unwrap(); + let telemetry = TelemetryThread::new(&env, &mut database).await.unwrap(); + // Build a long conversation history of tool use results. let mut tool_manager = ToolManager::default(); let mut conversation_state = ConversationState::new( Context::new(), "fake_conv_id", - tool_manager.load_tools().await.unwrap(), + tool_manager.load_tools(&database, &telemetry).await.unwrap(), None, None, ) @@ -984,7 +995,7 @@ mod tests { let mut conversation_state = ConversationState::new( Context::new(), "fake_conv_id", - tool_manager.load_tools().await.unwrap(), + tool_manager.load_tools(&database, &telemetry).await.unwrap(), None, None, ) @@ -1015,6 +1026,10 @@ mod tests { #[tokio::test] async fn test_conversation_state_with_context_files() { + let env = Env::new(); + let mut database = Database::new().await.unwrap(); + let telemetry = TelemetryThread::new(&env, &mut database).await.unwrap(); + let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); ctx.fs().write(AMAZONQ_FILENAME, "test context").await.unwrap(); @@ -1022,7 +1037,7 @@ mod tests { let mut conversation_state = ConversationState::new( ctx, "fake_conv_id", - tool_manager.load_tools().await.unwrap(), + tool_manager.load_tools(&database, &telemetry).await.unwrap(), None, None, ) @@ -1060,6 +1075,10 @@ mod tests { async fn test_conversation_state_additional_context() { // tracing_subscriber::fmt::try_init().ok(); + let env = Env::new(); + let mut database = Database::new().await.unwrap(); + let telemetry = TelemetryThread::new(&env, &mut database).await.unwrap(); + let mut tool_manager = ToolManager::default(); let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); let conversation_start_context = "conversation start context"; @@ -1087,7 +1106,7 @@ mod tests { let mut conversation_state = ConversationState::new( ctx, "fake_conv_id", - tool_manager.load_tools().await.unwrap(), + tool_manager.load_tools(&database, &telemetry).await.unwrap(), None, Some(SharedWriter::stdout()), ) diff --git a/crates/chat-cli/src/cli/chat/input_source.rs b/crates/chat-cli/src/cli/chat/input_source.rs index 1c659eba58..25aeed0174 100644 --- a/crates/chat-cli/src/cli/chat/input_source.rs +++ b/crates/chat-cli/src/cli/chat/input_source.rs @@ -4,6 +4,7 @@ use rustyline::error::ReadlineError; use super::prompt::rl; #[cfg(unix)] use super::skim_integration::SkimCommandSelector; +use crate::database::Database; #[derive(Debug)] pub struct InputSource(inner::Inner); @@ -27,15 +28,17 @@ mod inner { impl InputSource { pub fn new( + database: &Database, sender: std::sync::mpsc::Sender>, receiver: std::sync::mpsc::Receiver>, ) -> Result { - Ok(Self(inner::Inner::Readline(rl(sender, receiver)?))) + Ok(Self(inner::Inner::Readline(rl(database, sender, receiver)?))) } #[cfg(unix)] pub fn put_skim_command_selector( &mut self, + database: &crate::database::Database, context_manager: std::sync::Arc, tool_names: Vec, ) { @@ -44,8 +47,10 @@ impl InputSource { KeyEvent, }; + use crate::database::settings::Setting; + if let inner::Inner::Readline(rl) = &mut self.0 { - let key_char = match crate::settings::settings::get_string_opt("chat.skimCommandKey").as_deref() { + let key_char = match database.settings.get_string(Setting::SkimCommandKey) { Some(key) if key.len() == 1 => key.chars().next().unwrap_or('s'), _ => 's', // Default to 's' if setting is missing or invalid }; diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 91394db378..3fe0504d31 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -95,13 +95,11 @@ use crate::api_client::model::{ Tool as FigTool, ToolResultStatus, }; +use crate::database::Database; +use crate::database::settings::Setting; use crate::platform::Context; -use crate::settings::{ - Settings, - State, -}; -use crate::telemetry::core::Event; -use crate::util::CHAT_BINARY_NAME; +use crate::telemetry::TelemetryThread; +use crate::telemetry::core::ToolUseEventBuilder; /// Help text for the compact command fn compact_help_text() -> String { @@ -292,14 +290,17 @@ const TRUST_ALL_TEXT: &str = color_print::cstr! {"All tools are now trus const TOOL_BULLET: &str = " ● "; const CONTINUATION_LINE: &str = " ⋮ "; -pub async fn launch_chat(args: cli::Chat) -> Result { +pub async fn launch_chat(database: &mut Database, telemetry: &TelemetryThread, args: cli::Chat) -> Result { let trust_tools = args.trust_tools.map(|mut tools| { if tools.len() == 1 && tools[0].is_empty() { tools.pop(); } tools }); + chat( + database, + telemetry, args.input, args.no_interactive, args.accept_all, @@ -310,7 +311,10 @@ pub async fn launch_chat(args: cli::Chat) -> Result { .await } +#[allow(clippy::too_many_arguments)] pub async fn chat( + database: &mut Database, + telemetry: &TelemetryThread, input: Option, no_interactive: bool, accept_all: bool, @@ -318,11 +322,8 @@ pub async fn chat( trust_all_tools: bool, trust_tools: Option>, ) -> Result { - if !crate::util::system_info::in_cloudshell() && !crate::auth::is_logged_in().await { - bail!( - "You are not logged in, please log in with {}", - format!("{CHAT_BINARY_NAME} login",).bold() - ); + if !crate::util::system_info::in_cloudshell() && !crate::auth::is_logged_in(database).await { + bail!("You are not logged in, please log in with {}", "q login".bold()); } region_check("chat")?; @@ -348,7 +349,7 @@ pub async fn chat( let client = match ctx.env().get("Q_MOCK_CHAT_RESPONSE") { Ok(json) => create_stream(serde_json::from_str(std::fs::read_to_string(json)?.as_str())?), - _ => StreamingClient::new().await?, + _ => StreamingClient::new(database).await?, }; let mcp_server_configs = match McpServerConfig::load_config(&mut output).await { @@ -397,9 +398,9 @@ pub async fn chat( .prompt_list_sender(prompt_response_sender) .prompt_list_receiver(prompt_request_receiver) .conversation_id(&conversation_id) - .build() + .build(telemetry) .await?; - let tool_config = tool_manager.load_tools().await?; + let tool_config = tool_manager.load_tools(database, telemetry).await?; let mut tool_permissions = ToolPermissions::new(tool_config.len()); if accept_all || trust_all_tools { for tool in tool_config.values() { @@ -429,11 +430,9 @@ pub async fn chat( let mut chat = ChatContext::new( ctx, &conversation_id, - Settings::new(), - State::new(), output, input, - InputSource::new(prompt_request_sender, prompt_response_receiver)?, + InputSource::new(database, prompt_request_sender, prompt_response_receiver)?, interactive, client, || terminal::window_size().map(|s| s.columns.into()).ok(), @@ -444,7 +443,7 @@ pub async fn chat( ) .await?; - let result = chat.try_chat().await.map(|_| ExitCode::SUCCESS); + let result = chat.try_chat(database, telemetry).await.map(|_| ExitCode::SUCCESS); drop(chat); // Explicit drop for clarity result @@ -485,9 +484,6 @@ pub enum ChatError { pub struct ChatContext { ctx: Arc, - settings: Settings, - /// The [State] to use for the chat context. - state: State, /// The [Write] destination for printing conversation text. output: SharedWriter, initial_input: Option, @@ -519,8 +515,6 @@ impl ChatContext { pub async fn new( ctx: Arc, conversation_id: &str, - settings: Settings, - state: State, output: SharedWriter, input: Option, input_source: InputSource, @@ -538,8 +532,6 @@ impl ChatContext { ConversationState::new(ctx_clone, conversation_id, tool_config, profile, Some(output_clone)).await; Ok(Self { ctx, - settings, - state, output, initial_input: input, input_source, @@ -684,9 +676,9 @@ impl ChatContext { Ok(content.trim().to_string()) } - async fn try_chat(&mut self) -> Result<()> { + async fn try_chat(&mut self, database: &mut Database, telemetry: &TelemetryThread) -> Result<()> { let is_small_screen = self.terminal_width() < GREETING_BREAK_POINT; - if self.interactive && self.settings.get_bool_or("chat.greeting.enabled", true) { + if self.interactive && database.settings.get_bool(Setting::ChatGreetingEnabled).unwrap_or(true) { execute!( self.output, style::Print(if is_small_screen { @@ -697,8 +689,7 @@ impl ChatContext { style::Print("\n\n"), )?; - let current_tip_index = - (self.state.get_int_or("chat.greeting.rotating_tips_current_index", 0) as usize) % ROTATING_TIPS.len(); + let current_tip_index = database.get_increment_rotating_tip().unwrap_or(0) % ROTATING_TIPS.len(); let tip = ROTATING_TIPS[current_tip_index]; if is_small_screen { @@ -733,11 +724,6 @@ impl ChatContext { ) )?; execute!(self.output, style::Print("\n"), style::SetForegroundColor(Color::Reset))?; - - // update the current tip index - let next_tip_index = (current_tip_index + 1) % ROTATING_TIPS.len(); - self.state - .set_value("chat.greeting.rotating_tips_current_index", next_tip_index)?; } if self.interactive && self.all_tools_trusted() { @@ -792,7 +778,7 @@ impl ChatContext { if !self.interactive { return Ok(()); } - self.prompt_user(tool_uses, pending_tool_index, skip_printing_tools) + self.prompt_user(database, tool_uses, pending_tool_index, skip_printing_tools) .await }, ChatState::HandleInput { @@ -802,7 +788,7 @@ impl ChatContext { } => { let tool_uses_clone = tool_uses.clone(); tokio::select! { - res = self.handle_input(input, tool_uses, pending_tool_index) => res, + res = self.handle_input(telemetry, input, tool_uses, pending_tool_index) => res, Ok(_) = ctrl_c_stream => Err(ChatError::Interrupted { tool_uses: tool_uses_clone }) } }, @@ -815,25 +801,25 @@ impl ChatContext { } => { let tool_uses_clone = tool_uses.clone(); tokio::select! { - res = self.compact_history(tool_uses, pending_tool_index, prompt, show_summary, help) => res, + res = self.compact_history(telemetry, tool_uses, pending_tool_index, prompt, show_summary, help) => res, Ok(_) = ctrl_c_stream => Err(ChatError::Interrupted { tool_uses: tool_uses_clone }) } }, ChatState::ExecuteTools(tool_uses) => { let tool_uses_clone = tool_uses.clone(); tokio::select! { - res = self.tool_use_execute(tool_uses) => res, + res = self.tool_use_execute(database, telemetry, tool_uses) => res, Ok(_) = ctrl_c_stream => Err(ChatError::Interrupted { tool_uses: Some(tool_uses_clone) }) } }, ChatState::ValidateTools(tool_uses) => { tokio::select! { - res = self.validate_tools(tool_uses) => res, + res = self.validate_tools(telemetry, tool_uses) => res, Ok(_) = ctrl_c_stream => Err(ChatError::Interrupted { tool_uses: None }) } }, ChatState::HandleResponseStream(response) => tokio::select! { - res = self.handle_response(response) => res, + res = self.handle_response(database, telemetry, response) => res, Ok(_) = ctrl_c_stream => Err(ChatError::Interrupted { tool_uses: None }) }, ChatState::Exit => return Ok(()), @@ -975,6 +961,7 @@ impl ChatContext { /// The last two user messages in the history are not included in the compaction process. async fn compact_history( &mut self, + telemetry: &TelemetryThread, tool_uses: Option>, pending_tool_index: Option, custom_prompt: Option, @@ -1085,12 +1072,13 @@ impl ChatContext { } if let Some(message_id) = self.conversation_state.message_id() { - crate::telemetry::send_chat_added_message( - self.conversation_state.conversation_id().to_owned(), - message_id.to_owned(), - self.conversation_state.context_message_length(), - ) - .await; + telemetry + .send_chat_added_message( + self.conversation_state.conversation_id().to_owned(), + message_id.to_owned(), + self.conversation_state.context_message_length(), + ) + .ok(); } self.conversation_state.replace_history_with_summary(summary.clone()); @@ -1171,6 +1159,7 @@ impl ChatContext { /// Read input from the user. async fn prompt_user( &mut self, + database: &Database, mut tool_uses: Option>, pending_tool_index: Option, skip_printing_tools: bool, @@ -1225,7 +1214,7 @@ impl ChatContext { .cloned() .collect::>(); self.input_source - .put_skim_command_selector(Arc::new(context_manager.clone()), tool_names); + .put_skim_command_selector(database, Arc::new(context_manager.clone()), tool_names); } execute!( self.output, @@ -1247,6 +1236,7 @@ impl ChatContext { async fn handle_input( &mut self, + telemetry: &TelemetryThread, mut user_input: String, tool_uses: Option>, pending_tool_index: Option, @@ -1313,7 +1303,7 @@ impl ChatContext { } let conv_state = self.conversation_state.as_sendable_conversation_state(true).await; - self.send_tool_use_telemetry().await; + self.send_tool_use_telemetry(telemetry).await; ChatState::HandleResponseStream(self.client.send_message(conv_state).await?) }, @@ -1378,8 +1368,15 @@ impl ChatContext { show_summary, help, } => { - self.compact_history(Some(tool_uses), pending_tool_index, prompt, show_summary, help) - .await? + self.compact_history( + telemetry, + Some(tool_uses), + pending_tool_index, + prompt, + show_summary, + help, + ) + .await? }, Command::Help => { execute!(self.output, style::Print(HELP_TEXT))?; @@ -2842,7 +2839,12 @@ impl ChatContext { }) } - async fn tool_use_execute(&mut self, mut tool_uses: Vec) -> Result { + async fn tool_use_execute( + &mut self, + database: &Database, + telemetry: &TelemetryThread, + mut tool_uses: Vec, + ) -> Result { // Verify tools have permissions. for (index, tool) in tool_uses.iter_mut().enumerate() { // Manually accepted by the user or otherwise verified already. @@ -2857,7 +2859,11 @@ impl ChatContext { !tool.tool.requires_acceptance(&self.ctx) }; - if self.settings.get_bool_or("chat.enableNotifications", false) { + if database + .settings + .get_bool(Setting::ChatEnableNotifications) + .unwrap_or(false) + { play_notification_bell(!allowed); } @@ -2998,7 +3004,7 @@ impl ChatContext { self.conversation_state.add_tool_results(tool_results); } - self.send_tool_use_telemetry().await; + self.send_tool_use_telemetry(telemetry).await; return Ok(ChatState::HandleResponseStream( self.client .send_message(self.conversation_state.as_sendable_conversation_state(false).await) @@ -3006,7 +3012,12 @@ impl ChatContext { )); } - async fn handle_response(&mut self, response: SendMessageOutput) -> Result { + async fn handle_response( + &mut self, + database: &Database, + telemetry: &TelemetryThread, + response: SendMessageOutput, + ) -> Result { let request_id = response.request_id().map(|s| s.to_string()); let mut buf = String::new(); let mut offset = 0; @@ -3086,7 +3097,7 @@ impl ChatContext { .to_string(), ) .await; - self.send_tool_use_telemetry().await; + self.send_tool_use_telemetry(telemetry).await; return Ok(ChatState::HandleResponseStream( self.client .send_message(self.conversation_state.as_sendable_conversation_state(false).await) @@ -3138,7 +3149,7 @@ impl ChatContext { status: ToolResultStatus::Error, }]; self.conversation_state.add_tool_results(tool_results); - self.send_tool_use_telemetry().await; + self.send_tool_use_telemetry(telemetry).await; return Ok(ChatState::HandleResponseStream( self.client .send_message(self.conversation_state.as_sendable_conversation_state(false).await) @@ -3196,15 +3207,21 @@ impl ChatContext { if ended { if let Some(message_id) = self.conversation_state.message_id() { - crate::telemetry::send_chat_added_message( - self.conversation_state.conversation_id().to_owned(), - message_id.to_owned(), - self.conversation_state.context_message_length(), - ) - .await; + telemetry + .send_chat_added_message( + self.conversation_state.conversation_id().to_owned(), + message_id.to_owned(), + self.conversation_state.context_message_length(), + ) + .ok(); } - if self.interactive && self.settings.get_bool_or("chat.enableNotifications", false) { + if self.interactive + && database + .settings + .get_bool(Setting::ChatEnableNotifications) + .unwrap_or(false) + { // For final responses (no tools suggested), always play the bell play_notification_bell(tool_uses.is_empty()); } @@ -3241,7 +3258,11 @@ impl ChatContext { } } - async fn validate_tools(&mut self, tool_uses: Vec) -> Result { + async fn validate_tools( + &mut self, + telemetry: &TelemetryThread, + tool_uses: Vec, + ) -> Result { let conv_id = self.conversation_state.conversation_id().to_owned(); debug!(?tool_uses, "Validating tool uses"); let mut queued_tools: Vec = Vec::new(); @@ -3319,7 +3340,7 @@ impl ChatContext { } } self.conversation_state.add_tool_results(tool_results); - self.send_tool_use_telemetry().await; + self.send_tool_use_telemetry(telemetry).await; if let ToolUseStatus::Idle = self.tool_use_status { self.tool_use_status = ToolUseStatus::RetryInProgress( self.conversation_state @@ -3433,7 +3454,7 @@ impl ChatContext { prompt::generate_prompt(self.conversation_state.current_profile(), self.all_tools_trusted()) } - async fn send_tool_use_telemetry(&mut self) { + async fn send_tool_use_telemetry(&mut self, telemetry: &TelemetryThread) { for (_, mut event) in self.tool_use_telemetry_events.drain() { event.user_input_id = match self.tool_use_status { ToolUseStatus::Idle => self.conversation_state.message_id(), @@ -3441,10 +3462,7 @@ impl ChatContext { } .map(|v| v.to_string()); - crate::telemetry::client() - .await - .send_event(Event::new(event.into())) - .await; + telemetry.send_tool_use_suggested(event).ok(); } } @@ -3525,75 +3543,6 @@ fn print_hook_section(output: &mut impl Write, hooks: &HashMap, tr Ok(()) } -#[derive(Debug)] -struct ToolUseEventBuilder { - pub conversation_id: String, - pub utterance_id: Option, - pub user_input_id: Option, - pub tool_use_id: Option, - pub tool_name: Option, - pub is_accepted: bool, - pub is_success: Option, - pub is_valid: Option, - pub is_custom_tool: bool, - pub input_token_size: Option, - pub output_token_size: Option, - pub custom_tool_call_latency: Option, -} - -impl ToolUseEventBuilder { - pub fn new(conv_id: String, tool_use_id: String) -> Self { - Self { - conversation_id: conv_id, - utterance_id: None, - user_input_id: None, - tool_use_id: Some(tool_use_id), - tool_name: None, - is_accepted: false, - is_success: None, - is_valid: None, - is_custom_tool: false, - input_token_size: None, - output_token_size: None, - custom_tool_call_latency: None, - } - } - - pub fn utterance_id(mut self, id: Option) -> Self { - self.utterance_id = id; - self - } - - pub fn set_tool_use_id(mut self, id: String) -> Self { - self.tool_use_id.replace(id); - self - } - - pub fn set_tool_name(mut self, name: String) -> Self { - self.tool_name.replace(name); - self - } -} - -impl From for crate::telemetry::EventType { - fn from(val: ToolUseEventBuilder) -> Self { - crate::telemetry::EventType::ToolUseSuggested { - conversation_id: val.conversation_id, - utterance_id: val.utterance_id, - user_input_id: val.user_input_id, - tool_use_id: val.tool_use_id, - tool_name: val.tool_name, - is_accepted: val.is_accepted, - is_success: val.is_success, - is_valid: val.is_valid, - is_custom_tool: val.is_custom_tool, - input_token_size: val.input_token_size, - output_token_size: val.output_token_size, - custom_tool_call_latency: val.custom_tool_call_latency, - } - } -} - /// Testing helper fn split_tool_use_event(value: &Map) -> Vec { let tool_use_id = value.get("tool_use_id").unwrap().as_str().unwrap().to_string(); @@ -3654,6 +3603,7 @@ fn create_stream(model_responses: serde_json::Value) -> StreamingClient { #[cfg(test)] mod tests { use super::*; + use crate::platform::Env; #[tokio::test] async fn test_flow() { @@ -3677,14 +3627,16 @@ mod tests { ], ])); + let env = Env::new(); + let mut database = Database::new().await.unwrap(); + let telemetry = TelemetryThread::new(&env, &mut database).await.unwrap(); + let tool_manager = ToolManager::default(); let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) .expect("Tools failed to load"); ChatContext::new( Arc::clone(&ctx), "fake_conv_id", - Settings::new(), - State::new(), SharedWriter::stdout(), None, InputSource::new_mock(vec![ @@ -3702,7 +3654,7 @@ mod tests { ) .await .unwrap() - .try_chat() + .try_chat(&mut database, &telemetry) .await .unwrap(); @@ -3806,14 +3758,16 @@ mod tests { ], ])); + let env = Env::new(); + let mut database = Database::new().await.unwrap(); + let telemetry = TelemetryThread::new(&env, &mut database).await.unwrap(); + let tool_manager = ToolManager::default(); let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) .expect("Tools failed to load"); ChatContext::new( Arc::clone(&ctx), "fake_conv_id", - Settings::new(), - State::new(), SharedWriter::stdout(), None, InputSource::new_mock(vec![ @@ -3844,7 +3798,7 @@ mod tests { ) .await .unwrap() - .try_chat() + .try_chat(&mut database, &telemetry) .await .unwrap(); @@ -3910,14 +3864,16 @@ mod tests { ], ])); + let env = Env::new(); + let mut database = Database::new().await.unwrap(); + let telemetry = TelemetryThread::new(&env, &mut database).await.unwrap(); + let tool_manager = ToolManager::default(); let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) .expect("Tools failed to load"); ChatContext::new( Arc::clone(&ctx), "fake_conv_id", - Settings::new(), - State::new(), SharedWriter::stdout(), None, InputSource::new_mock(vec![ @@ -3939,7 +3895,7 @@ mod tests { ) .await .unwrap() - .try_chat() + .try_chat(&mut database, &telemetry) .await .unwrap(); @@ -3986,14 +3942,16 @@ mod tests { ], ])); + let env = Env::new(); + let mut database = Database::new().await.unwrap(); + let telemetry = TelemetryThread::new(&env, &mut database).await.unwrap(); + let tool_manager = ToolManager::default(); let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) .expect("Tools failed to load"); ChatContext::new( Arc::clone(&ctx), "fake_conv_id", - Settings::new(), - State::new(), SharedWriter::stdout(), None, InputSource::new_mock(vec![ @@ -4013,7 +3971,7 @@ mod tests { ) .await .unwrap() - .try_chat() + .try_chat(&mut database, &telemetry) .await .unwrap(); diff --git a/crates/chat-cli/src/cli/chat/prompt.rs b/crates/chat-cli/src/cli/chat/prompt.rs index b936555660..782731b2ce 100644 --- a/crates/chat-cli/src/cli/chat/prompt.rs +++ b/crates/chat-cli/src/cli/chat/prompt.rs @@ -35,6 +35,9 @@ use rustyline::{ }; use winnow::stream::AsChar; +use crate::database::Database; +use crate::database::settings::Setting; + pub const COMMANDS: &[&str] = &[ "/clear", "/help", @@ -264,10 +267,11 @@ impl Highlighter for ChatHelper { } pub fn rl( + database: &Database, sender: std::sync::mpsc::Sender>, receiver: std::sync::mpsc::Receiver>, ) -> Result> { - let edit_mode = match crate::settings::settings::get_string_opt("chat.editMode").as_deref() { + let edit_mode = match database.settings.get_string(Setting::ChatEditMode).as_deref() { Some("vi" | "vim") => EditMode::Vi, _ => EditMode::Vi, }; diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index cb445d3876..e7bc910ac9 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -54,11 +54,12 @@ use crate::api_client::model::{ ToolResultContentBlock, ToolResultStatus, }; +use crate::database::Database; use crate::mcp_client::{ JsonRpcResponse, PromptGet, }; -use crate::telemetry::send_mcp_server_init; +use crate::telemetry::TelemetryThread; const NAMESPACE_DELIMITER: &str = "___"; // This applies for both mcp server and tool name since in the end the tool name as seen by the @@ -204,7 +205,7 @@ impl ToolManagerBuilder { self } - pub async fn build(mut self) -> eyre::Result { + pub async fn build(mut self, telemetry: &TelemetryThread) -> eyre::Result { let McpServerConfig { mcp_servers } = self.mcp_server_config.ok_or(eyre::eyre!("Missing mcp server config"))?; debug_assert!(self.conversation_id.is_some()); let conversation_id = self.conversation_id.ok_or(eyre::eyre!("Missing conversation id"))?; @@ -324,7 +325,9 @@ impl ToolManagerBuilder { }, Err(e) => { error!("Error initializing mcp client for server {}: {:?}", name, &e); - send_mcp_server_init(conversation_id.clone(), Some(e.to_string()), 0).await; + telemetry + .send_mcp_server_init(conversation_id.clone(), Some(e.to_string()), 0) + .ok(); let _ = tx.send(LoadingMsg::Error { name: name.clone(), @@ -498,23 +501,29 @@ pub struct ToolManager { } impl ToolManager { - pub async fn load_tools(&mut self) -> eyre::Result> { + pub async fn load_tools( + &mut self, + database: &Database, + telemetry: &TelemetryThread, + ) -> eyre::Result> { let tx = self.loading_status_sender.take(); let display_task = self.loading_display_task.take(); let tool_specs = { let mut tool_specs = serde_json::from_str::>(include_str!("tools/tool_index.json"))?; - if !crate::cli::chat::tools::thinking::Thinking::is_enabled() { - tool_specs.remove("q_think_tool"); + if !crate::cli::chat::tools::thinking::Thinking::is_enabled(database) { + tool_specs.remove("thinking"); } Arc::new(Mutex::new(tool_specs)) }; let conversation_id = self.conversation_id.clone(); let regex = Arc::new(regex::Regex::new(VALID_TOOL_NAME)?); + let load_tool = self .clients .iter() .map(|(server_name, client)| { + let telemetry = telemetry.clone(); let client_clone = client.clone(); let server_name_clone = server_name.clone(); let tx_clone = tx.clone(); @@ -568,7 +577,7 @@ impl ToolManager { } // Send server load success metric datum - send_mcp_server_init(conversation_id, None, number_of_tools).await; + telemetry.send_mcp_server_init(conversation_id, None, number_of_tools).ok(); // Tool name translation. This is beyond of the scope of what is // considered a "server load". Reasoning being: @@ -614,7 +623,7 @@ impl ToolManager { Err(e) => { error!("Error obtaining tool spec for {}: {:?}", server_name_clone, e); let init_failure_reason = Some(e.to_string()); - send_mcp_server_init(conversation_id, init_failure_reason, 0).await; + telemetry.send_mcp_server_init(conversation_id, init_failure_reason, 0).ok(); if let Some(tx_clone) = &tx_clone { if let Err(e) = tx_clone.send(LoadingMsg::Error { name: server_name_clone, @@ -677,7 +686,7 @@ impl ToolManager { "execute_bash" => Tool::ExecuteBash(serde_json::from_value::(value.args).map_err(map_err)?), "use_aws" => Tool::UseAws(serde_json::from_value::(value.args).map_err(map_err)?), "report_issue" => Tool::GhIssue(serde_json::from_value::(value.args).map_err(map_err)?), - "q_think_tool" => Tool::Thinking(serde_json::from_value::(value.args).map_err(map_err)?), + "thinking" => Tool::Thinking(serde_json::from_value::(value.args).map_err(map_err)?), // Note that this name is namespaced with server_name{DELIMITER}tool_name name => { // Note: tn_map also has tools that underwent no transformation. In otherwords, if diff --git a/crates/chat-cli/src/cli/chat/tools/thinking.rs b/crates/chat-cli/src/cli/chat/tools/thinking.rs index d6d9884b0c..b6fe099e4d 100644 --- a/crates/chat-cli/src/cli/chat/tools/thinking.rs +++ b/crates/chat-cli/src/cli/chat/tools/thinking.rs @@ -12,7 +12,8 @@ use super::{ InvokeOutput, OutputKind, }; -use crate::settings::settings; +use crate::database::Database; +use crate::database::settings::Setting; /// The Think tool allows the model to reason through complex problems during response generation. /// It provides a dedicated space for the model to process information from tool call results, @@ -28,9 +29,8 @@ pub struct Thinking { impl Thinking { /// Checks if the thinking feature is enabled in settings - pub fn is_enabled() -> bool { - // Default to enabled if setting doesn't exist or can't be read - settings::get_bool_or("chat.enableThinking", true) + pub fn is_enabled(database: &Database) -> bool { + database.settings.get_bool(Setting::EnabledThinking).unwrap_or(true) } /// Queues up a description of the think tool for the user diff --git a/crates/chat-cli/src/cli/chat/util/images.rs b/crates/chat-cli/src/cli/chat/util/images.rs index e3aa8ca9ab..64528256cd 100644 --- a/crates/chat-cli/src/cli/chat/util/images.rs +++ b/crates/chat-cli/src/cli/chat/util/images.rs @@ -266,7 +266,7 @@ mod tests { fn test_handle_images_size_limit_exceeded() { let temp_dir = tempfile::tempdir().unwrap(); let large_image_path = temp_dir.path().join("large_image.jpg"); - let large_image_size = MAX_IMAGE_SIZE as usize + 1; + let large_image_size = MAX_IMAGE_SIZE + 1; std::fs::write(&large_image_path, vec![0; large_image_size]).unwrap(); let buf = Arc::new(std::sync::Mutex::new(Vec::::new())); let test_writer = TestWriterWithSink { sink: buf.clone() }; diff --git a/crates/chat-cli/src/cli/debug.rs b/crates/chat-cli/src/cli/debug.rs index 1cf14e0fb3..9ef75501c0 100644 --- a/crates/chat-cli/src/cli/debug.rs +++ b/crates/chat-cli/src/cli/debug.rs @@ -1,11 +1,7 @@ -use std::process::ExitCode; - -use anstream::eprintln; use clap::{ Subcommand, ValueEnum, }; -use eyre::Result; #[derive(Debug, ValueEnum, Clone, PartialEq, Eq)] pub enum Build { @@ -87,23 +83,3 @@ pub enum InputMethodDebugAction { action: TISAction, }, } - -#[derive(Debug, PartialEq, Subcommand)] -pub enum DebugSubcommand { - RefreshAuthToken, -} - -impl DebugSubcommand { - pub async fn execute(&self) -> Result { - match self { - DebugSubcommand::RefreshAuthToken => match crate::auth::refresh_token().await? { - Some(_) => eprintln!("Refreshed token"), - None => { - eprintln!("No token to refresh"); - return Ok(ExitCode::FAILURE); - }, - }, - } - Ok(ExitCode::SUCCESS) - } -} diff --git a/crates/chat-cli/src/cli/mod.rs b/crates/chat-cli/src/cli/mod.rs index 73c12fc869..be01b61b9c 100644 --- a/crates/chat-cli/src/cli/mod.rs +++ b/crates/chat-cli/src/cli/mod.rs @@ -4,7 +4,6 @@ mod diagnostics; mod feed; mod issue; mod settings; -mod telemetry; mod user; use std::io::{ @@ -33,13 +32,12 @@ use tracing::{ Level, debug, }; +use user::UserSubcommand; -use self::user::RootUserSubcommand; use crate::logging::{ LogArgs, initialize_logging, }; -use crate::telemetry::send_cli_subcommand_executed; use crate::util::directories::logs_dir; use crate::util::{ CHAT_BINARY_NAME, @@ -83,9 +81,6 @@ pub enum Processes { #[deny(missing_docs)] #[derive(Debug, PartialEq, Subcommand)] pub enum CliRootCommands { - /// Debug the app - #[command(subcommand)] - Debug(debug::DebugSubcommand), /// Customize appearance & behavior #[command(alias("setting"))] Settings(settings::SettingsArgs), @@ -94,15 +89,9 @@ pub enum CliRootCommands { Diagnostic(diagnostics::DiagnosticArgs), /// Create a new Github issue Issue(issue::IssueArgs), - /// Root level user subcommands + /// User subcommands #[command(flatten)] - RootUser(user::RootUserSubcommand), - /// Manage your account - #[command(subcommand)] User(user::UserSubcommand), - /// Enable/disable telemetry - #[command(subcommand, hide = true)] - Telemetry(telemetry::TelemetrySubcommand), /// Version #[command(hide = true)] Version { @@ -117,18 +106,15 @@ pub enum CliRootCommands { } impl CliRootCommands { - fn name(&self) -> &'static str { + pub fn name(&self) -> &'static str { match self { - CliRootCommands::Debug(_) => "debug", CliRootCommands::Settings(_) => "settings", CliRootCommands::Diagnostic(_) => "diagnostics", CliRootCommands::Issue(_) => "issue", - CliRootCommands::RootUser(RootUserSubcommand::Login(_)) => "login", - CliRootCommands::RootUser(RootUserSubcommand::Logout) => "logout", - CliRootCommands::RootUser(RootUserSubcommand::Whoami { .. }) => "whoami", - CliRootCommands::RootUser(RootUserSubcommand::Profile) => "profile", - CliRootCommands::User(_) => "user", - CliRootCommands::Telemetry(_) => "telemetry", + CliRootCommands::User(UserSubcommand::Login(_)) => "login", + CliRootCommands::User(UserSubcommand::Logout) => "logout", + CliRootCommands::User(UserSubcommand::Whoami { .. }) => "whoami", + CliRootCommands::User(UserSubcommand::Profile) => "profile", CliRootCommands::Version { .. } => "version", CliRootCommands::Chat { .. } => "chat", } @@ -191,9 +177,19 @@ impl Cli { delete_old_log_file: false, }); - debug!(command =? std::env::args().collect::>(), "Command ran"); + debug!(command =? std::env::args().collect::>(), "Command being ran"); + + let env = crate::platform::Env::new(); + let mut database = crate::database::Database::new().await?; + let telemetry = crate::telemetry::TelemetryThread::new(&env, &mut database).await?; - self.send_telemetry().await; + let _ = match &self.subcommand { + None => telemetry.send_cli_subcommand_executed(None), + Some(subcommand) if ["diagnostic", "version"].contains(&subcommand.name()) => { + telemetry.send_cli_subcommand_executed(Some(subcommand)) + }, + _ => Ok(()), + }; if self.help_all { return Self::print_help_all(); @@ -201,30 +197,24 @@ impl Cli { let cli_context = CliContext::new(); - match self.subcommand { + let result = match self.subcommand { Some(subcommand) => match subcommand { CliRootCommands::Diagnostic(args) => args.execute().await, - CliRootCommands::User(user) => user.execute().await, - CliRootCommands::RootUser(root_user) => root_user.execute().await, - CliRootCommands::Settings(settings_args) => settings_args.execute(&cli_context).await, - CliRootCommands::Debug(debug_subcommand) => debug_subcommand.execute().await, + CliRootCommands::User(user) => user.execute(&mut database, &telemetry).await, + CliRootCommands::Settings(settings_args) => settings_args.execute(&mut database, &cli_context).await, CliRootCommands::Issue(args) => args.execute().await, - CliRootCommands::Telemetry(subcommand) => subcommand.execute().await, CliRootCommands::Version { changelog } => Self::print_version(changelog), - CliRootCommands::Chat(args) => chat::launch_chat(args).await, + CliRootCommands::Chat(args) => chat::launch_chat(&mut database, &telemetry, args).await, }, // Root command - None => chat::launch_chat(chat::cli::Chat::default()).await, - } - } + None => chat::launch_chat(&mut database, &telemetry, chat::cli::Chat::default()).await, + }; - async fn send_telemetry(&self) { - match &self.subcommand { - None => {}, - Some(subcommand) => { - send_cli_subcommand_executed(subcommand.name()).await; - }, - } + let telemetry_result = telemetry.finish().await; + + let exit_code = result?; + telemetry_result?; + Ok(exit_code) } fn print_help_all() -> Result { diff --git a/crates/chat-cli/src/cli/settings.rs b/crates/chat-cli/src/cli/settings.rs index 67b438dbb3..83196908ff 100644 --- a/crates/chat-cli/src/cli/settings.rs +++ b/crates/chat-cli/src/cli/settings.rs @@ -15,7 +15,8 @@ use globset::Glob; use serde_json::json; use super::OutputFormat; -use crate::settings::JsonStore; +use crate::database::Database; +use crate::database::settings::Setting; use crate::util::{ CliContext, directories, @@ -30,6 +31,9 @@ pub enum SettingsSubcommands { /// Format of the output #[arg(long, short, value_enum, default_value_t)] format: OutputFormat, + /// Whether or not we want to modify state instead + #[arg(long, short, hide = true)] + state: bool, }, } @@ -53,7 +57,7 @@ pub struct SettingsArgs { } impl SettingsArgs { - pub async fn execute(&self, cli_context: &CliContext) -> Result { + pub async fn execute(&self, database: &mut Database, cli_context: &CliContext) -> Result { match self.cmd { Some(SettingsSubcommands::Open) => { let file = directories::settings_path().context("Could not get settings path")?; @@ -64,8 +68,11 @@ impl SettingsArgs { bail!("The EDITOR environment variable is not set") } }, - Some(SettingsSubcommands::All { format }) => { - let settings = crate::settings::OldSettings::load()?.map().clone(); + Some(SettingsSubcommands::All { format, state }) => { + let settings = match state { + true => database.get_all_entries()?, + false => database.settings.map().clone(), + }; match format { OutputFormat::Plain => { @@ -86,8 +93,9 @@ impl SettingsArgs { return Ok(ExitCode::SUCCESS); }; + let key = Setting::try_from(key.as_str())?; match (&self.value, self.delete) { - (None, false) => match crate::settings::settings::get_value(key)? { + (None, false) => match database.settings.get(key) { Some(value) => { match self.format { OutputFormat::Plain => match value.as_str() { @@ -109,14 +117,15 @@ impl SettingsArgs { }, (Some(value_str), false) => { let value = serde_json::from_str(value_str).unwrap_or_else(|_| json!(value_str)); - crate::settings::settings::set_value(key, value)?; + database.settings.set(key, value).await?; Ok(ExitCode::SUCCESS) }, (None, true) => { - let glob = Glob::new(key).context("Could not create glob")?.compile_matcher(); - let settings = crate::settings::OldSettings::load()?; - let map = settings.map(); - let keys_to_remove = map.keys().filter(|key| glob.is_match(key)).collect::>(); + let glob = Glob::new(key.as_ref()) + .context("Could not create glob")? + .compile_matcher(); + let map = database.settings.map(); + let keys_to_remove = map.keys().filter(|key| glob.is_match(key)).cloned().collect::>(); match keys_to_remove.len() { 0 => { @@ -124,16 +133,17 @@ impl SettingsArgs { }, 1 => { println!("Removing {:?}", keys_to_remove[0]); - crate::settings::settings::remove_value(keys_to_remove[0])?; + database + .settings + .remove(Setting::try_from(keys_to_remove[0].as_str())?) + .await?; }, _ => { - println!("Removing:"); - for key in &keys_to_remove { - println!(" - {key}"); - } - for key in &keys_to_remove { - crate::settings::settings::remove_value(key)?; + if let Ok(key) = Setting::try_from(key.as_str()) { + println!("Removing `{key}`"); + database.settings.remove(key).await?; + } } }, } diff --git a/crates/chat-cli/src/cli/telemetry.rs b/crates/chat-cli/src/cli/telemetry.rs deleted file mode 100644 index 34b7dab895..0000000000 --- a/crates/chat-cli/src/cli/telemetry.rs +++ /dev/null @@ -1,53 +0,0 @@ -use std::process::ExitCode; - -use clap::Subcommand; -use crossterm::style::Stylize; -use eyre::Result; -use serde_json::json; - -use super::OutputFormat; - -const TELEMETRY_ENABLED_KEY: &str = "telemetry.enabled"; - -#[derive(Debug, PartialEq, Eq, Subcommand)] -pub enum TelemetrySubcommand { - Enable, - Disable, - Status { - /// Format of the output - #[arg(long, short, value_enum, default_value_t)] - format: OutputFormat, - }, -} - -impl TelemetrySubcommand { - pub async fn execute(&self) -> Result { - match self { - TelemetrySubcommand::Enable => { - crate::settings::settings::set_value(TELEMETRY_ENABLED_KEY, true)?; - Ok(ExitCode::SUCCESS) - }, - TelemetrySubcommand::Disable => { - crate::settings::settings::set_value(TELEMETRY_ENABLED_KEY, false)?; - Ok(ExitCode::SUCCESS) - }, - TelemetrySubcommand::Status { format } => { - let status = crate::settings::settings::get_bool_or(TELEMETRY_ENABLED_KEY, true); - format.print( - || { - format!( - "Telemetry status: {}", - if status { "enabled" } else { "disabled" }.bold() - ) - }, - || { - json!({ - TELEMETRY_ENABLED_KEY: status, - }) - }, - ); - Ok(ExitCode::SUCCESS) - }, - } - } -} diff --git a/crates/chat-cli/src/cli/uninstall.rs b/crates/chat-cli/src/cli/uninstall.rs deleted file mode 100644 index a75ba88401..0000000000 --- a/crates/chat-cli/src/cli/uninstall.rs +++ /dev/null @@ -1,54 +0,0 @@ -use std::process::ExitCode; - -use anstream::println; -use crossterm::style::Stylize; -use eyre::Result; - -use crate::util::{ - CLI_BINARY_NAME, - PRODUCT_NAME, - dialoguer_theme, -}; - -pub async fn uninstall_command(no_confirm: bool) -> Result { - if !no_confirm { - println!( - "\nIs {PRODUCT_NAME} not working? Try running {}\n", - format!("{CLI_BINARY_NAME} doctor").bold().magenta() - ); - let should_continue = dialoguer::Select::with_theme(&dialoguer_theme()) - .with_prompt(format!("Are you sure want to continue uninstalling {PRODUCT_NAME}?")) - .items(&["Yes", "No"]) - .default(0) - .interact_opt()?; - - if should_continue == Some(0) { - println!("Uninstalling {PRODUCT_NAME}"); - } else { - println!("Cancelled"); - return Ok(ExitCode::FAILURE); - } - }; - - cfg_if::cfg_if! { - if #[cfg(target_os = "macos")] { - uninstall().await?; - } else if #[cfg(target_os = "linux")] { - (); - } - } - - Ok(ExitCode::SUCCESS) -} - -#[cfg(target_os = "macos")] -async fn uninstall() -> Result<()> { - crate::auth::logout().await.ok(); - crate::install::uninstall().await?; - Ok(()) -} - -#[cfg(all(unix, not(any(target_os = "macos", target_os = "linux"))))] -async fn uninstall() -> Result<()> { - eyre::bail!("Guided uninstallation is not supported on this platform. Please uninstall manually."); -} diff --git a/crates/chat-cli/src/cli/update.rs b/crates/chat-cli/src/cli/update.rs deleted file mode 100644 index 5fc49feb2b..0000000000 --- a/crates/chat-cli/src/cli/update.rs +++ /dev/null @@ -1,54 +0,0 @@ -use std::process::ExitCode; - -use anstream::println; -use clap::Args; -use crossterm::style::Stylize; -use eyre::Result; - -use crate::util::CLI_BINARY_NAME; - -#[derive(Debug, PartialEq, Args)] -pub struct UpdateArgs { - /// Don't prompt for confirmation - #[arg(long, short = 'y')] - non_interactive: bool, - /// Relaunch into dashboard after update (false will launch in background) - #[arg(long, default_value = "true")] - relaunch_dashboard: bool, - /// Uses rollout - #[arg(long)] - rollout: bool, -} - -impl UpdateArgs { - pub async fn execute(&self) -> Result { - todo!(); - - // let res = self_update::backends::s3::Update::configure() - // .bucket_name("self_update_releases") - // .asset_prefix("something/self_update") - // .region("eu-west-2") - // .bin_name("self_update_example") - // .show_download_progress(true) - // .current_version(cargo_crate_version!()) - // .build()? - // .update(); - // - // match res { - // Ok(Status::UpToDate(_)) => { - // println!( - // "No updates available, \n{} is the latest version.", - // env!("CARGO_PKG_VERSION").bold() - // ); - // Ok(ExitCode::SUCCESS) - // }, - // Ok(Status::Updated(_)) => Ok(ExitCode::SUCCESS), - // Err(err) => { - // eyre::bail!( - // "{err}\n\nIf this is unexpected, try running {} and then try again.\n", - // format!("{CLI_BINARY_NAME} doctor").bold() - // ) - // }, - // } - } -} diff --git a/crates/chat-cli/src/cli/user.rs b/crates/chat-cli/src/cli/user.rs index 4d291f5c71..a661491708 100644 --- a/crates/chat-cli/src/cli/user.rs +++ b/crates/chat-cli/src/cli/user.rs @@ -29,18 +29,19 @@ use tracing::{ use super::OutputFormat; use crate::api_client::list_available_profiles; -use crate::api_client::profile::Profile; use crate::auth::builder_id::{ + BuilderIdToken, PollCreateToken, TokenType, poll_create_token, start_device_authorization, }; use crate::auth::pkce::start_pkce_authorization; -use crate::auth::secret_store::SecretStore; +use crate::database::Database; use crate::telemetry::{ QProfileSwitchIntent, TelemetryResult, + TelemetryThread, }; use crate::util::spinner::{ Spinner, @@ -54,22 +55,6 @@ use crate::util::{ input, }; -#[derive(Subcommand, Debug, PartialEq, Eq)] -pub enum RootUserSubcommand { - /// Login - Login(LoginArgs), - /// Logout - Logout, - /// Prints details about the current user - Whoami { - /// Output format to use - #[arg(long, short, value_enum, default_value_t)] - format: OutputFormat, - }, - /// Show the profile associated with this idc user - Profile, -} - #[derive(Args, Debug, PartialEq, Eq, Clone, Default)] pub struct LoginArgs { /// License type (pro for Identity Center, free for Builder ID) @@ -115,23 +100,39 @@ impl Display for AuthMethod { } } -impl RootUserSubcommand { - pub async fn execute(self) -> Result { +#[derive(Subcommand, Debug, PartialEq, Eq)] +pub enum UserSubcommand { + /// Login + Login(LoginArgs), + /// Logout + Logout, + /// Prints details about the current user + Whoami { + /// Output format to use + #[arg(long, short, value_enum, default_value_t)] + format: OutputFormat, + }, + /// Show the profile associated with this idc user + Profile, +} + +impl UserSubcommand { + pub async fn execute(self, database: &mut Database, telemetry: &TelemetryThread) -> Result { match self { Self::Login(args) => { - if crate::auth::is_logged_in().await { + if crate::auth::is_logged_in(database).await { eyre::bail!( "Already logged in, please logout with {} first", format!("{CHAT_BINARY_NAME} logout").magenta() ); } - login_interactive(args).await?; + login_interactive(database, telemetry, args).await?; Ok(ExitCode::SUCCESS) }, Self::Logout => { - let _ = crate::auth::logout().await; + let _ = crate::auth::logout(database).await; println!("You are now logged out"); println!( @@ -141,7 +142,7 @@ impl RootUserSubcommand { Ok(ExitCode::SUCCESS) }, Self::Whoami { format } => { - let builder_id = crate::auth::builder_id_token().await; + let builder_id = BuilderIdToken::load(database).await; match builder_id { Ok(Some(token)) => { @@ -168,9 +169,7 @@ impl RootUserSubcommand { ); if matches!(token.token_type(), TokenType::IamIdentityCenter) { - if let Ok(Some(profile)) = crate::settings::state::get::( - "api.codewhisperer.profile", - ) { + if let Ok(Some(profile)) = database.get_auth_profile() { color_print::cprintln!( "\nProfile:\n{}\n{}\n", profile.profile_name, @@ -187,20 +186,17 @@ impl RootUserSubcommand { } }, Self::Profile => { - if !crate::util::system_info::in_cloudshell() && !crate::auth::is_logged_in().await { - bail!( - "You are not logged in, please log in with {}", - format!("{CHAT_BINARY_NAME} login",).bold() - ); + if !crate::util::system_info::in_cloudshell() && !crate::auth::is_logged_in(database).await { + bail!("You are not logged in, please log in with {}", "q login".bold()); } - if let Ok(Some(token)) = crate::auth::builder_id_token().await { + if let Ok(Some(token)) = BuilderIdToken::load(database).await { if matches!(token.token_type(), TokenType::BuilderId) { bail!("This command is only available for Pro users"); } } - select_profile_interactive(false).await?; + select_profile_interactive(database, telemetry, false).await?; Ok(ExitCode::SUCCESS) }, @@ -208,21 +204,7 @@ impl RootUserSubcommand { } } -#[derive(Subcommand, Debug, PartialEq, Eq)] -pub enum UserSubcommand { - #[command(flatten)] - Root(RootUserSubcommand), -} - -impl UserSubcommand { - pub async fn execute(self) -> Result { - match self { - Self::Root(cmd) => cmd.execute().await, - } - } -} - -pub async fn login_interactive(args: LoginArgs) -> Result<()> { +pub async fn login_interactive(database: &mut Database, telemetry: &TelemetryThread, args: LoginArgs) -> Result<()> { let login_method = match args.license { Some(LicenseType::Free) => AuthMethod::BuilderId, Some(LicenseType::Pro) => AuthMethod::IdentityCenter, @@ -248,28 +230,29 @@ pub async fn login_interactive(args: LoginArgs) -> Result<()> { let (start_url, region) = match login_method { AuthMethod::BuilderId => (None, None), AuthMethod::IdentityCenter => { - let default_start_url = args - .identity_provider - .or_else(|| crate::settings::state::get_string("auth.idc.start-url").ok().flatten()); - let default_region = args - .region - .or_else(|| crate::settings::state::get_string("auth.idc.region").ok().flatten()); + let default_start_url = match args.identity_provider { + Some(start_url) => Some(start_url), + None => database.get_start_url()?, + }; + let default_region = match args.region { + Some(region) => Some(region), + None => database.get_idc_region()?, + }; let start_url = input("Enter Start URL", default_start_url.as_deref())?; let region = input("Enter Region", default_region.as_deref())?; - let _ = crate::settings::state::set_value("auth.idc.start-url", start_url.clone()); - let _ = crate::settings::state::set_value("auth.idc.region", region.clone()); + let _ = database.set_start_url(start_url.clone()); + let _ = database.set_idc_region(region.clone()); (Some(start_url), Some(region)) }, }; - let secret_store = SecretStore::new().await?; // Remote machine won't be able to handle browser opening and redirects, // hence always use device code flow. if is_remote() || args.use_device_flow { - try_device_authorization(&secret_store, start_url.clone(), region.clone()).await?; + try_device_authorization(database, telemetry, start_url.clone(), region.clone()).await?; } else { let (client, registration) = start_pkce_authorization(start_url.clone(), region.clone()).await?; @@ -282,14 +265,14 @@ pub async fn login_interactive(args: LoginArgs) -> Result<()> { ]); let ctrl_c_stream = ctrl_c(); tokio::select! { - res = registration.finish(&client, Some(&secret_store)) => res?, + res = registration.finish(&client, Some(database)) => res?, Ok(_) = ctrl_c_stream => { #[allow(clippy::exit)] exit(1); }, } - crate::telemetry::send_user_logged_in().await; - spinner.stop_with_message("Device authorized".into()); + telemetry.send_user_logged_in().ok(); + spinner.stop_with_message("Logged in".into()); }, // If we are unable to open the link with the browser, then fallback to // the device code flow. @@ -297,7 +280,7 @@ pub async fn login_interactive(args: LoginArgs) -> Result<()> { error!(%err, "Failed to open URL with browser, falling back to device code flow"); // Try device code flow. - try_device_authorization(&secret_store, start_url.clone(), region.clone()).await?; + try_device_authorization(database, telemetry, start_url.clone(), region.clone()).await?; }, } } @@ -305,20 +288,19 @@ pub async fn login_interactive(args: LoginArgs) -> Result<()> { }; if login_method == AuthMethod::IdentityCenter { - select_profile_interactive(true).await?; + select_profile_interactive(database, telemetry, true).await?; } - eprintln!("Logged in successfully"); - Ok(()) } async fn try_device_authorization( - secret_store: &SecretStore, + database: &mut Database, + telemetry: &TelemetryThread, start_url: Option, region: Option, ) -> Result<()> { - let device_auth = start_device_authorization(secret_store, start_url.clone(), region.clone()).await?; + let device_auth = start_device_authorization(database, start_url.clone(), region.clone()).await?; println!(); println!("Confirm the following code in the browser"); @@ -349,7 +331,7 @@ async fn try_device_authorization( } } match poll_create_token( - secret_store, + database, device_auth.device_code.clone(), start_url.clone(), region.clone(), @@ -358,8 +340,8 @@ async fn try_device_authorization( { PollCreateToken::Pending => {}, PollCreateToken::Complete => { - crate::telemetry::send_user_logged_in().await; - spinner.stop_with_message("Device authorized".into()); + telemetry.send_user_logged_in().ok(); + spinner.stop_with_message("Logged in".into()); break; }, PollCreateToken::Error(err) => { @@ -371,42 +353,42 @@ async fn try_device_authorization( Ok(()) } -async fn select_profile_interactive(whoami: bool) -> Result<()> { +async fn select_profile_interactive(database: &mut Database, telemetry: &TelemetryThread, whoami: bool) -> Result<()> { let mut spinner = Spinner::new(vec![ SpinnerComponent::Spinner, SpinnerComponent::Text(" Fetching profiles...".into()), ]); - let profiles = list_available_profiles().await; + let profiles = list_available_profiles(database).await?; if profiles.is_empty() { info!("Available profiles was empty"); return Ok(()); } - let sso_region: Option = crate::settings::state::get_string("auth.idc.region").ok().flatten(); + let sso_region = database.get_idc_region()?; let total_profiles = profiles.len() as i64; if whoami && profiles.len() == 1 { if let Some(profile_region) = profiles[0].arn.split(':').nth(3) { - crate::telemetry::send_profile_state( - QProfileSwitchIntent::Update, - profile_region.to_string(), - TelemetryResult::Succeeded, - sso_region, - ) - .await; + telemetry + .send_profile_state( + QProfileSwitchIntent::Update, + profile_region.to_string(), + TelemetryResult::Succeeded, + sso_region, + ) + .ok(); } + spinner.stop_with_message(String::new()); - return Ok(crate::settings::state::set_value( - "api.codewhisperer.profile", - serde_json::to_value(&profiles[0])?, - )?); + database.set_auth_profile(&profiles[0])?; + return Ok(()); } let mut items: Vec = profiles .iter() .map(|p| format!("{} (arn: {})", p.profile_name, p.arn)) .collect(); - let active_profile: Option = crate::settings::state::get("api.codewhisperer.profile")?; + let active_profile = database.get_auth_profile()?; if let Some(default_idx) = active_profile .as_ref() @@ -425,10 +407,8 @@ async fn select_profile_interactive(whoami: bool) -> Result<()> { match selected { Some(i) => { let chosen = &profiles[i]; - let profile = serde_json::to_value(chosen)?; - eprintln!("Set profile: {}\n", chosen.profile_name.as_str().green()); - crate::settings::state::set_value("api.codewhisperer.profile", profile)?; - crate::settings::state::remove_value("api.selectedCustomization")?; + eprintln!("Profile set"); + database.set_auth_profile(chosen)?; if let Some(profile_region) = chosen.arn.split(':').nth(3) { let intent = if whoami { @@ -436,36 +416,32 @@ async fn select_profile_interactive(whoami: bool) -> Result<()> { } else { QProfileSwitchIntent::User }; - crate::telemetry::send_did_select_profile( - intent, - profile_region.to_string(), - TelemetryResult::Succeeded, - sso_region, - Some(total_profiles), - ) - .await; + + telemetry + .send_did_select_profile( + intent, + profile_region.to_string(), + TelemetryResult::Succeeded, + sso_region, + Some(total_profiles), + ) + .ok(); } }, None => { - crate::telemetry::send_did_select_profile( - QProfileSwitchIntent::User, - "not-set".to_string(), - TelemetryResult::Cancelled, - sso_region, - Some(total_profiles), - ) - .await; + telemetry + .send_did_select_profile( + QProfileSwitchIntent::User, + "not-set".to_string(), + TelemetryResult::Cancelled, + sso_region, + Some(total_profiles), + ) + .ok(); + bail!("No profile selected.\n"); }, } Ok(()) } - -mod tests { - #[test] - #[ignore] - fn unset_profile() { - crate::settings::state::remove_value("api.codewhisperer.profile").unwrap(); - } -} diff --git a/crates/chat-cli/src/database/mod.rs b/crates/chat-cli/src/database/mod.rs new file mode 100644 index 0000000000..b5c32348c8 --- /dev/null +++ b/crates/chat-cli/src/database/mod.rs @@ -0,0 +1,472 @@ +pub mod secret_store; +pub mod settings; + +use std::ops::Deref; +use std::str::FromStr; +use std::sync::PoisonError; + +use aws_sdk_cognitoidentity::primitives::DateTimeFormat; +use aws_sdk_cognitoidentity::types::Credentials; +use r2d2::Pool; +use r2d2_sqlite::SqliteConnectionManager; +use rusqlite::types::FromSql; +use rusqlite::{ + Connection, + Error, + ToSql, + params, +}; +use secret_store::SecretStore; +use serde::de::DeserializeOwned; +use serde::{ + Deserialize, + Serialize, +}; +use serde_json::{ + Map, + Value, +}; +use settings::Settings; +use thiserror::Error; +use tracing::info; +use uuid::Uuid; + +use crate::util::directories::{ + DirectoryError, + database_path, +}; + +macro_rules! migrations { + ($($name:expr),*) => {{ + &[ + $( + Migration { + name: $name, + sql: include_str!(concat!("sqlite_migrations/", $name, ".sql")), + } + ),* + ] + }}; +} + +const CREDENTIALS_KEY: &str = "telemetry-cognito-credentials"; +const CLIENT_ID_KEY: &str = "telemetryClientId"; +const CODEWHISPERER_PROFILE_KEY: &str = "api.codewhisperer.profile"; +const START_URL_KEY: &str = "auth.idc.start-url"; +const IDC_REGION_KEY: &str = "auth.idc.region"; +// We include this key to remove for backwards compatibility +const CUSTOMIZATION_STATE_KEY: &str = "api.selectedCustomization"; +const ROTATING_TIP_KEY: &str = "chat.greeting.rotating_tips_current_index"; + +const MIGRATIONS: &[Migration] = migrations![ + "000_migration_table", + "001_history_table", + "002_drop_history_in_ssh_docker", + "003_improved_history_timing", + "004_state_table", + "005_auth_table", + "006_make_state_blob" +]; + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub struct CredentialsJson { + pub access_key_id: Option, + pub secret_key: Option, + pub session_token: Option, + pub expiration: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct AuthProfile { + pub arn: String, + pub profile_name: String, +} + +impl From for AuthProfile { + fn from(profile: amzn_codewhisperer_client::types::Profile) -> Self { + Self { + arn: profile.arn, + profile_name: profile.profile_name, + } + } +} + +// A cloneable error +#[derive(Debug, Clone, thiserror::Error)] +#[error("Failed to open database: {}", .0)] +pub struct DbOpenError(pub(crate) String); + +#[derive(Debug, Error)] +pub enum DatabaseError { + #[error(transparent)] + IoError(#[from] std::io::Error), + #[error(transparent)] + JsonError(#[from] serde_json::Error), + #[error(transparent)] + FigUtilError(#[from] crate::util::UtilError), + #[error(transparent)] + DirectoryError(#[from] DirectoryError), + #[error(transparent)] + Rusqlite(#[from] rusqlite::Error), + #[error(transparent)] + R2d2(#[from] r2d2::Error), + #[error(transparent)] + DbOpenError(#[from] DbOpenError), + #[error("{}", .0)] + PoisonError(String), + #[cfg(target_os = "macos")] + #[error("Security error: {}", .0)] + Security(String), + #[error(transparent)] + StringFromUtf8(#[from] std::string::FromUtf8Error), + #[error(transparent)] + StrFromUtf8(#[from] std::str::Utf8Error), + #[error("`{}` is not a valid setting", .0)] + InvalidSetting(String), +} + +impl From> for DatabaseError { + fn from(value: PoisonError) -> Self { + Self::PoisonError(value.to_string()) + } +} + +#[derive(Debug)] +pub enum Table { + /// The state table contains persistent application state. + State, + /// The conversations tables contains user chat conversations. + #[allow(dead_code)] + Conversations, + #[cfg(not(target_os = "macos"))] + /// The auth table contains + Auth, +} + +impl std::fmt::Display for Table { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Table::State => write!(f, "state"), + Table::Conversations => write!(f, "conversations"), + #[cfg(not(target_os = "macos"))] + Table::Auth => write!(f, "auth_kv"), + } + } +} + +#[derive(Debug)] +struct Migration { + name: &'static str, + sql: &'static str, +} + +#[derive(Debug)] +pub struct Database { + pool: Pool, + pub settings: Settings, + pub secret_store: SecretStore, +} + +impl Database { + pub async fn new() -> Result { + let path = match cfg!(test) { + true => { + return Self { + pool: Pool::builder().build(SqliteConnectionManager::memory()).unwrap(), + settings: Settings::new().await?, + secret_store: SecretStore::new().await?, + } + .migrate(); + }, + false => database_path()?, + }; + + // make the parent dir if it doesnt exist + if let Some(parent) = path.parent() { + if !parent.exists() { + std::fs::create_dir_all(parent)?; + } + } + + let conn = SqliteConnectionManager::file(&path); + let pool = Pool::builder().build(conn)?; + + // Check the unix permissions of the database file, set them to 0600 if they are not + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let metadata = std::fs::metadata(&path)?; + let mut permissions = metadata.permissions(); + if permissions.mode() & 0o777 != 0o600 { + tracing::debug!(?path, "Setting database file permissions to 0600"); + permissions.set_mode(0o600); + std::fs::set_permissions(path, permissions)?; + } + } + + Ok(Self { + pool, + settings: Settings::new().await?, + secret_store: SecretStore::new().await?, + } + .migrate() + .map_err(|e| DbOpenError(e.to_string()))?) + } + + /// Get all entries for dumping the persistent application state. + pub fn get_all_entries(&self) -> Result, DatabaseError> { + self.all_entries(Table::State) + } + + /// Get cognito credentials used by toolkit telemetry. + pub fn get_credentials_entry(&mut self) -> Result, DatabaseError> { + self.get_json_entry::(Table::State, CREDENTIALS_KEY) + } + + /// Set cognito credentials used by toolkit telemetry. + pub fn set_credentials_entry(&mut self, credentials: &Credentials) -> Result { + self.set_json_entry(Table::State, CREDENTIALS_KEY, CredentialsJson { + access_key_id: credentials.access_key_id.clone(), + secret_key: credentials.secret_key.clone(), + session_token: credentials.session_token.clone(), + expiration: credentials + .expiration + .and_then(|t| t.fmt(DateTimeFormat::DateTime).ok()), + }) + } + + /// Get the current user profile used to determine API endpoints. + pub fn get_auth_profile(&self) -> Result, DatabaseError> { + self.get_json_entry(Table::State, CODEWHISPERER_PROFILE_KEY) + } + + /// Set the current user profile used to determine API endpoints. + pub fn set_auth_profile(&mut self, profile: &AuthProfile) -> Result<(), DatabaseError> { + self.set_json_entry(Table::State, CODEWHISPERER_PROFILE_KEY, profile)?; + self.delete_entry(Table::State, CUSTOMIZATION_STATE_KEY) + } + + /// Unset the current user profile used to determine API endpoints. + pub fn unset_auth_profile(&mut self) -> Result<(), DatabaseError> { + self.delete_entry(Table::State, CODEWHISPERER_PROFILE_KEY)?; + self.delete_entry(Table::State, CUSTOMIZATION_STATE_KEY) + } + + /// Get the client ID used for telemetry requests. + pub fn get_client_id(&mut self) -> Result, DatabaseError> { + Ok(self + .get_entry::(Table::State, CLIENT_ID_KEY)? + .and_then(|s| Uuid::from_str(&s).ok())) + } + + /// Set the client ID used for telemetry requests. + pub fn set_client_id(&mut self, client_id: Uuid) -> Result { + self.set_entry(Table::State, CLIENT_ID_KEY, client_id.to_string()) + } + + /// Get the start URL used for IdC login. + pub fn get_start_url(&mut self) -> Result, DatabaseError> { + self.get_json_entry::(Table::State, START_URL_KEY) + } + + /// Set the start URL used for IdC login. + pub fn set_start_url(&mut self, start_url: String) -> Result { + self.set_json_entry(Table::State, START_URL_KEY, start_url) + } + + /// Get the region used for IdC login. + pub fn get_idc_region(&mut self) -> Result, DatabaseError> { + // Annoyingly, this is encoded as a JSON string on older clients + self.get_json_entry::(Table::State, IDC_REGION_KEY) + } + + /// Set the region used for IdC login. + pub fn set_idc_region(&mut self, region: String) -> Result { + // Annoyingly, this is encoded as a JSON string on older clients + self.set_json_entry(Table::State, IDC_REGION_KEY, region) + } + + /// Get the rotating tip used for chat then post increment. + pub fn get_increment_rotating_tip(&mut self) -> Result { + let tip: usize = self.get_entry(Table::State, ROTATING_TIP_KEY)?.unwrap_or(0); + self.set_entry(Table::State, ROTATING_TIP_KEY, tip.wrapping_add(1))?; + Ok(tip) + } + + fn migrate(self) -> Result { + let mut conn = self.pool.get()?; + let transaction = conn.transaction()?; + + // select the max migration id + let max_id = max_migration(&transaction); + + for (version, migration) in MIGRATIONS.iter().enumerate() { + // skip migrations that already exist + match max_id { + Some(max_id) if max_id >= version as i64 => continue, + _ => (), + }; + + // execute the migration + transaction.execute_batch(migration.sql)?; + + info!(%version, name =% migration.name, "Applying migration"); + + // insert the migration entry + transaction.execute( + "INSERT INTO migrations (version, migration_time) VALUES (?1, strftime('%s', 'now'));", + params![version], + )?; + } + + // commit the transaction + transaction.commit()?; + + Ok(self) + } + + fn get_entry(&self, table: Table, key: impl AsRef) -> Result, DatabaseError> { + let conn = self.pool.get()?; + let mut stmt = conn.prepare(&format!("SELECT value FROM {table} WHERE key = ?1"))?; + match stmt.query_row([key.as_ref()], |row| row.get(0)) { + Ok(data) => Ok(Some(data)), + Err(Error::QueryReturnedNoRows) => Ok(None), + Err(err) => Err(err.into()), + } + } + + fn set_entry(&self, table: Table, key: impl AsRef, value: impl ToSql) -> Result { + Ok(self.pool.get()?.execute( + &format!("INSERT OR REPLACE INTO {table} (key, value) VALUES (?1, ?2)"), + params![key.as_ref(), value], + )?) + } + + fn get_json_entry( + &self, + table: Table, + key: impl AsRef, + ) -> Result, DatabaseError> { + Ok(match self.get_entry::(table, key.as_ref())? { + Some(value) => serde_json::from_str(&value)?, + None => None, + }) + } + + fn set_json_entry( + &self, + table: Table, + key: impl AsRef, + value: impl Serialize, + ) -> Result { + self.set_entry(table, key, serde_json::to_string(&value)?) + } + + fn delete_entry(&self, table: Table, key: impl AsRef) -> Result<(), DatabaseError> { + self.pool + .get()? + .execute(&format!("DELETE FROM {table} WHERE key = ?1"), [key.as_ref()])?; + Ok(()) + } + + fn all_entries(&self, table: Table) -> Result, DatabaseError> { + let conn = self.pool.get()?; + let mut stmt = conn.prepare(&format!("SELECT key, value FROM {table}"))?; + let rows = stmt.query_map([], |row| { + let key = row.get(0)?; + let value = Value::String(row.get(1)?); + Ok((key, value)) + })?; + + let mut map = Map::new(); + for row in rows { + let (key, value) = row?; + map.insert(key, value); + } + + Ok(map) + } +} + +fn max_migration>(conn: &C) -> Option { + let mut stmt = conn.prepare("SELECT MAX(id) FROM migrations").ok()?; + stmt.query_row([], |row| row.get(0)).ok() +} + +#[cfg(test)] +mod tests { + use super::*; + + fn all_errors() -> Vec { + vec![ + std::io::Error::new(std::io::ErrorKind::InvalidData, "oops").into(), + serde_json::from_str::<()>("oops").unwrap_err().into(), + crate::util::directories::DirectoryError::NoHomeDirectory.into(), + rusqlite::Error::SqliteSingleThreadedMode.into(), + // r2d2::Error + DbOpenError("oops".into()).into(), + PoisonError::<()>::new(()).into(), + ] + } + + #[test] + fn test_error_display_debug() { + for error in all_errors() { + eprintln!("{}", error); + eprintln!("{:?}", error); + } + } + + #[tokio::test] + async fn test_migrate() { + let db = Database::new().await.unwrap(); + + // assert migration count is correct + let max_migration = max_migration(&&*db.pool.get().unwrap()); + assert_eq!(max_migration, Some(MIGRATIONS.len() as i64)); + } + + #[test] + fn list_migrations() { + // Assert the migrations are in order + assert!(MIGRATIONS.windows(2).all(|w| w[0].name <= w[1].name)); + + // Assert the migrations start with their index + assert!( + MIGRATIONS + .iter() + .enumerate() + .all(|(i, m)| m.name.starts_with(&format!("{:03}_", i))) + ); + + // Assert all the files in migrations/ are in the list + let migration_folder = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("src/database/sqlite_migrations"); + let migration_count = std::fs::read_dir(migration_folder).unwrap().count(); + assert_eq!(MIGRATIONS.len(), migration_count); + } + + #[tokio::test] + async fn state_table_tests() { + let db = Database::new().await.unwrap(); + + // set + db.set_entry(Table::State, "test", "test").unwrap(); + db.set_entry(Table::State, "int", 1).unwrap(); + db.set_entry(Table::State, "float", 1.0).unwrap(); + db.set_entry(Table::State, "bool", true).unwrap(); + db.set_entry(Table::State, "array", vec![1, 2, 3]).unwrap(); + db.set_entry(Table::State, "object", serde_json::json!({ "test": "test" })) + .unwrap(); + db.set_entry(Table::State, "binary", b"test".to_vec()).unwrap(); + + // unset + db.delete_entry(Table::State, "test").unwrap(); + db.delete_entry(Table::State, "int").unwrap(); + + // is some + assert!(db.get_entry::(Table::State, "test").unwrap().is_none()); + assert!(db.get_entry::(Table::State, "int").unwrap().is_none()); + assert!(db.get_entry::(Table::State, "float").unwrap().is_some()); + assert!(db.get_entry::(Table::State, "bool").unwrap().is_some()); + } +} diff --git a/crates/chat-cli/src/database/secret_store/linux.rs b/crates/chat-cli/src/database/secret_store/linux.rs new file mode 100644 index 0000000000..86b70bd50d --- /dev/null +++ b/crates/chat-cli/src/database/secret_store/linux.rs @@ -0,0 +1,48 @@ +use super::Secret; +use super::sqlite::SqliteSecretStore; + +pub struct SqliteSecretStore { + db: &'static Database, +} + +impl SqliteSecretStore { + pub async fn new() -> Result { + Ok(Self { db: database()? }) + } + + pub async fn set(&self, key: &str, password: &str) -> Result<(), DatabaseError> { + Ok(self.db.set_auth_value(key, password)?) + } + + pub async fn get(&self, key: &str) -> Result, DatabaseError> { + Ok(self.db.get_auth_entry(key)?.map(Secret)) + } + + pub async fn delete(&self, key: &str) -> Result<(), DatabaseError> { + Ok(self.db.unset_auth_value(key)?) + } +} + +pub struct SecretStoreImpl { + inner: SqliteSecretStore, +} + +impl SecretStoreImpl { + pub async fn new() -> Result { + Ok(Self { + inner: SqliteSecretStore::new().await?, + }) + } + + pub async fn set(&self, key: &str, password: &str) -> Result<(), DatabaseError> { + self.inner.set(key, password).await + } + + pub async fn get(&self, key: &str) -> Result, DatabaseError> { + self.inner.get(key).await + } + + pub async fn delete(&self, key: &str) -> Result<(), DatabaseError> { + self.inner.delete(key).await + } +} diff --git a/crates/chat-cli/src/auth/secret_store/macos.rs b/crates/chat-cli/src/database/secret_store/macos.rs similarity index 83% rename from crates/chat-cli/src/auth/secret_store/macos.rs rename to crates/chat-cli/src/database/secret_store/macos.rs index 68a7feb02f..7c4aa7fa8a 100644 --- a/crates/chat-cli/src/auth/secret_store/macos.rs +++ b/crates/chat-cli/src/database/secret_store/macos.rs @@ -1,5 +1,5 @@ use super::Secret; -use crate::auth::AuthError; +use crate::database::DatabaseError; /// Path to the `security` binary const SECURITY_BIN: &str = "/usr/bin/security"; @@ -12,12 +12,12 @@ pub struct SecretStoreImpl { } impl SecretStoreImpl { - pub async fn new() -> Result { + pub async fn new() -> Result { Ok(Self { _private: () }) } /// Sets the `key` to `password` on the keychain, this will override any existing value - pub async fn set(&self, key: &str, password: &str) -> Result<(), AuthError> { + pub async fn set(&self, key: &str, password: &str) -> Result<(), DatabaseError> { let output = tokio::process::Command::new(SECURITY_BIN) .args(["add-generic-password", "-U", "-s", key, "-a", ACCOUNT, "-w", password]) .output() @@ -25,7 +25,7 @@ impl SecretStoreImpl { if !output.status.success() { let stderr = std::str::from_utf8(&output.stderr)?; - return Err(AuthError::Security(stderr.into())); + return Err(DatabaseError::Security(stderr.into())); } Ok(()) @@ -34,7 +34,7 @@ impl SecretStoreImpl { /// Returns the password for the `key` /// /// If not found the result will be `Ok(None)`, other errors will be returned - pub async fn get(&self, key: &str) -> Result, AuthError> { + pub async fn get(&self, key: &str) -> Result, DatabaseError> { let output = tokio::process::Command::new(SECURITY_BIN) .args(["find-generic-password", "-s", key, "-a", ACCOUNT, "-w"]) .output() @@ -45,7 +45,7 @@ impl SecretStoreImpl { if stderr.contains("could not be found") { return Ok(None); } else { - return Err(AuthError::Security(stderr.into())); + return Err(DatabaseError::Security(stderr.into())); } } @@ -61,7 +61,7 @@ impl SecretStoreImpl { } /// Deletes the `key` from the keychain - pub async fn delete(&self, key: &str) -> Result<(), AuthError> { + pub async fn delete(&self, key: &str) -> Result<(), DatabaseError> { let output = tokio::process::Command::new(SECURITY_BIN) .args(["delete-generic-password", "-s", key, "-a", ACCOUNT]) .output() @@ -69,7 +69,7 @@ impl SecretStoreImpl { if !output.status.success() { let stderr = std::str::from_utf8(&output.stderr)?; - return Err(AuthError::Security(stderr.into())); + return Err(DatabaseError::Security(stderr.into())); } Ok(()) diff --git a/crates/chat-cli/src/auth/secret_store/mod.rs b/crates/chat-cli/src/database/secret_store/mod.rs similarity index 84% rename from crates/chat-cli/src/auth/secret_store/mod.rs rename to crates/chat-cli/src/database/secret_store/mod.rs index 737e3a6807..ebc0f8b401 100644 --- a/crates/chat-cli/src/auth/secret_store/mod.rs +++ b/crates/chat-cli/src/database/secret_store/mod.rs @@ -1,14 +1,10 @@ -#[cfg(any(target_os = "linux", windows))] -mod linux; -#[cfg(target_os = "macos")] -mod macos; -mod sqlite; -#[cfg(any(target_os = "linux", windows))] -use linux::SecretStoreImpl; -#[cfg(target_os = "macos")] -use macos::SecretStoreImpl; - -use super::AuthError; +#[cfg_attr(target_os = "macos", path = "macos.rs")] +#[cfg_attr(any(target_os = "linux", windows), path = "linux.rs")] +mod os; + +use os::SecretStoreImpl; + +use super::DatabaseError; #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, serde::Serialize, serde::Deserialize)] #[serde(transparent)] @@ -34,19 +30,19 @@ pub struct SecretStore { } impl SecretStore { - pub async fn new() -> Result { + pub async fn new() -> Result { SecretStoreImpl::new().await.map(|inner| Self { inner }) } - pub async fn set(&self, key: &str, password: &str) -> Result<(), AuthError> { + pub async fn set(&self, key: &str, password: &str) -> Result<(), DatabaseError> { self.inner.set(key, password).await } - pub async fn get(&self, key: &str) -> Result, AuthError> { + pub async fn get(&self, key: &str) -> Result, DatabaseError> { self.inner.get(key).await } - pub async fn delete(&self, key: &str) -> Result<(), AuthError> { + pub async fn delete(&self, key: &str) -> Result<(), DatabaseError> { self.inner.delete(key).await } } diff --git a/crates/chat-cli/src/database/settings.rs b/crates/chat-cli/src/database/settings.rs new file mode 100644 index 0000000000..eb14c4c684 --- /dev/null +++ b/crates/chat-cli/src/database/settings.rs @@ -0,0 +1,213 @@ +use std::fmt::Display; +use std::io::SeekFrom; + +use fd_lock::RwLock; +use serde_json::{ + Map, + Value, +}; +use tokio::fs::File; +use tokio::io::{ + AsyncReadExt, + AsyncSeekExt, + AsyncWriteExt, +}; + +use super::DatabaseError; + +#[derive(Clone, Copy, Debug)] +pub enum Setting { + TelemetryEnabled, + OldClientId, + ShareCodeWhispererContent, + EnabledThinking, + SkimCommandKey, + ChatGreetingEnabled, + ApiTimeout, + ChatEditMode, + ChatEnableNotifications, + ApiCodeWhispererService, + ApiQService, +} + +impl AsRef for Setting { + fn as_ref(&self) -> &'static str { + match self { + Self::TelemetryEnabled => "telemetry.enabled", + Self::OldClientId => "telemetryClientId", + Self::ShareCodeWhispererContent => "codeWhisperer.shareCodeWhispererContentWithAWS", + Self::EnabledThinking => "chat.enableThinking", + Self::SkimCommandKey => "chat.skimCommandKey", + Self::ChatGreetingEnabled => "chat.greeting.enabled", + Self::ApiTimeout => "api.timeout", + Self::ChatEditMode => "chat.editMode", + Self::ChatEnableNotifications => "chat.enableNotifications", + Self::ApiCodeWhispererService => "api.codewhisperer.service", + Self::ApiQService => "api.q.service", + } + } +} + +impl Display for Setting { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_ref()) + } +} + +impl TryFrom<&str> for Setting { + type Error = DatabaseError; + + fn try_from(value: &str) -> Result { + match value { + "telemetry.enabled" => Ok(Self::TelemetryEnabled), + "telemetryClientId" => Ok(Self::OldClientId), + "codeWhisperer.shareCodeWhispererContentWithAWS" => Ok(Self::ShareCodeWhispererContent), + "chat.enableThinking" => Ok(Self::EnabledThinking), + "chat.skimCommandKey" => Ok(Self::SkimCommandKey), + "chat.greeting.enabled" => Ok(Self::ChatGreetingEnabled), + "api.timeout" => Ok(Self::ApiTimeout), + "chat.editMode" => Ok(Self::ChatEditMode), + "chat.enableNotifications" => Ok(Self::ChatEnableNotifications), + "api.codewhisperer.service" => Ok(Self::ApiCodeWhispererService), + "api.q.service" => Ok(Self::ApiQService), + _ => Err(DatabaseError::InvalidSetting(value.to_string())), + } + } +} + +#[derive(Debug, Clone, Default)] +pub struct Settings(Map); + +impl Settings { + pub async fn new() -> Result { + if cfg!(test) { + return Ok(Self::default()); + } + + let path = crate::util::directories::settings_path()?; + + // If the folder doesn't exist, create it. + if let Some(parent) = path.parent() { + if !parent.exists() { + std::fs::create_dir_all(parent)?; + } + } + + Ok(Self(match path.exists() { + true => { + let mut file = RwLock::new(File::open(&path).await?); + let mut buf = Vec::new(); + file.write()?.read_to_end(&mut buf).await?; + serde_json::from_slice(&buf)? + }, + false => { + let mut file = RwLock::new(File::create(path).await?); + file.write()?.write_all(b"{}").await?; + serde_json::Map::new() + }, + })) + } + + pub fn map(&self) -> &'_ Map { + &self.0 + } + + pub fn get(&self, key: Setting) -> Option<&Value> { + self.0.get(key.as_ref()) + } + + pub async fn set(&mut self, key: Setting, value: impl Into) -> Result<(), DatabaseError> { + self.0.insert(key.to_string(), value.into()); + self.save_to_file().await + } + + pub async fn remove(&mut self, key: Setting) -> Result, DatabaseError> { + let key = self.0.remove(key.as_ref()); + self.save_to_file().await?; + Ok(key) + } + + pub fn get_bool(&self, key: Setting) -> Option { + self.get(key).and_then(|value| value.as_bool()) + } + + pub fn get_string(&self, key: Setting) -> Option { + self.get(key).and_then(|value| value.as_str().map(|s| s.into())) + } + + pub fn get_int(&self, key: Setting) -> Option { + self.get(key).and_then(|value| value.as_i64()) + } + + async fn save_to_file(&self) -> Result<(), DatabaseError> { + if cfg!(test) { + return Ok(()); + } + + let path = crate::util::directories::settings_path()?; + + // If the folder doesn't exist, create it. + if let Some(parent) = path.parent() { + if !parent.exists() { + tokio::fs::create_dir_all(parent).await?; + } + } + + let mut file_opts = File::options(); + file_opts.create(true).write(true).truncate(true); + + #[cfg(unix)] + file_opts.mode(0o600); + let mut file = RwLock::new(file_opts.open(&path).await?); + let mut lock = file.write()?; + + match serde_json::to_string_pretty(&self.0) { + Ok(json) => lock.write_all(json.as_bytes()).await?, + Err(_err) => { + lock.seek(SeekFrom::Start(0)).await?; + lock.set_len(0).await?; + lock.write_all(b"{}").await?; + }, + } + lock.flush().await?; + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + /// General read/write settings test + #[tokio::test] + async fn test_settings() { + let mut settings = Settings::new().await.unwrap(); + + assert_eq!(settings.get(Setting::TelemetryEnabled), None); + assert_eq!(settings.get(Setting::OldClientId), None); + assert_eq!(settings.get(Setting::ShareCodeWhispererContent), None); + + settings.set(Setting::TelemetryEnabled, true).await.unwrap(); + settings.set(Setting::OldClientId, "test").await.unwrap(); + settings.set(Setting::ShareCodeWhispererContent, false).await.unwrap(); + + assert_eq!(settings.get(Setting::TelemetryEnabled), Some(&Value::Bool(true))); + assert_eq!( + settings.get(Setting::OldClientId), + Some(&Value::String("test".to_string())) + ); + assert_eq!( + settings.get(Setting::ShareCodeWhispererContent), + Some(&Value::Bool(false)) + ); + + settings.remove(Setting::TelemetryEnabled).await.unwrap(); + settings.remove(Setting::OldClientId).await.unwrap(); + settings.remove(Setting::ShareCodeWhispererContent).await.unwrap(); + + assert_eq!(settings.get(Setting::TelemetryEnabled), None); + assert_eq!(settings.get(Setting::OldClientId), None); + assert_eq!(settings.get(Setting::ShareCodeWhispererContent), None); + } +} diff --git a/crates/chat-cli/src/settings/sqlite_migrations/000_migration_table.sql b/crates/chat-cli/src/database/sqlite_migrations/000_migration_table.sql similarity index 100% rename from crates/chat-cli/src/settings/sqlite_migrations/000_migration_table.sql rename to crates/chat-cli/src/database/sqlite_migrations/000_migration_table.sql diff --git a/crates/chat-cli/src/settings/sqlite_migrations/001_history_table.sql b/crates/chat-cli/src/database/sqlite_migrations/001_history_table.sql similarity index 100% rename from crates/chat-cli/src/settings/sqlite_migrations/001_history_table.sql rename to crates/chat-cli/src/database/sqlite_migrations/001_history_table.sql diff --git a/crates/chat-cli/src/settings/sqlite_migrations/002_drop_history_in_ssh_docker.sql b/crates/chat-cli/src/database/sqlite_migrations/002_drop_history_in_ssh_docker.sql similarity index 100% rename from crates/chat-cli/src/settings/sqlite_migrations/002_drop_history_in_ssh_docker.sql rename to crates/chat-cli/src/database/sqlite_migrations/002_drop_history_in_ssh_docker.sql diff --git a/crates/chat-cli/src/settings/sqlite_migrations/003_improved_history_timing.sql b/crates/chat-cli/src/database/sqlite_migrations/003_improved_history_timing.sql similarity index 100% rename from crates/chat-cli/src/settings/sqlite_migrations/003_improved_history_timing.sql rename to crates/chat-cli/src/database/sqlite_migrations/003_improved_history_timing.sql diff --git a/crates/chat-cli/src/settings/sqlite_migrations/004_state_table.sql b/crates/chat-cli/src/database/sqlite_migrations/004_state_table.sql similarity index 100% rename from crates/chat-cli/src/settings/sqlite_migrations/004_state_table.sql rename to crates/chat-cli/src/database/sqlite_migrations/004_state_table.sql diff --git a/crates/chat-cli/src/settings/sqlite_migrations/005_auth_table.sql b/crates/chat-cli/src/database/sqlite_migrations/005_auth_table.sql similarity index 100% rename from crates/chat-cli/src/settings/sqlite_migrations/005_auth_table.sql rename to crates/chat-cli/src/database/sqlite_migrations/005_auth_table.sql diff --git a/crates/chat-cli/src/database/sqlite_migrations/006_make_state_blob.sql b/crates/chat-cli/src/database/sqlite_migrations/006_make_state_blob.sql new file mode 100644 index 0000000000..fc3153823b --- /dev/null +++ b/crates/chat-cli/src/database/sqlite_migrations/006_make_state_blob.sql @@ -0,0 +1,7 @@ +ALTER TABLE state RENAME TO state_old; +CREATE TABLE state ( + key TEXT PRIMARY KEY, + value BLOB +); +INSERT INTO state SELECT key, value FROM state_old; +DROP TABLE state_old; \ No newline at end of file diff --git a/crates/chat-cli/src/install.rs b/crates/chat-cli/src/install.rs index ced83a66b1..b806856df8 100644 --- a/crates/chat-cli/src/install.rs +++ b/crates/chat-cli/src/install.rs @@ -10,7 +10,7 @@ pub enum Error { #[error(transparent)] Util(#[from] crate::util::UtilError), #[error(transparent)] - Settings(#[from] crate::settings::SettingsError), + Settings(#[from] crate::database::DatabaseError), #[error(transparent)] Reqwest(#[from] reqwest::Error), #[error(transparent)] diff --git a/crates/chat-cli/src/main.rs b/crates/chat-cli/src/main.rs index 4e7e150d68..e28354bfb4 100644 --- a/crates/chat-cli/src/main.rs +++ b/crates/chat-cli/src/main.rs @@ -2,12 +2,12 @@ mod api_client; mod auth; mod aws_common; mod cli; +mod database; mod install; mod logging; mod mcp_client; mod platform; mod request; -mod settings; mod telemetry; mod util; @@ -15,72 +15,28 @@ use std::process::ExitCode; use anstream::eprintln; use clap::Parser; -use clap::error::{ - ContextKind, - ErrorKind, -}; use crossterm::style::Stylize; use eyre::Result; use logging::get_log_level_max; use tracing::metadata::LevelFilter; -use crate::telemetry::finish_telemetry; -use crate::util::{ - CHAT_BINARY_NAME, - PRODUCT_NAME, -}; - #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; fn main() -> Result { color_eyre::install()?; - let multithread = matches!( - std::env::args().nth(1).as_deref(), - Some("init" | "_" | "internal" | "completion" | "hook" | "chat") - ); - let parsed = match cli::Cli::try_parse() { Ok(cli) => cli, Err(err) => { - let _ = err.print(); - - let unknown_arg = matches!(err.kind(), ErrorKind::UnknownArgument | ErrorKind::InvalidSubcommand) - && !err.context().any(|(context_kind, _)| { - matches!( - context_kind, - ContextKind::SuggestedSubcommand | ContextKind::SuggestedArg - ) - }); - - if unknown_arg { - eprintln!( - "\nThis command may be valid in newer versions of the {PRODUCT_NAME} CLI. Try running {} {}.", - CHAT_BINARY_NAME.magenta(), - "update".magenta() - ); - } - + err.print().ok(); return Ok(ExitCode::from(err.exit_code().try_into().unwrap_or(2))); }, }; let verbose = parsed.verbose > 0; - - let runtime = if multithread { - tokio::runtime::Builder::new_multi_thread() - } else { - tokio::runtime::Builder::new_current_thread() - } - .enable_all() - .build()?; - - let result = runtime.block_on(async { - let result = parsed.execute().await; - finish_telemetry().await; - result - }); + let runtime = tokio::runtime::Builder::new_multi_thread().enable_all().build()?; + let result = runtime.block_on(parsed.execute()); match result { Ok(exit_code) => Ok(exit_code), @@ -90,6 +46,7 @@ fn main() -> Result { } else { eprintln!("{} {err}", "error:".bold().red()); } + Ok(ExitCode::FAILURE) }, } diff --git a/crates/chat-cli/src/request.rs b/crates/chat-cli/src/request.rs index db4f7d7ee1..e9b3abacc2 100644 --- a/crates/chat-cli/src/request.rs +++ b/crates/chat-cli/src/request.rs @@ -1,7 +1,4 @@ use std::env::current_exe; -use std::fs::File; -use std::io::BufReader; -use std::path::Path; use std::sync::{ Arc, LazyLock, @@ -12,71 +9,31 @@ use rustls::{ ClientConfig, RootCertStore, }; +use thiserror::Error; use url::ParseError; -#[derive(Debug)] +#[derive(Debug, Error)] pub enum RequestError { - Reqwest(reqwest::Error), - Serde(serde_json::Error), - Io(std::io::Error), - Dir(crate::util::directories::DirectoryError), - Settings(crate::settings::SettingsError), - UrlParseError(ParseError), + #[error(transparent)] + Reqwest(#[from] reqwest::Error), + #[error(transparent)] + Serde(#[from] serde_json::Error), + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + Dir(#[from] crate::util::directories::DirectoryError), + #[error(transparent)] + Settings(#[from] crate::database::DatabaseError), + #[error(transparent)] + UrlParseError(#[from] ParseError), } -impl std::fmt::Display for RequestError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - RequestError::Reqwest(err) => write!(f, "Reqwest error: {err}"), - RequestError::Serde(err) => write!(f, "Serde error: {err}"), - RequestError::Io(err) => write!(f, "Io error: {err}"), - RequestError::Dir(err) => write!(f, "Dir error: {err}"), - RequestError::Settings(err) => write!(f, "Settings error: {err}"), - RequestError::UrlParseError(err) => write!(f, "Url parse error: {err}"), - } - } -} - -impl std::error::Error for RequestError {} - -impl From for RequestError { - fn from(e: reqwest::Error) -> Self { - RequestError::Reqwest(e) - } -} - -impl From for RequestError { - fn from(e: serde_json::Error) -> Self { - RequestError::Serde(e) - } -} - -impl From for RequestError { - fn from(e: std::io::Error) -> Self { - RequestError::Io(e) - } -} - -impl From for RequestError { - fn from(e: crate::util::directories::DirectoryError) -> Self { - RequestError::Dir(e) - } -} - -impl From for RequestError { - fn from(e: crate::settings::SettingsError) -> Self { - RequestError::Settings(e) - } -} - -impl From for RequestError { - fn from(e: ParseError) -> Self { - RequestError::UrlParseError(e) - } -} - -pub fn client() -> Option<&'static Client> { - CLIENT_NATIVE_CERTS.as_ref() +pub fn new_client() -> Result { + Ok(Client::builder() + .use_preconfigured_tls(client_config()) + .user_agent(USER_AGENT.chars().filter(|c| c.is_ascii_graphic()).collect::()) + .cookie_store(true) + .build()?) } pub fn create_default_root_cert_store() -> RootCertStore { @@ -89,29 +46,6 @@ pub fn create_default_root_cert_store() -> RootCertStore { let _ = root_cert_store.add(cert); } - let custom_cert = std::env::var("Q_CUSTOM_CERT") - .ok() - .or_else(|| crate::settings::state::get_string("Q_CUSTOM_CERT").ok().flatten()); - - if let Some(custom_cert) = custom_cert { - match File::open(Path::new(&custom_cert)) { - Ok(file) => { - let reader = &mut BufReader::new(file); - for cert in rustls_pemfile::certs(reader) { - match cert { - Ok(cert) => { - if let Err(err) = root_cert_store.add(cert) { - tracing::error!(path =% custom_cert, %err, "Failed to add custom cert"); - }; - }, - Err(err) => tracing::error!(path =% custom_cert, %err, "Failed to parse cert"), - } - } - }, - Err(err) => tracing::error!(path =% custom_cert, %err, "Failed to open cert at"), - } - } - root_cert_store } @@ -127,12 +61,6 @@ fn client_config() -> ClientConfig { .with_no_client_auth() } -static CLIENT_CONFIG_NATIVE_CERTS: LazyLock> = LazyLock::new(|| Arc::new(client_config())); - -pub fn client_config_cached() -> Arc { - CLIENT_CONFIG_NATIVE_CERTS.clone() -} - static USER_AGENT: LazyLock = LazyLock::new(|| { let name = current_exe() .ok() @@ -146,24 +74,13 @@ static USER_AGENT: LazyLock = LazyLock::new(|| { format!("{name}-{os}-{arch}-{version}") }); -pub static CLIENT_NATIVE_CERTS: LazyLock> = LazyLock::new(|| { - Some( - Client::builder() - .use_preconfigured_tls((*client_config_cached()).clone()) - .user_agent(USER_AGENT.chars().filter(|c| c.is_ascii_graphic()).collect::()) - .cookie_store(true) - .build() - .unwrap(), - ) -}); - #[cfg(test)] mod tests { use super::*; - #[test] - fn get_client() { - client().unwrap(); + #[tokio::test] + async fn get_client() { + new_client().unwrap(); } #[tokio::test] @@ -177,7 +94,7 @@ mod tests { .create(); let url = server.url(); - let client = client().unwrap(); + let client = new_client().unwrap(); let res = client.get(format!("{url}/hello")).send().await.unwrap(); assert_eq!(res.status(), 200); assert_eq!(res.headers()["content-type"], "text/plain"); diff --git a/crates/chat-cli/src/settings/actions.json b/crates/chat-cli/src/settings/actions.json deleted file mode 100644 index c5a886f665..0000000000 --- a/crates/chat-cli/src/settings/actions.json +++ /dev/null @@ -1,216 +0,0 @@ -[ - { - "identifier": "insertSelected", - "name": "Insert selected", - "description": "Insert selected suggestion", - "category": "Insertion", - "availability": "WHEN_FOCUSED", - "defaultBindings": ["enter"] - }, - { - "identifier": "insertCommonPrefix", - "name": "Insert common prefix", - "description": "Insert shared prefix of available suggestions. Shake if there's no common prefix.", - "category": "Insertion", - "availability": "WHEN_FOCUSED", - "defaultBindings": ["tab"] - }, - { - "identifier": "insertCommonPrefixOrNavigateDown", - "name": "Insert common prefix or navigate", - "description": "Insert shared prefix of available suggestions. Navigate if there's no common prefix.", - "category": "Insertion", - "availability": "WHEN_FOCUSED" - }, - { - "identifier": "insertCommonPrefixOrInsertSelected", - "name": "Insert common prefix or insert selected", - "description": "Insert shared prefix of available suggestions. Insert currently selected suggestion if there's not common prefix.", - "category": "Insertion", - "availability": "WHEN_FOCUSED" - }, - { - "identifier": "insertSelectedAndExecute", - "name": "Insert selected and execute", - "description": "Insert selected suggestion and then execute the current command.", - "category": "Insertion", - "availability": "WHEN_FOCUSED" - }, - { - "identifier": "execute", - "name": "Execute", - "description": "Execute the current command.", - "category": "Insertion", - "availability": "WHEN_FOCUSED" - }, - { - "identifier": "hideAutocomplete", - "name": "Hide autocomplete", - "description": "Hide the autocomplete window", - "category": "General", - "availability": "ALWAYS", - "defaultBindings": ["esc"] - }, - { - "identifier": "showAutocomplete", - "name": "Show autocomplete", - "description": "Show the autocomplete window", - "category": "General", - "availability": "ALWAYS" - }, - { - "identifier": "toggleAutocomplete", - "name": "Toggle autocomplete", - "description": "Toggle the visibility of the autocomplete window", - "availability": "ALWAYS" - }, - { - "identifier": "navigateUp", - "name": "Navigate up", - "description": "Scroll up one entry in the list of suggestions", - "category": "Navigation", - "availability": "WHEN_FOCUSED", - "defaultBindings": ["shift+tab", "up", "control+k", "control+p"] - }, - { - "identifier": "navigateDown", - "name": "Navigate down", - "description": "Scroll down one entry in the list of suggestions", - "category": "Navigation", - "availability": "WHEN_FOCUSED", - "defaultBindings": ["down", "control+j", "control+n"] - }, - { - "identifier": "selectSuggestion1", - "name": "Select 1st suggestion", - "description": "Select the 1st suggestion of the list", - "category": "Navigation", - "availability": "WHEN_FOCUSED", - "defaultBindings": ["control+1"] - }, - { - "identifier": "selectSuggestion2", - "name": "Select 2nd suggestion", - "description": "Select the 2nd suggestion of the list", - "category": "Navigation", - "availability": "WHEN_FOCUSED", - "defaultBindings": ["control+2"] - }, - { - "identifier": "selectSuggestion3", - "name": "Select 3rd suggestion", - "description": "Select the 3rd suggestion of the list", - "category": "Navigation", - "availability": "WHEN_FOCUSED", - "defaultBindings": ["control+3"] - }, - { - "identifier": "selectSuggestion4", - "name": "Select 4th suggestion", - "description": "Select the 4th suggestion of the list", - "category": "Navigation", - "availability": "WHEN_FOCUSED", - "defaultBindings": ["control+4"] - }, - { - "identifier": "selectSuggestion5", - "name": "Select 5th suggestion", - "description": "Select the 5th suggestion of the list", - "category": "Navigation", - "availability": "WHEN_FOCUSED", - "defaultBindings": ["control+5"] - }, - { - "identifier": "selectSuggestion6", - "name": "Select 6th suggestion", - "description": "Select the 6th suggestion of the list", - "category": "Navigation", - "availability": "WHEN_FOCUSED", - "defaultBindings": ["control+6"] - }, - { - "identifier": "selectSuggestion7", - "name": "Select 7th suggestion", - "description": "Select the 7th suggestion of the list", - "category": "Navigation", - "availability": "WHEN_FOCUSED", - "defaultBindings": ["control+7"] - }, - { - "identifier": "selectSuggestion8", - "name": "Select 8th suggestion", - "description": "Select the 8th suggestion of the list", - "category": "Navigation", - "availability": "WHEN_FOCUSED", - "defaultBindings": ["control+8"] - }, - { - "identifier": "selectSuggestion9", - "name": "Select 9th suggestion", - "description": "Select the 9th suggestion of the list", - "category": "Navigation", - "availability": "WHEN_FOCUSED", - "defaultBindings": ["control+9"] - }, - { - "identifier": "selectSuggestion10", - "name": "Select 10th suggestion", - "description": "Select the 10th suggestion of the list", - "category": "Navigation", - "availability": "WHEN_FOCUSED", - "defaultBindings": ["control+0"] - }, - { - "identifier": "hideDescription", - "name": "Hide description popout", - "description": "Hide autocomplete description popout", - "category": "Appearance", - "availability": "WHEN_FOCUSED" - }, - { - "identifier": "showDescription", - "name": "Show description popout", - "description": "Show autocomplete description popout", - "category": "Appearance", - "availability": "WHEN_FOCUSED" - }, - { - "identifier": "toggleDescription", - "name": "Toggle description popout", - "description": "Toggle visibility of autocomplete description popout", - "category": "Appearance", - "availability": "WHEN_FOCUSED", - "defaultBindings": ["command+i"] - }, - { - "identifier": "toggleHistoryMode", - "name": "Toggle history mode", - "description": "Toggle between history suggestions and autocomplete spec suggestions", - "category": "General", - "availability": "WHEN_FOCUSED", - "defaultBindings": ["control+r"] - }, - { - "identifier": "toggleFuzzySearch", - "name": "Toggle fuzzy search", - "description": "Toggle between normal prefix search and fuzzy search", - "category": "General", - "availability": "WHEN_FOCUSED" - }, - { - "identifier": "increaseSize", - "name": "Increase window size", - "description": "Increase the size of the autocomplete window", - "category": "Appearance", - "availability": "WHEN_FOCUSED", - "defaultBindings": ["command+="] - }, - { - "identifier": "decreaseSize", - "name": "Decrease window size", - "description": "Decrease the size of the autocomplete window", - "category": "Appearance", - "availability": "WHEN_FOCUSED", - "defaultBindings": ["command+-"] - } -] diff --git a/crates/chat-cli/src/settings/error.rs b/crates/chat-cli/src/settings/error.rs deleted file mode 100644 index 4393696f09..0000000000 --- a/crates/chat-cli/src/settings/error.rs +++ /dev/null @@ -1,70 +0,0 @@ -use std::sync::PoisonError; - -use thiserror::Error; - -use crate::util::directories::DirectoryError; - -// A cloneable error -#[derive(Debug, Clone, thiserror::Error)] -#[error("Failed to open database: {}", .0)] -pub struct DbOpenError(pub(crate) String); - -#[derive(Debug, Error)] -pub enum SettingsError { - #[error(transparent)] - IoError(#[from] std::io::Error), - #[error(transparent)] - JsonError(#[from] serde_json::Error), - #[error(transparent)] - FigUtilError(#[from] crate::util::UtilError), - #[error("settings file is not a json object")] - SettingsNotObject, - #[error(transparent)] - DirectoryError(#[from] DirectoryError), - #[error("memory backend is not used")] - MemoryBackendNotUsed, - #[error(transparent)] - Rusqlite(#[from] rusqlite::Error), - #[error(transparent)] - R2d2(#[from] r2d2::Error), - #[error(transparent)] - DbOpenError(#[from] DbOpenError), - #[error("{}", .0)] - PoisonError(String), -} - -impl From> for SettingsError { - fn from(value: PoisonError) -> Self { - Self::PoisonError(value.to_string()) - } -} - -pub type Result = std::result::Result; - -#[cfg(test)] -mod tests { - use super::*; - - fn all_errors() -> Vec { - vec![ - std::io::Error::new(std::io::ErrorKind::InvalidData, "oops").into(), - serde_json::from_str::<()>("oops").unwrap_err().into(), - crate::util::UtilError::UnsupportedPlatform.into(), - SettingsError::SettingsNotObject, - crate::util::directories::DirectoryError::NoHomeDirectory.into(), - SettingsError::MemoryBackendNotUsed, - rusqlite::Error::SqliteSingleThreadedMode.into(), - // r2d2::Error - DbOpenError("oops".into()).into(), - PoisonError::<()>::new(()).into(), - ] - } - - #[test] - fn test_error_display_debug() { - for error in all_errors() { - eprintln!("{}", error); - eprintln!("{:?}", error); - } - } -} diff --git a/crates/chat-cli/src/settings/keys.rs b/crates/chat-cli/src/settings/keys.rs deleted file mode 100644 index b46055b1ae..0000000000 --- a/crates/chat-cli/src/settings/keys.rs +++ /dev/null @@ -1 +0,0 @@ -pub const UPDATE_AVAILABLE_KEY: &str = "update.new-version-available"; diff --git a/crates/chat-cli/src/settings/mod.rs b/crates/chat-cli/src/settings/mod.rs deleted file mode 100644 index 058c86d22e..0000000000 --- a/crates/chat-cli/src/settings/mod.rs +++ /dev/null @@ -1,328 +0,0 @@ -pub mod error; -#[allow(clippy::module_inception)] -pub mod settings; -pub mod sqlite; -pub mod state; - -use std::fs::{ - self, - File, -}; -use std::io::{ - Seek, - SeekFrom, - Write, -}; -use std::path::PathBuf; - -pub use error::{ - Result, - SettingsError, -}; -use fd_lock::RwLock as FileRwLock; -use parking_lot::{ - MappedRwLockReadGuard, - MappedRwLockWriteGuard, - RwLock, - RwLockReadGuard, - RwLockWriteGuard, -}; -use serde_json::Value; -pub use settings::Settings; -pub use state::State; - -use crate::util::directories; - -pub type Map = serde_json::Map; - -static SETTINGS_FILE_LOCK: RwLock<()> = RwLock::new(()); - -static SETTINGS_DATA: RwLock> = RwLock::new(None); - -#[derive(Debug, Clone)] -pub enum Backend { - Global, - Memory(Map), -} - -pub enum ReadGuard<'a, T> { - Global(RwLockReadGuard<'a, Option>), - Memory(&'a T), -} - -impl<'a, T> ReadGuard<'a, T> { - pub fn try_map Option<&U>>(self, f: F) -> Option> { - match self { - ReadGuard::Global(guard) => RwLockReadGuard::<'a, Option>::try_map(guard, |data: &Option| { - f(data.as_ref().expect("global backend is not used")) - }) - .ok() - .map(MappedReadGuard::Global), - ReadGuard::Memory(data) => f(data).map(MappedReadGuard::Memory), - } - } -} - -impl std::ops::Deref for ReadGuard<'_, T> { - type Target = T; - - fn deref(&self) -> &Self::Target { - match self { - ReadGuard::Global(guard) => guard.as_ref().expect("global backend is not used"), - ReadGuard::Memory(data) => data, - } - } -} - -pub enum MappedReadGuard<'a, T> { - Global(MappedRwLockReadGuard<'a, T>), - Memory(&'a T), -} - -impl std::ops::Deref for MappedReadGuard<'_, T> { - type Target = T; - - fn deref(&self) -> &Self::Target { - match self { - MappedReadGuard::Global(guard) => guard, - MappedReadGuard::Memory(data) => data, - } - } -} - -pub enum WriteGuard<'a, T> { - Global(RwLockWriteGuard<'a, Option>), - Memory(&'a mut T), -} - -impl<'a, T> WriteGuard<'a, T> { - pub fn try_map Option<&mut U>>(self, f: F) -> Option> { - match self { - WriteGuard::Global(guard) => RwLockWriteGuard::<'a, Option>::try_map(guard, |data: &mut Option| { - f(data.as_mut().expect("global backend is not used")) - }) - .ok() - .map(MappedWriteGuard::Global), - WriteGuard::Memory(data) => f(data).map(MappedWriteGuard::Memory), - } - } -} - -impl std::ops::Deref for WriteGuard<'_, T> { - type Target = T; - - fn deref(&self) -> &Self::Target { - match self { - WriteGuard::Global(guard) => guard.as_ref().expect("global backend is not used"), - WriteGuard::Memory(data) => data, - } - } -} - -impl std::ops::DerefMut for WriteGuard<'_, T> { - fn deref_mut(&mut self) -> &mut Self::Target { - match self { - WriteGuard::Global(guard) => guard.as_mut().expect("global backend is not used"), - WriteGuard::Memory(data) => data, - } - } -} - -pub enum MappedWriteGuard<'a, T> { - Global(MappedRwLockWriteGuard<'a, T>), - Memory(&'a mut T), -} - -impl std::ops::Deref for MappedWriteGuard<'_, T> { - type Target = T; - - fn deref(&self) -> &Self::Target { - match self { - MappedWriteGuard::Global(guard) => guard, - MappedWriteGuard::Memory(data) => data, - } - } -} - -impl std::ops::DerefMut for MappedWriteGuard<'_, T> { - fn deref_mut(&mut self) -> &mut Self::Target { - match self { - MappedWriteGuard::Global(guard) => guard, - MappedWriteGuard::Memory(data) => data, - } - } -} - -pub trait JsonStore: Sized { - /// Path to the file - fn path() -> Result; - - /// In mem lock on the file - fn file_lock() -> &'static RwLock<()>; - - /// [RwLock] on the data, [None] if not using the global backend - fn data_lock() -> &'static RwLock>; - - fn new_from_backend(backend: Backend) -> Self; - - fn map(&self) -> ReadGuard<'_, Map>; - - fn map_mut(&mut self) -> WriteGuard<'_, Map>; - - fn load() -> Result { - let is_global = Self::data_lock().read().as_ref().is_some(); - if is_global { - Ok(Self::new_from_backend(Backend::Global)) - } else { - Ok(Self::new_from_backend(Backend::Memory(Self::load_from_file()?))) - } - } - - fn load_from_file() -> Result { - let path = Self::path()?; - - // If the folder doesn't exist, create it. - if let Some(parent) = path.parent() { - if !parent.exists() { - fs::create_dir_all(parent)?; - } - } - - let json: Map = { - let _lock_guard = Self::file_lock().write(); - - // If the file doesn't exist, create it. - if !path.exists() { - let mut file = FileRwLock::new(File::create(path)?); - file.write()?.write_all(b"{}")?; - serde_json::Map::new() - } else { - let mut file = FileRwLock::new(File::open(&path)?); - let mut read = file.write()?; - serde_json::from_reader(&mut *read)? - } - }; - - Ok(json) - } - - fn save_to_file(&self) -> Result<()> { - let path = Self::path()?; - - // If the folder doesn't exist, create it. - if let Some(parent) = path.parent() { - if !parent.exists() { - fs::create_dir_all(parent)?; - } - } - - let _lock_guard = Self::file_lock().write(); - - let mut file_opts = File::options(); - file_opts.create(true).write(true).truncate(true); - - #[cfg(unix)] - { - use std::os::unix::fs::OpenOptionsExt; - file_opts.mode(0o600); - } - - let mut file = FileRwLock::new(file_opts.open(&path)?); - let mut lock = file.write()?; - - if let Err(_err) = serde_json::to_writer_pretty(&mut *lock, &*self.map()) { - // Write {} to the file if the serialization failed - lock.seek(SeekFrom::Start(0))?; - lock.set_len(0)?; - lock.write_all(b"{}")?; - }; - lock.flush()?; - - Ok(()) - } - - fn set(&mut self, key: impl Into, value: impl Into) { - self.map_mut().insert(key.into(), value.into()); - } - - fn get(&self, key: impl AsRef) -> Option> { - self.map().try_map(|data| data.get(key.as_ref())) - } - - fn remove(&mut self, key: impl AsRef) -> Option { - self.map_mut().remove(key.as_ref()) - } - - #[allow(dead_code)] - fn get_mut(&mut self, key: impl Into) -> Option> { - self.map_mut().try_map(|data| data.get_mut(&key.into())) - } - - fn get_bool(&self, key: impl AsRef) -> Option { - self.get(key).and_then(|value| value.as_bool()) - } - - #[allow(dead_code)] - fn get_bool_or(&self, key: impl AsRef, default: bool) -> bool { - self.get_bool(key).unwrap_or(default) - } - - fn get_string(&self, key: impl AsRef) -> Option { - self.get(key).and_then(|value| value.as_str().map(|s| s.into())) - } - - #[allow(dead_code)] - fn get_string_or(&self, key: impl AsRef, default: String) -> String { - self.get_string(key).unwrap_or(default) - } - - fn get_int(&self, key: impl AsRef) -> Option { - self.get(key).and_then(|value| value.as_i64()) - } - - #[allow(dead_code)] - fn get_int_or(&self, key: impl AsRef, default: i64) -> i64 { - self.get_int(key).unwrap_or(default) - } -} - -pub struct OldSettings { - pub(crate) inner: Backend, -} - -impl JsonStore for OldSettings { - fn path() -> Result { - Ok(directories::settings_path()?) - } - - fn file_lock() -> &'static RwLock<()> { - &SETTINGS_FILE_LOCK - } - - fn data_lock() -> &'static RwLock> { - &SETTINGS_DATA - } - - fn new_from_backend(backend: Backend) -> Self { - match backend { - Backend::Global => Self { inner: Backend::Global }, - Backend::Memory(map) => Self { - inner: Backend::Memory(map), - }, - } - } - - fn map(&self) -> ReadGuard<'_, Map> { - match &self.inner { - Backend::Global => ReadGuard::Global(Self::data_lock().read()), - Backend::Memory(map) => ReadGuard::Memory(map), - } - } - - fn map_mut(&mut self) -> WriteGuard<'_, Map> { - match &mut self.inner { - Backend::Global => WriteGuard::Global(Self::data_lock().write()), - Backend::Memory(map) => WriteGuard::Memory(map), - } - } -} diff --git a/crates/chat-cli/src/settings/settings.rs b/crates/chat-cli/src/settings/settings.rs deleted file mode 100644 index 0e0e373ce0..0000000000 --- a/crates/chat-cli/src/settings/settings.rs +++ /dev/null @@ -1,231 +0,0 @@ -use std::sync::{ - Arc, - Mutex, -}; - -use serde::de::DeserializeOwned; -use serde_json::Map; - -use super::{ - JsonStore, - OldSettings, - Result, -}; - -#[derive(Debug, Clone, Default)] -pub struct Settings(inner::Inner); - -mod inner { - use std::sync::{ - Arc, - Mutex, - }; - - use serde_json::{ - Map, - Value, - }; - - #[derive(Debug, Clone, Default)] - pub enum Inner { - #[default] - Real, - Fake(Arc>>), - } -} - -impl Settings { - pub fn new() -> Self { - match cfg!(test) { - true => Self(inner::Inner::Fake(Arc::new(Mutex::new(Map::new())))), - false => Self(inner::Inner::Real), - } - } - - pub fn set_value(&self, key: impl Into, value: impl Into) -> Result<()> { - match &self.0 { - inner::Inner::Real => { - let mut settings = OldSettings::load()?; - settings.set(key, value); - settings.save_to_file()?; - Ok(()) - }, - inner::Inner::Fake(map) => { - map.lock()?.insert(key.into(), value.into()); - Ok(()) - }, - } - } - - pub fn remove_value(&self, key: impl AsRef) -> Result<()> { - match &self.0 { - inner::Inner::Real => { - let mut settings = OldSettings::load()?; - settings.remove(key); - settings.save_to_file()?; - Ok(()) - }, - inner::Inner::Fake(map) => { - map.lock()?.remove(key.as_ref()); - Ok(()) - }, - } - } - - pub fn get_value(&self, key: impl AsRef) -> Result> { - match &self.0 { - inner::Inner::Real => Ok(OldSettings::load()?.get(key.as_ref()).map(|v| v.clone())), - inner::Inner::Fake(map) => Ok(map.lock()?.get(key.as_ref()).cloned()), - } - } - - #[allow(dead_code)] - pub fn get(&self, key: impl AsRef) -> Result> { - match &self.0 { - inner::Inner::Real => { - let settings = OldSettings::load()?; - let v = settings.get(key); - match v.as_deref() { - Some(value) => Ok(Some(serde_json::from_value(value.clone())?)), - None => Ok(None), - } - }, - inner::Inner::Fake(map) => { - let value = map.lock()?.get(key.as_ref()).cloned(); - match value { - Some(value) => Ok(Some(serde_json::from_value(value)?)), - None => Ok(None), - } - }, - } - } - - pub fn get_bool(&self, key: impl AsRef) -> Result> { - match &self.0 { - inner::Inner::Real => Ok(OldSettings::load()?.get_bool(key.as_ref())), - inner::Inner::Fake(map) => Ok(map.lock()?.get(key.as_ref()).cloned().and_then(|v| v.as_bool())), - } - } - - pub fn get_bool_or(&self, key: impl AsRef, default: bool) -> bool { - self.get_bool(key).ok().flatten().unwrap_or(default) - } - - pub fn get_string(&self, key: impl AsRef) -> Result> { - match &self.0 { - inner::Inner::Real => Ok(OldSettings::load()?.get_string(key.as_ref())), - inner::Inner::Fake(map) => Ok(map - .lock()? - .get(key.as_ref()) - .cloned() - .and_then(|v| v.as_str().map(|s| s.to_owned()))), - } - } - - pub fn get_string_opt(&self, key: impl AsRef) -> Option { - self.get_string(key).ok().flatten() - } - - #[allow(dead_code)] - pub fn get_string_or(&self, key: impl AsRef, default: String) -> String { - self.get_string(key).ok().flatten().unwrap_or(default) - } - - pub fn get_int(&self, key: impl AsRef) -> Result> { - match &self.0 { - inner::Inner::Real => Ok(OldSettings::load()?.get_int(key.as_ref())), - inner::Inner::Fake(map) => Ok(map.lock()?.get(key.as_ref()).cloned().and_then(|v| v.as_i64())), - } - } - - #[allow(dead_code)] - pub fn get_int_or(&self, key: impl AsRef, default: i64) -> i64 { - self.get_int(key).ok().flatten().unwrap_or(default) - } -} - -pub fn set_value(key: impl Into, value: impl Into) -> Result<()> { - Settings::new().set_value(key, value) -} - -pub fn remove_value(key: impl AsRef) -> Result<()> { - Settings::new().remove_value(key) -} - -pub fn get_value(key: impl AsRef) -> Result> { - Settings::new().get_value(key) -} - -#[allow(dead_code)] -pub fn get(key: impl AsRef) -> Result> { - Settings::new().get(key) -} - -#[allow(dead_code)] -pub fn get_bool(key: impl AsRef) -> Result> { - Settings::new().get_bool(key) -} - -pub fn get_bool_or(key: impl AsRef, default: bool) -> bool { - Settings::new().get_bool_or(key, default) -} - -#[allow(dead_code)] -pub fn get_string(key: impl AsRef) -> Result> { - Settings::new().get_string(key) -} - -pub fn get_string_opt(key: impl AsRef) -> Option { - Settings::new().get_string_opt(key) -} - -#[allow(dead_code)] -pub fn get_string_or(key: impl AsRef, default: String) -> String { - Settings::new().get_string_or(key, default) -} - -pub fn get_int(key: impl AsRef) -> Result> { - Settings::new().get_int(key) -} - -#[allow(dead_code)] -pub fn get_int_or(key: impl AsRef, default: i64) -> i64 { - Settings::new().get_int_or(key, default) -} - -#[cfg(test)] -mod test { - use super::{ - Result, - Settings, - }; - - /// General read/write settings test - #[test] - fn test_settings() -> Result<()> { - let settings = Settings::new(); - - assert!(settings.get_value("test").unwrap().is_none()); - assert!(settings.get::("test").unwrap().is_none()); - settings.set_value("test", "hello :)")?; - assert!(settings.get_value("test").unwrap().is_some()); - assert!(settings.get::("test").unwrap().is_some()); - settings.remove_value("test")?; - assert!(settings.get_value("test").unwrap().is_none()); - assert!(settings.get::("test").unwrap().is_none()); - - assert!(!settings.get_bool_or("bool", false)); - settings.set_value("bool", true).unwrap(); - assert!(settings.get_bool("bool").unwrap().unwrap()); - - assert_eq!(settings.get_string_or("string", "hi".into()), "hi"); - settings.set_value("string", "hi").unwrap(); - assert_eq!(settings.get_string("string").unwrap().unwrap(), "hi"); - - assert_eq!(settings.get_int_or("int", 32), 32); - settings.set_value("int", 32).unwrap(); - assert_eq!(settings.get_int("int").unwrap().unwrap(), 32); - - Ok(()) - } -} diff --git a/crates/chat-cli/src/settings/sqlite.rs b/crates/chat-cli/src/settings/sqlite.rs deleted file mode 100644 index 34cd4fe8d4..0000000000 --- a/crates/chat-cli/src/settings/sqlite.rs +++ /dev/null @@ -1,436 +0,0 @@ -use std::ops::Deref; -use std::path::{ - Path, - PathBuf, -}; -use std::sync::LazyLock; - -use r2d2::Pool; -use r2d2_sqlite::SqliteConnectionManager; -use rusqlite::types::FromSql; -use rusqlite::{ - Connection, - Error, - ToSql, - params, -}; -use serde_json::Map; -use tracing::info; - -use super::error::DbOpenError; -use crate::settings::Result; -use crate::util::directories::fig_data_dir; - -const STATE_TABLE_NAME: &str = "state"; -const AUTH_TABLE_NAME: &str = "auth_kv"; - -pub static DATABASE: LazyLock> = LazyLock::new(|| { - let db = Db::new().map_err(|e| DbOpenError(e.to_string()))?; - db.migrate().map_err(|e| DbOpenError(e.to_string()))?; - Ok(db) -}); - -pub fn database() -> Result<&'static Db, DbOpenError> { - match DATABASE.as_ref() { - Ok(db) => Ok(db), - Err(err) => Err(err.clone()), - } -} - -#[derive(Debug)] -struct Migration { - name: &'static str, - sql: &'static str, -} - -macro_rules! migrations { - ($($name:expr),*) => {{ - &[ - $( - Migration { - name: $name, - sql: include_str!(concat!("sqlite_migrations/", $name, ".sql")), - } - ),* - ] - }}; -} - -const MIGRATIONS: &[Migration] = migrations![ - "000_migration_table", - "001_history_table", - "002_drop_history_in_ssh_docker", - "003_improved_history_timing", - "004_state_table", - "005_auth_table" -]; - -#[derive(Debug, Clone)] -pub struct Db { - pub(crate) pool: Pool, -} - -impl Db { - fn path() -> Result { - Ok(fig_data_dir()?.join("data.sqlite3")) - } - - pub fn new() -> Result { - Self::open(&Self::path()?) - } - - fn open(path: &Path) -> Result { - // make the parent dir if it doesnt exist - if let Some(parent) = path.parent() { - if !parent.exists() { - std::fs::create_dir_all(parent)?; - } - } - - let conn = SqliteConnectionManager::file(path); - let pool = Pool::builder().build(conn)?; - - // Check the unix permissions of the database file, set them to 0600 if they are not - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - let metadata = std::fs::metadata(path)?; - let mut permissions = metadata.permissions(); - if permissions.mode() & 0o777 != 0o600 { - tracing::debug!(?path, "Setting database file permissions to 0600"); - permissions.set_mode(0o600); - std::fs::set_permissions(path, permissions)?; - } - } - - Ok(Self { pool }) - } - - pub(crate) fn mock() -> Self { - let conn = SqliteConnectionManager::memory(); - let pool = Pool::builder().build(conn).unwrap(); - Self { pool } - } - - pub fn migrate(&self) -> Result<()> { - let mut conn = self.pool.get()?; - let transaction = conn.transaction()?; - - // select the max migration id - let max_id = max_migration(&transaction); - - for (version, migration) in MIGRATIONS.iter().enumerate() { - // skip migrations that already exist - match max_id { - Some(max_id) if max_id >= version as i64 => continue, - _ => (), - }; - - // execute the migration - transaction.execute_batch(migration.sql)?; - - info!(%version, name =% migration.name, "Applying migration"); - - // insert the migration entry - transaction.execute( - "INSERT INTO migrations (version, migration_time) VALUES (?1, strftime('%s', 'now'));", - params![version], - )?; - } - - // commit the transaction - transaction.commit()?; - - Ok(()) - } - - fn get_value(&self, table: &'static str, key: impl AsRef) -> Result> { - let conn = self.pool.get()?; - let mut stmt = conn.prepare(&format!("SELECT value FROM {table} WHERE key = ?1"))?; - match stmt.query_row([key.as_ref()], |row| row.get(0)) { - Ok(data) => Ok(Some(data)), - Err(Error::QueryReturnedNoRows) => Ok(None), - Err(err) => Err(err.into()), - } - } - - pub fn get_state_value(&self, key: impl AsRef) -> Result> { - self.get_value(STATE_TABLE_NAME, key) - } - - pub fn get_auth_value(&self, key: impl AsRef) -> Result> { - self.get_value(AUTH_TABLE_NAME, key) - } - - fn set_value(&self, table: &'static str, key: impl AsRef, value: T) -> Result<()> { - self.pool.get()?.execute( - &format!("INSERT OR REPLACE INTO {table} (key, value) VALUES (?1, ?2)"), - params![key.as_ref(), value], - )?; - Ok(()) - } - - pub fn set_state_value(&self, key: impl AsRef, value: impl Into) -> Result<()> { - self.set_value(STATE_TABLE_NAME, key, value.into()) - } - - pub fn set_auth_value(&self, key: impl AsRef, value: impl Into) -> Result<()> { - self.set_value(AUTH_TABLE_NAME, key, value.into()) - } - - fn unset_value(&self, table: &'static str, key: impl AsRef) -> Result<()> { - self.pool - .get()? - .execute(&format!("DELETE FROM {table} WHERE key = ?1"), [key.as_ref()])?; - Ok(()) - } - - pub fn unset_state_value(&self, key: impl AsRef) -> Result<()> { - self.unset_value(STATE_TABLE_NAME, key) - } - - pub fn unset_auth_value(&self, key: impl AsRef) -> Result<()> { - self.unset_value(AUTH_TABLE_NAME, key) - } - - fn is_value_set(&self, table: &'static str, key: impl AsRef) -> Result { - let conn = self.pool.get()?; - let mut stmt = conn.prepare(&format!("SELECT value FROM {table} WHERE key = ?1"))?; - match stmt.query_row([key.as_ref()], |_| Ok(())) { - Ok(()) => Ok(true), - Err(Error::QueryReturnedNoRows) => Ok(false), - Err(err) => Err(err.into()), - } - } - - #[allow(dead_code)] - pub fn is_state_value_set(&self, key: impl AsRef) -> Result { - self.is_value_set(STATE_TABLE_NAME, key) - } - - #[allow(dead_code)] - pub fn is_auth_value_set(&self, key: impl AsRef) -> Result { - self.is_value_set(AUTH_TABLE_NAME, key) - } - - fn all_values(&self, table: &'static str) -> Result> { - let conn = self.pool.get()?; - let mut stmt = conn.prepare(&format!("SELECT key, value FROM {table}"))?; - let rows = stmt.query_map([], |row| { - let key = row.get(0)?; - let value = row.get(1)?; - Ok((key, value)) - })?; - - let mut map = Map::new(); - for row in rows { - let (key, value) = row?; - map.insert(key, value); - } - - Ok(map) - } - - pub fn all_state_values(&self) -> Result> { - self.all_values(STATE_TABLE_NAME) - } - - // atomic style operations - - fn atomic_op( - &self, - key: impl AsRef, - op: impl FnOnce(&Option) -> Option, - ) -> Result> { - let mut conn = self.pool.get()?; - let tx = conn.transaction()?; - - let value = tx.query_row::, _, _>( - &format!("SELECT value FROM {STATE_TABLE_NAME} WHERE key = ?1"), - [key.as_ref()], - |row| row.get(0), - ); - - let value_0: Option = match value { - Ok(value) => value, - Err(Error::QueryReturnedNoRows) => None, - Err(err) => return Err(err.into()), - }; - - let value_1 = op(&value_0); - - if let Some(value) = value_1 { - tx.execute( - &format!("INSERT OR REPLACE INTO {STATE_TABLE_NAME} (key, value) VALUES (?1, ?2)"), - params![key.as_ref(), value], - )?; - } else { - tx.execute( - &format!("DELETE FROM {STATE_TABLE_NAME} WHERE key = ?1"), - [key.as_ref()], - )?; - } - - tx.commit()?; - - Ok(value_0) - } - - /// Atomically get the value of a key, then perform an or operation on it - /// and set the new value. If the key does not exist, set it to the or value. - pub fn atomic_bool_or(&self, key: impl AsRef, or: bool) -> Result { - self.atomic_op::(key, |val| match val { - // Some(val) => Some(serde_json::Value::Bool( || or)), - Some(serde_json::Value::Bool(b)) => Some(serde_json::Value::Bool(*b || or)), - Some(_) | None => Some(serde_json::Value::Bool(or)), - }) - .map(|val| val.and_then(|val| val.as_bool()).unwrap_or(false)) - } -} - -fn max_migration>(conn: &C) -> Option { - let mut stmt = conn.prepare("SELECT MAX(id) FROM migrations").ok()?; - stmt.query_row([], |row| row.get(0)).ok() -} - -#[cfg(test)] -mod tests { - use super::*; - - fn mock() -> Db { - let db = Db::mock(); - db.migrate().unwrap(); - db - } - - #[test] - fn test_migrate() { - let db = mock(); - - // assert migration count is correct - let max_migration = max_migration(&&*db.pool.get().unwrap()); - assert_eq!(max_migration, Some(MIGRATIONS.len() as i64)); - } - - #[test] - fn list_migrations() { - // Assert the migrations are in order - assert!(MIGRATIONS.windows(2).all(|w| w[0].name <= w[1].name)); - - // Assert the migrations start with their index - assert!( - MIGRATIONS - .iter() - .enumerate() - .all(|(i, m)| m.name.starts_with(&format!("{:03}_", i))) - ); - - // Assert all the files in migrations/ are in the list - let migration_folder = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("src/settings/sqlite_migrations"); - let migration_count = std::fs::read_dir(migration_folder).unwrap().count(); - assert_eq!(MIGRATIONS.len(), migration_count); - } - - #[test] - fn state_table_tests() { - let db = mock(); - - // set - db.set_state_value("test", "test").unwrap(); - db.set_state_value("int", 1).unwrap(); - db.set_state_value("float", 1.0).unwrap(); - db.set_state_value("bool", true).unwrap(); - db.set_state_value("null", ()).unwrap(); - db.set_state_value("array", vec![1, 2, 3]).unwrap(); - db.set_state_value("object", serde_json::json!({ "test": "test" })) - .unwrap(); - db.set_state_value("binary", b"test".to_vec()).unwrap(); - - // get - assert_eq!(db.get_state_value("test").unwrap().unwrap(), "test"); - assert_eq!(db.get_state_value("int").unwrap().unwrap(), 1); - assert_eq!(db.get_state_value("float").unwrap().unwrap(), 1.0); - assert_eq!(db.get_state_value("bool").unwrap().unwrap(), true); - assert_eq!(db.get_state_value("null").unwrap().unwrap(), serde_json::Value::Null); - assert_eq!( - db.get_state_value("array").unwrap().unwrap(), - serde_json::json!([1, 2, 3]) - ); - assert_eq!( - db.get_state_value("object").unwrap().unwrap(), - serde_json::json!({ "test": "test" }) - ); - assert_eq!( - db.get_state_value("binary").unwrap().unwrap(), - serde_json::json!(b"test".to_vec()) - ); - - // unset - db.unset_state_value("test").unwrap(); - db.unset_state_value("int").unwrap(); - - // is_set - assert!(!db.is_state_value_set("test").unwrap()); - assert!(!db.is_state_value_set("int").unwrap()); - assert!(db.is_state_value_set("float").unwrap()); - assert!(db.is_state_value_set("bool").unwrap()); - } - - #[test] - fn auth_table_tests() { - let db = mock(); - - db.set_auth_value("test", "test").unwrap(); - assert_eq!(db.get_auth_value("test").unwrap().unwrap(), "test"); - assert!(db.is_auth_value_set("test").unwrap()); - db.unset_auth_value("test").unwrap(); - assert!(!db.is_auth_value_set("test").unwrap()); - - assert_eq!(db.get_auth_value("test2").unwrap(), None); - assert!(!db.is_auth_value_set("test2").unwrap()); - } - - #[test] - fn db_open_time() { - let tempdir = tempfile::tempdir().unwrap(); - let path = tempdir.path().join("data.sqlite3"); - - // init the db - let db = Db::open(&path).unwrap(); - db.migrate().unwrap(); - drop(db); - - let test_count = 100; - - let instant = std::time::Instant::now(); - let db = Db::open(&path).unwrap(); - for _ in 0..test_count { - db.set_state_value("test", "test").unwrap(); - db.get_state_value("test").unwrap().unwrap(); - } - let elapsed = instant.elapsed() / test_count; - println!("time: {:?}", elapsed); - } - - #[test] - fn test_atomic_bool() { - let key = "test"; - let db = mock(); - - let cases = [ - (None, false, false, false), - (None, true, false, true), - (Some(false), false, false, false), - (Some(false), true, false, true), - (Some(true), false, true, true), - (Some(true), true, true, true), - ]; - - for (a, b, c, d) in cases { - db.set_state_value(key, a).unwrap(); - assert_eq!(db.atomic_bool_or(key, b).unwrap(), c); - assert_eq!(db.get_state_value(key).unwrap().unwrap(), d); - db.unset_state_value(key).unwrap(); - } - } -} diff --git a/crates/chat-cli/src/settings/state.rs b/crates/chat-cli/src/settings/state.rs deleted file mode 100644 index 0a0685276b..0000000000 --- a/crates/chat-cli/src/settings/state.rs +++ /dev/null @@ -1,190 +0,0 @@ -use serde::de::DeserializeOwned; -use serde_json::{ - Map, - Value, -}; - -use super::sqlite::{ - Db, - database, -}; -use crate::settings::Result; - -#[derive(Debug, Clone, Default)] -pub struct State(inner::Inner); - -mod inner { - use super::*; - - #[derive(Debug, Clone, Default)] - pub enum Inner { - #[default] - Real, - Fake(Db), - } -} - -impl State { - pub fn new() -> Self { - if cfg!(test) { - let db = Db::mock(); - db.migrate().unwrap(); - return Self(inner::Inner::Fake(db)); - } - - Self::default() - } - - fn database(&self) -> Result<&Db> { - match &self.0 { - inner::Inner::Real => Ok(database()?), - inner::Inner::Fake(db) => Ok(db), - } - } - - pub fn all(&self) -> Result> { - self.database()?.all_state_values() - } - - pub fn set_value(&self, key: impl AsRef, value: impl Into) -> Result<()> { - self.database()?.set_state_value(key, value)?; - Ok(()) - } - - pub fn remove_value(&self, key: impl AsRef) -> Result<()> { - self.database()?.unset_state_value(key)?; - Ok(()) - } - - pub fn get_value(&self, key: impl AsRef) -> Result> { - self.database()?.get_state_value(key) - } - - pub fn get(&self, key: impl AsRef) -> Result> { - Ok(self - .database()? - .get_state_value(key)? - .map(|value| serde_json::from_value(value.clone())) - .transpose()?) - } - - pub fn get_bool(&self, key: impl AsRef) -> Result> { - Ok(self.database()?.get_state_value(key)?.and_then(|value| value.as_bool())) - } - - pub fn get_bool_or(&self, key: impl AsRef, default: bool) -> bool { - self.get_bool(key).ok().flatten().unwrap_or(default) - } - - pub fn get_string(&self, key: impl AsRef) -> Result> { - Ok(self.database()?.get_state_value(key)?.and_then(|value| match value { - Value::String(s) => Some(s), - _ => None, - })) - } - - pub fn get_string_or(&self, key: impl AsRef, default: impl Into) -> String { - self.get_string(key).ok().flatten().unwrap_or_else(|| default.into()) - } - - pub fn get_int(&self, key: impl AsRef) -> Result> { - Ok(self.database()?.get_state_value(key)?.and_then(|value| value.as_i64())) - } - - pub fn get_int_or(&self, key: impl AsRef, default: i64) -> i64 { - self.get_int(key).ok().flatten().unwrap_or(default) - } - - // Atomic style operations - - pub fn atomic_bool_or(&self, key: impl AsRef, or: bool) -> Result { - self.database()?.atomic_bool_or(key, or) - } -} - -#[allow(dead_code)] -pub fn all() -> Result> { - State::new().all() -} - -pub fn set_value(key: impl AsRef, value: impl Into) -> Result<()> { - State::new().set_value(key, value) -} - -pub fn remove_value(key: impl AsRef) -> Result<()> { - State::new().remove_value(key) -} - -pub fn get_value(key: impl AsRef) -> Result> { - State::new().get_value(key) -} - -pub fn get(key: impl AsRef) -> Result> { - State::new().get(key) -} - -#[allow(dead_code)] -pub fn get_bool(key: impl AsRef) -> Result> { - State::new().get_bool(key) -} - -#[allow(dead_code)] -pub fn get_bool_or(key: impl AsRef, default: bool) -> bool { - State::new().get_bool_or(key, default) -} - -pub fn get_string(key: impl AsRef) -> Result> { - State::new().get_string(key) -} - -#[allow(dead_code)] -pub fn get_string_or(key: impl AsRef, default: impl Into) -> String { - State::new().get_string_or(key, default) -} - -#[allow(dead_code)] -pub fn get_int(key: impl AsRef) -> Result> { - State::new().get_int(key) -} - -#[allow(dead_code)] -pub fn get_int_or(key: impl AsRef, default: i64) -> i64 { - State::new().get_int_or(key, default) -} - -#[cfg(test)] -mod tests { - use super::{ - Result, - State, - }; - - /// General read/write state test - #[test] - fn test_state() -> Result<()> { - let state = State::new(); - - assert!(state.get_value("test").unwrap().is_none()); - assert!(state.get::("test").unwrap().is_none()); - state.set_value("test", "hello :)")?; - assert!(state.get_value("test").unwrap().is_some()); - assert!(state.get::("test").unwrap().is_some()); - state.remove_value("test")?; - assert!(state.get_value("test").unwrap().is_none()); - assert!(state.get::("test").unwrap().is_none()); - - assert!(!state.get_bool_or("bool", false)); - state.set_value("bool", true).unwrap(); - assert!(state.get_bool("bool").unwrap().unwrap()); - - assert_eq!(state.get_string_or("string", "hi"), "hi"); - state.set_value("string", "hi").unwrap(); - assert_eq!(state.get_string("string").unwrap().unwrap(), "hi"); - - assert_eq!(state.get_int_or("int", 32), 32); - state.set_value("int", 32).unwrap(); - assert_eq!(state.get_int("int").unwrap().unwrap(), 32); - - Ok(()) - } -} diff --git a/crates/chat-cli/src/telemetry/cognito.rs b/crates/chat-cli/src/telemetry/cognito.rs index 6c23cde517..ad8dd155bc 100644 --- a/crates/chat-cli/src/telemetry/cognito.rs +++ b/crates/chat-cli/src/telemetry/cognito.rs @@ -10,21 +10,14 @@ use aws_sdk_cognitoidentity::primitives::{ }; use crate::aws_common::app_name; +use crate::database::{ + CredentialsJson, + Database, +}; use crate::telemetry::TelemetryStage; -const CREDENTIALS_KEY: &str = "telemetry-cognito-credentials"; - -const DATE_TIME_FORMAT: DateTimeFormat = DateTimeFormat::DateTime; - -#[derive(Debug, serde::Deserialize, serde::Serialize)] -struct CredentialsJson { - pub access_key_id: Option, - pub secret_key: Option, - pub session_token: Option, - pub expiration: Option, -} - -pub(crate) async fn get_cognito_credentials_send( +pub async fn get_cognito_credentials_send( + database: &mut Database, telemetry_stage: &TelemetryStage, ) -> Result { let conf = aws_sdk_cognitoidentity::Config::builder() @@ -54,14 +47,7 @@ pub(crate) async fn get_cognito_credentials_send( "no credentials from get_credentials_for_identity", ))?; - if let Ok(json) = serde_json::to_value(CredentialsJson { - access_key_id: credentials.access_key_id.clone(), - secret_key: credentials.secret_key.clone(), - session_token: credentials.session_token.clone(), - expiration: credentials.expiration.and_then(|t| t.fmt(DATE_TIME_FORMAT).ok()), - }) { - crate::settings::state::set_value(CREDENTIALS_KEY, json).ok(); - } + database.set_credentials_entry(&credentials).ok(); let Some(access_key_id) = credentials.access_key_id else { return Err(CredentialsError::provider_error("access key id not found")); @@ -80,22 +66,26 @@ pub(crate) async fn get_cognito_credentials_send( )) } -pub(crate) async fn get_cognito_credentials(telemetry_stage: &TelemetryStage) -> Result { - match crate::settings::state::get_string(CREDENTIALS_KEY).ok().flatten() { - Some(creds) => { - let CredentialsJson { - access_key_id, - secret_key, - session_token, - expiration, - }: CredentialsJson = serde_json::from_str(&creds).map_err(CredentialsError::provider_error)?; - +pub async fn get_cognito_credentials( + database: &mut Database, + telemetry_stage: &TelemetryStage, +) -> Result { + match database + .get_credentials_entry() + .map_err(CredentialsError::provider_error)? + { + Some(CredentialsJson { + access_key_id, + secret_key, + session_token, + expiration, + }) => { let Some(access_key_id) = access_key_id else { - return get_cognito_credentials_send(telemetry_stage).await; + return get_cognito_credentials_send(database, telemetry_stage).await; }; let Some(secret_key) = secret_key else { - return get_cognito_credentials_send(telemetry_stage).await; + return get_cognito_credentials_send(database, telemetry_stage).await; }; Ok(Credentials::new( @@ -103,23 +93,23 @@ pub(crate) async fn get_cognito_credentials(telemetry_stage: &TelemetryStage) -> secret_key, session_token, expiration - .and_then(|s| DateTime::from_str(&s, DATE_TIME_FORMAT).ok()) + .and_then(|s| DateTime::from_str(&s, DateTimeFormat::DateTime).ok()) .and_then(|dt| dt.try_into().ok()), "", )) }, - None => get_cognito_credentials_send(telemetry_stage).await, + None => get_cognito_credentials_send(database, telemetry_stage).await, } } #[derive(Debug)] -pub(crate) struct CognitoProvider { - telemetry_stage: TelemetryStage, +pub struct CognitoProvider { + credentials: Credentials, } impl CognitoProvider { - pub(crate) fn new(telemetry_stage: TelemetryStage) -> CognitoProvider { - CognitoProvider { telemetry_stage } + pub fn new(credentials: Credentials) -> CognitoProvider { + CognitoProvider { credentials } } } @@ -128,7 +118,7 @@ impl provider::ProvideCredentials for CognitoProvider { where Self: 'a, { - provider::future::ProvideCredentials::new(get_cognito_credentials(&self.telemetry_stage)) + provider::future::ProvideCredentials::new(async { Ok(self.credentials.clone()) }) } } @@ -139,7 +129,9 @@ mod test { #[tokio::test] async fn pools() { for telemetry_stage in [TelemetryStage::BETA, TelemetryStage::EXTERNAL_PROD] { - get_cognito_credentials_send(&telemetry_stage).await.unwrap(); + get_cognito_credentials_send(&mut Database::new().await.unwrap(), &telemetry_stage) + .await + .unwrap(); } } } diff --git a/crates/chat-cli/src/telemetry/core.rs b/crates/chat-cli/src/telemetry/core.rs index 86091ebfc3..ab143c066d 100644 --- a/crates/chat-cli/src/telemetry/core.rs +++ b/crates/chat-cli/src/telemetry/core.rs @@ -54,11 +54,6 @@ impl Event { } } - pub fn with_credential_start_url(mut self, credential_start_url: String) -> Self { - self.credential_start_url = Some(credential_start_url); - self - } - pub fn into_metric_datum(self) -> Option { match self.ty { EventType::UserLoggedIn {} => Some( @@ -286,6 +281,56 @@ pub enum EventType { }, } +#[derive(Debug)] +pub struct ToolUseEventBuilder { + pub conversation_id: String, + pub utterance_id: Option, + pub user_input_id: Option, + pub tool_use_id: Option, + pub tool_name: Option, + pub is_accepted: bool, + pub is_success: Option, + pub is_valid: Option, + pub is_custom_tool: bool, + pub input_token_size: Option, + pub output_token_size: Option, + pub custom_tool_call_latency: Option, +} + +impl ToolUseEventBuilder { + pub fn new(conv_id: String, tool_use_id: String) -> Self { + Self { + conversation_id: conv_id, + utterance_id: None, + user_input_id: None, + tool_use_id: Some(tool_use_id), + tool_name: None, + is_accepted: false, + is_success: None, + is_valid: None, + is_custom_tool: false, + input_token_size: None, + output_token_size: None, + custom_tool_call_latency: None, + } + } + + pub fn utterance_id(mut self, id: Option) -> Self { + self.utterance_id = id; + self + } + + pub fn set_tool_use_id(mut self, id: String) -> Self { + self.tool_use_id.replace(id); + self + } + + pub fn set_tool_name(mut self, name: String) -> Self { + self.tool_name.replace(name); + self + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub enum SuggestionState { Accept, diff --git a/crates/chat-cli/src/telemetry/mod.rs b/crates/chat-cli/src/telemetry/mod.rs index 896109590f..430a327204 100644 --- a/crates/chat-cli/src/telemetry/mod.rs +++ b/crates/chat-cli/src/telemetry/mod.rs @@ -3,15 +3,14 @@ pub mod core; pub mod definitions; pub mod endpoint; mod install_method; -mod util; -use std::sync::LazyLock; +use core::ToolUseEventBuilder; +use std::str::FromStr; use amzn_codewhisperer_client::types::{ ChatAddMessageEvent, IdeCategory, OperatingSystem, - OptOutPreference, TelemetryEvent, UserContext, }; @@ -26,26 +25,36 @@ use amzn_toolkit_telemetry_client::{ Config, }; use aws_credential_types::provider::SharedCredentialsProvider; -use cognito::CognitoProvider; +use cognito::{ + CognitoProvider, + get_cognito_credentials, +}; use endpoint::StaticEndpoint; pub use install_method::{ InstallMethod, get_install_method, }; -use tokio::sync::{ - Mutex, - OnceCell, -}; -use tokio::task::JoinSet; +use tokio::sync::mpsc; +use tokio::task::JoinHandle; use tracing::{ debug, error, + trace, +}; +use uuid::{ + Uuid, + uuid, }; -use util::telemetry_is_disabled; -use uuid::Uuid; use crate::api_client::Client as CodewhispererClient; use crate::aws_common::app_name; +use crate::cli::CliRootCommands; +use crate::database::settings::Setting; +use crate::database::{ + Database, + DatabaseError, +}; +use crate::platform::Env; use crate::telemetry::core::Event; pub use crate::telemetry::core::{ EventType, @@ -57,21 +66,33 @@ use crate::util::system_info::os_version; #[derive(thiserror::Error, Debug)] pub enum TelemetryError { #[error(transparent)] - ClientError(#[from] amzn_toolkit_telemetry_client::operation::post_metrics::PostMetricsError), + Client(Box), + #[error(transparent)] + Send(Box>), + #[error(transparent)] + Auth(#[from] crate::auth::AuthError), + #[error(transparent)] + Join(#[from] tokio::task::JoinError), + #[error(transparent)] + Database(#[from] DatabaseError), } -const PRODUCT: &str = "CodeWhisperer"; -const PRODUCT_VERSION: &str = env!("CARGO_PKG_VERSION"); +impl From for TelemetryError { + fn from(value: amzn_toolkit_telemetry_client::operation::post_metrics::PostMetricsError) -> Self { + Self::Client(Box::new(value)) + } +} -// TODO: DO NOT USE OUTSIDE THIS FILE. Currently being used in one other place as part of a rewrite. -// but. -pub async fn client() -> &'static Client { - static CLIENT: OnceCell = OnceCell::const_new(); - CLIENT - .get_or_init(|| async { Client::new(TelemetryStage::EXTERNAL_PROD).await }) - .await +impl From> for TelemetryError { + fn from(value: mpsc::error::SendError) -> Self { + Self::Send(Box::new(value)) + } } +const PRODUCT: &str = "CodeWhisperer"; +const PRODUCT_VERSION: &str = env!("CARGO_PKG_VERSION"); +const CLIENT_ID_ENV_VAR: &str = "Q_TELEMETRY_CLIENT_ID"; + /// A IDE toolkit telemetry stage #[derive(Debug, Clone)] #[non_exhaustive] @@ -82,7 +103,7 @@ pub struct TelemetryStage { } impl TelemetryStage { - #[allow(dead_code)] + #[cfg(test)] const BETA: Self = Self::new( "https://7zftft3lj2.execute-api.us-east-1.amazonaws.com/Beta", "us-east-1:db7bfc9f-8ecd-4fbb-bea7-280c16069a99", @@ -103,83 +124,254 @@ impl TelemetryStage { } } -static JOIN_SET: LazyLock>> = LazyLock::new(|| Mutex::new(JoinSet::new())); +#[derive(Debug)] +pub struct TelemetryThread { + handle: Option>, + tx: mpsc::UnboundedSender, +} -/// Joins all current telemetry events -pub async fn finish_telemetry() { - let mut set = JOIN_SET.lock().await; - while let Some(res) = set.join_next().await { - if let Err(err) = res { - error!(%err, "Failed to join telemetry event"); +impl Clone for TelemetryThread { + fn clone(&self) -> Self { + Self { + handle: None, + tx: self.tx.clone(), } } } -/// Joins all current telemetry events and panics if any fail to join -#[cfg(test)] -pub async fn finish_telemetry_unwrap() { - let mut set = JOIN_SET.lock().await; - while let Some(res) = set.join_next().await { - res.unwrap(); +impl TelemetryThread { + pub async fn new(env: &Env, database: &mut Database) -> Result { + let telemetry_client = TelemetryClient::new(env, database).await?; + let (tx, mut rx) = mpsc::unbounded_channel(); + let handle = tokio::spawn(async move { + while let Some(event) = rx.recv().await { + trace!("Sending telemetry event: {:?}", event); + telemetry_client.send_event(event).await; + } + }); + + Ok(Self { + handle: Some(handle), + tx, + }) + } + + pub async fn finish(self) -> Result<(), TelemetryError> { + drop(self.tx); + if let Some(handle) = self.handle { + handle.await?; + } + + Ok(()) + } + + pub fn send_user_logged_in(&self) -> Result<(), TelemetryError> { + Ok(self.tx.send(Event::new(EventType::UserLoggedIn {}))?) + } + + pub fn send_cli_subcommand_executed(&self, subcommand: Option<&CliRootCommands>) -> Result<(), TelemetryError> { + let subcommand = match subcommand { + Some(subcommand) => subcommand.name(), + None => "chat", + } + .to_owned(); + + Ok(self + .tx + .send(Event::new(EventType::CliSubcommandExecuted { subcommand }))?) + } + + pub fn send_chat_added_message( + &self, + conversation_id: String, + message_id: String, + context_file_length: Option, + ) -> Result<(), TelemetryError> { + Ok(self.tx.send(Event::new(EventType::ChatAddedMessage { + conversation_id, + message_id, + context_file_length, + }))?) + } + + pub fn send_tool_use_suggested(&self, event: ToolUseEventBuilder) -> Result<(), TelemetryError> { + Ok(self.tx.send(Event::new(EventType::ToolUseSuggested { + conversation_id: event.conversation_id, + utterance_id: event.utterance_id, + user_input_id: event.user_input_id, + tool_use_id: event.tool_use_id, + tool_name: event.tool_name, + is_accepted: event.is_accepted, + is_success: event.is_success, + is_valid: event.is_valid, + is_custom_tool: event.is_custom_tool, + input_token_size: event.input_token_size, + output_token_size: event.output_token_size, + custom_tool_call_latency: event.custom_tool_call_latency, + }))?) } -} -fn opt_out_preference() -> OptOutPreference { - if telemetry_is_disabled() { - OptOutPreference::OptOut - } else { - OptOutPreference::OptIn + pub fn send_mcp_server_init( + &self, + conversation_id: String, + init_failure_reason: Option, + number_of_tools: usize, + ) -> Result<(), TelemetryError> { + Ok(self.tx.send(Event::new(crate::telemetry::EventType::McpServerInit { + conversation_id, + init_failure_reason, + number_of_tools, + }))?) + } + + pub fn send_did_select_profile( + &self, + source: QProfileSwitchIntent, + amazonq_profile_region: String, + result: TelemetryResult, + sso_region: Option, + profile_count: Option, + ) -> Result<(), TelemetryError> { + Ok(self.tx.send(Event::new(EventType::DidSelectProfile { + source, + amazonq_profile_region, + result, + sso_region, + profile_count, + }))?) + } + + pub fn send_profile_state( + &self, + source: QProfileSwitchIntent, + amazonq_profile_region: String, + result: TelemetryResult, + sso_region: Option, + ) -> Result<(), TelemetryError> { + Ok(self.tx.send(Event::new(EventType::ProfileState { + source, + amazonq_profile_region, + result, + sso_region, + }))?) } } #[derive(Debug, Clone)] -pub struct Client { +struct TelemetryClient { client_id: Uuid, + telemetry_enabled: bool, + codewhisperer_client: CodewhispererClient, toolkit_telemetry_client: Option, - codewhisperer_client: Option, } -impl Client { - pub async fn new(telemetry_stage: TelemetryStage) -> Self { - let client_id = util::get_client_id(); +impl TelemetryClient { + async fn new(env: &Env, database: &mut Database) -> Result { + let telemetry_enabled = !cfg!(test) + && env.get_os("Q_DISABLE_TELEMETRY").is_none() + && database.settings.get_bool(Setting::TelemetryEnabled).unwrap_or(true); + + // If telemetry is disabled we do not emit using toolkit_telemetry + let toolkit_telemetry_client = match telemetry_enabled { + true => match get_cognito_credentials(database, &TelemetryStage::EXTERNAL_PROD).await { + Ok(credentials) => Some(ToolkitTelemetryClient::from_conf( + Config::builder() + .http_client(crate::aws_common::http_client::client()) + .behavior_version(BehaviorVersion::v2025_01_17()) + .endpoint_resolver(StaticEndpoint(TelemetryStage::EXTERNAL_PROD.endpoint)) + .app_name(app_name()) + .region(TelemetryStage::EXTERNAL_PROD.region.clone()) + .credentials_provider(SharedCredentialsProvider::new(CognitoProvider::new(credentials))) + .build(), + )), + Err(err) => { + error!("Failed to acquire cognito credentials: {err}"); + None + }, + }, + false => None, + }; - if cfg!(test) { - return Self { - client_id, - toolkit_telemetry_client: None, - codewhisperer_client: CodewhispererClient::new().await.ok(), - }; - } + fn client_id(env: &Env, database: &mut Database, telemetry_enabled: bool) -> Result { + if !telemetry_enabled { + return Ok(uuid!("ffffffff-ffff-ffff-ffff-ffffffffffff")); + } - let toolkit_telemetry_client = Some(amzn_toolkit_telemetry_client::Client::from_conf( - Config::builder() - .http_client(crate::aws_common::http_client::client()) - .behavior_version(BehaviorVersion::v2025_01_17()) - .endpoint_resolver(StaticEndpoint(telemetry_stage.endpoint)) - .app_name(app_name()) - .region(telemetry_stage.region.clone()) - .credentials_provider(SharedCredentialsProvider::new(CognitoProvider::new(telemetry_stage))) - .build(), - )); - let codewhisperer_client = CodewhispererClient::new().await.ok(); + if let Ok(client_id) = env.get(CLIENT_ID_ENV_VAR) { + if let Ok(uuid) = Uuid::from_str(&client_id) { + return Ok(uuid); + } + } - Self { - client_id, - toolkit_telemetry_client, - codewhisperer_client, + Ok(match database.get_client_id()? { + Some(uuid) => uuid, + None => { + let uuid = database + .settings + .get_string(Setting::OldClientId) + .and_then(|id| Uuid::try_parse(&id).ok()) + .unwrap_or_else(Uuid::new_v4); + + if let Err(err) = database.set_client_id(uuid) { + error!(%err, "Failed to set client id in state"); + } + + uuid + }, + }) } - } - /// TODO: DO NOT USE OUTSIDE THIS FILE - pub async fn send_event(&self, event: Event) { - if telemetry_is_disabled() { - return; - } + Ok(Self { + client_id: client_id(env, database, telemetry_enabled)?, + telemetry_enabled, + toolkit_telemetry_client, + codewhisperer_client: CodewhispererClient::new(database, None).await?, + }) + } + async fn send_event(&self, event: Event) { + // This client will exist when telemetry is disabled. self.send_cw_telemetry_event(&event).await; + + // This client won't exist when telemetry is disabled. self.send_telemetry_toolkit_metric(event).await; } + async fn send_cw_telemetry_event(&self, event: &Event) { + if let EventType::ChatAddedMessage { + conversation_id, + message_id, + .. + } = &event.ty + { + let user_context = self.user_context().unwrap(); + + let chat_add_message_event = match ChatAddMessageEvent::builder() + .conversation_id(conversation_id) + .message_id(message_id) + .build() + { + Ok(event) => event, + Err(err) => { + error!(err =% DisplayErrorContext(err), "Failed to send telemetry event"); + return; + }, + }; + + if let Err(err) = self + .codewhisperer_client + .send_telemetry_event( + TelemetryEvent::ChatAddMessageEvent(chat_add_message_event), + user_context, + self.telemetry_enabled, + ) + .await + { + error!(err =% DisplayErrorContext(err), "Failed to send telemetry event"); + } + } + } + async fn send_telemetry_toolkit_metric(&self, event: Event) { let Some(toolkit_telemetry_client) = self.toolkit_telemetry_client.clone() else { return; @@ -189,45 +381,24 @@ impl Client { return; }; - let mut set = JOIN_SET.lock().await; - set.spawn({ - async move { - let product = AwsProduct::CodewhispererTerminal; - let product_version = env!("CARGO_PKG_VERSION"); - let os = std::env::consts::OS; - let os_architecture = std::env::consts::ARCH; - let os_version = os_version().map(|v| v.to_string()).unwrap_or_default(); - let metric_name = metric_datum.metric_name().to_owned(); - - debug!(?product, ?metric_datum, "Posting metrics"); - if let Err(err) = toolkit_telemetry_client - .post_metrics() - .aws_product(product) - .aws_product_version(product_version) - .client_id(client_id) - .os(os) - .os_architecture(os_architecture) - .os_version(os_version) - .metric_data(metric_datum) - .send() - .await - .map_err(DisplayErrorContext) - { - error!(%err, ?metric_name, "Failed to post metric"); - } - } - }); - } - - async fn send_cw_telemetry_event(&self, event: &Event) { - if let EventType::ChatAddedMessage { - conversation_id, - message_id, - .. - } = &event.ty + let product = AwsProduct::CodewhispererTerminal; + let metric_name = metric_datum.metric_name().to_owned(); + + debug!(?product, ?metric_datum, "Posting metrics"); + if let Err(err) = toolkit_telemetry_client + .post_metrics() + .aws_product(product) + .aws_product_version(env!("CARGO_PKG_VERSION")) + .client_id(client_id) + .os(std::env::consts::OS) + .os_architecture(std::env::consts::ARCH) + .os_version(os_version().map(|v| v.to_string()).unwrap_or_default()) + .metric_data(metric_datum) + .send() + .await + .map_err(DisplayErrorContext) { - self.send_cw_telemetry_chat_add_message_event(conversation_id.clone(), message_id.clone()) - .await; + error!(%err, ?metric_name, "Failed to post metric"); } } @@ -257,130 +428,6 @@ impl Client { }, } } - - async fn send_cw_telemetry_chat_add_message_event(&self, conversation_id: String, message_id: String) { - let Some(codewhisperer_client) = self.codewhisperer_client.clone() else { - return; - }; - let user_context = self.user_context().unwrap(); - let opt_out_preference = opt_out_preference(); - - let chat_add_message_event = match ChatAddMessageEvent::builder() - .conversation_id(conversation_id) - .message_id(message_id) - .build() - { - Ok(event) => event, - Err(err) => { - error!(err =% DisplayErrorContext(err), "Failed to send telemetry event"); - return; - }, - }; - - let mut set = JOIN_SET.lock().await; - set.spawn(async move { - if let Err(err) = codewhisperer_client - .send_telemetry_event( - TelemetryEvent::ChatAddMessageEvent(chat_add_message_event), - user_context, - opt_out_preference, - ) - .await - { - error!(err =% DisplayErrorContext(err), "Failed to send telemetry event"); - } - }); - } -} - -pub async fn send_user_logged_in() { - client().await.send_event(Event::new(EventType::UserLoggedIn {})).await; -} - -pub async fn send_refresh_credentials(credential_start_url: String, request_id: String, oauth_flow: String) { - client() - .await - .send_event( - Event::new(EventType::RefreshCredentials { - request_id, - result: TelemetryResult::Succeeded, - reason: None, - oauth_flow, - }) - .with_credential_start_url(credential_start_url), - ) - .await; -} - -pub async fn send_cli_subcommand_executed(subcommand: impl Into) { - client() - .await - .send_event(Event::new(EventType::CliSubcommandExecuted { - subcommand: subcommand.into(), - })) - .await; -} - -pub async fn send_chat_added_message(conversation_id: String, message_id: String, context_file_length: Option) { - client() - .await - .send_event(Event::new(EventType::ChatAddedMessage { - conversation_id, - message_id, - context_file_length, - })) - .await; -} - -pub async fn send_mcp_server_init( - conversation_id: String, - init_failure_reason: Option, - number_of_tools: usize, -) { - client() - .await - .send_event(Event::new(crate::telemetry::EventType::McpServerInit { - conversation_id, - init_failure_reason, - number_of_tools, - })) - .await; -} - -pub async fn send_did_select_profile( - source: QProfileSwitchIntent, - amazonq_profile_region: String, - result: TelemetryResult, - sso_region: Option, - profile_count: Option, -) { - client() - .await - .send_event(Event::new(EventType::DidSelectProfile { - source, - amazonq_profile_region, - result, - sso_region, - profile_count, - })) - .await; -} - -pub async fn send_profile_state( - source: QProfileSwitchIntent, - amazonq_profile_region: String, - result: TelemetryResult, - sso_region: Option, -) { - client() - .await - .send_event(Event::new(EventType::ProfileState { - source, - amazonq_profile_region, - result, - sso_region, - })) - .await; } #[cfg(test)] @@ -391,7 +438,8 @@ mod test { #[tokio::test] async fn client_context() { - let client = client().await; + let mut database = Database::new().await.unwrap(); + let client = TelemetryClient::new(&Env::new(), &mut database).await.unwrap(); let context = client.user_context().unwrap(); assert_eq!(context.ide_category, IdeCategory::Cli); @@ -411,7 +459,10 @@ mod test { #[tokio::test] #[ignore = "needs auth which is not in CI"] async fn test_send() { - finish_telemetry_unwrap().await; + let mut database = Database::new().await.unwrap(); + let thread = TelemetryThread::new(&Env::new(), &mut database).await.unwrap(); + thread.send_user_logged_in().ok(); + drop(thread); assert!(!logs_contain("ERROR")); assert!(!logs_contain("error")); @@ -424,11 +475,18 @@ mod test { #[tokio::test] #[ignore = "needs auth which is not in CI"] async fn test_all_telemetry() { - send_user_logged_in().await; - send_cli_subcommand_executed("doctor").await; - send_chat_added_message("debug".to_owned(), "debug".to_owned(), Some(123)).await; + let mut database = Database::new().await.unwrap(); + let thread = TelemetryThread::new(&Env::new(), &mut database).await.unwrap(); + + thread.send_user_logged_in().ok(); + thread + .send_cli_subcommand_executed(Some(&CliRootCommands::Version { changelog: None })) + .ok(); + thread + .send_chat_added_message("version".to_owned(), "version".to_owned(), Some(123)) + .ok(); - finish_telemetry_unwrap().await; + drop(thread); assert!(!logs_contain("ERROR")); assert!(!logs_contain("error")); @@ -440,11 +498,10 @@ mod test { #[tokio::test] #[ignore = "needs auth which is not in CI"] async fn test_without_optout() { - let client = Client::new(TelemetryStage::BETA).await; + let mut database = Database::new().await.unwrap(); + let client = TelemetryClient::new(&Env::new(), &mut database).await.unwrap(); client .codewhisperer_client - .as_ref() - .unwrap() .send_telemetry_event( TelemetryEvent::ChatAddMessageEvent( ChatAddMessageEvent::builder() @@ -454,7 +511,7 @@ mod test { .unwrap(), ), client.user_context().unwrap(), - OptOutPreference::OptIn, + false, ) .await .unwrap(); diff --git a/crates/chat-cli/src/telemetry/util.rs b/crates/chat-cli/src/telemetry/util.rs deleted file mode 100644 index c3452c68e1..0000000000 --- a/crates/chat-cli/src/telemetry/util.rs +++ /dev/null @@ -1,159 +0,0 @@ -use std::str::FromStr; - -use tracing::error; -use uuid::{ - Uuid, - uuid, -}; - -use crate::platform::Env; -use crate::settings::{ - Settings, - State, -}; - -const CLIENT_ID_STATE_KEY: &str = "telemetryClientId"; -const CLIENT_ID_ENV_VAR: &str = "Q_TELEMETRY_CLIENT_ID"; - -pub(crate) fn telemetry_is_disabled() -> bool { - let is_test = cfg!(test); - telemetry_is_disabled_inner(is_test, &Env::new(), &Settings::new()) -} - -/// Returns whether or not the user has disabled telemetry through settings or environment -fn telemetry_is_disabled_inner(is_test: bool, env: &Env, settings: &Settings) -> bool { - let env_var = env.get_os("Q_DISABLE_TELEMETRY").is_some(); - let setting = !settings - .get_value("telemetry.enabled") - .ok() - .flatten() - .and_then(|v| v.as_bool()) - .unwrap_or(true); - !is_test && (env_var || setting) -} - -pub(crate) fn get_client_id() -> Uuid { - get_client_id_inner(cfg!(test), &Env::new(), &State::new(), &Settings::new()) -} - -/// Generates or gets the client id and caches the result -/// -/// Based on: -pub(crate) fn get_client_id_inner(is_test: bool, env: &Env, state: &State, settings: &Settings) -> Uuid { - if is_test { - return uuid!("ffffffff-ffff-ffff-ffff-ffffffffffff"); - } - - if telemetry_is_disabled_inner(is_test, env, settings) { - return uuid!("11111111-1111-1111-1111-111111111111"); - } - - if let Ok(client_id) = env.get(CLIENT_ID_ENV_VAR) { - if let Ok(uuid) = Uuid::from_str(&client_id) { - return uuid; - } - } - - let state_uuid = state - .get_string(CLIENT_ID_STATE_KEY) - .ok() - .flatten() - .and_then(|s| Uuid::from_str(&s).ok()); - - match state_uuid { - Some(uuid) => uuid, - None => { - let uuid = old_client_id_inner(settings).unwrap_or_else(Uuid::new_v4); - if let Err(err) = state.set_value(CLIENT_ID_STATE_KEY, uuid.to_string()) { - error!(%err, "Failed to set client id in state"); - } - uuid - }, - } -} - -/// We accidently generates some clientIds in the settings file, we want to include those in the -/// telemetry events so we corolate those users with the correct clientIds -fn old_client_id_inner(settings: &Settings) -> Option { - settings - .get_string(CLIENT_ID_STATE_KEY) - .ok() - .flatten() - .and_then(|s| Uuid::from_str(&s).ok()) -} - -#[cfg(test)] -mod tests { - use super::*; - - const TEST_UUID_STR: &str = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"; - const TEST_UUID: Uuid = uuid!(TEST_UUID_STR); - - #[test] - fn test_is_telemetry_disabled() { - // disabled by default in tests - // let is_disabled = telemetry_is_disabled(); - // assert!(!is_disabled); - - // let settings = Settings::new_fake(); - - // let env = Env::from_slice(&[("Q_DISABLE_TELEMETRY", "1")]); - // assert!(telemetry_is_disabled_inner(true, &env, &settings)); - // assert!(telemetry_is_disabled_inner(false, &env, &settings)); - - // let env = Env::new_fake(); - // assert!(telemetry_is_disabled_inner(true, &env, &settings)); - // assert!(!telemetry_is_disabled_inner(false, &env, &settings)); - - // settings.set_value("telemetry.enabled", false).unwrap(); - // assert!(telemetry_is_disabled_inner(false, &env, &settings)); - // assert!(!telemetry_is_disabled_inner(true, &env, &settings)); - - // settings.set_value("telemetry.enabled", true).unwrap(); - // assert!(!telemetry_is_disabled_inner(false, &env, &settings)); - // assert!(!telemetry_is_disabled_inner(true, &env, &settings)); - } - - #[ignore] - #[test] - fn test_get_client_id() { - // max by default in tests - let id = get_client_id(); - assert!(id.is_max()); - - let state = State::new(); - let settings = Settings::new(); - - let env = Env::from_slice(&[(CLIENT_ID_ENV_VAR, TEST_UUID_STR)]); - assert_eq!(get_client_id_inner(false, &env, &state, &settings), TEST_UUID); - - let env = Env::new(); - - // in tests returns the test uuid - assert!(get_client_id_inner(true, &env, &state, &settings).is_max()); - - // returns the currently set client id if one is found - state.set_value(CLIENT_ID_STATE_KEY, TEST_UUID_STR).unwrap(); - assert_eq!(get_client_id_inner(false, &env, &state, &settings), TEST_UUID); - - // generates a new client id if none is found - state.remove_value(CLIENT_ID_STATE_KEY).unwrap(); - assert_eq!( - get_client_id_inner(false, &env, &state, &settings).to_string(), - state.get_string(CLIENT_ID_STATE_KEY).unwrap().unwrap() - ); - - // migrates the client id in settings - state.remove_value(CLIENT_ID_STATE_KEY).unwrap(); - settings.set_value(CLIENT_ID_STATE_KEY, TEST_UUID_STR).unwrap(); - assert_eq!(get_client_id_inner(false, &env, &state, &settings), TEST_UUID); - } - - #[test] - fn test_get_client_id_old() { - let settings = Settings::new(); - assert!(old_client_id_inner(&settings).is_none()); - settings.set_value(CLIENT_ID_STATE_KEY, TEST_UUID_STR).unwrap(); - assert_eq!(old_client_id_inner(&settings), Some(TEST_UUID)); - } -} diff --git a/crates/chat-cli/src/util/directories.rs b/crates/chat-cli/src/util/directories.rs index 925c7471ad..b185e70eea 100644 --- a/crates/chat-cli/src/util/directories.rs +++ b/crates/chat-cli/src/util/directories.rs @@ -3,7 +3,6 @@ use std::path::PathBuf; use thiserror::Error; use crate::platform::Context; -use crate::util::env_var::Q_PARENT; #[derive(Debug, Error)] pub enum DirectoryError { @@ -11,8 +10,6 @@ pub enum DirectoryError { NoHomeDirectory, #[error("runtime directory not found: neither XDG_RUNTIME_DIR nor TMPDIR were found")] NoRuntimeDirectory, - #[error("non absolute path: {0:?}")] - NonAbsolutePath(PathBuf), #[error("IO Error: {0}")] Io(#[from] std::io::Error), #[error(transparent)] @@ -25,10 +22,6 @@ pub enum DirectoryError { FromVecWithNul(#[from] std::ffi::FromVecWithNulError), #[error(transparent)] IntoString(#[from] std::ffi::IntoStringError), - #[error("{Q_PARENT} env variable not set")] - QParentNotSet, - #[error("must be ran from an appimage executable")] - NotAppImage, } type Result = std::result::Result; @@ -113,8 +106,7 @@ pub fn runtime_dir() -> Result { pub fn logs_dir() -> Result { cfg_if::cfg_if! { if #[cfg(unix)] { - use crate::util::CHAT_BINARY_NAME; - Ok(runtime_dir()?.join(format!("{CHAT_BINARY_NAME}log"))) + Ok(runtime_dir()?.join("qlog")) } else if #[cfg(windows)] { Ok(std::env::temp_dir().join("amazon-q").join("logs")) } @@ -136,6 +128,11 @@ pub fn settings_path() -> Result { Ok(fig_data_dir()?.join("settings.json")) } +/// The path to the local sqlite database +pub fn database_path() -> Result { + Ok(fig_data_dir()?.join("data.sqlite3")) +} + #[cfg(test)] mod linux_tests { use super::*; diff --git a/crates/chat-cli/src/util/mod.rs b/crates/chat-cli/src/util/mod.rs index 8ec154c645..185dcee0ea 100644 --- a/crates/chat-cli/src/util/mod.rs +++ b/crates/chat-cli/src/util/mod.rs @@ -29,22 +29,8 @@ use tracing::warn; pub enum UtilError { #[error("io operation error")] IoError(#[from] std::io::Error), - #[error("unsupported platform")] - UnsupportedPlatform, - #[error("unsupported architecture")] - UnsupportedArch, #[error(transparent)] Directory(#[from] directories::DirectoryError), - #[error("could not find the os hwid")] - HwidNotFound, - #[error("the shell, `{0}`, isn't supported yet")] - UnknownShell(String), - #[error("missing environment variable `{0}`")] - MissingEnv(&'static str), - #[error("unknown display server `{0}`")] - UnknownDisplayServer(String), - #[error("unknown desktop, checked environment variables: {0}")] - UnknownDesktop(UnknownDesktopErrContext), #[error(transparent)] StrUtf8Error(#[from] std::str::Utf8Error), #[error(transparent)]