Skip to content

Commit 08a16b3

Browse files
authored
DGS-21595 Allow alternate KMS key IDs on a KEK (#2508)
* First cut * Minor cleanup * Add tests * Add test * Minor cleanup
1 parent 3ffb642 commit 08a16b3

File tree

3 files changed

+185
-17
lines changed

3 files changed

+185
-17
lines changed

src/Confluent.SchemaRegistry.Encryption/EncryptionExecutor.cs

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ public static void Register()
3838
public static readonly string EncryptKmsType = "encrypt.kms.type";
3939
public static readonly string EncryptDekAlgorithm = "encrypt.dek.algorithm";
4040
public static readonly string EncryptDekExpiryDays = "encrypt.dek.expiry.days";
41+
public static readonly string EncryptAlternateKmsKeyIds = "encrypt.alternate.kms.key.ids";
4142

4243
public static readonly string KmsTypeSuffix = "://";
4344

@@ -337,7 +338,7 @@ private async Task<RegisteredDek> GetOrCreateDek(RuleContext ctx, int? version)
337338
byte[] encryptedDek = null;
338339
if (!kek.Shared)
339340
{
340-
kmsClient = GetKmsClient(executor.Configs, kek);
341+
kmsClient = new KmsClientWrapper(executor.Configs, kek);
341342
// Generate new dek
342343
byte[] rawDek = cryptor.GenerateKey();
343344
encryptedDek = await kmsClient.Encrypt(rawDek)
@@ -363,7 +364,7 @@ private async Task<RegisteredDek> GetOrCreateDek(RuleContext ctx, int? version)
363364
{
364365
if (kmsClient == null)
365366
{
366-
kmsClient = GetKmsClient(executor.Configs, kek);
367+
kmsClient = new KmsClientWrapper(executor.Configs, kek);
367368
}
368369

369370
byte[] rawDek = await kmsClient.Decrypt(dek.EncryptedKeyMaterialBytes)
@@ -566,21 +567,6 @@ private byte[] PrefixVersion(int version, byte[] ciphertext)
566567
}
567568
}
568569
}
569-
570-
private static IKmsClient GetKmsClient(IEnumerable<KeyValuePair<string, string>> configs, RegisteredKek kek)
571-
{
572-
string keyUrl = kek.KmsType + EncryptionExecutor.KmsTypeSuffix + kek.KmsKeyId;
573-
IKmsClient kmsClient = KmsRegistry.GetKmsClient(keyUrl);
574-
if (kmsClient == null)
575-
{
576-
IKmsDriver kmsDriver = KmsRegistry.GetKmsDriver(keyUrl);
577-
kmsClient = kmsDriver.NewKmsClient(
578-
configs.ToDictionary(it => it.Key, it => it.Value), keyUrl);
579-
KmsRegistry.RegisterKmsClient(kmsClient);
580-
}
581-
582-
return kmsClient;
583-
}
584570
}
585571

586572
public interface IClock
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Threading.Tasks;
5+
6+
namespace Confluent.SchemaRegistry.Encryption
7+
{
8+
public class KmsClientWrapper : IKmsClient
9+
{
10+
public IEnumerable<KeyValuePair<string, string>> Configs { get; }
11+
12+
public RegisteredKek Kek { get; }
13+
14+
public string KekId { get; }
15+
16+
public IList<string> KmsKeyIds { get; }
17+
18+
public KmsClientWrapper(IEnumerable<KeyValuePair<string, string>> configs, RegisteredKek kek)
19+
{
20+
Configs = configs;
21+
Kek = kek;
22+
KekId = kek.KmsType + EncryptionExecutor.KmsTypeSuffix + kek.KmsKeyId;
23+
KmsKeyIds = GetKmsKeyIds();
24+
}
25+
26+
public bool DoesSupport(string uri)
27+
{
28+
return KekId == uri;
29+
}
30+
31+
public async Task<byte[]> Encrypt(byte[] plaintext)
32+
{
33+
for (int i = 0; i < KmsKeyIds.Count; i++)
34+
{
35+
try
36+
{
37+
IKmsClient kmsClient = GetKmsClient(Configs, Kek.KmsType, KmsKeyIds[i]);
38+
return await kmsClient.Encrypt(plaintext).ConfigureAwait(false);
39+
}
40+
catch (Exception e)
41+
{
42+
if (i == KmsKeyIds.Count - 1)
43+
{
44+
throw new RuleException("Failed to encrypt with all KEKs", e);
45+
}
46+
}
47+
}
48+
throw new RuleException("No KEK found for encryption");
49+
}
50+
51+
public async Task<byte[]> Decrypt(byte[] ciphertext)
52+
{
53+
for (int i = 0; i < KmsKeyIds.Count; i++)
54+
{
55+
try
56+
{
57+
IKmsClient kmsClient = GetKmsClient(Configs, Kek.KmsType, KmsKeyIds[i]);
58+
return await kmsClient.Decrypt(ciphertext).ConfigureAwait(false);
59+
}
60+
catch (Exception e)
61+
{
62+
if (i == KmsKeyIds.Count - 1)
63+
{
64+
throw new RuleException("Failed to decrypt with all KEKs", e);
65+
}
66+
}
67+
}
68+
throw new RuleException("No KEK found for decryption");
69+
}
70+
71+
private IList<string> GetKmsKeyIds()
72+
{
73+
IList<string> kmsKeyIds = new List<string>();
74+
kmsKeyIds.Add(Kek.KmsKeyId);
75+
string alternateKmsKeyIds = null;
76+
if (Kek.KmsProps != null)
77+
{
78+
Kek.KmsProps.TryGetValue(EncryptionExecutor.EncryptAlternateKmsKeyIds,
79+
out alternateKmsKeyIds);
80+
}
81+
if (string.IsNullOrEmpty(alternateKmsKeyIds))
82+
{
83+
var kvp = Configs.FirstOrDefault(x =>
84+
x.Key == EncryptionExecutor.EncryptAlternateKmsKeyIds);
85+
if (!kvp.Equals(default(KeyValuePair<string, string>)))
86+
{
87+
alternateKmsKeyIds = kvp.Value;
88+
}
89+
}
90+
if (!string.IsNullOrEmpty(alternateKmsKeyIds))
91+
{
92+
string[] ids = alternateKmsKeyIds.Split(',', StringSplitOptions.RemoveEmptyEntries);
93+
foreach (string id in ids)
94+
{
95+
if (!string.IsNullOrEmpty(id))
96+
{
97+
kmsKeyIds.Add(id);
98+
}
99+
}
100+
}
101+
return kmsKeyIds;
102+
}
103+
104+
private static IKmsClient GetKmsClient(IEnumerable<KeyValuePair<string, string>> configs, string kmsType, string kmsKeyId)
105+
{
106+
string keyUrl = kmsType + EncryptionExecutor.KmsTypeSuffix + kmsKeyId;
107+
IKmsClient kmsClient = KmsRegistry.GetKmsClient(keyUrl);
108+
if (kmsClient == null)
109+
{
110+
IKmsDriver kmsDriver = KmsRegistry.GetKmsDriver(keyUrl);
111+
kmsClient = kmsDriver.NewKmsClient(
112+
configs.ToDictionary(it => it.Key, it => it.Value), keyUrl);
113+
KmsRegistry.RegisterKmsClient(kmsClient);
114+
}
115+
116+
return kmsClient;
117+
}
118+
}
119+
}

