Skip to content

Commit 9abd747

Browse files
committed
implemented CSHARP_462 to support local collection Contains operations translating into In queries.
1 parent 4517ce3 commit 9abd747

File tree

2 files changed

+94
-9
lines changed

2 files changed

+94
-9
lines changed

Driver/Linq/Translators/SelectQuery.cs

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,11 @@ private IMongoQuery BuildContainsQuery(MethodCallExpression methodCallExpression
521521
return BuildStringQuery(methodCallExpression);
522522
}
523523

524+
if (methodCallExpression.Object != null && methodCallExpression.Object.NodeType == ExpressionType.Constant)
525+
{
526+
return BuildInQuery(methodCallExpression);
527+
}
528+
524529
BsonSerializationInfo serializationInfo = null;
525530
ConstantExpression valueExpression = null;
526531
var arguments = methodCallExpression.Arguments.ToArray();
@@ -536,6 +541,10 @@ private IMongoQuery BuildContainsQuery(MethodCallExpression methodCallExpression
536541
{
537542
if (methodCallExpression.Method.DeclaringType == typeof(Enumerable))
538543
{
544+
if (arguments[0].NodeType == ExpressionType.Constant)
545+
{
546+
return BuildInQuery(methodCallExpression);
547+
}
539548
serializationInfo = GetSerializationInfo(arguments[0]);
540549
valueExpression = arguments[1] as ConstantExpression;
541550
}
@@ -610,20 +619,46 @@ private IMongoQuery BuildEqualsQuery(MethodCallExpression methodCallExpression)
610619

611620
private IMongoQuery BuildInQuery(MethodCallExpression methodCallExpression)
612621
{
613-
if (methodCallExpression.Method.DeclaringType == typeof(LinqToMongo))
622+
var methodDeclaringType = methodCallExpression.Method.DeclaringType;
623+
var arguments = methodCallExpression.Arguments.ToArray();
624+
BsonSerializationInfo serializationInfo = null;
625+
ConstantExpression valuesExpression = null;
626+
if (methodDeclaringType == typeof(LinqToMongo))
614627
{
615-
var arguments = methodCallExpression.Arguments.ToArray();
616628
if (arguments.Length == 2)
617629
{
618-
var serializationInfo = GetSerializationInfo(arguments[0]);
619-
var valuesExpression = arguments[1] as ConstantExpression;
620-
if (serializationInfo != null && valuesExpression != null)
621-
{
622-
var serializedValues = SerializeValues(serializationInfo, (IEnumerable)valuesExpression.Value);
623-
return Query.In(serializationInfo.ElementName, serializedValues);
624-
}
630+
serializationInfo = GetSerializationInfo(arguments[0]);
631+
valuesExpression = arguments[1] as ConstantExpression;
625632
}
626633
}
634+
else if (methodDeclaringType == typeof(Enumerable) || methodDeclaringType == typeof(Queryable))
635+
{
636+
if (arguments.Length == 2)
637+
{
638+
serializationInfo = GetSerializationInfo(arguments[1]);
639+
valuesExpression = arguments[0] as ConstantExpression;
640+
}
641+
}
642+
else
643+
{
644+
if (methodDeclaringType.IsGenericType)
645+
{
646+
methodDeclaringType = methodDeclaringType.GetGenericTypeDefinition();
647+
}
648+
649+
bool contains = methodDeclaringType.GetInterface("ICollection`1") != null;
650+
if (contains && arguments.Length == 1)
651+
{
652+
serializationInfo = GetSerializationInfo(arguments[0]);
653+
valuesExpression = methodCallExpression.Object as ConstantExpression;
654+
}
655+
}
656+
657+
if (serializationInfo != null && valuesExpression != null)
658+
{
659+
var serializedValues = SerializeValues(serializationInfo, (IEnumerable)valuesExpression.Value);
660+
return Query.In(serializationInfo.ElementName, serializedValues);
661+
}
627662
return null;
628663
}
629664

DriverUnitTests/Linq/SelectQueryTests.cs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1948,6 +1948,56 @@ where c.A.Any(a => a > 3)
19481948
query.ToList(); // execute query
19491949
}
19501950

1951+
[Test]
1952+
public void TestWhereLocalListContainsX()
1953+
{
1954+
var local = new List<int> { 1, 2, 3 };
1955+
1956+
var query = from c in _collection.AsQueryable<C>()
1957+
where local.Contains(c.X)
1958+
select c;
1959+
1960+
var translatedQuery = MongoQueryTranslator.Translate(query);
1961+
Assert.IsInstanceOf<SelectQuery>(translatedQuery);
1962+
Assert.AreSame(_collection, translatedQuery.Collection);
1963+
Assert.AreSame(typeof(C), translatedQuery.DocumentType);
1964+
1965+
var selectQuery = (SelectQuery)translatedQuery;
1966+
Assert.AreEqual("(C c) => System.Collections.Generic.List`1[System.Int32].Contains(c.X)", ExpressionFormatter.ToString(selectQuery.Where));
1967+
Assert.IsNull(selectQuery.OrderBy);
1968+
Assert.IsNull(selectQuery.Projection);
1969+
Assert.IsNull(selectQuery.Skip);
1970+
Assert.IsNull(selectQuery.Take);
1971+
1972+
Assert.AreEqual("{ \"x\" : { \"$in\" : [1, 2, 3] } }", selectQuery.BuildQuery().ToJson());
1973+
Assert.AreEqual(3, Consume(query));
1974+
}
1975+
1976+
[Test]
1977+
public void TestWhereLocalArrayContainsX()
1978+
{
1979+
var local = new [] { 1, 2, 3 };
1980+
1981+
var query = from c in _collection.AsQueryable<C>()
1982+
where local.Contains(c.X)
1983+
select c;
1984+
1985+
var translatedQuery = MongoQueryTranslator.Translate(query);
1986+
Assert.IsInstanceOf<SelectQuery>(translatedQuery);
1987+
Assert.AreSame(_collection, translatedQuery.Collection);
1988+
Assert.AreSame(typeof(C), translatedQuery.DocumentType);
1989+
1990+
var selectQuery = (SelectQuery)translatedQuery;
1991+
Assert.AreEqual("(C c) => Enumerable.Contains<Int32>(Int32[]:{ 1, 2, 3 }, c.X)", ExpressionFormatter.ToString(selectQuery.Where));
1992+
Assert.IsNull(selectQuery.OrderBy);
1993+
Assert.IsNull(selectQuery.Projection);
1994+
Assert.IsNull(selectQuery.Skip);
1995+
Assert.IsNull(selectQuery.Take);
1996+
1997+
Assert.AreEqual("{ \"x\" : { \"$in\" : [1, 2, 3] } }", selectQuery.BuildQuery().ToJson());
1998+
Assert.AreEqual(3, Consume(query));
1999+
}
2000+
19512001
[Test]
19522002
public void TestWhereAContains2()
19532003
{

0 commit comments

Comments
 (0)