Skip to content

Commit 8803d6b

Browse files
committed
CSHARP-1228: fixed issue with enumerations in typeless field expressions not getting recognized because of a Unary conversion operator.
1 parent edf4c62 commit 8803d6b

File tree

6 files changed

+123
-72
lines changed

6 files changed

+123
-72
lines changed

src/MongoDB.Driver.Tests/FieldDefinitionTests.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
using System;
1717
using System.Collections.Generic;
1818
using System.Linq;
19+
using System.Linq.Expressions;
1920
using System.Text;
2021
using System.Threading.Tasks;
2122
using FluentAssertions;
@@ -162,13 +163,39 @@ public void Should_resolve_array_name_with_positional_operator_with_multiple_dot
162163
renderedField.FieldSerializer.Should().BeOfType<StringSerializer>();
163164
}
164165

166+
[Test]
167+
public void Should_resolve_an_enum_with_field_type()
168+
{
169+
var subject = new ExpressionFieldDefinition<Person, Gender>(x => x.Gender);
170+
171+
var renderedField = subject.Render(BsonSerializer.SerializerRegistry.GetSerializer<Person>(), BsonSerializer.SerializerRegistry);
172+
173+
renderedField.FieldName.Should().Be("g");
174+
renderedField.FieldSerializer.Should().BeOfType<EnumSerializer<Gender>>();
175+
}
176+
177+
[Test]
178+
public void Should_resolve_an_enum_without_field_type()
179+
{
180+
Expression<Func<Person, object>> exp = x => x.Gender;
181+
var subject = new ExpressionFieldDefinition<Person>(exp);
182+
183+
var renderedField = subject.Render(BsonSerializer.SerializerRegistry.GetSerializer<Person>(), BsonSerializer.SerializerRegistry);
184+
185+
renderedField.FieldName.Should().Be("g");
186+
renderedField.FieldSerializer.Should().BeOfType<EnumSerializer<Gender>>();
187+
}
188+
165189
private class Person
166190
{
167191
[BsonElement("name")]
168192
public Name Name { get; set; }
169193

170194
[BsonElement("pets")]
171195
public IEnumerable<Pet> Pets { get; set; }
196+
197+
[BsonElement("g")]
198+
public Gender Gender { get; set; }
172199
}
173200

174201
private class Name
@@ -187,5 +214,11 @@ private class Pet
187214
[BsonElement("name")]
188215
public Name Name { get; set; }
189216
}
217+
218+
private enum Gender
219+
{
220+
Male,
221+
Female
222+
}
190223
}
191224
}

src/MongoDB.Driver.Tests/Linq/Translators/LegacyPredicateTranslatorTests.cs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public void Setup()
5151
{
5252
new C { Id = _id2, X = 2, LX = 2, Y = 11, Date = new DateTime(2000, 1, 1, 1, 1, 1, 1, DateTimeKind.Utc), D = new D { Z = 22 }, NullableDouble = 2, A = new[] { 2, 3, 4 }, DA = new List<D> { new D { Y = 11, Z = 111 }, new D { Z = 222 } }, L = new List<int> { 2, 3, 4 } },
5353
new C { Id = _id1, X = 1, LX = 1, Y = 11, Date = new DateTime(2000, 2, 2, 2, 2, 2, 2, DateTimeKind.Utc), D = new D { Z = 11 }, NullableDouble = 2, S = "abc", SA = new string[] { "Tom", "Dick", "Harry" } },
54-
new C { Id = _id3, X = 3, LX = 3, Y = 33, Date = new DateTime(2001, 1, 1, 1, 1, 1, 1, DateTimeKind.Utc), D = new D { Z = 33 }, NullableDouble = 5, B = true, BA = new bool[] { true }, E = E.A, EA = new E[] { E.A, E.B } },
54+
new C { Id = _id3, X = 3, LX = 3, Y = 33, Date = new DateTime(2001, 1, 1, 1, 1, 1, 1, DateTimeKind.Utc), D = new D { Z = 33 }, NullableDouble = 5, B = true, BA = new bool[] { true }, E = E.A, ENullable = E.A, EA = new E[] { E.A, E.B } },
5555
new C { Id = _id5, X = 5, LX = 5, Y = 44, Date = new DateTime(2001, 2, 2, 2, 2, 2, 2, DateTimeKind.Utc), D = new D { Z = 55 }, DBRef = new MongoDBRef("db", "c", 1), F = new F { G = new G { H = 10 } } },
5656
new C { Id = _id4, X = 4, LX = 4, Y = 44, Date = new DateTime(2001, 3, 3, 3, 3, 3, 3, DateTimeKind.Utc), D = new D { Z = 44 }, S = " xyz ", DA = new List<D> { new D { Y = 33, Z = 333 }, new D { Y = 44, Z = 444 } } }
5757
}).GetAwaiter().GetResult();
@@ -345,6 +345,30 @@ public void TestWhereENotEqualsANot()
345345
Assert<C>(c => !(c.E != E.A), 1, "{ \"e\" : \"A\" }");
346346
}
347347

