Skip to content

Commit 164667d

Browse files
committed
CSHARP-4586: Map constructor in MemberInit and simplify $project for simple field inclusion.
1 parent 77da1b3 commit 164667d

File tree

12 files changed

+253
-16
lines changed

12 files changed

+253
-16
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,20 @@ exprFilter.Expression is AstConstantExpression constantExpression &&
322322
return node;
323323
}
324324

325+
public override AstNode VisitProjectStageSetFieldSpecification(AstProjectStageSetFieldSpecification node)
326+
{
327+
node = (AstProjectStageSetFieldSpecification)base.VisitProjectStageSetFieldSpecification(node);
328+
329+
// { path : '$path' } => { path : 1 }
330+
if (node.Value is AstFieldPathExpression fieldPathExpression &&
331+
fieldPathExpression.Path == $"${node.Path}")
332+
{
333+
return AstProject.Include(node.Path);
334+
}
335+
336+
return node;
337+
}
338+
325339
public override AstNode VisitUnaryExpression(AstUnaryExpression node)
326340
{
327341
// { $first : <arg> } => { $arrayElemAt : [<arg>, 0] } (or -1 for $last)

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
using System;
1717
using System.Collections.Generic;
18+
using System.Linq;
1819
using System.Linq.Expressions;
1920
using System.Reflection;
2021
using MongoDB.Bson.Serialization;
@@ -31,8 +32,10 @@ public static AggregationExpression Translate(TranslationContext context, Member
3132
var classMap = CreateClassMap(expression.Type);
3233

3334
var newExpression = expression.NewExpression;
34-
var constructorParameters = newExpression.Constructor.GetParameters();
35+
var constructorInfo = newExpression.Constructor;
36+
var constructorParameters = constructorInfo.GetParameters();
3537
var constructorArguments = newExpression.Arguments;
38+
var memberNames = new string[constructorParameters.Length];
3639
for (var i = 0; i < constructorParameters.Length; i++)
3740
{
3841
var constructorParameter = constructorParameters[i];
@@ -43,6 +46,7 @@ public static AggregationExpression Translate(TranslationContext context, Member
4346
computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, argumentTranslation.Ast));
4447

4548
memberMap.SetSerializer(argumentTranslation.Serializer);
49+
memberNames[i] = memberMap.MemberName;
4650
}
4751

4852
foreach (var binding in expression.Bindings)
@@ -60,6 +64,7 @@ public static AggregationExpression Translate(TranslationContext context, Member
6064

6165
var ast = AstExpression.ComputedDocument(computedFields);
6266

67+
classMap.MapConstructor(constructorInfo, memberNames);
6368
classMap.Freeze();
6469
var serializerType = typeof(BsonClassMapSerializer<>).MakeGenericType(expression.Type);
6570
var serializer = (IBsonSerializer)Activator.CreateInstance(serializerType, classMap);

tests/MongoDB.Driver.Tests/Linq/Linq2ImplementationTestsOnLinq3/MongoQueryableTests.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ public void Distinct_document_preceded_by_select_where()
215215

216216
Assert(query,
217217
1,
218-
"{ $project: { 'A': '$A', 'B': '$B', '_id': 0 } }",
218+
"{ $project: { 'A': 1, 'B': 1, '_id': 0 } }",
219219
"{ $match: { 'A': 'Awesome' } }",
220220
"{ $group: { '_id': '$$ROOT' } }",
221221
"{ $replaceRoot : { newRoot : '$_id' } }");
@@ -233,7 +233,7 @@ public void Distinct_document_preceded_by_where_select()
233233
Assert(query,
234234
1,
235235
"{ $match : { 'A' : 'Awesome' } }",
236-
"{ $project : { A : '$A', B : '$B', _id : 0 } }",
236+
"{ $project : { A : 1, B : 1, _id : 0 } }",
237237
"{ $group : { '_id' : '$$ROOT' } }",
238238
"{ $replaceRoot : { newRoot : '$_id' } }");
239239
}
@@ -999,7 +999,7 @@ public void Select_new_of_same()
999999

10001000
Assert(query,
10011001
2,
1002-
"{ $project : { _id : '$_id', A : '$A' } }");
1002+
"{ $project : { _id : 1, A : 1 } }");
10031003
}
10041004

