@@ -16,13 +16,13 @@ import org.mockito.ReflectionUtils.InvocationOnMockOps
1616import org .mockito .internal .configuration .plugins .Plugins .getMockMaker
1717import org .mockito .internal .creation .MockSettingsImpl
1818import org .mockito .internal .exceptions .Reporter .notAMockPassedToVerifyNoMoreInteractions
19- import org .mockito .internal .handler .ScalaMockHandler
19+ import org .mockito .internal .handler .{ ScalaMockHandler , ThreadAwareMockHandler }
2020import org .mockito .internal .progress .ThreadSafeMockingProgress .mockingProgress
2121import org .mockito .internal .stubbing .answers .ScalaThrowsException
2222import org .mockito .internal .util .MockUtil
2323import org .mockito .internal .util .reflection .LenientCopyTool
2424import org .mockito .internal .{ ValueClassExtractor , ValueClassWrapper }
25- import org .mockito .invocation .InvocationOnMock
25+ import org .mockito .invocation .{ Invocation , InvocationContainer , InvocationOnMock , MockHandler }
2626import org .mockito .mock .MockCreationSettings
2727import org .mockito .stubbing ._
2828import org .mockito .verification .{ VerificationAfterDelay , VerificationMode , VerificationWithTimeout }
@@ -472,7 +472,8 @@ private[mockito] trait MockitoEnhancer extends MockCreator {
472472 * <code>verify(aMock).iHaveSomeDefaultArguments("I'm not gonna pass the second argument", "default value")</code>
473473 * as the value for the second parameter would have been null...
474474 */
475- override def mock [T <: AnyRef : ClassTag : WeakTypeTag ](implicit defaultAnswer : DefaultAnswer , $pt : Prettifier ): T = mock(withSettings)
475+ override def mock [T <: AnyRef : ClassTag : WeakTypeTag ](implicit defaultAnswer : DefaultAnswer , $pt : Prettifier ): T =
476+ createMock(withSettings)
476477
477478 /**
478479 * Delegates to <code>Mockito.mock(type: Class[T], defaultAnswer: Answer[_])</code>
@@ -489,7 +490,7 @@ private[mockito] trait MockitoEnhancer extends MockCreator {
489490 * as the value for the second parameter would have been null...
490491 */
491492 override def mock [T <: AnyRef : ClassTag : WeakTypeTag ](defaultAnswer : DefaultAnswer )(implicit $pt : Prettifier ): T =
492- mock (withSettings(defaultAnswer))
493+ createMock (withSettings(defaultAnswer))
493494
494495 /**
495496 * Delegates to <code>Mockito.mock(type: Class[T], mockSettings: MockSettings)</code>
@@ -505,7 +506,13 @@ private[mockito] trait MockitoEnhancer extends MockCreator {
505506 * <code>verify(aMock).iHaveSomeDefaultArguments("I'm not gonna pass the second argument", "default value")</code>
506507 * as the value for the second parameter would have been null...
507508 */
508- override def mock [T <: AnyRef : ClassTag : WeakTypeTag ](mockSettings : MockSettings )(implicit $pt : Prettifier ): T = {
509+ override def mock [T <: AnyRef : ClassTag : WeakTypeTag ](mockSettings : MockSettings )(implicit $pt : Prettifier ): T =
510+ createMock(mockSettings)
511+
512+ private def createMock [T <: AnyRef : ClassTag : WeakTypeTag ](
513+ mockSettings : MockSettings ,
514+ mockHandler : (MockCreationSettings [T ], Prettifier ) => MockHandler [T ] = (settings : MockCreationSettings [T ], pt : Prettifier ) => ScalaMockHandler (settings)(pt)
515+ )(implicit $pt : Prettifier ): T = {
509516 val interfaces = ReflectionUtils .extraInterfaces
510517
511518 val realClass : Class [T ] = mockSettings match {
@@ -520,7 +527,7 @@ private[mockito] trait MockitoEnhancer extends MockCreator {
520527 else mockSettings
521528
522529 def createMock (settings : MockCreationSettings [T ]): T = {
523- val mock = getMockMaker.createMock(settings, ScalaMockHandler (settings))
530+ val mock = getMockMaker.createMock(settings, mockHandler (settings, $pt ))
524531 val spiedInstance = settings.getSpiedInstance
525532 if (spiedInstance != null ) new LenientCopyTool ().copyToMock(spiedInstance, mock)
526533 mock
@@ -620,12 +627,21 @@ private[mockito] trait MockitoEnhancer extends MockCreator {
620627 /**
621628 * Mocks the specified object only for the context of the block
622629 */
623- def withObjectMocked [O <: AnyRef : ClassTag ](block : => Any ): Unit = {
624- val moduleField = clazz[O ].getDeclaredField(" MODULE$" )
625- val realImpl = moduleField.get(null )
626- ReflectionUtils .setFinalStatic(moduleField, mock[O ])
627- try block
628- finally ReflectionUtils .setFinalStatic(moduleField, realImpl)
630+ def withObjectMocked [O <: AnyRef : ClassTag ](block : => Any )(implicit defaultAnswer : DefaultAnswer , $pt : Prettifier ): Unit = {
631+ val objectClass = clazz[O ]
632+ objectClass.synchronized {
633+ val moduleField = objectClass.getDeclaredField(" MODULE$" )
634+ val realImpl : O = moduleField.get(null ).asInstanceOf [O ]
635+
636+ val threadAwareMock = createMock(
637+ withSettings(defaultAnswer),
638+ (settings : MockCreationSettings [O ], pt : Prettifier ) => ThreadAwareMockHandler (settings, realImpl)(pt)
639+ )
640+
641+ ReflectionUtils .setFinalStatic(moduleField, threadAwareMock)
642+ try block
643+ finally ReflectionUtils .setFinalStatic(moduleField, realImpl)
644+ }
629645 }
630646}
631647
0 commit comments