Skip to content

Commit 717ac88

Browse files
committed
Apply client-side encryption in transactions on sharded clusters
This fixes a bug in both sync and async drivers where client-side encryption is not applied when in a transaction. JAVA-3752
1 parent b263d52 commit 717ac88

File tree

10 files changed

+440
-77
lines changed

10 files changed

+440
-77
lines changed

driver-async/src/main/com/mongodb/async/client/ClientSessionBinding.java

Lines changed: 41 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import com.mongodb.async.SingleResultCallback;
2222
import com.mongodb.binding.AsyncConnectionSource;
2323
import com.mongodb.binding.AsyncReadWriteBinding;
24-
import com.mongodb.binding.AsyncSingleServerBinding;
2524
import com.mongodb.connection.AsyncConnection;
2625
import com.mongodb.connection.ClusterType;
2726
import com.mongodb.connection.Server;
@@ -53,77 +52,45 @@ public ReadPreference getReadPreference() {
5352

5453
@Override
5554
public void getReadConnectionSource(final SingleResultCallback<AsyncConnectionSource> callback) {
56-
wrapped.getReadConnectionSource(new SingleResultCallback<AsyncConnectionSource>() {
57-
@Override
58-
public void onResult(final AsyncConnectionSource result, final Throwable t) {
59-
if (t != null) {
60-
callback.onResult(null, t);
61-
} else {
62-
wrapConnectionSource(result, callback);
63-
}
64-
}
65-
});
55+
if (isActiveShardedTxn()) {
56+
getPinnedConnectionSource(callback);
57+
} else {
58+
wrapped.getReadConnectionSource(new WrappingCallback(callback));
59+
}
6660
}
6761

6862
public void getWriteConnectionSource(final SingleResultCallback<AsyncConnectionSource> callback) {
69-
wrapped.getWriteConnectionSource(new SingleResultCallback<AsyncConnectionSource>() {
70-
@Override
71-
public void onResult(final AsyncConnectionSource result, final Throwable t) {
72-
if (t != null) {
73-
callback.onResult(null, t);
74-
} else {
75-
wrapConnectionSource(result, callback);
76-
}
77-
}
78-
});
63+
if (isActiveShardedTxn()) {
64+
getPinnedConnectionSource(callback);
65+
} else {
66+
wrapped.getWriteConnectionSource(new WrappingCallback(callback));
67+
}
7968
}
8069

8170
@Override
8271
public SessionContext getSessionContext() {
8372
return sessionContext;
8473
}
8574

86-
private void wrapConnectionSource(final AsyncConnectionSource connectionSource,
87-
final SingleResultCallback<AsyncConnectionSource> callback) {
88-
if (isActiveShardedTxn()) {
89-
if (session.getPinnedServerAddress() == null) {
90-
wrapped.getCluster().selectServerAsync(
91-
new ReadPreferenceServerSelector(wrapped.getReadPreference()),
92-
new SingleResultCallback<Server>() {
93-
@Override
94-
public void onResult(final Server server, final Throwable t) {
95-
if (t != null) {
96-
callback.onResult(null, t);
97-
} else {
98-
session.setPinnedServerAddress(server.getDescription().getAddress());
99-
setSingleServerBindingConnectionSource(callback);
100-
}
75+
private void getPinnedConnectionSource(final SingleResultCallback<AsyncConnectionSource> callback) {
76+
if (session.getPinnedServerAddress() == null) {
77+
wrapped.getCluster().selectServerAsync(
78+
new ReadPreferenceServerSelector(wrapped.getReadPreference()), new SingleResultCallback<Server>() {
79+
@Override
80+
public void onResult(final Server server, final Throwable t) {
81+
if (t != null) {
82+
callback.onResult(null, t);
83+
} else {
84+
session.setPinnedServerAddress(server.getDescription().getAddress());
85+
wrapped.getConnectionSource(session.getPinnedServerAddress(), new WrappingCallback(callback));
10186
}
102-
});
103-
} else {
104-
setSingleServerBindingConnectionSource(callback);
105-
}
87+
}
88+
});
10689
} else {
107-
callback.onResult(new SessionBindingAsyncConnectionSource(connectionSource), null);
90+
wrapped.getConnectionSource(session.getPinnedServerAddress(), new WrappingCallback(callback));
10891
}
10992
}
11093

111-
private void setSingleServerBindingConnectionSource(final SingleResultCallback<AsyncConnectionSource> callback) {
112-
final AsyncSingleServerBinding binding =
113-
new AsyncSingleServerBinding(wrapped.getCluster(), session.getPinnedServerAddress(), wrapped.getReadPreference());
114-
binding.getWriteConnectionSource(new SingleResultCallback<AsyncConnectionSource>() {
115-
@Override
116-
public void onResult(final AsyncConnectionSource result, final Throwable t) {
117-
binding.release();
118-
if (t != null) {
119-
callback.onResult(null, t);
120-
} else {
121-
callback.onResult(new SessionBindingAsyncConnectionSource(result), null);
122-
}
123-
}
124-
});
125-
}
126-
12794
@Override
12895
public int getCount() {
12996
return wrapped.getCount();
@@ -225,4 +192,21 @@ public ReadConcern getReadConcern() {
225192
}
226193
}
227194
}
195+
196+
private class WrappingCallback implements SingleResultCallback<AsyncConnectionSource> {
197+
private final SingleResultCallback<AsyncConnectionSource> callback;
198+
199+
WrappingCallback(final SingleResultCallback<AsyncConnectionSource> callback) {
200+
this.callback = callback;
201+
}
202+
203+
@Override
204+
public void onResult(final AsyncConnectionSource result, final Throwable t) {
205+
if (t != null) {
206+
callback.onResult(null, t);
207+
} else {
208+
callback.onResult(new SessionBindingAsyncConnectionSource(result), null);
209+
}
210+
}
211+
}
228212
}

driver-async/src/main/com/mongodb/async/client/internal/AsyncCryptBinding.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package com.mongodb.async.client.internal;
1818

1919
import com.mongodb.ReadPreference;
20+
import com.mongodb.ServerAddress;
2021
import com.mongodb.async.SingleResultCallback;
2122
import com.mongodb.binding.AsyncConnectionSource;
2223
import com.mongodb.binding.AsyncReadWriteBinding;
@@ -74,6 +75,20 @@ public void onResult(final AsyncConnectionSource result, final Throwable t) {
7475
});
7576
}
7677

