Skip to content

Commit be9b610

Browse files
committed
CSHARP-3314: Implement Known Serializers strategy.
1 parent 8b4a4bd commit be9b610

File tree

25 files changed

+677
-57
lines changed

25 files changed

+677
-57
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ConvertHelper.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
using System;
1717
using System.Linq.Expressions;
18-
using MongoDB.Driver.Linq;
1918

2019
namespace MongoDB.Driver.Linq.Linq3Implementation.Misc
2120
{
@@ -37,6 +36,22 @@ public static Expression RemoveConvertToMongoQueryable(Expression expression)
3736
throw new ExpressionNotSupportedException(expression);
3837
}
3938

39+
public static Expression RemoveConvertToEnumUnderlyingType(Expression expression)
40+
{
41+
if (expression.NodeType == ExpressionType.Convert)
42+
{
43+
var convertExpression = (UnaryExpression)expression;
44+
var sourceType = convertExpression.Operand.Type;
45+
var targetType = convertExpression.Type;
46+
if (sourceType.IsEnum() && targetType == Enum.GetUnderlyingType(sourceType))
47+
{
48+
return convertExpression.Operand;
49+
}
50+
}
51+
52+
return expression;
53+
}
54+
4055
public static Expression RemoveWideningConvert(Expression expression)
4156
{
4257
if (expression.NodeType == ExpressionType.Convert)

src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/KnownSerializers/KnownSerializerFinder.cs

Lines changed: 117 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
* limitations under the License.
1414
*/
1515

16+
using System;
1617
using System.Linq.Expressions;
18+
using MongoDB.Bson.Serialization;
19+
using MongoDB.Bson.Serialization.Serializers;
20+
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
21+
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
1722
using ExpressionVisitor = System.Linq.Expressions.ExpressionVisitor;
1823

1924
namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers.KnownSerializers
@@ -22,32 +27,136 @@ internal class KnownSerializerFinder : ExpressionVisitor
2227
{
2328
#region static
2429
// public static methods
25-
public static KnownSerializersRegistry FindKnownSerializers(Expression root)
30+
public static KnownSerializersRegistry FindKnownSerializers(Expression root, IBsonDocumentSerializer rootSerializer)
2631
{
27-
var visitor = new KnownSerializerFinder();
32+
var visitor = new KnownSerializerFinder(root, rootSerializer);
2833
visitor.Visit(root);
2934
return visitor._registry;
3035
}
3136
#endregion
3237

3338
// private fields
34-
private KnownSerializersNode _expressionKnownSerializers = null;
35-
private readonly KnownSerializersRegistry _registry = new KnownSerializersRegistry();
39+
private KnownSerializersNode _currentKnownSerializersNode;
40+
private IBsonDocumentSerializer _currentSerializer;
41+
private readonly KnownSerializersRegistry _registry = new();
42+
private readonly Expression _root;
43+
private readonly IBsonDocumentSerializer _rootSerializer;
3644

3745
// constructors
38-
public KnownSerializerFinder()
46+
private KnownSerializerFinder(Expression root, IBsonDocumentSerializer rootSerializer)
3947
{
48+
_rootSerializer = rootSerializer;
49+
_root = root;
4050
}
4151

4252
// public methods
4353
public override Expression Visit(Expression node)
4454
{
45-
_expressionKnownSerializers = new KnownSerializersNode(_expressionKnownSerializers);
46-
_registry.Add(node, _expressionKnownSerializers);
55+
if (node == null)
56+
{
57+
return null;
58+
}
59+
60+
_currentKnownSerializersNode = new KnownSerializersNode(_currentKnownSerializersNode);
61+
62+
if (node == _root)
63+
{
64+
_currentSerializer = _rootSerializer;
65+
}
4766

4867
var result = base.Visit(node);
68+
_registry.Add(node, _currentKnownSerializersNode);
69+
_currentKnownSerializersNode = _currentKnownSerializersNode.Parent;
70+
return result;
71+
}
72+
73+
protected override Expression VisitMember(MemberExpression node)
74+
{
75+
var result = base.VisitMember(node);
76+
if (_currentSerializer != null &&
77+
_currentSerializer.TryGetMemberSerializationInfo(node.Member.Name, out var memberSerializationInfo))
78+
{
79+
_currentKnownSerializersNode.AddKnownSerializer(node.Type, memberSerializationInfo.Serializer);
80+
81+
if (memberSerializationInfo.Serializer is IBsonDocumentSerializer bsonDocumentSerializer)
82+
{
83+
_currentSerializer = bsonDocumentSerializer;
84+
}
85+
else
86+
{
87+
_currentSerializer = null;
88+
}
89+
}
90+
return result;
91+
}
92+
93+
protected override Expression VisitMethodCall(MethodCallExpression node)
94+
{
95+
var result = base.VisitMethodCall(node);
96+
97+
if (node.Method.Is(QueryableMethod.OfType) || node.Method.Is(EnumerableMethod.OfType))
98+
{
99+
var actualType = node.Method.GetGenericArguments()[0];
100+
var serializer = BsonSerializer.LookupSerializer(actualType);
101+
_currentKnownSerializersNode.AddKnownSerializer(node.Type, serializer);
102+
}
103+
104+
return result;
105+
}
106+
107+
protected override Expression VisitNew(NewExpression node)
108+
{
109+
var result = base.VisitNew(node);
110+
111+
if (node.Type == _rootSerializer.ValueType)
112+
{
113+
return result;
114+
}
115+
116+
IBsonSerializer serializer;
117+
if (node.Type == typeof(DateTime))
118+
{
119+
serializer = new DateTimeSerializer();
120+
}
121+
else if (node.Type == typeof(DateTimeOffset))
122+
{
123+
serializer = new DateTimeOffsetSerializer();
124+
}
125+
else
126+
{
127+
var classMapType = typeof(BsonClassMap<>).MakeGenericType(node.Type);
128+
var classMap = (BsonClassMap)Activator.CreateInstance(classMapType);
129+
classMap.AutoMap();
130+
classMap.Freeze();
131+
132+
var serializerType = typeof(BsonClassMapSerializer<>).MakeGenericType(node.Type);
133+
serializer = (IBsonSerializer)Activator.CreateInstance(serializerType, classMap);
134+
}
135+
136+
_currentKnownSerializersNode.AddKnownSerializer(node.Type, serializer);
137+
138+
return result;
139+
}
140+
141+
protected override Expression VisitParameter(ParameterExpression node)
142+
{
143+
var result = base.VisitParameter(node);
144+
145+
if (node.Type == _rootSerializer.ValueType)
146+
{
147+
_currentSerializer = _rootSerializer;
148+
_currentKnownSerializersNode.AddKnownSerializer(node.Type, _rootSerializer);
149+
}
150+
151+
if (_currentSerializer is IBsonArraySerializer arraySerializer &&
152+
arraySerializer.TryGetItemSerializationInfo(out var itemSerializationInfo) &&
153+
node.Type == itemSerializationInfo.NominalType &&
154+
itemSerializationInfo.Serializer is IBsonDocumentSerializer documentSerializer)
155+
{
156+
_currentSerializer = documentSerializer;
157+
_currentKnownSerializersNode.AddKnownSerializer(node.Type, documentSerializer);
158+
}
49159

50-
_expressionKnownSerializers = _expressionKnownSerializers.Parent;
51160
return result;
52161
}
53162
}

src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/KnownSerializers/KnownSerializersNode.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
using System.Linq;
1919
using MongoDB.Bson.Serialization;
2020
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
21+
using MongoDB.Driver.Support;
2122

2223
namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers.KnownSerializers
2324
{
@@ -47,6 +48,8 @@ public void AddKnownSerializer(Type type, IBsonSerializer serializer)
4748
}
4849

4950
set.Add(serializer);
51+
52+
_parent?.AddKnownSerializer(type, serializer);
5053
}
5154

