Skip to content

Commit 0e710fd

Browse files
committed
First cut
1 parent 3a75aa2 commit 0e710fd

File tree

2 files changed

+109
-17
lines changed

2 files changed

+109
-17
lines changed

src/Confluent.SchemaRegistry.Encryption/EncryptionExecutor.cs

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ public static void Register()
3838
public static readonly string EncryptKmsType = "encrypt.kms.type";
3939
public static readonly string EncryptDekAlgorithm = "encrypt.dek.algorithm";
4040
public static readonly string EncryptDekExpiryDays = "encrypt.dek.expiry.days";
41+
public static readonly string EncryptAlternateKmsKeyIds = "encrypt.alternate.kms.key.ids";
4142

4243
public static readonly string KmsTypeSuffix = "://";
4344

@@ -337,7 +338,7 @@ private async Task<RegisteredDek> GetOrCreateDek(RuleContext ctx, int? version)
337338
byte[] encryptedDek = null;
338339
if (!kek.Shared)
339340
{
340-
kmsClient = GetKmsClient(executor.Configs, kek);
341+
kmsClient = new KmsClientWrapper(executor.Configs, kek);
341342
// Generate new dek
342343
byte[] rawDek = cryptor.GenerateKey();
343344
encryptedDek = await kmsClient.Encrypt(rawDek)
@@ -363,7 +364,7 @@ private async Task<RegisteredDek> GetOrCreateDek(RuleContext ctx, int? version)
363364
{
364365
if (kmsClient == null)
365366
{
366-
kmsClient = GetKmsClient(executor.Configs, kek);
367+
kmsClient = new KmsClientWrapper(executor.Configs, kek);
367368
}
368369

369370
byte[] rawDek = await kmsClient.Decrypt(dek.EncryptedKeyMaterialBytes)
@@ -566,21 +567,6 @@ private byte[] PrefixVersion(int version, byte[] ciphertext)
566567
}
567568
}
568569
}
569-
570-
private static IKmsClient GetKmsClient(IEnumerable<KeyValuePair<string, string>> configs, RegisteredKek kek)
571-
{
572-
string keyUrl = kek.KmsType + EncryptionExecutor.KmsTypeSuffix + kek.KmsKeyId;
573-
IKmsClient kmsClient = KmsRegistry.GetKmsClient(keyUrl);
574-
if (kmsClient == null)
575-
{
576-
IKmsDriver kmsDriver = KmsRegistry.GetKmsDriver(keyUrl);
577-
kmsClient = kmsDriver.NewKmsClient(
578-
configs.ToDictionary(it => it.Key, it => it.Value), keyUrl);
579-
KmsRegistry.RegisterKmsClient(kmsClient);
580-
}
581-
582-
return kmsClient;
583-
}
584570
}
585571

586572
public interface IClock
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Threading.Tasks;
5+
6+
namespace Confluent.SchemaRegistry.Encryption
7+
{
8+
public class KmsClientWrapper : IKmsClient
9+
{
10+
public IEnumerable<KeyValuePair<string, string>> Configs { get; }
11+
12+
public RegisteredKek Kek { get; }
13+
14+
public string KekId { get; }
15+
16+
public IList<string> KmsKeyIds { get; }
17+
18+
public KmsClientWrapper(IEnumerable<KeyValuePair<string, string>> configs, RegisteredKek kek)
19+
{
20+
Configs = configs;
21+
Kek = kek;
22+
KekId = kek.KmsType + EncryptionExecutor.KmsTypeSuffix + kek.KmsKeyId;
23+
KmsKeyIds = GetKmsKeyIds();
24+
}
25+
26+
public bool DoesSupport(string uri)
27+
{
28+
return KekId == uri;
29+
}
30+
31+
public async Task<byte[]> Encrypt(byte[] plaintext)
32+
{
33+
for (int i = 0; i < KmsKeyIds.Count; i++)
34+
{
35+
try
36+
{
37+
IKmsClient kmsClient = GetKmsClient(Configs, Kek.KmsType, KmsKeyIds[i]);
38+
return await kmsClient.Encrypt(plaintext).ConfigureAwait(false);
39+
}
40+
catch (Exception e)
41+
{
42+
if (i == KmsKeyIds.Count - 1)
43+
{
44+
throw new RuleException("Failed to encrypt with all KEKs", e);
45+
}
46+
}
47+
}
48+
return null;
49+
}
50+
51+
public async Task<byte[]> Decrypt(byte[] ciphertext)
52+
{
53+
for (int i = 0; i < KmsKeyIds.Count; i++)
54+
{
55+
try
56+
{
57+
IKmsClient kmsClient = GetKmsClient(Configs, Kek.KmsType, KmsKeyIds[i]);
58+
return await kmsClient.Decrypt(ciphertext).ConfigureAwait(false);
59+
}
60+
catch (Exception e)
61+
{
62+
if (i == KmsKeyIds.Count - 1)
63+
{
64+
throw new RuleException("Failed to decrypt with all KEKs", e);
65+
}
66+
}
67+
}
68+
return null;
69+
}
70+
71+
private IList<string> GetKmsKeyIds()
72+
{
73+
IList<string> kmsKeyIds = new List<string>();
74+
kmsKeyIds.Add(Kek.KmsKeyId);
75+
if (Kek.KmsProps != null)
76+
{
77+
if (Kek.KmsProps.TryGetValue(EncryptionExecutor.EncryptAlternateKmsKeyIds, out string alternateKmsKeyIds))
78+
{
79+
char[] separators = { ',' };
80+
string[] ids = alternateKmsKeyIds.Split(separators, StringSplitOptions.RemoveEmptyEntries);
81+
foreach (string id in ids) {
82+
if (!string.IsNullOrEmpty(id)) {
83+
kmsKeyIds.Add(id);
84+
}
85+
}
86+
}
87+
}
88+
return kmsKeyIds;
89+
}
90+
91+
private static IKmsClient GetKmsClient(IEnumerable<KeyValuePair<string, string>> configs, string kmsType, string kmsKeyId)
92+
{
93+
string keyUrl = kmsType + EncryptionExecutor.KmsTypeSuffix + kmsKeyId;
94+
IKmsClient kmsClient = KmsRegistry.GetKmsClient(keyUrl);
95+
if (kmsClient == null)
96+
{
97+
IKmsDriver kmsDriver = KmsRegistry.GetKmsDriver(keyUrl);
98+
kmsClient = kmsDriver.NewKmsClient(
99+
configs.ToDictionary(it => it.Key, it => it.Value), keyUrl);
100+
KmsRegistry.RegisterKmsClient(kmsClient);
101+
}
102+
103+
return kmsClient;
104+
}
105+
}
106+
}

0 commit comments

Comments
 (0)