Skip to content

Commit 6c20aa7

Browse files
committed
CSHARP-4688: Calling queryable.First when queryable is of type IMongoQueryable<T> doesn't add { $limit : 1 } stage to pipeline.
1 parent 5974753 commit 6c20aa7

File tree

6 files changed

+250
-6
lines changed

6 files changed

+250
-6
lines changed

src/MongoDB.Driver.Core/IAsyncCursorSource.cs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ public static class IAsyncCursorSourceExtensions
5757
/// <returns>True if the cursor contains any documents.</returns>
5858
public static bool Any<TDocument>(this IAsyncCursorSource<TDocument> source, CancellationToken cancellationToken = default(CancellationToken))
5959
{
60+
if (source is IQueryable<TDocument> queryable && !cancellationToken.CanBeCanceled)
61+
{
62+
return Queryable.Any(queryable);
63+
}
64+
6065
using (var cursor = source.ToCursor(cancellationToken))
6166
{
6267
return cursor.Any(cancellationToken);
@@ -72,6 +77,11 @@ public static class IAsyncCursorSourceExtensions
7277
/// <returns>A Task whose result is true if the cursor contains any documents.</returns>
7378
public static async Task<bool> AnyAsync<TDocument>(this IAsyncCursorSource<TDocument> source, CancellationToken cancellationToken = default(CancellationToken))
7479
{
80+
if (source is IMongoQueryableForwarder<TDocument> queryableForwarder)
81+
{
82+
return await queryableForwarder.AnyAsync(cancellationToken).ConfigureAwait(false);
83+
}
84+
7585
using (var cursor = await source.ToCursorAsync(cancellationToken).ConfigureAwait(false))
7686
{
7787
return await cursor.AnyAsync(cancellationToken).ConfigureAwait(false);
@@ -87,6 +97,11 @@ public static class IAsyncCursorSourceExtensions
8797
/// <returns>The first document.</returns>
8898
public static TDocument First<TDocument>(this IAsyncCursorSource<TDocument> source, CancellationToken cancellationToken = default(CancellationToken))
8999
{
100+
if (source is IQueryable<TDocument> queryable && !cancellationToken.CanBeCanceled)
101+
{
102+
return Queryable.First(queryable);
103+
}
104+
90105
using (var cursor = source.ToCursor(cancellationToken))
91106
{
92107
return cursor.First(cancellationToken);
@@ -102,6 +117,11 @@ public static class IAsyncCursorSourceExtensions
102117
/// <returns>A Task whose result is the first document.</returns>
103118
public static async Task<TDocument> FirstAsync<TDocument>(this IAsyncCursorSource<TDocument> source, CancellationToken cancellationToken = default(CancellationToken))
104119
{
120+
if (source is IMongoQueryableForwarder<TDocument> queryableForwarder)
121+
{
122+
return await queryableForwarder.FirstAsync(cancellationToken).ConfigureAwait(false);
123+
}
124+
105125
using (var cursor = await source.ToCursorAsync(cancellationToken).ConfigureAwait(false))
106126
{
107127
return await cursor.FirstAsync(cancellationToken).ConfigureAwait(false);
@@ -117,6 +137,11 @@ public static class IAsyncCursorSourceExtensions
117137
/// <returns>The first document of the cursor, or a default value if the cursor contains no documents.</returns>
118138
public static TDocument FirstOrDefault<TDocument>(this IAsyncCursorSource<TDocument> source, CancellationToken cancellationToken = default(CancellationToken))
119139
{
140+
if (source is IQueryable<TDocument> queryable && !cancellationToken.CanBeCanceled)
141+
{
142+
return Queryable.FirstOrDefault(queryable);
143+
}
144+
120145
using (var cursor = source.ToCursor(cancellationToken))
121146
{
122147
return cursor.FirstOrDefault(cancellationToken);
@@ -132,6 +157,11 @@ public static class IAsyncCursorSourceExtensions
132157
/// <returns>A Task whose result is the first document of the cursor, or a default value if the cursor contains no documents.</returns>
133158
public static async Task<TDocument> FirstOrDefaultAsync<TDocument>(this IAsyncCursorSource<TDocument> source, CancellationToken cancellationToken = default(CancellationToken))
134159
{
160+
if (source is IMongoQueryableForwarder<TDocument> queryableForwarder)
161+
{
162+
return await queryableForwarder.FirstOrDefaultAsync(cancellationToken).ConfigureAwait(false);
163+
}
164+
135165
using (var cursor = await source.ToCursorAsync(cancellationToken).ConfigureAwait(false))
136166
{
137167
return await cursor.FirstOrDefaultAsync(cancellationToken).ConfigureAwait(false);
@@ -221,6 +251,11 @@ public static class IAsyncCursorSourceExtensions
221251
/// <returns>The only document of a cursor.</returns>
222252
public static TDocument Single<TDocument>(this IAsyncCursorSource<TDocument> source, CancellationToken cancellationToken = default(CancellationToken))
223253
{
254+
if (source is IQueryable<TDocument> queryable && !cancellationToken.CanBeCanceled)
255+
{
256+
return Queryable.Single(queryable);
257+
}
258+
224259
using (var cursor = source.ToCursor(cancellationToken))
225260
{
226261
return cursor.Single(cancellationToken);
@@ -236,6 +271,11 @@ public static class IAsyncCursorSourceExtensions
236271
/// <returns>A Task whose result is the only document of a cursor.</returns>
237272
public static async Task<TDocument> SingleAsync<TDocument>(this IAsyncCursorSource<TDocument> source, CancellationToken cancellationToken = default(CancellationToken))
238273
{
274+
if (source is IMongoQueryableForwarder<TDocument> queryableForwarder)
275+
{
276+
return await queryableForwarder.SingleAsync(cancellationToken).ConfigureAwait(false);
277+
}
278+
239279
using (var cursor = await source.ToCursorAsync(cancellationToken).ConfigureAwait(false))
240280
{
241281
return await cursor.SingleAsync(cancellationToken).ConfigureAwait(false);
@@ -252,6 +292,11 @@ public static class IAsyncCursorSourceExtensions
252292
/// <returns>The only document of a cursor, or a default value if the cursor contains no documents.</returns>
253293
public static TDocument SingleOrDefault<TDocument>(this IAsyncCursorSource<TDocument> source, CancellationToken cancellationToken = default(CancellationToken))
254294
{
295+
if (source is IQueryable<TDocument> queryable && !cancellationToken.CanBeCanceled)
296+
{
297+
return Queryable.SingleOrDefault(queryable);
298+
}
299+
255300
using (var cursor = source.ToCursor(cancellationToken))
256301
{
257302
return cursor.SingleOrDefault(cancellationToken);
@@ -268,6 +313,11 @@ public static class IAsyncCursorSourceExtensions
268313
/// <returns>A Task whose result is the only document of a cursor, or a default value if the cursor contains no documents.</returns>
269314
public static async Task<TDocument> SingleOrDefaultAsync<TDocument>(this IAsyncCursorSource<TDocument> source, CancellationToken cancellationToken = default(CancellationToken))
270315
{
316+
if (source is IMongoQueryableForwarder<TDocument> queryableForwarder)
317+
{
318+
return await queryableForwarder.SingleOrDefaultAsync(cancellationToken).ConfigureAwait(false);
319+
}
320+
271321
using (var cursor = await source.ToCursorAsync(cancellationToken).ConfigureAwait(false))
272322
{
273323
return await cursor.SingleOrDefaultAsync(cancellationToken).ConfigureAwait(false);
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Threading;
17+
using System.Threading.Tasks;
18+
19+
namespace MongoDB.Driver
20+
{
21+
internal interface IMongoQueryableForwarder<T>
22+
{
23+
Task<bool> AnyAsync(CancellationToken cancellationToken);
24+
Task<T> FirstAsync(CancellationToken cancellationToken);
25+
Task<T> FirstOrDefaultAsync(CancellationToken cancellationToken);
26+
Task<T> SingleAsync(CancellationToken cancellationToken);
27+
Task<T> SingleOrDefaultAsync(CancellationToken cancellationToken);
28+
}
29+
}

src/MongoDB.Driver/Linq/Linq3Implementation/MongoQuery.cs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ internal abstract class MongoQuery<TOutput>
3030
public abstract Task<IAsyncCursor<TOutput>> ExecuteAsync();
3131
}
3232

33-
internal class MongoQuery<TDocument, TOutput> : MongoQuery<TOutput>, IOrderedMongoQueryable<TOutput>
33+
internal class MongoQuery<TDocument, TOutput> : MongoQuery<TOutput>, IOrderedMongoQueryable<TOutput>, IMongoQueryableForwarder<TOutput>
3434
{
3535
// private fields
3636
private readonly Expression _expression;
@@ -62,13 +62,13 @@ public MongoQuery(MongoQueryProvider<TDocument> provider, Expression expression)
6262
public override IAsyncCursor<TOutput> Execute()
6363
{
6464
var executableQuery = ExpressionToExecutableQueryTranslator.Translate<TDocument, TOutput>(_provider, _expression);
65-
return executableQuery.Execute(_provider.Session, CancellationToken.None);
65+
return _provider.Execute(executableQuery);
6666
}
6767

6868
public override Task<IAsyncCursor<TOutput>> ExecuteAsync()
6969
{
7070
var executableQuery = ExpressionToExecutableQueryTranslator.Translate<TDocument, TOutput>(_provider, _expression);
71-
return executableQuery.ExecuteAsync(_provider.Session, CancellationToken.None);
71+
return _provider.ExecuteAsync(executableQuery);
7272
}
7373

7474
public IEnumerator<TOutput> GetEnumerator()
@@ -90,13 +90,13 @@ public QueryableExecutionModel GetExecutionModel()
9090
public IAsyncCursor<TOutput> ToCursor(CancellationToken cancellationToken = default)
9191
{
9292
var executableQuery = ExpressionToExecutableQueryTranslator.Translate<TDocument, TOutput>(_provider, _expression);
93-
return executableQuery.Execute(_provider.Session, cancellationToken);
93+
return _provider.Execute(executableQuery, cancellationToken);
9494
}
9595

9696
public Task<IAsyncCursor<TOutput>> ToCursorAsync(CancellationToken cancellationToken = default)
9797
{
9898
var executableQuery = ExpressionToExecutableQueryTranslator.Translate<TDocument, TOutput>(_provider, _expression);
99-
return executableQuery.ExecuteAsync(_provider.Session, cancellationToken);
99+
return _provider.ExecuteAsync(executableQuery, cancellationToken);
100100
}
101101

102102
public override string ToString()
@@ -115,5 +115,11 @@ public override string ToString()
115115
return ex.ToString();
116116
}
117117
}
118+
119+
Task<bool> IMongoQueryableForwarder<TOutput>.AnyAsync(CancellationToken cancellationToken) => MongoQueryable.AnyAsync(this, cancellationToken);
120+
Task<TOutput> IMongoQueryableForwarder<TOutput>.FirstAsync(CancellationToken cancellationToken) => MongoQueryable.FirstAsync(this, cancellationToken);
121+
Task<TOutput> IMongoQueryableForwarder<TOutput>.FirstOrDefaultAsync(CancellationToken cancellationToken) => MongoQueryable.FirstOrDefaultAsync(this, cancellationToken);
122+
Task<TOutput> IMongoQueryableForwarder<TOutput>.SingleAsync(CancellationToken cancellationToken) => MongoQueryable.SingleAsync(this, cancellationToken);
123+
Task<TOutput> IMongoQueryableForwarder<TOutput>.SingleOrDefaultAsync(CancellationToken cancellationToken) => MongoQueryable.SingleOrDefaultAsync(this, cancellationToken);
118124
}
119125
}

src/MongoDB.Driver/Linq/Linq3Implementation/MongoQueryProvider.cs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,33 @@ public override object Execute(Expression expression)
115115
public override TResult Execute<TResult>(Expression expression)
116116
{
117117
var executableQuery = ExpressionToExecutableQueryTranslator.TranslateScalar<TDocument, TResult>(this, expression);
118+
return Execute(executableQuery);
119+
}
120+
121+
public TResult Execute<TResult>(ExecutableQuery<TDocument, TResult> executableQuery)
122+
{
123+
return Execute(executableQuery, CancellationToken.None);
124+
}
125+
126+
public TResult Execute<TResult>(ExecutableQuery<TDocument, TResult> executableQuery, CancellationToken cancellationToken)
127+
{
118128
_mostRecentExecutableQuery = executableQuery;
119-
return executableQuery.Execute(_session, CancellationToken.None);
129+
return executableQuery.Execute(_session, cancellationToken);
120130
}
121131

122132
public override Task<TResult> ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken)
123133
{
124134
var executableQuery = ExpressionToExecutableQueryTranslator.TranslateScalar<TDocument, TResult>(this, expression);
135+
return ExecuteAsync(executableQuery, cancellationToken);
136+
}
137+
138+
public Task<TResult> ExecuteAsync<TResult>(ExecutableQuery<TDocument, TResult> executableQuery)
139+
{
140+
return ExecuteAsync(executableQuery, CancellationToken.None);
141+
}
142+
143+
public Task<TResult> ExecuteAsync<TResult>(ExecutableQuery<TDocument, TResult> executableQuery, CancellationToken cancellationToken)
144+
{
125145
_mostRecentExecutableQuery = executableQuery;
126146
return executableQuery.ExecuteAsync(_session, cancellationToken);
127147
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using FluentAssertions;
17+
using MongoDB.Driver.Linq;
18+
using MongoDB.TestHelpers.XunitExtensions;
19+
using Xunit;
20+
21+
namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira
22+
{
23+
public class CSharp4688Tests : Linq3IntegrationTest
24+
{
25+
[Theory]
26+
[ParameterAttributeData]
27+
public void IMongoQueryable_Any_should_add_expected_stages(
28+
[Values(false, true)] bool async)
29+
{
30+
var collection = GetCollection();
31+
var queryable = collection.AsQueryable();
32+
33+
var (stages, result) = ExecuteQueryCapturingStages(
34+
queryable,
35+
queryable => async ? queryable.AnyAsync().Result : queryable.Any());
36+
37+
AssertStages(
38+
stages,
39+
"{ $limit : 1 }",
40+
"{ $project : { _id : 0, _v : null } }");
41+
result.Should().Be(true);
42+
}
43+
44+
[Theory]
45+
[ParameterAttributeData]
46+
public void IMongoQueryable_First_should_add_expected_stages(
47+
[Values(false, true)] bool async)
48+
{
49+
var collection = GetCollection();
50+
var queryable = collection.AsQueryable();
51+
52+
var (stages, result) = ExecuteQueryCapturingStages(
53+
queryable,
54+
queryable => async ? queryable.FirstAsync().Result : queryable.First());
55+
56+
AssertStages(stages, "{ $limit : 1 }");
57+
result.Id.Should().Be(1);
58+
}
59+
60+
[Theory]
61+
[ParameterAttributeData]
62+
public void IMongoQueryable_FirstOrDefault_should_add_expected_stages(
63+
[Values(false, true)] bool async)
64+
{
65+
var collection = GetCollection();
66+
var queryable = collection.AsQueryable();
67+
68+
var (stages, result) = ExecuteQueryCapturingStages(
69+
queryable,
70+
queryable => async ? queryable.FirstOrDefaultAsync().Result : queryable.FirstOrDefault());
71+
72+
AssertStages(stages, "{ $limit : 1 }");
73+
result.Id.Should().Be(1);
74+
}
75+
76+
[Theory]
77+
[ParameterAttributeData]
78+
public void IMongoQueryable_Single_should_add_expected_stages(
79+
[Values(false, true)] bool async)
80+
{
81+
var collection = GetCollection();
82+
var queryable = collection.AsQueryable().Where(x => x.X == 1);
83+
84+
var (stages, result) = ExecuteQueryCapturingStages(
85+
queryable,
86+
queryable => async ? queryable.SingleAsync().Result : queryable.Single());
87+
88+
AssertStages(
89+
stages,
90+
"{ $match : { X : 1 } }",
91+
"{ $limit : 2 }");
92+
result.Id.Should().Be(1);
93+
}
94+
95+
[Theory]
96+
[ParameterAttributeData]
97+
public void IMongoQueryable_SingleOrDefault_should_add_expected_stages(
98+
[Values(false, true)] bool async)
99+
{
100+
var collection = GetCollection();
101+
var queryable = collection.AsQueryable().Where(x => x.X == 1);
102+
103+
var (stages, result) = ExecuteQueryCapturingStages(
104+
queryable,
105+
queryable => async ? queryable.SingleOrDefaultAsync().Result : queryable.SingleOrDefault());
106+
107+
AssertStages(
108+
stages,
109+
"{ $match : { X : 1 } }",
110+
"{ $limit : 2 }");
111+
result.Id.Should().Be(1);
112+
}
113+
114+
private IMongoCollection<C> GetCollection()
115+
{
116+
var collection = GetCollection<C>();
117+
CreateCollection(
118+
collection,
119+
new C { Id = 1, X = 1 },
120+
new C { Id = 2, X = 2 });
121+
return collection;
122+
}
123+
124+
private class C
125+
{
126+
public int Id { get; set; }
127+
public int X { get; set; }
128+
}
129+
}
130+
}

tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Linq3IntegrationTest.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
* limitations under the License.
1414
*/
1515

16+
using System;
1617
using System.Collections.Generic;
1718
using System.Linq;
1819
using FluentAssertions;
@@ -57,6 +58,14 @@ protected void CreateCollection<TDocument>(IMongoCollection<TDocument> collectio
5758
CreateCollection(collection, (IEnumerable<TDocument>)documents); ;
5859
}
5960

61+
protected (BsonDocument[] Stages, TResult Result) ExecuteQueryCapturingStages<TInput, TResult>(IMongoQueryable<TInput> queryable, Func<IMongoQueryable<TInput>, TResult> executor)
62+
{
63+
var provider = (MongoQueryProvider)queryable.Provider;
64+
var result = executor(queryable);
65+
var stages = provider.GetMostRecentPipelineStages();
66+
return (stages, result);
67+
}
68+
6069
protected IMongoCollection<TDocument> GetCollection<TDocument>(string collectionName = null, LinqProvider linqProvider = LinqProvider.V3)
6170
{
6271
return GetCollection<TDocument>(databaseName: null, collectionName, linqProvider);

0 commit comments

Comments
 (0)