Skip to content

Commit a6c2dbb

Browse files
author
Meyn
committed
Fix CTS cancellation issue
Enhance RequestContainer concurrency Expand unit test coverage
1 parent fb43fbf commit a6c2dbb

File tree

5 files changed

+505
-69
lines changed

5 files changed

+505
-69
lines changed

Requests/Request.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ private void RegisterNewCTS()
136136
_cts?.Dispose();
137137
_ctr.Unregister();
138138
_cts = CreateNewCTS();
139-
_ctr = Token.Register(() => SynchronizationContext.Post((o) => Options.RequestCancelled?.Invoke((IRequest)o!), this));
139+
_ctr = Token.Register(() => { Cancel(); SynchronizationContext.Post((o) => Options.RequestCancelled?.Invoke((IRequest)o!), this); });
140140
}
141141

142142
/// <summary>
@@ -221,7 +221,7 @@ async Task IRequest.StartRequestAsync()
221221
SetResult(returnItem);
222222
}
223223

224-
private async Task<Request<TOptions, TCompleated, TFailed>.RequestReturn> TryRunRequestAsync()
224+
private async Task<RequestReturn> TryRunRequestAsync()
225225
{
226226
RequestReturn returnItem = new();
227227
try
@@ -246,7 +246,7 @@ private void SetResult(RequestReturn returnItem)
246246
SetTaskState();
247247
}
248248

249-
private void EvalueateRequest(Request<TOptions, TCompleated, TFailed>.RequestReturn returnItem)
249+
private void EvalueateRequest(RequestReturn returnItem)
250250
{
251251
if (State != RequestState.Running)
252252
return;

Requests/RequestContainer.cs

Lines changed: 67 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@ namespace Requests
1010
public class RequestContainer<TRequest> : IEnumerable<TRequest>, IRequest where TRequest : IRequest
1111
{
1212
private volatile TRequest[] _requests = Array.Empty<TRequest>();
13+
private int _count;
1314
private bool _isrunning = true;
1415
private bool _isCanceled = false;
1516
private bool _disposed = false;
1617
private TaskCompletionSource? _task;
1718
private CancellationTokenSource _taskCancelationTokenSource = new();
1819
private RequestState _state = RequestState.Paused;
1920
private RequestPriority _priority = RequestPriority.Normal;
21+
private int _writeInProgress; // 0 means no write, 1 means write in progress
2022

2123
/// <summary>
2224
/// Represents the combined task of the requests.
@@ -51,7 +53,7 @@ protected set
5153
/// <summary>
5254
/// Gets the count of <see cref="IRequest"/> instances contained in the <see cref="RequestContainer{TRequest}"/>.
5355
/// </summary>
54-
public int Length => _requests.Length;
56+
public int Length => _count; // Updated to use the new _count field
5557

5658
/// <summary>
5759
/// The synchronization context captured when this object was created. This will never be null.
@@ -61,7 +63,7 @@ protected set
6163
/// <summary>
6264
/// All exceptions that were thrown by the requests.
6365
/// </summary>
64-
public AggregateException? Exception => new(_requests.Where(x => x.Exception != null).Select(x => x.Exception!));
66+
public AggregateException? Exception => new(GetStored().Where(x => x?.Exception != null).Select(x => x!.Exception!));
6567

6668
/// <summary>
6769
/// Constructor that merges <see cref="IRequest"/> instances together.
@@ -127,7 +129,16 @@ public virtual void Add(TRequest request)
127129
request.Pause();
128130

129131
request.StateChanged += OnStateChanged;
130-
_requests = CreateArrayWithNewItems(request);
132+
while (Interlocked.CompareExchange(ref _writeInProgress, 1, 0) == 1)
133+
Thread.Yield();
134+
if (_requests.Length == _count)
135+
Grow();
136+
137+
_requests[_count] = request;
138+
_count++;
139+
140+
Interlocked.Exchange(ref _writeInProgress, 0); // Release the write lock
141+
131142
NewTaskCompletion();
132143
OnStateChanged(this, request.State);
133144
}
@@ -141,7 +152,8 @@ private void NewTaskCompletion()
141152
_taskCancelationTokenSource = new();
142153
if (Task.IsCompleted)
143154
_task = new(TaskCreationOptions.RunContinuationsAsynchronously);
144-
Task.WhenAll(_requests.Select(request => request.Task)).ContinueWith(task => _task?.TrySetResult(), _taskCancelationTokenSource.Token);
155+
156+
Task.WhenAll(GetStored().Select(request => request.Task)).ContinueWith(task => _task?.TrySetResult(), _taskCancelationTokenSource.Token);
145157
}
146158

