Skip to content

Commit ee11ef0

Browse files
authored
Merge pull request #845 from Jason31569/main
Ability to mock protected methods with and without return value
2 parents 4122187 + 18191e2 commit ee11ef0

File tree

7 files changed

+393
-0
lines changed

7 files changed

+393
-0
lines changed

src/NSubstitute/Core/IThreadLocalContext.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ public interface IThreadLocalContext
2424
void EnqueueArgumentSpecification(IArgumentSpecification spec);
2525
IList<IArgumentSpecification> DequeueAllArgumentSpecifications();
2626

27+
/// <summary>
28+
/// Peeks into the argument specifications
29+
/// </summary>
30+
/// <returns>Enqueued argument specifications</returns>
31+
IList<IArgumentSpecification> PeekAllArgumentSpecifications();
32+
2733
void SetPendingRaisingEventArgumentsFactory(Func<ICall, object?[]> getArguments);
2834
/// <summary>
2935
/// Returns the previously set arguments factory and resets the stored value.

src/NSubstitute/Core/ThreadLocalContext.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,23 @@ public IList<IArgumentSpecification> DequeueAllArgumentSpecifications()
108108
return queue;
109109
}
110110

111+
/// <inheritdoc/>
112+
public IList<IArgumentSpecification> PeekAllArgumentSpecifications()
113+
{
114+
var queue = _argumentSpecifications.Value;
115+
116+
if (queue?.Count > 0)
117+
{
118+
var items = new IArgumentSpecification[queue.Count];
119+
120+
queue.CopyTo(items, 0);
121+
122+
return items;
123+
}
124+
125+
return EmptySpecifications;
126+
}
127+
111128
public void SetPendingRaisingEventArgumentsFactory(Func<ICall, object?[]> getArguments)
112129
{
113130
_getArgumentsForRaisingEvent.Value = getArguments;
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
namespace NSubstitute.Exceptions;
2+
3+
public class ProtectedMethodNotFoundException(string message, Exception? innerException) : SubstituteException(message, innerException)
4+
{
5+
public ProtectedMethodNotFoundException() : this("", null)
6+
{ }
7+
8+
public ProtectedMethodNotFoundException(string message) : this(message, null)
9+
{ }
10+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
namespace NSubstitute.Exceptions;
2+
3+
public class ProtectedMethodNotVirtualException(string message, Exception? innerException) : SubstituteException(message, innerException)
4+
{
5+
public ProtectedMethodNotVirtualException() : this("", null)
6+
{ }
7+
8+
public ProtectedMethodNotVirtualException(string message) : this(message, null)
9+
{ }
10+
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
using System.Reflection;
2+
using NSubstitute.Core;
3+
using NSubstitute.Core.Arguments;
4+
using NSubstitute.Exceptions;
5+
6+
// Disable nullability for client API, so it does not affect clients.
7+
#nullable disable annotations
8+
9+
namespace NSubstitute.Extensions;
10+
11+
public static class ProtectedExtensions
12+
{
13+
/// <summary>
14+
/// Configure behavior for a protected method with return value
15+
/// </summary>
16+
/// <typeparam name="T"></typeparam>
17+
/// <param name="obj">The object.</param>
18+
/// <param name="methodName">Name of the method.</param>
19+
/// <param name="args">The method arguments.</param>
20+
/// <returns>Result object from the method invocation.</returns>
21+
/// <exception cref="NSubstitute.Exceptions.NullSubstituteReferenceException">Substitute - Cannot mock null object</exception>
22+
/// <exception cref="NSubstitute.Exceptions.ProtectedMethodNotFoundException">Error mocking method. Method must be protected virtual and with correct matching arguments and type</exception>
23+
/// <exception cref="System.ArgumentException">Must provide valid protected method name to mock - methodName</exception>
24+
public static object Protected<T>(this T obj, string methodName, params object[] args) where T : class
25+
{
26+
if (obj == null) { throw new NullSubstituteReferenceException(); }
27+
if (string.IsNullOrWhiteSpace(methodName)) { throw new ArgumentException("Must provide valid protected method name to mock", nameof(methodName)); }
28+
29+
IList<IArgumentSpecification> argTypes = SubstitutionContext.Current.ThreadContext.PeekAllArgumentSpecifications();
30+
MethodInfo mthdInfo = obj.GetType().GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Instance, Type.DefaultBinder, argTypes.Select(x => x.ForType).ToArray(), null);
31+
32+
if (mthdInfo == null)
33+
{
34+
_ = SubstitutionContext.Current.ThreadContext.DequeueAllArgumentSpecifications();
35+
throw new ProtectedMethodNotFoundException($"No protected virtual method found with signature {methodName}({string.Join(", ", argTypes.Select(x => x.ForType))}) in {obj.GetType().BaseType!.Name}. " +
36+
"Check that the method name and arguments are correct. Public virtual methods must use standard NSubstitute mocking. See the documentation for additional info.");
37+
}
38+
if (!mthdInfo.IsVirtual)
39+
{
40+
_ = SubstitutionContext.Current.ThreadContext.DequeueAllArgumentSpecifications();
41+
throw new ProtectedMethodNotVirtualException($"{mthdInfo} is not virtual. NSubstitute can only work with virtual members of the class that are overridable in the test assembly");
42+
}
43+
44+
return mthdInfo.Invoke(obj, args);
45+
}
46+
47+
/// <summary>
48+
/// Configure behavior for a protected method with no return vlaue
49+
/// </summary>
50+
/// <typeparam name="T"></typeparam>
51+
/// <param name="obj">The object.</param>
52+
/// <param name="methodName">Name of the method.</param>
53+
/// <param name="args">The method arguments.</param>
54+
/// <returns>WhenCalled&lt;T&gt;.</returns>
55+
/// <exception cref="NSubstitute.Exceptions.NullSubstituteReferenceException">Substitute - Cannot mock null object</exception>
56+
/// <exception cref="NSubstitute.Exceptions.ProtectedMethodNotFoundException">Error mocking method. Method must be protected virtual and with correct matching arguments and type</exception>
57+
/// <exception cref="System.ArgumentException">Must provide valid protected method name to mock - methodName</exception>
58+
public static WhenCalled<T> When<T>(this T obj, string methodName, params object[] args) where T : class
59+
{
60+
if (obj == null) { throw new NullSubstituteReferenceException(); }
61+
if (string.IsNullOrWhiteSpace(methodName)) { throw new ArgumentException("Must provide valid protected method name to mock", nameof(methodName)); }
62+
63+
IList<IArgumentSpecification> argTypes = SubstitutionContext.Current.ThreadContext.PeekAllArgumentSpecifications();
64+
MethodInfo mthdInfo = obj.GetType().GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Instance, Type.DefaultBinder, argTypes.Select(y => y.ForType).ToArray(), null);
65+
66+
if (mthdInfo == null)
67+
{
68+
_ = SubstitutionContext.Current.ThreadContext.DequeueAllArgumentSpecifications();
69+
throw new ProtectedMethodNotFoundException($"No protected virtual method found with signature {methodName}({string.Join(", ", argTypes.Select(x => x.ForType))}) in {obj.GetType().BaseType!.Name}. " +
70+
"Check that the method name and arguments are correct. Public virtual methods must use standard NSubstitute mocking. See the documentation for additional info.");
71+
}
72+
if (!mthdInfo.IsVirtual)
73+
{
74+
_ = SubstitutionContext.Current.ThreadContext.DequeueAllArgumentSpecifications();
75+
throw new ProtectedMethodNotVirtualException($"{mthdInfo} is not virtual. NSubstitute can only work with virtual members of the class that are overridable in the test assembly");
76+
}
77+
78+
return new WhenCalled<T>(SubstitutionContext.Current, obj, x => mthdInfo.Invoke(x, args), MatchArgs.AsSpecifiedInCall);
79+
}
80+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
namespace NSubstitute.Acceptance.Specs.Infrastructure;
2+
3+
public abstract class AnotherClass
4+
{
5+
protected abstract string ProtectedMethod();
6+
7+
protected abstract string ProtectedMethod(int i);
8+
9+
protected abstract string ProtectedMethod(string msg, int i, char j);
10+
11+
protected abstract void ProtectedMethodWithNoReturn();
12+
13+
protected abstract void ProtectedMethodWithNoReturn(int i);
14+
15+
protected abstract void ProtectedMethodWithNoReturn(string msg, int i, char j);
16+
17+
public abstract void PublicVirtualMethod();
18+
19+
protected void ProtectedNonVirtualMethod()
20+
{ }
21+
22+
public string DoWork()
23+
{
24+
return ProtectedMethod();
25+
}
26+
27+
public string DoWork(int i)
28+
{
29+
return ProtectedMethod(i);
30+
}
31+
32+
public string DoWork(string msg, int i, char j)
33+
{
34+
return ProtectedMethod(msg, i, j);
35+
}
36+
37+
public void DoVoidWork()
38+
{
39+
ProtectedMethodWithNoReturn();
40+
}
41+
42+
public void DoVoidWork(int i)
43+
{
44+
ProtectedMethodWithNoReturn(i);
45+
}
46+
47+
public void DoVoidWork(string msg, int i, char j)
48+
{
49+
ProtectedMethodWithNoReturn(msg, i, j);
50+
}
51+
}

0 commit comments

Comments
 (0)