diff --git a/src/Confluent.SchemaRegistry.Encryption/EncryptionExecutor.cs b/src/Confluent.SchemaRegistry.Encryption/EncryptionExecutor.cs index 939ccd992..22ee77028 100644 --- a/src/Confluent.SchemaRegistry.Encryption/EncryptionExecutor.cs +++ b/src/Confluent.SchemaRegistry.Encryption/EncryptionExecutor.cs @@ -38,6 +38,7 @@ public static void Register() public static readonly string EncryptKmsType = "encrypt.kms.type"; public static readonly string EncryptDekAlgorithm = "encrypt.dek.algorithm"; public static readonly string EncryptDekExpiryDays = "encrypt.dek.expiry.days"; + public static readonly string EncryptAlternateKmsKeyIds = "encrypt.alternate.kms.key.ids"; public static readonly string KmsTypeSuffix = "://"; @@ -337,7 +338,7 @@ private async Task GetOrCreateDek(RuleContext ctx, int? version) byte[] encryptedDek = null; if (!kek.Shared) { - kmsClient = GetKmsClient(executor.Configs, kek); + kmsClient = new KmsClientWrapper(executor.Configs, kek); // Generate new dek byte[] rawDek = cryptor.GenerateKey(); encryptedDek = await kmsClient.Encrypt(rawDek) @@ -363,7 +364,7 @@ private async Task GetOrCreateDek(RuleContext ctx, int? version) { if (kmsClient == null) { - kmsClient = GetKmsClient(executor.Configs, kek); + kmsClient = new KmsClientWrapper(executor.Configs, kek); } byte[] rawDek = await kmsClient.Decrypt(dek.EncryptedKeyMaterialBytes) @@ -566,21 +567,6 @@ private byte[] PrefixVersion(int version, byte[] ciphertext) } } } - - private static IKmsClient GetKmsClient(IEnumerable> configs, RegisteredKek kek) - { - string keyUrl = kek.KmsType + EncryptionExecutor.KmsTypeSuffix + kek.KmsKeyId; - IKmsClient kmsClient = KmsRegistry.GetKmsClient(keyUrl); - if (kmsClient == null) - { - IKmsDriver kmsDriver = KmsRegistry.GetKmsDriver(keyUrl); - kmsClient = kmsDriver.NewKmsClient( - configs.ToDictionary(it => it.Key, it => it.Value), keyUrl); - KmsRegistry.RegisterKmsClient(kmsClient); - } - - return kmsClient; - } } public interface IClock diff --git a/src/Confluent.SchemaRegistry.Encryption/KmsClientWrapper.cs b/src/Confluent.SchemaRegistry.Encryption/KmsClientWrapper.cs new file mode 100644 index 000000000..fd13e2d06 --- /dev/null +++ b/src/Confluent.SchemaRegistry.Encryption/KmsClientWrapper.cs @@ -0,0 +1,119 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +namespace Confluent.SchemaRegistry.Encryption +{ + public class KmsClientWrapper : IKmsClient + { + public IEnumerable> Configs { get; } + + public RegisteredKek Kek { get; } + + public string KekId { get; } + + public IList KmsKeyIds { get; } + + public KmsClientWrapper(IEnumerable> configs, RegisteredKek kek) + { + Configs = configs; + Kek = kek; + KekId = kek.KmsType + EncryptionExecutor.KmsTypeSuffix + kek.KmsKeyId; + KmsKeyIds = GetKmsKeyIds(); + } + + public bool DoesSupport(string uri) + { + return KekId == uri; + } + + public async Task Encrypt(byte[] plaintext) + { + for (int i = 0; i < KmsKeyIds.Count; i++) + { + try + { + IKmsClient kmsClient = GetKmsClient(Configs, Kek.KmsType, KmsKeyIds[i]); + return await kmsClient.Encrypt(plaintext).ConfigureAwait(false); + } + catch (Exception e) + { + if (i == KmsKeyIds.Count - 1) + { + throw new RuleException("Failed to encrypt with all KEKs", e); + } + } + } + throw new RuleException("No KEK found for encryption"); + } + + public async Task Decrypt(byte[] ciphertext) + { + for (int i = 0; i < KmsKeyIds.Count; i++) + { + try + { + IKmsClient kmsClient = GetKmsClient(Configs, Kek.KmsType, KmsKeyIds[i]); + return await kmsClient.Decrypt(ciphertext).ConfigureAwait(false); + } + catch (Exception e) + { + if (i == KmsKeyIds.Count - 1) + { + throw new RuleException("Failed to decrypt with all KEKs", e); + } + } + } + throw new RuleException("No KEK found for decryption"); + } + + private IList GetKmsKeyIds() + { + IList kmsKeyIds = new List(); + kmsKeyIds.Add(Kek.KmsKeyId); + string alternateKmsKeyIds = null; + if (Kek.KmsProps != null) + { + Kek.KmsProps.TryGetValue(EncryptionExecutor.EncryptAlternateKmsKeyIds, + out alternateKmsKeyIds); + } + if (string.IsNullOrEmpty(alternateKmsKeyIds)) + { + var kvp = Configs.FirstOrDefault(x => + x.Key == EncryptionExecutor.EncryptAlternateKmsKeyIds); + if (!kvp.Equals(default(KeyValuePair))) + { + alternateKmsKeyIds = kvp.Value; + } + } + if (!string.IsNullOrEmpty(alternateKmsKeyIds)) + { + string[] ids = alternateKmsKeyIds.Split(',', StringSplitOptions.RemoveEmptyEntries); + foreach (string id in ids) + { + if (!string.IsNullOrEmpty(id)) + { + kmsKeyIds.Add(id); + } + } + } + return kmsKeyIds; + } + + private static IKmsClient GetKmsClient(IEnumerable> configs, string kmsType, string kmsKeyId) + { + string keyUrl = kmsType + EncryptionExecutor.KmsTypeSuffix + kmsKeyId; + IKmsClient kmsClient = KmsRegistry.GetKmsClient(keyUrl); + if (kmsClient == null) + { + IKmsDriver kmsDriver = KmsRegistry.GetKmsDriver(keyUrl); + kmsClient = kmsDriver.NewKmsClient( + configs.ToDictionary(it => it.Key, it => it.Value), keyUrl); + KmsRegistry.RegisterKmsClient(kmsClient); + } + + return kmsClient; + } + } +} \ No newline at end of file diff --git a/test/Confluent.SchemaRegistry.Serdes.UnitTests/SerializeDeserialize.cs b/test/Confluent.SchemaRegistry.Serdes.UnitTests/SerializeDeserialize.cs index 1e0a958a5..2a43e49f2 100644 --- a/test/Confluent.SchemaRegistry.Serdes.UnitTests/SerializeDeserialize.cs +++ b/test/Confluent.SchemaRegistry.Serdes.UnitTests/SerializeDeserialize.cs @@ -649,6 +649,69 @@ public void ISpecificRecordPayloadEncryption() Assert.True(pic.SequenceEqual(result.picture)); } + [Fact] + public void ISpecificRecordEncryptionAlternateKeks() + { + var schemaStr = "{\"type\":\"record\",\"name\":\"UserWithPic\",\"namespace\":\"Confluent.Kafka.Examples.AvroSpecific" + + "\",\"fields\":[{\"name\":\"name\",\"type\":\"string\"},{\"name\":\"favorite_number\"," + + "\"type\":[\"int\",\"null\"]},{\"name\":\"favorite_color\",\"type\":[\"string\",\"null\"]}," + + "{\"name\":\"picture\",\"type\":[\"null\",\"bytes\"],\"default\":null}]}"; + + var schema = new RegisteredSchema("topic-value", 1, 1, schemaStr, SchemaType.Avro, null); + schema.Metadata = new Metadata(new Dictionary> + { + ["Confluent.Kafka.Examples.AvroSpecific.UserWithPic.name"] = new HashSet { "PII" }, + ["Confluent.Kafka.Examples.AvroSpecific.UserWithPic.picture"] = new HashSet { "PII" } + + }, new Dictionary(), new HashSet() + ); + schema.RuleSet = new RuleSet(new List(), new List(), + new List + { + new Rule("encryptPII", RuleKind.Transform, RuleMode.WriteRead, "ENCRYPT_PAYLOAD", null, + new Dictionary + { + ["encrypt.kek.name"] = "kek1", + ["encrypt.kms.type"] = "local-kms", + ["encrypt.kms.key.id"] = "mykey" + }) + } + ); + store[schemaStr] = 1; + subjectStore["topic-value"] = new List { schema }; + var config = new AvroSerializerConfig + { + AutoRegisterSchemas = false, + UseLatestVersion = true + }; + config.Set("rules.secret", "mysecret"); + config.Set("rules.encrypt.alternate.kms.key.ids", "mykey2,mykey3"); + RuleRegistry ruleRegistry = new RuleRegistry(); + IRuleExecutor ruleExecutor = new EncryptionExecutor(dekRegistryClient, clock); + ruleRegistry.RegisterExecutor(ruleExecutor); + var serializer = new AvroSerializer(schemaRegistryClient, config, ruleRegistry); + var deserializer = new AvroDeserializer(schemaRegistryClient, null, ruleRegistry); + + var pic = new byte[] { 1, 2, 3 }; + var user = new UserWithPic() + { + favorite_color = "blue", + favorite_number = 100, + name = "awesome", + picture = pic + }; + + Headers headers = new Headers(); + var bytes = serializer.SerializeAsync(user, new SerializationContext(MessageComponentType.Value, testTopic, headers)).Result; + var result = deserializer.DeserializeAsync(bytes, false, new SerializationContext(MessageComponentType.Value, testTopic, headers)).Result; + + // The user name has been modified + Assert.Equal("awesome", result.name); + Assert.Equal(user.favorite_color, result.favorite_color); + Assert.Equal(user.favorite_number, result.favorite_number); + Assert.True(pic.SequenceEqual(result.picture)); + } + [Fact] public void ISpecificRecordFieldEncryptionDekRotation() {