147159
/// <summary>
@@ -163,22 +175,33 @@ public virtual void AddRange(params TRequest[] requests)
163175
else if (!_isrunning)
164176
Array.ForEach(requests, request => request.Pause());
165177
Array.ForEach(requests, request => request.StateChanged += OnStateChanged);
166-
_requests = CreateArrayWithNewItems(requests);
178+
179+
while (Interlocked.CompareExchange(ref _writeInProgress, 1, 0) == 1)
180+
Thread.Yield();
181+
while (_requests.Length < _count + requests.Length)
182+
Grow();
183+
184+
Array.Copy(requests, 0, _requests, _count, requests.Length);
185+
_count += requests.Length;
186+
187+
Interlocked.Exchange(ref _writeInProgress, 0); // Release the write lock
167188
NewTaskCompletion();
168189
State = CalculateState();
169190
}
170191

171192
/// <summary>
172-
/// Creates a new array that includes the existing requests and the new items.
193+
/// Increases the capacity of the <see cref="RequestContainer{TRequest}"/> to accommodate additional elements.
173194
/// </summary>
174-
/// <param name="items">The new items to be added to the array.</param>
175-
/// <returns>A new array containing the existing and new items.</returns>
176-
private TRequest[] CreateArrayWithNewItems(params TRequest[] items)
195+
private void Grow()
177196
{
178-
TRequest[] result = new TRequest[_requests.Length + items.Length];
179-
_requests.CopyTo(result, 0);
180-
items.CopyTo(result, _requests.Length);
181-
return result;
197+
const int MinimumGrow = 4;
198+
int capacity = (int)(_requests.Length * 2L);
199+
if (capacity < _requests.Length + MinimumGrow)
200+
capacity = _requests.Length + MinimumGrow;
201+
202+
TRequest[] newArray = new TRequest[capacity];
203+
_requests.CopyTo(newArray, 0);
204+
_requests = newArray;
182205
}
183206

184207
/// <summary>
@@ -202,7 +225,7 @@ private void OnStateChanged(object? sender, RequestState state)
202225
private RequestState CalculateState()
203226
{
204227
RequestState state;
205-
IEnumerable<int> states = _requests.Select(req => (int)req.State);
228+
IEnumerable<int> states = GetStored().Select(req => (int)req.State);
206229
int[] counter = new int[7];
207230
foreach (int value in states)
208231
counter[value]++;
@@ -217,7 +240,7 @@ private RequestState CalculateState()
217240
state = RequestState.Idle;
218241
else if (counter[4] > 0)
219242
state = RequestState.Waiting;
220-
else if (counter[2] == Length)
243+
else if (counter[2] == _count)
221244
state = RequestState.Compleated;
222245
else if (counter[3] > 0)
223246
state = RequestState.Paused;
@@ -232,12 +255,11 @@ async Task IRequest.StartRequestAsync()
232255
if (_isrunning)
233256
return;
234257
_isrunning = true;
235-
foreach (TRequest request in _requests)
236-
_ = TrySetIdle();
237-
foreach (TRequest request in _requests.Where(x => x.State == RequestState.Idle))
258+
foreach (TRequest request in GetStored())
259+
_ = request.TrySetIdle();
260+
foreach (TRequest request in GetStored().Where(x => x.State == RequestState.Idle))
238261
await request.StartRequestAsync();
239262

240-
241263
_isrunning = false;
242264
}
243265

