Skip to content

Commit e926abe

Browse files
Address concurrency bug (#23)
1 parent 311f959 commit e926abe

File tree

3 files changed

+83
-44
lines changed

3 files changed

+83
-44
lines changed

src/Immediate.Cache.Shared/ApplicationCacheBase.cs

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -133,27 +133,42 @@ public async ValueTask<TResponse> GetValue(CancellationToken cancellationToken)
133133
[SuppressMessage("Maintainability", "CA1508:Avoid dead conditional code", Justification = "Double-checked lock pattern")]
134134
private Task<TResponse> GetHandlerTask()
135135
{
136-
if (_responseSource is not null)
136+
if (_responseSource is { Task.Status: not (TaskStatus.Faulted or TaskStatus.Canceled) })
137137
return _responseSource.Task;
138138

139139
lock (_lock)
140140
{
141-
if (_responseSource is not null)
141+
if (_responseSource is { Task.Status: not (TaskStatus.Faulted or TaskStatus.Canceled) })
142142
return _responseSource.Task;
143143

144-
var ts = _tokenSource = new();
145-
_responseSource = new();
144+
// escape current sync context
145+
_ = Task.Factory.StartNew(
146+
RunHandler,
147+
CancellationToken.None,
148+
TaskCreationOptions.PreferFairness,
149+
TaskScheduler.Current
150+
);
146151

147-
return Task.Run(() => RunHandler(ts));
152+
return (_responseSource = new()).Task;
148153
}
149154
}
150155

151-
private async Task<TResponse> RunHandler(CancellationTokenSource tokenSource)
156+
private async Task RunHandler()
152157
{
153158
while (true)
154159
{
155-
if (_responseSource?.Task is { IsCompletedSuccessfully: true } task)
156-
return await task.ConfigureAwait(false);
160+
CancellationTokenSource tokenSource;
161+
162+
lock (_lock)
163+
{
164+
if (_responseSource?.Task is { IsCompletedSuccessfully: true })
165+
return;
166+
167+
if (_tokenSource is null or { IsCancellationRequested: true })
168+
_tokenSource = new();
169+
170+
tokenSource = _tokenSource;
171+
}
157172

158173
try
159174
{
@@ -171,26 +186,25 @@ private async Task<TResponse> RunHandler(CancellationTokenSource tokenSource)
171186

172187
lock (_lock)
173188
{
174-
if (!token.IsCancellationRequested)
175-
{
176-
var rs = _responseSource ??= new();
177-
rs.SetResult(response);
178-
179-
return response;
180-
}
189+
if (!tokenSource.IsCancellationRequested)
190+
_responseSource!.SetResult(response);
181191
}
182192
}
183193
}
184194
catch (OperationCanceledException) when (tokenSource.IsCancellationRequested)
185195
{
186196
}
187-
188-
lock (_lock)
197+
#pragma warning disable CA1031 // Do not catch general exception types
198+
// no one is listening to `RunHandler`; return the exception via `SetException`
199+
catch (Exception ex)
200+
#pragma warning restore CA1031
189201
{
190-
if (_tokenSource is null or { IsCancellationRequested: true })
191-
_tokenSource = new();
192-
193-
tokenSource = _tokenSource;
202+
lock (_lock)
203+
{
204+
if (!tokenSource.IsCancellationRequested)
205+
_responseSource?.SetException(ex);
206+
return;
207+
}
194208
}
195209
}
196210
}
@@ -199,8 +213,11 @@ public void SetValue(TResponse response)
199213
{
200214
lock (_lock)
201215
{
202-
_responseSource = new TaskCompletionSource<TResponse>();
216+
if (_responseSource is null or { Task.IsCompleted: true })
217+
_responseSource = new TaskCompletionSource<TResponse>();
218+
203219
_responseSource.SetResult(response);
220+
204221
_tokenSource?.Cancel();
205222
_tokenSource = null;
206223
}
@@ -210,7 +227,9 @@ public void RemoveValue()
210227
{
211228
lock (_lock)
212229
{
213-
_responseSource = null;
230+
if (_responseSource is { Task.IsCompleted: true })
231+
_responseSource = null;
232+
214233
_tokenSource?.Cancel();
215234
_tokenSource = null;
216235
}

tests/Immediate.Cache.FunctionalTests/ApplicationCacheTests.cs

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,12 @@ public async Task SimultaneousAccessIsSerialized()
8080
{
8181
Value = 1,
8282
Name = "Request1",
83-
CompletionSource = new(),
8483
};
8584

8685
var request2 = new DelayGetValue.Query()
8786
{
8887
Value = 1,
8988
Name = "Request2",
90-
CompletionSource = new(),
9189
};
9290

9391
var cache = _serviceProvider.GetRequiredService<DelayGetValueCache>();
@@ -102,7 +100,7 @@ public async Task SimultaneousAccessIsSerialized()
102100
Assert.Equal(0, request2.TimesExecuted);
103101

104102
// request2 does nothing at this point
105-
request2.CompletionSource.SetResult();
103+
request2.WaitForTestToContinueOperation.SetResult();
106104

107105
Assert.False(response1Task.IsCompleted);
108106
Assert.False(response2Task.IsCompleted);
@@ -111,7 +109,7 @@ public async Task SimultaneousAccessIsSerialized()
111109
Assert.Equal(0, request2.TimesExecuted);
112110

113111
// trigger request1, which should run exactly once
114-
request1.CompletionSource.SetResult();
112+
request1.WaitForTestToContinueOperation.SetResult();
115113

116114
var response1 = await response1Task;
117115
var response2 = await response2Task;
@@ -130,7 +128,6 @@ public async Task ProperlyUsesCancellationToken()
130128
{
131129
Value = 1,
132130
Name = "Request1",
133-
CompletionSource = new(),
134131
};
135132

136133
using var tcs = new CancellationTokenSource();
@@ -144,14 +141,13 @@ public async Task ProperlyUsesCancellationToken()
144141
Assert.False(request.CancellationToken.IsCancellationRequested);
145142

146143
// actual handler will continue executing in spite of no remaining callers
147-
request.CompletionSource.SetResult();
144+
request.WaitForTestToContinueOperation.SetResult();
148145

149146
// check that value is now properly in cache
150147
var request2 = new DelayGetValue.Query()
151148
{
152149
Value = 1,
153150
Name = "Request2",
154-
CompletionSource = new(),
155151
};
156152

157153
var response = await cache.GetValue(request2);
@@ -168,15 +164,13 @@ public async Task CancellingFirstAccessOperatesCorrectly()
168164
{
169165
Value = 1,
170166
Name = "Request1",
171-
CompletionSource = new(),
172167
};
173168

174169
using var cts2 = new CancellationTokenSource();
175170
var request2 = new DelayGetValue.Query()
176171
{
177172
Value = 1,
178173
Name = "Request2",
179-
CompletionSource = new(),
180174
};
181175

182176
var cache = _serviceProvider.GetRequiredService<DelayGetValueCache>();
@@ -210,15 +204,13 @@ public async Task CancellingSecondAccessOperatesCorrectly()
210204
{
211205
Value = 1,
212206
Name = "Request1",
213-
CompletionSource = new(),
214207
};
215208

216209
using var cts2 = new CancellationTokenSource();
217210
var request2 = new DelayGetValue.Query()
218211
{
219212
Value = 1,
220213
Name = "Request2",
221-
CompletionSource = new(),
222214
};
223215

224216
var cache = _serviceProvider.GetRequiredService<DelayGetValueCache>();
@@ -251,19 +243,18 @@ public async Task RemovingValueCancelsExistingOperation()
251243
{
252244
Value = 1,
253245
Name = "Request1",
254-
CompletionSource = new(),
255246
};
256247

257248
using var tcs = new CancellationTokenSource();
258249
var cache = _serviceProvider.GetRequiredService<DelayGetValueCache>();
259250
var responseTask = cache.GetValue(request, tcs.Token);
260251

261-
cache.RemoveValue(request);
252+
await request.WaitForTestToStartExecuting.Task;
262253

263-
// allow IC task to be run
264-
await Task.Delay(10);
254+
cache.RemoveValue(request);
255+
await request.WaitForTestToFinalize.Task;
265256

266-
request.CompletionSource.SetResult();
257+
request.WaitForTestToContinueOperation.SetResult();
267258

268259
var response = await responseTask;
269260

@@ -281,22 +272,39 @@ public async Task SettingValueCancelsExistingOperation()
281272
{
282273
Value = 1,
283274
Name = "Request1",
284-
CompletionSource = new(),
285275
};
286276

287277
var cache = _serviceProvider.GetRequiredService<DelayGetValueCache>();
288278
var responseTask = cache.GetValue(request, default);
289279

290-
// allow IC task to be run
291-
await Task.Delay(10);
280+
await request.WaitForTestToStartExecuting.Task;
292281

293282
cache.SetValue(request, new(5, ExecutedHandler: false, Guid.NewGuid()));
294283

295284
var response = await responseTask;
296285
Assert.Equal(5, response.Value);
297286
Assert.False(response.ExecutedHandler);
298287

288+
await request.WaitForTestToFinalize.Task;
289+
299290
Assert.Equal(0, request.TimesExecuted);
300291
Assert.Equal(1, request.TimesCancelled);
301292
}
293+
294+
[Test]
295+
public async Task ExceptionGetsPropagatedCorrectly()
296+
{
297+
var request = new DelayGetValue.Query()
298+
{
299+
Value = 1,
300+
Name = "Request1",
301+
ThrowException = true,
302+
};
303+
304+
var cache = _serviceProvider.GetRequiredService<DelayGetValueCache>();
305+
var responseTask = cache.GetValue(request, default);
306+
307+
var ex = await Assert.ThrowsAsync<InvalidOperationException>(async () => await responseTask);
308+
Assert.Equal("Test Exception 1", ex.Message);
309+
}
302310
}

