diff --git a/.tool-versions b/.tool-versions deleted file mode 100644 index c3507cd6..00000000 --- a/.tool-versions +++ /dev/null @@ -1 +0,0 @@ -rust stable \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index acd50f41..5083992f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -800,6 +800,7 @@ name = "cipherstash-proxy" version = "2.0.0" dependencies = [ "arc-swap", + "async-trait", "aws-lc-rs", "bigdecimal", "bytes", diff --git a/docs/errors.md b/docs/errors.md index 9f5c60e1..670df8b7 100644 --- a/docs/errors.md +++ b/docs/errors.md @@ -32,6 +32,8 @@ - Configuration errors: - [Missing or invalid TLS configuration](#config-missing-or-invalid-tls) + - [Network configuration change requires restart](#config-network-change-requires-restart) + @@ -651,3 +653,35 @@ Check that the certificate and private key are valid. + + +## Network configuration change requires restart + +A configuration reload was attempted with network-level changes that require a full restart. + +### Error message + +``` +Network configuration change requires restart +``` + +### Notes + +When receiving a SIGHUP signal, CipherStash Proxy attempts to reload application-level configuration without disrupting active connections. However, certain network-related configuration changes require stopping and restarting the proxy service to take effect. + +The following settings require a restart when changed: +- `server.host` - The host address the proxy listens on +- `server.port` - The port the proxy listens on +- `server.require_tls` - TLS requirement setting +- `server.worker_threads` - Number of worker threads +- `tls` - Any TLS certificate or key configuration + +### How to fix + +1. Stop the CipherStash Proxy service +2. Update the configuration as needed +3. Restart the CipherStash Proxy service + +Application-level configuration changes (database, auth, encrypt, log, prometheus, development) can be reloaded without restart using SIGHUP. + + diff --git a/mise.toml b/mise.toml index eb451dcf..d0256394 100644 --- a/mise.toml +++ b/mise.toml @@ -309,6 +309,8 @@ echo mise --env tcp run postgres:setup mise --env tls run postgres:setup +mise run test:integration:showcase + echo echo '###############################################' echo '# Test: Prometheus' @@ -377,10 +379,8 @@ echo '###############################################' echo '# Test: Showcase' echo '###############################################' echo -mise --env tls run proxy:up proxy-tls --extra-args "--detach --wait" -mise --env tls run test:wait_for_postgres_to_quack --port 6432 --max-retries 20 --tls -RUST_BACKTRACE=full cargo run -p showcase -mise --env tls run proxy:down + +mise run test:integration:showcase echo echo '###############################################' @@ -637,6 +637,15 @@ else fi """ +[tasks."test:integration:showcase"] +description = "Run Showcase integration test" +run = """ +mise --env tls run proxy:up proxy-tls --extra-args "--detach --wait" +mise --env tls run test:wait_for_postgres_to_quack --port 6432 --max-retries 20 --tls +RUST_BACKTRACE=full cargo run -p showcase +mise --env tls run proxy:down +""" + [tasks.release] description = "Publish release artifacts" depends = ["release:docker"] diff --git a/packages/cipherstash-proxy/Cargo.toml b/packages/cipherstash-proxy/Cargo.toml index e9373779..b47033da 100644 --- a/packages/cipherstash-proxy/Cargo.toml +++ b/packages/cipherstash-proxy/Cargo.toml @@ -4,6 +4,7 @@ version = "2.0.0" edition = "2021" [dependencies] +async-trait = "0.1" aws-lc-rs = "1.13.3" bigdecimal = { version = "0.4.6", features = ["serde-json"] } arc-swap = "1.7.1" diff --git a/packages/cipherstash-proxy/src/config/database.rs b/packages/cipherstash-proxy/src/config/database.rs index 464698cb..428d96ac 100644 --- a/packages/cipherstash-proxy/src/config/database.rs +++ b/packages/cipherstash-proxy/src/config/database.rs @@ -81,6 +81,21 @@ impl DatabaseConfig { })?; Ok(name) } + + #[cfg(test)] + pub fn for_testing() -> Self { + Self { + host: Self::default_host(), + port: Self::default_port(), + name: "test".to_string(), + username: "test".to_string(), + password: Protected::new("test".to_string()), + connection_timeout: None, + with_tls_verification: false, + config_reload_interval: Self::default_config_reload_interval(), + schema_reload_interval: Self::default_schema_reload_interval(), + } + } } /// diff --git a/packages/cipherstash-proxy/src/config/tandem.rs b/packages/cipherstash-proxy/src/config/tandem.rs index bf018e99..9b5e6a3b 100644 --- a/packages/cipherstash-proxy/src/config/tandem.rs +++ b/packages/cipherstash-proxy/src/config/tandem.rs @@ -273,6 +273,29 @@ impl TandemConfig { DEFAULT_THREAD_STACK_SIZE } + + #[cfg(test)] + pub fn for_testing() -> Self { + Self { + server: ServerConfig::default(), + database: DatabaseConfig::for_testing(), + auth: AuthConfig { + workspace_crn: "crn:ap-southeast-2.aws:IJGECSCWKREECNBS".parse().unwrap(), + client_access_key: "test".to_string(), + }, + encrypt: EncryptConfig { + client_id: "test".to_string(), + client_key: "test".to_string(), + default_keyset_id: Some( + Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap(), + ), + }, + tls: None, + log: LogConfig::default(), + prometheus: PrometheusConfig::default(), + development: None, + } + } } impl PrometheusConfig { diff --git a/packages/cipherstash-proxy/src/config/tls.rs b/packages/cipherstash-proxy/src/config/tls.rs index e5c0647b..9585cff6 100644 --- a/packages/cipherstash-proxy/src/config/tls.rs +++ b/packages/cipherstash-proxy/src/config/tls.rs @@ -10,7 +10,7 @@ use crate::{error::TlsConfigError, log::CONFIG}; /// Server TLS Configuration /// This is listener/inbound connection config /// -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, PartialEq)] #[serde(untagged)] pub enum TlsConfig { Pem { diff --git a/packages/cipherstash-proxy/src/error.rs b/packages/cipherstash-proxy/src/error.rs index 2d09107c..8a1f0a8a 100644 --- a/packages/cipherstash-proxy/src/error.rs +++ b/packages/cipherstash-proxy/src/error.rs @@ -176,6 +176,9 @@ pub enum ConfigError { #[error("Expected an Encrypt configuration table")] MissingEncryptConfigTable, + #[error("Network configuration change requires restart For help visit {}#config-network-change-requires-restart", ERROR_DOC_BASE_URL)] + NetworkConfigurationChangeRequiresRestart, + #[error(transparent)] Parse(#[from] serde_json::Error), diff --git a/packages/cipherstash-proxy/src/main.rs b/packages/cipherstash-proxy/src/main.rs index fa3958c6..a11da143 100644 --- a/packages/cipherstash-proxy/src/main.rs +++ b/packages/cipherstash-proxy/src/main.rs @@ -1,12 +1,11 @@ use cipherstash_proxy::config::TandemConfig; use cipherstash_proxy::connect::{self, AsyncStream}; -use cipherstash_proxy::error::Error; +use cipherstash_proxy::error::{ConfigError, Error}; use cipherstash_proxy::prometheus::CLIENTS_ACTIVE_CONNECTIONS; use cipherstash_proxy::proxy::Proxy; use cipherstash_proxy::{cli, log, postgresql as pg, prometheus, tls, Args}; use clap::Parser; use metrics::gauge; -use tokio::net::TcpListener; use tokio::signal::unix::{signal, SignalKind}; use tokio_util::task::TaskTracker; use tracing::{error, info, warn}; @@ -55,7 +54,7 @@ fn main() -> Result<(), Box> { let mut proxy = init(config).await; - let mut listener = connect::bind_with_retry(&proxy.config.server).await; + let listener = connect::bind_with_retry(&proxy.config.server).await; let tracker = TaskTracker::new(); let mut client_id = 0; @@ -81,9 +80,8 @@ fn main() -> Result<(), Box> { break; }, _ = sighup() => { - info!(msg = "Received SIGHUP. Reloading configuration"); - (listener, proxy) = reload_config(listener, &args, proxy).await; - info!(msg = "Reloaded configuration"); + info!(msg = "Received SIGHUP. Reloading application configuration"); + proxy = reload_application_config(&proxy.config, &args).await.unwrap_or(proxy); }, _ = sigterm() => { info!(msg = "Received SIGTERM"); @@ -91,16 +89,15 @@ fn main() -> Result<(), Box> { }, Ok(client_stream) = AsyncStream::accept(&listener) => { - let proxy = proxy.clone(); - client_id += 1; + let context = proxy.context(client_id); + tracker.spawn(async move { - let proxy = proxy.clone(); gauge!(CLIENTS_ACTIVE_CONNECTIONS).increment(1); - match pg::handler(client_stream, proxy, client_id).await { + match pg::handler(client_stream,context).await { Ok(_) => (), Err(err) => { @@ -261,7 +258,15 @@ async fn sighup() -> std::io::Result<()> { Ok(()) } -async fn reload_config(listener: TcpListener, args: &Args, proxy: Proxy) -> (TcpListener, Proxy) { +fn has_network_config_changed(current: &TandemConfig, new: &TandemConfig) -> bool { + current.server.host != new.server.host + || current.server.port != new.server.port + || current.server.require_tls != new.server.require_tls + || current.server.worker_threads != new.server.worker_threads + || current.tls != new.tls +} + +async fn reload_application_config(config: &TandemConfig, args: &Args) -> Result { let new_config = match TandemConfig::load(args) { Ok(config) => config, Err(err) => { @@ -269,17 +274,19 @@ async fn reload_config(listener: TcpListener, args: &Args, proxy: Proxy) -> (Tcp msg = "Configuration could not be reloaded: {}", error = err.to_string() ); - return (listener, proxy); + return Err(err); } }; - let new_proxy = init(new_config).await; + // Check for network config changes that require restart + if has_network_config_changed(config, &new_config) { + let err = ConfigError::NetworkConfigurationChangeRequiresRestart; + warn!(msg = err.to_string()); - // Explicit drop needed here to free the network resources before binding if using the same address & port - std::mem::drop(listener); + return Err(err.into()); + } - ( - connect::bind_with_retry(&new_proxy.config.server).await, - new_proxy, - ) + info!(msg = "Configuration reloaded"); + let proxy = init(new_config).await; + Ok(proxy) } diff --git a/packages/cipherstash-proxy/src/postgresql/backend.rs b/packages/cipherstash-proxy/src/postgresql/backend.rs index 474d8309..f093716e 100644 --- a/packages/cipherstash-proxy/src/postgresql/backend.rs +++ b/packages/cipherstash-proxy/src/postgresql/backend.rs @@ -19,7 +19,7 @@ use crate::prometheus::{ DECRYPTION_ERROR_TOTAL, DECRYPTION_REQUESTS_TOTAL, ROWS_ENCRYPTED_TOTAL, ROWS_PASSTHROUGH_TOTAL, ROWS_TOTAL, SERVER_BYTES_RECEIVED_TOTAL, }; -use crate::proxy::Proxy; +use crate::proxy::EncryptionService; use bytes::BytesMut; use metrics::{counter, histogram}; use std::time::Instant; @@ -70,25 +70,25 @@ use tracing::{debug, error, info, warn}; /// - `RowDescription`: Result column metadata (modified for encrypted columns) /// - `ParameterDescription`: Parameter metadata (modified for encrypted parameters) /// - `ReadyForQuery`: Session ready state (triggers schema reload if needed) -pub struct Backend +pub struct Backend where R: AsyncRead + Unpin, + S: EncryptionService, { /// Sender for outgoing messages to client client_sender: Sender, /// Reader for incoming messages from server server_reader: R, - /// Encryption service for column decryption - proxy: Proxy, /// Session context with portal and statement metadata - context: Context, + context: Context, /// Buffer for batching DataRow messages before decryption buffer: MessageBuffer, } -impl Backend +impl Backend where R: AsyncRead + Unpin, + S: EncryptionService, { /// Creates a new Backend instance. /// @@ -98,12 +98,11 @@ where /// * `server_reader` - Stream for reading messages from the PostgreSQL server /// * `encrypt` - Encryption service for handling column decryption /// * `context` - Session context shared with the frontend - pub fn new(client_sender: Sender, server_reader: R, proxy: Proxy, context: Context) -> Self { + pub fn new(client_sender: Sender, server_reader: R, context: Context) -> Self { let buffer = MessageBuffer::new(); Backend { client_sender, server_reader, - proxy, context, buffer, } @@ -150,19 +149,17 @@ where /// Returns `Ok(())` on successful message processing, or an `Error` if a fatal /// error occurs that should terminate the connection. pub async fn rewrite(&mut self) -> Result<(), Error> { - let connection_timeout = self.proxy.config.database.connection_timeout(); - let (code, mut bytes) = protocol::read_message( &mut self.server_reader, self.context.client_id, - connection_timeout, + self.context.connection_timeout(), ) .await?; let sent: u64 = bytes.len() as u64; counter!(SERVER_BYTES_RECEIVED_TOTAL).increment(sent); - if self.proxy.is_passthrough() { + if self.context.is_passthrough() { debug!(target: DEVELOPMENT, client_id = self.context.client_id, msg = "Passthrough enabled" @@ -250,7 +247,7 @@ where msg = "ReadyForQuery" ); if self.context.schema_changed() { - self.proxy.reload_schema().await; + self.context.reload_schema().await; } } @@ -450,16 +447,12 @@ where ); // Decrypt CipherText -> Plaintext - let plaintexts = self - .proxy - .decrypt(keyset_id, ciphertexts) - .await - .inspect_err(|_| { - counter!(DECRYPTION_ERROR_TOTAL).increment(1); - })?; + let plaintexts = self.context.decrypt(ciphertexts).await.inspect_err(|_| { + counter!(DECRYPTION_ERROR_TOTAL).increment(1); + })?; // Avoid the iter calculation if we can - if self.proxy.config.prometheus_enabled() { + if self.context.prometheus_enabled() { let decrypted_count = plaintexts .iter() @@ -655,9 +648,10 @@ where } /// Implementation of PostgreSQL error handling for the Backend component. -impl PostgreSqlErrorHandler for Backend +impl PostgreSqlErrorHandler for Backend where R: AsyncRead + Unpin, + S: EncryptionService, { fn client_sender(&mut self) -> &mut Sender { &mut self.client_sender diff --git a/packages/cipherstash-proxy/src/postgresql/column_mapper.rs b/packages/cipherstash-proxy/src/postgresql/column_mapper.rs new file mode 100644 index 00000000..c63bba03 --- /dev/null +++ b/packages/cipherstash-proxy/src/postgresql/column_mapper.rs @@ -0,0 +1,167 @@ +use crate::{ + eql::Identifier, + error::{EncryptError, Error}, + log::MAPPER, + postgresql::Column, + proxy::EncryptConfig, +}; +use eql_mapper::{EqlTerm, TableColumn, TypeCheckedStatement}; +use postgres_types::Type; +use std::sync::Arc; +use tracing::{debug, warn}; + +/// Service responsible for processing columns from type-checked SQL statements +/// and mapping them to encryption configurations. +#[derive(Clone)] +pub struct ColumnMapper { + encrypt_config: Arc, +} + +impl ColumnMapper { + /// Create a new ColumnProcessor with the given schema service and client ID + pub fn new(encrypt_config: Arc) -> Self { + Self { encrypt_config } + } + + /// Maps typed statement projection columns to an Encrypt column configuration + /// + /// The returned `Vec` is of `Option` because the Projection columns are a mix of native and EQL types. + /// Only EQL columns will have a configuration. Native types are always None. + /// + /// Preserves the ordering and semantics of the projection to reduce the complexity of positional encryption. + pub fn get_projection_columns( + &self, + typed_statement: &TypeCheckedStatement<'_>, + ) -> Result>, Error> { + let mut projection_columns = vec![]; + + for col in typed_statement.projection.columns() { + let eql_mapper::ProjectionColumn { ty, .. } = col; + let configured_column = match &**ty { + eql_mapper::Type::Value(eql_mapper::Value::Eql(eql_term)) => { + let TableColumn { table, column } = eql_term.table_column(); + let identifier: Identifier = Identifier::from((table, column)); + + debug!( + target: MAPPER, + msg = "Configured column", + column = ?identifier, + ?eql_term, + ); + self.get_column(identifier, eql_term)? + } + _ => None, + }; + projection_columns.push(configured_column) + } + + Ok(projection_columns) + } + + /// Maps typed statement param columns to an Encrypt column configuration + /// + /// The returned `Vec` is of `Option` because the Param columns are a mix of native and EQL types. + /// Only EQL colunms will have a configuration. Native types are always None. + /// + /// Preserves the ordering and semantics of the projection to reduce the complexity of positional encryption. + pub fn get_param_columns( + &self, + typed_statement: &TypeCheckedStatement<'_>, + ) -> Result>, Error> { + let mut param_columns = vec![]; + + for param in typed_statement.params.iter() { + let configured_column = match param { + (_, eql_mapper::Value::Eql(eql_term)) => { + let TableColumn { table, column } = eql_term.table_column(); + let identifier = Identifier::from((table, column)); + + debug!( + target: MAPPER, + msg = "Encrypted parameter", + column = ?identifier, + ?eql_term, + ); + + self.get_column(identifier, eql_term)? + } + _ => None, + }; + param_columns.push(configured_column); + } + + Ok(param_columns) + } + + /// Maps typed statement literal columns to an Encrypt column configuration + pub fn get_literal_columns( + &self, + typed_statement: &TypeCheckedStatement<'_>, + ) -> Result>, Error> { + let mut literal_columns = vec![]; + + for (eql_term, _) in typed_statement.literals.iter() { + let TableColumn { table, column } = eql_term.table_column(); + let identifier = Identifier::from((table, column)); + + debug!( + target: MAPPER, + msg = "Encrypted literal", + column = ?identifier, + ?eql_term, + ); + let col = self.get_column(identifier, eql_term)?; + if col.is_some() { + literal_columns.push(col); + } + } + + Ok(literal_columns) + } + + /// Get the column configuration for the Identifier + /// Returns `EncryptError::UnknownColumn` if configuration cannot be found for the Identified column + /// if mapping enabled, and None if mapping is disabled. It'll log a warning either way. + fn get_column( + &self, + identifier: Identifier, + eql_term: &EqlTerm, + ) -> Result, Error> { + match self.encrypt_config.get_column_config(&identifier) { + Some(config) => { + debug!( + target: MAPPER, + msg = "Configured column", + column = ?identifier + ); + + // IndexTerm::SteVecSelector + let postgres_type = if matches!(eql_term, EqlTerm::JsonPath(_)) { + Some(Type::JSONPATH) + } else { + None + }; + + let eql_term = eql_term.variant(); + Ok(Some(Column::new( + identifier, + config, + postgres_type, + eql_term, + ))) + } + None => { + warn!( + target: MAPPER, + msg = "Configured column not found. Encryption configuration may have been deleted.", + ?identifier, + ); + Err(EncryptError::UnknownColumn { + table: identifier.table.to_owned(), + column: identifier.column.to_owned(), + } + .into()) + } + } + } +} diff --git a/packages/cipherstash-proxy/src/postgresql/context/mod.rs b/packages/cipherstash-proxy/src/postgresql/context/mod.rs index d8ed65b5..4b2d9aa0 100644 --- a/packages/cipherstash-proxy/src/postgresql/context/mod.rs +++ b/packages/cipherstash-proxy/src/postgresql/context/mod.rs @@ -1,14 +1,19 @@ pub mod column; +pub mod portal; +pub mod statement; +pub use self::{portal::Portal, statement::Statement}; use super::{ - format_code::FormatCode, + column_mapper::ColumnMapper, messages::{describe::Describe, Name, Target}, Column, }; use crate::{ + config::TandemConfig, error::{EncryptError, Error}, log::CONTEXT, prometheus::{STATEMENTS_EXECUTION_DURATION_SECONDS, STATEMENTS_SESSION_DURATION_SECONDS}, + proxy::{EncryptConfig, EncryptionService, ReloadCommand, ReloadSender}, }; use cipherstash_client::IdentifiedBy; use eql_mapper::{Schema, TableResolver}; @@ -19,7 +24,8 @@ use std::{ sync::{Arc, LazyLock, RwLock}, time::{Duration, Instant}, }; -use tracing::{debug, warn}; +use tokio::sync::oneshot; +use tracing::{debug, error, warn}; use uuid::Uuid; type DescribeQueue = Queue; @@ -36,9 +42,17 @@ impl std::fmt::Display for KeysetIdentifier { } } -#[derive(Clone, Debug)] -pub struct Context { +#[derive(Clone)] +pub struct Context +where + T: EncryptionService, +{ pub client_id: i32, + config: Arc, + encrypt_config: Arc, + encryption: T, + reload_sender: ReloadSender, + column_mapper: ColumnMapper, statements: Arc>>>, portals: Arc>>, describe: Arc>, @@ -91,28 +105,20 @@ pub struct Queue { pub queue: VecDeque, } -/// -/// Type Analysed parameters and projection -/// -#[derive(Debug, Clone, PartialEq)] -pub struct Statement { - pub param_columns: Vec>, - pub projection_columns: Vec>, - pub literal_columns: Vec>, - pub postgres_param_types: Vec, -} - -#[derive(Clone, Debug)] -pub enum Portal { - Encrypted { - format_codes: Vec, - statement: Arc, - }, - Passthrough, -} +impl Context +where + T: EncryptionService, +{ + pub fn new( + client_id: i32, + config: Arc, + encrypt_config: Arc, + schema: Arc, + encryption: T, + reload_sender: ReloadSender, + ) -> Context { + let column_mapper = ColumnMapper::new(encrypt_config.clone()); -impl Context { - pub fn new(client_id: i32, schema: Arc) -> Context { Context { statements: Arc::new(RwLock::new(HashMap::new())), portals: Arc::new(RwLock::new(HashMap::new())), @@ -122,6 +128,11 @@ impl Context { session_metrics: Arc::new(RwLock::from(Queue::new())), table_resolver: Arc::new(TableResolver::new_editable(schema)), client_id, + config, + encrypt_config, + column_mapper, + encryption, + reload_sender, unsafe_disable_mapping: false, keyset_id: Arc::new(RwLock::new(None)), } @@ -301,6 +312,10 @@ impl Context { } pub fn set_schema_changed(&self) { + debug!(target: CONTEXT, + client_id = self.client_id, + msg = "Schema changed" + ); let _ = self.schema_changed.write().map(|mut guard| *guard = true); } @@ -477,37 +492,133 @@ impl Context { pub fn keyset_identifier(&self) -> Option { self.keyset_id.read().ok().and_then(|k| k.clone()) } -} -impl Statement { - pub fn new( - param_columns: Vec>, - projection_columns: Vec>, - literal_columns: Vec>, - postgres_param_types: Vec, - ) -> Statement { - Statement { - param_columns, - projection_columns, - literal_columns, - postgres_param_types, + // Service delegation methods + pub async fn encrypt( + &self, + plaintexts: Vec>, + columns: &[Option], + ) -> Result>, Error> { + let keyset_id = self.keyset_identifier(); + + self.encryption + .encrypt(keyset_id, plaintexts, columns) + .await + } + + pub async fn decrypt( + &self, + ciphertexts: Vec>, + ) -> Result>, Error> { + let keyset_id = self.keyset_identifier(); + self.encryption.decrypt(keyset_id, ciphertexts).await + } + + pub async fn reload_schema(&self) { + let (responder, receiver) = oneshot::channel(); + match self + .reload_sender + .send(ReloadCommand::DatabaseSchema(responder)) + { + Ok(_) => (), + Err(err) => { + error!( + msg = "Database schema could not be reloaded", + error = err.to_string() + ); + } } + + debug!(target: CONTEXT, msg = "Waiting for schema reload"); + let response = receiver.await; + debug!(target: CONTEXT, msg = "Database schema reloaded", ?response); + } + + pub fn is_passthrough(&self) -> bool { + self.encrypt_config.is_empty() || self.config.mapping_disabled() + } + + // Column processing delegation methods + pub fn get_projection_columns( + &self, + typed_statement: &eql_mapper::TypeCheckedStatement<'_>, + ) -> Result>, Error> { + self.column_mapper.get_projection_columns(typed_statement) + } + + pub fn get_param_columns( + &self, + typed_statement: &eql_mapper::TypeCheckedStatement<'_>, + ) -> Result>, Error> { + self.column_mapper.get_param_columns(typed_statement) + } + + pub fn get_literal_columns( + &self, + typed_statement: &eql_mapper::TypeCheckedStatement<'_>, + ) -> Result>, Error> { + self.column_mapper.get_literal_columns(typed_statement) } - pub fn unencryped() -> Statement { - Statement::new(vec![], vec![], vec![], vec![]) + // Direct config access methods + pub fn connection_timeout(&self) -> Option { + self.config.database.connection_timeout() } - pub fn has_literals(&self) -> bool { - !self.literal_columns.is_empty() + pub fn mapping_disabled(&self) -> bool { + self.config.mapping_disabled() } - pub fn has_params(&self) -> bool { - !self.param_columns.is_empty() + pub fn mapping_errors_enabled(&self) -> bool { + self.config.mapping_errors_enabled() } - pub fn has_projection(&self) -> bool { - !self.projection_columns.is_empty() + pub fn prometheus_enabled(&self) -> bool { + self.config.prometheus_enabled() + } + + pub fn default_keyset_id(&self) -> Option { + self.config + .encrypt + .default_keyset_id + .map(|uuid| KeysetIdentifier(IdentifiedBy::Uuid(uuid))) + } + + // Additional config access methods for handler + pub fn database_socket_address(&self) -> String { + self.config.database.to_socket_address() + } + + pub fn database_username(&self) -> &str { + &self.config.database.username + } + + pub fn database_password(&self) -> String { + self.config.database.password() + } + + pub fn tls_config(&self) -> &Option { + &self.config.tls + } + + pub fn use_tls(&self) -> bool { + self.config.tls.is_some() + } + + pub fn require_tls(&self) -> bool { + self.config.server.require_tls + } + + pub fn use_structured_logging(&self) -> bool { + self.config.use_structured_logging() + } + + pub fn database_tls_disabled(&self) -> bool { + self.config.database_tls_disabled() + } + + pub fn config(&self) -> &crate::config::TandemConfig { + &self.config } } @@ -531,76 +642,69 @@ impl Queue { } } -impl Portal { - pub fn encrypted_with_format_codes( - statement: Arc, - format_codes: Vec, - ) -> Portal { - Portal::Encrypted { - statement, - format_codes, - } - } - - pub fn encrypted(statement: Arc) -> Portal { - let format_codes = vec![]; - Portal::Encrypted { - statement, - format_codes, - } - } - - pub fn passthrough() -> Portal { - Portal::Passthrough - } - - pub fn projection_columns(&self) -> &Vec> { - static EMPTY: Vec> = vec![]; - match self { - Portal::Encrypted { statement, .. } => &statement.projection_columns, - _ => &EMPTY, - } - } - - // FormatCodes should not be None at this point - // FormatCodes will be: - // - empty, in which case assume Text - // - single value, in which case use this for all columns - // - multiple values, in which case use the value for each column - pub fn format_codes(&self, row_len: usize) -> Vec { - match self { - Portal::Encrypted { format_codes, .. } => match format_codes.len() { - 0 => vec![FormatCode::Text; row_len], - 1 => { - let format_code = match format_codes.first() { - Some(code) => *code, - None => FormatCode::Text, - }; - vec![format_code; row_len] - } - _ => format_codes.clone(), - }, - Portal::Passthrough => { - unreachable!() - } - } - } -} - #[cfg(test)] mod tests { use super::{Context, Describe, KeysetIdentifier, Portal, Statement}; use crate::{ config::LogConfig, + error::Error, log, - postgresql::messages::{Name, Target}, + postgresql::{ + messages::{Name, Target}, + Column, + }, + proxy::{EncryptConfig, EncryptionService}, + TandemConfig, }; use cipherstash_client::IdentifiedBy; use eql_mapper::Schema; use sqltk::parser::{dialect::PostgreSqlDialect, parser::Parser}; use std::sync::Arc; + use tokio::sync::mpsc; use uuid::Uuid; + struct TestService {} + + #[async_trait::async_trait] + impl EncryptionService for TestService { + async fn encrypt( + &self, + _keyset_id: Option, + _plaintexts: Vec>, + _columns: &[Option], + ) -> Result>, Error> { + Ok(vec![]) + } + + async fn decrypt( + &self, + _keyset_id: Option, + _ciphertexts: Vec>, + ) -> Result>, Error> { + Ok(vec![]) + } + } + + fn create_context() -> Context { + let client_id = 1; + let config = Arc::new(TandemConfig::for_testing()); + let encrypt_config = Arc::new(EncryptConfig::default()); + let schema = Arc::new(Schema::new("public")); + + let (reload_sender, _reload_receiver) = mpsc::unbounded_channel(); + + let service = TestService {}; + + Context::new( + client_id, + config, + encrypt_config, + schema, + service, + reload_sender, + ) + } + fn statement() -> Statement { Statement { param_columns: vec![], @@ -627,9 +731,7 @@ mod tests { pub fn get_statement_from_describe() { log::init(LogConfig::default()); - let schema = Arc::new(Schema::new("public")); - - let mut context = Context::new(1, schema); + let mut context = create_context(); let name = Name::from("name"); @@ -652,9 +754,7 @@ mod tests { pub fn execution_flow() { log::init(LogConfig::default()); - let schema = Arc::new(Schema::new("public")); - - let mut context = Context::new(1, schema); + let mut context = create_context(); let statement_name = Name::from("statement"); let portal_name = Name::from("portal"); @@ -693,9 +793,7 @@ mod tests { pub fn add_and_close_portals() { log::init(LogConfig::default()); - let schema = Arc::new(Schema::new("public")); - - let mut context = Context::new(1, schema); + let mut context = create_context(); // Create multiple statements let statement_name_1 = Name::from("statement_1"); @@ -743,9 +841,7 @@ mod tests { pub fn pipeline_execution() { log::init(LogConfig::default()); - let schema = Arc::new(Schema::new("public")); - - let mut context = Context::new(1, schema); + let mut context = create_context(); let statement_name_1 = Name::from("statement_1"); let portal_name_1 = Name::unnamed(); @@ -816,8 +912,7 @@ mod tests { pub fn disable_mapping() { log::init(LogConfig::default()); - let schema = Arc::new(Schema::new("public")); - let mut context = Context::new(1, schema); + let mut context = create_context(); let sql = "SET CIPHERSTASH.UNSAFE_DISABLE_MAPPING = true"; let statement = parse_statement(sql); @@ -854,8 +949,6 @@ mod tests { pub fn set_keyset_id() { log::init(LogConfig::default()); - let schema = Arc::new(Schema::new("public")); - let uuid = Uuid::parse_str("7d4cbd7f-ba0d-4985-9ed2-ebe2ffe77590").unwrap(); let identifier = KeysetIdentifier(IdentifiedBy::Uuid(uuid)); @@ -867,7 +960,7 @@ mod tests { ]; for s in sql { - let mut context = Context::new(1, schema.clone()); + let mut context = create_context(); assert!(context.keyset_identifier().is_none()); let statement = parse_statement(s); @@ -886,8 +979,7 @@ mod tests { pub fn set_keyset_id_error_handling() { log::init(LogConfig::default()); - let schema = Arc::new(Schema::new("public")); - let mut context = Context::new(1, schema); + let mut context = create_context(); // Returns OK if unknown command let sql = "SET CIPHERSTASH.BLAH = 'keyset_id'"; @@ -922,8 +1014,6 @@ mod tests { pub fn set_keyset_name() { log::init(LogConfig::default()); - let schema = Arc::new(Schema::new("public")); - let sql = vec![ "SET CIPHERSTASH.KEYSET_NAME = 'test-keyset'", "SET SESSION CIPHERSTASH.KEYSET_NAME = 'test-keyset'", @@ -931,7 +1021,7 @@ mod tests { ]; for s in sql { - let mut context = Context::new(1, schema.clone()); + let mut context = create_context(); assert!(context.keyset_identifier().is_none()); let statement = parse_statement(s); @@ -951,8 +1041,7 @@ mod tests { pub fn set_keyset_name_error_handling() { log::init(LogConfig::default()); - let schema = Arc::new(Schema::new("public")); - let mut context = Context::new(1, schema); + let mut context = create_context(); // Returns OK if unknown command let sql = "SET CIPHERSTASH.BLAH = 'keyset_name'"; @@ -988,10 +1077,8 @@ mod tests { pub fn set_keyset_supports_numbers() { log::init(LogConfig::default()); - let schema = Arc::new(Schema::new("public")); - // Test keyset name with number - let mut context = Context::new(1, schema.clone()); + let mut context = create_context(); let sql = "SET CIPHERSTASH.KEYSET_NAME = 12345"; let statement = parse_statement(sql); @@ -1002,7 +1089,7 @@ mod tests { assert_eq!(Some(identifier.clone()), context.keyset_identifier()); // Test keyset id with numeric UUID (should work if it's a valid UUID) - let mut context = Context::new(2, schema); + let mut context = create_context(); // This will fail because 123 is not a valid UUID, but it shows the number is processed let sql = "SET CIPHERSTASH.KEYSET_ID = 123"; let statement = parse_statement(sql); @@ -1016,10 +1103,8 @@ mod tests { pub fn maybe_set_keyset_unified_function() { log::init(LogConfig::default()); - let schema = Arc::new(Schema::new("public")); - // Test that maybe_set_keyset handles both ID and name - let mut context = Context::new(1, schema.clone()); + let mut context = create_context(); // Test with keyset ID let keyset_id_sql = "SET CIPHERSTASH.KEYSET_ID = '7d4cbd7f-ba0d-4985-9ed2-ebe2ffe77590'"; @@ -1035,7 +1120,7 @@ mod tests { assert_eq!(Some(identifier.clone()), context.keyset_identifier()); // Test with keyset name - let mut context = Context::new(2, schema.clone()); + let mut context = create_context(); let keyset_name_sql = "SET CIPHERSTASH.KEYSET_NAME = 'test-keyset'"; let statement = parse_statement(keyset_name_sql); @@ -1047,7 +1132,7 @@ mod tests { assert_eq!(Some(identifier.clone()), context.keyset_identifier()); // Test with unknown command - let mut context = Context::new(3, schema); + let mut context = create_context(); let unknown_sql = "SET CIPHERSTASH.UNKNOWN = 'value'"; let statement = parse_statement(unknown_sql); let result = context.maybe_set_keyset(&statement); diff --git a/packages/cipherstash-proxy/src/postgresql/context/portal.rs b/packages/cipherstash-proxy/src/postgresql/context/portal.rs new file mode 100644 index 00000000..c6cc6585 --- /dev/null +++ b/packages/cipherstash-proxy/src/postgresql/context/portal.rs @@ -0,0 +1,68 @@ +use super::{super::format_code::FormatCode, Column}; +use crate::postgresql::context::statement::Statement; +use std::sync::Arc; + +#[derive(Clone, Debug)] +pub enum Portal { + Encrypted { + format_codes: Vec, + statement: Arc, + }, + Passthrough, +} + +impl Portal { + pub fn encrypted_with_format_codes( + statement: Arc, + format_codes: Vec, + ) -> Portal { + Portal::Encrypted { + statement, + format_codes, + } + } + + pub fn encrypted(statement: Arc) -> Portal { + let format_codes = vec![]; + Portal::Encrypted { + statement, + format_codes, + } + } + + pub fn passthrough() -> Portal { + Portal::Passthrough + } + + pub fn projection_columns(&self) -> &Vec> { + static EMPTY: Vec> = vec![]; + match self { + Portal::Encrypted { statement, .. } => &statement.projection_columns, + _ => &EMPTY, + } + } + + // FormatCodes should not be None at this point + // FormatCodes will be: + // - empty, in which case assume Text + // - single value, in which case use this for all columns + // - multiple values, in which case use the value for each column + pub fn format_codes(&self, row_len: usize) -> Vec { + match self { + Portal::Encrypted { format_codes, .. } => match format_codes.len() { + 0 => vec![FormatCode::Text; row_len], + 1 => { + let format_code = match format_codes.first() { + Some(code) => *code, + None => FormatCode::Text, + }; + vec![format_code; row_len] + } + _ => format_codes.clone(), + }, + Portal::Passthrough => { + unreachable!() + } + } + } +} diff --git a/packages/cipherstash-proxy/src/postgresql/context/statement.rs b/packages/cipherstash-proxy/src/postgresql/context/statement.rs new file mode 100644 index 00000000..68170d82 --- /dev/null +++ b/packages/cipherstash-proxy/src/postgresql/context/statement.rs @@ -0,0 +1,40 @@ +use super::Column; + +/// +/// Type Analysed parameters and projection +/// +#[derive(Debug, Clone, PartialEq)] +pub struct Statement { + pub param_columns: Vec>, + pub projection_columns: Vec>, + pub literal_columns: Vec>, + pub postgres_param_types: Vec, +} + +impl Statement { + pub fn new( + param_columns: Vec>, + projection_columns: Vec>, + literal_columns: Vec>, + postgres_param_types: Vec, + ) -> Statement { + Statement { + param_columns, + projection_columns, + literal_columns, + postgres_param_types, + } + } + + pub fn has_literals(&self) -> bool { + !self.literal_columns.is_empty() + } + + pub fn has_params(&self) -> bool { + !self.param_columns.is_empty() + } + + pub fn has_projection(&self) -> bool { + !self.projection_columns.is_empty() + } +} diff --git a/packages/cipherstash-proxy/src/postgresql/frontend.rs b/packages/cipherstash-proxy/src/postgresql/frontend.rs index 44d9311f..e5af3edb 100644 --- a/packages/cipherstash-proxy/src/postgresql/frontend.rs +++ b/packages/cipherstash-proxy/src/postgresql/frontend.rs @@ -9,9 +9,8 @@ use super::messages::FrontendCode as Code; use super::parser::SqlParser; use super::protocol::{self}; use crate::connect::Sender; -use crate::eql::Identifier; use crate::error::{EncryptError, Error, MappingError}; -use crate::log::{CONTEXT, MAPPER, PROTOCOL}; +use crate::log::{MAPPER, PROTOCOL}; use crate::postgresql::context::column::Column; use crate::postgresql::context::Portal; use crate::postgresql::data::literal_from_sql; @@ -25,14 +24,13 @@ use crate::prometheus::{ STATEMENTS_ENCRYPTED_TOTAL, STATEMENTS_PASSTHROUGH_MAPPING_DISABLED_TOTAL, STATEMENTS_PASSTHROUGH_TOTAL, STATEMENTS_UNMAPPABLE_TOTAL, }; -use crate::proxy::Proxy; +use crate::proxy::EncryptionService; use crate::EqlEncrypted; use bytes::BytesMut; use cipherstash_client::encryption::Plaintext; -use eql_mapper::{self, EqlMapperError, EqlTerm, TableColumn, TypeCheckedStatement}; +use eql_mapper::{self, EqlMapperError, EqlTerm, TypeCheckedStatement}; use metrics::{counter, histogram}; use pg_escape::quote_literal; -use postgres_types::Type; use serde::Serialize; use sqltk::parser::ast::{self, Value}; use sqltk::NodeKey; @@ -85,10 +83,11 @@ use tracing::{debug, error, info, warn}; /// Encryption and mapping errors are converted to appropriate PostgreSQL error responses /// and sent back to the client. The frontend maintains error state to properly handle /// the PostgreSQL extended query error recovery protocol. -pub struct Frontend +pub struct Frontend where R: AsyncRead + Unpin, W: AsyncWrite + Unpin, + S: EncryptionService, { /// Reader for incoming client messages client_reader: R, @@ -96,10 +95,8 @@ where client_sender: Sender, /// Writer for forwarding messages to server server_writer: W, - /// Proxy service for column encryption/decryption and configuration - proxy: Proxy, /// Session context tracking statements, portals, and keyset IDs - context: Context, + context: Context, /// Error state flag for extended query protocol error handling error_state: Option, } @@ -107,10 +104,11 @@ where #[derive(Debug)] struct ErrorState; -impl Frontend +impl Frontend where R: AsyncRead + Unpin, W: AsyncWrite + Unpin, + S: EncryptionService, { /// Creates a new Frontend instance. /// @@ -119,20 +117,17 @@ where /// * `client_reader` - Stream for reading messages from the PostgreSQL client /// * `client_sender` - Channel sender for sending messages back to client /// * `server_writer` - Stream for writing messages to the PostgreSQL server - /// * `encrypt` - Encryption service for handling column encryption/decryption - /// * `context` - Session context for tracking statements and portals + /// * `context` - Session context for tracking statements and portals with service access pub fn new( client_reader: R, client_sender: Sender, server_writer: W, - proxy: Proxy, - context: Context, + context: Context, ) -> Self { Frontend { client_reader, client_sender, server_writer, - proxy, context, error_state: None, } @@ -162,18 +157,17 @@ where /// Returns `Ok(())` on successful message processing, or an `Error` if a fatal /// error occurs that should terminate the connection. pub async fn rewrite(&mut self) -> Result<(), Error> { - let connection_timeout = self.proxy.config.database.connection_timeout(); let (code, mut bytes) = protocol::read_message( &mut self.client_reader, self.context.client_id, - connection_timeout, + self.context.connection_timeout(), ) .await?; let sent: u64 = bytes.len() as u64; counter!(CLIENTS_BYTES_RECEIVED_TOTAL).increment(sent); - if self.proxy.config.mapping_disabled() { + if self.context.mapping_disabled() { self.write_to_server(bytes).await?; return Ok(()); } @@ -273,7 +267,7 @@ where ); if self.context.schema_changed() { - self.proxy.reload_schema().await; + self.context.reload_schema().await; } if self.error_state.is_some() { @@ -421,7 +415,7 @@ where let typed_statement = match self.type_check(&statement) { Ok(ts) => ts, Err(err) => { - if self.proxy.config.mapping_errors_enabled() { + if self.context.mapping_errors_enabled() { return Err(err); } else { return Ok(None); @@ -533,15 +527,13 @@ where return Ok(vec![]); } - let keyset_id = self.context.keyset_identifier(); - let plaintexts = literals_to_plaintext(literal_values, literal_columns)?; let start = Instant::now(); let encrypted = self - .proxy - .encrypt(keyset_id, plaintexts, literal_columns) + .context + .encrypt(plaintexts, literal_columns) .await .inspect_err(|_| { counter!(ENCRYPTION_ERROR_TOTAL).increment(1); @@ -684,7 +676,7 @@ where let typed_statement = match self.type_check(&statement) { Ok(ts) => ts, Err(err) => { - if self.proxy.config.mapping_errors_enabled() { + if self.context.mapping_errors_enabled() { return Err(err); } else { return Ok(None); @@ -754,10 +746,6 @@ where let schema_changed = eql_mapper::collect_ddl(self.context.get_table_resolver(), statement); if schema_changed { - debug!(target: MAPPER, - client_id = self.context.client_id, - msg = "schema changed" - ); self.context.set_schema_changed(); } } @@ -772,10 +760,10 @@ where if let Some(keyset_identifier) = self.context.maybe_set_keyset(statement)? { debug!(client_id = self.context.client_id, ?keyset_identifier); - if self.proxy.config.encrypt.default_keyset_id.is_some() { + if self.context.default_keyset_id().is_some() { debug!(target: MAPPER, client_id = self.context.client_id, - default_keyset_id = ?self.proxy.config.encrypt.default_keyset_id, + default_keyset_id = ?self.context.default_keyset_id(), ?keyset_identifier ); return Err(EncryptError::UnexpectedSetKeyset.into()); @@ -799,9 +787,9 @@ where typed_statement: &TypeCheckedStatement<'_>, param_types: Vec, ) -> Result, Error> { - let param_columns = self.get_param_columns(typed_statement)?; - let projection_columns = self.get_projection_columns(typed_statement)?; - let literal_columns = self.get_literal_columns(typed_statement)?; + let param_columns = self.context.get_param_columns(typed_statement)?; + let projection_columns = self.context.get_projection_columns(typed_statement)?; + let literal_columns = self.context.get_literal_columns(typed_statement)?; let no_encrypted_param_columns = param_columns.iter().all(|c| c.is_none()); let no_encrypted_projection_columns = projection_columns.iter().all(|c| c.is_none()); @@ -933,28 +921,23 @@ where bind: &Bind, statement: &Statement, ) -> Result>, Error> { - let keyset_id = self.context.keyset_identifier(); let plaintexts = bind.to_plaintext(&statement.param_columns, &statement.postgres_param_types)?; debug!(target: MAPPER, client_id = self.context.client_id, plaintexts = ?plaintexts); - debug!(target: CONTEXT, - client_id = self.context.client_id, - ?keyset_id, - ); let start = Instant::now(); let encrypted = self - .proxy - .encrypt(keyset_id, plaintexts, &statement.param_columns) + .context + .encrypt(plaintexts, &statement.param_columns) .await .inspect_err(|_| { counter!(ENCRYPTION_ERROR_TOTAL).increment(1); })?; // Avoid the iter calculation if we can - if self.proxy.config.prometheus_enabled() { + if self.context.prometheus_enabled() { let encrypted_count = encrypted.iter().filter(|e| e.is_some()).count() as u64; counter!(ENCRYPTION_REQUESTS_TOTAL).increment(1); @@ -984,7 +967,7 @@ where warn!( client_id = self.context.client_id, msg = "Internal Error in EQL Mapper", - mapping_errors_enabled = self.proxy.config.mapping_errors_enabled(), + mapping_errors_enabled = self.context.mapping_errors_enabled(), error = str, ); counter!(STATEMENTS_UNMAPPABLE_TOTAL).increment(1); @@ -994,7 +977,7 @@ where warn!( client_id = self.context.client_id, msg = "Unmappable statement", - mapping_errors_enabled = self.proxy.config.mapping_errors_enabled(), + mapping_errors_enabled = self.context.mapping_errors_enabled(), error = err.to_string(), ); counter!(STATEMENTS_UNMAPPABLE_TOTAL).increment(1); @@ -1003,158 +986,6 @@ where } } - /// - /// Maps typed statement projection columns to an Encrypt column configuration - /// - /// The returned `Vec` is of `Option` because the Projection columns are a mix of native and EQL types. - /// Only EQL colunms will have a configuration. Native types are always None. - /// - /// Preserves the ordering and semantics of the projection to reduce the complexity of positional encryption. - /// - fn get_projection_columns( - &self, - typed_statement: &eql_mapper::TypeCheckedStatement<'_>, - ) -> Result>, Error> { - let mut projection_columns = vec![]; - - for col in typed_statement.projection.columns() { - let eql_mapper::ProjectionColumn { ty, .. } = col; - let configured_column = match &**ty { - eql_mapper::Type::Value(eql_mapper::Value::Eql(eql_term)) => { - let TableColumn { table, column } = eql_term.table_column(); - let identifier: Identifier = Identifier::from((table, column)); - - debug!( - target: MAPPER, - client_id = self.context.client_id, - msg = "Configured column", - column = ?identifier, - ?eql_term, - ); - self.get_column(identifier, eql_term)? - } - _ => None, - }; - projection_columns.push(configured_column) - } - - Ok(projection_columns) - } - - /// - /// Maps typed statement param columns to an Encrypt column configuration - /// - /// The returned `Vec` is of `Option` because the Param columns are a mix of native and EQL types. - /// Only EQL colunms will have a configuration. Native types are always None. - /// - /// Preserves the ordering and semantics of the projection to reduce the complexity of positional encryption. - /// - /// - fn get_param_columns( - &self, - typed_statement: &eql_mapper::TypeCheckedStatement<'_>, - ) -> Result>, Error> { - let mut param_columns = vec![]; - - for param in typed_statement.params.iter() { - let configured_column = match param { - (_, eql_mapper::Value::Eql(eql_term)) => { - let TableColumn { table, column } = eql_term.table_column(); - let identifier = Identifier::from((table, column)); - - debug!( - target: MAPPER, - client_id = self.context.client_id, - msg = "Encrypted parameter", - column = ?identifier, - ?eql_term, - ); - - self.get_column(identifier, eql_term)? - } - _ => None, - }; - param_columns.push(configured_column); - } - - Ok(param_columns) - } - - fn get_literal_columns( - &self, - typed_statement: &eql_mapper::TypeCheckedStatement<'_>, - ) -> Result>, Error> { - let mut literal_columns = vec![]; - - for (eql_term, _) in typed_statement.literals.iter() { - let TableColumn { table, column } = eql_term.table_column(); - let identifier = Identifier::from((table, column)); - - debug!( - target: MAPPER, - client_id = self.context.client_id, - msg = "Encrypted literal", - column = ?identifier, - ?eql_term, - ); - let col = self.get_column(identifier, eql_term)?; - if col.is_some() { - literal_columns.push(col); - } - } - - Ok(literal_columns) - } - - /// - /// Get the column configuration for the Identifier - /// Returns `EncryptError::UnknownColumn` if configuration cannot be found for the Identified column - /// if mapping enabled, and None if mapping is disabled. It'll log a warning either way. - fn get_column( - &self, - identifier: Identifier, - eql_term: &EqlTerm, - ) -> Result, Error> { - match self.proxy.get_column_config(&identifier) { - Some(config) => { - debug!( - target: MAPPER, - client_id = self.context.client_id, - msg = "Configured column", - column = ?identifier - ); - - // IndexTerm::SteVecSelector - let postgres_type = if matches!(eql_term, EqlTerm::JsonPath(_)) { - Some(Type::JSONPATH) - } else { - None - }; - - let eql_term = eql_term.variant(); - Ok(Some(Column::new( - identifier, - config, - postgres_type, - eql_term, - ))) - } - None => { - warn!( - target: MAPPER, - client_id = self.context.client_id, - msg = "Configured column not found. Encryption configuration may have been deleted.", - ?identifier, - ); - Err(EncryptError::UnknownColumn { - table: identifier.table.to_owned(), - column: identifier.column.to_owned(), - } - .into()) - } - } - } - /// /// Send an ReadyForQuery to the client and remove error state. /// @@ -1229,10 +1060,11 @@ where } /// Implementation of PostgreSQL error handling for the Frontend component. -impl PostgreSqlErrorHandler for Frontend +impl PostgreSqlErrorHandler for Frontend where R: AsyncRead + Unpin, W: AsyncWrite + Unpin, + S: EncryptionService, { fn client_sender(&mut self) -> &mut Sender { &mut self.client_sender diff --git a/packages/cipherstash-proxy/src/postgresql/handler.rs b/packages/cipherstash-proxy/src/postgresql/handler.rs index 0b3494a7..b13734b3 100644 --- a/packages/cipherstash-proxy/src/postgresql/handler.rs +++ b/packages/cipherstash-proxy/src/postgresql/handler.rs @@ -4,7 +4,6 @@ use super::protocol::StartupCode; use crate::connect::ChannelWriter; use crate::error::ConfigError; use crate::log::{AUTHENTICATION, PROTOCOL}; -use crate::postgresql::context::Context; use crate::postgresql::messages::authentication::auth::{AuthenticationMethod, SaslMechanism}; use crate::postgresql::messages::authentication::sasl::SASLResponse; use crate::postgresql::messages::authentication::{ @@ -12,10 +11,11 @@ use crate::postgresql::messages::authentication::{ }; use crate::postgresql::messages::error_response::ErrorResponse; use crate::postgresql::{protocol, startup}; +use crate::proxy::ZeroKms; use crate::{ connect::AsyncStream, error::{Error, ProtocolError}, - proxy::Proxy, + postgresql::context::Context, tls, }; use bytes::BytesMut; @@ -52,33 +52,27 @@ use tracing::{debug, error, info, warn}; /// Propagate and continue /// /// -pub async fn handler( - client_stream: AsyncStream, - proxy: Proxy, - client_id: i32, -) -> Result<(), Error> { +pub async fn handler(client_stream: AsyncStream, context: Context) -> Result<(), Error> { let mut client_stream = client_stream; + let client_id = context.client_id; // Connect to the database server, using TLS if configured - let stream = AsyncStream::connect(&proxy.config.database.to_socket_address()).await?; - let mut database_stream = startup::with_tls(stream, &proxy.config).await?; + let stream = AsyncStream::connect(&context.database_socket_address()).await?; + let mut database_stream = startup::with_tls(stream, context.config()).await?; info!( msg = "Client connected", - database = proxy.config.database.to_socket_address(), + database = context.database_socket_address(), client_id = client_id, ); loop { - let startup_message = startup::read_message( - &mut client_stream, - proxy.config.database.connection_timeout(), - ) - .await?; + let startup_message = + startup::read_message(&mut client_stream, context.connection_timeout()).await?; match &startup_message.code { StartupCode::SSLRequest => { - startup::send_ssl_response(&proxy, &mut client_stream).await?; - if let Some(ref tls) = proxy.config.tls { + startup::send_ssl_response(&mut client_stream, context.use_tls()).await?; + if let Some(ref tls) = context.tls_config() { match client_stream { AsyncStream::Tcp(stream) => { // The Client is connecting to our Server @@ -112,8 +106,8 @@ pub async fn handler( { let salt = generate_md5_password_salt(); - let username = proxy.config.database.username.as_bytes(); - let password = proxy.config.database.password(); + let username = context.database_username().as_bytes(); + let password = context.database_password(); let password = password.as_bytes(); @@ -123,7 +117,7 @@ pub async fn handler( let bytes = BytesMut::try_from(message)?; client_stream.write_all(&bytes).await?; - let connection_timeout = proxy.config.database.connection_timeout(); + let connection_timeout = context.connection_timeout(); let (_code, bytes) = protocol::read_message(&mut client_stream, client_id, connection_timeout).await?; @@ -161,15 +155,15 @@ pub async fn handler( } AuthenticationMethod::AuthenticationCleartextPassword => { debug!(target: AUTHENTICATION, msg = "AuthenticationCleartextPassword"); - let password = proxy.config.database.password(); + let password = context.database_password(); let message = PasswordMessage::new(password); let bytes = BytesMut::try_from(message)?; database_stream.write_all(&bytes).await?; } AuthenticationMethod::Md5Password { salt } => { debug!(target: AUTHENTICATION, msg = "Md5Password"); - let username = proxy.config.database.username.as_bytes(); - let password = proxy.config.database.password(); + let username = context.database_username().as_bytes(); + let password = context.database_password(); let password = password.as_bytes(); let hash = md5_hash(username, password, salt); @@ -186,7 +180,7 @@ pub async fn handler( // If we are connected via TLS, we can support SCRAM-SHA-256-PLUS // If we are not connected via TLS, the database won't ask for SCRAM-SHA-256-PLUS let channel_binding = database_stream.channel_binding(); - let password = proxy.config.database.password(); + let password = context.database_password(); let password = password.as_bytes(); scram_sha_256_plus_handler(&mut database_stream, mechanism, password, channel_binding) .await?; @@ -204,7 +198,7 @@ pub async fn handler( } } - if proxy.config.server.require_tls && !client_stream.is_tls() { + if context.require_tls() && !client_stream.is_tls() { let message = ErrorResponse::tls_required(); let bytes = BytesMut::try_from(message)?; client_stream.write_all(&bytes).await?; @@ -218,25 +212,16 @@ pub async fn handler( let channel_writer = ChannelWriter::new(client_writer, client_id); - let schema = proxy.schema.load(); - let context = Context::new(client_id, schema); - let mut frontend = Frontend::new( client_reader, channel_writer.sender(), server_writer, - proxy.clone(), - context.clone(), - ); - let mut backend = Backend::new( - channel_writer.sender(), - server_reader, - proxy.clone(), context.clone(), ); + let mut backend = Backend::new(channel_writer.sender(), server_reader, context.clone()); - if proxy.is_passthrough() { - if proxy.config.use_structured_logging() { + if context.is_passthrough() { + if context.use_structured_logging() { warn!(msg = "RUNNING IN PASSTHROUGH MODE"); warn!(msg = "DATA IS NOT PROTECTED WITH ENCRYPTION"); } else { diff --git a/packages/cipherstash-proxy/src/postgresql/messages/describe.rs b/packages/cipherstash-proxy/src/postgresql/messages/describe.rs index efe05e59..6f0bbd8e 100644 --- a/packages/cipherstash-proxy/src/postgresql/messages/describe.rs +++ b/packages/cipherstash-proxy/src/postgresql/messages/describe.rs @@ -28,7 +28,7 @@ use super::{FrontendCode, Name}; /// The name of the prepared statement or portal to describe (an empty string selects the unnamed prepared statement or portal). #[derive(Debug, Clone)] -pub(crate) struct Describe { +pub struct Describe { pub target: Target, pub name: Name, } diff --git a/packages/cipherstash-proxy/src/postgresql/mod.rs b/packages/cipherstash-proxy/src/postgresql/mod.rs index e589bb92..71c8df24 100644 --- a/packages/cipherstash-proxy/src/postgresql/mod.rs +++ b/packages/cipherstash-proxy/src/postgresql/mod.rs @@ -1,4 +1,5 @@ mod backend; +mod column_mapper; mod context; mod data; mod error_handler; @@ -12,6 +13,7 @@ mod protocol; mod startup; pub use context::column::Column; +pub use context::Context; pub use context::KeysetIdentifier; pub use handler::handler; diff --git a/packages/cipherstash-proxy/src/postgresql/startup.rs b/packages/cipherstash-proxy/src/postgresql/startup.rs index 87055887..a21e64f9 100644 --- a/packages/cipherstash-proxy/src/postgresql/startup.rs +++ b/packages/cipherstash-proxy/src/postgresql/startup.rs @@ -12,7 +12,6 @@ use crate::{ error::{Error, ProtocolError}, log::PROTOCOL, postgresql::{SSL_REQUEST, SSL_RESPONSE_NO, SSL_RESPONSE_YES}, - proxy::Proxy, tls, TandemConfig, SIZE_I32, }; @@ -151,13 +150,10 @@ pub async fn send_ssl_request( /// The SSLResponse MUST come before the TLS handshake /// pub async fn send_ssl_response( - proxy: &Proxy, stream: &mut T, + tls: bool, ) -> Result<(), Error> { - let response = match proxy.config.tls { - Some(_) => b'S', - None => b'N', - }; + let response = if tls { b'S' } else { b'N' }; debug!(target: PROTOCOL, msg = "SSLResponse to Client", SSLResponse = ?response); diff --git a/packages/cipherstash-proxy/src/proxy/config/mod.rs b/packages/cipherstash-proxy/src/proxy/config/mod.rs deleted file mode 100644 index 819e4961..00000000 --- a/packages/cipherstash-proxy/src/proxy/config/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -mod encrypt_config; -mod manager; - -pub use manager::EncryptConfigManager; diff --git a/packages/cipherstash-proxy/src/proxy/config/encrypt_config.rs b/packages/cipherstash-proxy/src/proxy/encrypt_config/config.rs similarity index 97% rename from packages/cipherstash-proxy/src/proxy/config/encrypt_config.rs rename to packages/cipherstash-proxy/src/proxy/encrypt_config/config.rs index 730e678e..65b758f0 100644 --- a/packages/cipherstash-proxy/src/proxy/config/encrypt_config.rs +++ b/packages/cipherstash-proxy/src/proxy/encrypt_config/config.rs @@ -1,7 +1,6 @@ use crate::{ eql, error::{ConfigError, Error}, - log::KEYSET, }; use cipherstash_client::schema::{ column::{Index, IndexType, TokenFilter, Tokenizer}, @@ -9,7 +8,13 @@ use cipherstash_client::schema::{ }; use serde::{Deserialize, Serialize}; use std::{collections::HashMap, str::FromStr}; -use tracing::debug; + +#[derive(Debug, Deserialize, Serialize, Clone, Default)] +pub struct ColumnEncryptionConfig { + #[serde(rename = "v")] + pub version: u32, + pub tables: Tables, +} #[derive(Debug, Deserialize, Serialize, Clone, Default)] pub struct Tables(HashMap); @@ -35,13 +40,6 @@ impl IntoIterator for Table { } } -#[derive(Debug, Deserialize, Serialize, Clone, Default)] -pub struct EncryptConfig { - #[serde(rename = "v")] - pub version: u32, - pub tables: Tables, -} - #[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)] pub struct Column { #[serde(default)] @@ -133,7 +131,7 @@ impl From for ColumnType { } } -impl FromStr for EncryptConfig { +impl FromStr for ColumnEncryptionConfig { type Err = Error; fn from_str(data: &str) -> Result { @@ -142,7 +140,7 @@ impl FromStr for EncryptConfig { } } -impl EncryptConfig { +impl ColumnEncryptionConfig { pub fn is_empty(&self) -> bool { self.tables.0.is_empty() } @@ -151,7 +149,6 @@ impl EncryptConfig { let mut map = HashMap::new(); for (table_name, columns) in self.tables.into_iter() { for (column_name, column) in columns.into_iter() { - debug!(target: KEYSET, msg = "Configured column", table = table_name, column = column_name); let column_config = column.into_column_config(&column_name); let key = eql::Identifier::new(&table_name, &column_name); map.insert(key, column_config); @@ -201,7 +198,7 @@ mod tests { use super::*; fn parse(json: serde_json::Value) -> HashMap { - serde_json::from_value::(json) + serde_json::from_value::(json) .map(|config| config.into_config_map()) .expect("Error ok") } diff --git a/packages/cipherstash-proxy/src/proxy/config/manager.rs b/packages/cipherstash-proxy/src/proxy/encrypt_config/manager.rs similarity index 84% rename from packages/cipherstash-proxy/src/proxy/config/manager.rs rename to packages/cipherstash-proxy/src/proxy/encrypt_config/manager.rs index 73b003ca..ed8bb44f 100644 --- a/packages/cipherstash-proxy/src/proxy/config/manager.rs +++ b/packages/cipherstash-proxy/src/proxy/encrypt_config/manager.rs @@ -12,7 +12,7 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; use tokio::{task::JoinHandle, time}; use tracing::{debug, error, info, warn}; -use super::encrypt_config::EncryptConfig; +use super::config::ColumnEncryptionConfig; /// /// Column configuration keyed by table name and column name @@ -20,10 +20,41 @@ use super::encrypt_config::EncryptConfig; /// type EncryptConfigMap = HashMap; +#[derive(Clone, Debug)] +pub struct EncryptConfig { + config: EncryptConfigMap, +} + +impl EncryptConfig { + pub fn new_from_config(config: EncryptConfigMap) -> Self { + Self { config } + } + + pub fn new() -> Self { + Self { + config: HashMap::new(), + } + } + + pub fn is_empty(&self) -> bool { + self.config.is_empty() + } + + pub fn get_column_config(&self, identifier: &eql::Identifier) -> Option { + self.config.get(identifier).cloned() + } +} + +impl Default for EncryptConfig { + fn default() -> Self { + Self::new() + } +} + #[derive(Clone, Debug)] pub struct EncryptConfigManager { config: DatabaseConfig, - encrypt_config: Arc>, + encrypt_config: Arc>, _reload_handle: Arc>, } @@ -33,7 +64,7 @@ impl EncryptConfigManager { init_reloader(config).await } - pub fn load(&self) -> Arc { + pub fn load(&self) -> Arc { self.encrypt_config.load().clone() } @@ -79,7 +110,7 @@ async fn init_reloader(config: DatabaseConfig) -> Result Result Result { +async fn load_encrypt_config_with_retry(config: &DatabaseConfig) -> Result { let mut retry_count = 0; let max_retry_count = 10; let max_backoff = Duration::from_secs(2); @@ -170,21 +199,23 @@ async fn load_encrypt_config_with_retry( } } -pub async fn load_encrypt_config(config: &DatabaseConfig) -> Result { +pub async fn load_encrypt_config(config: &DatabaseConfig) -> Result { let client = connect::database(config).await?; match client.query(ENCRYPT_CONFIG_QUERY, &[]).await { Ok(rows) => { if rows.is_empty() { - return Ok(EncryptConfigMap::new()); + return Ok(EncryptConfig::new()); }; // We know there is at least one row let row = rows.first().unwrap(); let json_value: Value = row.get("data"); - let encrypt_config: EncryptConfig = serde_json::from_value(json_value)?; - Ok(encrypt_config.into_config_map()) + let encrypt_config: ColumnEncryptionConfig = serde_json::from_value(json_value)?; + let encrypt_config = EncryptConfig::new_from_config(encrypt_config.into_config_map()); + + Ok(encrypt_config) } Err(err) => { if configuration_table_not_found(&err) { diff --git a/packages/cipherstash-proxy/src/proxy/encrypt_config/mod.rs b/packages/cipherstash-proxy/src/proxy/encrypt_config/mod.rs new file mode 100644 index 00000000..57e5563e --- /dev/null +++ b/packages/cipherstash-proxy/src/proxy/encrypt_config/mod.rs @@ -0,0 +1,4 @@ +mod config; +mod manager; + +pub use manager::{EncryptConfig, EncryptConfigManager}; diff --git a/packages/cipherstash-proxy/src/proxy/mod.rs b/packages/cipherstash-proxy/src/proxy/mod.rs index eff55840..dfabc6c0 100644 --- a/packages/cipherstash-proxy/src/proxy/mod.rs +++ b/packages/cipherstash-proxy/src/proxy/mod.rs @@ -1,18 +1,30 @@ +use std::sync::Arc; + use crate::{ config::TandemConfig, - connect, eql, + connect, error::Error, - log::PROXY, - postgresql::{Column, KeysetIdentifier}, - proxy::{config::EncryptConfigManager, schema::SchemaManager, zerokms::ZeroKms}, + postgresql::{Column, Context, KeysetIdentifier}, + proxy::{encrypt_config::EncryptConfigManager, schema::SchemaManager}, }; -use cipherstash_client::{encryption::Plaintext, schema::ColumnConfig}; +use cipherstash_client::encryption::Plaintext; +use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; +use tokio::sync::oneshot::Sender; use tracing::{debug, warn}; -mod config; +mod encrypt_config; mod schema; mod zerokms; +pub use encrypt_config::EncryptConfig; +pub use zerokms::ZeroKms; + +pub type ReloadSender = UnboundedSender; + +type ReloadReceiver = UnboundedReceiver; + +pub type ReloadResponder = Sender<()>; + /// SQL Statement for loading encrypt configuration from database const ENCRYPT_CONFIG_QUERY: &str = include_str!("./sql/select_config.sql"); @@ -22,17 +34,23 @@ const SCHEMA_QUERY: &str = include_str!("./sql/select_table_schemas.sql"); /// SQL Statement for loading aggregates as part of database schema const AGGREGATE_QUERY: &str = include_str!("./sql/select_aggregates.sql"); +#[derive(Debug)] +pub enum ReloadCommand { + DatabaseSchema(ReloadResponder), + EncryptSchema(ReloadResponder), +} + /// /// Core proxy service providing encryption, configuration, and schema management. /// -#[derive(Clone)] pub struct Proxy { - pub config: TandemConfig, - pub encrypt_config: EncryptConfigManager, - pub schema: SchemaManager, + pub config: Arc, + pub encrypt_config_manager: EncryptConfigManager, + pub schema_manager: SchemaManager, /// The EQL version installed in the database or `None` if it was not present pub eql_version: Option, zerokms: ZeroKms, + reload_sender: ReloadSender, } impl Proxy { @@ -43,96 +61,109 @@ impl Proxy { // Ensures error on start if credential or network issue zerokms.init_cipher(None).await?; - let encrypt_config = EncryptConfigManager::init(&config.database).await?; - // TODO: populate EqlTraitImpls based in config - let schema = SchemaManager::init(&config.database).await?; - - let eql_version = { - let client = connect::database(&config.database).await?; - let rows = client - .query("SELECT eql_v2.version() AS version;", &[]) - .await; - - match rows { - Ok(rows) => rows.first().map(|row| row.get("version")), - Err(err) => { - warn!( - msg = "Could not query EQL version from database", - error = err.to_string() - ); - None - } - } - }; + let encrypt_config_manager = EncryptConfigManager::init(&config.database).await?; + + let schema_manager = SchemaManager::init(&config.database).await?; + + let eql_version = Proxy::eql_version(&config).await?; + + let (reload_sender, reload_receiver) = mpsc::unbounded_channel(); + + Proxy::receive( + reload_receiver, + schema_manager.clone(), + encrypt_config_manager.clone(), + ); Ok(Proxy { - config, + config: Arc::new(config), zerokms, - encrypt_config, - schema, + encrypt_config_manager, + schema_manager, eql_version, + reload_sender, }) } + pub async fn eql_version(config: &TandemConfig) -> Result, Error> { + let client = connect::database(&config.database).await?; + let rows = client + .query("SELECT eql_v2.version() AS version;", &[]) + .await; + + let version = match rows { + Ok(rows) => rows.first().map(|row| row.get("version")), + Err(err) => { + warn!( + msg = "Could not query EQL version from database", + error = err.to_string() + ); + None + } + }; + Ok(version) + } + + pub fn receive( + mut reload_receiver: ReloadReceiver, + schema_manager: SchemaManager, + encrypt_config_manager: EncryptConfigManager, + ) { + tokio::task::spawn(async move { + while let Some(command) = reload_receiver.recv().await { + debug!(msg = "ReloadCommand received", ?command); + match command { + ReloadCommand::DatabaseSchema(responder) => { + schema_manager.reload().await; + encrypt_config_manager.reload().await; + let _ = responder.send(()); + } + ReloadCommand::EncryptSchema(responder) => { + encrypt_config_manager.reload().await; + let _ = responder.send(()); + } + } + } + }); + } + /// - /// Encrypt `Plaintexts` using the `Column` configuration + /// Create a new context from the Proxy settings /// - pub async fn encrypt( + pub fn context(&self, client_id: i32) -> Context { + let config = self.config.clone(); + let encrypt_config = self.encrypt_config_manager.load(); + let schema = self.schema_manager.load(); + let reload_sender = self.reload_sender.clone(); + let encryption = self.zerokms.clone(); + + Context::new( + client_id, + config, + encrypt_config, + schema, + encryption, + reload_sender, + ) + } +} + +#[async_trait::async_trait] +pub trait EncryptionService: Send + Sync { + /// Encrypt plaintexts for storage in the database + async fn encrypt( &self, keyset_id: Option, plaintexts: Vec>, columns: &[Option], - ) -> Result>, Error> { - debug!(target: PROXY, msg="Encrypt", ?keyset_id, default_keyset_id = ?self.config.encrypt.default_keyset_id); - - self.zerokms - .encrypt( - keyset_id, - plaintexts, - columns, - self.config.encrypt.default_keyset_id, - ) - .await - } + ) -> Result>, Error>; - /// - /// Decrypt eql::Ciphertext into Plaintext - /// - /// Database values are stored as `eql::Ciphertext` - /// - pub async fn decrypt( + /// Decrypt values retrieved from the database + async fn decrypt( &self, keyset_id: Option, - ciphertexts: Vec>, - ) -> Result>, Error> { - debug!(target: PROXY, msg="Decrypt", ?keyset_id, default_keyset_id = ?self.config.encrypt.default_keyset_id); - - self.zerokms - .decrypt( - keyset_id, - ciphertexts, - self.config.encrypt.default_keyset_id, - ) - .await - } - - pub fn get_column_config(&self, identifier: &eql::Identifier) -> Option { - let encrypt_config = self.encrypt_config.load(); - encrypt_config.get(identifier).cloned() - } - - pub async fn reload_schema(&self) { - self.schema.reload().await; - self.encrypt_config.reload().await; - } - - pub fn is_passthrough(&self) -> bool { - self.encrypt_config.is_empty() || self.config.mapping_disabled() - } - - pub fn is_empty_config(&self) -> bool { - self.encrypt_config.is_empty() - } + ciphertexts: Vec>, + ) -> Result>, Error>; } #[cfg(test)] diff --git a/packages/cipherstash-proxy/src/proxy/schema/manager.rs b/packages/cipherstash-proxy/src/proxy/schema/manager.rs index 368d6c6a..5fa2710d 100644 --- a/packages/cipherstash-proxy/src/proxy/schema/manager.rs +++ b/packages/cipherstash-proxy/src/proxy/schema/manager.rs @@ -71,7 +71,7 @@ async fn init_reloader(config: DatabaseConfig) -> Result { } Err(err) => { warn!( - msg = "Error loading Encrypt configuration", + msg = "Error loading database schema", error = err.to_string() ); } @@ -143,7 +143,6 @@ pub async fn load_schema(config: &DatabaseConfig) -> Result { Some("eql_v2_encrypted") => { debug!(target: SCHEMA, msg = "eql_v2_encrypted column", table = table_name, column = col); - // TODO - map config to the set of implemented traits let eql_traits = EqlTraits::all(); Column::eql(ident, eql_traits) } diff --git a/packages/cipherstash-proxy/src/proxy/zerokms/zerokms.rs b/packages/cipherstash-proxy/src/proxy/zerokms/zerokms.rs index c40a37fe..c45e1d39 100644 --- a/packages/cipherstash-proxy/src/proxy/zerokms/zerokms.rs +++ b/packages/cipherstash-proxy/src/proxy/zerokms/zerokms.rs @@ -5,6 +5,7 @@ use crate::{ log::{ENCRYPT, PROXY}, postgresql::{Column, KeysetIdentifier}, prometheus::{KEYSET_CIPHER_CACHE_HITS_TOTAL, KEYSET_CIPHER_INIT_TOTAL}, + proxy::EncryptionService, }; use cipherstash_client::{ encryption::QueryOp, @@ -14,6 +15,7 @@ use metrics::counter; use moka::future::Cache; use std::{sync::Arc, time::Duration}; use tracing::{debug, info, warn}; +use uuid::Uuid; use super::{ init_zerokms_client, plaintext_type_name, to_eql_encrypted, to_eql_encrypted_from_index_term, @@ -25,6 +27,7 @@ const SCOPED_CIPHER_SIZE: usize = std::mem::size_of::(); #[derive(Clone)] pub struct ZeroKms { + default_keyset_id: Option, zerokms_client: Arc, cipher_cache: Cache>, } @@ -43,7 +46,10 @@ impl ZeroKms { .time_to_live(Duration::from_secs(config.server.cipher_cache_ttl_seconds)) .build(); + let default_keyset_id = config.encrypt.default_keyset_id; + Ok(ZeroKms { + default_keyset_id, zerokms_client: Arc::new(zerokms_client), cipher_cache, }) @@ -118,21 +124,23 @@ impl ZeroKms { } } } +} +#[async_trait::async_trait] +impl EncryptionService for ZeroKms { /// /// Encrypt `Plaintexts` using the `Column` configuration /// - pub async fn encrypt( + async fn encrypt( &self, keyset_id: Option, plaintexts: Vec>, columns: &[Option], - default_keyset_id: Option, ) -> Result>, Error> { - debug!(target: ENCRYPT, msg="Encrypt", ?keyset_id, ?default_keyset_id); + debug!(target: ENCRYPT, msg="Encrypt", ?keyset_id, default_keyset_id = ?self.default_keyset_id); // A keyset is required if no default keyset has been configured - if default_keyset_id.is_none() && keyset_id.is_none() { + if self.default_keyset_id.is_none() && keyset_id.is_none() { return Err(EncryptError::MissingKeysetIdentifier.into()); } @@ -200,17 +208,17 @@ impl ZeroKms { /// /// Database values are stored as `eql::Ciphertext` /// - pub async fn decrypt( + async fn decrypt( &self, keyset_id: Option, ciphertexts: Vec>, - default_keyset_id: Option, ) -> Result>, Error> { + debug!(target: ENCRYPT, msg="Decrypt", ?keyset_id, default_keyset_id = ?self.default_keyset_id); + // A keyset is required if no default keyset has been configured - if default_keyset_id.is_none() && keyset_id.is_none() { + if self.default_keyset_id.is_none() && keyset_id.is_none() { return Err(EncryptError::MissingKeysetIdentifier.into()); } - debug!(target: ENCRYPT, msg="Decrypt", ?keyset_id, ?default_keyset_id); let cipher = self.init_cipher(keyset_id.clone()).await?; diff --git a/packages/eql-mapper/src/model/schema.rs b/packages/eql-mapper/src/model/schema.rs index 78ec4cf7..a06cd1ea 100644 --- a/packages/eql-mapper/src/model/schema.rs +++ b/packages/eql-mapper/src/model/schema.rs @@ -84,7 +84,6 @@ impl Schema { S: Into, { let name = Ident::new(name); - // name.quote_style = Some('"'); Self { name, diff --git a/packages/showcase/src/data.rs b/packages/showcase/src/data.rs index 47c77b61..a0085c19 100644 --- a/packages/showcase/src/data.rs +++ b/packages/showcase/src/data.rs @@ -496,8 +496,6 @@ pub async fn clear() { // // Deleting rows from the eql_v2_configuration table is not officially supported due to the risk of data loss. // - // TODO: EQL should support safe removal of config rows - at least in some kind of "test" or non-production - // mode. let sql = r#" DELETE FROM public.eql_v2_configuration diff --git a/packages/showcase/src/main.rs b/packages/showcase/src/main.rs index 1542a5a3..456fa1be 100644 --- a/packages/showcase/src/main.rs +++ b/packages/showcase/src/main.rs @@ -66,17 +66,18 @@ use crate::{ #[tokio::main] async fn main() -> Result<(), Box> { + println!("🩺 Healthcare Database Showcase - EQL v2 Searchable Encryption"); + println!("============================================================"); + trace(); clear().await; + setup_schema().await; insert_test_data().await; create_enhanced_jsonb_test_data().await; let client = connect_with_tls(PROXY).await; - println!("🩺 Healthcare Database Showcase - EQL v2 Searchable Encryption"); - println!("============================================================"); - // Query 1: Get the Aspirin medication ID let aspirin_id_sql = "SELECT id FROM medications WHERE name = 'Aspirin';"; let rows = client.query(aspirin_id_sql, &[]).await.unwrap();