Skip to content

Commit 4592912

Browse files
author
rstam
committed
Implemented support for Enumerable.Any in LINQ where clauses.
1 parent 172ec77 commit 4592912

File tree

4 files changed

+113
-10
lines changed

4 files changed

+113
-10
lines changed

Driver/Linq/Expressions/ExpressionParameterReplacer.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,18 @@ public class ExpressionParameterReplacer : ExpressionVisitor
3030
{
3131
// private fields
3232
private ParameterExpression _fromParameter;
33-
private ParameterExpression _toParameter;
33+
private Expression _toExpression;
3434

3535
// constructors
3636
/// <summary>
3737
/// Initializes a new instance of the ExpressionParameterReplacer class.
3838
/// </summary>
3939
/// <param name="fromParameter">The parameter to be replaced.</param>
40-
/// <param name="toParameter">The new parameter.</param>
41-
public ExpressionParameterReplacer(ParameterExpression fromParameter, ParameterExpression toParameter)
40+
/// <param name="toExpression">The expression that replaces the parameter.</param>
41+
public ExpressionParameterReplacer(ParameterExpression fromParameter, Expression toExpression)
4242
{
4343
_fromParameter = fromParameter;
44-
_toParameter = toParameter;
44+
_toExpression = toExpression;
4545
}
4646

4747
// public methods
@@ -50,11 +50,11 @@ public ExpressionParameterReplacer(ParameterExpression fromParameter, ParameterE
5050
/// </summary>
5151
/// <param name="node">The expression containing the parameter that should be replaced.</param>
5252
/// <param name="fromParameter">The from parameter.</param>
53-
/// <param name="toParameter">The to parameter.</param>
53+
/// <param name="toExpression">The expression that replaces the parameter.</param>
5454
/// <returns>The expression with all occurrences of the parameter replaced.</returns>
55-
public static Expression ReplaceParameter(Expression node, ParameterExpression fromParameter, ParameterExpression toParameter)
55+
public static Expression ReplaceParameter(Expression node, ParameterExpression fromParameter, Expression toExpression)
5656
{
57-
var replacer = new ExpressionParameterReplacer(fromParameter, toParameter);
57+
var replacer = new ExpressionParameterReplacer(fromParameter, toExpression);
5858
return replacer.Visit(node);
5959
}
6060

@@ -67,7 +67,7 @@ protected override Expression VisitParameter(ParameterExpression node)
6767
{
6868
if (node == _fromParameter)
6969
{
70-
return _toParameter;
70+
return _toExpression;
7171
}
7272
return node;
7373
}

Driver/Linq/LinqToMongo.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@ namespace MongoDB.Driver.Linq
3030
/// </summary>
3131
public static class LinqToMongo
3232
{
33+
/// <summary>
34+
/// Represents an arbitrary item in an array value. Can only be used in LINQ queries.
35+
/// </summary>
36+
/// <param name="source">The array value.</param>
37+
/// <returns>Throws an InvalidOperationException if called.</returns>
38+
public static TSource Arbitrary<TSource>(this IEnumerable<TSource> source)
39+
{
40+
throw new InvalidOperationException("The LinqToMongo.Arbitrary method is only intended to be used in LINQ Where clauses.");
41+
}
42+
3343
/// <summary>
3444
/// Determines whether a sequence contains all of the specified values.
3545
/// </summary>

Driver/Linq/Translators/SelectQuery.cs

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,36 @@ private IMongoQuery BuildAndAlsoQuery(BinaryExpression binaryExpression)
272272
return Query.And(BuildQuery(binaryExpression.Left), BuildQuery(binaryExpression.Right));
273273
}
274274

275+
private IMongoQuery BuildAnyQuery(MethodCallExpression methodCallExpression)
276+
{
277+
if (methodCallExpression.Method.DeclaringType == typeof(Enumerable))
278+
{
279+
var arguments = methodCallExpression.Arguments.ToArray();
280+
if (arguments.Length == 1)
281+
{
282+
var serializationInfo = GetSerializationInfo(arguments[0]);
283+
if (serializationInfo != null)
284+
{
285+
return Query.And(
286+
Query.NE(serializationInfo.ElementName, BsonNull.Value),
287+
Query.Not(serializationInfo.ElementName).Size(0));
288+
}
289+
}
290+
else if (arguments.Length == 2)
291+
{
292+
var sourceExpression = arguments[0];
293+
var lambda = (LambdaExpression)arguments[1];
294+
var parameter = lambda.Parameters[0];
295+
var body = lambda.Body;
296+
var arbitraryMethodInfo = typeof(LinqToMongo).GetMethod("Arbitrary").MakeGenericMethod(parameter.Type);
297+
var arbitraryMethodCallExpression = Expression.Call(arbitraryMethodInfo, sourceExpression);
298+
var modifiedBody = ExpressionParameterReplacer.ReplaceParameter(body, parameter, arbitraryMethodCallExpression);
299+
return BuildQuery(modifiedBody);
300+
}
301+
}
302+
return null;
303+
}
304+
275305
private IMongoQuery BuildArrayLengthQuery(BinaryExpression binaryExpression)
276306
{
277307
var leftUnaryExpression = binaryExpression.Left as UnaryExpression;
@@ -609,6 +639,7 @@ private IMongoQuery BuildMethodCallQuery(MethodCallExpression methodCallExpressi
609639
{
610640
switch (methodCallExpression.Method.Name)
611641
{
642+
case "Any": return BuildAnyQuery(methodCallExpression);
612643
case "Contains": return BuildContainsQuery(methodCallExpression);
613644
case "ContainsAll": return BuildContainsAllQuery(methodCallExpression);
614645
case "ContainsAny": return BuildContainsAnyQuery(methodCallExpression);
@@ -888,9 +919,25 @@ private BsonSerializationInfo GetSerializationInfo(IBsonSerializer serializer, E
888919
}
889920

890921
var methodCallExpression = expression as MethodCallExpression;
891-
if (methodCallExpression != null && methodCallExpression.Method.Name == "get_Item")
922+
if (methodCallExpression != null)
892923
{
893-
return GetSerializationInfoGetItem(serializer, methodCallExpression);
924+
switch (methodCallExpression.Method.Name)
925+
{
926+
case "Arbitrary":
927+
if (methodCallExpression.Method.DeclaringType == typeof(LinqToMongo))
928+
{
929+
var arraySerializationInfo = GetSerializationInfo(serializer, methodCallExpression.Arguments[0]);
930+
var itemSerializationInfo = arraySerializationInfo.Serializer.GetItemSerializationInfo();
931+
return new BsonSerializationInfo(
932+
arraySerializationInfo.ElementName,
933+
itemSerializationInfo.Serializer,
934+
itemSerializationInfo.NominalType,
935+
itemSerializationInfo.SerializationOptions);
936+
}
937+
break;
938+
case "get_Item":
939+
return GetSerializationInfoGetItem(serializer, methodCallExpression);
940+
}
894941
}
895942

896943
return null;

DriverOnlineTests/Linq/SelectQueryTests.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1902,6 +1902,52 @@ public void TestUnionWithEqualityComparer()
19021902
query.ToList(); // execute query
19031903
}
19041904

1905+
[Test]
1906+
public void TestWhereAAny()
1907+
{
1908+
var query = from c in _collection.AsQueryable<C>()
1909+
where c.A.Any()
1910+
select c;
1911+
1912+
var translatedQuery = MongoQueryTranslator.Translate(query);
1913+
Assert.IsInstanceOf<SelectQuery>(translatedQuery);
1914+
Assert.AreSame(_collection, translatedQuery.Collection);
1915+
Assert.AreSame(typeof(C), translatedQuery.DocumentType);
1916+
1917+
var selectQuery = (SelectQuery)translatedQuery;
1918+
Assert.AreEqual("(C c) => Enumerable.Any<Int32>(c.A)", ExpressionFormatter.ToString(selectQuery.Where));
1919+
Assert.IsNull(selectQuery.OrderBy);
1920+
Assert.IsNull(selectQuery.Projection);
1921+
Assert.IsNull(selectQuery.Skip);
1922+
Assert.IsNull(selectQuery.Take);
1923+
1924+
Assert.AreEqual("{ \"a\" : { \"$ne\" : null, \"$not\" : { \"$size\" : 0 } } }", selectQuery.BuildQuery().ToJson());
1925+
Assert.AreEqual(1, Consume(query));
1926+
}
1927+
1928+
[Test]
1929+
public void TestWhereAAnyWithPredicate()
1930+
{
1931+
var query = from c in _collection.AsQueryable<C>()
1932+
where c.A.Any(a => a > 3)
1933+
select c;
1934+
1935+
var translatedQuery = MongoQueryTranslator.Translate(query);
1936+
Assert.IsInstanceOf<SelectQuery>(translatedQuery);
1937+
Assert.AreSame(_collection, translatedQuery.Collection);
1938+
Assert.AreSame(typeof(C), translatedQuery.DocumentType);
1939+
1940+
var selectQuery = (SelectQuery)translatedQuery;
1941+
Assert.AreEqual("(C c) => Enumerable.Any<Int32>(c.A, (Int32 a) => (a > 3))", ExpressionFormatter.ToString(selectQuery.Where));
1942+
Assert.IsNull(selectQuery.OrderBy);
1943+
Assert.IsNull(selectQuery.Projection);
1944+
Assert.IsNull(selectQuery.Skip);
1945+
Assert.IsNull(selectQuery.Take);
1946+
1947+
Assert.AreEqual("{ \"a\" : { \"$gt\" : 3 } }", selectQuery.BuildQuery().ToJson());
1948+
Assert.AreEqual(1, Consume(query));
1949+
}
1950+
19051951
[Test]
19061952
public void TestWhereAContains2()
19071953
{

0 commit comments

Comments
 (0)