Skip to content

Commit fc04833

Browse files
committed
CSHARP-4859: Nested AsQueryable.
1 parent 7ed17f8 commit fc04833

File tree

66 files changed

+6200
-220
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+6200
-220
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ExpressionHelper.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
* limitations under the License.
1414
*/
1515

16+
using System.Linq;
1617
using System.Linq.Expressions;
18+
using System.Reflection;
1719
using MongoDB.Driver.Core.Misc;
1820

1921
namespace MongoDB.Driver.Linq.Linq3Implementation.Misc
@@ -27,5 +29,20 @@ public static LambdaExpression UnquoteLambda(Expression expression)
2729
var unaryExpression = (UnaryExpression)expression;
2830
return (LambdaExpression)unaryExpression.Operand;
2931
}
32+
33+
public static LambdaExpression UnquoteLambdaIfQueryableMethod(MethodInfo method, Expression expression)
34+
{
35+
Ensure.IsNotNull(method, nameof(method));
36+
Ensure.IsNotNull(expression, nameof(expression));
37+
38+
if (method.DeclaringType == typeof(Queryable))
39+
{
40+
return UnquoteLambda(expression);
41+
}
42+
else
43+
{
44+
return (LambdaExpression)expression;
45+
}
46+
}
3047
}
3148
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Linq;
17+
using System.Linq.Expressions;
18+
using MongoDB.Driver.Linq.Linq3Implementation.Serializers;
19+
using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators;
20+
21+
namespace MongoDB.Driver.Linq.Linq3Implementation.Misc
22+
{
23+
internal static class NestedAsQueryableHelper
24+
{
25+
public static void EnsureQueryableMethodHasNestedAsQueryableSource(MethodCallExpression expression, AggregationExpression sourceTranslation)
26+
{
27+
if (expression.Method.DeclaringType == typeof(Queryable) &&
28+
sourceTranslation.Serializer is not INestedAsQueryableSerializer)
29+
{
30+
throw new ExpressionNotSupportedException(expression, because: "source serializer is not a NestedAsQueryableSerializer");
31+
}
32+
}
33+
34+
public static void EnsureQueryableMethodHasNestedAsOrderedQueryableSource(MethodCallExpression expression, AggregationExpression sourceTranslation)
35+
{
36+
if (expression.Method.DeclaringType == typeof(Queryable) &&
37+
sourceTranslation.Serializer is not INestedAsOrderedQueryableSerializer)
38+
{
39+
throw new ExpressionNotSupportedException(expression, because: "source serializer is not a NestedAsOrderedQueryableSerializer");
40+
}
41+
}
42+
}
43+
}

src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,41 @@ public static bool IsContainsMethod(MethodCallExpression methodCallExpression, o
551551
return false;
552552
}
553553

