Skip to content

Commit eada8a2

Browse files
authored
Add SSLContext configuration per KMS provider (#820)
JAVA-4374
1 parent 84331ac commit eada8a2

File tree

8 files changed

+233
-60
lines changed

8 files changed

+233
-60
lines changed

driver-core/src/main/com/mongodb/AutoEncryptionSettings.java

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@
2020
import com.mongodb.lang.Nullable;
2121
import org.bson.BsonDocument;
2222

23+
import javax.net.ssl.SSLContext;
2324
import java.util.Collections;
25+
import java.util.HashMap;
2426
import java.util.Map;
2527

2628
import static com.mongodb.assertions.Assertions.notNull;
29+
import static java.util.Collections.unmodifiableMap;
2730

2831
/**
2932
* The client-side automatic encryption settings. Client side encryption enables an application to specify what fields in a collection
@@ -58,6 +61,7 @@ public final class AutoEncryptionSettings {
5861
private final MongoClientSettings keyVaultMongoClientSettings;
5962
private final String keyVaultNamespace;
6063
private final Map<String, Map<String, Object>> kmsProviders;
64+
private final Map<String, SSLContext> kmsProviderSslContextMap;
6165
private final Map<String, BsonDocument> schemaMap;
6266
private final Map<String, Object> extraOptions;
6367
private final boolean bypassAutoEncryption;
@@ -71,6 +75,7 @@ public static final class Builder {
7175
private MongoClientSettings keyVaultMongoClientSettings;
7276
private String keyVaultNamespace;
7377
private Map<String, Map<String, Object>> kmsProviders;
78+
private Map<String, SSLContext> kmsProviderSslContextMap = new HashMap<>();
7479
private Map<String, BsonDocument> schemaMap = Collections.emptyMap();
7580
private Map<String, Object> extraOptions = Collections.emptyMap();
7681
private boolean bypassAutoEncryption;
@@ -111,6 +116,19 @@ public Builder kmsProviders(final Map<String, Map<String, Object>> kmsProviders)
111116
return this;
112117
}
113118

119+
/**
120+
* Sets the KMS provider to SSLContext map
121+
*
122+
* @param kmsProviderSslContextMap the KMS provider to SSLContext map, which may not be null
123+
* @return this
124+
* @see #getKmsProviderSslContextMap()
125+
* @since 4.4
126+
*/
127+
public Builder kmsProviderSslContextMap(final Map<String, SSLContext> kmsProviderSslContextMap) {
128+
this.kmsProviderSslContextMap = notNull("kmsProviderSslContextMap", kmsProviderSslContextMap);
129+
return this;
130+
}
131+
114132
/**
115133
* Sets the map from namespace to local schema document
116134
*
@@ -250,7 +268,22 @@ public String getKeyVaultNamespace() {
250268
* @return map of KMS provider properties
251269
*/
252270
public Map<String, Map<String, Object>> getKmsProviders() {
253-
return kmsProviders;
271+
return unmodifiableMap(kmsProviders);
272+
}
273+
274+
/**
275+
* Gets the KMS provider to SSLContext map.
276+
*
277+
* <p>
278+
* If a KMS provider is mapped to a non-null {@link SSLContext}, the context will be used to establish a TLS connection to the KMS.
279+
* Otherwise, the default context will be used.
280+
* </p>
281+
*
282+
* @return the KMS provider to SSLContext map
283+
* @since 4.4
284+
*/
285+
public Map<String, SSLContext> getKmsProviderSslContextMap() {
286+
return unmodifiableMap(kmsProviderSslContextMap);
254287
}
255288

256289
/**
@@ -321,6 +354,7 @@ private AutoEncryptionSettings(final Builder builder) {
321354
this.keyVaultMongoClientSettings = builder.keyVaultMongoClientSettings;
322355
this.keyVaultNamespace = notNull("keyVaultNamespace", builder.keyVaultNamespace);
323356
this.kmsProviders = notNull("kmsProviders", builder.kmsProviders);
357+
this.kmsProviderSslContextMap = notNull("kmsProviderSslContextMap", builder.kmsProviderSslContextMap);
324358
this.schemaMap = notNull("schemaMap", builder.schemaMap);
325359
this.extraOptions = notNull("extraOptions", builder.extraOptions);
326360
this.bypassAutoEncryption = builder.bypassAutoEncryption;

driver-core/src/main/com/mongodb/ClientEncryptionSettings.java

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818

1919
import com.mongodb.annotations.NotThreadSafe;
2020

21+
import javax.net.ssl.SSLContext;
22+
import java.util.HashMap;
2123
import java.util.Map;
2224

2325
import static com.mongodb.assertions.Assertions.notNull;
26+
import static java.util.Collections.unmodifiableMap;
2427

2528
/**
2629
* The client-side settings for data key creation and explicit encryption.
@@ -36,7 +39,7 @@ public final class ClientEncryptionSettings {
3639
private final MongoClientSettings keyVaultMongoClientSettings;
3740
private final String keyVaultNamespace;
3841
private final Map<String, Map<String, Object>> kmsProviders;
39-
42+
private final Map<String, SSLContext> kmsProviderSslContextMap;
4043
/**
4144
* A builder for {@code ClientEncryptionSettings} so that {@code ClientEncryptionSettings} can be immutable, and to support easier
4245
* construction through chaining.
@@ -46,6 +49,7 @@ public static final class Builder {
4649
private MongoClientSettings keyVaultMongoClientSettings;
4750
private String keyVaultNamespace;
4851
private Map<String, Map<String, Object>> kmsProviders;
52+
private Map<String, SSLContext> kmsProviderSslContextMap = new HashMap<>();
4953

5054
/**
5155
* Sets the key vault settings.
@@ -83,6 +87,19 @@ public Builder kmsProviders(final Map<String, Map<String, Object>> kmsProviders)
8387
return this;
8488
}
8589

90+
/**
91+
* Sets the KMS provider to SSLContext map
92+
*
93+
* @param kmsProviderSslContextMap the KMS provider to SSLContext map, which may not be null
94+
* @return this
95+
* @see #getKmsProviderSslContextMap()
96+
* @since 4.4
97+
*/
98+
public Builder kmsProviderSslContextMap(final Map<String, SSLContext> kmsProviderSslContextMap) {
99+
this.kmsProviderSslContextMap = notNull("kmsProviderSslContextMap", kmsProviderSslContextMap);
100+
return this;
101+
}
102+
86103
/**
87104
* Build an instance of {@code ClientEncryptionSettings}.
88105
*
@@ -184,13 +201,29 @@ public String getKeyVaultNamespace() {
184201
* @return map of KMS provider properties
185202
*/
186203
public Map<String, Map<String, Object>> getKmsProviders() {
187-
return kmsProviders;
204+
return unmodifiableMap(kmsProviders);
205+
}
206+
207+
/**
208+
* Gets the KMS provider to SSLContext map.
209+
*
210+
* <p>
211+
* If a KMS provider is mapped to a non-null {@link SSLContext}, the context will be used to establish a TLS connection to the KMS.
212+
* Otherwise, the default context will be used.
213+
* </p>
214+
*
215+
* @return the KMS provider to SSLContext map
216+
* @since 4.4
217+
*/
218+
public Map<String, SSLContext> getKmsProviderSslContextMap() {
219+
return unmodifiableMap(kmsProviderSslContextMap);
188220
}
189221

190222
private ClientEncryptionSettings(final Builder builder) {
191223
this.keyVaultMongoClientSettings = builder.keyVaultMongoClientSettings;
192224
this.keyVaultNamespace = notNull("keyVaultNamespace", builder.keyVaultNamespace);
193225
this.kmsProviders = notNull("kmsProviders", builder.kmsProviders);
226+
this.kmsProviderSslContextMap = notNull("kmsProviderSslContextMap", builder.kmsProviderSslContextMap);
194227
}
195228

196229
}

driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/Crypts.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import javax.net.ssl.SSLContext;
3030
import java.security.NoSuchAlgorithmException;
31+
import java.util.Map;
3132

3233
import static com.mongodb.internal.capi.MongoCryptHelper.createMongoCryptOptions;
3334

@@ -54,7 +55,7 @@ public static Crypt createCrypt(final MongoClientImpl client, final AutoEncrypti
5455
options.isBypassAutoEncryption() ? null : new CollectionInfoRetriever(collectionInfoRetrieverClient),
5556
new CommandMarker(options.isBypassAutoEncryption(), options.getExtraOptions()),
5657
new KeyRetriever(keyVaultClient, new MongoNamespace(options.getKeyVaultNamespace())),
57-
createKeyManagementService(),
58+
createKeyManagementService(options.getKmsProviderSslContextMap()),
5859
options.isBypassAutoEncryption(),
5960
internalClient);
6061
}
@@ -63,11 +64,11 @@ public static Crypt create(final MongoClient keyVaultClient, final ClientEncrypt
6364
return new Crypt(MongoCrypts.create(
6465
createMongoCryptOptions(options.getKmsProviders(), null)),
6566
new KeyRetriever(keyVaultClient, new MongoNamespace(options.getKeyVaultNamespace())),
66-
createKeyManagementService());
67+
createKeyManagementService(options.getKmsProviderSslContextMap()));
6768
}
6869