348+
[Test]
349+
public void TestWhereENullableEqualsA()
350+
{
351+
Assert<C>(c => c.ENullable == E.A, 1, "{ \"en\" : \"A\" }");
352+
}
353+
354+
[Test]
355+
public void TestWhereENullableEqualsNull()
356+
{
357+
Assert<C>(c => c.ENullable == null, 4, "{ \"en\" : null }");
358+
}
359+
360+
[Test]
361+
public void TestWhereENullabeEqualsAReversed()
362+
{
363+
Assert<C>(c => E.A == c.ENullable, 1, "{ \"en\" : \"A\" }");
364+
}
365+
366+
[Test]
367+
public void TestWhereENullabeEqualsNullReversed()
368+
{
369+
Assert<C>(c => null == c.ENullable, 4, "{ \"en\" : null }");
370+
}
371+
348372
[Test]
349373
public void TestWhereLContains2()
350374
{
@@ -1188,6 +1212,9 @@ private class C
11881212
[BsonIgnoreIfDefault]
11891213
[BsonRepresentation(BsonType.String)]
11901214
public E E { get; set; }
1215+
[BsonElement("en")]
1216+
[BsonRepresentation(BsonType.String)]
1217+
public E? ENullable { get; set; }
11911218
[BsonElement("ea")]
11921219
[BsonIgnoreIfNull]
11931220
public E[] EA { get; set; }

src/MongoDB.Driver/Linq/Processors/Normalizer.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ protected override Expression VisitBinary(BinaryExpression node)
8080

8181
// VB introduces a Convert on the LHS with a Nothing comparison, so we make it look like C# which does not have
8282
// any with a comparison to null
83-
if ((node.NodeType == ExpressionType.Equal || node.NodeType == ExpressionType.NotEqual) &&
84-
node.Left.NodeType == ExpressionType.Convert &&
83+
if ((node.NodeType == ExpressionType.Equal || node.NodeType == ExpressionType.NotEqual) &&
84+
node.Left.NodeType == ExpressionType.Convert &&
8585
node.Right.NodeType == ExpressionType.Constant)
8686
{
8787
var left = (UnaryExpression)node.Left;
@@ -163,7 +163,9 @@ private BinaryExpression EnsureConstantIsOnRight(BinaryExpression node)
163163
var left = node.Left;
164164
var right = node.Right;
165165
var operatorType = node.NodeType;
166-
if (left.NodeType == ExpressionType.Constant)
166+
if (left.NodeType == ExpressionType.Constant ||
167+
(left.NodeType == ExpressionType.Convert && ((UnaryExpression)left).Operand.NodeType == ExpressionType.Constant) ||
168+
(left.NodeType == ExpressionType.ConvertChecked && ((UnaryExpression)left).Operand.NodeType == ExpressionType.Constant))
167169
{
168170
right = node.Left;
169171
left = node.Right;

src/MongoDB.Driver/Linq/Processors/SerializationInfoBinder.cs

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
using System.Reflection;
2121
using MongoDB.Bson.Serialization;
2222
using MongoDB.Driver.Linq.Expressions;
23+
using MongoDB.Driver.Linq.Utils;
2324

2425
namespace MongoDB.Driver.Linq.Processors
2526
{
@@ -166,24 +167,16 @@ protected override Expression VisitUnary(UnaryExpression node)
166167
var unaryExpression = newNode as UnaryExpression;
167168
if (node != newNode &&
168169
unaryExpression != null &&
169-
!unaryExpression.Operand.Type.IsEnum && // enums are weird, so we skip them
170170
(newNode.NodeType == ExpressionType.Convert || newNode.NodeType == ExpressionType.ConvertChecked))
171171
{
172-
if (unaryExpression.Operand.Type.IsGenericType && unaryExpression.Operand.Type.GetGenericTypeDefinition() == typeof(Nullable<>))
173-
{
174-
var underlyingType = Nullable.GetUnderlyingType(node.Operand.Type);
175-
if (underlyingType.IsEnum)
176-
{
177-
// we skip enums because they are weird
178-
return newNode;
179-
}
180-
}
181-
182172
var serializationExpression = unaryExpression.Operand as ISerializationExpression;
183173
if (serializationExpression != null)
184174
{
185175
BsonSerializationInfo serializationInfo;
186-
if (!unaryExpression.Type.IsAssignableFrom(unaryExpression.Operand.Type))
176+
var operandType = unaryExpression.Operand.Type;
177+
if (!unaryExpression.Operand.Type.IsEnum &&
178+
!TypeHelper.IsNullableEnum(operandType) &&
179+
!unaryExpression.Type.IsAssignableFrom(unaryExpression.Operand.Type))
187180
{
188181
// only lookup a new serializer if the cast is "unnecessary"
189182
var serializer = _serializerRegistry.GetSerializer(node.Type);

src/MongoDB.Driver/Linq/Translators/PredicateTranslator.cs

Lines changed: 27 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -333,73 +333,44 @@ private FilterDefinition<BsonDocument> BuildComparisonQuery(BinaryExpression bin
333333

334334
private FilterDefinition<BsonDocument> BuildComparisonQuery(Expression variableExpression, ExpressionType operatorType, ConstantExpression constantExpression)
335335
{
336-
BsonSerializationInfo serializationInfo = null;
337336
var value = constantExpression.Value;
338337

339-
var unaryExpression = variableExpression as UnaryExpression;
340-
if (unaryExpression != null && (unaryExpression.NodeType == ExpressionType.Convert || unaryExpression.NodeType == ExpressionType.ConvertChecked))
338+
var methodCallExpression = variableExpression as MethodCallExpression;
339+
if (methodCallExpression != null && value is bool)
341340
{
342-
if (unaryExpression.Operand.Type.IsEnum)
343-
{
344-
var enumType = unaryExpression.Operand.Type;
345-
if (unaryExpression.Type == Enum.GetUnderlyingType(enumType))
346-
{
347-
serializationInfo = GetSerializationInfo(unaryExpression.Operand);
348-
value = Enum.ToObject(enumType, value); // serialize enum instead of underlying integer
349-
}
350-
}
351-
else if (
352-
unaryExpression.Type.IsGenericType &&
353-
unaryExpression.Type.GetGenericTypeDefinition() == typeof(Nullable<>) &&
354-
unaryExpression.Operand.Type.IsGenericType &&
355-
unaryExpression.Operand.Type.GetGenericTypeDefinition() == typeof(Nullable<>) &&
356-
unaryExpression.Operand.Type.GetGenericArguments()[0].IsEnum)
357-
{
358-
var enumType = unaryExpression.Operand.Type.GetGenericArguments()[0];
359-
if (unaryExpression.Type.GetGenericArguments()[0] == Enum.GetUnderlyingType(enumType))
360-
{
361-
serializationInfo = GetSerializationInfo(unaryExpression.Operand);
362-
if (value != null)
363-
{
364-
value = Enum.ToObject(enumType, value); // serialize enum instead of underlying integer
365-
}
366-
}
367-
}
368-
else
369-
{
370-
//Allows a cast, which would be required for compilation, such as (float){object} >= 25f to be built as __builder.GTE({object}, 25)
371-
serializationInfo = GetSerializationInfo(unaryExpression.Operand);
372-
}
341+
var boolValue = (bool)value;
342+
var query = this.BuildMethodCallQuery(methodCallExpression);
343+
344+
var isTrueComparison = (boolValue && operatorType == ExpressionType.Equal)
345+
|| (!boolValue && operatorType == ExpressionType.NotEqual);
346+
347+
return isTrueComparison ? query : __builder.Not(query);
373348
}
374-
else
349+
350+
var serializationInfo = GetSerializationInfo(variableExpression);
351+
var valueType = serializationInfo.Serializer.ValueType;
352+
if (valueType.IsEnum || TypeHelper.IsNullableEnum(valueType))
375353
{
376-
var methodCallExpression = variableExpression as MethodCallExpression;
377-
if (methodCallExpression != null && value is bool)
354+
if (!valueType.IsEnum && value != null)
378355
{
379-
var boolValue = (bool)value;
380-
var query = this.BuildMethodCallQuery(methodCallExpression);
381-
382-
var isTrueComparison = (boolValue && operatorType == ExpressionType.Equal)
383-
|| (!boolValue && operatorType == ExpressionType.NotEqual);
384-
385-
return isTrueComparison ? query : __builder.Not(query);
356+
valueType = TypeHelper.GetNullableUnderlyingType(valueType);
386357
}
387358

388-
serializationInfo = GetSerializationInfo(variableExpression);
359+
if (value != null)
360+
{
361+
value = Enum.ToObject(valueType, value);
362+
}
389363
}
390364

391-
if (serializationInfo != null)
365+
var serializedValue = serializationInfo.SerializeValue(value);
366+
switch (operatorType)
392367
{
393-
var serializedValue = serializationInfo.SerializeValue(value);
394-
switch (operatorType)
395-
{
396-
case ExpressionType.Equal: return __builder.Eq(serializationInfo.ElementName, serializedValue);
397-
case ExpressionType.GreaterThan: return __builder.Gt(serializationInfo.ElementName, serializedValue);
398-
case ExpressionType.GreaterThanOrEqual: return __builder.Gte(serializationInfo.ElementName, serializedValue);
399-
case ExpressionType.LessThan: return __builder.Lt(serializationInfo.ElementName, serializedValue);
400-
case ExpressionType.LessThanOrEqual: return __builder.Lte(serializationInfo.ElementName, serializedValue);
401-
case ExpressionType.NotEqual: return __builder.Ne(serializationInfo.ElementName, serializedValue);
402-
}
368+
case ExpressionType.Equal: return __builder.Eq(serializationInfo.ElementName, serializedValue);
369+
case ExpressionType.GreaterThan: return __builder.Gt(serializationInfo.ElementName, serializedValue);
370+
case ExpressionType.GreaterThanOrEqual: return __builder.Gte(serializationInfo.ElementName, serializedValue);
371+
case ExpressionType.LessThan: return __builder.Lt(serializationInfo.ElementName, serializedValue);
372+
case ExpressionType.LessThanOrEqual: return __builder.Lte(serializationInfo.ElementName, serializedValue);
373+
case ExpressionType.NotEqual: return __builder.Ne(serializationInfo.ElementName, serializedValue);
403374
}
404375

405376
return null;

src/MongoDB.Driver/Linq/Utils/TypeHelper.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,31 @@ internal static bool ImplementsInterface(Type candidate, Type iface)
6666
return candidate.GetInterfaces().Any(i => TypeHelper.ImplementsInterface(i, iface));
6767
}
6868

69+
internal static bool IsNullable(Type type)
70+
{
71+
return type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>);
72+
}
73+
74+
internal static bool IsNullableEnum(Type type)
75+
{
76+
if (!IsNullable(type))
77+
{
78+
return false;
79+
}
80+
81+
return GetNullableUnderlyingType(type).IsEnum;
82+
}
83+
84+
internal static Type GetNullableUnderlyingType(Type type)
85+
{
86+
if (!IsNullable(type))
87+
{
88+
throw new ArgumentException("Type must be nullable.", "type");
89+
}
90+
91+
return type.GetGenericArguments()[0];
92+
}
93+
6994
private static Type FindIEnumerable(Type seqType)
7095
{
7196
if (seqType == null || seqType == typeof(string))

0 commit comments

Comments
 (0)