554+
public static bool IsToArrayMethod(MethodCallExpression methodCallExpression, out Expression sourceExpression)
555+
{
556+
var method = methodCallExpression.Method;
557+
var parameters = method.GetParameters();
558+
var arguments = methodCallExpression.Arguments;
559+
560+
if (method.Name == "ToArray")
561+
{
562+
var returnType = method.ReturnType;
563+
if (returnType.IsArray)
564+
{
565+
var returnItemType = returnType.GetElementType();
566+
567+
sourceExpression = method switch
568+
{
569+
_ when method.IsStatic && parameters.Length == 1 => arguments[0],
570+
_ when !method.IsStatic && parameters.Length == 0 => methodCallExpression.Object,
571+
_ => null
572+
};
573+
if (sourceExpression != null)
574+
{
575+
var sourceType = sourceExpression.Type;
576+
if (sourceType.ImplementsIEnumerable(out var sourceItemType) &&
577+
sourceItemType == returnItemType)
578+
{
579+
return true;
580+
}
581+
}
582+
}
583+
}
584+
585+
sourceExpression = null;
586+
return false;
587+
}
588+
554589
public static MethodInfo MakeSelect(Type sourceType, Type resultType)
555590
{
556591
return __select.MakeGenericMethod(sourceType, resultType);

src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/QueryableMethod.cs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection
2424
internal static class QueryableMethod
2525
{
2626
// private static fields
27-
private static readonly MethodInfo __aggregate;
27+
private static readonly MethodInfo __aggregateWithFunc;
2828
private static readonly MethodInfo __aggregateWithSeedAndFunc;
29-
private static readonly MethodInfo __aggregateWithSeedFuncAndSelector;
29+
private static readonly MethodInfo __aggregateWithSeedFuncAndResultSelector;
3030
private static readonly MethodInfo __all;
3131
private static readonly MethodInfo __any;
3232
private static readonly MethodInfo __anyWithPredicate;
33+
private static readonly MethodInfo __asQueryable;
3334
private static readonly MethodInfo __averageDecimal;
3435
private static readonly MethodInfo __averageDecimalWithSelector;
3536
private static readonly MethodInfo __averageDouble;
@@ -128,12 +129,13 @@ internal static class QueryableMethod
128129
// static constructor
129130
static QueryableMethod()
130131
{
131-
__aggregate = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, object, object>> func) => source.Aggregate(func));
132+
__aggregateWithFunc = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, object, object>> func) => source.Aggregate(func));
132133
__aggregateWithSeedAndFunc = ReflectionInfo.Method((IQueryable<object> source, object seed, Expression<Func<object, object, object>> func) => source.Aggregate(seed, func));
133-
__aggregateWithSeedFuncAndSelector = ReflectionInfo.Method((IQueryable<object> source, object seed, Expression<Func<object, object, object>> func, Expression<Func<object, object>> selector) => source.Aggregate(seed, func, selector));
134+
__aggregateWithSeedFuncAndResultSelector = ReflectionInfo.Method((IQueryable<object> source, object seed, Expression<Func<object, object, object>> func, Expression<Func<object, object>> selector) => source.Aggregate(seed, func, selector));
134135
__all = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, bool>> predicate) => source.All(predicate));
135136
__any = ReflectionInfo.Method((IQueryable<object> source) => source.Any());
136137
__anyWithPredicate = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, bool>> predicate) => source.Any(predicate));
138+
__asQueryable = ReflectionInfo.Method((IEnumerable<object> source) => source.AsQueryable());
137139
__averageDecimal = ReflectionInfo.Method((IQueryable<decimal> source) => source.Average());
138140
__averageDecimalWithSelector = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, decimal>> selector) => source.Average(selector));
139141
__averageDouble = ReflectionInfo.Method((IQueryable<double> source) => source.Average());
@@ -149,7 +151,7 @@ static QueryableMethod()
149151
__averageNullableInt32 = ReflectionInfo.Method((IQueryable<int?> source) => source.Average());
150152
__averageNullableInt32WithSelector = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, int?>> selector) => source.Average(selector));
151153
__averageNullableInt64 = ReflectionInfo.Method((IQueryable<long?> source) => source.Average());
152-
__averageNullableInt64WithSelector = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, double?>> selector) => source.Average(selector));
154+
__averageNullableInt64WithSelector = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, long?>> selector) => source.Average(selector));
153155
__averageNullableSingle = ReflectionInfo.Method((IQueryable<float?> source) => source.Average());
154156
__averageNullableSingleWithSelector = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, float?>> selector) => source.Average(selector));
155157
__averageSingle = ReflectionInfo.Method((IQueryable<float> source) => source.Average());
@@ -231,11 +233,13 @@ static QueryableMethod()
231233
}
232234

233235
// public properties
234-
public static MethodInfo Aggregate => __aggregate;
236+
public static MethodInfo AggregateWithFunc => __aggregateWithFunc;
235237
public static MethodInfo AggregateWithSeedAndFunc => __aggregateWithSeedAndFunc;
236-
public static MethodInfo AggregateWithSeedFuncAndSelector => __aggregateWithSeedFuncAndSelector;
238+
public static MethodInfo AggregateWithSeedFuncAndResultSelector => __aggregateWithSeedFuncAndResultSelector;
237239
public static MethodInfo All => __all;
238240
public static MethodInfo Any => __any;
241+
public static MethodInfo AnyWithPredicate => __anyWithPredicate;
242+
public static MethodInfo AsQueryable => __asQueryable;
239243
public static MethodInfo AverageDecimal => __averageDecimal;
240244
public static MethodInfo AverageDecimalWithSelector => __averageDecimalWithSelector;
241245
public static MethodInfo AverageDouble => __averageDouble;
@@ -256,7 +260,6 @@ static QueryableMethod()
256260
public static MethodInfo AverageNullableSingleWithSelector => __averageNullableSingleWithSelector;
257261
public static MethodInfo AverageSingle => __averageSingle;
258262
public static MethodInfo AverageSingleWithSelector => __averageSingleWithSelector;
259-
public static MethodInfo AnyWithPredicate => __anyWithPredicate;
260263
public static MethodInfo Cast => __cast;
261264
public static MethodInfo Concat => __concat;
262265
public static MethodInfo Contains => __contains;

src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/IEnumerableSerializer.cs

Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,53 +16,19 @@
1616
using System;
1717
using System.Collections.Generic;
1818
using MongoDB.Bson.Serialization;
19-
using MongoDB.Bson.Serialization.Serializers;
20-
using MongoDB.Driver.Core.Misc;
2119