@@ -248,8 +270,15 @@ async Task IRequest.StartRequestAsync()
248270
public virtual void Remove(params TRequest[] requests)
249271
{
250272
Array.ForEach(requests, request => request.StateChanged -= StateChanged);
251-
_requests = _requests.Where(x => !requests.Any(y => y.Equals(x))).ToArray();
252-
if (_requests.Length > 0 && !Task.IsCompleted)
273+
274+
while (Interlocked.CompareExchange(ref _writeInProgress, 1, 0) == 1)
275+
Thread.Yield();
276+
_requests = GetStored().Where(x => !requests.Any(y => y.Equals(x))).ToArray();
277+
_count = _requests.Length;
278+
279+
Interlocked.Exchange(ref _writeInProgress, 0); // Release the write lock
280+
281+
if (_count > 0 && !Task.IsCompleted)
253282
NewTaskCompletion();
254283
else
255284
_task = null;
@@ -263,7 +292,9 @@ public virtual void Remove(params TRequest[] requests)
263292
public void Cancel()
264293
{
265294
_isCanceled = true;
266-
Array.ForEach(_requests, request => request.Cancel());
295+
foreach (var request in GetStored())
296+
request.Cancel();
297+
267298
}
268299

269300
/// <summary>
@@ -274,7 +305,7 @@ public void Start()
274305
if (_isrunning)
275306
return;
276307
_isrunning = true;
277-
foreach (TRequest? request in _requests)
308+
foreach (TRequest request in GetStored())
278309
request.Start();
279310
}
280311

@@ -284,7 +315,7 @@ public void Start()
284315
public void Pause()
285316
{
286317
_isrunning = false;
287-
foreach (TRequest? request in _requests)
318+
foreach (TRequest request in GetStored())
288319
request.Pause();
289320
}
290321

@@ -297,7 +328,7 @@ public void Dispose()
297328
return;
298329
_disposed = true;
299330
StateChanged = null;
300-
foreach (TRequest? request in _requests)
331+
foreach (TRequest request in GetStored())
301332
request.Dispose();
302333
GC.SuppressFinalize(this);
303334
}
@@ -306,10 +337,11 @@ public void Dispose()
306337
/// Provides an enumerator that iterates through the <see cref="RequestContainer{TRequest}"/>
307338
/// </summary>
308339
/// <returns> A <see cref="RequestContainer{TRequest}"/> .Enumerator for the <see cref="RequestContainer{TRequest}"/> .</returns>
309-
public IEnumerator<TRequest> GetEnumerator() => ((IEnumerable<TRequest>)_requests).GetEnumerator();
340+
public IEnumerator<TRequest> GetEnumerator() => GetStored().GetEnumerator();
341+
310342

311343
/// <inheritdoc/>
312-
IEnumerator IEnumerable.GetEnumerator() => _requests.GetEnumerator();
344+
IEnumerator IEnumerable.GetEnumerator() => GetStored().GetEnumerator();
313345

314346
/// <summary>
315347
/// Attempts to set all <see cref="IRequest"/> objects in the container's <see cref="State"/> to idle.
@@ -318,9 +350,15 @@ public void Dispose()
318350
/// <returns>True if all <see cref="IRequest"/> objects are in an idle <see cref="RequestState"/>, otherwise false.</returns>
319351
public bool TrySetIdle()
320352
{
321-
foreach (TRequest request in _requests)
353+
foreach (TRequest request in GetStored())
322354
_ = request.TrySetIdle();
323355
return State == RequestState.Idle;
324356
}
357+
358+
/// <summary>
359+
/// Gets the stored requests that are non-null.
360+
/// </summary>
361+
/// <returns>An enumerable of non-null requests.</returns>
362+
private IEnumerable<TRequest> GetStored() => _requests[.._count].Where(req => req != null)!;
325363
}
326-
}
364+
}

Requests/Requests.csproj

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
<PackageTags>async; channel; priority; request; parallel; </PackageTags>
1717
<RepositoryUrl>https://github.com/TypNull/Requests</RepositoryUrl>
1818
<PackageIcon>logo.png</PackageIcon>
19-
<Version>2.1.4</Version>
19+
<Version>2.1.5</Version>
2020
<PackageRequireLicenseAcceptance>True</PackageRequireLicenseAcceptance>
2121
<PackageLicenseFile>LICENSE.txt</PackageLicenseFile>
2222
<PackageId>Shard.Requests</PackageId>
23-
<PackageReleaseNotes>Request Dispose Task not finished Fix</PackageReleaseNotes>
23+
<PackageReleaseNotes>Fix CTS cancellation issue
24+
Enhance RequestContainer concurrency
25+
Expand unit test coverage</PackageReleaseNotes>
2426
</PropertyGroup>
2527

2628
<ItemGroup>

0 commit comments

Comments
 (0)