Skip to content

Commit 07e8388

Browse files
rstamJamesKovacs
authored andcommitted
SHARP-4172: Enum constant not serialized using the correct serializer.
1 parent 31bd36f commit 07e8388

File tree

8 files changed

+221
-27
lines changed

8 files changed

+221
-27
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/EnumUnderlyingTypeSerializer.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ public EnumUnderlyingTypeSerializer(IBsonSerializer<TEnum> enumSerializer)
3838
_enumSerializer = Ensure.IsNotNull(enumSerializer, nameof(enumSerializer));
3939
}
4040

41+
// public properties
42+
public IBsonSerializer<TEnum> EnumSerializer => _enumSerializer;
43+
4144
// public methods
4245
public override TEnumUnderlyingType Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args)
4346
{

src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/KnownSerializers/KnownSerializerFinder.cs

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
using MongoDB.Bson.Serialization.Serializers;
2121
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
2222
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
23+
using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators;
2324
using ExpressionVisitor = System.Linq.Expressions.ExpressionVisitor;
2425

2526
namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers.KnownSerializers
@@ -71,23 +72,58 @@ public override Expression Visit(Expression node)
7172
return result;
7273
}
7374

74-
protected override Expression VisitMember(MemberExpression node)
75+
protected override Expression VisitBinary(BinaryExpression node)
7576
{
76-
var result = base.VisitMember(node);
77-
if (_currentSerializer != null &&
78-
_currentSerializer.TryGetMemberSerializationInfo(node.Member.Name, out var memberSerializationInfo))
79-
{
80-
_currentKnownSerializersNode.AddKnownSerializer(node.Type, memberSerializationInfo.Serializer);
77+
var result = base.VisitBinary(node);
8178

82-
if (memberSerializationInfo.Serializer is IBsonDocumentSerializer bsonDocumentSerializer)
79+
if (result is BinaryExpression binaryExpression)
80+
{
81+
if (BinaryExpressionToAggregationExpressionTranslator.IsEnumComparisonExpression(binaryExpression))
8382
{
84-
_currentSerializer = bsonDocumentSerializer;
83+
var leftExpression = ConvertHelper.RemoveConvertToEnumUnderlyingType(binaryExpression.Left);
84+
var rightExpression = ConvertHelper.RemoveConvertToEnumUnderlyingType(binaryExpression.Right);
85+
86+
if (leftExpression is ConstantExpression leftConstantExpression)
87+
{
88+
var rightExpressionSerializer = _registry.GetSerializer(rightExpression);
89+
var leftExpressionSerializer = EnumUnderlyingTypeSerializer.Create(rightExpressionSerializer);
90+
_registry.AddKnownSerializer(leftExpression, leftExpressionSerializer, allowPropagation: false);
91+
}
92+
93+
if (rightExpression is ConstantExpression rightConstantExpression)
94+
{
95+
var leftExpressionSerializer = _registry.GetSerializer(leftExpression);
96+
var rightExpressionSerializer = EnumUnderlyingTypeSerializer.Create(leftExpressionSerializer);
97+
_registry.AddKnownSerializer(rightExpression, rightExpressionSerializer, allowPropagation: false);
98+
}
8599
}
86-
else
100+
}
101+
102+
return result;
103+
}
104+
105+
protected override Expression VisitMember(MemberExpression node)
106+
{
107+
var result = base.VisitMember(node);
108+
109+
var containerSerializer = _registry.GetSerializer(node.Expression);
110+
if (containerSerializer is IBsonDocumentSerializer documentSerializer)
111+
{
112+
if (documentSerializer.TryGetMemberSerializationInfo(node.Member.Name, out var memberSerializationInfo))
87113
{
88-
_currentSerializer = null;
114+
_currentKnownSerializersNode.AddKnownSerializer(node.Type, memberSerializationInfo.Serializer);
115+
116+
if (memberSerializationInfo.Serializer is IBsonDocumentSerializer bsonDocumentSerializer)
117+
{
118+
_currentSerializer = bsonDocumentSerializer;
119+
}
120+
else
121+
{
122+
_currentSerializer = null;
123+
}
89124
}
90125
}
126+
91127
return result;
92128
}
93129

src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/KnownSerializers/KnownSerializersNode.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public KnownSerializersNode(Expression expression, KnownSerializersNode parent)
4242
public KnownSerializersNode Parent => _parent;
4343

4444
// public methods
45-
public void AddKnownSerializer(Type type, IBsonSerializer serializer)
45+
public void AddKnownSerializer(Type type, IBsonSerializer serializer, bool allowPropagation = true)
4646
{
4747
if (!_knownSerializers.TryGetValue(type, out var set))
4848
{
@@ -52,7 +52,7 @@ public void AddKnownSerializer(Type type, IBsonSerializer serializer)
5252

5353
set.Add(serializer);
5454

55-
if (ShouldPropagateKnownSerializerToParent())
55+
if (allowPropagation && ShouldPropagateKnownSerializerToParent())
5656
{
5757
_parent.AddKnownSerializer(type, serializer);
5858
}

src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/KnownSerializers/KnownSerializersRegistry.cs

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,53 @@ public void Add(Expression expression, KnownSerializersNode knownSerializers)
4242
_registry.Add(expression, knownSerializers);
4343
}
4444

45+
public void AddKnownSerializer(Expression expression, IBsonSerializer knownSerializer, bool allowPropagation = true)
46+
{
47+
if (knownSerializer.ValueType != expression.Type)
48+
{
49+
throw new ArgumentException($"Serializer value type {knownSerializer.ValueType} does not match expresion type {expression.Type}.", nameof(knownSerializer));
50+
}
51+
52+
if (!_registry.TryGetValue(expression, out var knownSerializers))
53+
{
54+
throw new InvalidOperationException("KnownSerializersNode does not exist yet for expression: {expression}.");
55+
}
56+
57+
knownSerializers.AddKnownSerializer(expression.Type, knownSerializer, allowPropagation);
58+
}
59+
4560
public IBsonSerializer GetSerializer(Expression expression, IBsonSerializer defaultSerializer = null)
4661
{
47-
var expressionType = expression is LambdaExpression lambdaExpression ? lambdaExpression.ReturnType : expression.Type;
48-
var possibleSerializers = _registry.TryGetValue(expression, out var knownSerializers) ? knownSerializers.GetPossibleSerializers(expressionType) : new HashSet<IBsonSerializer>();
62+
var type = expression is LambdaExpression lambdaExpression ? lambdaExpression.ReturnType : expression.Type;
63+
return GetSerializer(expression, type, defaultSerializer);
64+
}
65+
66+
private IBsonSerializer GetSerializer(Expression expression, Type type, IBsonSerializer defaultSerializer = null)
67+
{
68+
var possibleSerializers = _registry.TryGetValue(expression, out var knownSerializers) ? knownSerializers.GetPossibleSerializers(type) : new HashSet<IBsonSerializer>();
4969
return possibleSerializers.Count switch
5070
{
51-
0 => defaultSerializer ?? BsonSerializer.LookupSerializer(expressionType), // sometimes there is no known serializer from the context (e.g. CSHARP-4062)
71+
0 => defaultSerializer ?? LookupSerializer(expression, type), // sometimes there is no known serializer from the context (e.g. CSHARP-4062)
5272
1 => possibleSerializers.First(),
5373
_ => throw new InvalidOperationException($"More than one possible serializer found for {expression}.")
5474
};
5575
}
76+
77+
private IBsonSerializer LookupSerializer(Expression expression, Type type)
78+
{
79+
if (type.IsConstructedGenericType &&
80+
type.GetGenericTypeDefinition() == typeof(IGrouping<,>))
81+
{
82+
var genericArguments = type.GetGenericArguments();
83+
var keyType = genericArguments[0];
84+
var elementType = genericArguments[1];
85+
86+
var keySerializer = GetSerializer(expression, keyType);
87+
var elementSerializer = GetSerializer(expression, elementType);
88+
return IGroupingSerializer.Create(keySerializer, elementSerializer);
89+
}
90+
91+
return BsonSerializer.LookupSerializer(type);
92+
}
5693
}
5794
}

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/BinaryExpressionToAggregationExpressionTranslator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ private static bool IsComparisonOperator(ExpressionType nodeType)
127127
};
128128
}
129129

