Skip to content

Commit fac48e3

Browse files
committed
Ensure MongoClient instances are closed in Scala tests (#830)
Co-authored-by: Ross Lawley <[email protected]> JAVA-4412
1 parent 9e94de6 commit fac48e3

File tree

6 files changed

+134
-92
lines changed

6 files changed

+134
-92
lines changed

driver-scala/src/it/scala/org/mongodb/scala/ClientSideEncryptionBypassAutoEncryptionSpec.scala

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,25 +60,30 @@ class ClientSideEncryptionBypassAutoEncryptionSpec extends RequiresMongoDBISpec
6060
.codecRegistry(DEFAULT_CODEC_REGISTRY)
6161
.build
6262

63-
val clientEncrypted = MongoClient(clientSettings)
63+
withTempClient(
64+
clientSettings,
65+
clientEncrypted => {
6466

65-
val fieldValue = BsonString("123456789")
67+
val fieldValue = BsonString("123456789")
6668

67-
val dataKeyId = clientEncryption.createDataKey("local", DataKeyOptions()).head().futureValue
69+
val dataKeyId = clientEncryption.createDataKey("local", DataKeyOptions()).head().futureValue
6870

69-
val encryptedFieldValue = clientEncryption
70-
.encrypt(fieldValue, EncryptOptions("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic").keyId(dataKeyId))
71-
.head()
72-
.futureValue
71+
val encryptedFieldValue = clientEncryption
72+
.encrypt(fieldValue, EncryptOptions("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic").keyId(dataKeyId))
73+
.head()
74+
.futureValue
7375

74-
val collection: MongoCollection[Document] =
75-
clientEncrypted.getDatabase(databaseName).getCollection[Document]("test")
76+
val collection: MongoCollection[Document] =
77+
clientEncrypted.getDatabase(databaseName).getCollection[Document]("test")
7678

77-
collection.insertOne(Document("encryptedField" -> encryptedFieldValue)).futureValue
79+
collection.insertOne(Document("encryptedField" -> encryptedFieldValue)).futureValue
7880

79-
val result = collection.find().first().head().futureValue
81+
val result = collection.find().first().head().futureValue
8082

81-
result.get[BsonString]("encryptedField") should equal(Some(fieldValue))
83+
result.get[BsonString]("encryptedField") should equal(Some(fieldValue))
84+
85+
}
86+
)
8287
}
8388

8489
}

driver-scala/src/it/scala/org/mongodb/scala/RequiresMongoDBISpec.scala

Lines changed: 30 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -16,30 +16,24 @@
1616

1717
package org.mongodb.scala
1818

19-
import com.mongodb.ClusterFixture.getServerApi
2019
import com.mongodb.connection.ServerVersion
2120
import org.mongodb.scala.bson.BsonString
2221
import org.scalatest._
2322

2423
import scala.collection.JavaConverters._
2524
import scala.concurrent.duration.{ Duration, _ }
2625
import scala.concurrent.{ Await, ExecutionContext }
27-
import scala.util.{ Properties, Try }
2826