69-
private static KeyManagementService createKeyManagementService() {
70-
return new KeyManagementService(getSslContext(), 443, 10000);
70+
private static KeyManagementService createKeyManagementService(final Map<String, SSLContext> kmsProviderSslContextMap) {
71+
return new KeyManagementService(kmsProviderSslContextMap, 10000);
7172
}
7273

7374
private static SSLContext getSslContext() {

driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/crypt/KeyManagementService.java

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import com.mongodb.connection.StreamFactory;
2626
import com.mongodb.connection.TlsChannelStreamFactoryFactory;
2727
import com.mongodb.crypt.capi.MongoKeyDecryptor;
28+
import com.mongodb.diagnostics.logging.Logger;
29+
import com.mongodb.diagnostics.logging.Loggers;
2830
import com.mongodb.internal.connection.AsynchronousChannelStream;
2931
import org.bson.ByteBuf;
3032
import org.bson.ByteBufNIO;
@@ -35,36 +37,41 @@
3537
import java.io.Closeable;
3638
import java.nio.channels.CompletionHandler;
3739
import java.util.List;
38-
import java.util.concurrent.TimeUnit;
40+
import java.util.Map;
3941

4042
import static java.util.Collections.singletonList;
4143
import static java.util.concurrent.TimeUnit.MILLISECONDS;
4244

4345
class KeyManagementService implements Closeable {
44-
private final int defaultPort;
46+
private static final Logger LOGGER = Loggers.getLogger("client");
47+
private final Map<String, SSLContext> kmsProviderSslContextMap;
48+
private final int timeoutMillis;
4549
private final TlsChannelStreamFactoryFactory tlsChannelStreamFactoryFactory;
46-
private final StreamFactory streamFactory;
4750

48-
KeyManagementService(final SSLContext sslContext, final int defaultPort, final int timeoutMillis) {
49-
this.defaultPort = defaultPort;
51+
KeyManagementService(final Map<String, SSLContext> kmsProviderSslContextMap, final int timeoutMillis) {
52+
this.kmsProviderSslContextMap = kmsProviderSslContextMap;
5053
this.tlsChannelStreamFactoryFactory = new TlsChannelStreamFactoryFactory();
51-
this.streamFactory = tlsChannelStreamFactoryFactory.create(SocketSettings.builder()
52-
.connectTimeout(timeoutMillis, TimeUnit.MILLISECONDS)
53-
.readTimeout(timeoutMillis, TimeUnit.MILLISECONDS)
54-
.build(),
55-
SslSettings.builder().enabled(true).context(sslContext).build());
54+
this.timeoutMillis = timeoutMillis;
5655
}
5756

5857
public void close() {
5958
tlsChannelStreamFactoryFactory.close();
6059
}
6160

6261
Mono<Void> decryptKey(final MongoKeyDecryptor keyDecryptor) {
62+
SocketSettings socketSettings = SocketSettings.builder()
63+
.connectTimeout(timeoutMillis, MILLISECONDS)
64+
.readTimeout(timeoutMillis, MILLISECONDS)
65+
.build();
66+
StreamFactory streamFactory = tlsChannelStreamFactoryFactory.create(socketSettings,
67+
SslSettings.builder().enabled(true).context(kmsProviderSslContextMap.get(keyDecryptor.getKmsProvider())).build());
68+
69+
ServerAddress serverAddress = new ServerAddress(keyDecryptor.getHostName());
70+
71+
LOGGER.info("Connecting to KMS server at " + serverAddress);
72+
6373
return Mono.<Void>create(sink -> {
64-
ServerAddress serverAddress = keyDecryptor.getHostName().contains(":")
65-
? new ServerAddress(keyDecryptor.getHostName())
66-
: new ServerAddress(keyDecryptor.getHostName(), defaultPort);
67-
final Stream stream = streamFactory.create(serverAddress);
74+
Stream stream = streamFactory.create(serverAddress);
6875
stream.openAsync(new AsyncCompletionHandler<Void>() {
6976
@Override
7077
public void completed(final Void ignored) {

driver-sync/src/main/com/mongodb/client/internal/Crypt.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,8 @@ private void decryptKeys(final MongoCryptContext cryptContext) {
302302
}
303303

304304
private void decryptKey(final MongoKeyDecryptor keyDecryptor) throws IOException {
305-
InputStream inputStream = keyManagementService.stream(keyDecryptor.getHostName(), keyDecryptor.getMessage());
305+
InputStream inputStream = keyManagementService.stream(keyDecryptor.getKmsProvider(), keyDecryptor.getHostName(),
306+
keyDecryptor.getMessage());
306307
try {
307308
int bytesNeeded = keyDecryptor.bytesNeeded();
308309

driver-sync/src/main/com/mongodb/client/internal/Crypts.java

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,14 @@
1818

1919
import com.mongodb.AutoEncryptionSettings;
2020
import com.mongodb.ClientEncryptionSettings;
21-
import com.mongodb.MongoClientException;
2221
import com.mongodb.MongoClientSettings;
2322
import com.mongodb.MongoNamespace;
2423
import com.mongodb.client.MongoClient;
2524
import com.mongodb.client.MongoClients;
2625
import com.mongodb.crypt.capi.MongoCrypts;
2726

2827
import javax.net.ssl.SSLContext;
29-
import java.security.NoSuchAlgorithmException;
28+
import java.util.Map;
3029

3130
import static com.mongodb.internal.capi.MongoCryptHelper.createMongoCryptOptions;
3231

@@ -50,7 +49,7 @@ public static Crypt createCrypt(final MongoClientImpl client, final AutoEncrypti
5049
options.isBypassAutoEncryption() ? null : new CollectionInfoRetriever(collectionInfoRetrieverClient),
5150
new CommandMarker(options.isBypassAutoEncryption(), options.getExtraOptions()),
5251
new KeyRetriever(keyVaultClient, new MongoNamespace(options.getKeyVaultNamespace())),
53-
createKeyManagementService(),
52+
createKeyManagementService(options.getKmsProviderSslContextMap()),
5453
options.isBypassAutoEncryption(),
5554
internalClient);
5655
}
@@ -59,26 +58,16 @@ static Crypt create(final MongoClient keyVaultClient, final ClientEncryptionSett
5958
return new Crypt(MongoCrypts.create(
6059
createMongoCryptOptions(options.getKmsProviders(), null)),
6160
createKeyRetriever(keyVaultClient, options.getKeyVaultNamespace()),
62-
createKeyManagementService());
61+
createKeyManagementService(options.getKmsProviderSslContextMap()));
6362
}
6463

6564
private static KeyRetriever createKeyRetriever(final MongoClient keyVaultClient,
6665
final String keyVaultNamespaceString) {
6766
return new KeyRetriever(keyVaultClient, new MongoNamespace(keyVaultNamespaceString));
6867
}
6968

70-
private static KeyManagementService createKeyManagementService() {
71-
return new KeyManagementService(getSslContext(), 443, 10000);
72-
}
73-
74-
private static SSLContext getSslContext() {
75-
SSLContext sslContext;
76-
try {
77-
sslContext = SSLContext.getDefault();
78-
} catch (NoSuchAlgorithmException e) {
79-
throw new MongoClientException("Unable to create default SSLContext", e);
80-
}
81-
return sslContext;
69+
private static KeyManagementService createKeyManagementService(final Map<String, SSLContext> kmsProviderSslContextMap) {
70+
return new KeyManagementService(kmsProviderSslContextMap, 10000);
8271
}
8372

8473
private Crypts() {

driver-sync/src/main/com/mongodb/client/internal/KeyManagementService.java

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,48 @@
1717
package com.mongodb.client.internal;
1818

1919
import com.mongodb.ServerAddress;
20+
import com.mongodb.diagnostics.logging.Logger;
21+
import com.mongodb.diagnostics.logging.Loggers;
2022
import com.mongodb.internal.connection.SslHelper;
2123

24+
import javax.net.SocketFactory;
2225
import javax.net.ssl.SSLContext;
2326
import javax.net.ssl.SSLParameters;
2427
import javax.net.ssl.SSLSocket;
28+
import javax.net.ssl.SSLSocketFactory;
2529
import java.io.IOException;
2630
import java.io.InputStream;
2731
import java.io.OutputStream;
2832
import java.net.InetAddress;
2933
import java.net.InetSocketAddress;
3034
import java.net.Socket;
3135
import java.nio.ByteBuffer;
36+
import java.util.Map;
37+
38+
import static com.mongodb.assertions.Assertions.notNull;
3239

3340
class KeyManagementService {
34-
private final SSLContext sslContext;
35-
private final int defaultPort;
41+
private static final Logger LOGGER = Loggers.getLogger("client");
42+
private final Map<String, SSLContext> kmsProviderSslContextMap;
3643
private final int timeoutMillis;
3744

38-
KeyManagementService(final SSLContext sslContext, final int defaultPort, final int timeoutMillis) {
39-
this.sslContext = sslContext;
40-
this.defaultPort = defaultPort;
45+
KeyManagementService(final Map<String, SSLContext> kmsProviderSslContextMap, final int timeoutMillis) {
46+
this.kmsProviderSslContextMap = notNull("kmsProviderSslContextMap", kmsProviderSslContextMap);
4147
this.timeoutMillis = timeoutMillis;
4248
}
4349

44-
public InputStream stream(final String host, final ByteBuffer message) throws IOException {
45-
ServerAddress serverAddress = host.contains(":") ? new ServerAddress(host) : new ServerAddress(host, defaultPort);
46-
SSLSocket socket = (SSLSocket) sslContext.getSocketFactory().createSocket();
50+
public InputStream stream(final String kmsProvider, final String host, final ByteBuffer message) throws IOException {
51+
ServerAddress serverAddress = new ServerAddress(host);
52+
53+
LOGGER.info("Connecting to KMS server at " + serverAddress);
54+
SSLContext sslContext = kmsProviderSslContextMap.get(kmsProvider);
55+
56+
SocketFactory sslSocketFactory = sslContext == null
57+
? SSLSocketFactory.getDefault() : sslContext.getSocketFactory();
58+
SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket();
59+
enableHostNameVerification(socket);
4760

4861
try {
49-
enableHostNameVerification(socket);
5062
socket.setSoTimeout(timeoutMillis);
5163
socket.connect(new InetSocketAddress(InetAddress.getByName(serverAddress.getHost()), serverAddress.getPort()), timeoutMillis);
5264
} catch (IOException e) {
@@ -83,10 +95,6 @@ private void enableHostNameVerification(final SSLSocket socket) {
8395
socket.setSSLParameters(sslParameters);
8496
}
8597

86-
public int getDefaultPort() {
87-
return defaultPort;
88-
}
89-
9098
private void closeSocket(final Socket socket) {
9199
try {
92100
socket.close();

0 commit comments

Comments
 (0)