Skip to content

Commit a28da50

Browse files
committed
CSHARP-4802: Projection of nested field does not compute dotted field name properly.
1 parent ca977be commit a28da50

File tree

2 files changed

+104
-6
lines changed

2 files changed

+104
-6
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ProjectionHelper.cs

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
* limitations under the License.
1414
*/
1515

16+
using System;
1617
using System.Collections.Generic;
1718
using MongoDB.Bson;
1819
using MongoDB.Bson.Serialization;
@@ -77,18 +78,42 @@ private static (IReadOnlyList<AstProjectStageSpecification>, IBsonSerializer) Cr
7778
private static (IReadOnlyList<AstProjectStageSpecification>, IBsonSerializer) CreateFindGetFieldProjection(AggregationExpression expression)
7879
{
7980
var getFieldExpressionAst = (AstGetFieldExpression)expression.Ast;
80-
if (getFieldExpressionAst.HasSafeFieldName(out var fieldName))
81+
if (IsGetFieldChainWithSafeFieldNames(getFieldExpressionAst))
8182
{
82-
var specifications = fieldName == "_id" ?
83-
new List<AstProjectStageSpecification> { AstProject.Include(fieldName) } :
84-
new List<AstProjectStageSpecification> { AstProject.Include(fieldName), AstProject.Exclude("_id") };
85-
var wrappedValueSerializer = WrappedValueSerializer.Create(fieldName, expression.Serializer);
86-
return (specifications, wrappedValueSerializer);
83+
var (dottedFieldName, serializer) = CreateGetFieldChainWithSafeFieldNamesProjection(getFieldExpressionAst, expression.Serializer);
84+
var specifications = new List<AstProjectStageSpecification> { AstProject.Include(dottedFieldName) };
85+
if (dottedFieldName != "_id" && !dottedFieldName.StartsWith("_id."))
86+
{
87+
specifications.Add(AstProject.Exclude("_id"));
88+
}
89+
return (specifications, serializer);
8790
}
8891

8992
return CreateWrappedValueProjection(expression);
9093
}
9194

95+
private static (string, IBsonSerializer) CreateGetFieldChainWithSafeFieldNamesProjection(AstGetFieldExpression getFieldExpression, IBsonSerializer serializer)
96+
{
97+
if (getFieldExpression.HasSafeFieldName(out var fieldName))
98+
{
99+
var wrappedValueSerializer = WrappedValueSerializer.Create(fieldName, serializer);
100+
var input = getFieldExpression.Input;
101+
102+
if (input is AstVarExpression varExpression && varExpression.Name == "ROOT")
103+
{
104+
return (fieldName, wrappedValueSerializer);
105+
}
106+
107+
if (input is AstGetFieldExpression outerGetFieldExpression)
108+
{
109+
var (outerDottedFieldName, outerWrappedValueSerializer) = CreateGetFieldChainWithSafeFieldNamesProjection(outerGetFieldExpression, wrappedValueSerializer);
110+
return (outerDottedFieldName + "." + fieldName, outerWrappedValueSerializer);
111+
}
112+
}
113+
114+
throw new ArgumentException($"{nameof(CreateGetFieldChainWithSafeFieldNamesProjection)} called with an invalid getFieldExpression.", nameof(getFieldExpression));
115+
}
116+
92117
private static (IReadOnlyList<AstProjectStageSpecification>, IBsonSerializer) CreateWrappedValueProjection(AggregationExpression expression)
93118
{
94119
var wrappedValueSerializer = WrappedValueSerializer.Create("_v", expression.Serializer);
@@ -101,6 +126,16 @@ private static (IReadOnlyList<AstProjectStageSpecification>, IBsonSerializer) Cr
101126
return (specifications, wrappedValueSerializer);
102127
}
103128

129+
private static bool IsGetFieldChainWithSafeFieldNames(AstGetFieldExpression getFieldExpression)
130+
{
131+
return
132+
getFieldExpression.HasSafeFieldName(out _) &&
133+
(
134+
(getFieldExpression.Input is AstVarExpression varExpression && varExpression.Name == "ROOT") ||
135+
(getFieldExpression.Input is AstGetFieldExpression nestedGetFieldExpression && IsGetFieldChainWithSafeFieldNames(nestedGetFieldExpression))
136+
);
137+
}
138+
104139
private static AstExpression QuoteIfNecessary(AstExpression expression)
105140
{
106141
if (expression is AstConstantExpression constantExpression)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using FluentAssertions;
17+
using MongoDB.Driver.Linq;
18+
using MongoDB.TestHelpers.XunitExtensions;
19+
using Xunit;
20+
21+
namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira
22+
{
23+
public class CSharp4802Tests : Linq3IntegrationTest
24+
{
25+
[Theory]
26+
[ParameterAttributeData]
27+
public void Find_with_projection_of_subfield_should_work(
28+
[Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
29+
{
30+
var collection = GetCollection(linqProvider);
31+
32+
var find = collection.Find(d => d.Status == "a").Project(d => d.SubDocument.Id);
33+
34+
var projection = TranslateFindProjection(collection, find);
35+
projection.Should().Be("{ 'SubDocument._id' : 1, _id : 0 }");
36+
37+
var result = find.Single();
38+
result.Should().Be(11);
39+
}
40+
41+
private IMongoCollection<Document> GetCollection(LinqProvider linqProvider)
42+
{
43+
var collection = GetCollection<Document>("test", linqProvider);
44+
CreateCollection(
45+
collection,
46+
new Document { Id = 1, Status = "a", SubDocument = new SubDocument { Id = 11 } },
47+
new Document { Id = 2, Status = "b", SubDocument = new SubDocument { Id = 22 } });
48+
return collection;
49+
}
50+
51+
private class Document
52+
{
53+
public int Id { get; set; }
54+
public string Status { get; set; }
55+
public SubDocument SubDocument { get; set; }
56+
}
57+
58+
public class SubDocument
59+
{
60+
public int Id { get; set; }
61+
}
62+
}
63+
}

0 commit comments

Comments
 (0)