Skip to content

Commit fa6ec4a

Browse files
rstamDmitryLukyanov
authored andcommitted
CSHARP-4524: Handle projections using constructors for classes that use public fields instead of public properties.
1 parent 41604c9 commit fa6ec4a

File tree

2 files changed

+136
-23
lines changed

2 files changed

+136
-23
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515

1616
using System;
1717
using System.Collections.Generic;
18+
using System.Collections.ObjectModel;
1819
using System.Linq;
1920
using System.Linq.Expressions;
21+
using System.Reflection;
2022
using MongoDB.Bson.Serialization;
2123
using MongoDB.Driver.Linq.Linq3Implementation.Ast;
2224
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
@@ -49,40 +51,42 @@ public static AggregationExpression Translate(TranslationContext context, NewExp
4951
var classMap = (BsonClassMap)Activator.CreateInstance(classMapType);
5052
var computedFields = new List<AstComputedField>();
5153

52-
string[] propertyNames;
53-
if (members != null)
54+
// if Members is not null then trust Members more than the constructor parameter names (which are compiler generated for anonymous types)
55+
if (members == null)
5456
{
55-
// if Members is not null then trust Members more than the constructor parameter names (which are compiler generated for anonymous types)
56-
propertyNames = members.Select(member => member.Name).ToArray();
57-
}
58-
else
59-
{
60-
propertyNames = constructorInfo.GetParameters().Select(p => GetMatchingPropertyName(expression, p.Name)).ToArray();
57+
var membersList = constructorInfo.GetParameters().Select(p => GetMatchingMember(expression, p.Name)).ToList();
58+
members = new ReadOnlyCollection<MemberInfo>(membersList);
6159
}
6260

6361
for (var i = 0; i < arguments.Length; i++)
6462
{
65-
var propertyName = propertyNames[i];
6663
var valueExpression = arguments[i];
6764
var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression);
68-
var valueSerializer = valueTranslation.Serializer ?? BsonSerializer.LookupSerializer(valueExpression.Type);
69-
var defaultValue = GetDefaultValue(valueSerializer.ValueType);
70-
classMap.MapProperty(propertyName).SetSerializer(valueSerializer).SetDefaultValue(defaultValue);
71-
computedFields.Add(AstExpression.ComputedField(propertyName, valueTranslation.Ast));
65+
var valueType = valueExpression.Type;
66+
var valueSerializer = valueTranslation.Serializer ?? BsonSerializer.LookupSerializer(valueType);
67+
var defaultValue = GetDefaultValue(valueType);
68+
var memberMap = classMap.MapMember(members[i]).SetSerializer(valueSerializer).SetDefaultValue(defaultValue);
69+
computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, valueTranslation.Ast));
7270
}
7371

74-
// map any properties that didn't match a constructor argument
75-
foreach (var property in expressionType.GetProperties())
72+
// map any public fields or properties that didn't match a constructor argument
73+
foreach (var member in expressionType.GetFields().Cast<MemberInfo>().Concat(expressionType.GetProperties()))
7674
{
77-
if (!propertyNames.Contains(property.Name))
75+
if (!members.Contains(member))
7876
{
79-
var valueSerializer = context.KnownSerializersRegistry.GetSerializer(expression, property.PropertyType);
80-
var defaultValue = GetDefaultValue(valueSerializer.ValueType);
81-
classMap.MapProperty(property.Name).SetSerializer(valueSerializer).SetDefaultValue(defaultValue);
77+
var valueType = member switch
78+
{
79+
FieldInfo fieldInfo => fieldInfo.FieldType,
80+
PropertyInfo propertyInfo => propertyInfo.PropertyType,
81+
_ => throw new Exception($"Unexpected member type: {member.MemberType}")
82+
};
83+
var valueSerializer = context.KnownSerializersRegistry.GetSerializer(expression, valueType);
84+
var defaultValue = GetDefaultValue(valueType);
85+
classMap.MapMember(member).SetSerializer(valueSerializer).SetDefaultValue(defaultValue);
8286
}
8387
}
8488

85-
classMap.MapConstructor(constructorInfo, propertyNames);
89+
classMap.MapConstructor(constructorInfo, members.Select(m => m.Name).ToArray());
8690
classMap.Freeze();
8791

8892
var ast = AstExpression.ComputedDocument(computedFields);
@@ -108,17 +112,25 @@ private static object GetDefaultValue(Type type)
108112
}
109113
}
110114

