Skip to content

Commit a8d5c59

Browse files
committed
♻️ refactor: implement service-oriented architecture for PostgreSQL frontend
- Create service traits for encryption and schema operations in services/mod.rs - Refactor Context to use service trait objects instead of direct Proxy dependency - Remove proxy field from Frontend, making it focus solely on PostgreSQL message handling - Context now serves as central coordinator with service delegation methods - Add backward compatibility with new_with_proxy() helper method - Implement service traits on Proxy for seamless migration - Add comprehensive test infrastructure with mock services - Maintain all existing functionality while improving separation of concerns This architectural change makes the frontend more focused, testable, and follows proper dependency injection patterns while preserving backward compatibility.
1 parent 73c35df commit a8d5c59

File tree

9 files changed

+298
-39
lines changed

9 files changed

+298
-39
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/cipherstash-proxy/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "2.0.0"
44
edition = "2021"
55

66
[dependencies]
7+
async-trait = "0.1"
78
aws-lc-rs = "1.13.3"
89
bigdecimal = { version = "0.4.6", features = ["serde-json"] }
910
arc-swap = "1.7.1"

packages/cipherstash-proxy/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ pub mod log;
99
pub mod postgresql;
1010
pub mod prometheus;
1111
pub mod proxy;
12+
pub mod services;
1213
pub mod tls;
1314

1415
pub use crate::cli::Args;

packages/cipherstash-proxy/src/postgresql/context/mod.rs

Lines changed: 183 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ use super::{
66
Column,
77
};
88
use crate::{
9+
config::TandemConfig,
910
error::{EncryptError, Error},
1011
log::CONTEXT,
1112
prometheus::{STATEMENTS_EXECUTION_DURATION_SECONDS, STATEMENTS_SESSION_DURATION_SECONDS},
13+
services::{EncryptionService, SchemaService},
1214
};
1315
use cipherstash_client::IdentifiedBy;
1416
use eql_mapper::{Schema, TableResolver};
@@ -36,9 +38,12 @@ impl std::fmt::Display for KeysetIdentifier {
3638
}
3739
}
3840