tests/Immediate.Cache.FunctionalTests/DelayGetValue.cs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@ public sealed class Query
99
{
1010
public required int Value { get; init; }
1111
public required string Name { get; init; }
12-
public required TaskCompletionSource CompletionSource { get; init; }
1312
public int TimesExecuted { get; set; }
1413
public int TimesCancelled { get; set; }
1514
public CancellationToken CancellationToken { get; set; }
15+
public bool ThrowException { get; init; }
16+
17+
public TaskCompletionSource WaitForTestToContinueOperation { get; } = new();
18+
public TaskCompletionSource WaitForTestToStartExecuting { get; } = new();
19+
public TaskCompletionSource WaitForTestToFinalize { get; } = new();
1620
}
1721

1822
public sealed record Response(int Value, bool ExecutedHandler, Guid RandomValue);
@@ -24,10 +28,14 @@ private static async ValueTask<Response> HandleAsync(
2428
CancellationToken token
2529
)
2630
{
31+
if (query.ThrowException)
32+
throw new InvalidOperationException($"Test Exception {query.Value}");
33+
2734
try
2835
{
2936
query.CancellationToken = token;
30-
await query.CompletionSource.Task.WaitAsync(token);
37+
_ = query.WaitForTestToStartExecuting.TrySetResult();
38+
await query.WaitForTestToContinueOperation.Task.WaitAsync(token);
3139

3240
lock (s_lock)
3341
query.TimesExecuted++;
@@ -40,5 +48,9 @@ CancellationToken token
4048
query.TimesCancelled++;
4149
throw;
4250
}
51+
finally
52+
{
53+
_ = query.WaitForTestToFinalize.TrySetResult();
54+
}
4355
}
4456
}

0 commit comments

Comments
 (0)