111-
private static string GetMatchingPropertyName(NewExpression expression, string constructorParameterName)
115+
private static MemberInfo GetMatchingMember(NewExpression expression, string constructorParameterName)
112116
{
117+
foreach (var field in expression.Type.GetFields())
118+
{
119+
if (field.Name.Equals(constructorParameterName, StringComparison.OrdinalIgnoreCase))
120+
{
121+
return field;
122+
}
123+
}
124+
113125
foreach (var property in expression.Type.GetProperties())
114126
{
115127
if (property.Name.Equals(constructorParameterName, StringComparison.OrdinalIgnoreCase))
116128
{
117-
return property.Name;
129+
return property;
118130
}
119131
}
120132

121-
throw new ExpressionNotSupportedException(expression, because: $"constructor parameter {constructorParameterName} does not match any property");
133+
throw new ExpressionNotSupportedException(expression, because: $"constructor parameter {constructorParameterName} does not match any public field or property");
122134
}
123135
}
124136
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System;
17+
using FluentAssertions;
18+
using MongoDB.Bson.Serialization;
19+
using MongoDB.Driver.Core.Misc;
20+
using MongoDB.Driver.Core.TestHelpers.XunitExtensions;
21+
using MongoDB.Driver.Linq;
22+
using Xunit;
23+
24+
namespace MongoDB.Driver.Tests.Linq.Linq3ImplementationTests.Jira
25+
{
26+
public class CSharp4524Tests : Linq3IntegrationTest
27+
{
28+
[Fact]
29+
public void Find_with_projection_using_LINQ2_should_work()
30+
{
31+
var collection = CreateCollection(LinqProvider.V2);
32+
var find = collection.Find("{}").Project(x => new SpawnData(x.StartDate, x.SpawnPeriod));
33+
34+
var results = find.ToList();
35+
var projection = find.Options.Projection;
36+
var serializerRegistry = BsonSerializer.SerializerRegistry;
37+
var documentSerializer = serializerRegistry.GetSerializer<MyData>();
38+
var renderedProjection = projection.Render(documentSerializer, serializerRegistry, LinqProvider.V2);
39+
renderedProjection.Document.Should().Be("{ SpawnPeriod : 1, StartDate : 1, _id : 0 }");
40+
41+
results.Should().HaveCount(1);
42+
results[0].Date.Should().Be(new DateTime(2023, 1, 2, 3, 4, 5, DateTimeKind.Utc));
43+
results[0].Period.Should().Be(SpawnPeriod.LIVE);
44+
}
45+
46+
[Fact]
47+
public void Find_with_projection_using_LINQ3_should_work()
48+
{
49+
RequireServer.Check().Supports(Feature.FindProjectionExpressions);
50+
var collection = CreateCollection(LinqProvider.V3);
51+
var find = collection.Find("{}").Project(x => new SpawnData(x.StartDate, x.SpawnPeriod));
52+
53+
var results = find.ToList();
54+
55+
var projection = find.Options.Projection;
56+
var serializerRegistry = BsonSerializer.SerializerRegistry;
57+
var documentSerializer = serializerRegistry.GetSerializer<MyData>();
58+
var renderedProjection = projection.Render(documentSerializer, serializerRegistry, LinqProvider.V3);
59+
renderedProjection.Document.Should().Be("{ Date : '$StartDate', Period : '$SpawnPeriod', _id : 0 }");
60+
61+
results.Should().HaveCount(1);
62+
results[0].Date.Should().Be(new DateTime(2023, 1, 2, 3, 4, 5, DateTimeKind.Utc));
63+
results[0].Period.Should().Be(SpawnPeriod.LIVE);
64+
}
65+
66+
private IMongoCollection<MyData> CreateCollection(LinqProvider linqProvider)
67+
{
68+
var collection = GetCollection<MyData>("data", linqProvider);
69+
70+
CreateCollection(
71+
collection,
72+
new MyData { Id = 1, StartDate = new DateTime(2023, 1, 2, 3, 4, 5, DateTimeKind.Utc), SpawnPeriod = SpawnPeriod.LIVE });
73+
74+
return collection;
75+
}
76+
77+
public class MyData
78+
{
79+
public int Id { get; set; }
80+
public DateTime StartDate;
81+
public SpawnPeriod SpawnPeriod;
82+
}
83+
84+
public enum SpawnPeriod { LIVE, MIDNIGHT, MORNING, EVENING }
85+
86+
public struct SpawnData
87+
{
88+
public readonly DateTime Date;
89+
public readonly SpawnPeriod Period;
90+
91+
public SpawnData(DateTime date, SpawnPeriod period)
92+
{
93+
// Normally there is more complex handling here, value-type semantics are important, there are custom comparison operators, etc. hence the point of this struct.
94+
Date = date;
95+
Period = period;
96+
}
97+
98+
public bool Equals(SpawnData other) => Date == other.Date && Period == other.Period;
99+
}
100+
}
101+
}

0 commit comments

Comments
 (0)