78+
@Override
79+
public void getConnectionSource(final ServerAddress serverAddress, final SingleResultCallback<AsyncConnectionSource> callback) {
80+
wrapped.getConnectionSource(serverAddress, new SingleResultCallback<AsyncConnectionSource>() {
81+
@Override
82+
public void onResult(final AsyncConnectionSource result, final Throwable t) {
83+
if (t != null) {
84+
callback.onResult(null, t);
85+
} else {
86+
callback.onResult(new AsyncCryptConnectionSource(result), null);
87+
}
88+
}
89+
});
90+
}
91+
7792
@Override
7893
public int getCount() {
7994
return wrapped.getCount();
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/*
2+
* Copyright 2008-present MongoDB, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.mongodb.async.client;
18+
19+
import com.mongodb.AutoEncryptionSettings;
20+
import com.mongodb.MongoClientSettings;
21+
import com.mongodb.MongoNamespace;
22+
import com.mongodb.WriteConcern;
23+
import com.mongodb.async.FutureResultCallback;
24+
import com.mongodb.client.test.CollectionHelper;
25+
import org.bson.BsonDocument;
26+
import org.bson.BsonString;
27+
import org.bson.codecs.BsonDocumentCodec;
28+
import org.junit.After;
29+
import org.junit.Before;
30+
import org.junit.Test;
31+
import org.junit.runner.RunWith;
32+
import org.junit.runners.Parameterized;
33+
34+
import java.io.File;
35+
import java.io.IOException;
36+
import java.net.URISyntaxException;
37+
import java.util.Arrays;
38+
import java.util.Base64;
39+
import java.util.Collection;
40+
import java.util.HashMap;
41+
import java.util.Map;
42+
43+
import static com.mongodb.ClusterFixture.isNotAtLeastJava8;
44+
import static com.mongodb.ClusterFixture.isStandalone;
45+
import static com.mongodb.ClusterFixture.serverVersionAtLeast;
46+
import static com.mongodb.async.client.Fixture.getDefaultDatabaseName;
47+
import static com.mongodb.async.client.Fixture.getMongoClient;
48+
import static com.mongodb.async.client.Fixture.getMongoClientBuilderFromConnectionString;
49+
import static org.junit.Assert.assertEquals;
50+
import static org.junit.Assert.assertTrue;
51+
import static org.junit.Assume.assumeFalse;
52+
import static org.junit.Assume.assumeTrue;
53+
import static util.JsonPoweredTestHelper.getTestDocument;
54+
55+
@RunWith(Parameterized.class)
56+
public class ClientSideEncryptionSessionTest {
57+
private static final String COLLECTION_NAME = "clientSideEncryptionSessionsTest";
58+
59+
private MongoClient client = getMongoClient();
60+
private MongoClient clientEncrypted;
61+
private final boolean useTransaction;
62+
63+
@Parameterized.Parameters(name = "useTransaction: {0}")
64+
public static Collection<Object[]> data() {
65+
return Arrays.asList(new Object[]{true}, new Object[]{false});
66+
}
67+
68+
public ClientSideEncryptionSessionTest(final boolean useTransaction) {
69+
this.useTransaction = useTransaction;
70+
}
71+
72+
@Before
73+
public void setUp() throws Throwable {
74+
assumeFalse(isNotAtLeastJava8());
75+
assumeTrue(serverVersionAtLeast(4, 2));
76+
assumeFalse(isStandalone());
77+
78+
/* Step 1: get unencrypted client and recreate keys collection */
79+
client = getMongoClient();
80+
MongoDatabase keyVaultDatabase = client.getDatabase("keyvault");
81+
MongoCollection<BsonDocument> dataKeys = keyVaultDatabase.getCollection("datakeys", BsonDocument.class)
82+
.withWriteConcern(WriteConcern.MAJORITY);
83+
FutureResultCallback<Void> voidCallback = new FutureResultCallback<Void>();
84+
dataKeys.drop(voidCallback);
85+
voidCallback.get();
86+
87+
voidCallback = new FutureResultCallback<Void>();
88+
dataKeys.insertOne(bsonDocumentFromPath("external-key.json"), voidCallback);
89+
voidCallback.get();
90+
91+
/* Step 2: create encryption objects. */
92+
Map<String, Map<String, Object>> kmsProviders = new HashMap<String, Map<String, Object>>();
93+
Map<String, Object> localMasterkey = new HashMap<String, Object>();
94+
Map<String, BsonDocument> schemaMap = new HashMap<String, BsonDocument>();
95+
96+
byte[] localMasterKeyBytes = Base64.getDecoder().decode("Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBM"
97+
+ "UN3YkQ5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk");
98+
localMasterkey.put("key", localMasterKeyBytes);
99+
kmsProviders.put("local", localMasterkey);
100+
schemaMap.put(getDefaultDatabaseName() + "." + COLLECTION_NAME, bsonDocumentFromPath("external-schema.json"));
101+
102+
MongoClientSettings clientSettings = getMongoClientBuilderFromConnectionString()
103+
.autoEncryptionSettings(AutoEncryptionSettings.builder()
104+
.keyVaultNamespace("keyvault.datakeys")
105+
.kmsProviders(kmsProviders)
106+
.schemaMap(schemaMap).build())
107+
.build();
108+
clientEncrypted = MongoClients.create(clientSettings);
109+
110+
CollectionHelper<BsonDocument> collectionHelper =
111+
new CollectionHelper<BsonDocument>(new BsonDocumentCodec(), new MongoNamespace(getDefaultDatabaseName(), COLLECTION_NAME));
112+
collectionHelper.drop();
113+
collectionHelper.create();
114+
}
115+
116+
@After
117+
public void after() {
118+
if (clientEncrypted != null) {
119+
try {
120+
clientEncrypted.close();
121+
} catch (Exception e) {
122+
// ignore
123+
}
124+
}
125+
}
126+
127+
@Test
128+
public void testWithExplicitSession() throws Throwable {
129+
BsonString unencryptedValue = new BsonString("test");
130+
131+
FutureResultCallback<ClientSession> clientSessionCallback = new FutureResultCallback<ClientSession>();
132+
clientEncrypted.startSession(clientSessionCallback);
133+
ClientSession clientSession = clientSessionCallback.get();
134+
try {
135+
if (useTransaction) {
136+
clientSession.startTransaction();
137+
}
138+
MongoCollection<BsonDocument> autoEncryptedCollection = clientEncrypted.getDatabase(getDefaultDatabaseName())
139+
.getCollection(COLLECTION_NAME, BsonDocument.class);
140+
FutureResultCallback<Void> insertCallback = new FutureResultCallback<Void>();
141+
autoEncryptedCollection.insertOne(clientSession, new BsonDocument().append("encrypted", new BsonString("test")),
142+
insertCallback);
143+
insertCallback.get();
144+
145+
FutureResultCallback<BsonDocument> findCallback = new FutureResultCallback<BsonDocument>();
146+
autoEncryptedCollection.find(clientSession).first(findCallback);
147+
BsonDocument unencryptedDocument = findCallback.get();
148+
assertEquals(unencryptedValue, unencryptedDocument.getString("encrypted"));
149+
150+
if (useTransaction) {
151+
FutureResultCallback<Void> commitCallback = new FutureResultCallback<Void>();
152+
clientSession.commitTransaction(commitCallback);
153+
commitCallback.get();
154+
}
155+
} finally {
156+
clientSession.close();
157+
}
158+
159+
MongoCollection<BsonDocument> encryptedCollection = client.getDatabase(getDefaultDatabaseName())
160+
.getCollection(COLLECTION_NAME, BsonDocument.class);
161+
FutureResultCallback<BsonDocument> findCallback = new FutureResultCallback<BsonDocument>();
162+
encryptedCollection.find().first(findCallback);
163+
BsonDocument encryptedDocument = findCallback.get();
164+
assertTrue(encryptedDocument.isBinary("encrypted"));
165+
assertEquals(6, encryptedDocument.getBinary("encrypted").getType());
166+
}
167+
168+
private static BsonDocument bsonDocumentFromPath(final String path) throws IOException, URISyntaxException {
169+
return getTestDocument(new File(ClientSideEncryptionSessionTest.class
170+
.getResource("/client-side-encryption-external/" + path).toURI()));
171+
}
172+
}

