Skip to content

Commit 9e07e4c

Browse files
committed
Implement client side encryption custom endpoint test
JAVA-3464 KeyManagementService
1 parent 2c9a064 commit 9e07e4c

File tree

6 files changed

+366
-62
lines changed

6 files changed

+366
-62
lines changed

driver-async/src/main/com/mongodb/async/client/internal/Crypt.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ private void decryptKeys(final MongoCryptContext cryptContext, final String data
342342
@Override
343343
public void onResult(final Void result, final Throwable t) {
344344
if (t != null) {
345-
callback.onResult(null, t);
345+
callback.onResult(null, wrapInClientException(t));
346346
} else {
347347
decryptKeys(cryptContext, databaseName, callback);
348348
}

driver-async/src/main/com/mongodb/async/client/internal/KeyManagementService.java

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@
1717
package com.mongodb.async.client.internal;
1818

1919
import com.mongodb.MongoSocketException;
20-
import com.mongodb.MongoSocketOpenException;
21-
import com.mongodb.MongoSocketReadException;
22-
import com.mongodb.MongoSocketReadTimeoutException;
23-
import com.mongodb.MongoSocketWriteException;
2420
import com.mongodb.ServerAddress;
2521
import com.mongodb.async.SingleResultCallback;
2622
import com.mongodb.connection.AsyncCompletionHandler;
@@ -36,19 +32,18 @@
3632

3733
import javax.net.ssl.SSLContext;
3834
import java.nio.channels.CompletionHandler;
39-
import java.nio.channels.InterruptedByTimeoutException;
4035
import java.util.Collections;
4136
import java.util.List;
4237
import java.util.concurrent.TimeUnit;
4338

4439
import static java.util.concurrent.TimeUnit.MILLISECONDS;
4540

4641
class KeyManagementService {
47-
private final int port;
42+
private final int defaultPort;
4843
private final StreamFactory streamFactory;
4944

50-
KeyManagementService(final SSLContext sslContext, final int port, final int timeoutMillis) {
51-
this.port = port;
45+
KeyManagementService(final SSLContext sslContext, final int defaultPort, final int timeoutMillis) {
46+
this.defaultPort = defaultPort;
5247
this.streamFactory = new TlsChannelStreamFactoryFactory().create(SocketSettings.builder()
5348
.connectTimeout(timeoutMillis, TimeUnit.MILLISECONDS)
5449
.readTimeout(timeoutMillis, TimeUnit.MILLISECONDS)
@@ -61,7 +56,10 @@ void decryptKey(final MongoKeyDecryptor keyDecryptor, final SingleResultCallback
6156
}
6257

6358
private void streamOpen(final MongoKeyDecryptor keyDecryptor, final SingleResultCallback<Void> callback) {
64-
final Stream stream = streamFactory.create(new ServerAddress(keyDecryptor.getHostName(), port));
59+
ServerAddress serverAddress = keyDecryptor.getHostName().contains(":")
60+
? new ServerAddress(keyDecryptor.getHostName())
61+
: new ServerAddress(keyDecryptor.getHostName(), defaultPort);
62+
final Stream stream = streamFactory.create(serverAddress);
6563
stream.openAsync(new AsyncCompletionHandler<Void>() {
6664
@Override
6765
public void completed(final Void aVoid) {
@@ -71,8 +69,7 @@ public void completed(final Void aVoid) {
7169
@Override
7270
public void failed(final Throwable t) {
7371
stream.close();
74-
callback.onResult(null, new MongoSocketOpenException("Exception opening connection to Key Management Service",
75-
getServerAddress(keyDecryptor), t));
72+
callback.onResult(null, wrapException(t));
7673
}
7774
});
7875
}
@@ -88,8 +85,7 @@ public void completed(final Void aVoid) {
8885
@Override
8986
public void failed(final Throwable t) {
9087
stream.close();
91-
callback.onResult(null, new MongoSocketWriteException("Exception sending message to Key Management Service",
92-
getServerAddress(keyDecryptor), t));
88+
callback.onResult(null, wrapException(t));
9389
}
9490
});
9591
}
@@ -105,24 +101,20 @@ private void streamRead(final Stream stream, final MongoKeyDecryptor keyDecrypto
105101
@Override
106102
public void completed(final Integer integer, final Void aVoid) {
107103
buffer.flip();
108-
keyDecryptor.feed(buffer.asNIO());
109-
buffer.release();
110-
streamRead(stream, keyDecryptor, callback);
104+
try {
105+
keyDecryptor.feed(buffer.asNIO());
106+
buffer.release();
107+
streamRead(stream, keyDecryptor, callback);
108+
} catch (Throwable t) {
109+
callback.onResult(null, t);
110+
}
111111
}
112112

113113
@Override
114114
public void failed(final Throwable t, final Void aVoid) {
115115
buffer.release();
116116
stream.close();
117-
MongoSocketException exception;
118-
if (t instanceof InterruptedByTimeoutException) {
119-
exception = new MongoSocketReadTimeoutException("Timeout while receiving message from Key Management "
120-
+ "Service", getServerAddress(keyDecryptor), t);
121-
} else {
122-
exception = new MongoSocketReadException("Exception receiving message from Key Management Service",
123-
getServerAddress(keyDecryptor), t);
124-
}
125-
callback.onResult(null, exception);
117+
callback.onResult(null, wrapException(t));
126118
}
127119
});
128120
} else {
@@ -131,7 +123,7 @@ public void failed(final Throwable t, final Void aVoid) {
131123
}
132124
}
133125

134-
private ServerAddress getServerAddress(final MongoKeyDecryptor keyDecryptor) {
135-
return new ServerAddress(keyDecryptor.getHostName(), port);
126+
private Throwable wrapException(final Throwable t) {
127+
return t instanceof MongoSocketException ? t.getCause() : t;
136128
}
137129
}
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
/*
2+
* Copyright 2008-present MongoDB, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.mongodb.async.client;
18+
19+
import com.mongodb.ClientEncryptionSettings;
20+
import com.mongodb.MongoClientException;
21+
import com.mongodb.async.FutureResultCallback;
22+
import com.mongodb.async.client.vault.ClientEncryption;
23+
import com.mongodb.async.client.vault.ClientEncryptions;
24+
import com.mongodb.client.model.vault.DataKeyOptions;
25+
import com.mongodb.client.model.vault.EncryptOptions;
26+
import com.mongodb.crypt.capi.MongoCryptException;
27+
import com.mongodb.lang.Nullable;
28+
import org.bson.BsonBinary;
29+
import org.bson.BsonDocument;
30+
import org.bson.BsonString;
31+
import org.junit.After;
32+
import org.junit.Before;
33+
import org.junit.Test;
34+
import org.junit.runner.RunWith;
35+
import org.junit.runners.Parameterized;
36+
37+
import java.net.ConnectException;
38+
import java.util.ArrayList;
39+
import java.util.Collection;
40+
import java.util.HashMap;
41+
import java.util.List;
42+
import java.util.Map;
43+
import java.util.concurrent.TimeUnit;
44+
45+
import static com.mongodb.ClusterFixture.TIMEOUT;
46+
import static com.mongodb.ClusterFixture.isNotAtLeastJava8;
47+
import static com.mongodb.ClusterFixture.serverVersionAtLeast;
48+
import static org.junit.Assert.assertEquals;
49+
import static org.junit.Assert.assertNull;
50+
import static org.junit.Assert.assertTrue;
51+
import static org.junit.Assume.assumeFalse;
52+
import static org.junit.Assume.assumeTrue;
53+
54+
@RunWith(Parameterized.class)
55+
public class ClientEncryptionCustomEndpointTest {
56+
57+
private ClientEncryption clientEncryption;
58+
private BsonDocument masterKey;
59+
private final Class<? extends RuntimeException> exceptionClass;
60+
private final Class<? extends RuntimeException> wrappedExceptionClass;
61+
private final String messageContainedInException;
62+
63+
public ClientEncryptionCustomEndpointTest(@SuppressWarnings("unused") final String name,
64+
final BsonDocument masterKey,
65+
@Nullable final Class<? extends RuntimeException> exceptionClass,
66+
@Nullable final Class<? extends RuntimeException> wrappedExceptionClass,
67+
@Nullable final String messageContainedInException) {
68+
this.masterKey = masterKey;
69+
this.exceptionClass = exceptionClass;
70+
this.wrappedExceptionClass = wrappedExceptionClass;
71+
this.messageContainedInException = messageContainedInException;
72+
}
73+
74+
@Before
75+
public void setUp() {
76+
assumeFalse(isNotAtLeastJava8());
77+
assumeTrue(serverVersionAtLeast(4, 1));
78+
assumeTrue("Encryption test with external keyVault is disabled",
79+
System.getProperty("org.mongodb.test.awsAccessKeyId") != null
80+
&& !System.getProperty("org.mongodb.test.awsAccessKeyId").isEmpty());
81+
82+
Map<String, Map<String, Object>> kmsProviders = new HashMap<String, Map<String, Object>>();
83+
Map<String, Object> awsCreds = new HashMap<String, Object>();
84+
awsCreds.put("accessKeyId", System.getProperty("org.mongodb.test.awsAccessKeyId"));
85+
awsCreds.put("secretAccessKey", System.getProperty("org.mongodb.test.awsSecretAccessKey"));
86+
kmsProviders.put("aws", awsCreds);
87+
88+
ClientEncryptionSettings.Builder clientEncryptionSettingsBuilder = ClientEncryptionSettings.builder().
89+
keyVaultMongoClientSettings(Fixture.getMongoClientSettings())
90+
.kmsProviders(kmsProviders)
91+
.keyVaultNamespace("admin.datakeys");
92+
93+
ClientEncryptionSettings clientEncryptionSettings = clientEncryptionSettingsBuilder.build();
94+
clientEncryption = ClientEncryptions.create(clientEncryptionSettings);
95+
}
96+
97+
@After
98+
public void after() {
99+
if (clientEncryption != null) {
100+
try {
101+
clientEncryption.close();
102+
} catch (Exception e) {
103+
// ignore
104+
}
105+
}
106+
}
107+
108+
@Test
109+
public void testEndpoint() throws Exception {
110+
try {
111+
FutureResultCallback<BsonBinary> dataKeyCreationCallback = new FutureResultCallback<BsonBinary>();
112+
clientEncryption.createDataKey("aws", new DataKeyOptions()
113+
.masterKey(masterKey), dataKeyCreationCallback);
114+
115+
BsonBinary dataKeyId = dataKeyCreationCallback.get(TIMEOUT, TimeUnit.SECONDS);
116+
117+
assertNull("Expected exception, but encryption succeeded", exceptionClass);
118+
119+
FutureResultCallback<BsonBinary> encryptCallback = new FutureResultCallback<BsonBinary>();
120+
clientEncryption.encrypt(new BsonString("test"), new EncryptOptions("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic")
121+
.keyId(dataKeyId), encryptCallback);
122+
encryptCallback.get();
123+
} catch (Exception e) {
124+
if (exceptionClass == null) {
125+
throw e;
126+
}
127+
try {
128+
assertEquals(exceptionClass, e.getClass());
129+
assertEquals(wrappedExceptionClass, e.getCause().getClass());
130+
} catch (AssertionError ae) {
131+
throw e;
132+
}
133+
if (messageContainedInException != null) {
134+
assertTrue(e.getCause().getMessage().contains(messageContainedInException));
135+
}
136+
}
137+
}
138+
139+
@Parameterized.Parameters(name = "{0}")
140+
public static Collection<Object[]> data() {
141+
List<Object[]> data = new ArrayList<Object[]>();
142+
143+
data.add(new Object[]{"default endpoint",
144+
getDefaultMasterKey(),
145+
null, null, null});
146+
data.add(new Object[]{"valid endpoint",
147+
getDefaultMasterKey().append("endpoint", new BsonString("kms.us-east-1.amazonaws.com")),
148+
null, null, null});
149+
data.add(new Object[]{"valid endpoint port",
150+
getDefaultMasterKey().append("endpoint", new BsonString("kms.us-east-1.amazonaws.com:443")),
151+
null, null, null});
152+
data.add(new Object[]{"invalid endpoint port",
153+
getDefaultMasterKey().append("endpoint", new BsonString("kms.us-east-1.amazonaws.com:12345")),
154+
MongoClientException.class, ConnectException.class, "Connection refused"});
155+
data.add(new Object[]{"invalid amazon region in endpoint",
156+
getDefaultMasterKey().append("endpoint", new BsonString("kms.us-east-2.amazonaws.com")),
157+
MongoClientException.class, MongoCryptException.class, "us-east-1"});
158+
data.add(new Object[]{"invalid endpoint host",
159+
getDefaultMasterKey().append("endpoint", new BsonString("example.com")),
160+
MongoClientException.class, MongoCryptException.class, "parse error"});
161+
162+
return data;
163+
}
164+
165+
private static BsonDocument getDefaultMasterKey() {
166+
return new BsonDocument()
167+
.append("region", new BsonString("us-east-1"))
168+
.append("key", new BsonString("arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0"));
169+
}
170+
}

driver-sync/src/main/com/mongodb/client/internal/Crypt.java

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
import com.mongodb.MongoClientException;
2020
import com.mongodb.MongoException;
2121
import com.mongodb.MongoInternalException;
22-
import com.mongodb.MongoSocketReadException;
23-
import com.mongodb.ServerAddress;
2422
import com.mongodb.client.model.vault.DataKeyOptions;
2523
import com.mongodb.client.model.vault.EncryptOptions;
2624
import com.mongodb.crypt.capi.MongoCrypt;
@@ -269,7 +267,7 @@ private void mark(final MongoCryptContext cryptContext, final String databaseNam
269267
cryptContext.addMongoOperationResult(markedCommand);
270268
cryptContext.completeMongoOperation();
271269
} catch (Throwable t) {
272-
throw MongoException.fromThrowableNonNull(t);
270+
throw wrapInClientException(t);
273271
}
274272
}
275273

@@ -293,11 +291,11 @@ private void decryptKeys(final MongoCryptContext cryptContext) {
293291
}
294292
cryptContext.completeKeyDecryptors();
295293
} catch (Throwable t) {
296-
throw MongoException.fromThrowableNonNull(t);
294+
throw wrapInClientException(t);
297295
}
298296
}
299297

300-
private void decryptKey(final MongoKeyDecryptor keyDecryptor) {
298+
private void decryptKey(final MongoKeyDecryptor keyDecryptor) throws IOException {
301299
InputStream inputStream = keyManagementService.stream(keyDecryptor.getHostName(), keyDecryptor.getMessage());
302300
try {
303301
int bytesNeeded = keyDecryptor.bytesNeeded();
@@ -308,10 +306,6 @@ private void decryptKey(final MongoKeyDecryptor keyDecryptor) {
308306
keyDecryptor.feed(ByteBuffer.wrap(bytes, 0, bytesRead));
309307
bytesNeeded = keyDecryptor.bytesNeeded();
310308
}
311-
} catch (IOException e) {
312-
throw new MongoSocketReadException("Exception receiving message from key management service",
313-
new ServerAddress(keyDecryptor.getHostName(), keyManagementService.getPort()), e);
314-
// type
315309
} finally {
316310
try {
317311
inputStream.close();
@@ -321,7 +315,7 @@ private void decryptKey(final MongoKeyDecryptor keyDecryptor) {
321315
}
322316
}
323317

324-
private MongoClientException wrapInClientException(final MongoCryptException e) {
325-
return new MongoClientException("Exception in encryption library: " + e.getMessage(), e);
318+
private MongoClientException wrapInClientException(final Throwable t) {
319+
return new MongoClientException("Exception in encryption library: " + t.getMessage(), t);
326320
}
327321
}

0 commit comments

Comments
 (0)