5255
public HashSet<IBsonSerializer> GetPossibleSerializers(Type type)
@@ -82,7 +85,8 @@ private HashSet<IBsonSerializer> GetPossibleSerializersAtThisLevel(Type type)
8285
var possibleSerializers = new HashSet<IBsonSerializer>();
8386
foreach (var serializer in _knownSerializers.Values.SelectMany(hashset => hashset))
8487
{
85-
if (serializer.ValueType.IsAssignableFrom(type))
88+
var valueType = serializer.ValueType;
89+
if (valueType == type || valueType.IsEnum() && Enum.GetUnderlyingType(valueType) == type)
8690
{
8791
possibleSerializers.Add(serializer);
8892
}
@@ -96,7 +100,7 @@ private HashSet<IBsonSerializer> GetPossibleSerializersAtThisLevel(Type type)
96100
}
97101
}
98102

99-
if (serializer.ValueType == itemType)
103+
if (valueType == itemType)
100104
{
101105
var ienumerableSerializer = IEnumerableSerializer.Create(serializer);
102106
possibleSerializers.Add(ienumerableSerializer);

src/MongoDB.Driver/Linq/Linq3Implementation/Serializers/KnownSerializers/KnownSerializersRegistry.cs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
using System;
1717
using System.Collections.Generic;
18+
using System.Linq;
1819
using System.Linq.Expressions;
1920
using MongoDB.Bson.Serialization;
2021

@@ -23,24 +24,26 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Serializers.KnownSerializers
2324
internal class KnownSerializersRegistry
2425
{
2526
// private fields
26-
private readonly Dictionary<Expression, KnownSerializersNode> _registry = new Dictionary<Expression, KnownSerializersNode>();
27+
private readonly Dictionary<Expression, KnownSerializersNode> _registry = new();
2728

2829
// public methods
2930
public void Add(Expression expression, KnownSerializersNode knownSerializers)
3031
{
32+
if (_registry.ContainsKey(expression)) return;
33+
3134
_registry.Add(expression, knownSerializers);
3235
}
3336

34-
public HashSet<IBsonSerializer> GetPossibleSerializers(Expression expression, Type type)
37+
public IBsonSerializer GetSerializer(Expression expression, IBsonSerializer defaultSerializer = null)
3538
{
36-
if (_registry.TryGetValue(expression, out var knownSerializers))
37-
{
38-
return knownSerializers.GetPossibleSerializers(type);
39-
}
40-
else
39+
var expressionType = expression is LambdaExpression lambdaExpression ? lambdaExpression.ReturnType : expression.Type;
40+
var possibleSerializers = _registry.TryGetValue(expression, out var knownSerializers) ? knownSerializers.GetPossibleSerializers(expressionType) : new HashSet<IBsonSerializer>();
41+
return possibleSerializers.Count switch
4142
{
42-
return new HashSet<IBsonSerializer>();
43-
}
43+
0 => defaultSerializer ?? throw new InvalidOperationException($"Cannot find serializer for {expression}."),
44+
> 1 => throw new InvalidOperationException($"More than one possible serializer found for {expression}."),
45+
_ => possibleSerializers.First()
46+
};
4447
}
4548
}
4649
}

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/BinaryExpressionToAggregationExpressionTranslator.cs

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
* limitations under the License.
1414
*/
1515

