Skip to content

Commit e33629a

Browse files
authored
CSHARP-3458: Extend IAsyncCursor and IAsyncCursorSource to support IAsyncEnumerable (#1708)
1 parent ca6234d commit e33629a

14 files changed

+593
-38
lines changed

src/MongoDB.Driver/Core/IAsyncCursor.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,17 @@ public static class IAsyncCursorExtensions
360360
return new AsyncCursorEnumerableOneTimeAdapter<TDocument>(cursor, cancellationToken);
361361
}
362362

363+
/// <summary>
364+
/// Wraps a cursor in an IAsyncEnumerable that can be enumerated one time.
365+
/// </summary>
366+
/// <typeparam name="TDocument">The type of the document.</typeparam>
367+
/// <param name="cursor">The cursor.</param>
368+
/// <returns>An IAsyncEnumerable.</returns>
369+
public static IAsyncEnumerable<TDocument> ToAsyncEnumerable<TDocument>(this IAsyncCursor<TDocument> cursor)
370+
{
371+
return new AsyncCursorEnumerableOneTimeAdapter<TDocument>(cursor, CancellationToken.None);
372+
}
373+
363374
/// <summary>
364375
/// Returns a list containing all the documents returned by a cursor.
365376
/// </summary>

src/MongoDB.Driver/Core/IAsyncCursorSource.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,18 @@ public static class IAsyncCursorSourceExtensions
336336
return new AsyncCursorSourceEnumerableAdapter<TDocument>(source, cancellationToken);
337337
}
338338

339+
/// <summary>
340+
/// Wraps a cursor source in an IAsyncEnumerable. Each time GetAsyncEnumerator is called a new enumerator is returned and a new cursor
341+
/// is fetched from the cursor source on the first call to MoveNextAsync.
342+
/// </summary>
343+
/// <typeparam name="TDocument">The type of the document.</typeparam>
344+
/// <param name="source">The source.</param>
345+
/// <returns>An IAsyncEnumerable.</returns>
346+
public static IAsyncEnumerable<TDocument> ToAsyncEnumerable<TDocument>(this IAsyncCursorSource<TDocument> source)
347+
{
348+
return new AsyncCursorSourceEnumerableAdapter<TDocument>(source, CancellationToken.None);
349+
}
350+
339351
/// <summary>
340352
/// Returns a list containing all the documents returned by the cursor returned by a cursor source.
341353
/// </summary>