130-
private static bool IsEnumComparisonExpression(BinaryExpression expression)
130+
internal static bool IsEnumComparisonExpression(BinaryExpression expression)
131131
{
132132
return
133133
IsComparisonOperator(expression.NodeType) &&
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/* Copyright 2019-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System;
17+
using System.Collections.Generic;
18+
using System.Linq;
19+
using FluentAssertions;
20+
using MongoDB.Bson.Serialization;
21+
using MongoDB.Bson.Serialization.Attributes;
22+
using MongoDB.Bson.Serialization.Serializers;
23+
using MongoDB.Driver.Linq;
24+
using MongoDB.Driver.Tests.Linq.Linq3ImplementationTests;
25+
using Xunit;
26+
27+
namespace MongoDB.Driver.Tests.Jira
28+
{
29+
public class CSharp4172Tests : Linq3IntegrationTest
30+
{
31+
[Theory]
32+
[InlineData(LinqProvider.V2)]
33+
[InlineData(LinqProvider.V3)]
34+
public void Find_uses_the_expected_serializer(LinqProvider linqProvider)
35+
{
36+
var collection = GetCollection<Order>(null, null, linqProvider);
37+
38+
var find = collection.Find(o => o.Items.Any(i => i.Type == ItemType.Refund));
39+
var result = find.ToString();
40+
41+
var expectedResult = "find({ \"Items\" : { \"$elemMatch\" : { \"Type\" : \"refund\" } } })";
42+
result.Should().Be(expectedResult);
43+
}
44+
45+
[Theory]
46+
[InlineData(LinqProvider.V2)]
47+
[InlineData(LinqProvider.V3)]
48+
public void Aggregate_uses_the_expected_serializer(LinqProvider linqProvider)
49+
{
50+
var collection = GetCollection<Order>(null, null, linqProvider);
51+
52+
var aggregate = collection
53+
.Aggregate()
54+
.Project((o) => new { o.Id, HasAnyRefund = o.Items.Any(i => i.Type == ItemType.Refund) });
55+
var stages = Translate(collection, aggregate);
56+
57+
// LINQ2 uses the wrong serializer but won't be fixed
58+
var expectedStage = linqProvider switch
59+
{
60+
LinqProvider.V2 => "{ $project : { Id : '$_id', HasAnyRefund : { $anyElementTrue : { $map : { input : '$Items', as : 'i', in : { $eq : ['$$i.Type', 1] } } } }, _id : 0 } }",
61+
LinqProvider.V3 => "{ $project : { Id : '$_id', HasAnyRefund : { $anyElementTrue : { $map : { input : '$Items', as : 'i', in : { $eq : ['$$i.Type', 'refund'] } } } }, _id : 0 } }",
62+
_ => throw new ArgumentException($"Invalid linqProvider: {linqProvider}.", nameof(linqProvider))
63+
};
64+
AssertStages(stages, expectedStage);
65+
}
66+
67+
public class Order
68+
{
69+
public int Id { get; set; }
70+
public List<Item> Items { get; set; }
71+
}
72+
73+
public class Item
74+
{
75+
[BsonSerializer(typeof(CamelCaseEnumSerializer<ItemType>))]
76+
public ItemType Type { get; set; }
77+
}
78+
79+
public enum ItemType
80+
{
81+
SaleItem,
82+
Refund
83+
}
84+
85+
public class CamelCaseEnumSerializer<T> : EnumSerializer<T>
86+
where T : struct, Enum
87+
{
88+
private static string ToCamelCase(string s)
89+
{
90+
return char.ToLowerInvariant(s[0]) + s.Substring(1);
91+
}
92+
93+
public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, T value)
94+
{
95+
context.Writer.WriteString(ToCamelCase(value.ToString()));
96+
}
97+
98+
public override T Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args)
99+
{
100+
return (T)Enum.Parse(typeof(T), context.Reader.ReadString(), true);
101+
}
102+
}
103+
}
104+
}

tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationTests/Linq3IntegrationTest.cs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
* limitations under the License.
1414
*/
1515

16+
using System;
1617
using System.Collections.Generic;
1718
using System.Linq;
1819
using FluentAssertions;
1920
using MongoDB.Bson;
2021
using MongoDB.Bson.Serialization;
22+
using MongoDB.Driver.Linq;
2123
using MongoDB.Driver.Linq.Linq3Implementation;
2224
using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToExecutableQueryTranslators;
2325

@@ -56,16 +58,27 @@ protected void CreateCollection<TDocument>(IMongoCollection<TDocument> collectio
5658
CreateCollection(collection, (IEnumerable<TDocument>)documents); ;
5759
}
5860

61+
protected IMongoClient GetClient(LinqProvider linqProvider)
62+
{
63+
return linqProvider switch
64+
{
65+
LinqProvider.V2 => DriverTestConfiguration.Client,
66+
LinqProvider.V3 => DriverTestConfiguration.Linq3Client,
67+
_ => throw new ArgumentException($"Invalid linqProvider: {linqProvider}.", nameof(linqProvider))
68+
};
69+
}
70+
5971
protected IMongoCollection<TDocument> GetCollection<TDocument>(string collectionName = null)
6072
{
6173
var databaseName = DriverTestConfiguration.DatabaseNamespace.DatabaseName;
62-
collectionName ??= DriverTestConfiguration.CollectionNamespace.CollectionName;
6374
return GetCollection<TDocument>(databaseName, collectionName);
6475
}
6576

66-
protected IMongoCollection<TDocument> GetCollection<TDocument>(string databaseName, string collectionName)
77+
protected IMongoCollection<TDocument> GetCollection<TDocument>(string databaseName, string collectionName, LinqProvider linqProvider = LinqProvider.V3)
6778
{
68-
var client = DriverTestConfiguration.Linq3Client;
79+
databaseName ??= DriverTestConfiguration.DatabaseNamespace.DatabaseName;
80+
collectionName ??= DriverTestConfiguration.CollectionNamespace.CollectionName;
81+
var client = GetClient(linqProvider);
6982
var database = client.GetDatabase(databaseName);
7083
return database.GetCollection<TDocument>(collectionName);
7184
}

tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationTests/Serializers/KnownSerializers/KnownSerializerFinderTests.cs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
using MongoDB.Bson;
2020
using MongoDB.Bson.Serialization;
2121
using MongoDB.Bson.Serialization.Attributes;
22+
using MongoDB.Driver.Linq.Linq3Implementation.Serializers;
2223
using MongoDB.Driver.Linq.Linq3Implementation.Serializers.KnownSerializers;
2324
using Xunit;
2425

@@ -89,17 +90,17 @@ public void Enum_property_expression_should_return_enum_serializer_with_int_repr
8990
}
9091