16+
using System;
1617
using System.Linq.Expressions;
1718
using MongoDB.Bson.Serialization;
19+
using MongoDB.Bson.Serialization.Serializers;
1820
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
1921
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
2022
using MongoDB.Driver.Support;
@@ -38,6 +40,12 @@ public static AggregationExpression Translate(TranslationContext context, Binary
3840
rightExpression = ConvertHelper.RemoveWideningConvert(rightExpression);
3941
}
4042

43+
if (IsEnumComparisonExpression(expression))
44+
{
45+
leftExpression = ConvertHelper.RemoveConvertToEnumUnderlyingType(leftExpression);
46+
rightExpression = ConvertHelper.RemoveConvertToEnumUnderlyingType(rightExpression);
47+
}
48+
4149
var leftTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, leftExpression);
4250
var rightTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, rightExpression);
4351

@@ -64,7 +72,24 @@ public static AggregationExpression Translate(TranslationContext context, Binary
6472
ExpressionType.Subtract => AstExpression.Subtract(leftTranslation.Ast, rightTranslation.Ast),
6573
_ => throw new ExpressionNotSupportedException(expression)
6674
};
67-
var serializer = BsonSerializer.LookupSerializer(expression.Type); // TODO: get correct serializer
75+
var serializer = expression.Type switch
76+
{
77+
Type t when t == typeof(bool) => new BooleanSerializer(),
78+
Type t when t == typeof(string) => new StringSerializer(),
79+
Type t when t == typeof(byte) => new ByteSerializer(),
80+
Type t when t == typeof(short) => new Int16Serializer(),
81+
Type t when t == typeof(ushort) => new UInt16Serializer(),
82+
Type t when t == typeof(int) => new Int32Serializer(),
83+
Type t when t == typeof(uint) => new UInt32Serializer(),
84+
Type t when t == typeof(long) => new Int64Serializer(),
85+
Type t when t == typeof(ulong) => new UInt64Serializer(),
86+
Type t when t == typeof(float) => new SingleSerializer(),
87+
Type t when t == typeof(double) => new DoubleSerializer(),
88+
Type t when t == typeof(decimal) => new DecimalSerializer(),
89+
Type { IsConstructedGenericType: true } t when t.GetGenericTypeDefinition() == typeof(Nullable<>) => (IBsonSerializer)Activator.CreateInstance(typeof(NullableSerializer<>).MakeGenericType(t.GenericTypeArguments[0])),
90+
Type { IsArray: true } t => (IBsonSerializer)Activator.CreateInstance(typeof(ArraySerializer<>).MakeGenericType(t.GetElementType())),
91+
_ => context.KnownSerializersRegistry.GetSerializer(expression) // Required for Coalesce
92+
};
6893