39-
#[derive(Clone, Debug)]
41+
#[derive(Clone)]
4042
pub struct Context {
4143
pub client_id: i32,
44+
config: Arc<TandemConfig>,
45+
encryption: Arc<dyn EncryptionService>,
46+
schema: Arc<dyn SchemaService>,
4247
statements: Arc<RwLock<HashMap<Name, Arc<Statement>>>>,
4348
portals: Arc<RwLock<HashMap<Name, PortalQueue>>>,
4449
describe: Arc<RwLock<DescribeQueue>>,
@@ -112,7 +117,13 @@ pub enum Portal {
112117
}
113118

114119
impl Context {
115-
pub fn new(client_id: i32, schema: Arc<Schema>) -> Context {
120+
pub fn new(
121+
client_id: i32,
122+
schema: Arc<Schema>,
123+
config: Arc<TandemConfig>,
124+
encryption: Arc<dyn EncryptionService>,
125+
schema_service: Arc<dyn SchemaService>,
126+
) -> Context {
116127
Context {
117128
statements: Arc::new(RwLock::new(HashMap::new())),
118129
portals: Arc::new(RwLock::new(HashMap::new())),
@@ -122,6 +133,9 @@ impl Context {
122133
session_metrics: Arc::new(RwLock::from(Queue::new())),
123134
table_resolver: Arc::new(TableResolver::new_editable(schema)),
124135
client_id,
136+
config,
137+
encryption,
138+
schema: schema_service,
125139
unsafe_disable_mapping: false,
126140
keyset_id: Arc::new(RwLock::new(None)),
127141
}
@@ -477,6 +491,159 @@ impl Context {
477491
pub fn keyset_identifier(&self) -> Option<KeysetIdentifier> {
478492
self.keyset_id.read().ok().and_then(|k| k.clone())
479493
}
494+
495+
// Service delegation methods
496+
pub async fn encrypt(
497+
&self,
498+
keyset_id: Option<KeysetIdentifier>,
499+
plaintexts: Vec<Option<cipherstash_client::encryption::Plaintext>>,
500+
columns: &[Option<Column>],
501+
) -> Result<Vec<Option<crate::EqlEncrypted>>, Error> {
502+
self.encryption
503+
.encrypt(keyset_id, plaintexts, columns)
504+
.await
505+
}
506+
507+
pub async fn decrypt(
508+
&self,
509+
keyset_id: Option<KeysetIdentifier>,
510+
ciphertexts: Vec<Option<crate::EqlEncrypted>>,
511+
) -> Result<Vec<Option<cipherstash_client::encryption::Plaintext>>, Error> {
512+
self.encryption.decrypt(keyset_id, ciphertexts).await
513+
}
514+
515+
pub async fn reload_schema(&self) {
516+
self.schema.reload_schema().await;
517+
self.set_schema_changed();
518+
}
519+
520+
pub fn get_column_config(
521+
&self,
522+
identifier: &crate::eql::Identifier,
523+
) -> Option<cipherstash_client::schema::ColumnConfig> {
524+
self.schema.get_column_config(identifier)
525+
}
526+
527+
pub fn is_passthrough(&self) -> bool {
528+
self.schema.is_passthrough()
529+
}
530+
531+
pub fn is_empty_config(&self) -> bool {
532+
self.schema.is_empty_config()
533+
}
534+
535+
// Direct config access methods
536+
pub fn connection_timeout(&self) -> std::time::Duration {
537+
self.config
538+
.database
539+
.connection_timeout()
540+
.unwrap_or_else(|| std::time::Duration::from_secs(10))
541+
}
542+
543+
pub fn mapping_disabled(&self) -> bool {
544+
self.config.mapping_disabled()
545+
}
546+
547+
pub fn mapping_errors_enabled(&self) -> bool {
548+
self.config.mapping_errors_enabled()
549+
}
550+
551+
pub fn prometheus_enabled(&self) -> bool {
552+
self.config.prometheus_enabled()
553+
}
554+
555+
pub fn default_keyset_id(&self) -> Option<KeysetIdentifier> {
556+
self.config
557+
.encrypt
558+
.default_keyset_id
559+
.map(|uuid| KeysetIdentifier(IdentifiedBy::Uuid(uuid)))
560+
}
561+
562+
/// Helper function for gradual migration - creates Context with Proxy implementing services
563+
pub fn new_with_proxy(
564+
client_id: i32,
565+
schema: Arc<Schema>,
566+
proxy: crate::proxy::Proxy,
567+
) -> Context {
568+
let config = Arc::new(proxy.config.clone());
569+
let proxy = Arc::new(proxy);
570+
Self::new(
571+
client_id,
572+
schema,
573+
config,
574+
proxy.clone(), // as EncryptionService
575+
proxy, // as SchemaService
576+
)
577+
}
578+
579+
/// Helper function for tests - creates Context with minimal mock services
580+
#[cfg(test)]
581+
pub fn new_for_test(client_id: i32, schema: Arc<Schema>) -> Context {
582+
// Create minimal mock services for testing
583+
struct MockEncryptionService;
584+
struct MockSchemaService;
585+
586+
#[async_trait::async_trait]
587+
impl EncryptionService for MockEncryptionService {
588+
async fn encrypt(
589+
&self,
590+
_keyset_id: Option<KeysetIdentifier>,
591+
plaintexts: Vec<Option<cipherstash_client::encryption::Plaintext>>,
592+
_columns: &[Option<Column>],
593+
) -> Result<Vec<Option<crate::EqlEncrypted>>, Error> {
594+
Ok(plaintexts.into_iter().map(|_| None).collect())
595+
}
596+
597+
async fn decrypt(
598+
&self,
599+
_keyset_id: Option<KeysetIdentifier>,
600+
ciphertexts: Vec<Option<crate::EqlEncrypted>>,
601+
) -> Result<Vec<Option<cipherstash_client::encryption::Plaintext>>, Error> {
602+
Ok(ciphertexts.into_iter().map(|_| None).collect())
603+
}
604+
}
605+
606+
#[async_trait::async_trait]
607+
impl SchemaService for MockSchemaService {
608+
async fn reload_schema(&self) {}
609+
fn get_column_config(
610+
&self,
611+
_identifier: &crate::eql::Identifier,
612+
) -> Option<cipherstash_client::schema::ColumnConfig> {
613+
None
614+
}
615+
fn get_table_resolver(&self) -> Arc<TableResolver> {
616+
Arc::new(TableResolver::new_editable(Arc::new(Schema::new("public"))))
617+
}
618+
fn is_passthrough(&self) -> bool {
619+
true
620+
}
621+
fn is_empty_config(&self) -> bool {
622+
true
623+
}
624+
}
625+
626+
// Set up minimal environment variables for testing
627+
std::env::set_var("CS_DATABASE__USERNAME", "test");
628+
std::env::set_var("CS_DATABASE__PASSWORD", "test");
629+
std::env::set_var("CS_DATABASE__NAME", "test");
630+
std::env::set_var("CS_DATABASE__HOST", "localhost");
631+
std::env::set_var("CS_DATABASE__PORT", "5432");
632+
std::env::set_var("CS_AUTH__WORKSPACE_CRN", "crn:ap-southeast-2.aws:test");
633+
std::env::set_var("CS_AUTH__CLIENT_ACCESS_KEY", "test");
634+
std::env::set_var("CS_ENCRYPT__CLIENT_ID", "test");
635+
std::env::set_var("CS_ENCRYPT__CLIENT_KEY", "test");
636+
637+
let config = Arc::new(
638+
crate::config::TandemConfig::build("tests/config/unknown.toml")
639+
.expect("Failed to create test config"),
640+
);
641+
642+
let encryption = Arc::new(MockEncryptionService) as Arc<dyn EncryptionService>;
643+
let schema_service = Arc::new(MockSchemaService) as Arc<dyn SchemaService>;
644+
645+
Self::new(client_id, schema, config, encryption, schema_service)
646+
}
480647
}
481648

482649
impl Statement {
@@ -629,7 +796,7 @@ mod tests {
629796

630797
let schema = Arc::new(Schema::new("public"));
631798

632-
let mut context = Context::new(1, schema);
799+
let mut context = Context::new_for_test(1, schema);
633800

634801
let name = Name::from("name");
635802

@@ -654,7 +821,7 @@ mod tests {
654821

655822
let schema = Arc::new(Schema::new("public"));
656823

657-
let mut context = Context::new(1, schema);
824+
let mut context = Context::new_for_test(1, schema);
658825

659826
let statement_name = Name::from("statement");
660827
let portal_name = Name::from("portal");
@@ -695,7 +862,7 @@ mod tests {
695862

696863
let schema = Arc::new(Schema::new("public"));
697864

698-
let mut context = Context::new(1, schema);
865+
let mut context = Context::new_for_test(1, schema);
699866

700867
// Create multiple statements
701868
let statement_name_1 = Name::from("statement_1");
@@ -745,7 +912,7 @@ mod tests {
745912

746913
let schema = Arc::new(Schema::new("public"));
747914

748-
let mut context = Context::new(1, schema);
915+
let mut context = Context::new_for_test(1, schema);
749916

750917
let statement_name_1 = Name::from("statement_1");
751918
let portal_name_1 = Name::unnamed();
@@ -817,7 +984,7 @@ mod tests {
817984
log::init(LogConfig::default());
818985

819986
let schema = Arc::new(Schema::new("public"));
820-
let mut context = Context::new(1, schema);
987+
let mut context = Context::new_for_test(1, schema);
821988

822989
let sql = "SET CIPHERSTASH.UNSAFE_DISABLE_MAPPING = true";
823990
let statement = parse_statement(sql);
@@ -867,7 +1034,7 @@ mod tests {
8671034
];
8681035

8691036
for s in sql {
870-
let mut context = Context::new(1, schema.clone());
1037+
let mut context = Context::new_for_test(1, schema.clone());
8711038
assert!(context.keyset_identifier().is_none());
8721039

8731040
let statement = parse_statement(s);
@@ -887,7 +1054,7 @@ mod tests {
8871054
log::init(LogConfig::default());
8881055

8891056
let schema = Arc::new(Schema::new("public"));
890-
let mut context = Context::new(1, schema);
1057+
let mut context = Context::new_for_test(1, schema);
8911058

8921059
// Returns OK if unknown command
8931060
let sql = "SET CIPHERSTASH.BLAH = 'keyset_id'";
@@ -931,7 +1098,7 @@ mod tests {
9311098
];
9321099

9331100
for s in sql {
934-
let mut context = Context::new(1, schema.clone());
1101+
let mut context = Context::new_for_test(1, schema.clone());
9351102
assert!(context.keyset_identifier().is_none());
9361103

9371104
let statement = parse_statement(s);
@@ -952,7 +1119,7 @@ mod tests {
9521119
log::init(LogConfig::default());
9531120

9541121
let schema = Arc::new(Schema::new("public"));
955-
let mut context = Context::new(1, schema);
1122+
let mut context = Context::new_for_test(1, schema);
9561123

9571124
// Returns OK if unknown command
9581125
let sql = "SET CIPHERSTASH.BLAH = 'keyset_name'";
@@ -991,7 +1158,7 @@ mod tests {
9911158
let schema = Arc::new(Schema::new("public"));
9921159

9931160
// Test keyset name with number
994-
let mut context = Context::new(1, schema.clone());
1161+
let mut context = Context::new_for_test(1, schema.clone());
9951162
let sql = "SET CIPHERSTASH.KEYSET_NAME = 12345";
9961163
let statement = parse_statement(sql);
9971164

@@ -1002,7 +1169,7 @@ mod tests {
10021169
assert_eq!(Some(identifier.clone()), context.keyset_identifier());
10031170

10041171
// Test keyset id with numeric UUID (should work if it's a valid UUID)
1005-
let mut context = Context::new(2, schema);
1172+
let mut context = Context::new_for_test(2, schema);
10061173
// This will fail because 123 is not a valid UUID, but it shows the number is processed
10071174
let sql = "SET CIPHERSTASH.KEYSET_ID = 123";
10081175
let statement = parse_statement(sql);
@@ -1019,7 +1186,7 @@ mod tests {
10191186
let schema = Arc::new(Schema::new("public"));
10201187

10211188
// Test that maybe_set_keyset handles both ID and name
1022-
let mut context = Context::new(1, schema.clone());
1189+
let mut context = Context::new_for_test(1, schema.clone());
10231190

10241191
// Test with keyset ID
10251192
let keyset_id_sql = "SET CIPHERSTASH.KEYSET_ID = '7d4cbd7f-ba0d-4985-9ed2-ebe2ffe77590'";
@@ -1035,7 +1202,7 @@ mod tests {
10351202
assert_eq!(Some(identifier.clone()), context.keyset_identifier());
10361203

10371204
// Test with keyset name
1038-
let mut context = Context::new(2, schema.clone());
1205+
let mut context = Context::new_for_test(2, schema.clone());
10391206
let keyset_name_sql = "SET CIPHERSTASH.KEYSET_NAME = 'test-keyset'";
10401207
let statement = parse_statement(keyset_name_sql);
10411208

@@ -1047,7 +1214,7 @@ mod tests {
10471214
assert_eq!(Some(identifier.clone()), context.keyset_identifier());
10481215

10491216
// Test with unknown command
1050-
let mut context = Context::new(3, schema);
1217+
let mut context = Context::new_for_test(3, schema);
10511218
let unknown_sql = "SET CIPHERSTASH.UNKNOWN = 'value'";
10521219
let statement = parse_statement(unknown_sql);
10531220
let result = context.maybe_set_keyset(&statement);

0 commit comments

Comments
 (0)