9192
[Fact]
92-
public void Enum_comparison_expression_should_return_enum_serializer_with_int_representation()
93+
public void Enum_comparison_expression_should_use_underlying_type_serializer_for_constant_represented_as_int()
9394
{
9495
Expression<Func<C, bool>> expression = x => x.Ei == E.A;
9596
var collectionSerializer = GetCollectionSerializer();
9697

9798
var result = KnownSerializerFinder.FindKnownSerializers(expression, collectionSerializer);
9899

99100
var equalsExpression = (BinaryExpression)expression.Body;
100-
var serializer = result.GetSerializer(equalsExpression.Right);
101-
collectionSerializer.TryGetMemberSerializationInfo(nameof(C.Ei), out var expectedPropertySerializationInfo).Should().BeTrue();
102-
serializer.Should().Be(expectedPropertySerializationInfo.Serializer);
101+
var leftSerializer = result.GetSerializer(equalsExpression.Left);
102+
var rightSerializer = (EnumUnderlyingTypeSerializer<E, int>)result.GetSerializer(equalsExpression.Right);
103+
rightSerializer.EnumSerializer.Should().BeSameAs(leftSerializer);
103104
}
104105

105106
[Fact]
@@ -116,17 +117,17 @@ public void Enum_property_expression_should_return_enum_serializer_with_string_r
116117
}
117118

118119
[Fact]
119-
public void Enum_comparison_expression_should_return_enum_serializer_with_string_representation()
120+
public void Enum_comparison_expression_should_use_underlying_type_serializer_for_constant_represented_as_string()
120121
{
121122
Expression<Func<C, bool>> expression = x => x.Es == E.A;
122123
var collectionSerializer = GetCollectionSerializer();
123124

124125
var result = KnownSerializerFinder.FindKnownSerializers(expression, collectionSerializer);
125126

126127
var equalsExpression = (BinaryExpression)expression.Body;
127-
var serializer = result.GetSerializer(equalsExpression.Right);
128-
collectionSerializer.TryGetMemberSerializationInfo(nameof(C.Es), out var expectedPropertySerializationInfo).Should().BeTrue();
129-
serializer.Should().Be(expectedPropertySerializationInfo.Serializer);
128+
var leftSerializer = result.GetSerializer(equalsExpression.Left);
129+
var rightSerializer = (EnumUnderlyingTypeSerializer<E, int>)result.GetSerializer(equalsExpression.Right);
130+
rightSerializer.EnumSerializer.Should().BeSameAs(leftSerializer);
130131
}
131132

132133
[Fact]

0 commit comments

Comments
 (0)