Skip to content

Commit d6716c1

Browse files
CSHARP-3210: Support Azure and GCP keystores in FLE.
1 parent a73ed25 commit d6716c1

25 files changed

+12351
-750
lines changed

evergreen/evergreen.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,11 @@ functions:
251251
set +x
252252
export FLE_AWS_ACCESS_KEY_ID=${FLE_AWS_ACCESS_KEY_ID}
253253
export FLE_AWS_SECRET_ACCESS_KEY=${FLE_AWS_SECRET_ACCESS_KEY}
254+
export FLE_AZURE_TENANT_ID=${FLE_AZURE_TENANT_ID}
255+
export FLE_AZURE_CLIENT_ID=${FLE_AZURE_CLIENT_ID}
256+
export FLE_AZURE_CLIENT_SECRET=${FLE_AZURE_CLIENT_SECRET}
257+
export FLE_GCP_EMAIL=${FLE_GCP_EMAIL}
258+
export FLE_GCP_PRIVATE_KEY=${FLE_GCP_PRIVATE_KEY}
254259
${PREPARE_SHELL}
255260
SSL=${SSL} evergreen/add-certs-if-needed.sh
256261
AUTH=${AUTH} \

src/MongoDB.Driver.Core/Core/Clusters/CryptClientCreator.cs

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -65,24 +65,16 @@ private CryptClient CreateCryptClient(CryptOptions options)
6565

6666
private CryptOptions CreateCryptOptions()
6767
{
68-
Dictionary<KmsType, IKmsCredentials> kmsProvidersMap = null;
68+
List<KmsCredentials> kmsProviders = null;
6969
if (_kmsProviders != null && _kmsProviders.Count > 0)
7070
{
71-
kmsProvidersMap = new Dictionary<KmsType, IKmsCredentials>();
72-
if (_kmsProviders.TryGetValue("aws", out var awsProvider))
71+
kmsProviders = new List<KmsCredentials>();
72+
foreach (var kmsProvider in _kmsProviders)
7373
{
74-
if (awsProvider.TryGetValue("accessKeyId", out var accessKeyId) &&
75-
awsProvider.TryGetValue("secretAccessKey", out var secretAccessKey))
76-
{
77-
kmsProvidersMap.Add(KmsType.Aws, new AwsKmsCredentials((string)secretAccessKey, (string)accessKeyId));
78-
}
79-
}
80-
if (_kmsProviders.TryGetValue("local", out var localProvider))
81-
{
82-
if (localProvider.TryGetValue("key", out var keyObject) && keyObject is byte[] key)
83-
{
84-
kmsProvidersMap.Add(KmsType.Local, new LocalKmsCredentials(key));
85-
}
74+
var kmsTypeDocumentKey = kmsProvider.Key.ToLower();
75+
var kmsProviderDocument = CreateProviderDocument(kmsTypeDocumentKey, kmsProvider.Value);
76+
var kmsCredentials = new KmsCredentials(credentialsBytes: kmsProviderDocument.ToBson());
77+
kmsProviders.Add(kmsCredentials);
8678
}
8779
}
8880
else
@@ -105,7 +97,18 @@ private CryptOptions CreateCryptOptions()
10597
schemaBytes = schemaDocument.ToBson(writerSettings: writerSettings);
10698
}
10799

108-
return new CryptOptions(kmsProvidersMap, schemaBytes);
100+
return new CryptOptions(kmsProviders, schemaBytes);
101+
}
102+
103+
private BsonDocument CreateProviderDocument(string kmsType, IReadOnlyDictionary<string, object> data)
104+
{
105+
var providerContent = new BsonDocument();
106+
foreach (var record in data)
107+
{
108+
providerContent.Add(new BsonElement(record.Key, BsonValue.Create(record.Value)));
109+
}
110+
var providerDocument = new BsonDocument(kmsType, providerContent);
111+
return providerDocument;
109112
}
110113
}
111114
}

src/MongoDB.Driver.Core/MongoDB.Driver.Core.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@
128128
<ItemGroup>
129129
<PackageReference Include="DnsClient" Version="1.3.1" />
130130
<PackageReference Include="Microsoft.CodeAnalysis.FxCopAnalyzers" Version="2.6.2" PrivateAssets="All" />
131-
<PackageReference Include="MongoDB.Libmongocrypt" Version="1.0.0" />
131+
<PackageReference Include="MongoDB.Libmongocrypt" Version="1.1.0-beta02" />
132132
<PackageReference Include="SharpCompress" Version="0.23.0" />
133133
<PackageReference Include="System.Buffers" Version="4.4.0" />
134134
</ItemGroup>

src/MongoDB.Driver/Encryption/ExplicitEncryptionLibMongoCryptController.cs

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
using System.Threading.Tasks;
2121
using MongoDB.Bson;
2222
using MongoDB.Bson.IO;
23-
using MongoDB.Bson.Serialization.Serializers;
2423
using MongoDB.Driver.Core.Misc;
2524
using MongoDB.Libmongocrypt;
2625

