@@ -9,6 +9,7 @@ namespace Bunit.Extensions.WaitForHelpers;
99/// </summary>
1010public abstract class WaitForHelper < T > : IDisposable
1111{
12+ private readonly Timer timer ;
1213 private readonly TaskCompletionSource < T > checkPassedCompletionSource ;
1314 private readonly Func < ( bool CheckPassed , T Content ) > completeChecker ;
1415 private readonly IRenderedFragmentBase renderedFragment ;
@@ -51,7 +52,13 @@ protected WaitForHelper(
5152
5253 logger = renderedFragment . Services . CreateLogger < WaitForHelper < T > > ( ) ;
5354 checkPassedCompletionSource = new TaskCompletionSource < T > ( ) ;
54- WaitTask = CreateWaitTask ( renderedFragment , timeout ) ;
55+ timer = new Timer ( _ =>
56+ {
57+ logger . LogWaiterTimedOut ( renderedFragment . ComponentId ) ;
58+ checkPassedCompletionSource . TrySetException ( new WaitForFailedException ( TimeoutErrorMessage , capturedException ) ) ;
59+ } ) ;
60+ WaitTask = CreateWaitTask ( renderedFragment ) ;
61+ timer . Change ( GetRuntimeTimeout ( timeout ) , Timeout . InfiniteTimeSpan ) ;
5562
5663 InitializeWaiting ( ) ;
5764 }
@@ -80,6 +87,7 @@ protected virtual void Dispose(bool disposing)
8087 return ;
8188
8289 isDisposed = true ;
90+ timer . Dispose ( ) ;
8391 checkPassedCompletionSource . TrySetCanceled ( ) ;
8492 renderedFragment . OnAfterRender -= OnAfterRender ;
8593 logger . LogWaiterDisposed ( renderedFragment . ComponentId ) ;
@@ -105,41 +113,30 @@ private void InitializeWaiting()
105113 }
106114 }
107115
108- private Task < T > CreateWaitTask ( IRenderedFragmentBase renderedFragment , TimeSpan ? timeout )
116+ private Task < T > CreateWaitTask ( IRenderedFragmentBase renderedFragment )
109117 {
110- var renderer = renderedFragment . Services . GetRequiredService < ITestRenderer > ( ) ;
118+ var renderer = renderedFragment
119+ . Services
120+ . GetRequiredService < ITestRenderer > ( ) ;
111121
112122 // Two to failure conditions, that the renderer captures an unhandled
113123 // exception from a component or itself, or that the timeout is reached,
114124 // are executed on the renderes scheduler, to ensure that OnAfterRender
115125 // and the continuations does not happen at the same time.
116126 var failureTask = renderer . Dispatcher . InvokeAsync ( ( ) =>
117127 {
118- var taskScheduler = TaskScheduler . FromCurrentSynchronizationContext ( ) ;
119-
120- var renderException = renderer
128+ return renderer
121129 . UnhandledException
122130 . ContinueWith (
123131 x => Task . FromException < T > ( x . Result ) ,
124132 CancellationToken . None ,
125133 TaskContinuationOptions . OnlyOnRanToCompletion | TaskContinuationOptions . ExecuteSynchronously ,
126- taskScheduler ) ;
127-
128- var timeoutTask = Task . Delay ( GetRuntimeTimeout ( timeout ) )
129- . ContinueWith (
130- x =>
131- {
132- logger . LogWaiterTimedOut ( renderedFragment . ComponentId ) ;
133- return Task . FromException < T > ( new WaitForFailedException ( TimeoutErrorMessage , capturedException ) ) ;
134- } ,
135- CancellationToken . None ,
136- TaskContinuationOptions . OnlyOnRanToCompletion | TaskContinuationOptions . ExecuteSynchronously ,
137- taskScheduler ) ;
138-
139- return Task . WhenAny ( renderException , timeoutTask ) . Unwrap ( ) ;
134+ TaskScheduler . FromCurrentSynchronizationContext ( ) ) ;
140135 } ) . Unwrap ( ) ;
141136
142- return Task . WhenAny ( failureTask , checkPassedCompletionSource . Task ) . Unwrap ( ) ;
137+ return Task
138+ . WhenAny ( checkPassedCompletionSource . Task , failureTask )
139+ . Unwrap ( ) ;
143140 }
144141
145142 private void OnAfterRender ( object ? sender , EventArgs args )
@@ -170,7 +167,8 @@ private void OnAfterRender(object? sender, EventArgs args)
170167
171168 if ( StopWaitingOnCheckException )
172169 {
173- checkPassedCompletionSource . TrySetException ( new WaitForFailedException ( CheckThrowErrorMessage , capturedException ) ) ;
170+ checkPassedCompletionSource . TrySetException (
171+ new WaitForFailedException ( CheckThrowErrorMessage , capturedException ) ) ;
174172 Dispose ( ) ;
175173 }
176174 }
0 commit comments