Skip to content

Commit f958b93

Browse files
authored
CSHARP-4632: Consolidate new and init expression translation (#1094)
1 parent 8160082 commit f958b93

File tree

13 files changed

+297
-345
lines changed

13 files changed

+297
-345
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,19 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg
2727
internal static class MemberInitExpressionToAggregationExpressionTranslator
2828
{
2929
public static AggregationExpression Translate(TranslationContext context, MemberInitExpression expression)
30+
=> Translate(context, expression, expression.NewExpression, expression.Bindings);
31+
32+
public static AggregationExpression Translate(
33+
TranslationContext context,
34+
Expression expression,
35+
NewExpression newExpression,
36+
IReadOnlyList<MemberBinding> bindings)
3037
{
31-
var newExpression = expression.NewExpression;
3238
var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct
3339
var constructorArguments = newExpression.Arguments;
3440
var computedFields = new List<AstComputedField>();
3541

36-
var classMap = CreateClassMap(expression.Type, constructorInfo, out var creatorMap);
42+
var classMap = CreateClassMap(newExpression.Type, constructorInfo, out var creatorMap);
3743
if (constructorInfo != null && creatorMap != null)
3844
{
3945
var creatorMapParameters = creatorMap.Arguments?.ToArray();
@@ -50,12 +56,13 @@ public static AggregationExpression Translate(TranslationContext context, Member
5056
var constructorArgumentType = constructorArgumentExpression.Type;
5157
var constructorArgumentSerializer = constructorArgumentTranslation.Serializer ?? BsonSerializer.LookupSerializer(constructorArgumentType);
5258
var memberMap = EnsureMemberMap(expression, classMap, creatorMapParameter);
59+
EnsureDefaultValue(memberMap);
5360
memberMap.SetSerializer(constructorArgumentSerializer);
5461
computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, constructorArgumentTranslation.Ast));
5562
}
5663
}
5764

58-
foreach (var binding in expression.Bindings)
65+
foreach (var binding in bindings)
5966
{
6067
var memberAssignment = (MemberAssignment)binding;
6168
var member = memberAssignment.Member;
@@ -68,7 +75,7 @@ public static AggregationExpression Translate(TranslationContext context, Member
6875

6976
var ast = AstExpression.ComputedDocument(computedFields);
7077
classMap.Freeze();
71-
var serializerType = typeof(BsonClassMapSerializer<>).MakeGenericType(expression.Type);
78+
var serializerType = typeof(BsonClassMapSerializer<>).MakeGenericType(newExpression.Type);
7279
var serializer = (IBsonSerializer)Activator.CreateInstance(serializerType, classMap);
7380

7481
return new AggregationExpression(expression, ast, serializer);
@@ -77,7 +84,7 @@ public static AggregationExpression Translate(TranslationContext context, Member
7784
private static BsonClassMap CreateClassMap(Type classType, ConstructorInfo constructorInfo, out BsonCreatorMap creatorMap)
7885
{
7986
BsonClassMap baseClassMap = null;
80-
if (classType.BaseType != null)
87+
if (classType.BaseType != null)
8188
{
8289
baseClassMap = CreateClassMap(classType.BaseType, null, out _);
8390
}
@@ -132,6 +139,17 @@ static bool MemberMapMatchesCreatorMapParameter(BsonMemberMap memberMap, MemberI
132139
}
133140
}
134141

142+
private static void EnsureDefaultValue(BsonMemberMap memberMap)
143+
{
144+
if (memberMap.IsDefaultValueSpecified)
145+
{
146+
return;
147+
}
148+
149+
var defaultValue = memberMap.MemberType.IsValueType ? Activator.CreateInstance(memberMap.MemberType) : null;
150+
memberMap.SetDefaultValue(defaultValue);
151+
}
152+
135153
private static BsonMemberMap FindMemberMap(Expression expression, BsonClassMap classMap, string memberName)
136154
{
137155
foreach (var memberMap in classMap.DeclaredMemberMaps)

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs

Lines changed: 1 addition & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,7 @@
1515

1616
using System;
1717
using System.Collections.Generic;
18-
using System.Collections.ObjectModel;
19-
using System.Linq;
2018
using System.Linq.Expressions;
21-
using System.Reflection;
22-
using MongoDB.Bson.Serialization;
23-
using MongoDB.Driver.Linq.Linq3Implementation.Ast;
24-
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
2519

2620
namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators
2721
{
@@ -30,9 +24,6 @@ internal static class NewExpressionToAggregationExpressionTranslator
3024
public static AggregationExpression Translate(TranslationContext context, NewExpression expression)
3125
{
3226
var expressionType = expression.Type;
33-
var constructorInfo = expression.Constructor;
34-
var arguments = expression.Arguments.ToArray();
35-
var members = expression.Members;
3627

3728
if (expressionType == typeof(DateTime))
3829
{
@@ -50,91 +41,7 @@ public static AggregationExpression Translate(TranslationContext context, NewExp
5041
{
5142
return NewTupleExpressionToAggregationExpressionTranslator.Translate(context, expression);
5243
}
53-
54-
var classMapType = typeof(BsonClassMap<>).MakeGenericType(expressionType);
55-
var classMap = (BsonClassMap)Activator.CreateInstance(classMapType);
56-
var computedFields = new List<AstComputedField>();
57-
58-
// if Members is not null then trust Members more than the constructor parameter names (which are compiler generated for anonymous types)
59-
if (members == null)
60-
{
61-
var membersList = constructorInfo.GetParameters().Select(p => GetMatchingMember(expression, p.Name)).ToList();
62-
members = new ReadOnlyCollection<MemberInfo>(membersList);
63-
}
64-
65-
for (var i = 0; i < arguments.Length; i++)
66-
{
67-
var valueExpression = arguments[i];
68-
var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression);
69-
var valueType = valueExpression.Type;
70-
var valueSerializer = valueTranslation.Serializer ?? BsonSerializer.LookupSerializer(valueType);
71-
var defaultValue = GetDefaultValue(valueType);
72-
var memberMap = classMap.MapMember(members[i]).SetSerializer(valueSerializer).SetDefaultValue(defaultValue);
73-
computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, valueTranslation.Ast));
74-
}
75-
76-
// map any public fields or properties that didn't match a constructor argument
77-
foreach (var member in expressionType.GetFields().Cast<MemberInfo>().Concat(expressionType.GetProperties()))
78-
{
79-
if (!members.Contains(member))
80-
{
81-
var valueType = member switch
82-
{
83-
FieldInfo fieldInfo => fieldInfo.FieldType,
84-
PropertyInfo propertyInfo => propertyInfo.PropertyType,
85-
_ => throw new Exception($"Unexpected member type: {member.MemberType}")
86-
};
87-
var valueSerializer = context.KnownSerializersRegistry.GetSerializer(expression, valueType);
88-
var defaultValue = GetDefaultValue(valueType);
89-
classMap.MapMember(member).SetSerializer(valueSerializer).SetDefaultValue(defaultValue);
90-
}
91-
}
92-
93-
classMap.MapConstructor(constructorInfo, members.Select(m => m.Name).ToArray());
94-
classMap.Freeze();
95-
96-
var ast = AstExpression.ComputedDocument(computedFields);
97-
var serializerType = typeof(BsonClassMapSerializer<>).MakeGenericType(expression.Type);
98-
// Note that we should use context.KnownSerializersRegistry to find the serializer,
99-
// but the above implementation builds up computedFields during the mapping process.
100-
// We need to figure out how to resolve the serializer from KnownSerializers and then
101-
// populate computedFields from that resolved serializer.
102-
var serializer = (IBsonSerializer)Activator.CreateInstance(serializerType, classMap);
103-
104-
return new AggregationExpression(expression, ast, serializer);
105-
}
106-
107-
private static object GetDefaultValue(Type type)
108-
{
109-
if (type.IsValueType)
110-
{
111-
return Activator.CreateInstance(type);
112-
}
113-
else
114-
{
115-
return null;
116-
}
117-
}
118-
119-
private static MemberInfo GetMatchingMember(NewExpression expression, string constructorParameterName)
120-
{
121-
foreach (var field in expression.Type.GetFields())
122-
{
123-
if (field.Name.Equals(constructorParameterName, StringComparison.OrdinalIgnoreCase))
124-
{
125-
return field;
126-
}
127-
}
128-
129-
foreach (var property in expression.Type.GetProperties())
130-
{
131-
if (property.Name.Equals(constructorParameterName, StringComparison.OrdinalIgnoreCase))
132-
{
133-
return property;
134-
}
135-
}
136-
137-
throw new ExpressionNotSupportedException(expression, because: $"constructor parameter {constructorParameterName} does not match any public field or property");
44+
return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, expression, expression, Array.Empty<MemberBinding>());
13845
}
13946
}
14047
}

tests/MongoDB.Driver.Tests/Jira/CSharp4172Tests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public void Aggregate_uses_the_expected_serializer(
5959
var expectedStage = linqProvider switch
6060
{
6161
LinqProvider.V2 => "{ $project : { Id : '$_id', HasAnyRefund : { $anyElementTrue : { $map : { input : '$Items', as : 'i', in : { $eq : ['$$i.Type', 1] } } } }, _id : 0 } }",
62-
LinqProvider.V3 => "{ $project : { Id : '$_id', HasAnyRefund : { $anyElementTrue : { $map : { input : '$Items', as : 'i', in : { $eq : ['$$i.Type', 'refund'] } } } }, _id : 0 } }",
62+
LinqProvider.V3 => "{ $project : { _id : '$_id', HasAnyRefund : { $anyElementTrue : { $map : { input : '$Items', as : 'i', in : { $eq : ['$$i.Type', 'refund'] } } } } } }",
6363
_ => throw new ArgumentException($"Invalid linqProvider: {linqProvider}.", nameof(linqProvider))
6464
};
6565
AssertStages(stages, expectedStage);

tests/MongoDB.Driver.Tests/Linq/Linq2ImplementationTestsOnLinq3/MongoQueryableTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,8 +1230,8 @@ public void Select_followed_by_group()
12301230

12311231
Assert(query,
12321232
2,
1233-
"{ $project : { Id : '$_id', First : '$A', Second : '$B', _id : 0 } }",
1234-
"{ $group : { _id : '$First', __agg0 : { $push : { Id : '$Id', Second : '$Second' } } } }",
1233+
"{ $project : { _id : '$_id', First : '$A', Second : '$B' } }",
1234+
"{ $group : { _id : '$First', __agg0 : { $push : { _id : '$_id', Second : '$Second' } } } }",
12351235
"{ $project : { First : '$_id', Stuff : '$__agg0', _id : 0 } }");
12361236
}
12371237

tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationTests/Jira/CSharp3236Tests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public void Select_should_work()
3939
var stages = Translate(collection, queryable);
4040
AssertStages(
4141
stages,
42-
"{ $project : { Id : '$_id', Comments : { $filter : { input : '$Comments', as : 'c', cond : { $gte : [{ $indexOfCP : ['$$c.Text', 'test'] }, 0] } } }, _id : 0 } }");
42+
"{ $project : { _id : '$_id', Comments : { $filter : { input : '$Comments', as : 'c', cond : { $gte : [{ $indexOfCP : ['$$c.Text', 'test'] }, 0] } } } } }");
4343

4444
var result = queryable.Single();
4545
result.Id.Should().Be(1);

0 commit comments

Comments
 (0)