@@ -219,26 +218,20 @@ public async Task<BsonBinaryData> EncryptFieldAsync(
219218
}
220219

221220
// private methods
222-
private IKmsKeyId GetKmsKeyId(string kmsProvider, IReadOnlyList<string> alternateKeyNames, BsonDocument masterKey)
221+
private KmsKeyId GetKmsKeyId(string kmsProvider, IReadOnlyList<string> alternateKeyNames, BsonDocument masterKey)
223222
{
224223
IEnumerable<byte[]> wrappedAlternateKeyNamesBytes = null;
225224
if (alternateKeyNames != null)
226225
{
227226
wrappedAlternateKeyNamesBytes = alternateKeyNames.Select(GetWrappedAlternateKeyNameBytes);
228227
}
229228

230-
switch (kmsProvider)
229+
var dataKeyDocument = new BsonDocument("provider", kmsProvider.ToLower());
230+
if (masterKey != null)
231231
{
232-
case "aws":
233-
var customerMasterKey = masterKey["key"].ToString();
234-
var endpoint = masterKey.GetValue("endpoint", null)?.ToString();
235-
var region = masterKey["region"].ToString();
236-
return new AwsKeyId(customerMasterKey, region, wrappedAlternateKeyNamesBytes, endpoint);
237-
case "local":
238-
return wrappedAlternateKeyNamesBytes != null ? new LocalKeyId(wrappedAlternateKeyNamesBytes) : new LocalKeyId();
239-
default:
240-
throw new ArgumentException($"Invalid kmsProvider {kmsProvider}.");
232+
dataKeyDocument.AddRange(masterKey.Elements);
241233
}
234+
return new KmsKeyId(dataKeyDocument.ToBson(), wrappedAlternateKeyNamesBytes);
242235
}
243236

244237
private byte[] GetWrappedAlternateKeyNameBytes(string value)

src/MongoDB.Driver/Encryption/LibMongoCryptControllerBase.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,9 @@ private void SendKmsRequest(KmsRequest request, CancellationToken cancellation)
235235
var requestBytes = request.Message.ToArray();
236236
sslStream.Write(requestBytes);
237237

