Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions src/Confluent.SchemaRegistry.Encryption/EncryptionExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public static void Register()
public static readonly string EncryptKmsType = "encrypt.kms.type";
public static readonly string EncryptDekAlgorithm = "encrypt.dek.algorithm";
public static readonly string EncryptDekExpiryDays = "encrypt.dek.expiry.days";
public static readonly string EncryptAlternateKmsKeyIds = "encrypt.alternate.kms.key.ids";

public static readonly string KmsTypeSuffix = "://";

Expand Down Expand Up @@ -337,7 +338,7 @@ private async Task<RegisteredDek> GetOrCreateDek(RuleContext ctx, int? version)
byte[] encryptedDek = null;
if (!kek.Shared)
{
kmsClient = GetKmsClient(executor.Configs, kek);
kmsClient = new KmsClientWrapper(executor.Configs, kek);
// Generate new dek
byte[] rawDek = cryptor.GenerateKey();
encryptedDek = await kmsClient.Encrypt(rawDek)
Expand All @@ -363,7 +364,7 @@ private async Task<RegisteredDek> GetOrCreateDek(RuleContext ctx, int? version)
{
if (kmsClient == null)
{
kmsClient = GetKmsClient(executor.Configs, kek);
kmsClient = new KmsClientWrapper(executor.Configs, kek);
}

byte[] rawDek = await kmsClient.Decrypt(dek.EncryptedKeyMaterialBytes)
Expand Down Expand Up @@ -566,21 +567,6 @@ private byte[] PrefixVersion(int version, byte[] ciphertext)
}
}
}

private static IKmsClient GetKmsClient(IEnumerable<KeyValuePair<string, string>> configs, RegisteredKek kek)
{
string keyUrl = kek.KmsType + EncryptionExecutor.KmsTypeSuffix + kek.KmsKeyId;
IKmsClient kmsClient = KmsRegistry.GetKmsClient(keyUrl);
if (kmsClient == null)
{
IKmsDriver kmsDriver = KmsRegistry.GetKmsDriver(keyUrl);
kmsClient = kmsDriver.NewKmsClient(
configs.ToDictionary(it => it.Key, it => it.Value), keyUrl);
KmsRegistry.RegisterKmsClient(kmsClient);
}

return kmsClient;
}
}

public interface IClock
Expand Down
119 changes: 119 additions & 0 deletions src/Confluent.SchemaRegistry.Encryption/KmsClientWrapper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;