driver-core/src/main/com/mongodb/binding/AsyncClusterBinding.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import com.mongodb.ReadConcern;
2020
import com.mongodb.ReadPreference;
21+
import com.mongodb.ServerAddress;
2122
import com.mongodb.async.SingleResultCallback;
2223
import com.mongodb.connection.AsyncConnection;
2324
import com.mongodb.connection.Cluster;
@@ -27,6 +28,7 @@
2728
import com.mongodb.internal.binding.AsyncClusterAwareReadWriteBinding;
2829
import com.mongodb.internal.connection.ReadConcernAwareNoOpSessionContext;
2930
import com.mongodb.selector.ReadPreferenceServerSelector;
31+
import com.mongodb.selector.ServerAddressSelector;
3032
import com.mongodb.selector.ServerSelector;
3133
import com.mongodb.selector.WritableServerSelector;
3234
import com.mongodb.session.SessionContext;
@@ -102,6 +104,11 @@ public void getWriteConnectionSource(final SingleResultCallback<AsyncConnectionS
102104
getAsyncClusterBindingConnectionSource(new WritableServerSelector(), callback);
103105
}
104106

107+
@Override
108+
public void getConnectionSource(final ServerAddress serverAddress, final SingleResultCallback<AsyncConnectionSource> callback) {
109+
getAsyncClusterBindingConnectionSource(new ServerAddressSelector(serverAddress), callback);
110+
}
111+
105112
private void getAsyncClusterBindingConnectionSource(final ServerSelector serverSelector,
106113
final SingleResultCallback<AsyncConnectionSource> callback) {
107114
cluster.selectServerAsync(serverSelector, new SingleResultCallback<Server>() {

0 commit comments

Comments
 (0)