6994
return new AggregationExpression(expression, ast, serializer);
7095
}
@@ -88,6 +113,40 @@ private static bool IsArithmeticOperator(ExpressionType nodeType)
88113
};
89114
}
90115

116+
private static bool IsComparisonOperator(ExpressionType nodeType)
117+
{
118+
return nodeType switch
119+
{
120+
ExpressionType.Equal => true,
121+
ExpressionType.GreaterThan => true,
122+
ExpressionType.GreaterThanOrEqual => true,
123+
ExpressionType.LessThan => true,
124+
ExpressionType.LessThanOrEqual => true,
125+
ExpressionType.NotEqual => true,
126+
_ => false
127+
};
128+
}
129+
130+
private static bool IsEnumComparisonExpression(BinaryExpression expression)
131+
{
132+
return
133+
IsComparisonOperator(expression.NodeType) &&
134+
(IsConvertToEnumUnderlyingType(expression.Left) || IsConvertToEnumUnderlyingType(expression.Right));
135+
136+
static bool IsConvertToEnumUnderlyingType(Expression expression)
137+
{
138+
if (expression.NodeType == ExpressionType.Convert)
139+
{
140+
var convertExpression = (UnaryExpression)expression;
141+
var sourceType = convertExpression.Operand.Type;
142+
var targetType = convertExpression.Type;
143+
return sourceType.IsEnum() && targetType == Enum.GetUnderlyingType(sourceType);
144+
}
145+
146+
return false;
147+
}
148+
}
149+
91150
private static bool IsStringConcatenationExpression(BinaryExpression expression)
92151
{
93152
return

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConditionalExpressionToAggregationExpressionTranslator.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
*/
1515

1616
using System.Linq.Expressions;
17-
using MongoDB.Bson.Serialization;
1817
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
1918

2019
namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators
@@ -32,7 +31,7 @@ public static AggregationExpression Translate(TranslationContext context, Condit
3231
var ifFalseExpression = expression.IfFalse;
3332
var ifFalseTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, ifFalseExpression);
3433
var ast = AstExpression.Cond(testTranslation.Ast, ifTrueTranslation.Ast, ifFalseTranslation.Ast);
35-
var serializer = BsonSerializer.LookupSerializer(expression.Type); // TODO: use known serializer
34+
var serializer = context.KnownSerializersRegistry.GetSerializer(expression);
3635
return new AggregationExpression(expression, ast, serializer);
3736
}
3837

0 commit comments

Comments
 (0)