namespace Confluent.SchemaRegistry.Encryption
{
public class KmsClientWrapper : IKmsClient
{
public IEnumerable<KeyValuePair<string, string>> Configs { get; }

public RegisteredKek Kek { get; }

public string KekId { get; }

public IList<string> KmsKeyIds { get; }

public KmsClientWrapper(IEnumerable<KeyValuePair<string, string>> configs, RegisteredKek kek)
{
Configs = configs;
Kek = kek;
KekId = kek.KmsType + EncryptionExecutor.KmsTypeSuffix + kek.KmsKeyId;
KmsKeyIds = GetKmsKeyIds();
}

public bool DoesSupport(string uri)
{
return KekId == uri;
}

public async Task<byte[]> Encrypt(byte[] plaintext)
{
for (int i = 0; i < KmsKeyIds.Count; i++)
{
try
{
IKmsClient kmsClient = GetKmsClient(Configs, Kek.KmsType, KmsKeyIds[i]);
return await kmsClient.Encrypt(plaintext).ConfigureAwait(false);
}
catch (Exception e)
{
if (i == KmsKeyIds.Count - 1)
{
throw new RuleException("Failed to encrypt with all KEKs", e);
}
}
}
throw new RuleException("No KEK found for encryption");
}

public async Task<byte[]> Decrypt(byte[] ciphertext)
{
for (int i = 0; i < KmsKeyIds.Count; i++)
{
try
{
IKmsClient kmsClient = GetKmsClient(Configs, Kek.KmsType, KmsKeyIds[i]);
return await kmsClient.Decrypt(ciphertext).ConfigureAwait(false);
}
catch (Exception e)
{
if (i == KmsKeyIds.Count - 1)
{
throw new RuleException("Failed to decrypt with all KEKs", e);
}
}
}
throw new RuleException("No KEK found for decryption");
}

private IList<string> GetKmsKeyIds()
{
IList<string> kmsKeyIds = new List<string>();
kmsKeyIds.Add(Kek.KmsKeyId);
string alternateKmsKeyIds = null;
if (Kek.KmsProps != null)
{
Kek.KmsProps.TryGetValue(EncryptionExecutor.EncryptAlternateKmsKeyIds,
out alternateKmsKeyIds);
}
if (string.IsNullOrEmpty(alternateKmsKeyIds))
{
var kvp = Configs.FirstOrDefault(x =>
x.Key == EncryptionExecutor.EncryptAlternateKmsKeyIds);
if (!kvp.Equals(default(KeyValuePair<string, string>)))
{
alternateKmsKeyIds = kvp.Value;
}
}
if (!string.IsNullOrEmpty(alternateKmsKeyIds))
{
string[] ids = alternateKmsKeyIds.Split(',', StringSplitOptions.RemoveEmptyEntries);
foreach (string id in ids)
{
if (!string.IsNullOrEmpty(id))
{
kmsKeyIds.Add(id);
}
}
}
return kmsKeyIds;
}

private static IKmsClient GetKmsClient(IEnumerable<KeyValuePair<string, string>> configs, string kmsType, string kmsKeyId)
{
string keyUrl = kmsType + EncryptionExecutor.KmsTypeSuffix + kmsKeyId;
IKmsClient kmsClient = KmsRegistry.GetKmsClient(keyUrl);
if (kmsClient == null)
{
IKmsDriver kmsDriver = KmsRegistry.GetKmsDriver(keyUrl);
kmsClient = kmsDriver.NewKmsClient(
configs.ToDictionary(it => it.Key, it => it.Value), keyUrl);
KmsRegistry.RegisterKmsClient(kmsClient);
}

return kmsClient;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,69 @@ public void ISpecificRecordPayloadEncryption()
Assert.True(pic.SequenceEqual(result.picture));
}

[Fact]
public void ISpecificRecordEncryptionAlternateKeks()
{
var schemaStr = "{\"type\":\"record\",\"name\":\"UserWithPic\",\"namespace\":\"Confluent.Kafka.Examples.AvroSpecific" +
"\",\"fields\":[{\"name\":\"name\",\"type\":\"string\"},{\"name\":\"favorite_number\"," +
"\"type\":[\"int\",\"null\"]},{\"name\":\"favorite_color\",\"type\":[\"string\",\"null\"]}," +
"{\"name\":\"picture\",\"type\":[\"null\",\"bytes\"],\"default\":null}]}";

var schema = new RegisteredSchema("topic-value", 1, 1, schemaStr, SchemaType.Avro, null);
schema.Metadata = new Metadata(new Dictionary<string, ISet<string>>
{
["Confluent.Kafka.Examples.AvroSpecific.UserWithPic.name"] = new HashSet<string> { "PII" },
["Confluent.Kafka.Examples.AvroSpecific.UserWithPic.picture"] = new HashSet<string> { "PII" }

}, new Dictionary<string, string>(), new HashSet<string>()
);
schema.RuleSet = new RuleSet(new List<Rule>(), new List<Rule>(),
new List<Rule>
{
new Rule("encryptPII", RuleKind.Transform, RuleMode.WriteRead, "ENCRYPT_PAYLOAD", null,
new Dictionary<string, string>
{
["encrypt.kek.name"] = "kek1",
["encrypt.kms.type"] = "local-kms",
["encrypt.kms.key.id"] = "mykey"
})
}
);
store[schemaStr] = 1;
subjectStore["topic-value"] = new List<RegisteredSchema> { schema };
var config = new AvroSerializerConfig
{
AutoRegisterSchemas = false,
UseLatestVersion = true
};
config.Set("rules.secret", "mysecret");
config.Set("rules.encrypt.alternate.kms.key.ids", "mykey2,mykey3");
RuleRegistry ruleRegistry = new RuleRegistry();
IRuleExecutor ruleExecutor = new EncryptionExecutor(dekRegistryClient, clock);
ruleRegistry.RegisterExecutor(ruleExecutor);
var serializer = new AvroSerializer<UserWithPic>(schemaRegistryClient, config, ruleRegistry);
var deserializer = new AvroDeserializer<UserWithPic>(schemaRegistryClient, null, ruleRegistry);

var pic = new byte[] { 1, 2, 3 };
var user = new UserWithPic()
{
favorite_color = "blue",
favorite_number = 100,
name = "awesome",
picture = pic
};

Headers headers = new Headers();
var bytes = serializer.SerializeAsync(user, new SerializationContext(MessageComponentType.Value, testTopic, headers)).Result;
var result = deserializer.DeserializeAsync(bytes, false, new SerializationContext(MessageComponentType.Value, testTopic, headers)).Result;

// The user name has been modified
Assert.Equal("awesome", result.name);
Assert.Equal(user.favorite_color, result.favorite_color);
Assert.Equal(user.favorite_number, result.favorite_number);
Assert.True(pic.SequenceEqual(result.picture));
}

[Fact]
public void ISpecificRecordFieldEncryptionDekRotation()
{
Expand Down