Skip to content

Commit 61e262c

Browse files
author
rstam
committed
Implemented CSHARP-458. Added support for Nullable<Enum> in LINQ queries.
1 parent ca4e15c commit 61e262c

File tree

4 files changed

+204
-6
lines changed

4 files changed

+204
-6
lines changed

Driver/Linq/Expressions/ExpressionFormatter.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ protected override Expression VisitConditional(ConditionalExpression node)
129129
/// <returns>The ConstantExpression.</returns>
130130
protected override Expression VisitConstant(ConstantExpression node)
131131
{
132+
// need to check node.Type instead of value.GetType() because boxed Nullable<T> values are boxed as <T>
133+
if (node.Type.IsGenericType && node.Type.GetGenericTypeDefinition() == typeof(Nullable<>))
134+
{
135+
_sb.AppendFormat("({0})", FriendlyClassName(node.Type));
136+
}
132137
VisitValue(node.Value);
133138
return node;
134139
}
@@ -380,7 +385,7 @@ protected override Expression VisitUnary(UnaryExpression node)
380385
switch (node.NodeType)
381386
{
382387
case ExpressionType.ArrayLength: break;
383-
case ExpressionType.Convert: _sb.AppendFormat("({0})", node.Type.Name); break;
388+
case ExpressionType.Convert: _sb.AppendFormat("({0})", FriendlyClassName(node.Type)); break;
384389
case ExpressionType.Negate: _sb.Append("-"); break;
385390
case ExpressionType.Not: _sb.Append("!"); break;
386391
case ExpressionType.Quote: break;
@@ -425,6 +430,12 @@ private string PublicClassName(Type type)
425430

426431
private void VisitValue(object value)
427432
{
433+
if (value == null)
434+
{
435+
_sb.Append("null");
436+
return;
437+
}
438+
428439
var a = value as Array;
429440
if (a != null && a.Rank == 1)
430441
{

Driver/Linq/Translators/PredicateTranslator.cs

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,13 +271,33 @@ private IMongoQuery BuildComparisonQuery(Expression variableExpression, Expressi
271271
var value = constantExpression.Value;
272272

273273
var unaryExpression = variableExpression as UnaryExpression;
274-
if (unaryExpression != null && (unaryExpression.NodeType == ExpressionType.Convert || unaryExpression.NodeType == ExpressionType.ConvertChecked) && unaryExpression.Operand.Type.IsEnum)
274+
if (unaryExpression != null && (unaryExpression.NodeType == ExpressionType.Convert || unaryExpression.NodeType == ExpressionType.ConvertChecked))
275275
{
276-
var enumType = unaryExpression.Operand.Type;
277-
if (unaryExpression.Type == Enum.GetUnderlyingType(enumType))
276+
if (unaryExpression.Operand.Type.IsEnum)
278277
{
279-
serializationInfo = _serializationInfoHelper.GetSerializationInfo(unaryExpression.Operand);
280-
value = Enum.ToObject(enumType, value); // serialize enum instead of underlying integer
278+
var enumType = unaryExpression.Operand.Type;
279+
if (unaryExpression.Type == Enum.GetUnderlyingType(enumType))
280+
{
281+
serializationInfo = _serializationInfoHelper.GetSerializationInfo(unaryExpression.Operand);
282+
value = Enum.ToObject(enumType, value); // serialize enum instead of underlying integer
283+
}
284+
}
285+
else if (
286+
unaryExpression.Type.IsGenericType &&
287+
unaryExpression.Type.GetGenericTypeDefinition() == typeof(Nullable<>) &&
288+
unaryExpression.Operand.Type.IsGenericType &&
289+
unaryExpression.Operand.Type.GetGenericTypeDefinition() == typeof(Nullable<>) &&
290+
unaryExpression.Operand.Type.GetGenericArguments()[0].IsEnum)
291+
{
292+
var enumType = unaryExpression.Operand.Type.GetGenericArguments()[0];
293+
if (unaryExpression.Type.GetGenericArguments()[0] == Enum.GetUnderlyingType(enumType))
294+
{
295+
serializationInfo = _serializationInfoHelper.GetSerializationInfo(unaryExpression.Operand);
296+
if (value != null)
297+
{
298+
value = Enum.ToObject(enumType, value); // serialize enum instead of underlying integer
299+
}
300+
}
281301
}
282302
}
283303
else

DriverUnitTests/DriverUnitTests.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@
163163
<Compile Include="Jira\CSharp93Tests.cs" />
164164
<Compile Include="Jira\CSharp98Tests.cs" />
165165
<Compile Include="Jira\CSharp100Tests.cs" />
166+
<Compile Include="Linq\SelectNullableTests.cs" />
166167
<Compile Include="Linq\SelectOfTypeHierarchicalTests.cs" />
167168
<Compile Include="Linq\SelectOfTypeTests.cs" />
168169
<Compile Include="Linq\SelectQueryTests.cs" />
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
/* Copyright 2010-2012 10gen Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System;
17+
using System.Collections.Generic;
18+
using System.Linq;
19+
using System.Text;
20+
using NUnit.Framework;
21+
22+
using MongoDB.Bson;
23+
using MongoDB.Bson.Serialization.Attributes;
24+
using MongoDB.Driver;
25+
using MongoDB.Driver.Linq;
26+
27+
namespace MongoDB.DriverUnitTests.Linq
28+
{
29+
[TestFixture]
30+
public class SelectNullableTests
31+
{
32+
private enum E { None, A, B };
33+
34+
private class C
35+
{
36+
public ObjectId Id { get; set; }
37+
[BsonElement("e")]
38+
[BsonRepresentation(BsonType.String)]
39+
public E? E { get; set; }
40+
[BsonElement("x")]
41+
public int? X { get; set; }
42+
}
43+
44+
private MongoServer _server;
45+
private MongoDatabase _database;
46+
private MongoCollection<C> _collection;
47+
48+
[TestFixtureSetUp]
49+
public void Setup()
50+
{
51+
_server = Configuration.TestServer;
52+
_database = Configuration.TestDatabase;
53+
_collection = Configuration.GetTestCollection<C>();
54+
55+
_collection.Drop();
56+
_collection.Insert(new C { E = null });
57+
_collection.Insert(new C { E = E.A });
58+
_collection.Insert(new C { E = E.B });
59+
_collection.Insert(new C { X = null });
60+
_collection.Insert(new C { X = 1 });
61+
_collection.Insert(new C { X = 2 });
62+
}
63+
64+
[Test]
65+
public void TestWhereEEqualsA()
66+
{
67+
var query = from c in _collection.AsQueryable<C>()
68+
where c.E == E.A
69+
select c;
70+
71+
var translatedQuery = MongoQueryTranslator.Translate(query);
72+
Assert.IsInstanceOf<SelectQuery>(translatedQuery);
73+
Assert.AreSame(_collection, translatedQuery.Collection);
74+
Assert.AreSame(typeof(C), translatedQuery.DocumentType);
75+
76+
var selectQuery = (SelectQuery)translatedQuery;
77+
Assert.AreEqual("(C c) => ((Nullable<Int32>)c.E == (Nullable<Int32>)1)", ExpressionFormatter.ToString(selectQuery.Where));
78+
Assert.IsNull(selectQuery.OrderBy);
79+
Assert.IsNull(selectQuery.Projection);
80+
Assert.IsNull(selectQuery.Skip);
81+
Assert.IsNull(selectQuery.Take);
82+
83+
Assert.AreEqual("{ \"e\" : \"A\" }", selectQuery.BuildQuery().ToJson());
84+
Assert.AreEqual(1, Consume(query));
85+
}
86+
87+
[Test]
88+
public void TestWhereEEqualsNull()
89+
{
90+
var query = from c in _collection.AsQueryable<C>()
91+
where c.E == null
92+
select c;
93+
94+
var translatedQuery = MongoQueryTranslator.Translate(query);
95+
Assert.IsInstanceOf<SelectQuery>(translatedQuery);
96+
Assert.AreSame(_collection, translatedQuery.Collection);
97+
Assert.AreSame(typeof(C), translatedQuery.DocumentType);
98+
99+
var selectQuery = (SelectQuery)translatedQuery;
100+
Assert.AreEqual("(C c) => ((Nullable<Int32>)c.E == (Nullable<Int32>)null)", ExpressionFormatter.ToString(selectQuery.Where));
101+
Assert.IsNull(selectQuery.OrderBy);
102+
Assert.IsNull(selectQuery.Projection);
103+
Assert.IsNull(selectQuery.Skip);
104+
Assert.IsNull(selectQuery.Take);
105+
106+
Assert.AreEqual("{ \"e\" : null }", selectQuery.BuildQuery().ToJson());
107+
Assert.AreEqual(4, Consume(query));
108+
}
109+
110+
[Test]
111+
public void TestWhereXEquals1()
112+
{
113+
var query = from c in _collection.AsQueryable<C>()
114+
where c.X == 1
115+
select c;
116+
117+
var translatedQuery = MongoQueryTranslator.Translate(query);
118+
Assert.IsInstanceOf<SelectQuery>(translatedQuery);
119+
Assert.AreSame(_collection, translatedQuery.Collection);
120+
Assert.AreSame(typeof(C), translatedQuery.DocumentType);
121+
122+
var selectQuery = (SelectQuery)translatedQuery;
123+
Assert.AreEqual("(C c) => (c.X == (Nullable<Int32>)1)", ExpressionFormatter.ToString(selectQuery.Where));
124+
Assert.IsNull(selectQuery.OrderBy);
125+
Assert.IsNull(selectQuery.Projection);
126+
Assert.IsNull(selectQuery.Skip);
127+
Assert.IsNull(selectQuery.Take);
128+
129+
Assert.AreEqual("{ \"x\" : 1 }", selectQuery.BuildQuery().ToJson());
130+
Assert.AreEqual(1, Consume(query));
131+
}
132+
133+
[Test]
134+
public void TestWhereXEqualsNull()
135+
{
136+
var query = from c in _collection.AsQueryable<C>()
137+
where c.X == null
138+
select c;
139+
140+
var translatedQuery = MongoQueryTranslator.Translate(query);
141+
Assert.IsInstanceOf<SelectQuery>(translatedQuery);
142+
Assert.AreSame(_collection, translatedQuery.Collection);
143+
Assert.AreSame(typeof(C), translatedQuery.DocumentType);
144+
145+
var selectQuery = (SelectQuery)translatedQuery;
146+
Assert.AreEqual("(C c) => (c.X == (Nullable<Int32>)null)", ExpressionFormatter.ToString(selectQuery.Where));
147+
Assert.IsNull(selectQuery.OrderBy);
148+
Assert.IsNull(selectQuery.Projection);
149+
Assert.IsNull(selectQuery.Skip);
150+
Assert.IsNull(selectQuery.Take);
151+
152+
Assert.AreEqual("{ \"x\" : null }", selectQuery.BuildQuery().ToJson());
153+
Assert.AreEqual(4, Consume(query));
154+
}
155+
156+
private int Consume<T>(IQueryable<T> query)
157+
{
158+
var count = 0;
159+
foreach (var c in query)
160+
{
161+
count++;
162+
}
163+
return count;
164+
}
165+
}
166+
}

0 commit comments

Comments
 (0)