@@ -6,9 +6,11 @@ use super::{
66 Column ,
77} ;
88use crate :: {
9+ config:: TandemConfig ,
910 error:: { EncryptError , Error } ,
1011 log:: CONTEXT ,
1112 prometheus:: { STATEMENTS_EXECUTION_DURATION_SECONDS , STATEMENTS_SESSION_DURATION_SECONDS } ,
13+ services:: { EncryptionService , SchemaService } ,
1214} ;
1315use cipherstash_client:: IdentifiedBy ;
1416use eql_mapper:: { Schema , TableResolver } ;
@@ -36,9 +38,12 @@ impl std::fmt::Display for KeysetIdentifier {
3638 }
3739}
3840
39- #[ derive( Clone , Debug ) ]
41+ #[ derive( Clone ) ]
4042pub struct Context {
4143 pub client_id : i32 ,
44+ config : Arc < TandemConfig > ,
45+ encryption : Arc < dyn EncryptionService > ,
46+ schema : Arc < dyn SchemaService > ,
4247 statements : Arc < RwLock < HashMap < Name , Arc < Statement > > > > ,
4348 portals : Arc < RwLock < HashMap < Name , PortalQueue > > > ,
4449 describe : Arc < RwLock < DescribeQueue > > ,
@@ -112,7 +117,13 @@ pub enum Portal {
112117}
113118
114119impl Context {
115- pub fn new ( client_id : i32 , schema : Arc < Schema > ) -> Context {
120+ pub fn new (
121+ client_id : i32 ,
122+ schema : Arc < Schema > ,
123+ config : Arc < TandemConfig > ,
124+ encryption : Arc < dyn EncryptionService > ,
125+ schema_service : Arc < dyn SchemaService > ,
126+ ) -> Context {
116127 Context {
117128 statements : Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ,
118129 portals : Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ,
@@ -122,6 +133,9 @@ impl Context {
122133 session_metrics : Arc :: new ( RwLock :: from ( Queue :: new ( ) ) ) ,
123134 table_resolver : Arc :: new ( TableResolver :: new_editable ( schema) ) ,
124135 client_id,
136+ config,
137+ encryption,
138+ schema : schema_service,
125139 unsafe_disable_mapping : false ,
126140 keyset_id : Arc :: new ( RwLock :: new ( None ) ) ,
127141 }
@@ -477,6 +491,159 @@ impl Context {
477491 pub fn keyset_identifier ( & self ) -> Option < KeysetIdentifier > {
478492 self . keyset_id . read ( ) . ok ( ) . and_then ( |k| k. clone ( ) )
479493 }
494+
495+ // Service delegation methods
496+ pub async fn encrypt (
497+ & self ,
498+ keyset_id : Option < KeysetIdentifier > ,
499+ plaintexts : Vec < Option < cipherstash_client:: encryption:: Plaintext > > ,
500+ columns : & [ Option < Column > ] ,
501+ ) -> Result < Vec < Option < crate :: EqlEncrypted > > , Error > {
502+ self . encryption
503+ . encrypt ( keyset_id, plaintexts, columns)
504+ . await
505+ }
506+
507+ pub async fn decrypt (
508+ & self ,
509+ keyset_id : Option < KeysetIdentifier > ,
510+ ciphertexts : Vec < Option < crate :: EqlEncrypted > > ,
511+ ) -> Result < Vec < Option < cipherstash_client:: encryption:: Plaintext > > , Error > {
512+ self . encryption . decrypt ( keyset_id, ciphertexts) . await
513+ }
514+
515+ pub async fn reload_schema ( & self ) {
516+ self . schema . reload_schema ( ) . await ;
517+ self . set_schema_changed ( ) ;
518+ }
519+
520+ pub fn get_column_config (
521+ & self ,
522+ identifier : & crate :: eql:: Identifier ,
523+ ) -> Option < cipherstash_client:: schema:: ColumnConfig > {
524+ self . schema . get_column_config ( identifier)
525+ }
526+
527+ pub fn is_passthrough ( & self ) -> bool {
528+ self . schema . is_passthrough ( )
529+ }
530+
531+ pub fn is_empty_config ( & self ) -> bool {
532+ self . schema . is_empty_config ( )
533+ }
534+
535+ // Direct config access methods
536+ pub fn connection_timeout ( & self ) -> std:: time:: Duration {
537+ self . config
538+ . database
539+ . connection_timeout ( )
540+ . unwrap_or_else ( || std:: time:: Duration :: from_secs ( 10 ) )
541+ }
542+
543+ pub fn mapping_disabled ( & self ) -> bool {
544+ self . config . mapping_disabled ( )
545+ }
546+
547+ pub fn mapping_errors_enabled ( & self ) -> bool {
548+ self . config . mapping_errors_enabled ( )
549+ }
550+
551+ pub fn prometheus_enabled ( & self ) -> bool {
552+ self . config . prometheus_enabled ( )
553+ }
554+
555+ pub fn default_keyset_id ( & self ) -> Option < KeysetIdentifier > {
556+ self . config
557+ . encrypt
558+ . default_keyset_id
559+ . map ( |uuid| KeysetIdentifier ( IdentifiedBy :: Uuid ( uuid) ) )
560+ }
561+
562+ /// Helper function for gradual migration - creates Context with Proxy implementing services
563+ pub fn new_with_proxy (
564+ client_id : i32 ,
565+ schema : Arc < Schema > ,
566+ proxy : crate :: proxy:: Proxy ,
567+ ) -> Context {
568+ let config = Arc :: new ( proxy. config . clone ( ) ) ;
569+ let proxy = Arc :: new ( proxy) ;
570+ Self :: new (
571+ client_id,
572+ schema,
573+ config,
574+ proxy. clone ( ) , // as EncryptionService
575+ proxy, // as SchemaService
576+ )
577+ }
578+
579+ /// Helper function for tests - creates Context with minimal mock services
580+ #[ cfg( test) ]
581+ pub fn new_for_test ( client_id : i32 , schema : Arc < Schema > ) -> Context {
582+ // Create minimal mock services for testing
583+ struct MockEncryptionService ;
584+ struct MockSchemaService ;
585+
586+ #[ async_trait:: async_trait]
587+ impl EncryptionService for MockEncryptionService {
588+ async fn encrypt (
589+ & self ,
590+ _keyset_id : Option < KeysetIdentifier > ,
591+ plaintexts : Vec < Option < cipherstash_client:: encryption:: Plaintext > > ,
592+ _columns : & [ Option < Column > ] ,
593+ ) -> Result < Vec < Option < crate :: EqlEncrypted > > , Error > {
594+ Ok ( plaintexts. into_iter ( ) . map ( |_| None ) . collect ( ) )
595+ }
596+
597+ async fn decrypt (
598+ & self ,
599+ _keyset_id : Option < KeysetIdentifier > ,
600+ ciphertexts : Vec < Option < crate :: EqlEncrypted > > ,
601+ ) -> Result < Vec < Option < cipherstash_client:: encryption:: Plaintext > > , Error > {
602+ Ok ( ciphertexts. into_iter ( ) . map ( |_| None ) . collect ( ) )
603+ }
604+ }
605+
606+ #[ async_trait:: async_trait]
607+ impl SchemaService for MockSchemaService {
608+ async fn reload_schema ( & self ) { }
609+ fn get_column_config (
610+ & self ,
611+ _identifier : & crate :: eql:: Identifier ,
612+ ) -> Option < cipherstash_client:: schema:: ColumnConfig > {
613+ None
614+ }
615+ fn get_table_resolver ( & self ) -> Arc < TableResolver > {
616+ Arc :: new ( TableResolver :: new_editable ( Arc :: new ( Schema :: new ( "public" ) ) ) )
617+ }
618+ fn is_passthrough ( & self ) -> bool {
619+ true
620+ }
621+ fn is_empty_config ( & self ) -> bool {
622+ true
623+ }
624+ }
625+
626+ // Set up minimal environment variables for testing
627+ std:: env:: set_var ( "CS_DATABASE__USERNAME" , "test" ) ;
628+ std:: env:: set_var ( "CS_DATABASE__PASSWORD" , "test" ) ;
629+ std:: env:: set_var ( "CS_DATABASE__NAME" , "test" ) ;
630+ std:: env:: set_var ( "CS_DATABASE__HOST" , "localhost" ) ;
631+ std:: env:: set_var ( "CS_DATABASE__PORT" , "5432" ) ;
632+ std:: env:: set_var ( "CS_AUTH__WORKSPACE_CRN" , "crn:ap-southeast-2.aws:test" ) ;
633+ std:: env:: set_var ( "CS_AUTH__CLIENT_ACCESS_KEY" , "test" ) ;
634+ std:: env:: set_var ( "CS_ENCRYPT__CLIENT_ID" , "test" ) ;
635+ std:: env:: set_var ( "CS_ENCRYPT__CLIENT_KEY" , "test" ) ;
636+
637+ let config = Arc :: new (
638+ crate :: config:: TandemConfig :: build ( "tests/config/unknown.toml" )
639+ . expect ( "Failed to create test config" ) ,
640+ ) ;
641+
642+ let encryption = Arc :: new ( MockEncryptionService ) as Arc < dyn EncryptionService > ;
643+ let schema_service = Arc :: new ( MockSchemaService ) as Arc < dyn SchemaService > ;
644+
645+ Self :: new ( client_id, schema, config, encryption, schema_service)
646+ }
480647}
481648
482649impl Statement {
@@ -629,7 +796,7 @@ mod tests {
629796
630797 let schema = Arc :: new ( Schema :: new ( "public" ) ) ;
631798
632- let mut context = Context :: new ( 1 , schema) ;
799+ let mut context = Context :: new_for_test ( 1 , schema) ;
633800
634801 let name = Name :: from ( "name" ) ;
635802
@@ -654,7 +821,7 @@ mod tests {
654821
655822 let schema = Arc :: new ( Schema :: new ( "public" ) ) ;
656823
657- let mut context = Context :: new ( 1 , schema) ;
824+ let mut context = Context :: new_for_test ( 1 , schema) ;
658825
659826 let statement_name = Name :: from ( "statement" ) ;
660827 let portal_name = Name :: from ( "portal" ) ;
@@ -695,7 +862,7 @@ mod tests {
695862
696863 let schema = Arc :: new ( Schema :: new ( "public" ) ) ;
697864
698- let mut context = Context :: new ( 1 , schema) ;
865+ let mut context = Context :: new_for_test ( 1 , schema) ;
699866
700867 // Create multiple statements
701868 let statement_name_1 = Name :: from ( "statement_1" ) ;
@@ -745,7 +912,7 @@ mod tests {
745912
746913 let schema = Arc :: new ( Schema :: new ( "public" ) ) ;
747914
748- let mut context = Context :: new ( 1 , schema) ;
915+ let mut context = Context :: new_for_test ( 1 , schema) ;
749916
750917 let statement_name_1 = Name :: from ( "statement_1" ) ;
751918 let portal_name_1 = Name :: unnamed ( ) ;
@@ -817,7 +984,7 @@ mod tests {
817984 log:: init ( LogConfig :: default ( ) ) ;
818985
819986 let schema = Arc :: new ( Schema :: new ( "public" ) ) ;
820- let mut context = Context :: new ( 1 , schema) ;
987+ let mut context = Context :: new_for_test ( 1 , schema) ;
821988
822989 let sql = "SET CIPHERSTASH.UNSAFE_DISABLE_MAPPING = true" ;
823990 let statement = parse_statement ( sql) ;
@@ -867,7 +1034,7 @@ mod tests {
8671034 ] ;
8681035
8691036 for s in sql {
870- let mut context = Context :: new ( 1 , schema. clone ( ) ) ;
1037+ let mut context = Context :: new_for_test ( 1 , schema. clone ( ) ) ;
8711038 assert ! ( context. keyset_identifier( ) . is_none( ) ) ;
8721039
8731040 let statement = parse_statement ( s) ;
@@ -887,7 +1054,7 @@ mod tests {
8871054 log:: init ( LogConfig :: default ( ) ) ;
8881055
8891056 let schema = Arc :: new ( Schema :: new ( "public" ) ) ;
890- let mut context = Context :: new ( 1 , schema) ;
1057+ let mut context = Context :: new_for_test ( 1 , schema) ;
8911058
8921059 // Returns OK if unknown command
8931060 let sql = "SET CIPHERSTASH.BLAH = 'keyset_id'" ;
@@ -931,7 +1098,7 @@ mod tests {
9311098 ] ;
9321099
9331100 for s in sql {
934- let mut context = Context :: new ( 1 , schema. clone ( ) ) ;
1101+ let mut context = Context :: new_for_test ( 1 , schema. clone ( ) ) ;
9351102 assert ! ( context. keyset_identifier( ) . is_none( ) ) ;
9361103
9371104 let statement = parse_statement ( s) ;
@@ -952,7 +1119,7 @@ mod tests {
9521119 log:: init ( LogConfig :: default ( ) ) ;
9531120
9541121 let schema = Arc :: new ( Schema :: new ( "public" ) ) ;
955- let mut context = Context :: new ( 1 , schema) ;
1122+ let mut context = Context :: new_for_test ( 1 , schema) ;
9561123
9571124 // Returns OK if unknown command
9581125 let sql = "SET CIPHERSTASH.BLAH = 'keyset_name'" ;
@@ -991,7 +1158,7 @@ mod tests {
9911158 let schema = Arc :: new ( Schema :: new ( "public" ) ) ;
9921159
9931160 // Test keyset name with number
994- let mut context = Context :: new ( 1 , schema. clone ( ) ) ;
1161+ let mut context = Context :: new_for_test ( 1 , schema. clone ( ) ) ;
9951162 let sql = "SET CIPHERSTASH.KEYSET_NAME = 12345" ;
9961163 let statement = parse_statement ( sql) ;
9971164
@@ -1002,7 +1169,7 @@ mod tests {
10021169 assert_eq ! ( Some ( identifier. clone( ) ) , context. keyset_identifier( ) ) ;
10031170
10041171 // Test keyset id with numeric UUID (should work if it's a valid UUID)
1005- let mut context = Context :: new ( 2 , schema) ;
1172+ let mut context = Context :: new_for_test ( 2 , schema) ;
10061173 // This will fail because 123 is not a valid UUID, but it shows the number is processed
10071174 let sql = "SET CIPHERSTASH.KEYSET_ID = 123" ;
10081175 let statement = parse_statement ( sql) ;
@@ -1019,7 +1186,7 @@ mod tests {
10191186 let schema = Arc :: new ( Schema :: new ( "public" ) ) ;
10201187
10211188 // Test that maybe_set_keyset handles both ID and name
1022- let mut context = Context :: new ( 1 , schema. clone ( ) ) ;
1189+ let mut context = Context :: new_for_test ( 1 , schema. clone ( ) ) ;
10231190
10241191 // Test with keyset ID
10251192 let keyset_id_sql = "SET CIPHERSTASH.KEYSET_ID = '7d4cbd7f-ba0d-4985-9ed2-ebe2ffe77590'" ;
@@ -1035,7 +1202,7 @@ mod tests {
10351202 assert_eq ! ( Some ( identifier. clone( ) ) , context. keyset_identifier( ) ) ;
10361203
10371204 // Test with keyset name
1038- let mut context = Context :: new ( 2 , schema. clone ( ) ) ;
1205+ let mut context = Context :: new_for_test ( 2 , schema. clone ( ) ) ;
10391206 let keyset_name_sql = "SET CIPHERSTASH.KEYSET_NAME = 'test-keyset'" ;
10401207 let statement = parse_statement ( keyset_name_sql) ;
10411208
@@ -1047,7 +1214,7 @@ mod tests {
10471214 assert_eq ! ( Some ( identifier. clone( ) ) , context. keyset_identifier( ) ) ;
10481215
10491216 // Test with unknown command
1050- let mut context = Context :: new ( 3 , schema) ;
1217+ let mut context = Context :: new_for_test ( 3 , schema) ;
10511218 let unknown_sql = "SET CIPHERSTASH.UNKNOWN = 'value'" ;
10521219 let statement = parse_statement ( unknown_sql) ;
10531220 let result = context. maybe_set_keyset ( & statement) ;
0 commit comments