src/MongoDB.Driver/Core/Operations/AsyncCursorEnumerableOneTimeAdapter.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
namespace MongoDB.Driver.Core.Operations
2323
{
24-
internal sealed class AsyncCursorEnumerableOneTimeAdapter<TDocument> : IEnumerable<TDocument>
24+
internal sealed class AsyncCursorEnumerableOneTimeAdapter<TDocument> : IEnumerable<TDocument>, IAsyncEnumerable<TDocument>
2525
{
2626
private readonly CancellationToken _cancellationToken;
2727
private readonly IAsyncCursor<TDocument> _cursor;
@@ -33,6 +33,16 @@ public AsyncCursorEnumerableOneTimeAdapter(IAsyncCursor<TDocument> cursor, Cance
3333
_cancellationToken = cancellationToken;
3434
}
3535

36+
public IAsyncEnumerator<TDocument> GetAsyncEnumerator(CancellationToken cancellationToken = default)
37+
{
38+
if (_hasBeenEnumerated)
39+
{
40+
throw new InvalidOperationException("An IAsyncCursor can only be enumerated once.");
41+
}
42+
_hasBeenEnumerated = true;
43+
return new AsyncCursorEnumerator<TDocument>(_cursor, cancellationToken);
44+
}
45+
3646
public IEnumerator<TDocument> GetEnumerator()
3747
{
3848
if (_hasBeenEnumerated)

src/MongoDB.Driver/Core/Operations/AsyncCursorEnumerator.cs

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
using System.Collections;
1818
using System.Collections.Generic;
1919
using System.Threading;
20+
using System.Threading.Tasks;
2021
using MongoDB.Driver.Core.Misc;
2122

2223
namespace MongoDB.Driver.Core.Operations
2324
{
24-
internal class AsyncCursorEnumerator<TDocument> : IEnumerator<TDocument>
25+
internal sealed class AsyncCursorEnumerator<TDocument> : IEnumerator<TDocument>, IAsyncEnumerator<TDocument>
2526
{
2627
// private fields
2728
private IEnumerator<TDocument> _batchEnumerator;
@@ -72,6 +73,15 @@ public void Dispose()
7273
}
7374
}
7475

76+
public ValueTask DisposeAsync()
77+
{
78+
// TODO: implement true async disposal (CSHARP-5630)
79+
Dispose();
80+
81+
// TODO: convert to ValueTask.CompletedTask once we stop supporting older target frameworks
82+
return default; // Equivalent to ValueTask.CompletedTask which is not available on older target frameworks.
83+
}
84+
7585
public bool MoveNext()
7686
{
7787
ThrowIfDisposed();
@@ -82,24 +92,46 @@ public bool MoveNext()
8292
return true;
8393
}
8494

85-
while (true)
95+
while (_cursor.MoveNext(_cancellationToken))
8696
{
87-
if (_cursor.MoveNext(_cancellationToken))
97+
_batchEnumerator?.Dispose();
98+
_batchEnumerator = _cursor.Current.GetEnumerator();
99+
if (_batchEnumerator.MoveNext())
88100
{
89-
_batchEnumerator?.Dispose();
90-
_batchEnumerator = _cursor.Current.GetEnumerator();
91-
if (_batchEnumerator.MoveNext())
92-
{
93-
return true;
94-
}
101+
return true;
95102
}
96-
else
103+
}
104+
105+
_batchEnumerator?.Dispose();
106+
_batchEnumerator = null;
107+
_finished = true;
108+
return false;
109+
}
110+
111+
public async ValueTask<bool> MoveNextAsync()
112+
{
113+
ThrowIfDisposed();
114+
_started = true;
115+
116+
if (_batchEnumerator != null && _batchEnumerator.MoveNext())
117+
{
118+
return true;
119+
}
120+
121+
while (await _cursor.MoveNextAsync(_cancellationToken).ConfigureAwait(false))
122+
{
123+
_batchEnumerator?.Dispose();
124+
_batchEnumerator = _cursor.Current.GetEnumerator();
125+
if (_batchEnumerator.MoveNext())
97126
{
98-
_batchEnumerator = null;
99-
_finished = true;
100-
return false;
127+
return true;
101128
}
102129
}
130+
131+
_batchEnumerator?.Dispose();
132+
_batchEnumerator = null;
133+
_finished = true;
134+
return false;
103135
}
104136

105137
public void Reset()

src/MongoDB.Driver/Core/Operations/AsyncCursorSourceEnumerableAdapter.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@
1313
* limitations under the License.
1414
*/
1515

16-
using System;
1716
using System.Collections;
1817
using System.Collections.Generic;
1918
using System.Threading;
2019
using MongoDB.Driver.Core.Misc;
2120

2221
namespace MongoDB.Driver.Core.Operations
2322
{
24-
internal class AsyncCursorSourceEnumerableAdapter<TDocument> : IEnumerable<TDocument>
23+
internal sealed class AsyncCursorSourceEnumerableAdapter<TDocument> : IEnumerable<TDocument>, IAsyncEnumerable<TDocument>
2524
{
2625
// private fields
2726
private readonly CancellationToken _cancellationToken;
@@ -34,6 +33,11 @@ public AsyncCursorSourceEnumerableAdapter(IAsyncCursorSource<TDocument> source,
3433
_cancellationToken = cancellationToken;
3534
}
3635

36+
public IAsyncEnumerator<TDocument> GetAsyncEnumerator(CancellationToken cancellationToken = default)
37+
{
38+
return new AsyncCursorSourceEnumerator<TDocument>(_source, cancellationToken);
39+
}
40+
3741
// public methods
3842
public IEnumerator<TDocument> GetEnumerator()
3943
{
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System;
17+
using System.Collections.Generic;
18+
using System.Threading;
19+
using System.Threading.Tasks;
20+
using MongoDB.Driver.Core.Misc;
21+
22+
namespace MongoDB.Driver.Core.Operations
23+
{
24+
#pragma warning disable CA1001
25+
// we are suppressing this warning as we currently use the old Microsoft.CodeAnalysis.FxCopAnalyzers which doesn't
26+
// have a concept of IAsyncDisposable.
27+
// TODO: remove this suppression once we update our analyzers to use Microsoft.CodeAnalysis.NetAnalyzers
28+
internal sealed class AsyncCursorSourceEnumerator<TDocument> : IAsyncEnumerator<TDocument>
29+
#pragma warning restore CA1001
30+
{
31+
private readonly CancellationToken _cancellationToken;
32+
private AsyncCursorEnumerator<TDocument> _cursorEnumerator;
33+
private readonly IAsyncCursorSource<TDocument> _cursorSource;
34+
private bool _disposed;
35+
36+
public AsyncCursorSourceEnumerator(IAsyncCursorSource<TDocument> cursorSource, CancellationToken cancellationToken)
37+
{
38+
_cursorSource = Ensure.IsNotNull(cursorSource, nameof(cursorSource));
39+
_cancellationToken = cancellationToken;
40+
}
41+
42+
public TDocument Current
43+
{
44+
get
45+
{
46+
if (_cursorEnumerator == null)
47+
{
48+
throw new InvalidOperationException("Enumeration has not started. Call MoveNextAsync.");
49+
}
50+
return _cursorEnumerator.Current;
51+
}
52+
}
53+
54+
public async ValueTask DisposeAsync()
55+
{
56+
if (!_disposed)
57+
{
58+
_disposed = true;
59+
60+
if (_cursorEnumerator != null)
61+
{
62+
await _cursorEnumerator.DisposeAsync().ConfigureAwait(false);
63+
}
64+
}
65+
}
66+
67+
public async ValueTask<bool> MoveNextAsync()
68+
{
69+
ThrowIfDisposed();
70+
71+
if (_cursorEnumerator == null)
72+
{
73+
var cursor = await _cursorSource.ToCursorAsync(_cancellationToken).ConfigureAwait(false);
74+
_cursorEnumerator = new AsyncCursorEnumerator<TDocument>(cursor, _cancellationToken);
75+
}
76+
77+
return await _cursorEnumerator.MoveNextAsync().ConfigureAwait(false);
78+
}
79+
80+
public void Reset()
81+
{
82+
ThrowIfDisposed();
83+
throw new NotSupportedException();
84+
}
85+
86+
// private methods
87+
private void ThrowIfDisposed()
88+
{
89+
if (_disposed)
90+
{
91+
throw new ObjectDisposedException(GetType().Name);
92+
}
93+
}
94+
}
95+
}

src/MongoDB.Driver/Linq/MongoQueryable.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3385,6 +3385,18 @@ public static IQueryable<TSource> Take<TSource>(this IQueryable<TSource> source,
33853385
Expression.Constant(count)));
33863386
}
33873387

3388+
/// <summary>
3389+
/// Returns an <see cref="IAsyncEnumerable{T}" /> which can be enumerated asynchronously.
3390+
/// </summary>
3391+
/// <typeparam name="TSource">The type of the elements of <paramref name="source" />.</typeparam>
3392+
/// <param name="source">A sequence of values.</param>
3393+
/// <returns>An IAsyncEnumerable for the query results.</returns>
3394+
public static IAsyncEnumerable<TSource> ToAsyncEnumerable<TSource>(this IQueryable<TSource> source)
3395+
{
3396+
var cursorSource = GetCursorSource(source);
3397+
return cursorSource.ToAsyncEnumerable();
3398+
}
3399+
33883400
/// <summary>
33893401
/// Executes the LINQ query and returns a cursor to the results.
33903402
/// </summary>

tests/MongoDB.Driver.Tests/BulkWriteErrorTests.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
using System.Threading.Tasks;
2121
using FluentAssertions;
2222
using MongoDB.Bson;
23+
using MongoDB.Driver.Core.Operations;
2324
using Xunit;
2425

2526
namespace MongoDB.Driver.Tests
@@ -34,7 +35,7 @@ public class BulkWriteErrorTests
3435
[InlineData(12582, ServerErrorCategory.DuplicateKey)]
3536
public void Should_translate_category_correctly(int code, ServerErrorCategory expectedCategory)
3637
{
37-
var coreError = new Core.Operations.BulkWriteOperationError(0, code, "blah", new BsonDocument());
38+
var coreError = new BulkWriteOperationError(0, code, "blah", new BsonDocument());
3839
var subject = BulkWriteError.FromCore(coreError);
3940

4041
subject.Category.Should().Be(expectedCategory);

tests/MongoDB.Driver.Tests/Core/IAsyncCursorExtensionsTests.cs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
using System;
1717
using System.Collections.Generic;
1818
using System.Linq;
19+
using System.Threading;
20+
using System.Threading.Tasks;
1921
using FluentAssertions;
2022
using MongoDB.Bson;
2123
using MongoDB.Bson.Serialization.Serializers;
@@ -201,6 +203,55 @@ public void SingleOrDefault_should_throw_when_cursor_has_wrong_number_of_documen
201203
action.ShouldThrow<InvalidOperationException>();
202204
}
203205

206+
[Fact]
207+
public void ToAsyncEnumerable_result_should_only_be_enumerable_one_time()
208+
{
209+
var cursor = CreateCursor(2);
210+
var enumerable = cursor.ToAsyncEnumerable();
211+
enumerable.GetAsyncEnumerator();
212+
213+
Record.Exception(() => enumerable.GetAsyncEnumerator()).Should().BeOfType<InvalidOperationException>();
214+
}
215+
216+
[Fact]
217+
public async Task ToAsyncEnumerable_should_respect_cancellation_token()
218+
{
219+
var source = CreateCursor(5);
220+
using var cts = new CancellationTokenSource();
221+
222+
var count = 0;
223+
var exception = await Record.ExceptionAsync(async () =>
224+
{
225+
await foreach (var doc in source.ToAsyncEnumerable().WithCancellation(cts.Token))
226+
{
227+
count++;
228+
if (count == 2)
229+
cts.Cancel();
230+
}
231+
});
232+
233+
exception.Should().BeOfType<OperationCanceledException>();
234+
}
235+
236+
[Fact]
237+
public async Task ToAsyncEnumerable_should_return_expected_result()
238+
{
239+
var cursor = CreateCursor(2);
240+
var expectedDocuments = new[]
241+
{
242+
new BsonDocument("_id", 0),
243+
new BsonDocument("_id", 1)
244+
};
245+
246+
var result = new List<BsonDocument>();
247+
await foreach (var doc in cursor.ToAsyncEnumerable())
248+
{
249+
result.Add(doc);
250+
}
251+
252+
result.Should().Equal(expectedDocuments);
253+
}
254+
204255
[Fact]
205256
public void ToEnumerable_result_should_only_be_enumerable_one_time()
206257
{

0 commit comments

Comments
 (0)