Skip to content

Commit 28c9bf0

Browse files
author
rstam
committed
Using p.S.Count() == n in a LINQ query where p.S is of type string was erroneously mapping to an array length query. Now it is treated as an alternative for p.S.Length == n.
1 parent 4e5322b commit 28c9bf0

File tree

2 files changed

+59
-37
lines changed

2 files changed

+59
-37
lines changed

Driver/Linq/Translators/SelectQuery.cs

Lines changed: 36 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
using System.Collections.ObjectModel;
2020
using System.Linq;
2121
using System.Linq.Expressions;
22+
using System.Reflection;
2223
using System.Text;
2324
using System.Text.RegularExpressions;
2425

@@ -249,49 +250,40 @@ private IMongoQuery BuildArrayLengthQuery(Expression variableExpression, Express
249250
{
250251
return null;
251252
}
253+
var value = ToInt32(constantExpression);
252254

253255
BsonSerializationInfo serializationInfo = null;
254-
var value = ToInt32(constantExpression);
255256

256257
var unaryExpression = variableExpression as UnaryExpression;
257-
if (unaryExpression != null)
258+
if (unaryExpression != null && unaryExpression.NodeType == ExpressionType.ArrayLength)
258259
{
259-
if (unaryExpression.NodeType == ExpressionType.ArrayLength)
260+
var arrayMemberExpression = unaryExpression.Operand as MemberExpression;
261+
if (arrayMemberExpression != null)
260262
{
261-
var memberExpression = unaryExpression.Operand as MemberExpression;
262-
if (memberExpression != null)
263-
{
264-
serializationInfo = GetSerializationInfo(memberExpression);
265-
}
263+
serializationInfo = GetSerializationInfo(arrayMemberExpression);
266264
}
267265
}
268266

269-
var countPropertyExpression = variableExpression as MemberExpression;
270-
if (countPropertyExpression != null)
267+
var memberExpression = variableExpression as MemberExpression;
268+
if (memberExpression != null && memberExpression.Member.Name == "Count")
271269
{
272-
if (countPropertyExpression.Member.Name == "Count")
270+
var arrayMemberExpression = memberExpression.Expression as MemberExpression;
271+
if (arrayMemberExpression != null)
273272
{
274-
var memberExpression = countPropertyExpression.Expression as MemberExpression;
275-
if (memberExpression != null)
276-
{
277-
serializationInfo = GetSerializationInfo(memberExpression);
278-
}
273+
serializationInfo = GetSerializationInfo(arrayMemberExpression);
279274
}
280275
}
281276

282-
var countMethodCallExpression = variableExpression as MethodCallExpression;
283-
if (countMethodCallExpression != null)
277+
var methodCallExpression = variableExpression as MethodCallExpression;
278+
if (methodCallExpression != null && methodCallExpression.Method.Name == "Count" && methodCallExpression.Method.DeclaringType == typeof(Enumerable))
284279
{
285-
if (countMethodCallExpression.Method.Name == "Count")
280+
var arguments = methodCallExpression.Arguments.ToArray();
281+
if (arguments.Length == 1)
286282
{
287-
var arguments = countMethodCallExpression.Arguments.ToArray();
288-
if (arguments.Length == 1)
283+
var arrayMemberExpression = methodCallExpression.Arguments[0] as MemberExpression;
284+
if (arrayMemberExpression != null && arrayMemberExpression.Type != typeof(string))
289285
{
290-
var memberExpression = countMethodCallExpression.Arguments[0] as MemberExpression;
291-
if (memberExpression != null)
292-
{
293-
serializationInfo = GetSerializationInfo(memberExpression);
294-
}
286+
serializationInfo = GetSerializationInfo(arrayMemberExpression);
295287
}
296288
}
297289
}
@@ -756,23 +748,30 @@ private IMongoQuery BuildStringLengthQuery(Expression variableExpression, Expres
756748
{
757749
return null;
758750
}
751+
var value = ToInt32(constantExpression);
759752

760753
BsonSerializationInfo serializationInfo = null;
761-
var value = ToInt32(constantExpression);
762754

763-
var lengthPropertyExpression = variableExpression as MemberExpression;
764-
if (lengthPropertyExpression != null)
755+
var memberExpression = variableExpression as MemberExpression;
756+
if (memberExpression != null && memberExpression.Member.Name == "Length")
765757
{
766-
if (lengthPropertyExpression.Member.Name == "Length")
758+
var stringMemberExpression = memberExpression.Expression as MemberExpression;
759+
if (stringMemberExpression != null && stringMemberExpression.Type == typeof(string))
767760
{
768-
var memberExpression = lengthPropertyExpression.Expression as MemberExpression;
769-
if (memberExpression != null)
761+
serializationInfo = GetSerializationInfo(stringMemberExpression);
762+
}
763+
}
764+
765+
var methodCallExpression = variableExpression as MethodCallExpression;
766+
if (methodCallExpression != null && methodCallExpression.Method.Name == "Count" && methodCallExpression.Method.DeclaringType == typeof(Enumerable))
767+
{
768+
var args = methodCallExpression.Arguments.ToArray();
769+
if (args.Length == 1)
770+
{
771+
var stringMemberExpression = args[0] as MemberExpression;
772+
if (stringMemberExpression != null && stringMemberExpression.Type == typeof(string))
770773
{
771-
serializationInfo = GetSerializationInfo(memberExpression);
772-
if (serializationInfo != null && serializationInfo.NominalType != typeof(string))
773-
{
774-
serializationInfo = null;
775-
}
774+
serializationInfo = GetSerializationInfo(stringMemberExpression);
776775
}
777776
}
778777
}

DriverUnitTests/Linq/SelectQueryTests.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4282,6 +4282,29 @@ where c.S.Contains(".")
42824282
Assert.AreEqual(0, Consume(query));
42834283
}
42844284

4285+
[Test]
4286+
public void TestWhereSCountEquals3()
4287+
{
4288+
var query = from c in _collection.AsQueryable<C>()
4289+
where c.S.Count() == 3
4290+
select c;
4291+
4292+
var translatedQuery = MongoQueryTranslator.Translate(query);
4293+
Assert.IsInstanceOf<SelectQuery>(translatedQuery);
4294+
Assert.AreSame(_collection, translatedQuery.Collection);
4295+
Assert.AreSame(typeof(C), translatedQuery.DocumentType);
4296+
4297+
var selectQuery = (SelectQuery)translatedQuery;
4298+
Assert.AreEqual("(C c) => (Enumerable.Count<Char>(c.S) == 3)", ExpressionFormatter.ToString(selectQuery.Where));
4299+
Assert.IsNull(selectQuery.OrderBy);
4300+
Assert.IsNull(selectQuery.Projection);
4301+
Assert.IsNull(selectQuery.Skip);
4302+
Assert.IsNull(selectQuery.Take);
4303+
4304+
Assert.AreEqual("{ \"s\" : /^.{3}$/s }", selectQuery.BuildQuery().ToJson());
4305+
Assert.AreEqual(1, Consume(query));
4306+
}
4307+
42854308
[Test]
42864309
public void TestWhereSEndsWithAbc()
42874310
{

0 commit comments

Comments
 (0)