2927
trait RequiresMongoDBISpec extends BaseSpec with BeforeAndAfterAll {
3028

3129
implicit val ec: ExecutionContext = ExecutionContext.Implicits.global
3230

33-
private val DEFAULT_URI: String = "mongodb://localhost:27017/"
34-
private val MONGODB_URI_SYSTEM_PROPERTY_NAME: String = "org.mongodb.test.uri"
3531
val WAIT_DURATION: Duration = 60.seconds
3632
private val DB_PREFIX = "mongo-scala-"
3733
private var _currentTestName: Option[String] = None
38-
private var mongoDBOnline: Boolean = false
3934

4035
protected override def runTest(testName: String, args: Args): Status = {
4136
_currentTestName = Some(testName.split("should")(1))
42-
mongoDBOnline = isMongoDBOnline()
4337
super.runTest(testName, args)
4438
}
4539

@@ -53,48 +47,33 @@ trait RequiresMongoDBISpec extends BaseSpec with BeforeAndAfterAll {
5347
*/
5448
def collectionName: String = _currentTestName.getOrElse(suiteName).filter(_.isLetterOrDigit)
5549

56-
val mongoClientURI: String = {
57-
val uri = Properties.propOrElse(MONGODB_URI_SYSTEM_PROPERTY_NAME, DEFAULT_URI)
58-
if (!uri.isBlank) uri else DEFAULT_URI
59-
}
60-
val connectionString: ConnectionString = ConnectionString(mongoClientURI)
50+
def mongoClientSettingsBuilder: MongoClientSettings.Builder = TestMongoClientHelper.mongoClientSettingsBuilder
6151

62-
def mongoClientSettingsBuilder: MongoClientSettings.Builder = {
63-
val builder = MongoClientSettings.builder().applyConnectionString(connectionString)
64-
if (getServerApi != null) {
65-
builder.serverApi(getServerApi)
66-
}
67-
builder
68-
}
52+
val mongoClientSettings: MongoClientSettings = TestMongoClientHelper.mongoClientSettings
6953

70-
val mongoClientSettings: MongoClientSettings = mongoClientSettingsBuilder.build()
54+
def mongoClient(): MongoClient = TestMongoClientHelper.mongoClient
7155

72-
def mongoClient(): MongoClient = MongoClient(mongoClientSettings)
73-
74-
def isMongoDBOnline(): Boolean = {
75-
Try(Await.result(MongoClient(mongoClientSettings).listDatabaseNames().toFuture(), WAIT_DURATION)).isSuccess
76-
}
77-
78-
def hasSingleHost(): Boolean = {
79-
new ConnectionString(mongoClientURI).getHosts.size() == 1
56+
def checkMongoDB(): Unit = {
57+
if (!TestMongoClientHelper.isMongoDBOnline) {
58+
cancel("No Available Database")
59+
}
8060
}
8161

82-
def checkMongoDB() {
83-
if (!mongoDBOnline) {
84-
cancel("No Available Database")
62+
def withTempClient(mongoClientSettings: MongoClientSettings, testCode: MongoClient => Any): Unit = {
63+
val client = MongoClient(mongoClientSettings)
64+
try {
65+
testCode(client)
66+
} finally {
67+
client.close()
8568
}
8669
}
8770

8871
def withClient(testCode: MongoClient => Any): Unit = {
8972
checkMongoDB()
90-
val client = mongoClient()
91-
try testCode(client) // loan the client
92-
finally {
93-
client.close()
94-
}
73+
testCode(TestMongoClientHelper.mongoClient) // loan the client
9574
}
9675

97-
def withDatabase(dbName: String)(testCode: MongoDatabase => Any) {
76+
def withDatabase(dbName: String)(testCode: MongoDatabase => Any): Unit = {
9877
withClient { client =>
9978
val databaseName = if (dbName.startsWith(DB_PREFIX)) dbName.take(63) else s"$DB_PREFIX$dbName".take(63) // scalastyle:ignore
10079
val mongoDatabase = client.getDatabase(databaseName)
@@ -108,7 +87,7 @@ trait RequiresMongoDBISpec extends BaseSpec with BeforeAndAfterAll {
10887

10988
def withDatabase(testCode: MongoDatabase => Any): Unit = withDatabase(databaseName)(testCode: MongoDatabase => Any)
11089

111-
def withCollection(testCode: MongoCollection[Document] => Any) {
90+
def withCollection(testCode: MongoCollection[Document] => Any): Unit = {
11291
withDatabase(databaseName) { mongoDatabase =>
11392
val mongoCollection = mongoDatabase.getCollection(collectionName)
11493
try testCode(mongoCollection) // "loan" the fixture to the test
@@ -119,19 +98,25 @@ trait RequiresMongoDBISpec extends BaseSpec with BeforeAndAfterAll {
11998
}
12099
}
121100

122-
lazy val isSharded: Boolean = if (!mongoDBOnline) {
101+
lazy val isSharded: Boolean = if (!TestMongoClientHelper.isMongoDBOnline) {
123102
false
124103
} else {
125104
Await
126-
.result(mongoClient().getDatabase("admin").runCommand(Document("isMaster" -> 1)).toFuture(), WAIT_DURATION)
105+
.result(
106+
mongoClient().getDatabase("admin").runCommand(Document("isMaster" -> 1)).toFuture(),
107+
WAIT_DURATION
108+
)
127109
.getOrElse("msg", BsonString(""))
128110
.asString()
129111
.getValue == "isdbgrid"
130112
}
131113

132114
lazy val buildInfo: Document = {
133-
if (mongoDBOnline) {
134-
Await.result(mongoClient().getDatabase("admin").runCommand(Document("buildInfo" -> 1)).toFuture(), WAIT_DURATION)
115+
if (TestMongoClientHelper.isMongoDBOnline) {
116+
Await.result(
117+
mongoClient().getDatabase("admin").runCommand(Document("buildInfo" -> 1)).toFuture(),
118+
WAIT_DURATION
119+
)
135120
} else {
136121
Document()
137122
}
@@ -158,26 +143,14 @@ trait RequiresMongoDBISpec extends BaseSpec with BeforeAndAfterAll {
158143
}
159144

160145
override def beforeAll() {
161-
if (mongoDBOnline) {
162-
val client = mongoClient()
163-
Await.result(client.getDatabase(databaseName).drop().toFuture(), WAIT_DURATION)
164-
client.close()
146+
if (TestMongoClientHelper.isMongoDBOnline) {
147+
Await.result(TestMongoClientHelper.mongoClient.getDatabase(databaseName).drop().toFuture(), WAIT_DURATION)
165148
}
166149
}
167150

168151
override def afterAll() {
169-
if (mongoDBOnline) {
170-
val client = mongoClient()
171-
Await.result(client.getDatabase(databaseName).drop().toFuture(), WAIT_DURATION)
172-
client.close()
173-
}
174-
}
175-
176-
Runtime.getRuntime.addShutdownHook(new ShutdownHook())
177-
178-
private[mongodb] class ShutdownHook extends Thread {
179-
override def run() {
180-
mongoClient().getDatabase(databaseName).drop()
152+
if (TestMongoClientHelper.isMongoDBOnline) {
153+
Await.result(TestMongoClientHelper.mongoClient.getDatabase(databaseName).drop().toFuture(), WAIT_DURATION)
181154
}
182155
}
183156

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright (c) 2008 - 2013 10gen, Inc. <http://10gen.com>
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
*/
17+
18+
package org.mongodb.scala
19+
20+
import com.mongodb.ClusterFixture.getServerApi
21+
import org.mongodb.scala.syncadapter.WAIT_DURATION
22+
23+
import scala.concurrent.Await
24+
import scala.util.{ Properties, Try }
25+
26+
object TestMongoClientHelper {
27+
private val DEFAULT_URI: String = "mongodb://localhost:27017/"
28+
private val MONGODB_URI_SYSTEM_PROPERTY_NAME: String = "org.mongodb.test.uri"
29+
30+
val mongoClientURI: String = {
31+
val uri = Properties.propOrElse(MONGODB_URI_SYSTEM_PROPERTY_NAME, DEFAULT_URI)
32+
if (!uri.isBlank) uri else DEFAULT_URI
33+
}
34+
val connectionString: ConnectionString = ConnectionString(mongoClientURI)
35+
36+
def mongoClientSettingsBuilder: MongoClientSettings.Builder = {
37+
val builder = MongoClientSettings.builder().applyConnectionString(connectionString)
38+
if (getServerApi != null) {
39+
builder.serverApi(getServerApi)
40+
}
41+
builder
42+
}
43+
44+
val mongoClientSettings: MongoClientSettings = mongoClientSettingsBuilder.build()
45+
val mongoClient: MongoClient = MongoClient(mongoClientSettings)
46+
47+
def isMongoDBOnline: Boolean = {
48+
Try(Await.result(TestMongoClientHelper.mongoClient.listDatabaseNames().toFuture(), WAIT_DURATION)).isSuccess
49+
}
50+
51+
def hasSingleHost: Boolean = {
52+
TestMongoClientHelper.connectionString.getHosts.size() == 1
53+
}
54+
55+
Runtime.getRuntime.addShutdownHook(new ShutdownHook())
56+
57+
private[mongodb] class ShutdownHook extends Thread {
58+
override def run() {
59+
mongoClient.close()
60+
}
61+
}
62+
}

driver-scala/src/it/scala/org/mongodb/scala/documentation/DocumentationExampleSpec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
package org.mongodb.scala.documentation
1818

1919
import java.util.concurrent.atomic.AtomicBoolean
20-
2120
import com.mongodb.client.model.changestream.{ ChangeStreamDocument, FullDocument }
21+
import org.mongodb.scala.TestMongoClientHelper.hasSingleHost
2222
import org.mongodb.scala._
2323
import org.mongodb.scala.bson.conversions.Bson
2424
import org.mongodb.scala.bson.{ BsonArray, BsonDocument, BsonNull, BsonString, BsonValue }
@@ -596,7 +596,7 @@ class DocumentationExampleSpec extends RequiresMongoDBISpec with FuturesSpec {
596596
}
597597

598598
it should "be able to watch" in withCollection { collection =>
599-
assume(serverVersionAtLeast(List(3, 6, 0)) && !hasSingleHost())
599+
assume(serverVersionAtLeast(List(3, 6, 0)) && !hasSingleHost)
600600
val inventory: MongoCollection[Document] = collection
601601
val stop: AtomicBoolean = new AtomicBoolean(false)
602602
new Thread(new Runnable {

driver-scala/src/it/scala/org/mongodb/scala/documentation/DocumentationTransactionsExampleSpec.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.mongodb.scala.documentation
1818

19+
import org.mongodb.scala.TestMongoClientHelper.hasSingleHost
1920
import org.mongodb.scala._
2021
import org.mongodb.scala.model.{ Filters, Updates }
2122
import org.mongodb.scala.result.{ InsertOneResult, UpdateResult }
@@ -36,7 +37,7 @@ class DocumentationTransactionsExampleSpec extends RequiresMongoDBISpec {
3637
// end implicit functions
3738

3839
"The Scala driver" should "be able to commit a transaction" in withClient { client =>
39-
assume(serverVersionAtLeast(List(4, 0, 0)) && !hasSingleHost())
40+
assume(serverVersionAtLeast(List(4, 0, 0)) && !hasSingleHost)
4041
client.getDatabase("hr").drop().execute()
4142
client.getDatabase("hr").createCollection("employees").execute()
4243
client.getDatabase("hr").createCollection("events").execute()

driver-scala/src/it/scala/org/mongodb/scala/gridfs/GridFSObservableSpec.scala

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -284,29 +284,30 @@ class GridFSObservableSpec extends RequiresMongoDBISpec with FuturesSpec with Be
284284
}
285285

286286
it should "use the user provided codec registries for encoding / decoding data" in {
287-
val client = MongoClient(
287+
withTempClient(
288288
mongoClientSettingsBuilder
289289
.uuidRepresentation(UuidRepresentation.STANDARD)
290-
.build()
290+
.build(),
291+
client => {
292+
val database = client.getDatabase(databaseName)
293+
val uuid = UUID.randomUUID()
294+
val fileMeta = new org.bson.Document("uuid", uuid)
295+
val bucket = GridFSBucket(database)
296+
297+
val fileId = bucket
298+
.uploadFromObservable(
299+
"myFile",
300+
Observable(Seq(ByteBuffer.wrap(multiChunkString.getBytes()))),
301+
new GridFSUploadOptions().metadata(fileMeta)
302+
)
303+
.head()
304+
.futureValue
305+
306+
val fileAsDocument = filesCollection.find[BsonDocument]().head().futureValue
307+
fileAsDocument.getDocument("metadata").getBinary("uuid").getType should equal(4.toByte)
308+
fileAsDocument.getDocument("metadata").getBinary("uuid").asUuid() should equal(uuid)
309+
}
291310
)
292-
293-
val database = client.getDatabase(databaseName)
294-
val uuid = UUID.randomUUID()
295-
val fileMeta = new org.bson.Document("uuid", uuid)
296-
val bucket = GridFSBucket(database)
297-
298-
val fileId = bucket
299-
.uploadFromObservable(
300-
"myFile",
301-
Observable(Seq(ByteBuffer.wrap(multiChunkString.getBytes()))),
302-
new GridFSUploadOptions().metadata(fileMeta)
303-
)
304-
.head()
305-
.futureValue
306-
307-
val fileAsDocument = filesCollection.find[BsonDocument]().head().futureValue
308-
fileAsDocument.getDocument("metadata").getBinary("uuid").getType should equal(4.toByte)
309-
fileAsDocument.getDocument("metadata").getBinary("uuid").asUuid() should equal(uuid)
310311
}
311312

312313
it should "handle missing file name data when downloading" in {

0 commit comments

Comments
 (0)