238-
var buffer = new byte[4096];
239238
while (request.BytesNeeded > 0)
240239
{
240+
var buffer = new byte[request.BytesNeeded]; // BytesNeeded is the maximum number of bytes that libmongocrypt wants to receive.
241241
var count = sslStream.Read(buffer, 0, buffer.Length);
242242
var responseBytes = new byte[count];
243243
Buffer.BlockCopy(buffer, 0, responseBytes, 0, count);
@@ -264,9 +264,9 @@ private async Task SendKmsRequestAsync(KmsRequest request, CancellationToken can
264264
var requestBytes = request.Message.ToArray();
265265
await sslStream.WriteAsync(requestBytes, 0, requestBytes.Length).ConfigureAwait(false);
266266

267-
var buffer = new byte[4096];
268267
while (request.BytesNeeded > 0)
269268
{
269+
var buffer = new byte[request.BytesNeeded]; // BytesNeeded is the maximum number of bytes that libmongocrypt wants to receive.
270270
var count = await sslStream.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false);
271271
var responseBytes = new byte[count];
272272
Buffer.BlockCopy(buffer, 0, responseBytes, 0, count);

src/MongoDB.Driver/MongoDB.Driver.csproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
</PropertyGroup>
4343
<ItemGroup>
4444
<PackageReference Include="Microsoft.CodeAnalysis.FxCopAnalyzers" Version="2.6.2" PrivateAssets="All" />
45-
<PackageReference Include="MongoDB.Libmongocrypt" Version="1.0.0" />
45+
<PackageReference Include="MongoDB.Libmongocrypt" Version="1.1.0-beta02" />
4646
</ItemGroup>
4747

4848
<ItemGroup Condition="'$(TargetFramework)' == 'netstandard1.5'">
@@ -63,7 +63,7 @@
6363

6464
<ItemGroup>
6565
<None Include="..\..\License.txt" Pack="true" PackagePath="$(PackageLicenseFile)" />
66-
<None Include="..\..\packageIcon.png" Pack="true" PackagePath=""/>
66+
<None Include="..\..\packageIcon.png" Pack="true" PackagePath="" />
6767
</ItemGroup>
6868

6969
</Project>

tests/MongoDB.Driver.Core.TestHelpers/CoreExceptionHelper.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ public static Exception CreateException(string errorType)
100100
case "TimedOutSocketException":
101101
return new SocketException((int)SocketError.TimedOut);
102102

103+
case "ConnectionRefusedSocketException":
104+
return new SocketException((int)SocketError.ConnectionRefused);
105+
103106
default:
104107
throw new ArgumentException("Unknown error type.");
105108
}

tests/MongoDB.Driver.Tests/ClusterRegistryTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public void GetOrCreateCluster_should_return_a_cluster_with_the_correct_settings
5454
};
5555
var kmsProviders = new Dictionary<string, IReadOnlyDictionary<string, object>>()
5656
{
57-
{ "local", new Dictionary<string, object>() { { "key" , "test" } } }
57+
{ "local", new Dictionary<string, object>() { { "key" , new byte[96] } } }
5858
};
5959
var schemaMap = new Dictionary<string, BsonDocument>()
6060
{

tests/MongoDB.Driver.Tests/Specifications/client-side-encryption/ClientSideEncryptionTestRunner.cs

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,11 @@ private AutoEncryptionOptions ConfigureAutoEncryptionOptions(BsonDocument autoEn
207207
return autoEncryptionOptions;
208208
}
209209

210+
private string GetEnvironmentVariableOrDefaultOrThrowIfNothing(string variableName, string defaultValue = null) =>
211+
Environment.GetEnvironmentVariable(variableName) ??
212+
defaultValue ??
213+
throw new Exception($"{variableName} environment variable must be configured on the machine.");
214+
210215
private ReadOnlyDictionary<string, IReadOnlyDictionary<string, object>> ParseKmsProviders(BsonDocument kmsProviders)
211216
{
212217
var providers = new Dictionary<string, IReadOnlyDictionary<string, object>>();
@@ -217,26 +222,41 @@ private ReadOnlyDictionary<string, IReadOnlyDictionary<string, object>> ParseKms
217222
{
218223
case "aws":
219224
{
220-
var awsRegion = Environment.GetEnvironmentVariable("FLE_AWS_REGION") ?? "us-east-1";
221-
var awsAccessKey = Environment.GetEnvironmentVariable("FLE_AWS_ACCESS_KEY_ID") ?? throw new Exception("The FLE_AWS_ACCESS_KEY_ID system variable should be configured on the machine.");
222-
var awsSecretAccessKey = Environment.GetEnvironmentVariable("FLE_AWS_SECRET_ACCESS_KEY") ?? throw new Exception("The FLE_AWS_SECRET_ACCESS_KEY system variable should be configured on the machine.");
223-
kmsOptions.Add("region", awsRegion);
225+
var awsAccessKey = GetEnvironmentVariableOrDefaultOrThrowIfNothing("FLE_AWS_ACCESS_KEY_ID");
226+
var awsSecretAccessKey = GetEnvironmentVariableOrDefaultOrThrowIfNothing("FLE_AWS_SECRET_ACCESS_KEY");
224227
kmsOptions.Add("accessKeyId", awsAccessKey);
225228
kmsOptions.Add("secretAccessKey", awsSecretAccessKey);
226229
}
227-
providers.Add(kmsProvider.Name, kmsOptions);
228230
break;
229231
case "local":
230232
if (kmsProvider.Value.AsBsonDocument.TryGetElement("key", out var key))
231233
{
232234
var binary = key.Value.AsBsonBinaryData;
233235
kmsOptions.Add(key.Name, binary.Bytes);
234236
}
235-
providers.Add(kmsProvider.Name, kmsOptions);
237+
break;
238+
case "azure":
239+
{
240+
var azureTenantId = GetEnvironmentVariableOrDefaultOrThrowIfNothing("FLE_AZURE_TENANT_ID");
241+
var azureClientId = GetEnvironmentVariableOrDefaultOrThrowIfNothing("FLE_AZURE_CLIENT_ID");
242+
var azureClientSecret = GetEnvironmentVariableOrDefaultOrThrowIfNothing("FLE_AZURE_CLIENT_SECRET");
243+
kmsOptions.Add("tenantId", azureTenantId);
244+
kmsOptions.Add("clientId", azureClientId);
245+
kmsOptions.Add("clientSecret", azureClientSecret);
246+
}
247+
break;
248+
case "gcp":
249+
{
250+
var gcpEmail = GetEnvironmentVariableOrDefaultOrThrowIfNothing("FLE_GCP_EMAIL");
251+
var gcpPrivateKey = GetEnvironmentVariableOrDefaultOrThrowIfNothing("FLE_GCP_PRIVATE_KEY");
252+
kmsOptions.Add("email", gcpEmail);
253+
kmsOptions.Add("privateKey", gcpPrivateKey);
254+
}
236255
break;
237256
default:
238257
throw new Exception($"Unexpected kms provider type {kmsProvider.Name}.");
239258
}
259+
providers.Add(kmsProvider.Name, kmsOptions);
240260
}
241261

242262
return new ReadOnlyDictionary<string, IReadOnlyDictionary<string, object>>(providers);

0 commit comments

Comments
 (0)