10051005
[Fact]

tests/MongoDB.Driver.Tests/Linq/Linq2ImplementationTestsOnLinq3/Translators/AggregateGroupTranslatorTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ public void Should_translate_just_id()
7070
AssertStages(
7171
result.Stages,
7272
"{ $group : { _id : '$A' } }",
73-
"{ $project : { _id : '$_id' } }");
73+
"{ $project : { _id : 1 } }");
7474

7575
result.Value._id.Should().Be("Amazing");
7676
}

tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationTests/Jira/CSharp1555Tests.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public void Select_new_Person_should_work()
4343
.Select(p => new Person { Id = p.Id, Name = p.Name });
4444

4545
var stages = Translate(collection, queryable);
46-
AssertStages(stages, "{ $project : { _id : '$_id', Name : '$Name' } }");
46+
AssertStages(stages, "{ $project : { _id : 1, Name : 1 } }");
4747

4848
var result = queryable.ToList().Single();
4949
result.ShouldBeEquivalentTo(new Person { Id = 1, Name = "A" });
@@ -57,7 +57,7 @@ public void Select_new_Person_without_Name_should_work()
5757
.Select(p => new Person { Id = p.Id });
5858

5959
var stages = Translate(collection, queryable);
60-
AssertStages(stages, "{ $project : { _id : '$_id' } }");
60+
AssertStages(stages, "{ $project : { _id : 1 } }");
6161

6262
var result = queryable.ToList().Single();
6363
result.ShouldBeEquivalentTo(new Person { Id = 1, Name = null });
@@ -71,7 +71,7 @@ public void Select_new_Person_without_Id_should_work()
7171
.Select(p => new Person { Name = p.Name });
7272

7373
var stages = Translate(collection, queryable);
74-
AssertStages(stages, "{ $project : { Name : '$Name', _id : 0 } }");
74+
AssertStages(stages, "{ $project : { Name : 1, _id : 0 } }");
7575

7676
var result = queryable.ToList().Single();
7777
result.ShouldBeEquivalentTo(new Person { Id = 0, Name = "A" });

tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationTests/Jira/CSharp2723Tests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ public void Nested_Select_should_work()
115115
}",
116116
@"{
117117
'$project':{
118-
'_id':'$_id',
118+
'_id':1,
119119
'ParentName':'$Name',
120120
'Children':{
121121
'$map':{

tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationTests/Jira/CSharp3614Tests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public void Test()
4242
});
4343

4444
var stages = Translate(collection, queryable);
45-
AssertStages(stages, "{ $project : { _id : '$_id', PageCount : '$PageCount', Author : { $cond : { if : { $eq : ['$Author', null] }, then : null, else : { _id : '$Author._id', Name : '$Author.Name' } } } } }");
45+
AssertStages(stages, "{ $project : { _id : 1, PageCount : 1, Author : { $cond : { if : { $eq : ['$Author', null] }, then : null, else : { _id : '$Author._id', Name : '$Author.Name' } } } } }");
4646

4747
var results = queryable.ToList().OrderBy(r => r.Id).ToList();
4848
results.Should().HaveCount(2);

tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationTests/Jira/CSharp3922Tests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public void Select_with_constructor_call_should_work()
4949
var stages = Translate(collection, queryable);
5050
AssertStages(
5151
stages,
52-
"{ $project : { X : '$X', _id : 0 } }",
52+
"{ $project : { X : 1, _id : 0 } }",
5353
"{ $project : { R : '$X', S : '$Y', _id : 0 } }",
5454
"{ $project : { T : '$R', U : '$S', _id : 0 } }");
5555
}
@@ -67,7 +67,7 @@ public void Select_with_constructor_call_and_property_set_should_work()
6767
var stages = Translate(collection, queryable);
6868
AssertStages(
6969
stages,
70-
"{ $project : { X : '$X', Y : { $literal : 123 }, _id : 0 } }",
70+
"{ $project : { X : 1, Y : { $literal : 123 }, _id : 0 } }",
7171
"{ $project : { R : '$X', S : '$Y', _id : 0 } }",
7272
"{ $project : { T : '$R', U : '$S', _id : 0 } }");
7373
}

tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationTests/Jira/CSharp4468Tests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public void Query1_should_should_work(
5959
"{ $project : { _v : '$Lines', _id : 0 } }",
6060
"{ $unwind : '$_v' }",
6161
"{ $group : { _id : '$_v.ItemId', __agg0 : { $sum : '$_v.TotalAmount' } } }",
62-
"{ $project : { _id : '$_id', TotalAmount : '$__agg0' } }"
62+
"{ $project : { _id : 1, TotalAmount : '$__agg0' } }"
6363
};
6464
}
6565
AssertStages(stages, expectedStages);
@@ -96,7 +96,7 @@ public void Query2_should_should_work(
9696
expectedStages = new[]
9797
{
9898
"{ $group : { _id : '$_id', __agg0 : { $sum : '$TotalAmount' } } }",
99-
"{ $project : { _id : '$_id', TotalAmount : '$__agg0' } }" // only difference from LINQ2 is "_id" vs "Id"
99+
"{ $project : { _id : 1, TotalAmount : '$__agg0' } }" // only difference from LINQ2 is "_id" vs "Id"
100100
};
101101
}
102102
AssertStages(stages, expectedStages);
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using FluentAssertions;
17+
using MongoDB.Driver.Linq;
18+
using MongoDB.TestHelpers.XunitExtensions;
19+
using Xunit;
20+
21+
#if NET6_0_OR_GREATER
22+
namespace MongoDB.Driver.Tests.Linq.Linq3ImplementationTests.Jira
23+
{
24+
public class CSharp4586Tests : Linq3IntegrationTest
25+
{
26+
[Theory]
27+
[ParameterAttributeData]
28+
public void Project_View1_with_constructor_should_work(
29+
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
30+
{
31+
var collection = CreateCollection(linqProvider);
32+
var id = "a";
33+
var filter = Builders<Model>.Filter.Eq(m => m.Id, id);
34+
35+
var find = collection
36+
.Find(filter)
37+
.Project(Builders<Model>.Projection.Expression(m => new View1(m.Id)));
38+
39+
var projection = TranslateFindProjection(collection, find);
40+
if (linqProvider == LinqProvider.V2)
41+
{
42+
projection.Should().Be("{ _id : 1 }");
43+
}
44+
else
45+
{
46+
projection.Should().Be("{ Id : '$_id', _id : 0 }");
47+
}
48+
49+
var results = find.ToList();
50+
results.Should().HaveCount(1);
51+
results[0].Id.Should().Be("a");
52+
53+
var deleteResult = collection.DeleteOne(filter);
54+
deleteResult.DeletedCount.Should().Be(1);
55+
}
56+
57+
[Theory]
58+
[ParameterAttributeData]
59+
public void Project_View1_with_empty_initializer_should_work(
60+
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
61+
{
62+
var collection = CreateCollection(linqProvider);
63+
var id = "a";
64+
var filter = Builders<Model>.Filter.Eq(m => m.Id, id);
65+
66+
var find = collection
67+
.Find(filter)
68+
.Project(Builders<Model>.Projection.Expression(m => new View1(m.Id) { }));
69+
70+
var projection = TranslateFindProjection(collection, find);
71+
projection.Should().Be("{ _id : 1 }");
72+
73+
var results = find.ToList();
74+
results.Should().HaveCount(1);
75+
results[0].Id.Should().Be("a");
76+
77+
var deleteResult = collection.DeleteOne(filter);
78+
deleteResult.DeletedCount.Should().Be(1);
79+
}
80+
81+
[Theory]
82+
[ParameterAttributeData]
83+
public void Project_View2_with_constructor_should_work(
84+
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
85+
{
86+
var collection = CreateCollection(linqProvider);
87+
var id = "a";
88+
var filter = Builders<Model>.Filter.Eq(m => m.Id, id);
89+
90+
var find = collection
91+
.Find(filter)
92+
.Project(Builders<Model>.Projection.Expression(m => new View2(m.Id)));
93+
94+
var projection = TranslateFindProjection(collection, find);
95+
if (linqProvider == LinqProvider.V2)
96+
{
97+
projection.Should().Be("{ _id : 1 }");
98+
}
99+
else
100+
{
101+
projection.Should().Be("{ Id : '$_id', _id : 0 }");
102+
}
103+
104+
var results = find.ToList();
105+
results.Should().HaveCount(1);
106+
results[0].Id.Should().Be("a");
107+
results[0].Version.Should().NotHaveValue();
108+
109+
var deleteResult = collection.DeleteOne(filter);
110+
deleteResult.DeletedCount.Should().Be(1);
111+
}
112+
113+
[Theory]
114+
[ParameterAttributeData]
115+
public void Project_View2_with_empty_initializer_should_work(
116+
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
117+
{
118+
var collection = CreateCollection(linqProvider);
119+
var id = "a";
120+
var filter = Builders<Model>.Filter.Eq(m => m.Id, id);
121+
122+
var find = collection
123+
.Find(filter)
124+
.Project(Builders<Model>.Projection.Expression(m => new View2(m.Id) { }));
125+
126+
var projection = TranslateFindProjection(collection, find);
127+
projection.Should().Be("{ _id : 1 }");
128+
129+
var results = find.ToList();
130+
results.Should().HaveCount(1);
131+
results[0].Id.Should().Be("a");
132+
results[0].Version.Should().NotHaveValue();
133+
134+
var deleteResult = collection.DeleteOne(filter);
135+
deleteResult.DeletedCount.Should().Be(1);
136+
}
137+
138+
[Theory]
139+
[ParameterAttributeData]
140+
public void Project_View2_with_initializer_should_work(
141+
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
142+
{
143+
var collection = CreateCollection(linqProvider);
144+
var id = "a";
145+
var filter = Builders<Model>.Filter.Eq(m => m.Id, id);
146+
147+
var find = collection
148+
.Find(filter)
149+
.Project(Builders<Model>.Projection.Expression(m => new View2(m.Id) { Version = 1 }));
150+
151+
var projection = TranslateFindProjection(collection, find);
152+
if (linqProvider == LinqProvider.V2)
153+
{
154+
projection.Should().Be("{ _id : 1 }"); // apparently LINQ2 handles Version = 1 client side
155+
}
156+
else
157+
{
158+
projection.Should().Be("{ _id : 1, Version : { $literal : 1 } }");
159+
}
160+
161+
var results = find.ToList();
162+
results.Should().HaveCount(1);
163+
results[0].Id.Should().Be("a");
164+
results[0].Version.Should().Be(1);
165+
166+
var deleteResult = collection.DeleteOne(filter);
167+
deleteResult.DeletedCount.Should().Be(1);
168+
}
169+
170+
private IMongoCollection<Model> CreateCollection(LinqProvider linqProvider)
171+
{
172+
var collection = GetCollection<Model>("test", linqProvider);
173+
CreateCollection(collection, new Model("a"));
174+
return collection;
175+
}
176+
177+
private class Model
178+
{
179+
public Model(string id)
180+
{
181+
Id = id;
182+
}
183+
184+
public string Id { get; private set; }
185+
}
186+
187+
private class View1
188+
{
189+
public View1(string id)
190+
{
191+
Id = id;
192+
}
193+
194+
public string Id { get; }
195+
}
196+
197+
private class View2
198+
{
199+
public View2(string id)
200+
{
201+
Id = id;
202+
}
203+
204+
public string Id { get; }
205+
public int? Version { get; init; } // View1 does not have this property
206+
}
207+
}
208+
}
209+
#endif

0 commit comments

Comments
 (0)