2220
namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers
2321
{
24-
internal class IEnumerableSerializer<TItem> : SerializerBase<IEnumerable<TItem>>, IBsonArraySerializer
22+
internal class IEnumerableSerializer<TItem> : IEnumerableSerializerBase<IEnumerable<TItem>, TItem>
2523
{
26-
// private fields
27-
private readonly IBsonSerializer<TItem> _itemSerializer;
28-
2924
// constructors
3025
public IEnumerableSerializer(IBsonSerializer<TItem> itemSerializer)
26+
: base(itemSerializer)
3127
{
32-
_itemSerializer = Ensure.IsNotNull(itemSerializer, nameof(itemSerializer));
3328
}
3429

35-
// public methods
36-
public override IEnumerable<TItem> Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args)
37-
{
38-
var reader = context.Reader;
39-
reader.ReadStartArray();
40-
var value = new List<TItem>();
41-
while (reader.ReadBsonType() != 0)
42-
{
43-
var item = _itemSerializer.Deserialize(context);
44-
value.Add(item);
45-
}
46-
reader.ReadEndArray();
47-
return value;
48-
}
49-
50-
public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, IEnumerable<TItem> value)
51-
{
52-
var writer = context.Writer;
53-
writer.WriteStartArray();
54-
foreach (var item in value)
55-
{
56-
_itemSerializer.Serialize(context, item);
57-
}
58-
writer.WriteEndArray();
59-
}
60-
61-
public bool TryGetItemSerializationInfo(out BsonSerializationInfo serializationInfo)
62-
{
63-
serializationInfo = new BsonSerializationInfo(null, _itemSerializer, typeof(TItem));
64-
return true;
65-
}
30+
// protected methods
31+
protected override IEnumerable<TItem> CreateDeserializedValue(List<TItem> items) => items;
6632
}
6733

6834
internal static class IEnumerableSerializer
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Collections.Generic;
17+
using MongoDB.Bson.Serialization;
18+
using MongoDB.Bson.Serialization.Serializers;
19+
using MongoDB.Driver.Core.Misc;
20+
21+
namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers
22+
{
23+
internal abstract class IEnumerableSerializerBase<TEnumerable, TItem> : SerializerBase<TEnumerable>, IBsonArraySerializer
24+
where TEnumerable : IEnumerable<TItem>
25+
{
26+
// private fields
27+
private readonly IBsonSerializer<TItem> _itemSerializer;
28+
29+
// constructors
30+
public IEnumerableSerializerBase(IBsonSerializer<TItem> itemSerializer)
31+
{
32+
_itemSerializer = Ensure.IsNotNull(itemSerializer, nameof(itemSerializer));
33+
}
34+
35+
// public methods
36+
public override TEnumerable Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args)
37+
{
38+
var reader = context.Reader;
39+
reader.ReadStartArray();
40+
var items = new List<TItem>();
41+
while (reader.ReadBsonType() != 0)
42+
{
43+
var item = _itemSerializer.Deserialize(context);
44+
items.Add(item);
45+
}
46+
reader.ReadEndArray();
47+
return CreateDeserializedValue(items);
48+
}
49+
50+
public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, TEnumerable value)
51+
{
52+
var writer = context.Writer;
53+
writer.WriteStartArray();
54+
foreach (var item in value)
55+
{
56+
_itemSerializer.Serialize(context, item);
57+
}
58+
writer.WriteEndArray();
59+
}
60+
61+
public bool TryGetItemSerializationInfo(out BsonSerializationInfo serializationInfo)
62+
{
63+
serializationInfo = new BsonSerializationInfo(null, _itemSerializer, typeof(TItem));
64+
return true;
65+
}
66+
67+
protected abstract TEnumerable CreateDeserializedValue(List<TItem> items);
68+
}
69+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System;
17+
using System.Collections;
18+
using System.Collections.Generic;
19+
using System.Linq;
20+
using MongoDB.Bson.Serialization;
21+
22+
namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers
23+
{
24+
internal class IOrderedEnumerableSerializer<TItem> : IEnumerableSerializerBase<IOrderedEnumerable<TItem>, TItem>
25+
{
26+
// constructors
27+
public IOrderedEnumerableSerializer(IBsonSerializer<TItem> itemSerializer)
28+
: base(itemSerializer)
29+
{
30+
}
31+
32+
// protected methods
33+
protected override IOrderedEnumerable<TItem> CreateDeserializedValue(List<TItem> items) => new IOrderedEnumerableWrapper(items);
34+
35+
private class IOrderedEnumerableWrapper : IOrderedEnumerable<TItem>
36+
{
37+
private readonly IEnumerable<TItem> _items;
38+
public IOrderedEnumerableWrapper(IEnumerable<TItem> items) => _items = items;
39+
public IOrderedEnumerable<TItem> CreateOrderedEnumerable<TKey>(Func<TItem, TKey> keySelector, IComparer<TKey> comparer, bool descending) => throw new InvalidOperationException("ThenBy or ThenByDescending cannot be executed client-side and should be moved to the LINQ query.");
40+
public IEnumerator<TItem> GetEnumerator() => _items.GetEnumerator();
41+
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
42+
}
43+
}
44+
45+
internal static class IOrderedEnumerableSerializer
46+
{
47+
public static IBsonSerializer Create(IBsonSerializer itemSerializer)
48+
{
49+
var itemType = itemSerializer.ValueType;
50+
var serializerType = typeof(IOrderedEnumerableSerializer<>).MakeGenericType(itemType);
51+
return (IBsonSerializer)Activator.CreateInstance(serializerType, itemSerializer);
52+
}
53+
}
54+
}

0 commit comments

Comments
 (0)