Skip to content

Commit aada8e8

Browse files
committed
CSHARP-4702: Recognize all Contains methods that are equivalent to Enumerable.Contains.
1 parent 51cf2ff commit aada8e8

File tree

4 files changed

+188
-3
lines changed

4 files changed

+188
-3
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@ public static bool ImplementsIEnumerable(this Type type, out Type itemType)
9595
return false;
9696
}
9797

98+
public static bool ImplementsIEnumerableOf(this Type type, Type itemType)
99+
{
100+
return
101+
ImplementsIEnumerable(type, out var actualItemType) &&
102+
actualItemType == itemType;
103+
}
104+
98105
public static bool Is(this Type type, Type comparand)
99106
{
100107
if (type == comparand)

src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
using System.Collections;
1818
using System.Collections.Generic;
1919
using System.Linq;
20+
using System.Linq.Expressions;
2021
using System.Reflection;
22+
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
2123

2224
namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection
2325
{
@@ -510,6 +512,45 @@ static EnumerableMethod()
510512
public static MethodInfo Zip => __zip;
511513

512514
// public methods
515+
public static bool IsContainsMethod(MethodCallExpression methodCallExpression, out Expression sourceExpression, out Expression valueExpression)
516+
{
517+
var method = methodCallExpression.Method;
518+
var parameters = method.GetParameters();
519+
var arguments = methodCallExpression.Arguments;
520+
521+
if (method.Name == "Contains" && method.ReturnType == typeof(bool))
522+
{
523+
if (method.IsStatic)
524+
{
525+
if (parameters.Length == 2)
526+
{
527+
if (parameters[0].ParameterType.ImplementsIEnumerableOf(parameters[1].ParameterType))
528+
{
529+
sourceExpression = arguments[0];
530+
valueExpression = arguments[1];
531+
return true;
532+
}
533+
}
534+
}
535+
else
536+
{
537+
if (parameters.Length == 1)
538+
{
539+
if (method.DeclaringType.ImplementsIEnumerableOf(parameters[0].ParameterType))
540+
{
541+
sourceExpression = methodCallExpression.Object;
542+
valueExpression = arguments[0];
543+
return true;
544+
}
545+
}
546+
}
547+
}
548+
549+
sourceExpression = null;
550+
valueExpression = null;
551+
return false;
552+
}
553+
513554
public static MethodInfo MakeSelect(Type sourceType, Type resultType)
514555
{
515556
return __select.MakeGenericMethod(sourceType, resultType);

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/AnyWithContainsInPredicateMethodToFilterTranslator.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
using System.Collections;
1717
using System.Linq.Expressions;
18+
using System.Reflection;
1819
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Filters;
1920
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
2021
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
@@ -72,10 +73,9 @@ public static AstFilter Translate(TranslationContext context, Expression arrayFi
7273
private static bool IsContainsParameterExpression(Expression predicateBody, ParameterExpression predicateParameter, out Expression innerSourceExpression)
7374
{
7475
if (predicateBody is MethodCallExpression methodCallExpression &&
75-
methodCallExpression.Method.Is(EnumerableMethod.Contains) &&
76-
methodCallExpression.Arguments[1] == predicateParameter)
76+
EnumerableMethod.IsContainsMethod(methodCallExpression, out innerSourceExpression, out var valueExpression) &&
77+
valueExpression == predicateParameter)
7778
{
78-
innerSourceExpression = methodCallExpression.Arguments[0];
7979
return true;
8080
}
8181

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Collections.Generic;
17+
using System.Linq;
18+
using FluentAssertions;
19+
using MongoDB.Bson;
20+
using MongoDB.Driver.Linq;
21+
using MongoDB.TestHelpers.XunitExtensions;
22+
using Xunit;
23+
24+
namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira
25+
{
26+
public class CSharp702Tests : Linq3IntegrationTest
27+
{
28+
[Theory]
29+
[ParameterAttributeData]
30+
public void Query1_using_list_should_work(
31+
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
32+
{
33+
var collection = GetCollection(linqProvider);
34+
var lookingFor = new List<string> { "value1", "value2" };
35+
36+
var queryable = collection.AsQueryable()
37+
.Where(model => lookingFor.Any(value => model.List.Contains(value)));
38+
39+
var stages = Translate(collection, queryable);
40+
AssertStages(stages, "{ $match : { List : { $in : ['value1', 'value2'] } } }");
41+
42+
var results = queryable.ToList();
43+
results.Select(model => model.Id).Should().BeEquivalentTo(4, 5);
44+
}
45+
46+
[Theory]
47+
[ParameterAttributeData]
48+
public void Query2_using_list_should_work(
49+
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
50+
{
51+
var collection = GetCollection(linqProvider);
52+
var lookingFor = new List<string> { "value1", "value2" };
53+
54+
var queryable = collection.AsQueryable()
55+
.Where(model => model.List.Any(value => lookingFor.Contains(value)));
56+
57+
var stages = Translate(collection, queryable);
58+
if (linqProvider == LinqProvider.V2)
59+
{
60+
AssertStages(stages, "{ $match : { List : { $elemMatch : { $in : ['value1', 'value2'] } } } }");
61+
}
62+
else
63+
{
64+
AssertStages(stages, "{ $match : { List : { $in : ['value1', 'value2'] } } }");
65+
}
66+
67+
var results = queryable.ToList();
68+
results.Select(model => model.Id).Should().BeEquivalentTo(4, 5);
69+
}
70+
71+
[Theory]
72+
[ParameterAttributeData]
73+
public void Query1_using_hashset_should_work(
74+
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
75+
{
76+
var collection = GetCollection(linqProvider);
77+
var lookingFor = new List<string> { "value1", "value2" };
78+
79+
var queryable = collection.AsQueryable()
80+
.Where(model => lookingFor.Any(value => model.HashSet.Contains(value)));
81+
82+
var stages = Translate(collection, queryable);
83+
AssertStages(stages, "{ $match : { HashSet : { $in : ['value1', 'value2'] } } }");
84+
85+
var results = queryable.ToList();
86+
results.Select(model => model.Id).Should().BeEquivalentTo(4, 5);
87+
}
88+
89+
[Theory]
90+
[ParameterAttributeData]
91+
public void Query2_using_hashset_should_work(
92+
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
93+
{
94+
var collection = GetCollection(linqProvider);
95+
var lookingFor = new List<string> { "value1", "value2" };
96+
97+
var queryable = collection.AsQueryable()
98+
.Where(model => model.HashSet.Any(value => lookingFor.Contains(value)));
99+
100+
var stages = Translate(collection, queryable);
101+
if (linqProvider == LinqProvider.V2)
102+
{
103+
AssertStages(stages, "{ $match : { HashSet : { $elemMatch : { $in : ['value1', 'value2'] } } } }");
104+
}
105+
else
106+
{
107+
AssertStages(stages, "{ $match : { HashSet : { $in : ['value1', 'value2'] } } }");
108+
}
109+
110+
var results = queryable.ToList();
111+
results.Select(model => model.Id).Should().BeEquivalentTo(4, 5);
112+
}
113+
114+
private IMongoCollection<Model> GetCollection(LinqProvider linqProvider)
115+
{
116+
var collection = GetCollection<Model>("test", linqProvider);
117+
var documentsCollection = GetCollection<BsonDocument>("test");
118+
CreateCollection(
119+
documentsCollection,
120+
BsonDocument.Parse("{ _id : 1 }"),
121+
BsonDocument.Parse("{ _id : 2, List : null, HashSet : null }"),
122+
BsonDocument.Parse("{ _id : 3, List : [], HashSet : [] }"),
123+
BsonDocument.Parse("{ _id : 4, List : ['value1'], HashSet : ['value1'] }"),
124+
BsonDocument.Parse("{ _id : 5, List : ['value1', 'value2'], HashSet : ['value1', 'value2'] }"),
125+
BsonDocument.Parse("{ _id : 6, List : ['value3'], HashSet : ['value3'] }"),
126+
BsonDocument.Parse("{ _id : 7, List : ['value3', 'value4'], HashSet : ['value3', 'value4'] }"));
127+
return collection;
128+
}
129+
130+
private class Model
131+
{
132+
public int Id { get; set; }
133+
public List<string> List { get; set; }
134+
public HashSet<string> HashSet { get; set; }
135+
}
136+
}
137+
}

0 commit comments

Comments
 (0)