1
+ using System ;
2
+ using System . Collections . Generic ;
3
+ using System . Linq ;
4
+ using System . Threading . Tasks ;
5
+
6
+ namespace Confluent . SchemaRegistry . Encryption
7
+ {
8
+ public class KmsClientWrapper : IKmsClient
9
+ {
10
+ public IEnumerable < KeyValuePair < string , string > > Configs { get ; }
11
+
12
+ public RegisteredKek Kek { get ; }
13
+
14
+ public string KekId { get ; }
15
+
16
+ public IList < string > KmsKeyIds { get ; }
17
+
18
+ public KmsClientWrapper ( IEnumerable < KeyValuePair < string , string > > configs , RegisteredKek kek )
19
+ {
20
+ Configs = configs ;
21
+ Kek = kek ;
22
+ KekId = kek . KmsType + EncryptionExecutor . KmsTypeSuffix + kek . KmsKeyId ;
23
+ KmsKeyIds = GetKmsKeyIds ( ) ;
24
+ }
25
+
26
+ public bool DoesSupport ( string uri )
27
+ {
28
+ return KekId == uri ;
29
+ }
30
+
31
+ public async Task < byte [ ] > Encrypt ( byte [ ] plaintext )
32
+ {
33
+ for ( int i = 0 ; i < KmsKeyIds . Count ; i ++ )
34
+ {
35
+ try
36
+ {
37
+ IKmsClient kmsClient = GetKmsClient ( Configs , Kek . KmsType , KmsKeyIds [ i ] ) ;
38
+ return await kmsClient . Encrypt ( plaintext ) . ConfigureAwait ( false ) ;
39
+ }
40
+ catch ( Exception e )
41
+ {
42
+ if ( i == KmsKeyIds . Count - 1 )
43
+ {
44
+ throw new RuleException ( "Failed to encrypt with all KEKs" , e ) ;
45
+ }
46
+ }
47
+ }
48
+ return null ;
49
+ }
50
+
51
+ public async Task < byte [ ] > Decrypt ( byte [ ] ciphertext )
52
+ {
53
+ for ( int i = 0 ; i < KmsKeyIds . Count ; i ++ )
54
+ {
55
+ try
56
+ {
57
+ IKmsClient kmsClient = GetKmsClient ( Configs , Kek . KmsType , KmsKeyIds [ i ] ) ;
58
+ return await kmsClient . Decrypt ( ciphertext ) . ConfigureAwait ( false ) ;
59
+ }
60
+ catch ( Exception e )
61
+ {
62
+ if ( i == KmsKeyIds . Count - 1 )
63
+ {
64
+ throw new RuleException ( "Failed to decrypt with all KEKs" , e ) ;
65
+ }
66
+ }
67
+ }
68
+ return null ;
69
+ }
70
+
71
+ private IList < string > GetKmsKeyIds ( )
72
+ {
73
+ IList < string > kmsKeyIds = new List < string > ( ) ;
74
+ kmsKeyIds . Add ( Kek . KmsKeyId ) ;
75
+ if ( Kek . KmsProps != null )
76
+ {
77
+ if ( Kek . KmsProps . TryGetValue ( EncryptionExecutor . EncryptAlternateKmsKeyIds , out string alternateKmsKeyIds ) )
78
+ {
79
+ char [ ] separators = { ',' } ;
80
+ string [ ] ids = alternateKmsKeyIds . Split ( separators , StringSplitOptions . RemoveEmptyEntries ) ;
81
+ foreach ( string id in ids ) {
82
+ if ( ! string . IsNullOrEmpty ( id ) ) {
83
+ kmsKeyIds . Add ( id ) ;
84
+ }
85
+ }
86
+ }
87
+ }
88
+ return kmsKeyIds ;
89
+ }
90
+
91
+ private static IKmsClient GetKmsClient ( IEnumerable < KeyValuePair < string , string > > configs , string kmsType , string kmsKeyId )
92
+ {
93
+ string keyUrl = kmsType + EncryptionExecutor . KmsTypeSuffix + kmsKeyId ;
94
+ IKmsClient kmsClient = KmsRegistry . GetKmsClient ( keyUrl ) ;
95
+ if ( kmsClient == null )
96
+ {
97
+ IKmsDriver kmsDriver = KmsRegistry . GetKmsDriver ( keyUrl ) ;
98
+ kmsClient = kmsDriver . NewKmsClient (
99
+ configs . ToDictionary ( it => it . Key , it => it . Value ) , keyUrl ) ;
100
+ KmsRegistry . RegisterKmsClient ( kmsClient ) ;
101
+ }
102
+
103
+ return kmsClient ;
104
+ }
105
+ }
106
+ }
0 commit comments