Skip to content

Commit 30d59a5

Browse files
committed
CSHARP-4883: Support SkipWhile and TakeWhile methods in LINQ.
1 parent 37ed2a8 commit 30d59a5

File tree

6 files changed

+294
-0
lines changed

6 files changed

+294
-0
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,16 @@ public static AstExpression Sum(AstExpression array)
868868
return new AstUnaryExpression(AstUnaryOperator.Sum, array);
869869
}
870870

871+
public static AstExpression Switch(IEnumerable<AstSwitchExpressionBranch> branches, AstExpression @default = null)
872+
{
873+
return new AstSwitchExpression(branches, @default);
874+
}
875+
876+
public static AstExpression Switch(IEnumerable<(AstExpression Case, AstExpression Then)> branches, AstExpression @default = null)
877+
{
878+
return new AstSwitchExpression(branches.Select(branch => new AstSwitchExpressionBranch(branch.Case, branch.Then)), @default);
879+
}
880+
871881
public static AstExpression ToLower(AstExpression arg)
872882
{
873883
if (arg is AstConstantExpression constantExpression && constantExpression.Value.BsonType == BsonType.String)

src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ internal static class EnumerableMethod
155155
private static readonly MethodInfo __singleOrDefaultWithPredicate;
156156
private static readonly MethodInfo __singleWithPredicate;
157157
private static readonly MethodInfo __skip;
158+
private static readonly MethodInfo __skipWhile;
158159
private static readonly MethodInfo __sumDecimal;
159160
private static readonly MethodInfo __sumDecimalWithSelector;
160161
private static readonly MethodInfo __sumDouble;
@@ -176,6 +177,7 @@ internal static class EnumerableMethod
176177
private static readonly MethodInfo __sumSingle;
177178
private static readonly MethodInfo __sumSingleWithSelector;
178179
private static readonly MethodInfo __take;
180+
private static readonly MethodInfo __takeWhile;
179181
private static readonly MethodInfo __thenBy;
180182
private static readonly MethodInfo __thenByDescending;
181183
private static readonly MethodInfo __toArray;
@@ -320,6 +322,7 @@ static EnumerableMethod()
320322
__singleOrDefaultWithPredicate = ReflectionInfo.Method((IEnumerable<object> source, Func<object, bool> predicate) => source.SingleOrDefault(predicate));
321323
__singleWithPredicate = ReflectionInfo.Method((IEnumerable<object> source, Func<object, bool> predicate) => source.Single(predicate));
322324
__skip = ReflectionInfo.Method((IEnumerable<object> source, int count) => source.Skip(count));
325+
__skipWhile = ReflectionInfo.Method((IEnumerable<object> source, Func<object, bool> predicate) => source.SkipWhile(predicate));
323326
__sumDecimal = ReflectionInfo.Method((IEnumerable<decimal> source) => source.Sum());
324327
__sumDecimalWithSelector = ReflectionInfo.Method((IEnumerable<object> source, Func<object, decimal> selector) => source.Sum(selector));
325328
__sumDouble = ReflectionInfo.Method((IEnumerable<double> source) => source.Sum());
@@ -341,6 +344,7 @@ static EnumerableMethod()
341344
__sumSingle = ReflectionInfo.Method((IEnumerable<float> source) => source.Sum());
342345
__sumSingleWithSelector = ReflectionInfo.Method((IEnumerable<object> source, Func<object, float> selector) => source.Sum(selector));
343346
__take = ReflectionInfo.Method((IEnumerable<object> source, int count) => source.Take(count));
347+
__takeWhile = ReflectionInfo.Method((IEnumerable<object> source, Func<object, bool> predicate) => source.TakeWhile(predicate));
344348
__thenBy = ReflectionInfo.Method((IOrderedEnumerable<object> source, Func<object, object> keySelector) => source.ThenBy(keySelector));
345349
__thenByDescending = ReflectionInfo.Method((IOrderedEnumerable<object> source, Func<object, object> keySelector) => source.ThenByDescending(keySelector));
346350
__toArray = ReflectionInfo.Method((IEnumerable<object> source) => source.ToArray());
@@ -484,6 +488,7 @@ static EnumerableMethod()
484488
public static MethodInfo SingleOrDefaultWithPredicate => __singleOrDefaultWithPredicate;
485489
public static MethodInfo SingleWithPredicate => __singleWithPredicate;
486490
public static MethodInfo Skip => __skip;
491+
public static MethodInfo SkipWhile => __skipWhile;
487492
public static MethodInfo SumDecimal => __sumDecimal;
488493
public static MethodInfo SumDecimalWithSelector => __sumDecimalWithSelector;
489494
public static MethodInfo SumDouble => __sumDouble;
@@ -505,6 +510,7 @@ static EnumerableMethod()
505510
public static MethodInfo SumSingle => __sumSingle;
506511
public static MethodInfo SumSingleWithSelector => __sumSingleWithSelector;
507512
public static MethodInfo Take => __take;
513+
public static MethodInfo TakeWhile => __takeWhile;
508514
public static MethodInfo ThenBy => __thenBy;
509515
public static MethodInfo ThenByDescending => __thenByDescending;
510516
public static MethodInfo ToArray => __toArray;

src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/QueryableMethod.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ internal static class QueryableMethod
101101
private static readonly MethodInfo __singleOrDefaultWithPredicate;
102102
private static readonly MethodInfo __singleWithPredicate;
103103
private static readonly MethodInfo __skip;
104+
private static readonly MethodInfo __skipWhile;
104105
private static readonly MethodInfo __sumDecimal;
105106
private static readonly MethodInfo __sumDecimalWithSelector;
106107
private static readonly MethodInfo __sumDouble;
@@ -122,6 +123,7 @@ internal static class QueryableMethod
122123
private static readonly MethodInfo __sumSingle;
123124
private static readonly MethodInfo __sumSingleWithSelector;
124125
private static readonly MethodInfo __take;
126+
private static readonly MethodInfo __takeWhile;
125127
private static readonly MethodInfo __thenBy;
126128
private static readonly MethodInfo __thenByDescending;
127129
private static readonly MethodInfo __union;
@@ -209,6 +211,7 @@ static QueryableMethod()
209211
__singleOrDefaultWithPredicate = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, bool>> predicate) => source.SingleOrDefault(predicate));
210212
__singleWithPredicate = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, bool>> predicate) => source.Single(predicate));
211213
__skip = ReflectionInfo.Method((IQueryable<object> source, int count) => Queryable.Skip(source, count));
214+
__skipWhile = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, bool>> predicate) => Queryable.SkipWhile(source, predicate));
212215
__sumDecimal = ReflectionInfo.Method((IQueryable<decimal> source) => source.Sum());
213216
__sumDecimalWithSelector = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, decimal>> selector) => source.Sum(selector));
214217
__sumDouble = ReflectionInfo.Method((IQueryable<double> source) => source.Sum());
@@ -230,6 +233,7 @@ static QueryableMethod()
230233
__sumSingle = ReflectionInfo.Method((IQueryable<float> source) => source.Sum());
231234
__sumSingleWithSelector = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, float>> selector) => source.Sum(selector));
232235
__take = ReflectionInfo.Method((IQueryable<object> source, int count) => Queryable.Take(source, count));
236+
__takeWhile = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, bool>> predicate) => Queryable.TakeWhile(source, predicate));
233237
__thenBy = ReflectionInfo.Method((IOrderedQueryable<object> source, Expression<Func<object, object>> keySelector) => source.ThenBy(keySelector));
234238
__thenByDescending = ReflectionInfo.Method((IOrderedQueryable<object> source, Expression<Func<object, object>> keySelector) => source.ThenByDescending(keySelector));
235239
__union = ReflectionInfo.Method((IQueryable<object> source1, IEnumerable<object> source2) => source1.Union(source2));
@@ -316,6 +320,7 @@ static QueryableMethod()
316320
public static MethodInfo SingleOrDefaultWithPredicate => __singleOrDefaultWithPredicate;
317321
public static MethodInfo SingleWithPredicate => __singleWithPredicate;
318322
public static MethodInfo Skip => __skip;
323+
public static MethodInfo SkipWhile => __skipWhile;
319324
public static MethodInfo SumDecimal => __sumDecimal;
320325
public static MethodInfo SumDecimalWithSelector => __sumDecimalWithSelector;
321326
public static MethodInfo SumDouble => __sumDouble;
@@ -337,6 +342,7 @@ static QueryableMethod()
337342
public static MethodInfo SumSingle => __sumSingle;
338343
public static MethodInfo SumSingleWithSelector => __sumSingleWithSelector;
339344
public static MethodInfo Take => __take;
345+
public static MethodInfo TakeWhile => __takeWhile;
340346
public static MethodInfo ThenBy => __thenBy;
341347
public static MethodInfo ThenByDescending => __thenByDescending;
342348
public static MethodInfo Union => __union;

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC
177177
case "Take":
178178
return SkipOrTakeMethodToAggregationExpressionTranslator.Translate(context, expression);
179179

