Skip to content

Commit ccd63c2

Browse files
authored
Make object mocking thread safe (#312)
Protect against concurrent non-mocked access to the a mocked object (Fix #311)
1 parent 9926b8d commit ccd63c2

File tree

6 files changed

+93
-17
lines changed

6 files changed

+93
-17
lines changed

build.sbt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,15 @@ lazy val commonSettings =
4343
Nil
4444
}
4545
},
46-
Test / scalacOptions += "-Ywarn-value-discard"
46+
Test / scalacOptions += "-Ywarn-value-discard",
47+
libraryDependencies ++= {
48+
CrossVersion.partialVersion(scalaVersion.value) match {
49+
case Some((2, major)) if major <= 12 =>
50+
Seq()
51+
case _ =>
52+
Seq("org.scala-lang.modules" %% "scala-parallel-collections" % "1.0.0-RC1")
53+
}
54+
}
4755
)
4856

4957
lazy val publishSettings = Seq(

common/src/main/scala/org/mockito/MockitoAPI.scala

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ import org.mockito.ReflectionUtils.InvocationOnMockOps
1616
import org.mockito.internal.configuration.plugins.Plugins.getMockMaker
1717
import org.mockito.internal.creation.MockSettingsImpl
1818
import org.mockito.internal.exceptions.Reporter.notAMockPassedToVerifyNoMoreInteractions
19-
import org.mockito.internal.handler.ScalaMockHandler
19+
import org.mockito.internal.handler.{ ScalaMockHandler, ThreadAwareMockHandler }
2020
import org.mockito.internal.progress.ThreadSafeMockingProgress.mockingProgress
2121
import org.mockito.internal.stubbing.answers.ScalaThrowsException
2222
import org.mockito.internal.util.MockUtil
2323
import org.mockito.internal.util.reflection.LenientCopyTool
2424
import org.mockito.internal.{ ValueClassExtractor, ValueClassWrapper }
25-
import org.mockito.invocation.InvocationOnMock
25+
import org.mockito.invocation.{ Invocation, InvocationContainer, InvocationOnMock, MockHandler }
2626
import org.mockito.mock.MockCreationSettings
2727
import org.mockito.stubbing._
2828
import org.mockito.verification.{ VerificationAfterDelay, VerificationMode, VerificationWithTimeout }
@@ -472,7 +472,8 @@ private[mockito] trait MockitoEnhancer extends MockCreator {
472472
* <code>verify(aMock).iHaveSomeDefaultArguments("I'm not gonna pass the second argument", "default value")</code>
473473
* as the value for the second parameter would have been null...
474474
*/
475-
override def mock[T <: AnyRef: ClassTag: WeakTypeTag](implicit defaultAnswer: DefaultAnswer, $pt: Prettifier): T = mock(withSettings)
475+
override def mock[T <: AnyRef: ClassTag: WeakTypeTag](implicit defaultAnswer: DefaultAnswer, $pt: Prettifier): T =
476+
createMock(withSettings)
476477

477478
/**
478479
* Delegates to <code>Mockito.mock(type: Class[T], defaultAnswer: Answer[_])</code>
@@ -489,7 +490,7 @@ private[mockito] trait MockitoEnhancer extends MockCreator {
489490
* as the value for the second parameter would have been null...
490491
*/
491492
override def mock[T <: AnyRef: ClassTag: WeakTypeTag](defaultAnswer: DefaultAnswer)(implicit $pt: Prettifier): T =
492-
mock(withSettings(defaultAnswer))
493+
createMock(withSettings(defaultAnswer))
493494

494495
/**
495496
* Delegates to <code>Mockito.mock(type: Class[T], mockSettings: MockSettings)</code>
@@ -505,7 +506,13 @@ private[mockito] trait MockitoEnhancer extends MockCreator {
505506
* <code>verify(aMock).iHaveSomeDefaultArguments("I'm not gonna pass the second argument", "default value")</code>
506507
* as the value for the second parameter would have been null...
507508
*/
508-
override def mock[T <: AnyRef: ClassTag: WeakTypeTag](mockSettings: MockSettings)(implicit $pt: Prettifier): T = {
509+
override def mock[T <: AnyRef: ClassTag: WeakTypeTag](mockSettings: MockSettings)(implicit $pt: Prettifier): T =
510+
createMock(mockSettings)
511+
512+
private def createMock[T <: AnyRef: ClassTag: WeakTypeTag](
513+
mockSettings: MockSettings,
514+
mockHandler: (MockCreationSettings[T], Prettifier) => MockHandler[T] = (settings: MockCreationSettings[T], pt: Prettifier) => ScalaMockHandler(settings)(pt)
515+
)(implicit $pt: Prettifier): T = {
509516
val interfaces = ReflectionUtils.extraInterfaces
510517

511518
val realClass: Class[T] = mockSettings match {
@@ -520,7 +527,7 @@ private[mockito] trait MockitoEnhancer extends MockCreator {
520527
else mockSettings
521528

522529
def createMock(settings: MockCreationSettings[T]): T = {
523-
val mock = getMockMaker.createMock(settings, ScalaMockHandler(settings))
530+
val mock = getMockMaker.createMock(settings, mockHandler(settings, $pt))
524531
val spiedInstance = settings.getSpiedInstance
525532
if (spiedInstance != null) new LenientCopyTool().copyToMock(spiedInstance, mock)
526533
mock
@@ -620,12 +627,21 @@ private[mockito] trait MockitoEnhancer extends MockCreator {
620627
/**
621628
* Mocks the specified object only for the context of the block
622629
*/
623-
def withObjectMocked[O <: AnyRef: ClassTag](block: => Any): Unit = {
624-
val moduleField = clazz[O].getDeclaredField("MODULE$")
625-
val realImpl = moduleField.get(null)
626-
ReflectionUtils.setFinalStatic(moduleField, mock[O])
627-
try block
628-
finally ReflectionUtils.setFinalStatic(moduleField, realImpl)
630+
def withObjectMocked[O <: AnyRef: ClassTag](block: => Any)(implicit defaultAnswer: DefaultAnswer, $pt: Prettifier): Unit = {
631+
val objectClass = clazz[O]
632+
objectClass.synchronized {
633+
val moduleField = objectClass.getDeclaredField("MODULE$")
634+
val realImpl: O = moduleField.get(null).asInstanceOf[O]
635+
636+
val threadAwareMock = createMock(
637+
withSettings(defaultAnswer),
638+
(settings: MockCreationSettings[O], pt: Prettifier) => ThreadAwareMockHandler(settings, realImpl)(pt)
639+
)
640+
641+
ReflectionUtils.setFinalStatic(moduleField, threadAwareMock)
642+
try block
643+
finally ReflectionUtils.setFinalStatic(moduleField, realImpl)
644+
}
629645
}
630646
}
631647

common/src/main/scala/org/mockito/ReflectionUtils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ object ReflectionUtils {
123123
.getOrElse(Seq.empty)
124124
}
125125

126-
def setFinalStatic(field: Field, newValue: Any) = {
126+
def setFinalStatic(field: Field, newValue: Any): Unit = {
127127
field.setAccessible(true)
128128
val modifiersField = classOf[Field].getDeclaredField("modifiers")
129129
modifiersField.setAccessible(true)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package org.mockito.internal.handler
2+
3+
import org.mockito.AdditionalAnswers
4+
import org.mockito.invocation.{ Invocation, InvocationContainer, MockHandler }
5+
import org.mockito.mock.MockCreationSettings
6+
import org.scalactic.Prettifier
7+
8+
class ThreadAwareMockHandler[T](settings: MockCreationSettings[T], realImpl: T)(implicit $pt: Prettifier) extends MockHandler[T] {
9+
private val currentThread = Thread.currentThread()
10+
private val mockDelegate = ScalaMockHandler(settings)
11+
private val realImplDelegate = AdditionalAnswers.delegatesTo(realImpl)
12+
13+
override def handle(invocation: Invocation): AnyRef =
14+
if (Thread.currentThread() == currentThread) mockDelegate.handle(invocation)
15+
else realImplDelegate.answer(invocation)
16+
17+
override def getMockSettings: MockCreationSettings[T] = mockDelegate.getMockSettings
18+
19+
override def getInvocationContainer: InvocationContainer = mockDelegate.getInvocationContainer
20+
}
21+
22+
object ThreadAwareMockHandler {
23+
def apply[T](settings: MockCreationSettings[T], realImpl: T)(implicit $pt: Prettifier): ThreadAwareMockHandler[T] =
24+
new ThreadAwareMockHandler(settings, realImpl)($pt)
25+
}

scalatest/src/test/scala/user/org/mockito/IdiomaticStubbingTest.scala

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
package user.org.mockito
22

3-
import java.lang.reflect.{ Field, Modifier }
43
import java.util.concurrent.atomic.AtomicInteger
54

65
import org.mockito.invocation.InvocationOnMock
7-
import org.mockito.{ clazz, ArgumentMatchersSugar, IdiomaticStubbing }
6+
import org.mockito.{ ArgumentMatchersSugar, IdiomaticStubbing }
87
import org.scalatest.matchers.should.Matchers
98
import org.scalatest.wordspec.AnyWordSpec
109
import user.org.mockito.matchers.{ ValueCaseClassInt, ValueCaseClassString, ValueClass }
1110

12-
import scala.reflect.ClassTag
11+
import scala.collection.parallel.immutable
12+
import scala.concurrent.{ Await, Future }
13+
import scala.util.Random
1314

1415
class IdiomaticStubbingTest extends AnyWordSpec with Matchers with ArgumentMatchersSugar with IdiomaticMockitoTestSetup with IdiomaticStubbing {
1516

@@ -313,5 +314,26 @@ class IdiomaticStubbingTest extends AnyWordSpec with Matchers with ArgumentMatch
313314

314315
FooObject.simpleMethod shouldBe "not mocked!"
315316
}
317+
318+
"object stubbing should be thread safe" in {
319+
immutable.ParSeq.range(1, 100).foreach { i =>
320+
withObjectMocked[FooObject.type] {
321+
FooObject.simpleMethod returns s"mocked!-$i"
322+
FooObject.simpleMethod shouldBe s"mocked!-$i"
323+
}
324+
}
325+
}
326+
327+
"object stubbing should be thread safe 2" in {
328+
val now = FooObject.stateDependantMethod
329+
immutable.ParSeq.range(1, 100).foreach { i =>
330+
if (i % 2 == 0)
331+
withObjectMocked[FooObject.type] {
332+
FooObject.stateDependantMethod returns i
333+
FooObject.stateDependantMethod shouldBe i
334+
}
335+
else FooObject.stateDependantMethod shouldBe now
336+
}
337+
}
316338
}
317339
}

scalatest/src/test/scala/user/org/mockito/TestModel.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package user.org.mockito
22
import user.org.mockito.matchers.{ ValueCaseClassInt, ValueCaseClassString, ValueClass }
33

44
import scala.annotation.varargs
5+
import scala.util.Random
56

67
trait FooTrait {
78
def bar = "not mocked"
@@ -128,5 +129,9 @@ class TestController(org: Org) {
128129
}
129130

130131
object FooObject {
132+
val now: Long = Random.nextLong()
133+
131134
def simpleMethod: String = "not mocked!"
135+
136+
def stateDependantMethod: Long = now
132137
}

0 commit comments

Comments
 (0)