test/Confluent.SchemaRegistry.Serdes.UnitTests/SerializeDeserialize.cs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,69 @@ public void ISpecificRecordPayloadEncryption()
649649
Assert.True(pic.SequenceEqual(result.picture));
650650
}
651651

652+
[Fact]
653+
public void ISpecificRecordEncryptionAlternateKeks()
654+
{
655+
var schemaStr = "{\"type\":\"record\",\"name\":\"UserWithPic\",\"namespace\":\"Confluent.Kafka.Examples.AvroSpecific" +
656+
"\",\"fields\":[{\"name\":\"name\",\"type\":\"string\"},{\"name\":\"favorite_number\"," +
657+
"\"type\":[\"int\",\"null\"]},{\"name\":\"favorite_color\",\"type\":[\"string\",\"null\"]}," +
658+
"{\"name\":\"picture\",\"type\":[\"null\",\"bytes\"],\"default\":null}]}";
659+
660+
var schema = new RegisteredSchema("topic-value", 1, 1, schemaStr, SchemaType.Avro, null);
661+
schema.Metadata = new Metadata(new Dictionary<string, ISet<string>>
662+
{
663+
["Confluent.Kafka.Examples.AvroSpecific.UserWithPic.name"] = new HashSet<string> { "PII" },
664+
["Confluent.Kafka.Examples.AvroSpecific.UserWithPic.picture"] = new HashSet<string> { "PII" }
665+
666+
}, new Dictionary<string, string>(), new HashSet<string>()
667+
);
668+
schema.RuleSet = new RuleSet(new List<Rule>(), new List<Rule>(),
669+
new List<Rule>
670+
{
671+
new Rule("encryptPII", RuleKind.Transform, RuleMode.WriteRead, "ENCRYPT_PAYLOAD", null,
672+
new Dictionary<string, string>
673+
{
674+
["encrypt.kek.name"] = "kek1",
675+
["encrypt.kms.type"] = "local-kms",
676+
["encrypt.kms.key.id"] = "mykey"
677+
})
678+
}
679+
);
680+
store[schemaStr] = 1;
681+
subjectStore["topic-value"] = new List<RegisteredSchema> { schema };
682+
var config = new AvroSerializerConfig
683+
{
684+
AutoRegisterSchemas = false,
685+
UseLatestVersion = true
686+
};
687+
config.Set("rules.secret", "mysecret");
688+
config.Set("rules.encrypt.alternate.kms.key.ids", "mykey2,mykey3");
689+
RuleRegistry ruleRegistry = new RuleRegistry();
690+
IRuleExecutor ruleExecutor = new EncryptionExecutor(dekRegistryClient, clock);
691+
ruleRegistry.RegisterExecutor(ruleExecutor);
692+
var serializer = new AvroSerializer<UserWithPic>(schemaRegistryClient, config, ruleRegistry);
693+
var deserializer = new AvroDeserializer<UserWithPic>(schemaRegistryClient, null, ruleRegistry);
694+
695+
var pic = new byte[] { 1, 2, 3 };
696+
var user = new UserWithPic()
697+
{
698+
favorite_color = "blue",
699+
favorite_number = 100,
700+
name = "awesome",
701+
picture = pic
702+
};
703+
704+
Headers headers = new Headers();
705+
var bytes = serializer.SerializeAsync(user, new SerializationContext(MessageComponentType.Value, testTopic, headers)).Result;
706+
var result = deserializer.DeserializeAsync(bytes, false, new SerializationContext(MessageComponentType.Value, testTopic, headers)).Result;
707+
708+
// The user name has been modified
709+
Assert.Equal("awesome", result.name);
710+
Assert.Equal(user.favorite_color, result.favorite_color);
711+
Assert.Equal(user.favorite_number, result.favorite_number);
712+
Assert.True(pic.SequenceEqual(result.picture));
713+
}
714+
652715
[Fact]
653716
public void ISpecificRecordFieldEncryptionDekRotation()
654717
{

0 commit comments

Comments
 (0)