180+
case "SkipWhile":
181+
case "TakeWhile":
182+
return SkipWhileOrTakeWhileMethodToAggregationExpressionTranslator.Translate(context, expression);
183+
180184
case "StandardDeviationPopulation":
181185
case "StandardDeviationSample":
182186
return StandardDeviationMethodsToAggregationExpressionTranslator.Translate(context, expression);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Linq;
17+
using System.Linq.Expressions;
18+
using System.Reflection;
19+
using MongoDB.Bson;
20+
using MongoDB.Driver.Linq.Linq3Implementation.Ast;
21+
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
22+
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
23+
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
24+
using MongoDB.Driver.Linq.Linq3Implementation.Serializers;
25+
26+
namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators
27+
{
28+
internal static class SkipWhileOrTakeWhileMethodToAggregationExpressionTranslator
29+
{
30+
private static MethodInfo[] __skipWhileOrTakeWhileMethods =
31+
{
32+
EnumerableMethod.SkipWhile,
33+
EnumerableMethod.TakeWhile,
34+
QueryableMethod.SkipWhile,
35+
QueryableMethod.TakeWhile
36+
};
37+
38+
private static MethodInfo[] __skipWhileMethods =
39+
{
40+
EnumerableMethod.SkipWhile,
41+
QueryableMethod.SkipWhile
42+
};
43+
44+
private static MethodInfo[] __takeWhileMethods =
45+
{
46+
EnumerableMethod.TakeWhile,
47+
QueryableMethod.TakeWhile
48+
};
49+
50+
public static TranslatedExpression Translate(TranslationContext context, MethodCallExpression expression)
51+
{
52+
var method = expression.Method;
53+
var arguments = expression.Arguments;
54+
55+
if (method.IsOneOf(__skipWhileOrTakeWhileMethods))
56+
{
57+
var sourceExpression = arguments[0];
58+
var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression);
59+
NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation);
60+
var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer);
61+
62+
var predicateExpression = arguments[1];
63+
var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, predicateExpression);
64+
var predicateParameter = predicateLambda.Parameters.Single();
65+
var thisSymbol = context.CreateSymbol(predicateParameter, "this", itemSerializer);
66+
var predicateTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, predicateLambda, thisSymbol);
67+
68+
var (sourceBinding, sourceAst) = AstExpression.UseVarIfNotSimple("source", sourceTranslation.Ast);
69+
70+
var valueVar = AstExpression.Var("value");
71+
var valuePredicateField = AstExpression.GetField(valueVar, "predicate");
72+
var valueCountField = AstExpression.GetField(valueVar, "count");
73+
74+
var reduceAst = AstExpression.Reduce(
75+
input: sourceAst,
76+
initialValue: new BsonDocument { { "predicate", true }, { "count", 0 } },
77+
@in: AstExpression.Switch(
78+
branches:
79+
[
80+
(AstExpression.Not(valuePredicateField), valueVar),
81+
(predicateTranslation.Ast, AstExpression.ComputedDocument([new AstComputedField("predicate", true), new AstComputedField("count", AstExpression.Add(valueCountField, 1))]))
82+
],
83+
@default: AstExpression.ComputedDocument([new AstComputedField("predicate", false), new AstComputedField("count", valueCountField)])));
84+
85+
var whileVar = AstExpression.Var("while");
86+
var whileBinding = AstExpression.VarBinding(whileVar, reduceAst);
87+
var whileCountField = AstExpression.GetField(whileVar, "count");
88+
89+
var sliceAst = method switch
90+
{
91+
_ when method.IsOneOf(__skipWhileMethods) => AstExpression.Slice(sourceAst, whileCountField, int.MaxValue),
92+
_ when method.IsOneOf(__takeWhileMethods) => AstExpression.Slice(sourceAst, whileCountField),
93+
_ => throw new ExpressionNotSupportedException(expression)
94+
};
95+
96+
var ast = AstExpression.Let(
97+
sourceBinding,
98+
whileBinding,
99+
sliceAst);
100+
101+
var resultSerializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, itemSerializer);
102+
return new TranslatedExpression(expression, ast, resultSerializer);
103+
}
104+
105+
throw new ExpressionNotSupportedException(expression);
106+
}
107+
}
108+
}

0 commit comments

Comments
 (0)